From d5af89624edc74d5dca5b63ca8d1f76a5a300308 Mon Sep 17 00:00:00 2001 From: Gayathri Murali Date: Mon, 7 Sep 2015 12:57:05 -0700 Subject: [PATCH 0001/2089] [SPARK-10380] - Confusing examples in pyspark SQL docs --- python/pyspark/sql/column.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 56e75e8caee88..0338825839987 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -329,6 +329,8 @@ def cast(self, dataType): [Row(ages=u'2'), Row(ages=u'5')] >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] + + (Note : astype is alias for cast) """ if isinstance(dataType, basestring): jc = self._jc.cast(dataType) From b415f7e7e382b7e69322a1d01b3b5063c408a5f2 Mon Sep 17 00:00:00 2001 From: Gayathri Murali Date: Mon, 7 Sep 2015 13:23:00 -0700 Subject: [PATCH 0002/2089] [SPARK-10380]- Confusing examples in pyspark SQL docs : Added drop_duplicates documentation change --- python/pyspark/sql/dataframe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e269ef4304f3f..71df09a83974d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -943,7 +943,9 @@ def dropDuplicates(self, subset=None): +---+------+-----+ | 5| 80|Alice| +---+------+-----+ - """ + + (Note: drop_duplicates is an alias for dropDuplicates) + """ if subset is None: jdf = self._jdf.dropDuplicates() else: From 5ffe752b59e468d55363f1f24b17a3677927ca8f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Sep 2015 10:42:30 -1000 Subject: [PATCH 0003/2089] [SPARK-9767] Remove ConnectionManager. We introduced the Netty network module for shuffle in Spark 1.2, and has turned it on by default for 3 releases. The old ConnectionManager is difficult to maintain. If we merge the patch now, by the time it is released, it would be 1 yr for which ConnectionManager is off by default. It's time to remove it. Author: Reynold Xin Closes #8161 from rxin/SPARK-9767. --- .../scala/org/apache/spark/SparkEnv.scala | 11 +- .../netty/NettyBlockTransferService.scala | 8 +- .../spark/network/nio/BlockMessage.scala | 175 --- .../spark/network/nio/BlockMessageArray.scala | 140 -- .../spark/network/nio/BufferMessage.scala | 114 -- .../apache/spark/network/nio/Connection.scala | 619 -------- .../spark/network/nio/ConnectionId.scala | 36 - .../spark/network/nio/ConnectionManager.scala | 1157 --------------- .../network/nio/ConnectionManagerId.scala | 37 - .../apache/spark/network/nio/Message.scala | 114 -- .../spark/network/nio/MessageChunk.scala | 41 - .../network/nio/MessageChunkHeader.scala | 83 -- .../network/nio/NioBlockTransferService.scala | 217 --- .../spark/network/nio/SecurityMessage.scala | 160 --- .../spark/serializer/KryoSerializer.scala | 4 - .../network/nio/ConnectionManagerSuite.scala | 296 ---- .../BlockManagerReplicationSuite.scala | 6 +- .../spark/storage/BlockManagerSuite.scala | 10 +- docs/configuration.md | 11 - project/MimaExcludes.scala | 1257 +++++++++-------- .../streaming/ReceivedBlockHandlerSuite.scala | 10 +- 21 files changed, 651 insertions(+), 3855 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/Connection.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/Message.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0f1e2e069568d..c6fef7f91f00c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -33,7 +33,6 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} @@ -326,15 +325,7 @@ object SparkEnv extends Logging { val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) - val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { - case "netty" => - new NettyBlockTransferService(conf, securityManager, numUsableCores) - case "nio" => - logWarning("NIO-based block transfer service is deprecated, " + - "and will be removed in Spark 1.6.0.") - new NioBlockTransferService(conf, securityManager) - } + val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index d5ad2c9ad00e8..4b851bcb36597 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -149,7 +149,11 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } override def close(): Unit = { - server.close() - clientFactory.close() + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala deleted file mode 100644 index 79cb0640c8672..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -// private[spark] because we need to register them in Kryo -private[spark] case class GetBlock(id: BlockId) -private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) -private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) - -private[nio] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: BlockId = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = BlockId(idBuilder.toString) - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = typ - def getId: BlockId = id - def getData: ByteBuffer = data - def getLevel: StorageLevel = level - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) - buffer.putInt(typ).putInt(id.name.length) - id.name.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[nio] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala deleted file mode 100644 index f1c9ea8b64ca3..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark._ -import org.apache.spark.storage.{StorageLevel, TestBlockId} - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) - extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int): BlockMessage = blockMessages(i) - - def iterator: Iterator[BlockMessage] = blockMessages.iterator - - def length: Int = blockMessages.length - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + - (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - Message.createBufferMessage(buffers) - } -} - -private[nio] object BlockMessageArray extends Logging { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear() - BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, - StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - logDebug("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - logDebug("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - logDebug("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - logDebug("Converted back to block message array") - // scalastyle:off println - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - // scalastyle:on println - } -} - - diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala deleted file mode 100644 index 9a9e22b0c2366..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.storage.BlockManager - - -private[nio] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size: Int = initialSize - - def currentSize(): Int = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - val security = if (isSecurityNeg) 1 else 0 - if (size == 0 && !gotChunkForSendingOnce) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, - hasError, security, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - val security = if (isSecurityNeg) 1 else 0 - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId(): Boolean = ackId != 0 - - def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - - override def toString: String = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala deleted file mode 100644 index 8d9ebadaf79d4..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ /dev/null @@ -1,619 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.LinkedList - -import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.control.NonFatal - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} - -private[nio] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, - val securityMgr: SecurityManager) - extends Logging { - - var sparkSaslServer: SparkSaslServer = null - var sparkSaslClient: SparkSaslClient = null - - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, - securityMgr_ : SecurityManager) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), - id_, securityMgr_) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /* channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def isSaslComplete(): Boolean - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - private def disposeSasl() { - if (sparkSaslServer != null) { - sparkSaslServer.dispose() - } - - if (sparkSaslClient != null) { - sparkSaslClient.dispose() - } - } - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key(): SelectionKey = channel.keyFor(selector) - - def getRemoteAddress(): InetSocketAddress = { - channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - } - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - disposeSasl() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Throwable) => Unit) { - onExceptionCallbacks.add(callback) - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks.asScala.foreach { - callback => - try { - callback(this, e) - } catch { - case NonFatal(e) => { - logWarning("Ignored error in onExceptionCallback", e) - } - } - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.length + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[nio] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslClient != null) sparkSaslClient.isComplete() else false - } - - private class Outbox { - val messages = new LinkedList[Message]() - val defaultChunkSize = 65536 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized { - messages.add(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /* val message = messages(nextMessageToBeUsed) */ - - val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { - // only allow sending of security messages until sasl is complete - var pos = 0 - var securityMsg: Message = null - while (pos < messages.size() && securityMsg == null) { - if (messages.get(pos).isSecurityNeg) { - securityMsg = messages.remove(pos) - } - pos = pos + 1 - } - // didn't find any security messages and auth isn't completed so return - if (securityMsg == null) return None - securityMsg - } else { - messages.removeFirst() - } - - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.add(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox() - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /* channel.socket.setSendBufferSize(256 * 1024) */ - - override def getRemoteAddress(): InetSocketAddress = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def registerAfterAuth(): Unit = { - outbox.synchronized { - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try { - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => - logError("Error connecting to " + address, e) - callOnExceptionCallbacks(e) - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallbacks(e) - } - } - true - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as - // normal registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /* key.interestOps(0) */ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), - e) - callOnExceptionCallbacks(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection( - channel_ : SocketChannel, - selector_ : Selector, - id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(channel_, selector_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslServer != null) sparkSaslServer.isComplete() else false - } - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - newMessage.isSecurityNeg = header.securityNeg == 1 - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The receiver's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection, Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /* logDebug("Read " + bytesRead + " bytes for the buffer") */ - - if (currentChunk.buffer.remaining == 0) { - /* println("Filled buffer at " + System.currentTimeMillis) */ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip() - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala deleted file mode 100644 index b3b281ff465f1..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString: String = { - connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId - } -} - -private[nio] object ConnectionId { - - def createConnectionIdFromString(connectionIdString: String): ConnectionId = { - val res = connectionIdString.split("_").map(_.trim()) - if (res.size != 3) { - throw new Exception("Error converting ConnectionId string: " + connectionIdString + - " to a ConnectionId Object") - } - new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala deleted file mode 100644 index 9143918790381..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ /dev/null @@ -1,1157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.lang.ref.WeakReference -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} -import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future, Promise} -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.{ThreadUtils, Utils} - -import scala.util.Try -import scala.util.control.NonFatal - -private[nio] class ConnectionManager( - port: Int, - conf: SparkConf, - securityManager: SecurityManager, - name: String = "Connection manager") - extends Logging { - - /** - * Used by sendMessageReliably to track messages being sent. - * @param message the message that was sent - * @param connectionManagerId the connection manager that sent this message - * @param completionHandler callback that's invoked when the send has completed or failed - */ - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: Try[Message] => Unit) { - - def success(ackMessage: Message) { - if (ackMessage == null) { - failure(new NullPointerException) - } - else { - completionHandler(scala.util.Success(ackMessage)) - } - } - - def failWithoutAck() { - completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) - } - - def failure(e: Throwable) { - completionHandler(scala.util.Failure(e)) - } - } - - private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = - new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) - - private val ackTimeout = - conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", - conf.get("spark.network.timeout", "120s")) - - // Get the thread counts from the Spark Configuration. - // - // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, - // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is - // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" - // parameter is necessary. - private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) - private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4) - private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1) - - private val handleMessageExecutor = new ThreadPoolExecutor( - handlerThreadCount, - handlerThreadCount, - conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-message-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleMessageExecutor is not handled properly", t) - } - } - } - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - ioThreadCount, - ioThreadCount, - conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-read-write-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleReadWriteExecutor is not handled properly", t) - } - } - } - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : - // which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - connectThreadCount, - connectThreadCount, - conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-connect-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleConnectExecutor is not handled properly", t) - } - } - } - - private val serverChannel = ServerSocketChannel.open() - // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] - with SynchronizedMap[ConnectionId, SendingConnection] - private val connectionsByKey = - new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] - with SynchronizedMap[ConnectionManagerId, SendingConnection] - // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this - // map when messages are sent and are removed when acknowledgement messages are received or when - // acknowledgement timeouts expire - private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor( - ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) - - @volatile - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null - - private val authEnabled = securityManager.isAuthenticationEnabled() - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - private def startService(port: Int): (ServerSocketChannel, Int) = { - serverChannel.socket.bind(new InetSocketAddress(port)) - (serverChannel, serverChannel.socket.getLocalPort) - } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - // used in combination with the ConnectionManagerId to create unique Connection ids - // to be able to track asynchronous messages - private val idCount: AtomicInteger = new AtomicInteger(1) - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - @volatile private var isActive = true - private val selectorThread = new Thread("connection-manager-thread") { - override def run(): Unit = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - // start this thread last, since it invokes run(), which accesses members above - selectorThread.start() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in - // triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } catch { - case NonFatal(e) => { - logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallbacks(e) - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while (isActive) { - while (!registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue() - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue() - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + - connection.getRemoteConnectionManagerId() + "] changed from [" + - intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions - // should be dealt with differently. - case e: CancelledKeyException => - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - 0 - - case e: ClosedSelectorException => - logDebug("Failed select() as selector is closed.", e) - return - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + - " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, - // key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, - securityManager) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - connection match { - case sendingConnection: SendingConnection => - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId - - messageStatuses.synchronized { - messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) - .foreach(status => { - logInfo("Notifying " + status) - status.failWithoutAck() - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case receivingConnection: ReceivingConnection => - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (!sendingConnectionOpt.isDefined) { - logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert(sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values - if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.failWithoutAck() - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case _ => logError("Unsupported type of connection.") - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Throwable) { - logInfo("Handling connection error on connection to " + - connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registrations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - try { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } catch { - case NonFatal(e) => { - logError("Error when handling messages from " + - connection.getRemoteConnectionManagerId(), e) - connection.callOnExceptionCallbacks(e) - } - } - } - } - handleMessageExecutor.execute(runnable) - /* handleMessage(connection, message) */ - } - - private def handleClientAuthentication( - waitingConn: SendingConnection, - securityMsg: SecurityMessage, - connectionId : ConnectionId) { - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } else { - var replyToken : Array[Byte] = null - try { - replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new IOException("Error creating security message") - sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => - logError("Error handling sasl client authentication", e) - waitingConn.close() - throw new IOException("Error evaluating sasl response: ", e) - } - } - } - - private def handleServerAuthentication( - connection: Connection, - securityMsg: SecurityMessage, - connectionId: ConnectionId) { - if (!connection.isSaslComplete()) { - logDebug("saslContext not established") - var replyToken : Array[Byte] = null - try { - connection.synchronized { - if (connection.sparkSaslServer == null) { - logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) - } - } - replyToken = connection.sparkSaslServer.response(securityMsg.getToken) - if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId + - " for: " + connectionId) - } else { - logDebug("Server sasl not completed: " + connection.connectionId + - " for: " + connectionId) - } - if (replyToken != null) { - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security Message") - sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } - } catch { - case e: Exception => { - logError("Error in server auth negotiation: " + e) - // It would probably be better to send an error message telling other side auth failed - // but for now just close - connection.close() - } - } - } else { - logDebug("connection already established for this connection id: " + connection.connectionId) - } - } - - - private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { - if (bufferMessage.isSecurityNeg) { - logDebug("This is security neg message") - - // parse as SecurityMessage - val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) - val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) - - connectionsAwaitingSasl.get(connectionId) match { - case Some(waitingConn) => { - // Client - this must be in response to us doing Send - logDebug("Client handleAuth for id: " + waitingConn.connectionId) - handleClientAuthentication(waitingConn, securityMsg, connectionId) - } - case None => { - // Server - someone sent us something and we haven't authenticated yet - logDebug("Server handleAuth for id: " + connectionId) - handleServerAuthentication(conn, securityMsg, connectionId) - } - } - return true - } else { - if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. - logError("message sent that is not security negotiation message on connection " + - "not authenticated yet, ignoring it!!") - return true - } - } - false - } - - private def handleMessage( - connectionManagerId: ConnectionManagerId, - message: Message, - connection: Connection) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (authEnabled) { - val res = handleAuthentication(connection, bufferMessage) - if (res) { - // message was security negotiation so skip the rest - logDebug("After handleAuth result was true, returning") - return - } - } - if (bufferMessage.hasAckId()) { - messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status.success(message) - } - case None => { - /** - * We can fall down on this code because of following 2 cases - * - * (1) Invalid ack sent due to buggy code. - * - * (2) Late-arriving ack for a SendMessageStatus - * To avoid unwilling late-arriving ack - * caused by long pause like GC, you can set - * larger value than default to spark.core.connection.ack.wait.timeout - */ - logWarning(s"Could not find reference for received ack Message ${message.id}") - } - } - } - } else { - var ackMessage : Option[Message] = None - try { - ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - } catch { - case e: Exception => { - logError(s"Exception was thrown while processing message", e) - ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) - } - } finally { - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { - // see if we need to do sasl before writing - // this should only be the first negotiation as the Client!!! - if (!conn.isSaslComplete()) { - conn.synchronized { - if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) - var firstResponse: Array[Byte] = null - try { - firstResponse = conn.sparkSaslClient.firstToken() - val securityMsg = SecurityMessage.fromResponse(firstResponse, - conn.connectionId.toString()) - val message = securityMsg.toBufferMessage - if (message == null) throw new Exception("Error creating security message") - connectionsAwaitingSasl += ((conn.connectionId, conn)) - sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + - " to: " + connManagerId) - } catch { - case e: Exception => { - logError("Error getting first response from the SaslClient.", e) - conn.close() - throw new Exception("Error getting first response from the SaslClient") - } - } - } - } - } else { - logDebug("Sasl already established ") - } - } - - // allow us to add messages to the inbox for doing sasl negotiating - private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId, securityManager) - logInfo("creating new sending connection for security! " + newConnectionId ) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? - // We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - message.senderAddress = id.toSocketAddress() - logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") - val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) - - // send security message until going connection has been authenticated - connection.send(message) - - wakeupSelector() - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, - connectionManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId, securityManager) - newConnection.onException { - case (conn, e) => { - logError("Exception while sending message.", e) - reportSendingMessageFailure(message.id, e) - } - } - logTrace("creating new sending connection: " + newConnectionId) - registerRequests.enqueue(newConnection) - - newConnection - } - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - - message.senderAddress = id.toSocketAddress() - logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + - "connectionid: " + connection.connectionId) - - if (authEnabled) { - try { - checkSendAuthFirst(connectionManagerId, connection) - } catch { - case NonFatal(e) => { - reportSendingMessageFailure(message.id, e) - } - } - } - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - wakeupSelector() - } - - private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(messageId) - s match { - case Some(msgStatus) => { - messageStatuses -= messageId - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.failure(e) - } - case None => { - logError("no messageStatus for failed message id: " + messageId) - } - } - } - } - - private def wakeupSelector() { - selector.wakeup() - } - - /** - * Send a message and block until an acknowledgment is received or an error occurs. - * @param connectionManagerId the message's destination - * @param message the message being sent - * @return a Future that either returns the acknowledgment message or captures an exception. - */ - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Message] = { - val promise = Promise[Message]() - - // It's important that the TimerTask doesn't capture a reference to `message`, which can cause - // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time - // at which they would originally be scheduled to run. Therefore, extract the message id - // from outside of the TimerTask closure (see SPARK-4393 for more context). - val messageId = message.id - // Keep a weak reference to the promise so that the completed promise may be garbage-collected - val promiseReference = new WeakReference(promise) - val timeoutTask: TimerTask = new TimerTask { - override def run(timeout: Timeout): Unit = { - messageStatuses.synchronized { - messageStatuses.remove(messageId).foreach { s => - val e = new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec") - val p = promiseReference.get - if (p != null) { - // Attempt to fail the promise with a Timeout exception - if (!p.tryFailure(e)) { - // If we reach here, then someone else has already signalled success or failure - // on this promise, so log a warning: - logError("Ignore error because promise is completed", e) - } - } else { - // The WeakReference was empty, which should never happen because - // sendMessageReliably's caller should have a strong reference to promise.future; - logError("Promise was garbage collected; this should never happen!", e) - } - } - } - } - } - - val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) - - val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTaskHandle.cancel() - s match { - case scala.util.Failure(e) => - // Indicates a failure where we either never sent or never got ACK'd - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - case scala.util.Success(ackMessage) => - if (ackMessage.hasError) { - val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head - val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) - errorMsgByteBuf.get(errorMsgBytes) - val errorMsg = new String(errorMsgBytes, UTF_8) - val e = new IOException( - s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - } else { - if (!promise.trySuccess(ackMessage)) { - logWarning("Drop ackMessage because promise is completed") - } - } - } - }) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - - sendMessage(connectionManagerId, message) - promise.future - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - isActive = false - ackTimeoutMonitor.stop() - selector.close() - selectorThread.interrupt() - selectorThread.join() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - import scala.concurrent.ExecutionContext.Implicits.global - - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // scalastyle:off println - println("Received [" + msg + "] from [" + id + "]") - // scalastyle:on println - None - }) - - /* testSequentialSending(manager) */ - /* System.gc() */ - - /* testParallelSending(manager) */ - /* System.gc() */ - - /* testParallelDecreasingSending(manager) */ - /* System.gc() */ - - testContinuousSending(manager) - System.gc() - } - - // scalastyle:off println - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms " + - "(" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count) { i => - val bufferLen = size * (i + 1) - val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte) - ByteBuffer.allocate(bufferLen).put(bufferContent) - } - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /* println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } - // scalastyle:on println -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala deleted file mode 100644 index 1cd13d887c6f6..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.InetSocketAddress - -import org.apache.spark.util.Utils - -private[nio] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) -} - - -private[nio] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala deleted file mode 100644 index 85d2fe2bf9c20..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Charsets.UTF_8 - -import org.apache.spark.util.Utils - -private[nio] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - var isSecurityNeg = false - var hasError = false - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString: String = { - this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" - } -} - - -private[nio] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId(): Int = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - /** - * Create a "negative acknowledgment" to notify a sender that an error occurred - * while processing its message. The exception's stacktrace will be formatted - * as a string, serialized into a byte array, and sent as the message payload. - */ - def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { - val exceptionString = Utils.exceptionString(exception) - val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) - val errorMessage = createBufferMessage(serializedExceptionString, ackId) - errorMessage.hasError = true - errorMessage - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.hasError = header.hasError - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala deleted file mode 100644 index a4568e849fa13..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size: Int = if (buffer == null) 0 else buffer.remaining - - lazy val buffers: ArrayBuffer[ByteBuffer] = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } - - override def toString: String = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala deleted file mode 100644 index 7b3da4bb9d5ee..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.{InetAddress, InetSocketAddress} -import java.nio.ByteBuffer - -private[nio] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val hasError: Boolean, - val securityNeg: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). - putInt(securityNeg). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString: String = { - "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg - } - -} - - -private[nio] object MessageChunkHeader { - val HEADER_SIZE = 45 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val hasError = buffer.get() != 0 - val securityNeg = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, - new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala deleted file mode 100644 index b2aec160635c7..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} - -import scala.concurrent.Future - - -/** - * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom - * implementation using Java NIO. - */ -final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) - extends BlockTransferService with Logging { - - private var cm: ConnectionManager = _ - - private var blockDataManager: BlockDataManager = _ - - /** - * Port number the service is listening on, available only after [[init]] is invoked. - */ - override def port: Int = { - checkInit() - cm.id.port - } - - /** - * Host name the service is listening on, available only after [[init]] is invoked. - */ - override def hostName: String = { - checkInit() - cm.id.host - } - - /** - * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch - * local blocks or put local blocks. - */ - override def init(blockDataManager: BlockDataManager): Unit = { - this.blockDataManager = blockDataManager - cm = new ConnectionManager( - conf.getInt("spark.blockManager.port", 0), - conf, - securityManager, - "Connection manager for block manager") - cm.onReceiveMessage(onBlockMessageReceive) - } - - /** - * Tear down the transfer service. - */ - override def close(): Unit = { - if (cm != null) { - cm.stop() - } - } - - override def fetchBlocks( - host: String, - port: Int, - execId: String, - blockIds: Array[String], - listener: BlockFetchingListener): Unit = { - checkInit() - - val cmId = new ConnectionManagerId(host, port) - val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => - BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) - }) - - val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - - // Register the listener on success/failure future callback. - future.onSuccess { case message => - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - - // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. - if (blockMessageArray.isEmpty) { - blockIds.foreach { id => - listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) - } - } else { - for (blockMessage: BlockMessage <- blockMessageArray) { - val msgType = blockMessage.getType - if (msgType != BlockMessage.TYPE_GOT_BLOCK) { - if (blockMessage.getId != null) { - listener.onBlockFetchFailure(blockMessage.getId.toString, - new SparkException(s"Unexpected message $msgType received from $cmId")) - } - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioManagedBuffer(blockMessage.getData)) - } - } - } - }(cm.futureExecContext) - - future.onFailure { case exception => - blockIds.foreach { blockId => - listener.onBlockFetchFailure(blockId, exception) - } - }(cm.futureExecContext) - } - - /** - * Upload a single block to a remote node, available only after [[init]] is invoked. - * - * This call blocks until the upload completes, or throws an exception upon failures. - */ - override def uploadBlock( - hostname: String, - port: Int, - execId: String, - blockId: BlockId, - blockData: ManagedBuffer, - level: StorageLevel) - : Future[Unit] = { - checkInit() - val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) - val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) - val remoteCmId = new ConnectionManagerId(hostName, port) - val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) - reply.map(x => ())(cm.futureExecContext) - } - - private def checkInit(): Unit = if (cm == null) { - throw new IllegalStateException(getClass.getName + " has not been initialized") - } - - private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => - logError("Exception handling buffer message", e) - Some(Message.createErrorMessage(e, msg.id)) - } - - case otherMessage: Any => - val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" - logError(errorMsg) - Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) - } - } - - private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => - val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + msg + "]") - putBlock(msg.id, msg.data, msg.level) - None - - case BlockMessage.TYPE_GET_BLOCK => - val msg = new GetBlock(blockMessage.getId) - logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) - - case _ => None - } - } - - private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) - logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(blockId: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId) - logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer.nioByteBuffer() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala deleted file mode 100644 index 232c552f9865d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -import org.apache.spark._ - -/** - * SecurityMessage is class that contains the connectionId and sasl token - * used in SASL negotiation. SecurityMessage has routines for converting - * it to and from a BufferMessage so that it can be sent by the ConnectionManager - * and easily consumed by users when received. - * The api was modeled after BlockMessage. - * - * The connectionId is the connectionId of the client side. Since - * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * - * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side - * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This - * is where the connectionId field is used. node_0 can lookup the connectionId to see if - * it is in response to it being a client or if its in response to someone sending other data. - * - * The format of a SecurityMessage as its sent is: - * - Length of the ConnectionId - * - ConnectionId - * - Length of the token - * - Token - */ -private[nio] class SecurityMessage extends Logging { - - private var connectionId: String = null - private var token: Array[Byte] = null - - def set(byteArr: Array[Byte], newconnectionId: String) { - if (byteArr == null) { - token = new Array[Byte](0) - } else { - token = byteArr - } - connectionId = newconnectionId - } - - /** - * Read the given buffer and set the members of this class. - */ - def set(buffer: ByteBuffer) { - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - connectionId = idBuilder.toString() - - val tokenLength = buffer.getInt() - token = new Array[Byte](tokenLength) - if (tokenLength > 0) { - buffer.get(token, 0, tokenLength) - } - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getConnectionId: String = { - return connectionId - } - - def getToken: Array[Byte] = { - return token - } - - /** - * Create a BufferMessage that can be sent by the ConnectionManager containing - * the security information from this class. - * @return BufferMessage - */ - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes - // 4 bytes for the length of token - // token is a byte buffer so just take the length - var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) - buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) - buffer.putInt(token.length) - - if (token.length > 0) { - buffer.put(token) - } - buffer.flip() - buffers += buffer - - var message = Message.createBufferMessage(buffers) - logDebug("message total size is : " + message.size) - message.isSecurityNeg = true - return message - } - - override def toString: String = { - "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" - } -} - -private[nio] object SecurityMessage { - - /** - * Convert the given BufferMessage to a SecurityMessage by parsing the contents - * of the BufferMessage and populating the SecurityMessage fields. - * @param bufferMessage is a BufferMessage that was received - * @return new SecurityMessage - */ - def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(bufferMessage) - newSecurityMessage - } - - /** - * Create a SecurityMessage to send from a given saslResponse. - * @param response is the response to a challenge from the SaslClient or Saslserver - * @param connectionId the client connectionId we are negotiation authentication for - * @return a new SecurityMessage - */ - def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, connectionId) - newSecurityMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index b977711e7d5ad..c5195c1143a8f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -35,7 +35,6 @@ import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, Roaring import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ @@ -362,9 +361,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[PutBlock], - classOf[GotBlock], - classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], classOf[RoaringBitmap], diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala deleted file mode 100644 index 5e364cc0edeb2..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.nio._ - -import scala.concurrent.duration._ -import scala.concurrent.{Await, TimeoutException} -import scala.language.postfixOps - -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils - -/** - * Test the ConnectionManager with various security settings. - */ -class ConnectionManagerSuite extends SparkFunSuite { - - test("security default off") { - val conf = new SparkConf - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var receivedMessage = false - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - receivedMessage = true - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) - - assert(receivedMessage == true) - - manager.stop() - } - - test("security on same password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - conf.set("spark.app.id", "app-id") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - val managerServer = new ConnectionManager(0, conf, securityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val count = 10 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - }) - - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.app.id", "app-id") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = conf.clone.set("spark.authenticate.secret", "bad") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - // Expect managerServer to close connection, which we'll report as an error: - intercept[IOException] { - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - } - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "good") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 1).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - assert(false) - } catch { - case i: IOException => - assert(true) - case e: TimeoutException => { - // we should timeout here since the client can't do the negotiation - assert(true) - } - } - }) - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - manager.stop() - managerServer.stop() - } - - test("security auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 10).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - } catch { - case e: Exception => { - assert(false) - } - } - }) - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("Ack error message") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - val managerServer = new ConnectionManager(0, conf, securityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception("Custom exception text") - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - val exception = intercept[IOException] { - Await.result(future, 1 second) - } - assert(Utils.exceptionString(exception).contains("Custom exception text")) - - manager.stop() - managerServer.stop() - - } - - test("sendMessageReliably timeout") { - val clientConf = new SparkConf - clientConf.set("spark.authenticate", "false") - val ackTimeoutS = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") - - val clientSecurityManager = new SecurityManager(clientConf) - val manager = new ConnectionManager(0, clientConf, clientSecurityManager) - - val serverConf = new SparkConf - serverConf.set("spark.authenticate", "false") - val serverSecurityManager = new SecurityManager(serverConf) - val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeoutS * 3 * 1000) - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - // Future should throw IOException in 30 sec. - // Otherwise TimeoutExcepton is thrown from Await.result. - // We expect TimeoutException is not thrown. - intercept[IOException] { - Await.result(future, (ackTimeoutS * 2) second) - } - - manager.stop() - managerServer.stop() - } - -} - diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 0f5ba46f69c2f..eb5af70d57aec 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,10 +26,10 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -38,7 +38,7 @@ import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) @@ -59,7 +59,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e5b54d66c8157..34bb4952e7246 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -30,10 +30,10 @@ import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -44,7 +44,7 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var store: BlockManager = null var store2: BlockManager = null var store3: BlockManager = null @@ -66,7 +66,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") @@ -819,7 +819,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -833,7 +833,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1") == None, "a1 should not be in store") + assert(store.getSingle("a1").isEmpty, "a1 should not be in store") } } diff --git a/docs/configuration.md b/docs/configuration.md index 29a36bd67f28b..a2cc7a37e2240 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -382,17 +382,6 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. - - spark.shuffle.blockTransferService - netty - - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2, and nio block transfer is deprecated in Spark 1.5.0 and will - be removed in Spark 1.6.0. - - spark.shuffle.compress true diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 714ce3cd9b1de..3b8b6c8ffa375 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -32,635 +32,638 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") */ object MimaExcludes { - def excludes(version: String) = - version match { - case v if v.startsWith("1.5") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("network"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // JavaRDDLike is not meant to be extended by user programs - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.partitioner"), - // Modification of private static method - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), - // Mima false positive (was a private[spark] class) - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator"), - // Removing a testing method from a private class - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), - // While private MiMa is still not happy about the changes, - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticCostFun.this"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // The old JSON RDD is removed in favor of streaming Jackson - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), - // local function inside a method - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") - ) ++ Seq( - // SPARK-8479 Add numNonzeros and numActives to Matrix. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numActives") - ) ++ Seq( - // SPARK-8914 Remove RDDApi - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") - ) ++ Seq( - // SPARK-7292 Provide operator to truncate lineage cheaply - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.RDDCheckpointData"), - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.CheckpointRDD") - ) ++ Seq( - // SPARK-8701 Add input metadata in the batch page. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo") - ) ++ Seq( - // SPARK-6797 Support YARN modes for SparkR - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.PairwiseRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.createRWorker"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.StringRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.BaseRRDD.this") - ) ++ Seq( - // SPARK-7422 add argmax for sparse vectors - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.argmax") - ) ++ Seq( - // SPARK-8906 Move all internal data source classes into execution.datasources - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), - // SPARK-9763 Minimize exposure of internal SQL classes - excludePackage("org.apache.spark.sql.parquet"), - excludePackage("org.apache.spark.sql.json"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") - ) ++ Seq( - // SPARK-4751 Dynamic allocation for standalone mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.supportDynamicAllocation") - ) ++ Seq( - // SPARK-9580: Remove SQL test singletons - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.LocalSQLContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.TestSQLContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.TestSQLContext$") - ) ++ Seq( - // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.mllib.linalg.VectorUDT.serialize") - ) + def excludes(version: String) = version match { + case v if v.startsWith("1.6") => + Seq( + MimaBuild.excludeSparkPackage("network") + ) + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // The old JSON RDD is removed in favor of streaming Jackson + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") + ) ++ Seq( + // SPARK-8479 Add numNonzeros and numActives to Matrix. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numActives") + ) ++ Seq( + // SPARK-8914 Remove RDDApi + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-7292 Provide operator to truncate lineage cheaply + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.RDDCheckpointData"), + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.CheckpointRDD") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") + ) ++ Seq( + // SPARK-6797 Support YARN modes for SparkR + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.PairwiseRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.createRWorker"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.StringRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.BaseRRDD.this") + ) ++ Seq( + // SPARK-7422 add argmax for sparse vectors + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.argmax") + ) ++ Seq( + // SPARK-8906 Move all internal data source classes into execution.datasources + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") + ) ++ Seq( + // SPARK-4751 Dynamic allocation for standalone mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) - case v if v.startsWith("1.4") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // SPARK-7910 Adding a method to get the partioner to JavaRDD, - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.rdd.JdbcRDD.compute"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") - ) ++ Seq( - // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though - // the stage class is defined as private[spark] - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") - ) ++ Seq( - // SPARK-6510 Add a Graph#minus method acting as Set#difference - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") - ) ++ Seq( - // SPARK-6492 Fix deadlock in SparkContext.stop() - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + - "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") - )++ Seq( - // SPARK-6693 add tostring with max lines and width for matrix - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.toString") - )++ Seq( - // SPARK-6703 Add getOrCreate method to SparkContext - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") - )++ Seq( - // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.mllib.clustering.LDA$EMOptimizer") - ) ++ Seq( - // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.compressed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toDense"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toSparse"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numActives"), - // SPARK-7681 add SparseVector support for gemv - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.multiply") - ) ++ Seq( - // Execution should never be included as its always internal. - MimaBuild.excludeSparkPackage("sql.execution"), - // This `protected[sql]` method was removed in 1.3.1 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.checkAnalysis"), - // These `private[sql]` class were removed in 1.4.0: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), - // These test support classes were moved out of src/main and into src/test: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.TestGroupWriteSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), - // TODO: Remove the following rule once ParquetTest has been moved to src/test. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTest") - ) ++ Seq( - // SPARK-7530 Added StreamingContext.getState() - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.StreamingContext.state_=") - ) ++ Seq( - // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some - // unnecessary type bounds in order to fix some compiler warnings that occurred when - // implementing this interface in Java. Note that ShuffleWriter is private[spark]. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.shuffle.ShuffleWriter") - ) ++ Seq( - // SPARK-6888 make jdbc driver handling user definable - // This patch renames some classes to API friendly names. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") - ) + case v if v.startsWith("1.4") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives"), + // SPARK-7681 add SparseVector support for gemv + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.multiply") + ) ++ Seq( + // Execution should never be included as its always internal. + MimaBuild.excludeSparkPackage("sql.execution"), + // This `protected[sql]` method was removed in 1.3.1 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.checkAnalysis"), + // These `private[sql]` class were removed in 1.4.0: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), + // These test support classes were moved out of src/main and into src/test: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.TestGroupWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") + ) ++ Seq( + // SPARK-7530 Added StreamingContext.getState() + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") + ) - case v if v.startsWith("1.3") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in the 1.2 build. - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") - ) ++ Seq( - // SPARK-2321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfoImpl.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfo.submissionTime") - ) ++ Seq( - // SPARK-4614 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.randn"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.rand") - ) ++ Seq( - // SPARK-5321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.transpose"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + - "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.isTransposed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.foreachActive") - ) ++ Seq( - // SPARK-5540 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), - // SPARK-5536 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") - ) ++ Seq( - // SPARK-3325 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), - // SPARK-2757 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + - "removeAndGetProcessor") - ) ++ Seq( - // SPARK-5123 (SparkSQL data type change) - alpha component only - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.HashingTF.outputDataType"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.outputDataType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.validateInputType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") - ) ++ Seq( - // SPARK-4014 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.taskAttemptId"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.attemptNumber") - ) ++ Seq( - // SPARK-5166 Spark SQL API stabilization - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") - ) ++ Seq( - // SPARK-5270 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.isEmpty") - ) ++ Seq( - // SPARK-5430 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeReduce"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeAggregate") - ) ++ Seq( - // SPARK-5297 Java FileStream do not work with custom key/values - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") - ) ++ Seq( - // SPARK-5315 Spark Streaming Java API returns Scala DStream - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") - ) ++ Seq( - // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.getCheckpointFiles"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.isCheckpointed") - ) ++ Seq( - // SPARK-4789 Standardize ML Prediction APIs - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") - ) ++ Seq( - // SPARK-5814 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") - ) ++ Seq( - // SPARK-4682 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") - ) ++ Seq( - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") - ) + case v if v.startsWith("1.3") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in the 1.2 build. + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ) ++ Seq( + // SPARK-2321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfoImpl.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfo.submissionTime") + ) ++ Seq( + // SPARK-4614 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.randn"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") + ) ++ Seq( + // SPARK-5540 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), + // SPARK-5536 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") + ) ++ Seq( + // SPARK-3325 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), + // SPARK-2757 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + + "removeAndGetProcessor") + ) ++ Seq( + // SPARK-5123 (SparkSQL data type change) - alpha component only + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.HashingTF.outputDataType"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.outputDataType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.validateInputType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") + ) ++ Seq( + // SPARK-4014 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.taskAttemptId"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.attemptNumber") + ) ++ Seq( + // SPARK-5166 Spark SQL API stabilization + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") + ) ++ Seq( + // SPARK-5270 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") + ) ++ Seq( + // SPARK-5297 Java FileStream do not work with custom key/values + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") + ) ++ Seq( + // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.getCheckpointFiles"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-5814 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") + ) ++ Seq( + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") + ) - case v if v.startsWith("1.2") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ - Seq( - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.TaskLocation"), - // Added normL1 and normL2 to trait MultivariateStatisticalSummary - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), - // MapStatus should be private[spark] - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.PathResolver"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.client.BlockClientListener"), + case v if v.startsWith("1.2") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ + MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ + Seq( + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.TaskLocation"), + // Added normL1 and normL2 to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), + // MapStatus should be private[spark] + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener"), - // TaskContext was promoted to Abstract class - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.TaskContext"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.util.collection.SortDataFormat") - ) ++ Seq( - // Adding new methods to the JavaRDDLike trait: - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.takeAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.collectAsync") - ) ++ Seq( - // SPARK-3822 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-1209 - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.rdd.PairRDDFunctions") - ) ++ Seq( - // SPARK-4062 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") - ) + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.util.collection.SortDataFormat") + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") + ) ++ Seq( + // SPARK-3822 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") + ) - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // Should probably mark this as Experimental - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.getValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry") - ) ++ - Seq( - // Serializer interface change. See SPARK-3045. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.DeserializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.Serializer"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializerInstance") - )++ - Seq( - // Renamed putValues -> putArray + putIterator - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.TachyonStore.putValues") - ) ++ - Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.flume.FlumeReceiver.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver.this") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - // Class was missing "@DeveloperApi" annotation in 1.0. - MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ - Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Variance.calculate") - ) ++ - Seq( // Package-private classes removed in SPARK-2341 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ - Seq( // package-private classes removed in MLlib - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") - ) ++ - Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") - ) ++ - Seq( // synthetic methods generated in LabeledPoint - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") - ) ++ - Seq ( // Scala 2.11 compatibility fix - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") - ) - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") - case _ => Seq() - } -} + case v if v.startsWith("1.1") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // Should probably mark this as Experimental + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values + // for countApproxDistinct* functions, which does not work in Java. We later removed + // them, and use the following to tell Mima to not care about them. + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.getValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry") + ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ + Seq( + // Renamed putValues -> putArray + putIterator + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.TachyonStore.putValues") + ) ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + // Class was missing "@DeveloperApi" annotation in 1.0. + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) ++ + Seq( // Package-private classes removed in SPARK-2341 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") + ) ++ + Seq( // package-private classes removed in MLlib + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") + ) + case v if v.startsWith("1.0") => + Seq( + MimaBuild.excludeSparkPackage("api.java"), + MimaBuild.excludeSparkPackage("mllib"), + MimaBuild.excludeSparkPackage("streaming") + ) ++ + MimaBuild.excludeSparkClass("rdd.ClassTags") ++ + MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ + MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ + MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ + MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ + MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ + MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + case _ => Seq() + } +} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 6c0c926755c20..13cfe29d7b304 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -47,7 +47,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + val conf = new SparkConf() + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + .set("spark.app.id", "streaming-test") val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) @@ -184,7 +186,7 @@ class ReceivedBlockHandlerSuite } test("Test Block - isFullyConsumed") { - val sparkConf = new SparkConf() + val sparkConf = new SparkConf().set("spark.app.id", "streaming-test") sparkConf.set("spark.storage.unrollMemoryThreshold", "512") // spark.storage.unrollFraction set to 0.4 for BlockManager sparkConf.set("spark.storage.unrollFraction", "0.4") @@ -251,7 +253,7 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") From 9d8e838d883ed21f9ef562e7e3ac074c7e4adb88 Mon Sep 17 00:00:00 2001 From: Stephen Hopper Date: Tue, 8 Sep 2015 14:36:34 +0100 Subject: [PATCH 0004/2089] =?UTF-8?q?[DOC]=20Added=20R=20to=20the=20list?= =?UTF-8?q?=20of=20languages=20with=20"high-level=20API"=20support=20in=20?= =?UTF-8?q?the=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … main README. Author: Stephen Hopper Closes #8646 from enragedginger/master. --- README.md | 4 ++-- docs/quick-start.md | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 380422ca00dbe..76e29b4235666 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Apache Spark Spark is a fast and general cluster computing system for Big Data. It provides -high-level APIs in Scala, Java, and Python, and an optimized engine that +high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, @@ -94,5 +94,5 @@ distribution. ## Configuration -Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) +Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. diff --git a/docs/quick-start.md b/docs/quick-start.md index ce2cc9d2169cd..d481fe0ea6d70 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -126,7 +126,7 @@ scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (w wordCounts: spark.RDD[(String, Int)] = spark.ShuffledAggregatedRDD@71f027b8 {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight scala %} scala> wordCounts.collect() @@ -163,7 +163,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight python %} >>> wordCounts.collect() @@ -217,13 +217,13 @@ a cluster, as described in the [programming guide](programming-guide.html#initia # Self-Contained Applications -Now say we wanted to write a self-contained application using the Spark API. We will walk through a -simple application in both Scala (with SBT), Java (with Maven), and Python. +Suppose we wish to write a self-contained application using the Spark API. We will walk through a +simple application in Scala (with sbt), Java (with Maven), and Python.
-We'll create a very simple Spark application in Scala. So simple, in fact, that it's +We'll create a very simple Spark application in Scala--so simple, in fact, that it's named `SimpleApp.scala`: {% highlight scala %} @@ -259,7 +259,7 @@ object which contains information about our application. Our application depends on the Spark API, so we'll also include an sbt configuration file, -`simple.sbt` which explains that Spark is a dependency. This file also adds a repository that +`simple.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -302,7 +302,7 @@ Lines with a: 46, Lines with b: 23
-This example will use Maven to compile an application jar, but any similar build system will work. +This example will use Maven to compile an application JAR, but any similar build system will work. We'll create a very simple Spark application, `SimpleApp.java`: @@ -374,7 +374,7 @@ $ find . Now, we can package the application using Maven and execute it with `./bin/spark-submit`. {% highlight bash %} -# Package a jar containing your application +# Package a JAR containing your application $ mvn package ... [INFO] Building jar: {..}/{..}/target/simple-project-1.0.jar From 6ceed852ab716d8acc46ce90cba9cfcff6d3616f Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 8 Sep 2015 14:38:10 +0100 Subject: [PATCH 0005/2089] Docs small fixes Author: Jacek Laskowski Closes #8629 from jaceklaskowski/docs-fixes. --- docs/building-spark.md | 23 +++++++++++------------ docs/cluster-overview.md | 15 ++++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index f133eb96d9a21..4db32cfd628bc 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -61,12 +61,13 @@ If you don't run this, you may see errors like the following: You can fix this by setting the `MAVEN_OPTS` variable as discussed before. **Note:** -* *For Java 8 and above this step is not required.* -* *If using `build/mvn` and `MAVEN_OPTS` were not already set, the script will automate this for you.* + +* For Java 8 and above this step is not required. +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. # Specifying the Hadoop Version -Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: +Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the `hadoop.version` property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: @@ -91,7 +92,7 @@ mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package {% endhighlight %} -You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. +You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. Examples: @@ -125,7 +126,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip # Building for Scala 2.11 To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.11 mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package Spark does not yet support its JDBC component for Scala 2.11. @@ -163,11 +164,9 @@ the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: -``` - $ mvn install - $ cd core - $ mvn scala:cc -``` + $ mvn install + $ cd core + $ mvn scala:cc # Building Spark with IntelliJ IDEA or Eclipse @@ -193,11 +192,11 @@ then ship it over to the cluster. We are investigating the exact cause for this. # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT -Maven is the official recommendation for packaging Spark, and is the "build of reference". +Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. But SBT is supported for day-to-day development since it can provide much faster iterative compilation. More advanced developers may wish to use SBT. diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 7079de546e2f5..faaf154d243f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -5,18 +5,19 @@ title: Cluster Mode Overview This document gives a short overview of how Spark runs on clusters, to make it easier to understand the components involved. Read through the [application submission guide](submitting-applications.html) -to submit applications to a cluster. +to learn about launching applications on a cluster. # Components -Spark applications run as independent sets of processes on a cluster, coordinated by the SparkContext +Spark applications run as independent sets of processes on a cluster, coordinated by the `SparkContext` object in your main program (called the _driver program_). + Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_ -(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across +(either Spark's own standalone cluster manager, Mesos or YARN), which allocate resources across applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are processes that run computations and store data for your application. Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to -the executors. Finally, SparkContext sends *tasks* for the executors to run. +the executors. Finally, SparkContext sends *tasks* to the executors to run.

Spark cluster components @@ -33,9 +34,9 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. The driver program must listen for and accept incoming connections from its executors throughout - its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config - section](configuration.html#networking)). As such, the driver program must be network +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network addressable from the worker nodes. 4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the From 990c9f79c28db501018a0a3af446ff879962475d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Sep 2015 23:07:34 +0800 Subject: [PATCH 0006/2089] [SPARK-9170] [SQL] Use OrcStructInspector to be case preserving when writing ORC files JIRA: https://issues.apache.org/jira/browse/SPARK-9170 `StandardStructObjectInspector` will implicitly lowercase column names. But I think Orc format doesn't have such requirement. In fact, there is a `OrcStructInspector` specified for Orc format. We should use it when serialize rows to Orc file. It can be case preserving when writing ORC files. Author: Liang-Chi Hsieh Closes #7520 from viirya/use_orcstruct. --- .../spark/sql/hive/orc/OrcRelation.scala | 47 ++++++++++--------- .../spark/sql/hive/orc/OrcQuerySuite.scala | 14 ++++++ 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 4eeca9aec12bd..7e89109259955 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -25,9 +25,9 @@ import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct} +import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, StructTypeInfo} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -89,21 +89,10 @@ private[orc] class OrcOutputWriter( TypeInfoUtils.getTypeInfoFromTypeString( HiveMetastoreTypes.toMetastoreType(dataSchema)) - TypeInfoUtils - .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) - .asInstanceOf[StructObjectInspector] + OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) + .asInstanceOf[SettableStructObjectInspector] } - // Used to hold temporary `Writable` fields of the next row to be written. - private val reusableOutputBuffer = new Array[Any](dataSchema.length) - - // Used to convert Catalyst values into Hadoop `Writable`s. - private val wrappers = structOI.getAllStructFieldRefs.asScala - .zip(dataSchema.fields.map(_.dataType)) - .map { case (ref, dt) => - wrapperFor(ref.getFieldObjectInspector, dt) - }.toArray - // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. private var recordWriterInstantiated = false @@ -127,16 +116,32 @@ private[orc] class OrcOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - override protected[sql] def writeInternal(row: InternalRow): Unit = { + private def wrapOrcStruct( + struct: OrcStruct, + oi: SettableStructObjectInspector, + row: InternalRow): Unit = { + val fieldRefs = oi.getAllStructFieldRefs var i = 0 - while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) + while (i < fieldRefs.size) { + oi.setStructFieldData( + struct, + fieldRefs.get(i), + wrap( + row.get(i, dataSchema(i).dataType), + fieldRefs.get(i).getFieldObjectInspector, + dataSchema(i).dataType)) i += 1 } + } + + val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + wrapOrcStruct(cachedOrcStruct, structOI, row) recordWriter.write( NullWritable.get(), - serializer.serialize(reusableOutputBuffer, structOI)) + serializer.serialize(cachedOrcStruct, structOI)) } override def close(): Unit = { @@ -259,7 +264,7 @@ private[orc] case class OrcTableScan( maybeStructOI.map { soi => val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + soi.getStructFieldRef(attr.name) -> ordinal }.unzip val unwrappers = fieldRefs.map(unwrapperFor) // Map each tuple to a row object diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 744d462938141..8bc33fcf5d906 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -287,6 +287,20 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("SPARK-9170: Don't implicitly lowercase of user-provided columns") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path) + sqlContext.read.format("orc").load(path).schema("Acol") + intercept[IllegalArgumentException] { + sqlContext.read.format("orc").load(path).schema("acol") + } + checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"), + (0 until 10).map(Row(_))) + } + } + test("SPARK-8501: Avoids discovery schema from empty ORC files") { withTempPath { dir => val path = dir.getCanonicalPath From 5b2192e846b843d8a0cb9427d19bb677431194a0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 8 Sep 2015 11:11:35 -0700 Subject: [PATCH 0007/2089] [SPARK-10480] [ML] Fix ML.LinearRegressionModel.copy() This PR fix two model ```copy()``` related issues: [SPARK-10480](https://issues.apache.org/jira/browse/SPARK-10480) ```ML.LinearRegressionModel.copy()``` ignored argument ```extra```, it will not take effect when users setting this parameter. [SPARK-10479](https://issues.apache.org/jira/browse/SPARK-10479) ```ML.LogisticRegressionModel.copy()``` should copy model summary if available. Author: Yanbo Liang Closes #8641 from yanboliang/linear-regression-copy. --- .../apache/spark/ml/classification/LogisticRegression.scala | 4 +++- .../org/apache/spark/ml/regression/LinearRegression.scala | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21fbe38ca8233..a460262b87e43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) + val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel.setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 884003eb38524..e4602d36ccc87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -310,7 +310,7 @@ class LinearRegressionModel private[ml] ( } override def copy(extra: ParamMap): LinearRegressionModel = { - val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } From 5fd57955ef477347408f68eb1cb6ad1881fdb6e0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Sep 2015 12:05:41 -0700 Subject: [PATCH 0008/2089] [SPARK-10316] [SQL] respect nondeterministic expressions in PhysicalOperation We did a lot of special handling for non-deterministic expressions in `Optimizer`. However, `PhysicalOperation` just collects all Projects and Filters and mess it up. We should respect the operators order caused by non-deterministic expressions in `PhysicalOperation`. Author: Wenchen Fan Closes #8486 from cloud-fan/fix. --- .../sql/catalyst/planning/patterns.scala | 38 ++++--------------- .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++ 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index e8abcd63f7d85..53537799517ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,35 +17,12 @@ package org.apache.spark.sql.catalyst.planning -import scala.annotation.tailrec - import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -/** - * A pattern that matches any number of filter operations on top of another relational operator. - * Adjacent filter operators are collected and their conditions are broken up and returned as a - * sequence of conjunctive predicates. - * - * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the - * output and a relational operator. - */ -object FilteredOperation extends PredicateHelper { - type ReturnType = (Seq[Expression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan)) - - @tailrec - private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match { - case Filter(condition, child) => - collectFilters(filters ++ splitConjunctivePredicates(condition), child) - case other => (filters, other) - } -} - /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -62,8 +39,9 @@ object PhysicalOperation extends PredicateHelper { } /** - * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two - * examples for alias in-lining/substitution. Before: + * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. + * Here are two examples for alias in-lining/substitution. + * Before: * {{{ * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 @@ -74,15 +52,15 @@ object PhysicalOperation extends PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - def collectProjectsAndFilters(plan: LogicalPlan): + private def collectProjectsAndFilters(plan: LogicalPlan): (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { - case Project(fields, child) => + case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) - case Filter(condition, child) => + case Filter(condition, child) if condition.deterministic => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) @@ -91,11 +69,11 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { + private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b5b9f11785074..dbed4fc247140 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -22,6 +22,8 @@ import java.io.File import scala.language.postfixOps import scala.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -895,4 +897,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .orderBy(sum('j)) checkAnswer(query, Row(1, 2)) } + + test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { + val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + (1 to 10).map(i => s"""{"id": $i}"""))) + + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } + } } From f7b55dbfc3343cad988e2490478fce1a11343c73 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 8 Sep 2015 12:48:21 -0700 Subject: [PATCH 0009/2089] [SPARK-10470] [ML] ml.IsotonicRegressionModel.copy should set parent Copied model must have the same parent, but ml.IsotonicRegressionModel.copy did not set parent. Here fix it and add test case. Author: Yanbo Liang Closes #8637 from yanboliang/spark-10470. --- .../org/apache/spark/ml/regression/IsotonicRegression.scala | 2 +- .../apache/spark/ml/regression/IsotonicRegressionSuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index d43a3447d3975..2ff500f291abc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -203,7 +203,7 @@ class IsotonicRegressionModel private[ml] ( def predictions: Vector = Vectors.dense(oldModel.predictions) override def copy(extra: ParamMap): IsotonicRegressionModel = { - copyValues(new IsotonicRegressionModel(uid, oldModel), extra) + copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } override def transform(dataset: DataFrame): DataFrame = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index c0ab00b68a2f3..59f4193abc8f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -89,6 +90,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ir.getFeatureIndex === 0) val model = ir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "features", "prediction", "weight") .collect() From 7a9dcbc91d55dbc0cbf4812319bde65f4509b467 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 8 Sep 2015 14:10:12 -0700 Subject: [PATCH 0010/2089] [SPARK-10441] [SQL] Save data correctly to json. https://issues.apache.org/jira/browse/SPARK-10441 Author: Yin Huai Closes #8597 from yhuai/timestampJson. --- .../spark/sql/RandomDataGenerator.scala | 41 +++++++++- .../datasources/json/JacksonGenerator.scala | 11 ++- .../datasources/json/JacksonParser.scala | 31 ++++++++ .../hive/orc/OrcHadoopFsRelationSuite.scala | 8 ++ .../sources/JsonHadoopFsRelationSuite.scala | 8 ++ .../ParquetHadoopFsRelationSuite.scala | 9 ++- .../SimpleTextHadoopFsRelationSuite.scala | 19 ++++- .../sql/sources/SimpleTextRelation.scala | 7 +- .../sql/sources/hadoopFsRelationSuites.scala | 79 +++++++++++++++++++ 9 files changed, 205 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 11e0c120f4072..4025cbcec1019 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -23,6 +23,8 @@ import java.math.MathContext import scala.util.Random +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -84,6 +86,7 @@ object RandomDataGenerator { * random data generator is defined for that data type. The generated values will use an external * representation of the data type; for example, the random generator for [[DateType]] will return * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. + * For a [[UserDefinedType]] for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated @@ -106,7 +109,22 @@ object RandomDataGenerator { }) case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt())) - case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) + case TimestampType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + // DateTimeUtils.toJavaTimestamp takes microsecond. + DateTimeUtils.toJavaTimestamp(milliseconds * 1000) + } + Some(generator) case CalendarIntervalType => Some(() => { val months = rand.nextInt(1000) val ns = rand.nextLong() @@ -159,6 +177,27 @@ object RandomDataGenerator { None } } + case udt: UserDefinedType[_] => { + val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed) + // Because random data generator at here returns scala value, we need to + // convert it to catalyst value to call udt's deserialize. + val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) + + if (maybeSqlTypeGenerator.isDefined) { + val sqlTypeGenerator = maybeSqlTypeGenerator.get + val generator = () => { + val generatedScalaValue = sqlTypeGenerator.apply() + if (generatedScalaValue == null) { + null + } else { + udt.deserialize(toCatalystType(generatedScalaValue)) + } + } + Some(generator) + } else { + None + } + } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 330ba907b2ef9..f65c7bbd6e29d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import scala.collection.Map @@ -89,7 +90,7 @@ private[sql] object JacksonGenerator { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v) => gen.writeString(v.toString) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) + case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) case (FloatType, v: Float) => gen.writeNumber(v) @@ -99,8 +100,12 @@ private[sql] object JacksonGenerator { case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) + case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) + // For UDT values, they should be in the SQL type's corresponding value type. + // We should not see values in the user-defined class at here. + // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is + // an ArrayData at here, instead of a Vector. + case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) case (ArrayType(ty, _), v: ArrayData) => gen.writeStartArray() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index cd68bd667c5c4..ff4d8c04e8eaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -81,9 +81,37 @@ private[sql] object JacksonParser { case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => parser.getFloatValue + case (VALUE_STRING, FloatType) => + // Special case handling for NaN and Infinity. + val value = parser.getText + val lowerCaseValue = value.toLowerCase() + if (lowerCaseValue.equals("nan") || + lowerCaseValue.equals("infinity") || + lowerCaseValue.equals("-infinity") || + lowerCaseValue.equals("inf") || + lowerCaseValue.equals("-inf")) { + value.toFloat + } else { + sys.error(s"Cannot parse $value as FloatType.") + } + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => parser.getDoubleValue + case (VALUE_STRING, DoubleType) => + // Special case handling for NaN and Infinity. + val value = parser.getText + val lowerCaseValue = value.toLowerCase() + if (lowerCaseValue.equals("nan") || + lowerCaseValue.equals("infinity") || + lowerCaseValue.equals("-infinity") || + lowerCaseValue.equals("inf") || + lowerCaseValue.equals("-inf")) { + value.toDouble + } else { + sys.error(s"Cannot parse $value as DoubleType.") + } + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => Decimal(parser.getDecimalValue, dt.precision, dt.scale) @@ -126,6 +154,9 @@ private[sql] object JacksonParser { case (_, udt: UserDefinedType[_]) => convertField(factory, parser, udt.sqlType) + + case (token, dataType) => + sys.error(s"Failed to parse a value for data type $dataType (current token: $token).") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 9a299c3f9d1f3..92043d66c914f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -28,6 +28,14 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + // ORC does not play well with NullType and UDT. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _: UserDefinedType[_] => false + case _ => true + } + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 1945b15002337..ef37787137d07 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,6 +28,14 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" + // JSON does not write data of NullType and does not play well with BinaryType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: BinaryType => false + case _: CalendarIntervalType => false + case _ => true + } + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 08c3c17973043..e2d754e806403 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.{execution, AnalysisException, SaveMode} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -32,6 +32,13 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "parquet" + // Parquet does not play well with NullType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _ => true + } + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 1125ca670107b..a3a124488d983 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -20,11 +20,28 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + // We have a very limited number of supported types at here since it is just for a + // test relation and we do very basic testing at here. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: BinaryType => false + // We are using random data generator and the generated strings are not really valid string. + case _: StringType => false + case _: BooleanType => false // see https://issues.apache.org/jira/browse/SPARK-10442 + case _: CalendarIntervalType => false + case _: DateType => false + case _: TimestampType => false + case _: ArrayType => false + case _: MapType => false + case _: StructType => false + case _: UserDefinedType[_] => false + case _ => true + } + test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 527ca7a81cad8..aeaaa3e1c5220 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -68,7 +68,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { - val serialized = row.toSeq.map(_.toString).mkString(",") + val serialized = row.toSeq.map { v => + if (v == null) "" else v.toString + }.mkString(",") recordWriter.write(null, new Text(serialized)) } @@ -112,7 +114,8 @@ class SimpleTextRelation( val fields = dataSchema.map(_.dataType) sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => - Row(record.split(",").zip(fields).map { case (value, dataType) => + Row(record.split(",", -1).zip(fields).map { case (v, dataType) => + val value = if (v == "") null else v // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) val catalystValue = Cast(Literal(value), dataType).eval() // Here we're converting Catalyst values to Scala values to test `needsConversion` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2ad2618dfc436..24f43cf7c15ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -38,6 +38,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val dataSourceName: String + protected def supportsDataType(dataType: DataType): Boolean = true + val dataSchema = StructType( Seq( @@ -98,6 +100,83 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } + test("test all data types") { + withTempPath { file => + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + // TODO: add CalendarIntervalType to here once we can save it out. + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schema, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schema")) + val data = (1 to 10).map { i => + dataGenerator.apply() match { + case row: Row => row + case null => Row.fromSeq(Seq.fill(schema.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 10) + val df = sqlContext.createDataFrame(rdd, schema) + + // All columns that have supported data types of this source. + val supportedColumns = schema.fields.collect { + case StructField(name, dataType, _, _) if supportsDataType(dataType) => name + } + val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) + + val dfToBeSaved = df.selectExpr(selectedColumns: _*) + + // Save the data out. + dfToBeSaved + .write + .format(dataSourceName) + .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. + .save(file.getCanonicalPath) + + val loadedDF = + sqlContext + .read + .format(dataSourceName) + .schema(dfToBeSaved.schema) + .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. + .load(file.getCanonicalPath) + .selectExpr(selectedColumns: _*) + + // Read the data back. + checkAnswer( + loadedDF, + dfToBeSaved + ) + } + } + test("save()/load() - non-partitioned table - Overwrite") { withTempPath { file => testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) From e6f8d3686016a305a747c5bcc85f46fd4c0cbe83 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 8 Sep 2015 14:44:05 -0700 Subject: [PATCH 0011/2089] [SPARK-10468] [ MLLIB ] Verify schema before Dataframe select API call Loader.checkSchema was called to verify the schema after dataframe.select(...). Schema verification should be done before dataframe.select(...) Author: Vinod K C Closes #8636 from vinodkc/fix_GaussianMixtureModel_load_verification. --- .../apache/spark/mllib/clustering/GaussianMixtureModel.scala | 3 +-- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 7f6163e04bf17..a5902190d4637 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -168,10 +168,9 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) val dataFrame = sqlContext.read.parquet(dataPath) - val dataArray = dataFrame.select("weight", "mu", "sigma").collect() - // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() val (weights, gaussians) = dataArray.map { case Row(weight: Double, mu: Vector, sigma: Matrix) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 36b124c5d2966..58857c338f546 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -590,12 +590,10 @@ object Word2VecModel extends Loader[Word2VecModel] { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) val dataFrame = sqlContext.read.parquet(dataPath) - - val dataArray = dataFrame.select("word", "vector").collect() - // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) + val dataArray = dataFrame.select("word", "vector").collect() val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap new Word2VecModel(word2VecMap) } From 52b24a602ad615a7f6aa427aefb1c7444c05d298 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 8 Sep 2015 14:54:43 -0700 Subject: [PATCH 0012/2089] [SPARK-10492] [STREAMING] [DOCUMENTATION] Update Streaming documentation about rate limiting and backpressure Author: Tathagata Das Closes #8656 from tdas/SPARK-10492 and squashes the following commits: 986cdd6 [Tathagata Das] Added information on backpressure --- docs/configuration.md | 13 +++++++++++++ docs/streaming-programming-guide.md | 13 ++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index a2cc7a37e2240..e287591f3fda1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1433,6 +1433,19 @@ Apart from these, the following properties are also available, and may be useful #### Spark Streaming

+ + + + + diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index a1acf83f75245..c751dbb41785a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1807,7 +1807,7 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, +- *Configuring write ahead logs* - Since Spark 1.2, we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver @@ -1822,6 +1822,17 @@ To run a Spark Streaming applications, you need to have the following. stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. +- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming + application to process data as fast as it is being received, the receivers can be rate limited + by setting a maximum rate limit in terms of records / sec. + See the [configuration parameters](configuration.html#spark-streaming) + `spark.streaming.receiver.maxRate` for receivers and `spark.streaming.kafka.maxRatePerPartition` + for Direct Kafka approach. In Spark 1.5, we have introduced a feature called *backpressure* that + eliminate the need to set this rate limit, as Spark Streaming automatically figures out the + rate limits and dynamically adjusts them if the processing conditions change. This backpressure + can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) + `spark.streaming.backpressure.enabled` to `true`. + ### Upgrading Application Code {:.no_toc} From d637a666d5932002c8ce0bd23c06064fbfdc1c97 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 8 Sep 2015 16:16:50 -0700 Subject: [PATCH 0013/2089] [SPARK-10327] [SQL] Cache Table is not working while subquery has alias in its project list ```scala import org.apache.spark.sql.hive.execution.HiveTableScan sql("select key, value, key + 1 from src").registerTempTable("abc") cacheTable("abc") val sparkPlan = sql( """select a.key, b.key, c.key from |abc a join abc b on a.key=b.key |join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) // failed assert(sparkPlan.collect { case e: HiveTableScan => e }.size === 0) // failed ``` The actual plan is: ``` == Parsed Logical Plan == 'Project [unresolvedalias('a.key),unresolvedalias('b.key),unresolvedalias('c.key)] 'Join Inner, Some(('a.key = 'c.key)) 'Join Inner, Some(('a.key = 'b.key)) 'UnresolvedRelation [abc], Some(a) 'UnresolvedRelation [abc], Some(b) 'UnresolvedRelation [abc], Some(c) == Analyzed Logical Plan == key: int, key: int, key: int Project [key#14,key#61,key#66] Join Inner, Some((key#14 = key#66)) Join Inner, Some((key#14 = key#61)) Subquery a Subquery abc Project [key#14,value#15,(key#14 + 1) AS _c2#16] MetastoreRelation default, src, None Subquery b Subquery abc Project [key#61,value#62,(key#61 + 1) AS _c2#58] MetastoreRelation default, src, None Subquery c Subquery abc Project [key#66,value#67,(key#66 + 1) AS _c2#63] MetastoreRelation default, src, None == Optimized Logical Plan == Project [key#14,key#61,key#66] Join Inner, Some((key#14 = key#66)) Project [key#14,key#61] Join Inner, Some((key#14 = key#61)) Project [key#14] InMemoryRelation [key#14,value#15,_c2#16], true, 10000, StorageLevel(true, true, false, true, 1), (Project [key#14,value#15,(key#14 + 1) AS _c2#16]), Some(abc) Project [key#61] MetastoreRelation default, src, None Project [key#66] MetastoreRelation default, src, None == Physical Plan == TungstenProject [key#14,key#61,key#66] BroadcastHashJoin [key#14], [key#66], BuildRight TungstenProject [key#14,key#61] BroadcastHashJoin [key#14], [key#61], BuildRight ConvertToUnsafe InMemoryColumnarTableScan [key#14], (InMemoryRelation [key#14,value#15,_c2#16], true, 10000, StorageLevel(true, true, false, true, 1), (Project [key#14,value#15,(key#14 + 1) AS _c2#16]), Some(abc)) ConvertToUnsafe HiveTableScan [key#61], (MetastoreRelation default, src, None) ConvertToUnsafe HiveTableScan [key#66], (MetastoreRelation default, src, None) ``` Author: Cheng Hao Closes #8494 from chenghao-intel/weird_cache. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 15 ++++++++++++--- .../org/apache/spark/sql/CachedTableSuite.scala | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9bb466ac2d29c..8f8747e105932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -135,16 +135,25 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { val input = children.flatMap(_.output) + def cleanExpression(e: Expression) = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) + BindReferences.bindReference(cleanedExprId, input, allowFailures = true) + case other => BindReferences.bindReference(other, input, allowFailures = true) + } + productIterator.map { // Children are checked using sameResult above. case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case s: Option[_] => s.map { - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case other => other } case s: Seq[_] => s.map { - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case other => other } case other => other diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 3a3541a8429b6..84e66b50cf1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.PhysicalRDD + import scala.concurrent.duration._ import scala.language.postfixOps @@ -338,4 +340,18 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { assert((accsSize - 2) == Accumulators.originals.size) } } + + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { + ctx.sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) + .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") + ctx.cacheTable("abc") + + val sparkPlan = sql( + """select a.key, b.key, c.key from + |abc a join abc b on a.key=b.key + |join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan + + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) + assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) + } } From 2143d592c802ec8f83a1eb5ce9b33ad8e48d7196 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Sep 2015 16:51:45 -0700 Subject: [PATCH 0014/2089] [HOTFIX] Fix build break caused by #8494 Author: Michael Armbrust Closes #8659 from marmbrus/testBuildBreak. --- .../test/scala/org/apache/spark/sql/CachedTableSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 84e66b50cf1d7..356d4ff3fa837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -342,9 +342,9 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { - ctx.sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) + sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") - ctx.cacheTable("abc") + sqlContext.cacheTable("abc") val sparkPlan = sql( """select a.key, b.key, c.key from From ae74c3fa84885dac18f39368e6dda48a2f7a25a5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 8 Sep 2015 17:36:00 -0700 Subject: [PATCH 0015/2089] [RELEASE] Add more contributors & only show names in release notes. Author: Reynold Xin Closes #8660 from rxin/contrib. --- dev/create-release/generate-contributors.py | 20 +++++++++------ dev/create-release/known_translations | 27 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 8aaa250bd7e29..db9c680a4bad3 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -178,13 +178,16 @@ def populate(issue_type, components): author_info[author][issue_type].add(component) # Find issues and components associated with this commit for issue in issues: - jira_issue = jira_client.issue(issue) - jira_type = jira_issue.fields.issuetype.name - jira_type = translate_issue_type(jira_type, issue, warnings) - jira_components = [translate_component(c.name, _hash, warnings)\ - for c in jira_issue.fields.components] - all_components = set(jira_components + commit_components) - populate(jira_type, all_components) + try: + jira_issue = jira_client.issue(issue) + jira_type = jira_issue.fields.issuetype.name + jira_type = translate_issue_type(jira_type, issue, warnings) + jira_components = [translate_component(c.name, _hash, warnings)\ + for c in jira_issue.fields.components] + all_components = set(jira_components + commit_components) + populate(jira_type, all_components) + except Exception as e: + print "Unexpected error:", e # For docs without an associated JIRA, manually add it ourselves if is_docs(title) and not issues: populate("documentation", commit_components) @@ -223,7 +226,8 @@ def populate(issue_type, components): # E.g. andrewor14/SPARK-3425/SPARK-1157/SPARK-6672 if author in invalid_authors and invalid_authors[author]: author = author + "/" + "/".join(invalid_authors[author]) - line = " * %s -- %s" % (author, contribution) + #line = " * %s -- %s" % (author, contribution) + line = author contributors_file.write(line + "\n") contributors_file.close() print "Contributors list is successfully written to %s!" % contributors_file_name diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index e462302f28423..3563fe3cc3c03 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -138,3 +138,30 @@ lee19 - Lee lockwobr - Brian Lockwood navis - Navis Ryu pparkkin - Paavo Parkkinen +HyukjinKwon - Hyukjin Kwon +JDrit - Joseph Batchik +JuhongPark - Juhong Park +KaiXinXiaoLei - KaiXinXIaoLei +NamelessAnalyst - NamelessAnalyst +alyaxey - Alex Slusarenko +baishuo - Shuo Bai +fe2s - Oleksiy Dyagilev +felixcheung - Felix Cheung +feynmanliang - Feynman Liang +josepablocam - Jose Cambronero +kai-zeng - Kai Zeng +mosessky - mosessky +msannell - Michael Sannella +nishkamravi2 - Nishkam Ravi +noel-smith - Noel Smith +petz2000 - Patrick Baier +qiansl127 - Shilei Qian +rahulpalamuttam - Rahul Palamuttam +rowan000 - Rowan Chattaway +sarutak - Kousuke Saruta +sethah - Seth Hendrickson +small-wang - Wang Wei +stanzhai - Stan Zhai +tien-dungle - Tien-Dung Le +xuchenCN - Xu Chen +zhangjiajin - Zhang JiaJin From 084e83c38d1c5dfc1c07aeb30ee1bb3c1087954b Mon Sep 17 00:00:00 2001 From: Gayathri Murali Date: Tue, 8 Sep 2015 19:17:49 -0700 Subject: [PATCH 0016/2089] [SPARK-10380] - Confusing examples in pyspark SQL docs --- python/pyspark/sql/column.py | 27 ++++++++++++++++++++++-- python/pyspark/sql/dataframe.py | 37 +++++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0338825839987..1d7dfd1b0a3f5 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -330,7 +330,7 @@ def cast(self, dataType): >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] - (Note : astype is alias for cast) + (Note : cast is alias for astype) """ if isinstance(dataType, basestring): jc = self._jc.cast(dataType) @@ -343,7 +343,30 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) - astype = cast + @ignore_unicode_prefix + @since(1.3) + def astype(self, dataType): + """ Convert the column into type ``dataType``. + + >>> df.select(df.age.astype("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.astype(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + + (Note : astype is alias for cast) + """ + if isinstance(dataType, basestring): + jc = self._jc.astype(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.astype(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + # astype = cast @since(1.3) def between(self, lowerBound, upperBound): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 71df09a83974d..e7d138a6eb143 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -944,7 +944,7 @@ def dropDuplicates(self, subset=None): | 5| 80|Alice| +---+------+-----+ - (Note: drop_duplicates is an alias for dropDuplicates) + (Note: dropDuplicates is an alias for drop_duplicates) """ if subset is None: jdf = self._jdf.dropDuplicates() @@ -952,6 +952,39 @@ def dropDuplicates(self, subset=None): jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx) + @since(1.4) + def drop_duplicates(self, subset=None): + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only considering certain columns. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([ \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=10, height=80)]).toDF() + >>> df.drop_duplicates().show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + | 10| 80|Alice| + +---+------+-----+ + + >>> df.drop_duplicates(['name', 'height']).show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + +---+------+-----+ + + (Note: drop_duplicates is an alias for dropDuplicates) + """ + if subset is None: + jdf = self._jdf.drop_duplicates() + else: + jdf = self._jdf.drop_duplicates(self._jseq(subset)) + return DataFrame(jdf, self.sql_ctx) + @since("1.3.1") def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. @@ -1277,7 +1310,7 @@ def toPandas(self): ########################################################################################## groupby = groupBy - drop_duplicates = dropDuplicates + # drop_duplicates = dropDuplicates # Having SchemaRDD for backward compatibility (for docs) From 820913f554bef610d07ca2dadaead657f916ae63 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 8 Sep 2015 20:39:15 -0700 Subject: [PATCH 0017/2089] [SPARK-10071] [STREAMING] Output a warning when writing QueueInputDStream and throw a better exception when reading QueueInputDStream Output a warning when serializing QueueInputDStream rather than throwing an exception to allow unit tests use it. Moreover, this PR also throws an better exception when deserializing QueueInputDStream to make the user find out the problem easily. The previous exception is hard to understand: https://issues.apache.org/jira/browse/SPARK-8553 Author: zsxwing Closes #8624 from zsxwing/SPARK-10071 and squashes the following commits: 847cfa8 [zsxwing] Output a warning when writing QueueInputDStream and throw a better exception when reading QueueInputDStream --- .../apache/spark/streaming/Checkpoint.scala | 6 ++-- .../streaming/dstream/QueueInputDStream.scala | 9 ++++-- .../streaming/StreamingContextSuite.scala | 28 +++++++++++++------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 3985e1a3d9dfa..27024ecfd9101 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -321,7 +321,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) - val compressionCodec = CompressionCodec.createCodec(conf) + var readError: Exception = null checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { @@ -332,13 +332,15 @@ object CheckpointReader extends Logging { return Some(cp) } catch { case e: Exception => + readError = e logWarning("Error reading checkpoint from file " + file, e) } }) // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + throw new SparkException( + s"Failed to read checkpoint from directory $checkpointPath", readError) } None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index a2f5d82a79bd3..bab78a3536b47 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import java.io.{NotSerializableException, ObjectOutputStream} +import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag @@ -37,8 +37,13 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def readObject(in: ObjectInputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.") + } + private def writeObject(oos: ObjectOutputStream): Unit = { - throw new NotSerializableException("queueStream doesn't support checkpointing") + logWarning("queueStream doesn't support checkpointing") } override def compute(validTime: Time): Option[RDD[T]] = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 7423ef6bcb6ea..d26894e88fc26 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel @@ -726,16 +726,26 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } test("queueStream doesn't support checkpointing") { - val checkpointDir = Utils.createTempDir() - ssc = new StreamingContext(master, appName, batchDuration) - val rdd = ssc.sparkContext.parallelize(1 to 10) - ssc.queueStream[Int](Queue(rdd)).print() - ssc.checkpoint(checkpointDir.getAbsolutePath) - val e = intercept[NotSerializableException] { - ssc.start() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + def creatingFunction(): StreamingContext = { + val _ssc = new StreamingContext(conf, batchDuration) + val rdd = _ssc.sparkContext.parallelize(1 to 10) + _ssc.checkpoint(checkpointDirectory) + _ssc.queueStream[Int](Queue(rdd)).register() + _ssc + } + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + val e = intercept[SparkException] { + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) } // StreamingContext.validate changes the message, so use "contains" here - assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + assert(e.getCause.getMessage.contains("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.")) } def addInputStream(s: StreamingContext): DStream[Int] = { From 52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 8 Sep 2015 20:51:20 -0700 Subject: [PATCH 0018/2089] [SPARK-9834] [MLLIB] implement weighted least squares via normal equation The goal of this PR is to have a weighted least squares implementation that takes the normal equation approach, and hence to be able to provide R-like summary statistics and support IRLS (used by GLMs). The tests match R's lm and glmnet. There are couple TODOs that can be addressed in future PRs: * consolidate summary statistics aggregators * move `dspr` to `BLAS` * etc It would be nice to have this merged first because it blocks couple other features. dbtsai Author: Xiangrui Meng Closes #8588 from mengxr/SPARK-9834. --- .../spark/ml/optim/WeightedLeastSquares.scala | 296 ++++++++++++++++++ .../org/apache/spark/mllib/linalg/BLAS.scala | 7 + .../mllib/linalg/distributed/RowMatrix.scala | 3 +- .../ml/optim/WeightedLeastSquaresSuite.scala | 133 ++++++++ 4 files changed, 438 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala new file mode 100644 index 0000000000000..a99e2ac4c6913 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[WeightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + */ +private[ml] class WeightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double) extends Serializable + +/** + * Weighted least squares solver via normal equation. + * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares + * formulation: + * + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i + * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^, + * + * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by + * [[standardizeLabel]] and [[standardizeFeatures]], respectively. + * + * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to + * match R's `lm`. + * Turn on [[standardizeLabel]] to match R's `glmnet`. + * + * @param fitIntercept whether to fit intercept. If false, z is 0.0. + * @param regParam L2 regularization parameter (lambda) + * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the + * population standard deviation of the j-th column of A. Otherwise, + * sigma,,j,, is 1.0. + * @param standardizeLabel whether to standardize label. If true, delta is the population standard + * deviation of the label column b. Otherwise, delta is 1.0. + */ +private[ml] class WeightedLeastSquares( + val fitIntercept: Boolean, + val regParam: Double, + val standardizeFeatures: Boolean, + val standardizeLabel: Boolean) extends Logging with Serializable { + import WeightedLeastSquares._ + + require(regParam >= 0.0, s"regParam cannot be negative: $regParam") + if (regParam == 0.0) { + logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + + /** + * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. + */ + def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) + summary.validate() + logInfo(s"Number of instances: ${summary.count}.") + val triK = summary.triK + val bBar = summary.bBar + val bStd = summary.bStd + val aBar = summary.aBar + val aVar = summary.aVar + val abBar = summary.abBar + val aaBar = summary.aaBar + val aaValues = aaBar.values + + if (fitIntercept) { + // shift centers + // A^T A - aBar aBar^T + RowMatrix.dspr(-1.0, aBar, aaValues) + // A^T b - bBar aBar + BLAS.axpy(-bBar, aBar, abBar) + } + + // add regularization to diagonals + var i = 0 + var j = 2 + while (i < triK) { + var lambda = regParam + if (standardizeFeatures) { + lambda *= aVar(j - 2) + } + if (standardizeLabel) { + // TODO: handle the case when bStd = 0 + lambda /= bStd + } + aaValues(i) += lambda + i += j + j += 1 + } + + val x = choleskySolve(aaBar.values, abBar) + + // compute intercept + val intercept = if (fitIntercept) { + bBar - BLAS.dot(aBar, x) + } else { + 0.0 + } + + new WeightedLeastSquaresModel(x, intercept) + } + + /** + * Solves a symmetric positive definite linear system via Cholesky factorization. + * The input arguments are modified in-place to store the factorization and the solution. + * @param A the upper triangular part of A + * @param bx right-hand side + * @return the solution vector + */ + // TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS + private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = { + val k = bx.size + val info = new intW(0) + lapack.dppsv("U", k, 1, A, bx.values, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dpotrs returned $code.") + bx + } +} + +private[ml] object WeightedLeastSquares { + + /** + * Case class for weighted observations. + * @param w weight, must be positive + * @param a features + * @param b label + */ + case class Instance(w: Double, a: Vector, b: Double) { + require(w >= 0.0, s"Weight cannot be negative: $w.") + } + + /** + * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. + */ + // TODO: consolidate aggregates for summary statistics + private class Aggregator extends Serializable { + var initialized: Boolean = false + var k: Int = _ + var count: Long = _ + var triK: Int = _ + private var wSum: Double = _ + private var wwSum: Double = _ + private var bSum: Double = _ + private var bbSum: Double = _ + private var aSum: DenseVector = _ + private var abSum: DenseVector = _ + private var aaSum: DenseVector = _ + + private def init(k: Int): Unit = { + require(k <= 4096, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to 4096 but got $k.") + this.k = k + triK = k * (k + 1) / 2 + count = 0L + wSum = 0.0 + wwSum = 0.0 + bSum = 0.0 + bbSum = 0.0 + aSum = new DenseVector(Array.ofDim(k)) + abSum = new DenseVector(Array.ofDim(k)) + aaSum = new DenseVector(Array.ofDim(triK)) + initialized = true + } + + /** + * Adds an instance. + */ + def add(instance: Instance): this.type = { + val Instance(w, a, b) = instance + val ak = a.size + if (!initialized) { + init(ak) + initialized = true + } + assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") + count += 1L + wSum += w + wwSum += w * w + bSum += w * b + bbSum += w * b * b + BLAS.axpy(w, a, aSum) + BLAS.axpy(w * b, a, abSum) + RowMatrix.dspr(w, a, aaSum.values) + this + } + + /** + * Merges another [[Aggregator]]. + */ + def merge(other: Aggregator): this.type = { + if (!other.initialized) { + this + } else { + if (!initialized) { + init(other.k) + } + assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}") + count += other.count + wSum += other.wSum + wwSum += other.wwSum + bSum += other.bSum + bbSum += other.bbSum + BLAS.axpy(1.0, other.aSum, aSum) + BLAS.axpy(1.0, other.abSum, abSum) + BLAS.axpy(1.0, other.aaSum, aaSum) + this + } + } + + /** + * Validates that we have seen observations. + */ + def validate(): Unit = { + assert(initialized, "Training dataset is empty.") + assert(wSum > 0.0, "Sum of weights cannot be zero.") + } + + /** + * Weighted mean of features. + */ + def aBar: DenseVector = { + val output = aSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of labels. + */ + def bBar: Double = bSum / wSum + + /** + * Weighted population standard deviation of labels. + */ + def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + + /** + * Weighted mean of (label * features). + */ + def abBar: DenseVector = { + val output = abSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of (features * features^T^). + */ + def aaBar: DenseVector = { + val output = aaSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted population variance of features. + */ + def aVar: DenseVector = { + val variance = Array.ofDim[Double](k) + var i = 0 + var j = 2 + val aaValues = aaSum.values + while (i < triK) { + val l = j - 2 + val aw = aSum(l) / wSum + variance(l) = aaValues(i) / wSum - aw * aw + i += j + j += 1 + } + new DenseVector(variance) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index ab475af264dd3..9ee81eda8a8c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging { } } + /** Y += a * x */ + private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = { + require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " + + s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.") + f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1) + } + /** * dot(x, y) */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a423ddafdc09..83779ac88989b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -678,7 +678,8 @@ object RowMatrix { * * @param U the upper triangular part of the matrix packed in an array (column major) */ - private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + // TODO: SPARK-10491 - move this method to linalg.BLAS + private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { // TODO: Find a better home (breeze?) for this method. val n = v.size v match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala new file mode 100644 index 0000000000000..652f3adb984d3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.optim.WeightedLeastSquares.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + instances = sc.parallelize(Seq( + Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0), + Instance(2.0, Vectors.dense(1.0, 7.0), 19.0), + Instance(3.0, Vectors.dense(2.0, 11.0), 23.0), + Instance(4.0, Vectors.dense(3.0, 13.0), 29.0) + ), 2) + } + + test("WLS against lm") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- lm(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -3.727121 3.009983 + [1] 18.08 6.08 -0.60 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727121, 3.009983), + Vectors.dense(18.08, 6.08, -0.60)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("WLS against glmnet") { + /* + R code: + + library(glmnet) + + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + for (standardize in c(FALSE, TRUE)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda, + standardize=standardize, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.307532 2.924206 + [1] 0.000000 -2.914790 2.840627 + [1] 0.000000 -1.526575 2.558158 + [1] 0.00000000 0.06984238 2.20488344 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 13.5356178 3.2714044 0.3770744 + [1] 14.064629 3.565802 0.269593 + [1] 10.1238013 0.9708569 1.1475466 + [1] 13.1860638 2.1761382 0.6213134 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.307532, 2.924206), + Vectors.dense(0.0, -2.914790, 2.840627), + Vectors.dense(0.0, -1.526575, 2.558158), + Vectors.dense(0.0, 0.06984238, 2.20488344), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(13.5356178, 3.2714044, 0.3770744), + Vectors.dense(14.064629, 3.565802, 0.269593), + Vectors.dense(10.1238013, 0.9708569, 1.1475466), + Vectors.dense(13.1860638, 2.1761382, 0.6213134)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0); + standardizeFeatures <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam, standardizeFeatures, standardizeLabel = true) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} From a1573489a37def97b7c26b798898ffbbdc4defa8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 8 Sep 2015 20:54:02 -0700 Subject: [PATCH 0019/2089] [SPARK-10464] [MLLIB] Add WeibullGenerator for RandomDataGenerator Add WeibullGenerator for RandomDataGenerator. #8611 need use WeibullGenerator to generate random data based on Weibull distribution. Author: Yanbo Liang Closes #8622 from yanboliang/spark-10464. --- .../mllib/random/RandomDataGenerator.scala | 27 +++++++++++++++++-- .../random/RandomDataGeneratorSuite.scala | 16 ++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index a2d85a68cd327..9eab7efc160da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.random -import org.apache.commons.math3.distribution.{ExponentialDistribution, - GammaDistribution, LogNormalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution._ import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} @@ -195,3 +194,27 @@ class LogNormalGenerator @Since("1.3.0") ( @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } + +/** + * :: DeveloperApi :: + * Generates i.i.d. samples from the Weibull distribution with the + * given shape and scale parameter. + * + * @param alpha shape parameter for the Weibull distribution. + * @param beta scale parameter for the Weibull distribution. + */ +@DeveloperApi +class WeibullGenerator( + val alpha: Double, + val beta: Double) extends RandomDataGenerator[Double] { + + private val rng = new WeibullDistribution(alpha, beta) + + override def nextValue(): Double = rng.sample() + + override def setSeed(seed: Long): Unit = { + rng.reseedRandomGenerator(seed) + } + + override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index a5ca1518f82f5..8416771552fd3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.random -import scala.math +import org.apache.commons.math3.special.Gamma import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter @@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite { distributionChecks(gamma, expectedMean, expectedStd, 0.1) } } + + test("WeibullGenerator") { + List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map { + case (alpha: Double, beta: Double) => + val weibull = new WeibullGenerator(alpha, beta) + apiChecks(weibull) + + val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta + val expectedVariance = math.exp( + Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean + val expectedStd = math.sqrt(expectedVariance) + distributionChecks(weibull, expectedMean, expectedStd, 0.1) + } + } } From 3a11e50e21ececbec9708eb487b08196f195cd87 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 8 Sep 2015 20:56:22 -0700 Subject: [PATCH 0020/2089] [SPARK-10373] [PYSPARK] move @since into pyspark from sql cc mengxr Author: Davies Liu Closes #8657 from davies/move_since. --- python/pyspark/__init__.py | 16 ++++++++++++++++ python/pyspark/sql/__init__.py | 15 --------------- python/pyspark/sql/column.py | 2 +- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 3 +-- python/pyspark/sql/group.py | 2 +- python/pyspark/sql/readwriter.py | 3 +-- python/pyspark/sql/window.py | 3 +-- 9 files changed, 23 insertions(+), 25 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 5f70ac6ed8fe6..8475dfb1c6ad0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -48,6 +48,22 @@ from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler + +def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + + def deco(f): + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) + return f + return deco + + # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index ad9c891ba1c04..98eaf52866d23 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -44,21 +44,6 @@ from __future__ import absolute_import -def since(version): - """ - A decorator that annotates a function to append the version of Spark the function was added. - """ - import re - indent_p = re.compile(r'\n( +)') - - def deco(f): - indents = indent_p.findall(f.__doc__) - indent = ' ' * (min(len(m) for m in indents) if indents else 0) - f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) - return f - return deco - - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 56e75e8caee88..573f65f5bf096 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -22,9 +22,9 @@ basestring = str long = int +from pyspark import since from pyspark.context import SparkContext from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql import since from pyspark.sql.types import * __all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0ef46c44644ab..89c8c6e0d94f1 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -26,9 +26,9 @@ from py4j.protocol import Py4JError +from pyspark import since from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter from pyspark.sql.dataframe import DataFrame diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e269ef4304f3f..c5bf55791240b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,11 +26,11 @@ else: from itertools import imap as map +from pyspark import since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql import since from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4b74a501521a5..26b8662718a60 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,10 +24,9 @@ if sys.version < "3": from itertools import imap as map -from pyspark import SparkContext +from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql import since from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 04594d5a836ce..71c0bccc5eeff 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,8 +15,8 @@ # limitations under the License. # +from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql import since from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3fa6895880a97..f43d8bf646a9e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,8 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD -from pyspark.sql import since +from pyspark import RDD, since from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index eaf4d7e98620a..57bbe340bbd4d 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -17,8 +17,7 @@ import sys -from pyspark import SparkContext -from pyspark.sql import since +from pyspark import since, SparkContext from pyspark.sql.column import _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] From 0e2f2163314972f6be18e3453c64314d1bee7bb9 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Tue, 8 Sep 2015 21:26:20 -0700 Subject: [PATCH 0021/2089] [SPARK-10094] Pyspark ML Feature transformers marked as experimental Modified class-level docstrings to mark all feature transformers in pyspark.ml as experimental. Author: noelsmith Closes #8623 from noel-smith/SPARK-10094-mark-pyspark-ml-trans-exp. --- python/pyspark/ml/feature.py | 52 ++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index d955307e27efd..a7c5b2b62e764 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -36,6 +36,8 @@ @inherit_doc class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Binarize a column of continuous features given a threshold. >>> df = sqlContext.createDataFrame([(0.5,)], ["values"]) @@ -92,6 +94,8 @@ def getThreshold(self): @inherit_doc class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Maps a column of continuous features to a column of feature buckets. >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) @@ -169,6 +173,8 @@ def getSplits(self): @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero padding is performed on the input vector. It returns a real vector of the same length representing the DCT. @@ -232,6 +238,8 @@ def getInverse(self): @inherit_doc class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a provided "weight" vector. In other words, it scales each column of the dataset by a scalar multiplier. @@ -289,6 +297,8 @@ def getScalingVec(self): @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): """ + .. note:: Experimental + Maps a sequence of terms to their term frequencies using the hashing trick. @@ -327,6 +337,8 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Compute the Inverse Document Frequency (IDF) given a collection of documents. >>> from pyspark.mllib.linalg import DenseVector @@ -387,6 +399,8 @@ def _create_model(self, java_model): class IDFModel(JavaModel): """ + .. note:: Experimental + Model fitted by IDF. """ @@ -395,6 +409,8 @@ class IDFModel(JavaModel): @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A feature transformer that converts the input array of strings into an array of n-grams. Null values in the input array are ignored. It returns an array of n-grams where each n-gram is represented by a space-separated string of @@ -463,6 +479,8 @@ def getN(self): @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Normalize a vector to have unit norm using the given p-norm. >>> from pyspark.mllib.linalg import Vectors @@ -519,6 +537,8 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index. @@ -591,6 +611,8 @@ def getDropLast(self): @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, which is available at `http://en.wikipedia.org/wiki/Polynomial_expansion`, "In mathematics, an expansion of a product of sums expresses it as a sum of products by using the fact that @@ -649,6 +671,8 @@ def getDegree(self): @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text (default) or repeatedly matching the regex (if gaps is false). @@ -746,6 +770,8 @@ def getPattern(self): @inherit_doc class SQLTransformer(JavaTransformer): """ + .. note:: Experimental + Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. @@ -797,6 +823,8 @@ def getStatement(self): @inherit_doc class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. @@ -870,6 +898,8 @@ def _create_model(self, java_model): class StandardScalerModel(JavaModel): """ + .. note:: Experimental + Model fitted by StandardScaler. """ @@ -891,6 +921,8 @@ def mean(self): @inherit_doc class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. The indices are in [0, numLabels), ordered by label frequencies. @@ -929,6 +961,8 @@ def _create_model(self, java_model): class StringIndexerModel(JavaModel): """ + .. note:: Experimental + Model fitted by StringIndexer. """ @@ -1006,6 +1040,8 @@ def getCaseSensitive(self): @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A tokenizer that converts the input string to lowercase and then splits it by white spaces. @@ -1051,6 +1087,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): """ + .. note:: Experimental + A feature transformer that merges multiple columns into a vector column. >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) @@ -1087,6 +1125,8 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Class for indexing categorical feature columns in a dataset of [[Vector]]. This has 2 usage modes: @@ -1186,6 +1226,8 @@ def _create_model(self, java_model): class VectorIndexerModel(JavaModel): """ + .. note:: Experimental + Model fitted by VectorIndexer. """ @@ -1194,6 +1236,8 @@ class VectorIndexerModel(JavaModel): @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further natural language processing or machine learning process. @@ -1307,6 +1351,8 @@ def _create_model(self, java_model): class Word2VecModel(JavaModel): """ + .. note:: Experimental + Model fitted by Word2Vec. """ @@ -1332,6 +1378,8 @@ def findSynonyms(self, word, num): @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + PCA trains a model to project vectors to a low-dimensional space using PCA. >>> from pyspark.mllib.linalg import Vectors @@ -1387,6 +1435,8 @@ def _create_model(self, java_model): class PCAModel(JavaModel): """ + .. note:: Experimental + Model fitted by PCA. """ @@ -1470,6 +1520,8 @@ def _create_model(self, java_model): class RFormulaModel(JavaModel): """ + .. note:: Experimental + Model fitted by :py:class:`RFormula`. """ From 2f6fd5256c6650868916a3eefaa0beb091187cbb Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 8 Sep 2015 22:13:05 -0700 Subject: [PATCH 0022/2089] [SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark Adds IndexToString to PySpark. Author: Holden Karau Closes #7976 from holdenk/SPARK-9654-add-string-indexer-inverse-in-pyspark. --- .../spark/ml/feature/StringIndexer.scala | 2 +- python/pyspark/ml/feature.py | 74 ++++++++++++++++++- python/pyspark/ml/wrapper.py | 3 +- 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 77aeed0ab0370..b6482ffe0b2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,7 +102,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. * - * @param labels Ordered list of labels, corresponding to indices to be assigned + * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Experimental class StringIndexerModel ( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a7c5b2b62e764..8c26cfbd5a47d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,10 +27,11 @@ from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', - 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', - 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover'] + 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', + 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', + 'StopWordsRemover'] @inherit_doc @@ -934,6 +935,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> itd = inverter.transform(td) + >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] """ @keyword_only @@ -965,6 +971,66 @@ class StringIndexerModel(JavaModel): Model fitted by StringIndexer. """ + @property + def labels(self): + """ + Ordered list of labels, corresponding to indices to be assigned. + """ + return self._java_obj.labels + + +@inherit_doc +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A :py:class:`Transformer` that maps a column of string indices back to a new column of + corresponding string values using either the ML attributes of the input column, or if + provided using the labels supplied by the user. + All original columns are kept during transformation. + See L{StringIndexer} for converting strings into indices. + """ + + # a placeholder to make the labels show up in generated doc + labels = Param(Params._dummy(), "labels", + "Optional array of labels to be provided by the user, if not supplied or " + + "empty, column metadata is read for labels") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, labels=None): + """ + __init__(self, inputCol=None, outputCol=None, labels=None) + """ + super(IndexToString, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", + self.uid) + self.labels = Param(self, "labels", + "Optional array of labels to be provided by the user, if not " + + "supplied or empty, column metadata is read for labels") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, labels=None): + """ + setParams(self, inputCol=None, outputCol=None, labels=None) + Sets params for this IndexToString. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setLabels(self, value): + """ + Sets the value of :py:attr:`labels`. + """ + self._paramMap[self.labels] = value + return self + + def getLabels(self): + """ + Gets the value of :py:attr:`labels` or its default value. + """ + return self.getOrDefault(self.labels) class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 253705bde913e..8218c7c5f801c 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -136,7 +136,8 @@ def _fit(self, dataset): class JavaTransformer(Transformer, JavaWrapper): """ Base class for :py:class:`Transformer`s that wrap Java/Scala - implementations. + implementations. Subclasses should ensure they have the transformer Java object + available as _java_obj. """ __metaclass__ = ABCMeta From 91a577d2778ab5946f0c40cb80c89de24e3d10e8 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Sep 2015 22:33:23 -0700 Subject: [PATCH 0023/2089] [SPARK-10249] [ML] [DOC] Add Python Code Example to StopWordsRemover User Guide jira: https://issues.apache.org/jira/browse/SPARK-10249 update user guide since python support added. Author: Yuhao Yang Closes #8620 from hhbyyh/swPyDocExample. --- docs/ml-features.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 90654d1e5a248..58b31a5a5cc47 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -512,6 +512,25 @@ DataFrame dataset = jsql.createDataFrame(rdd, schema); remover.transform(dataset).show(); {% endhighlight %} + +
+[`StopWordsRemover`](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight python %} +from pyspark.ml.feature import StopWordsRemover + +sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) +], ["label", "raw"]) + +remover = StopWordsRemover(inputCol="raw", outputCol="filtered") +remover.transform(sentenceData).show(truncate=False) +{% endhighlight %} +
## $n$-gram From c1bc4f439f54625c01a585691e5293cd9961eb0c Mon Sep 17 00:00:00 2001 From: Luc Bourlier Date: Wed, 9 Sep 2015 09:57:58 +0100 Subject: [PATCH 0024/2089] [SPARK-10227] fatal warnings with sbt on Scala 2.11 The bulk of the changes are on `transient` annotation on class parameter. Often the compiler doesn't generate a field for this parameters, so the the transient annotation would be unnecessary. But if the class parameter are used in methods, then fields are created. So it is safer to keep the annotations. The remainder are some potential bugs, and deprecated syntax. Author: Luc Bourlier Closes #8433 from skyluc/issue/sbt-2.11. --- .../scala/org/apache/spark/Accumulators.scala | 2 +- .../scala/org/apache/spark/Dependency.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 4 +- .../org/apache/spark/SparkHadoopWriter.scala | 2 +- .../apache/spark/api/python/PythonRDD.scala | 4 +- .../spark/input/PortableDataStream.scala | 4 +- .../netty/NettyBlockTransferService.scala | 2 +- .../org/apache/spark/rdd/BinaryFileRDD.scala | 6 +-- .../scala/org/apache/spark/rdd/BlockRDD.scala | 4 +- .../org/apache/spark/rdd/CartesianRDD.scala | 4 +- .../org/apache/spark/rdd/CheckpointRDD.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 8 ++-- .../apache/spark/rdd/LocalCheckpointRDD.scala | 2 +- .../spark/rdd/LocalRDDCheckpointData.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 18 ++++---- .../spark/rdd/ParallelCollectionRDD.scala | 4 +- .../spark/rdd/PartitionPruningRDD.scala | 6 +-- .../spark/rdd/PartitionwiseSampledRDD.scala | 4 +- .../apache/spark/rdd/RDDCheckpointData.scala | 2 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../spark/rdd/ReliableRDDCheckpointData.scala | 2 +- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 6 +-- .../scala/org/apache/spark/rdd/UnionRDD.scala | 4 +- .../spark/rdd/ZippedPartitionsRDD.scala | 2 +- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 2 +- .../org/apache/spark/rpc/RpcEndpointRef.scala | 2 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 21 ++++++--- .../apache/spark/scheduler/ResultTask.scala | 2 +- .../shuffle/unsafe/UnsafeShuffleManager.scala | 2 +- .../apache/spark/util/ClosureCleaner.scala | 2 +- .../spark/util/ShutdownHookManager.scala | 4 +- .../util/logging/RollingFileAppender.scala | 46 +++++++++---------- .../streaming/flume/FlumeInputDStream.scala | 2 +- .../flume/FlumePollingInputDStream.scala | 2 +- .../kafka/DirectKafkaInputDStream.scala | 4 +- .../streaming/kafka/KafkaInputDStream.scala | 2 +- .../streaming/mqtt/MQTTInputDStream.scala | 2 +- .../twitter/TwitterInputDStream.scala | 2 +- .../org/apache/spark/graphx/EdgeRDD.scala | 4 +- .../scala/org/apache/spark/graphx/Graph.scala | 6 +-- .../org/apache/spark/graphx/VertexRDD.scala | 4 +- .../apache/spark/mllib/rdd/RandomRDD.scala | 12 ++--- .../org/apache/spark/repl/SparkILoop.scala | 2 +- .../expressions/stringExpressions.scala | 2 +- .../spark/sql/types/AbstractDataType.scala | 2 +- .../datasources/WriterContainer.scala | 10 ++-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../apache/spark/sql/hive/TableReader.scala | 8 ++-- .../hive/execution/ScriptTransformation.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 10 ++-- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/api/python/PythonDStream.scala | 12 ++--- .../streaming/dstream/FileInputDStream.scala | 2 +- .../streaming/dstream/InputDStream.scala | 2 +- .../dstream/PluggableInputDStream.scala | 2 +- .../streaming/dstream/QueueInputDStream.scala | 4 +- .../streaming/dstream/RawInputDStream.scala | 2 +- .../dstream/ReceiverInputDStream.scala | 2 +- .../dstream/SocketInputDStream.scala | 2 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 22 ++++----- 60 files changed, 158 insertions(+), 151 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index c39c8667d013e..5592b75afb75b 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils * @tparam T partial data that can be added in */ class Accumulable[R, T] private[spark] ( - @transient initialValue: R, + initialValue: R, param: AccumulableParam[R, T], val name: Option[String], internal: Boolean) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index fc8cdde9348ee..cfeeb3902c033 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -66,7 +66,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient _rdd: RDD[_ <: Product2[K, V]], + @transient private val _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 29e581bb57cbc..e4df7af81a6d2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -104,8 +104,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K, V]], + partitions: Int, + rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index f5dd36cbcfe6d..ae5926dd534a6 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(@transient jobConf: JobConf) +class SparkHadoopWriter(jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b4d152b336602..69da180593bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal private[spark] class PythonRDD( - @transient parent: RDD[_], + parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -785,7 +785,7 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index a5ad47293f1c2..e2ffc3b64e5db 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -131,8 +131,8 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat */ @Experimental class PortableDataStream( - @transient isplit: CombineFileSplit, - @transient context: TaskAttemptContext, + isplit: CombineFileSplit, + context: TaskAttemptContext, index: Integer) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 4b851bcb36597..70a42f9045e6b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -137,7 +137,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") - result.success() + result.success((): Unit) } override def onFailure(e: Throwable): Unit = { logError(s"Error while uploading block $blockId", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 1f755db485812..6fec00dcd0d85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -28,7 +28,7 @@ private[spark] class BinaryFileRDD[T]( inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { @@ -36,10 +36,10 @@ private[spark] class BinaryFileRDD[T]( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(getConf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(getConf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 922030263756b..fc1710fbad0a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -28,7 +28,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @@ -64,7 +64,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds */ private[spark] def removeBlocks() { blockIds.foreach { blockId => - sc.env.blockManager.master.removeBlock(blockId) + sparkContext.env.blockManager.master.removeBlock(blockId) } _isValid = false } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index c1d6971787572..18e8cddbc40db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -27,8 +27,8 @@ import org.apache.spark.util.Utils private[spark] class CartesianPartition( idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], + @transient private val rdd1: RDD[_], + @transient private val rdd2: RDD[_], s1Index: Int, s2Index: Int ) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 72fe215dae73e..b0364623af4cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -29,7 +29,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * An RDD that recovers checkpointed data from storage. */ -private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext) +private[spark] abstract class CheckpointRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { // CheckpointRDD should not be checkpointed again diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1f8719eead02..8f2655d63b797 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) +private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -99,7 +99,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - @transient sc: SparkContext, + sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -109,7 +109,7 @@ class HadoopRDD[K, V]( extends RDD[(K, V)](sc, Nil) with Logging { if (initLocalJobConfFuncOpt.isDefined) { - sc.clean(initLocalJobConfFuncOpt.get) + sparkContext.clean(initLocalJobConfFuncOpt.get) } def this( @@ -137,7 +137,7 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() - private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala index daa5779d688cc..bfe19195fcd37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.storage.RDDBlockId * @param numPartitions the number of partitions in the checkpointed RDD */ private[spark] class LocalCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, rddId: Int, numPartitions: Int) extends CheckpointRDD[T](sc) { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index d6fad896845f6..c115e0ff74d3c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils * is written to the local, ephemeral block storage that lives in each executor. This is useful * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). */ -private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 6a9c004d65cff..174979aaeb231 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -40,7 +40,7 @@ import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -68,14 +68,14 @@ class NewHadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient private val _conf: Configuration) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) - // private val serializableConf = new SerializableWritable(conf) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) + // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -88,10 +88,10 @@ class NewHadoopRDD[K, V]( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -262,7 +262,7 @@ private[spark] class WholeTextFileRDD( inputFormatClass: Class[_ <: WholeTextFileInputFormat], keyClass: Class[String], valueClass: Class[String], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { @@ -270,10 +270,10 @@ private[spark] class WholeTextFileRDD( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(getConf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(getConf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e2394e28f8d26..582fa93afe34e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -83,8 +83,8 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( } private[spark] class ParallelCollectionRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient data: Seq[T], + sc: SparkContext, + @transient private val data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index a00f4c1cdff91..d6a37e8cc5dac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -32,7 +32,7 @@ private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Par * Represents a dependency between the PartitionPruningRDD and its parent. In this * case, the child RDD contains a subset of partitions of the parents'. */ -private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) +private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient @@ -55,8 +55,8 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF */ @DeveloperApi class PartitionPruningRDD[T: ClassTag]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) + prev: RDD[T], + partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index a637d6f15b7e5..3b1acacf409b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -47,8 +47,8 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], - @transient preservesPartitioning: Boolean, - @transient seed: Long = Utils.random.nextLong) + preservesPartitioning: Boolean, + @transient private val seed: Long = Utils.random.nextLong) extends RDD[U](prev) { @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 0e43520870c0a..429514b4f6bee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -36,7 +36,7 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends Serializable { import CheckpointState._ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 35d8b0bfd18c5..1c3b5da19ceba 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} * An RDD that reads from checkpoint files previously written to reliable storage. */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, val checkpointPath: String) extends CheckpointRDD[T](sc) { diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 1df8eef5ff2b9..e9f6060301ba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.SerializableConfiguration * An implementation of checkpointing that writes the RDD data to reliable storage. * This allows drivers to be restarted on failure with previously computed state. */ -private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { // The directory to which the associated RDD has been checkpointed to diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index fa3fecc80cb63..9babe56267e08 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Ut private[spark] class SqlNewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends SparkPartition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -61,9 +61,9 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - @transient sc : SparkContext, + sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], - @transient initDriverSideJobFuncOpt: Option[Job => Unit], + @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 3986645350a82..66cf4369da2ef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class UnionPartition[T: ClassTag]( idx: Int, - @transient rdd: RDD[T], + @transient private val rdd: RDD[T], val parentRddIndex: Int, - @transient parentRddPartitionIndex: Int) + @transient private val parentRddPartitionIndex: Int) extends Partition { var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index b3c64394abc76..70bf04de6400d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]], + @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e277ae28d588f..32931d59acb18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -37,7 +37,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) * @tparam T parent RDD item type */ private[spark] -class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { +class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) { /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 7409ac8859991..f25710bb5bd6e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) +private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging { private[this] val maxRetries = RpcUtils.numRetries(conf) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index fc17542abf81d..ad67e1c5ad4d5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -87,9 +87,9 @@ private[spark] class AkkaRpcEnv private[akka] ( override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use lazy because the Actor needs to use `endpointRef`. + // Use defered function because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. - lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { assert(endpointRef != null) @@ -272,13 +272,20 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging } private[akka] class AkkaRpcEndpointRef( - @transient defaultAddress: RpcAddress, - @transient _actorRef: => ActorRef, - @transient conf: SparkConf, - @transient initInConstructor: Boolean = true) + @transient private val defaultAddress: RpcAddress, + @transient private val _actorRef: () => ActorRef, + conf: SparkConf, + initInConstructor: Boolean) extends RpcEndpointRef(conf) with Logging { - lazy val actorRef = _actorRef + def this( + defaultAddress: RpcAddress, + _actorRef: ActorRef, + conf: SparkConf) = { + this(defaultAddress, () => _actorRef, conf, true) + } + + lazy val actorRef = _actorRef() override lazy val address: RpcAddress = { val akkaAddress = actorRef.path.address diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c4dc080e2b22b..fb693721a9cb6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -44,7 +44,7 @@ private[spark] class ResultTask[T, U]( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient locs: Seq[TaskLocation], + locs: Seq[TaskLocation], val outputId: Int, internalAccumulators: Seq[Accumulator[Long]]) extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index df7bbd64247dd..75f22f642b9d1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -159,7 +159,7 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage mapId: Int, context: TaskContext): ShuffleWriter[K, V] = { handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) val env = SparkEnv.get new UnsafeShuffleWriter( diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 150d82b3930ef..1b49dca9dc78b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -94,7 +94,7 @@ private[spark] object ClosureCleaner extends Logging { if (cls.isPrimitive) { cls match { case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Character.TYPE => new java.lang.Character('\u0000') case java.lang.Void.TYPE => // This should not happen because `Foo(void x) {}` does not compile. throw new IllegalStateException("Unexpected void parameter in constructor") diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 61ff9b89ec1c1..db4a8b304ec3e 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -217,7 +217,9 @@ private [util] class SparkShutdownHookManager { } Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + val fsPriority = classOf[FileSystem] + .getField("SHUTDOWN_HOOK_PRIORITY") + .get(null) // static field, the value is not used .asInstanceOf[Int] val shm = shmClass.getMethod("get").invoke(null) shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 7138b4b8e4533..1e8476c4a047e 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -79,32 +79,30 @@ private[spark] class RollingFileAppender( val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() val rolloverFile = new File( activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile - try { - logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") - if (activeFile.exists) { - if (!rolloverFile.exists) { - Files.move(activeFile, rolloverFile) - logInfo(s"Rolled over $activeFile to $rolloverFile") - } else { - // In case the rollover file name clashes, make a unique file name. - // The resultant file names are long and ugly, so this is used only - // if there is a name collision. This can be avoided by the using - // the right pattern such that name collisions do not occur. - var i = 0 - var altRolloverFile: File = null - do { - altRolloverFile = new File(activeFile.getParent, - s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile - i += 1 - } while (i < 10000 && altRolloverFile.exists) - - logWarning(s"Rollover file $rolloverFile already exists, " + - s"rolled over $activeFile to file $altRolloverFile") - Files.move(activeFile, altRolloverFile) - } + logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") + if (activeFile.exists) { + if (!rolloverFile.exists) { + Files.move(activeFile, rolloverFile) + logInfo(s"Rolled over $activeFile to $rolloverFile") } else { - logWarning(s"File $activeFile does not exist") + // In case the rollover file name clashes, make a unique file name. + // The resultant file names are long and ugly, so this is used only + // if there is a name collision. This can be avoided by the using + // the right pattern such that name collisions do not occur. + var i = 0 + var altRolloverFile: File = null + do { + altRolloverFile = new File(activeFile.getParent, + s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile + i += 1 + } while (i < 10000 && altRolloverFile.exists) + + logWarning(s"Rollover file $rolloverFile already exists, " + + s"rolled over $activeFile to file $altRolloverFile") + Files.move(activeFile, altRolloverFile) } + } else { + logWarning(s"File $activeFile does not exist") } } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2bf99cb3cba1f..c8780aa83bdbd 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -43,7 +43,7 @@ import org.jboss.netty.handler.codec.compression._ private[streaming] class FlumeInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel, diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 0bc46209b8369..3b936d88abd3e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -46,7 +46,7 @@ import org.apache.spark.streaming.flume.sink._ * @tparam T Class type of the object of this stream */ private[streaming] class FlumePollingInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, val addresses: Seq[InetSocketAddress], val maxBatchSize: Int, val parallelism: Int, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 1000094e93cb3..8a087474d3169 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -58,7 +58,7 @@ class DirectKafkaInputDStream[ U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R @@ -79,7 +79,7 @@ class DirectKafkaInputDStream[ override protected[streaming] val rateController: Option[RateController] = { if (RateController.isBackPressureEnabled(ssc.conf)) { Some(new DirectKafkaRateController(id, - RateEstimator.create(ssc.conf, ssc_.graph.batchDuration))) + RateEstimator.create(ssc.conf, context.graph.batchDuration))) } else { None } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 04b2dc10d39ea..38730fecf332a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -48,7 +48,7 @@ class KafkaInputDStream[ V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], useReliableReceiver: Boolean, diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 7c2f18cb35bda..116c170489e96 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class MQTTInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 7cf02d85d73d3..d7de74b350543 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class TwitterInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 4611a3ace219b..ee7302a1edbf6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -38,8 +38,8 @@ import org.apache.spark.graphx.impl.EdgeRDDImpl * `impl.ReplicatedVertexView`. */ abstract class EdgeRDD[ED]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index db73a8abc5733..869caa340f52b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -46,7 +46,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @note vertex ids are unique. * @return an RDD containing the vertices in this graph */ - @transient val vertices: VertexRDD[VD] + val vertices: VertexRDD[VD] /** * An RDD containing the edges and their associated attributes. The entries in the RDD contain @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED] + val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -77,7 +77,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum * }}} */ - @transient val triplets: RDD[EdgeTriplet[VD, ED]] + val triplets: RDD[EdgeTriplet[VD, ED]] /** * Caches the vertices and edges associated with this graph at the specified storage level, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index a9f04b559c3d1..1ef7a78fbcd00 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -55,8 +55,8 @@ import org.apache.spark.graphx.impl.VertexRDDImpl * @tparam VD the vertex attribute associated with each vertex in the set. */ abstract class VertexRDD[VD]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { implicit protected def vdTag: ClassTag[VD] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index 910eff9540a47..f8cea7ecea6bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -35,11 +35,11 @@ private[mllib] class RandomRDDPartition[T](override val index: Int, } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: RandomDataGenerator[T], - @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { + @transient private val rng: RandomDataGenerator[T], + @transient private val seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") @@ -56,12 +56,12 @@ private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, } } -private[mllib] class RandomVectorRDD(@transient sc: SparkContext, +private[mllib] class RandomVectorRDD(sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: RandomDataGenerator[Double], - @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { + @transient private val rng: RandomDataGenerator[Double], + @transient private val seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index bf609ff0f65fc..33d262558b1fc 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -118,5 +118,5 @@ object SparkILoop { } } } - def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 48d02bb534501..a09d5b6e3ad14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -255,7 +255,7 @@ object StringTranslate { val dict = new HashMap[Character, Character]() var i = 0 while (i < matching.length()) { - val rep = if (i < replace.length()) replace.charAt(i) else '\0' + val rep = if (i < replace.length()) replace.charAt(i) else '\u0000' if (null == dict.get(matching.charAt(i))) { dict.put(matching.charAt(i), rep) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index e0667c629486d..1d2d007c2b4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -126,7 +126,7 @@ protected[sql] object AnyDataType extends AbstractDataType { */ protected[sql] abstract class AtomicType extends DataType { private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val tag: TypeTag[InternalType] private[sql] val ordering: Ordering[InternalType] @transient private[sql] val classTag = ScalaReflectionLock.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 879fd69863211..9a573db0c023a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.SerializableConfiguration private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, - @transient job: Job, + @transient private val job: Job, isAppend: Boolean) extends SparkHadoopMapReduceUtil with Logging @@ -222,8 +222,8 @@ private[sql] abstract class BaseWriterContainer( * A writer that writes all of the rows in a partition to a single file. */ private[sql] class DefaultWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, + relation: HadoopFsRelation, + job: Job, isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { @@ -286,8 +286,8 @@ private[sql] class DefaultWriterContainer( * writer externally sorts the remaining rows and then writes out them out one file at a time. */ private[sql] class DynamicPartitionWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, + relation: HadoopFsRelation, + job: Job, partitionColumns: Seq[Attribute], dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b8da0840ae569..0a5569b0a4446 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -767,7 +767,7 @@ private[hive] case class InsertIntoHiveTable( private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) - (@transient sqlContext: SQLContext) + (@transient private val sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index dc355690852bd..e35468a624c3e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -54,10 +54,10 @@ private[hive] sealed trait TableReader { */ private[hive] class HadoopTableReader( - @transient attributes: Seq[Attribute], - @transient relation: MetastoreRelation, - @transient sc: HiveContext, - @transient hiveExtraConf: HiveConf) + @transient private val attributes: Seq[Attribute], + @transient private val relation: MetastoreRelation, + @transient private val sc: HiveContext, + hiveExtraConf: HiveConf) extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c7651daffe36e..32bddbaeaeaf9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -53,7 +53,7 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) + ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { override def otherCopyArgs: Seq[HiveContext] = sc :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8dc796b056a72..29a6f08f40728 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc) extends Logging with SparkHadoopMapRedUtil @@ -163,7 +163,7 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { } private[spark] class SparkHiveDynamicPartitionWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { @@ -194,10 +194,10 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does: // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then // load it with loadDynamicPartitions/loadPartition/loadTable. - val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) + val oldMarker = conf.value.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) super.commitJob() - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } override def getLocalFileWriter(row: InternalRow, schema: StructType) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 27024ecfd9101..8a6050f5227bf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) +class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 2c373640d2fd9..dfc569451df86 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -170,7 +170,7 @@ private[python] object PythonDStream { */ private[python] abstract class PythonDStream( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -187,7 +187,7 @@ private[python] abstract class PythonDStream( */ private[python] class PythonTransformedDStream ( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends PythonDStream(parent, pfunc) { override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -206,7 +206,7 @@ private[python] class PythonTransformedDStream ( private[python] class PythonTransformed2DStream( parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -230,7 +230,7 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - @transient reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -252,8 +252,8 @@ private[python] class PythonStateDStream( */ private[python] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - @transient preduceFunc: PythonTransformFunction, - @transient pinvReduceFunc: PythonTransformFunction, + preduceFunc: PythonTransformFunction, + @transient private val pinvReduceFunc: PythonTransformFunction, _windowDuration: Duration, _slideDuration: Duration) extends PythonDStream(parent, preduceFunc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index c358f5b5bd70b..40208a64861fb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -70,7 +70,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index a6c4cd220e42f..95994c983c0cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * * @param ssc_ Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) +abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) extends DStream[T](ssc_) { private[streaming] var lastValidTime: Time = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 186e1bf03a944..002aac9f43617 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { def getReceiver(): Receiver[T] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index bab78a3536b47..a2685046e03d4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Time, StreamingContext} private[streaming] class QueueInputDStream[T: ClassTag]( - @transient ssc: StreamingContext, + ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] @@ -57,7 +57,7 @@ class QueueInputDStream[T: ClassTag]( if (oneAtATime) { Some(buffer.head) } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) + Some(new UnionRDD(context.sc, buffer.toSeq)) } } else if (defaultRDD != null) { Some(defaultRDD) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index e2925b9e03ec3..5a9eda7c12776 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 6c139f32da31d..87c20afd5c13c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.{StreamingContext, Time} * @param ssc_ Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) +abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) extends InputDStream[T](ssc_) { /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 5ce5b7aae6e69..de84e0c9a498d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class SocketInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index e081ffe46f502..f811784b25c82 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -61,7 +61,7 @@ class WriteAheadLogBackedBlockRDDPartition( * * * @param sc SparkContext - * @param blockIds Ids of the blocks that contains this RDD's data + * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark * executors). If not, then block lookups by the block ids will be skipped. @@ -73,23 +73,23 @@ class WriteAheadLogBackedBlockRDDPartition( */ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient blockIds: Array[BlockId], + sc: SparkContext, + @transient private val _blockIds: Array[BlockId], @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) - extends BlockRDD[T](sc, blockIds) { + extends BlockRDD[T](sc, _blockIds) { require( - blockIds.length == walRecordHandles.length, - s"Number of block Ids (${blockIds.length}) must be " + + _blockIds.length == walRecordHandles.length, + s"Number of block Ids (${_blockIds.length}) must be " + s" same as number of WAL record handles (${walRecordHandles.length})") require( - isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, + isBlockIdValid.isEmpty || isBlockIdValid.length == _blockIds.length, s"Number of elements in isBlockIdValid (${isBlockIdValid.length}) must be " + - s" same as number of block Ids (${blockIds.length})") + s" same as number of block Ids (${_blockIds.length})") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -99,9 +99,9 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), isValid, walRecordHandles(i)) + new WriteAheadLogBackedBlockRDDPartition(i, _blockIds(i), isValid, walRecordHandles(i)) } } From 2ddeb63126d26149eda197e85b7b26ef16a6e97c Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 9 Sep 2015 09:29:10 -0700 Subject: [PATCH 0025/2089] [SPARK-10117] [MLLIB] Implement SQL data source API for reading LIBSVM data It is convenient to implement data source API for LIBSVM format to have a better integration with DataFrames and ML pipeline API. Two option is implemented. * `numFeatures`: Specify the dimension of features vector * `featuresType`: Specify the type of output vector. `sparse` is default. Author: lewuathe Closes #8537 from Lewuathe/SPARK-10117 and squashes the following commits: 986999d [lewuathe] Change unit test phrase 11d513f [lewuathe] Fix some reviews 21600a4 [lewuathe] Merge branch 'master' into SPARK-10117 9ce63c7 [lewuathe] Rewrite service loader file 1fdd2df [lewuathe] Merge branch 'SPARK-10117' of github.com:Lewuathe/spark into SPARK-10117 ba3657c [lewuathe] Merge branch 'master' into SPARK-10117 0ea1c1c [lewuathe] LibSVMRelation is registered into META-INF 4f40891 [lewuathe] Improve test suites 5ab62ab [lewuathe] Merge branch 'master' into SPARK-10117 8660d0e [lewuathe] Fix Java unit test b56a948 [lewuathe] Merge branch 'master' into SPARK-10117 2c12894 [lewuathe] Remove unnecessary tag 7d693c2 [lewuathe] Resolv conflict 62010af [lewuathe] Merge branch 'master' into SPARK-10117 a97ee97 [lewuathe] Fix some points aef9564 [lewuathe] Fix 70ee4dd [lewuathe] Add Java test 3fd8dce [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data 40d3027 [lewuathe] Add Java test 7056d4a [lewuathe] Merge branch 'master' into SPARK-10117 99accaa [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../ml/source/libsvm/LibSVMRelation.scala | 99 +++++++++++++++++++ .../ml/source/JavaLibSVMRelationSuite.java | 80 +++++++++++++++ .../spark/ml/source/LibSVMRelationSuite.scala | 76 ++++++++++++++ 4 files changed, 256 insertions(+) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..f632dd603c449 --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.ml.source.libsvm.DefaultSource diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 0000000000000..b12cb62a4ef15 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects + +import org.apache.spark.Logging +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{StructType, StructField, DoubleType} +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ + +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path File path of LibSVM format + * @param numFeatures The number of features + * @param vectorType The type of vector. It can be 'sparse' or 'dense' + * @param sqlContext The Spark SQLContext + */ +private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan with Logging with Serializable { + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil + ) + + override def buildScan(): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + + baseRdd.map { pt => + val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse + Row(pt.label, features) + } + } + + override def hashCode(): Int = { + Objects.hashCode(path, schema) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) + case _ => false + } + +} + +/** + * This is used for creating DataFrame from LibSVM format file. + * The LibSVM file path must be specified to DefaultSource. + */ +@Since("1.6.0") +class DefaultSource extends RelationProvider with DataSourceRegister { + + @Since("1.6.0") + override def shortName(): String = "libsvm" + + private def checkPath(parameters: Map[String, String]): String = { + require(parameters.contains("path"), "'path' must be specified") + parameters.get("path").get + } + + /** + * Returns a new base relation with the given parameters. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { + val path = checkPath(parameters) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + /** + * featuresType can be selected "dense" or "sparse". + * This parameter decides the type of returned feature vector. + */ + val vectorType = parameters.getOrElse("vectorType", "sparse") + new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java new file mode 100644 index 0000000000000..11fa4eec0ccf0 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source; + +import java.io.File; +import java.io.IOException; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseVector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + + private File tmpDir; + private File path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + jsql = new SQLContext(jsc); + + tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + path = new File(tmpDir.getPath(), "part-00000"); + + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, path, Charsets.US_ASCII); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + Utils.deleteRecursively(tmpDir); + } + + @Test + public void verifyLibSVMDF() { + dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath()); + Assert.assertEquals("label", dataset.columns()[0]); + Assert.assertEquals("features", dataset.columns()[1]); + Row r = dataset.first(); + Assert.assertEquals(1.0, r.getDouble(0), 1e-15); + DenseVector v = r.getAs(1); + Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala new file mode 100644 index 0000000000000..8ed134128c8d2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + test("select as sparse vector") { + val df = sqlContext.read.format("libsvm").load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + .load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[DenseVector](1) + assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select a vector with specifying the longer dimension") { + val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + .load(path) + val row1 = df.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } +} From c0052d8d09eebadadb5ed35ac512caaf73919551 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 9 Sep 2015 10:26:53 -0700 Subject: [PATCH 0026/2089] =?UTF-8?q?[SPARK-10481]=20[YARN]=20SPARK=5FPREP?= =?UTF-8?q?END=5FCLASSES=20make=20spark-yarn=20related=20jar=20could=20n?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Throw a more readable exception. Please help review. Thanks Author: Jeff Zhang Closes #8649 from zjffdu/SPARK-10481. --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e9a02baafd28e..a2c4bc2f5480b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1045,7 +1045,10 @@ object Client extends Logging { s"in favor of the $CONF_SPARK_JAR configuration variable.") System.getenv(ENV_SPARK_JAR) } else { - SparkContext.jarOfClass(this.getClass).head + SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not " + + "find jar containing Spark classes. The jar can be defined using the " + + "spark.yarn.jar configuration option. If testing Spark, either set that option or " + + "make sure SPARK_PREPEND_CLASSES is not set.")) } } From 71da1633c4dcc4e748fbe3b3236af90032b695ae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 9 Sep 2015 10:57:29 -0700 Subject: [PATCH 0027/2089] [SPARK-10461] [SQL] make sure `input.primitive` is always variable name not code at `GenerateUnsafeProjection` When we generate unsafe code inside `createCodeForXXX`, we always assign the `input.primitive` to a temp variable in case `input.primitive` is expression code. This PR did some refactor to make sure `input.primitive` is always variable name, and some other typo and style fixes. Author: Wenchen Fan Closes #8613 from cloud-fan/minor. --- .../expressions/codegen/CodeGenerator.scala | 10 +-- .../codegen/GenerateMutableProjection.scala | 2 - .../codegen/GenerateProjection.scala | 12 +-- .../codegen/GenerateUnsafeProjection.scala | 85 ++++++++++--------- .../expressions/complexTypeCreator.scala | 33 ++++--- 5 files changed, 75 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf96248feaef7..da3103b4ebb6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -84,11 +84,11 @@ class CodeGenContext { /** * Holding all the functions those will be added into generated class. */ - val addedFuntions: mutable.Map[String, String] = + val addedFunctions: mutable.Map[String, String] = mutable.Map.empty[String, String] def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFuntions += ((funcName, funcCode)) + addedFunctions += ((funcName, funcCode)) } final val JAVA_BOOLEAN = "boolean" @@ -298,8 +298,8 @@ class CodeGenContext { | $body |} """.stripMargin - addNewFunction(name, code) - name + addNewFunction(name, code) + name } functions.map(name => s"$name($row);").mkString("\n") @@ -337,7 +337,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b4d4df8934bd4..793023b9fbed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types.DecimalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d822e8..2164ddf03d1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -48,7 +48,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val columns = expressions.zipWithIndex.map { case (e, i) => s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n ") + }.mkString("\n") val initColumns = expressions.zipWithIndex.map { case (e, i) => @@ -67,18 +67,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val getCases = (0 until expressions.size).map { i => s"case $i: return c$i;" - }.mkString("\n ") + }.mkString("\n") val updateCases = expressions.zipWithIndex.map { case (e, i) => s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n ") + }.mkString("\n") val specificAccessorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { case (e, i) if ctx.javaType(e.dataType) == jt => Some(s"case $i: return c$i;") case _ => None - }.mkString("\n ") + }.mkString("\n") if (cases.length > 0) { val getter = "get" + ctx.primitiveTypeName(jt) s""" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case (e, i) if ctx.javaType(e.dataType) == jt => Some(s"case $i: { c$i = value; return; }") case _ => None - }.mkString("\n ") + }.mkString("\n") if (cases.length > 0) { val setter = "set" + ctx.primitiveTypeName(jt) s""" @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val copyColumns = expressions.zipWithIndex.map { case (e, i) => s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n ") + }.mkString("\n") val code = s""" public SpecificProjection generate($exprType[] expr) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b570fe86db1aa..03c5f449bf9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -134,7 +134,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") ctx.addMutableState("int", cursor, s"this.$cursor = 0;") - val tmp = ctx.freshName("tmpBuffer") + val tmpBuffer = ctx.freshName("tmpBuffer") val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => val ev = createConvertCode(ctx, input, dt) @@ -144,10 +144,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); if ($buffer.length < $numBytes) { // This will not happen frequently, because the buffer is re-used. - byte[] $tmp = new byte[$numBytes * 2]; + byte[] $tmpBuffer = new byte[$numBytes * 2]; Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, - $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); - $buffer = $tmp; + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; } $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); """ @@ -207,20 +207,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val buffer = ctx.freshName("buffer") ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") val numElements = ctx.freshName("numElements") val fixedSize = ctx.freshName("fixedSize") val numBytes = ctx.freshName("numBytes") val elements = ctx.freshName("elements") val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") + val elementName = ctx.freshName("elementName") - val element = GeneratedExpressionCode( - code = "", - isNull = s"$tmp.isNullAt($index)", - primitive = s"${ctx.getValue(tmp, elementType, index)}" - ) - val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType) + val element = { + val code = s"${ctx.javaType(elementType)} $elementName = " + + s"${ctx.getValue(input.primitive, elementType, index)};" + val isNull = s"${input.primitive}.isNullAt($index)" + GeneratedExpressionCode(code, isNull, elementName) + } + + val convertedElement = createConvertCode(ctx, element, elementType) // go through the input array to calculate how many bytes we need. val calculateNumBytes = elementType match { @@ -272,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" + ${convertedElement.code} Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -280,6 +283,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" + ${convertedElement.code} Platform.putLong( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -307,11 +311,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${input.code} final boolean $outputIsNull = ${input.isNull}; if (!$outputIsNull) { - final ArrayData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeArrayData) { - $output = (UnsafeArrayData) $tmp; + if (${input.primitive} instanceof UnsafeArrayData) { + $output = (UnsafeArrayData) ${input.primitive}; } else { - final int $numElements = $tmp.numElements(); + final int $numElements = ${input.primitive}.numElements(); final int $fixedSize = 4 * $numElements; int $numBytes = $fixedSize; @@ -350,29 +353,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro valueType: DataType): GeneratedExpressionCode = { val output = ctx.freshName("convertedMap") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") - - val keyArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.keyArray()" - ) - val valueArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.valueArray()" - ) - val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType) - val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType) + val keyArrayName = ctx.freshName("keyArrayName") + val valueArrayName = ctx.freshName("valueArrayName") + + val keyArray = { + val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();" + val isNull = "false" + GeneratedExpressionCode(code, isNull, keyArrayName) + } + + val valueArray = { + val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();" + val isNull = "false" + GeneratedExpressionCode(code, isNull, valueArrayName) + } + + val convertedKeys = createCodeForArray(ctx, keyArray, keyType) + val convertedValues = createCodeForArray(ctx, valueArray, valueType) val code = s""" ${input.code} final boolean $outputIsNull = ${input.isNull}; UnsafeMapData $output = null; if (!$outputIsNull) { - final MapData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeMapData) { - $output = (UnsafeMapData) $tmp; + if (${input.primitive} instanceof UnsafeMapData) { + $output = (UnsafeMapData) ${input.primitive}; } else { ${convertedKeys.code} ${convertedValues.code} @@ -393,22 +398,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => val output = ctx.freshName("convertedStruct") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") val fieldTypes = t.fields.map(_.dataType) val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val getFieldCode = ctx.getValue(tmp, dt, i.toString) - val fieldIsNull = s"$tmp.isNullAt($i)" - GeneratedExpressionCode("", fieldIsNull, getFieldCode) + val fieldName = ctx.freshName("fieldName") + val code = s"${ctx.javaType(dt)} $fieldName = " + + s"${ctx.getValue(input.primitive, dt, i.toString)};" + val isNull = s"${input.primitive}.isNullAt($i)" + GeneratedExpressionCode(code, isNull, fieldName) } - val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; final boolean $outputIsNull = ${input.isNull}; if (!$outputIsNull) { - final InternalRow $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeRow) { - $output = (UnsafeRow) $tmp; + if (${input.primitive} instanceof UnsafeRow) { + $output = (UnsafeRow) ${input.primitive}; } else { ${converter.code} $output = ${converter.primitive}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1c546719730b7..82eab5fb3d03a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") s""" final boolean ${ev.isNull} = false; - final Object[] values = new Object[${children.size}]; + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - values[$i] = null; + $values[$i] = null; } else { - values[$i] = ${eval.primitive}; + $values[$i] = ${eval.primitive}; } """ }.mkString("\n") + - s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" + s"final ArrayData ${ev.primitive} = new $arrayClass($values);" } override def prettyName: String = "array" @@ -94,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.primitive} = new $rowClass($values);" } override def prettyName: String = "struct" @@ -161,21 +164,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + final Object[] $values = new Object[${valExprs.size}]; """ + valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.primitive} = new $rowClass($values);" } override def prettyName: String = "named_struct" From 45de518742446ddfbd4816c9d0f8501139f9bc2d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Sep 2015 16:02:27 -0700 Subject: [PATCH 0028/2089] [SPARK-9730] [SQL] Add Full Outer Join support for SortMergeJoin This PR is based on #8383 , thanks to viirya JIRA: https://issues.apache.org/jira/browse/SPARK-9730 This patch adds the Full Outer Join support for SortMergeJoin. A new class SortMergeFullJoinScanner is added to scan rows from left and right iterators. FullOuterIterator is simply a wrapper of type RowIterator to consume joined rows from SortMergeFullJoinScanner. Closes #8383 Author: Liang-Chi Hsieh Author: Davies Liu Closes #8579 from davies/smj_fullouter. --- .../apache/spark/util/collection/BitSet.scala | 11 + .../spark/sql/execution/SparkStrategies.scala | 9 +- .../execution/joins/SortMergeOuterJoin.scala | 227 +++++++++++++++++- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 44 ++-- 5 files changed, 259 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d91c..7ab67fc3a2de9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable { */ def capacity: Int = numWords * 64 + /** + * Clear all set bits. + */ + def clear(): Unit = { + var i = 0 + while (i < numWords) { + words(i) = 0L + i += 1 + } + } + /** * Set all the bits up to a given index */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2170bc73a0fd6..4572d5efc92bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -132,15 +132,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => joins.SortMergeOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil - - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - joins.SortMergeOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index dea9e5e580a1e..ab20ee573ab5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * :: DeveloperApi :: * Performs an sort merge outer join of two child relations. - * - * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. */ @DeveloperApi case class SortMergeOuterJoin( @@ -52,6 +53,8 @@ case class SortMergeOuterJoin( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -62,6 +65,7 @@ case class SortMergeOuterJoin( // For left and right outer joins, the output is partitioned by the streamed input's join keys. case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -71,6 +75,8 @@ case class SortMergeOuterJoin( // For left and right outer joins, the output is ordered by the streamed input's join keys. case LeftOuter => requiredOrders(leftKeys) case RightOuter => requiredOrders(rightKeys) + // there are null rows in both streams, so there is no order + case FullOuter => Nil case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") } @@ -165,6 +171,26 @@ case class SortMergeOuterJoin( new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + numLeftRows, + rightIter = RowIterator.fromScala(rightIter), + numRightRows, + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator( + smjScanner, + resultProj, + numOutputRows).toScala + case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") @@ -271,3 +297,196 @@ private class RightOuterIterator( override def getRow: InternalRow = resultProj(joinedRow) } + +private class SortMergeFullOuterJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + numLeftRows: LongSQLMetric, + rightIter: RowIterator, + numRightRows: LongSQLMetric, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftRow: InternalRow = _ + private[this] var leftRowKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightRowKey: InternalRow = _ + + private[this] var leftIndex: Int = 0 + private[this] var rightIndex: Int = 0 + private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatched: BitSet = new BitSet(1) + private[this] var rightMatched: BitSet = new BitSet(1) + + advancedLeft() + advancedRight() + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. + */ + private def advancedLeft(): Boolean = { + if (leftIter.advanceNext()) { + leftRow = leftIter.getRow + leftRowKey = leftKeyGenerator(leftRow) + numLeftRows += 1 + true + } else { + leftRow = null + leftRowKey = null + false + } + } + + /** + * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. + */ + private def advancedRight(): Boolean = { + if (rightIter.advanceNext()) { + rightRow = rightIter.getRow + rightRowKey = rightKeyGenerator(rightRow) + numRightRows += 1 + true + } else { + rightRow = null + rightRowKey = null + false + } + } + + /** + * Populate the left and right buffers with rows matching the provided key. + * This consumes rows from both iterators until their keys are different from the matching key. + */ + private def findMatchingRows(matchingKey: InternalRow): Unit = { + leftMatches.clear() + rightMatches.clear() + leftIndex = 0 + rightIndex = 0 + + while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { + leftMatches += leftRow.copy() + advancedLeft() + } + while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { + rightMatches += rightRow.copy() + advancedRight() + } + + if (leftMatches.size <= leftMatched.capacity) { + leftMatched.clear() + } else { + leftMatched = new BitSet(leftMatches.size) + } + if (rightMatches.size <= rightMatched.capacity) { + rightMatched.clear() + } else { + rightMatched = new BitSet(rightMatches.size) + } + } + + /** + * Scan the left and right buffers for the next valid match. + * + * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. + * If a left row has no valid matches on the right, or a right row has no valid matches on the + * left, then the row is joined with the null row and the result is considered a valid match. + * + * @return true if a valid match is found, false otherwise. + */ + private def scanNextInBuffered(): Boolean = { + while (leftIndex < leftMatches.size) { + while (rightIndex < rightMatches.size) { + joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + if (boundCondition(joinedRow)) { + leftMatched.set(leftIndex) + rightMatched.set(rightIndex) + rightIndex += 1 + return true + } + rightIndex += 1 + } + rightIndex = 0 + if (!leftMatched.get(leftIndex)) { + // the left row has never matched any right row, join it with null row + joinedRow(leftMatches(leftIndex), rightNullRow) + leftIndex += 1 + return true + } + leftIndex += 1 + } + + while (rightIndex < rightMatches.size) { + if (!rightMatched.get(rightIndex)) { + // the right row has never matched any left row, join it with null row + joinedRow(leftNullRow, rightMatches(rightIndex)) + rightIndex += 1 + return true + } + rightIndex += 1 + } + + // There are no more valid matches in the left and right buffers + false + } + + // --- Public methods -------------------------------------------------------------------------- + + def getJoinedRow(): JoinedRow = joinedRow + + def advanceNext(): Boolean = { + // If we already buffered some matching rows, use them directly + if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (scanNextInBuffered()) { + return true + } + } + + if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { + joinedRow(leftRow.copy(), rightNullRow) + advancedLeft() + true + } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { + joinedRow(leftNullRow, rightRow.copy()) + advancedRight() + true + } else if (leftRow != null && rightRow != null) { + // Both rows are present and neither have null values, + // so we populate the buffers with rows matching the next key + val comp = keyOrdering.compare(leftRowKey, rightRowKey) + if (comp <= 0) { + findMatchingRows(leftRowKey.copy()) + } else { + findMatchingRows(rightRowKey.copy()) + } + scanNextInBuffered() + true + } else { + // Both iterators have been consumed + false + } + } +} + +private class FullOuterIterator( + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() + + override def advanceNext(): Boolean = { + val r = smjScanner.advanceNext() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b05435bad5c5a..7a027e13089e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -83,7 +83,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c2e0bdac17968..09e0237a7cc50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using ShuffledHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } if (joinType != FullOuter) { test(s"$testName using BroadcastHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } + } - test(s"$testName using SortMergeOuterJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } } } From 56a0fe5c6e4ae2929c48fae2d6225558d020e5f9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 9 Sep 2015 18:02:33 -0700 Subject: [PATCH 0029/2089] [SPARK-9772] [PYSPARK] [ML] Add Python API for ml.feature.VectorSlicer Add Python API for ml.feature.VectorSlicer. Author: Yanbo Liang Closes #8102 from yanboliang/SPARK-9772. --- python/pyspark/ml/feature.py | 95 ++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 8c26cfbd5a47d..1c423486be8d9 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,11 +27,11 @@ from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', - 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', - 'StopWordsRemover'] + 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -1298,6 +1298,91 @@ class VectorIndexerModel(JavaModel): """ +@inherit_doc +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + This class takes a feature vector and outputs a new feature vector with a subarray + of the original features. + + The subset of features can be specified with either indices (`setIndices()`) + or names (`setNames()`). At least one feature must be selected. Duplicate features + are not allowed, so there can be no overlap between selected indices and names. + + The output vector will order features with the selected indices first (in the order given), + followed by the selected names (in the order given). + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),), + ... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),), + ... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"]) + >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4]) + >>> vs.transform(df).head().sliced + DenseVector([2.3, 1.0]) + """ + + # a placeholder to make it appear in the generated doc + indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + + "a vector column. There can be no overlap with names.") + names = Param(Params._dummy(), "names", "An array of feature names to select features from " + + "a vector column. These names must be specified by ML " + + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " + + "indices.") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): + """ + __init__(self, inputCol=None, outputCol=None, indices=None, names=None) + """ + super(VectorSlicer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) + self.indices = Param(self, "indices", "An array of indices to select features from " + + "a vector column. There can be no overlap with names.") + self.names = Param(self, "names", "An array of feature names to select features from " + + "a vector column. These names must be specified by ML " + + "org.apache.spark.ml.attribute.Attribute. There can be no overlap " + + "with indices.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): + """ + setParams(self, inputCol=None, outputCol=None, indices=None, names=None): + Sets params for this VectorSlicer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setIndices(self, value): + """ + Sets the value of :py:attr:`indices`. + """ + self._paramMap[self.indices] = value + return self + + def getIndices(self): + """ + Gets the value of indices or its default value. + """ + return self.getOrDefault(self.indices) + + def setNames(self, value): + """ + Sets the value of :py:attr:`names`. + """ + self._paramMap[self.names] = value + return self + + def getNames(self): + """ + Gets the value of names or its default value. + """ + return self.getOrDefault(self.names) + + @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): From 1dc7548c598c4eb4ecc7d5bb8962a735bbd2c0f7 Mon Sep 17 00:00:00 2001 From: Sean Paradiso Date: Wed, 9 Sep 2015 22:09:33 -0700 Subject: [PATCH 0030/2089] [MINOR] [MLLIB] [ML] [DOC] fixed typo: label for negative result should be 0.0 (original: 1.0) Small typo in the example for `LabelledPoint` in the MLLib docs. Author: Sean Paradiso Closes #8680 from sparadiso/docs_mllib_smalltypo. --- docs/mllib-data-types.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 065bf4727624f..d8c7bdc63c70e 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -144,7 +144,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)); // Create a labeled point with a negative label and a sparse feature vector. -LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); +LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); {% endhighlight %} From 48817cc111a9705f40b7c842315eee24291c2198 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 10 Sep 2015 16:42:12 +0200 Subject: [PATCH 0031/2089] [SPARK-10497] [BUILD] [TRIVIAL] Handle both locations for JIRAError with python-jira Location of JIRAError has moved between old and new versions of python-jira package. Longer term it probably makes sense to pin to specific versions (as mentioned in https://issues.apache.org/jira/browse/SPARK-10498 ) but for now, making release tools works with both new and old versions of python-jira. Author: Holden Karau Closes #8661 from holdenk/SPARK-10497-release-utils-does-not-work-with-new-jira-python. --- dev/create-release/releaseutils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 51ab25a6a5bd8..7f152b7f53559 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -24,7 +24,11 @@ try: from jira.client import JIRA - from jira.exceptions import JIRAError + # Old versions have JIRAError in exceptions package, new (0.5+) in utils. + try: + from jira.exceptions import JIRAError + except ImportError: + from jira.utils import JIRAError except ImportError: print "This tool requires the jira-python library" print "Install using 'sudo pip install jira'" From 4f1daa1ef6b36440962f3c8faea3928599e33784 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 10 Sep 2015 10:04:38 -0700 Subject: [PATCH 0032/2089] [SPARK-10065] [SQL] avoid the extra copy when generate unsafe array The reason for this extra copy is that we iterate the array twice: calculate elements data size and copy elements to array buffer. A simple solution is to follow `createCodeForStruct`, we can dynamically grow the buffer when needed and thus don't need to know the data size ahead. This PR also include some typo and style fixes, and did some minor refactor to make sure `input.primitive` is always variable name not code when generate unsafe code. Author: Wenchen Fan Closes #8496 from cloud-fan/avoid-copy. --- .../codegen/GenerateUnsafeProjection.scala | 84 ++++++------------- 1 file changed, 24 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 03c5f449bf9ac..55562facf9652 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") val buffer = ctx.freshName("buffer") ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val tmpBuffer = ctx.freshName("tmpBuffer") val outputIsNull = ctx.freshName("isNull") val numElements = ctx.freshName("numElements") val fixedSize = ctx.freshName("fixedSize") val numBytes = ctx.freshName("numBytes") - val elements = ctx.freshName("elements") val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") val elementName = ctx.freshName("elementName") @@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val convertedElement = createConvertCode(ctx, element, elementType) - // go through the input array to calculate how many bytes we need. - val calculateNumBytes = elementType match { - case _ if ctx.isPrimitiveType(elementType) => - // Should we do word align? - val elementSize = elementType.defaultSize - s""" - $numBytes += $elementSize * $numElements; - """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - $numBytes += 8 * $numElements; - """ - case _ => - val writer = getWriter(elementType) - val elementSize = s"$writer.getSize($elements[$index])" - // TODO(davies): avoid the copy - val unsafeType = elementType match { - case _: StructType => "UnsafeRow" - case _: ArrayType => "UnsafeArrayData" - case _: MapType => "UnsafeMapData" - case _ => ctx.javaType(elementType) - } - val copy = elementType match { - // We reuse the buffer during conversion, need copy it before process next element. - case _: StructType | _: ArrayType | _: MapType => ".copy()" - case _ => "" - } - - val newElements = if (elementType == BinaryType) { - s"new byte[$numElements][]" - } else { - s"new $unsafeType[$numElements]" - } - s""" - final $unsafeType[] $elements = $newElements; - for (int $index = 0; $index < $numElements; $index++) { - ${convertedElement.code} - if (!${convertedElement.isNull}) { - $elements[$index] = ${convertedElement.primitive}$copy; - $numBytes += $elementSize; - } - } - """ - } - val writeElement = elementType match { case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" - ${convertedElement.code} Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - ${convertedElement.code} Platform.putLong( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $cursor += $writer.write( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, - $elements[$index]); + ${convertedElement.primitive}); """ } - val checkNull = elementType match { - case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" - case t: DecimalType => s"$elements[$index] == null" + - s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" - case _ => s"$elements[$index] == null" + val checkNull = convertedElement.isNull + (elementType match { + case t: DecimalType => + s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})" + case _ => "" + }) + + val elementSize = elementType match { + // Should we do word align for primitive types? + case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8" + case _ => + val writer = getWriter(elementType) + s"$writer.getSize(${convertedElement.primitive})" } val code = s""" @@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $fixedSize = 4 * $numElements; int $numBytes = $fixedSize; - $calculateNumBytes - - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - int $cursor = $fixedSize; for (int $index = 0; $index < $numElements; $index++) { + ${convertedElement.code} if ($checkNull) { // If element is null, write the negative value address into offset region. Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { + $numBytes += $elementSize; + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmpBuffer = new byte[$numBytes * 2]; + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; + } Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } From f892d927d7246856dd3ea617b2942873359454bc Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Thu, 10 Sep 2015 10:34:00 -0700 Subject: [PATCH 0033/2089] [SPARK-7142] [SQL] Minor enhancement to BooleanSimplification Optimizer rule Use these in the optimizer as well: A and (not(A) or B) => A and B not(A and B) => not(A) or not(B) not(A or B) => not(A) and not(B) Author: Yash Datta Closes #5700 from saucam/bool_simp. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 9 +++++++++ .../optimizer/BooleanSimplificationSuite.scala | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a430000bef653..d9b50f3c97da0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -434,6 +434,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case (_, Literal(false, BooleanType)) => Literal(false) // a && a => a case (l, r) if l fastEquals r => l + // a && (not(a) || b) => a && b + case (l, Or(l1, r)) if (Not(l) fastEquals l1) => And(l, r) + case (l, Or(r, l1)) if (Not(l) fastEquals l1) => And(l, r) + case (Or(l, l1), r) if (l1 fastEquals Not(r)) => And(l, r) + case (Or(l1, l), r) if (l1 fastEquals Not(r)) => And(l, r) // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, @@ -512,6 +517,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case LessThan(l, r) => GreaterThanOrEqual(l, r) // not(l <= r) => l > r case LessThanOrEqual(l, r) => GreaterThan(l, r) + // not(l || r) => not(l) && not(r) + case Or(l, r) => And(Not(l), Not(r)) + // not(l && r) => not(l) or not(r) + case And(l, r) => Or(Not(l), Not(r)) // not(not(e)) => e case Not(e) => e case _ => not diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1877cff1334bd..cde346e99eb17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -89,6 +89,22 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) } + test("a && (!a || b)") { + checkCondition(('a && (!('a) || 'b )), ('a && 'b)) + + checkCondition(('a && ('b || !('a) )), ('a && 'b)) + + checkCondition(((!('a) || 'b ) && 'a), ('b && 'a)) + + checkCondition((('b || !('a) ) && 'a), ('b && 'a)) + } + + test("!(a && b) , !(a || b)") { + checkCondition((!('a && 'b)), (!('a) || !('b))) + + checkCondition(!('a || 'b), (!('a) && !('b))) + } + private val caseInsensitiveAnalyzer = new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false)) From 49da38e5f728e05e8e929c4dcd37145ba060151d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 10 Sep 2015 11:01:08 -0700 Subject: [PATCH 0034/2089] [SPARK-10301] [SPARK-10428] [SQL] Addresses comments of PR #8583 and #8509 for master Author: Cheng Lian Closes #8670 from liancheng/spark-10301/address-pr-comments. --- .../parquet/CatalystReadSupport.scala | 20 +- .../parquet/CatalystRowConverter.scala | 10 + .../parquet/ParquetQuerySuite.scala | 291 ++++++++++++++++-- .../parquet/ParquetSchemaSuite.scala | 246 ++++++++++++++- 4 files changed, 522 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index dc4ff06df6f22..5a8166fac5418 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -117,14 +117,18 @@ private[parquet] object CatalystReadSupport { // Only clips array types with nested type as element type. clipParquetListType(parquetType.asGroupType(), t.elementType) - case t: MapType if !isPrimitiveCatalystType(t.valueType) => - // Only clips map types with nested type as value type. + case t: MapType + if !isPrimitiveCatalystType(t.keyType) || + !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested key type or value type clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) case t: StructType => clipParquetGroup(parquetType.asGroupType(), t) case _ => + // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able + // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. parquetType } } @@ -204,14 +208,14 @@ private[parquet] object CatalystReadSupport { } /** - * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. The value type - * of the [[MapType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a - * [[StructType]]. Note that key type of any [[MapType]] is always a primitive type. + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or + * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], or + * a [[StructType]]. */ private def clipParquetMapType( parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { - // Precondition of this method, should only be called for maps with nested value types. - assert(!isPrimitiveCatalystType(valueType)) + // Precondition of this method, only handles maps with nested key types or value types. + assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) val repeatedGroup = parquetMap.getType(0).asGroupType() val parquetKeyType = repeatedGroup.getType(0) @@ -221,7 +225,7 @@ private[parquet] object CatalystReadSupport { Types .repeatedGroup() .as(repeatedGroup.getOriginalType) - .addField(parquetKeyType) + .addField(clipParquetType(parquetKeyType, keyType)) .addField(clipParquetType(parquetValueType, valueType)) .named(repeatedGroup.getName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index f17e794b76650..2ff2fda3610b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -123,6 +123,16 @@ private[parquet] class CatalystRowConverter( updater: ParentContainerUpdater) extends CatalystGroupConverter(updater) with Logging { + assert( + parquetType.getFieldCount == catalystType.length, + s"""Field counts of the Parquet schema and the Catalyst schema don't match: + | + |Parquet schema: + |$parquetType + |Catalyst schema: + |${catalystType.prettyJson} + """.stripMargin) + logDebug( s"""Building row converter for the following schema: | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 9edbb522684eb..1c1cfa34ad04b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,6 +22,9 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -228,54 +231,168 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - test("SPARK-10301 Clipping nested structs in requested schema") { + test("SPARK-10301 requested schema clipping - same schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + } + + // This test case is ignored because of parquet-mr bug PARQUET-370 + ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(null, null))) + } + } + + test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L, null, null))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, null, null, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { withTempPath { dir => val path = dir.getCanonicalPath val df = sqlContext .range(1) - .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) - df.write.mode("append").parquet(path) + df.write.parquet(path) - val userDefinedSchema = new StructType() - .add("s", new StructType().add("a", LongType, nullable = true), nullable = true) + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) checkAnswer( sqlContext.read.schema(userDefinedSchema).parquet(path), - Row(Row(0))) + Row(Row(0L, 1L))) } withTempPath { dir => val path = dir.getCanonicalPath - - val df1 = sqlContext + val df = sqlContext .range(1) - .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) - val df2 = sqlContext - .range(1, 2) - .selectExpr("NAMED_STRUCT('b', id, 'c', id) AS s") + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) - df1.write.parquet(path) - df2.write.mode(SaveMode.Append).parquet(path) + df.write.parquet(path) - val userDefinedSchema = new StructType() - .add("s", - new StructType() - .add("a", LongType, nullable = true) - .add("c", LongType, nullable = true), - nullable = true) + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) checkAnswer( sqlContext.read.schema(userDefinedSchema).parquet(path), - Seq( - Row(Row(0, null)), - Row(Row(null, 1)))) + Row(Row(1L, 2L, null))) } + } + test("SPARK-10301 requested schema clipping - deeply nested struct") { withTempPath { dir => val path = dir.getCanonicalPath @@ -304,4 +421,132 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext Row(Row(Seq(Row(0, null))))) } } + + test("SPARK-10301 requested schema clipping - out of order") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, 1, null)), + Row(Row(null, 2, 4)))) + } + } + + test("SPARK-10301 requested schema clipping - schema merging") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df1.write.mode(SaveMode.Append).parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + checkAnswer( + sqlContext + .read + .option("mergeSchema", "true") + .parquet(path) + .selectExpr("s.a", "s.b", "s.c"), + Seq( + Row(0, null, 2), + Row(1, 2, 3))) + } + } + + test("SPARK-10301 requested schema clipping - UDT") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr( + """NAMED_STRUCT( + | 'f0', CAST(id AS STRING), + | 'f1', NAMED_STRUCT( + | 'a', CAST(id + 1 AS INT), + | 'b', CAST(id + 2 AS LONG), + | 'c', CAST(id + 3.5 AS DOUBLE) + | ) + |) AS s + """.stripMargin) + .coalesce(1) + + df.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("f1", new NestedStructUDT, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(NestedStruct(1, 2L, 3.5D)))) + } + } +} + +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + case class NestedStruct(a: Integer, b: Long, c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = + new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(obj: Any): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + obj match { + case n: NestedStruct => + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = { + datum match { + case row: InternalRow => + NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 5331d7c035236..5a8f772c32289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -1012,12 +1012,17 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11Type = new StructType().add("f011", DoubleType, nullable = true) - val f01Type = ArrayType(StringType, containsNull = false) + val f00Type = ArrayType(StringType, containsNull = false) + val f01Type = ArrayType( + new StructType() + .add("f011", DoubleType, nullable = true), + containsNull = false) + val f0Type = new StructType() - .add("f00", f01Type, nullable = false) - .add("f01", f11Type, nullable = false) + .add("f00", f00Type, nullable = false) + .add("f01", f01Type, nullable = false) val f1Type = ArrayType(IntegerType, containsNull = true) + new StructType() .add("f0", f0Type, nullable = false) .add("f1", f1Type, nullable = true) @@ -1046,7 +1051,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { parquetSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary f00_tuple (UTF8); | } | @@ -1061,13 +1066,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11ElementType = new StructType() + val f01ElementType = new StructType() .add("f011", DoubleType, nullable = true) .add("f012", LongType, nullable = true) val f0Type = new StructType() - .add("f00", ArrayType(StringType, containsNull = false), nullable = false) - .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) new StructType().add("f0", f0Type, nullable = false) }, @@ -1075,7 +1080,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary f00_tuple (UTF8); | } | @@ -1095,7 +1100,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { parquetSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary array (UTF8); | } | @@ -1110,13 +1115,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11ElementType = new StructType() + val f01ElementType = new StructType() .add("f011", DoubleType, nullable = true) .add("f012", LongType, nullable = true) val f0Type = new StructType() - .add("f00", ArrayType(StringType, containsNull = false), nullable = false) - .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) new StructType().add("f0", f0Type, nullable = false) }, @@ -1124,7 +1129,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary array (UTF8); | } | @@ -1236,6 +1241,63 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin) + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + testSchemaClipping( "empty requested schema", @@ -1251,4 +1313,160 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), expectedSchema = "message root {}") + + testSchemaClipping( + "disjoint field sets", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = + new StructType() + .add( + "f0", + new StructType() + .add("f02", FloatType, nullable = true) + .add("f03", DoubleType, nullable = true), + nullable = true), + + expectedSchema = + """message root { + | required group f0 { + | optional float f02; + | optional double f03; + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int32 value_f0; + | required int64 value_f1; + | } + | required int32 value; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int64 value_f1; + | required double value_f2; + | } + | required int32 value; + | } + | } + |} + """.stripMargin) } From e04811137680f937669cdcc78771227aeb7cd849 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 10 Sep 2015 11:48:43 -0700 Subject: [PATCH 0035/2089] [SPARK-10466] [SQL] UnsafeRow SerDe exception with data spill Data Spill with UnsafeRow causes assert failure. ``` java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:165) at org.apache.spark.sql.execution.UnsafeRowSerializerInstance$$anon$2.writeKey(UnsafeRowSerializer.scala:75) at org.apache.spark.storage.DiskBlockObjectWriter.write(DiskBlockObjectWriter.scala:180) at org.apache.spark.util.collection.ExternalSorter$$anonfun$writePartitionedFile$2$$anonfun$apply$1.apply(ExternalSorter.scala:688) at org.apache.spark.util.collection.ExternalSorter$$anonfun$writePartitionedFile$2$$anonfun$apply$1.apply(ExternalSorter.scala:687) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at org.apache.spark.util.collection.ExternalSorter$$anonfun$writePartitionedFile$2.apply(ExternalSorter.scala:687) at org.apache.spark.util.collection.ExternalSorter$$anonfun$writePartitionedFile$2.apply(ExternalSorter.scala:683) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at org.apache.spark.util.collection.ExternalSorter.writePartitionedFile(ExternalSorter.scala:683) at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:80) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) ``` To reproduce that with code (thanks andrewor14): ```scala bin/spark-shell --master local --conf spark.shuffle.memoryFraction=0.005 --conf spark.shuffle.sort.bypassMergeThreshold=0 sc.parallelize(1 to 2 * 1000 * 1000, 10) .map { i => (i, i) }.toDF("a", "b").groupBy("b").avg().count() ``` Author: Cheng Hao Closes #8635 from chenghao-intel/unsafe_spill. --- .../util/collection/ExternalSorter.scala | 6 ++ .../sql/execution/UnsafeRowSerializer.scala | 2 +- .../execution/UnsafeRowSerializerSuite.scala | 64 +++++++++++++++++-- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 19287edbaf166..138c05dff19e4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -188,6 +188,12 @@ private[spark] class ExternalSorter[K, V, C]( private val spills = new ArrayBuffer[SpilledFile] + /** + * Number of files this sorter has spilled so far. + * Exposed for testing. + */ + private[spark] def numSpills: Int = spills.size + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 5c18558f9bde7..e060c06d9e2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -72,7 +72,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def writeKey[T: ClassTag](key: T): SerializationStream = { // The key is only needed on the map side when computing partition ids. It does not need to // be shuffled. - assert(key.isInstanceOf[Int]) + assert(null == key || key.isInstanceOf[Int]) this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index bd02c73a26ace..0113d052e338d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.execution -import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} -import org.apache.spark.SparkFunSuite +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark._ /** @@ -40,9 +44,15 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { - val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { val converter = UnsafeProjection.create(schema) - converter.apply(internalRow) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } } test("toUnsafeRow() test helper method") { @@ -87,4 +97,50 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(!deserializerIter.hasNext) assert(input.closed) } + + test("SPARK-10466: external sorter spilling with unsafe row serializer") { + var sc: SparkContext = null + var outputFile: File = null + val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten + Utils.tryWithSafeFinally { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.shuffle.memoryFraction", "0.0001") + + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 1000).iterator.map { i => + (i, converter(Row(i))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + val taskContext = + new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + } { + // Clean up + if (sc != null) { + sc.stop() + } + + // restore the spark env + SparkEnv.set(oldEnv) + + if (outputFile != null) { + outputFile.delete() + } + } + } } From a76bde9dae54c4641e21f3c1ceb4870e3dc91881 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 10 Sep 2015 11:49:53 -0700 Subject: [PATCH 0036/2089] [SPARK-10469] [DOC] Try and document the three options From JIRA: Add documentation for tungsten-sort. From the mailing list "I saw a new "spark.shuffle.manager=tungsten-sort" implemented in https://issues.apache.org/jira/browse/SPARK-7081, but it can't be found its corresponding description in http://people.apache.org/~pwendell/spark-releases/spark-1.5.0-rc3-docs/configuration.html(Currenlty there are only 'sort' and 'hash' two options)." Author: Holden Karau Closes #8638 from holdenk/SPARK-10469-document-tungsten-sort. --- docs/configuration.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index e287591f3fda1..0b1a273916314 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -447,9 +447,12 @@ Apart from these, the following properties are also available, and may be useful
From af3bc59d1f5d9d952c2d7ad1af599c49f1dbdaf0 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 10 Sep 2015 11:58:54 -0700 Subject: [PATCH 0037/2089] [SPARK-8167] Make tasks that fail from YARN preemption not fail job The architecture is that, in YARN mode, if the driver detects that an executor has disconnected, it asks the ApplicationMaster why the executor died. If the ApplicationMaster is aware that the executor died because of preemption, all tasks associated with that executor are not marked as failed. The executor is still removed from the driver's list of available executors, however. There's a few open questions: 1. Should standalone mode have a similar "get executor loss reason" as well? I localized this change as much as possible to affect only YARN, but there could be a valid case to differentiate executor losses in standalone mode as well. 2. I make a pretty strong assumption in YarnAllocator that getExecutorLossReason(executorId) will only be called once per executor id; I do this so that I can remove the metadata from the in-memory map to avoid object accumulation. It's not clear if I'm being overly zealous to save space, however. cc vanzin specifically for review because it collided with some earlier YARN scheduling work. cc JoshRosen because it's similar to output commit coordination we did in the past cc andrewor14 for our discussion on how to get executor exit codes and loss reasons Author: mcheah Closes #8007 from mccheah/feature/preemption-handling. --- .../org/apache/spark/TaskEndReason.scala | 18 +++- .../spark/scheduler/ExecutorLossReason.scala | 14 ++- .../org/apache/spark/scheduler/Pool.scala | 4 +- .../apache/spark/scheduler/Schedulable.scala | 2 +- .../spark/scheduler/TaskSchedulerImpl.scala | 9 +- .../spark/scheduler/TaskSetManager.scala | 21 +++-- .../cluster/CoarseGrainedClusterMessage.scala | 8 +- .../CoarseGrainedSchedulerBackend.scala | 24 +++-- .../cluster/SparkDeploySchedulerBackend.scala | 6 +- .../cluster/YarnSchedulerBackend.scala | 77 ++++++++++++++-- .../mesos/CoarseMesosSchedulerBackend.scala | 4 +- .../cluster/mesos/MesosSchedulerBackend.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 9 +- .../spark/scheduler/TaskSetManagerSuite.scala | 33 ++++++- .../apache/spark/util/JsonProtocolSuite.scala | 10 +- .../spark/deploy/yarn/ApplicationMaster.scala | 7 ++ .../spark/deploy/yarn/YarnAllocator.scala | 92 ++++++++++++++----- 17 files changed, 261 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 934d00dc708b9..2ae878b3e6087 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -48,6 +48,8 @@ case object Success extends TaskEndReason sealed trait TaskFailedReason extends TaskEndReason { /** Error message displayed in the web UI. */ def toErrorString: String + + def shouldEventuallyFailJob: Boolean = true } /** @@ -194,6 +196,12 @@ case object TaskKilled extends TaskFailedReason { case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + /** + * If a task failed because its attempt to commit was denied, do not count this failure + * towards failing the stage. This is intended to prevent spurious stage failures in cases + * where many speculative tasks are launched and denied to commit. + */ + override def shouldEventuallyFailJob: Boolean = false } /** @@ -202,8 +210,14 @@ case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extend * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String) extends TaskFailedReason { - override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" +case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false) + extends TaskFailedReason { + override def toErrorString: String = { + val exitBehavior = if (isNormalExit) "normally" else "abnormally" + s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + } + + override def shouldEventuallyFailJob: Boolean = !isNormalExit } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 2bc43a9186449..0a98c69b89ea5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -23,16 +23,20 @@ import org.apache.spark.executor.ExecutorExitCode * Represents an explanation for a executor or whole slave failing or exiting. */ private[spark] -class ExecutorLossReason(val message: String) { +class ExecutorLossReason(val message: String) extends Serializable { override def toString: String = message } private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String) + extends ExecutorLossReason(reason) + +private[spark] object ExecutorExited { + def apply(exitCode: Int, isNormalExit: Boolean): ExecutorExited = { + ExecutorExited(exitCode, isNormalExit, ExecutorExitCode.explainExitCode(exitCode)) + } } private[spark] case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} + extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 5821afea98982..551e39a81b695 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -83,8 +83,8 @@ private[spark] class Pool( null } - override def executorLost(executorId: String, host: String) { - schedulableQueue.asScala.foreach(_.executorLost(executorId, host)) + override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) { + schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } override def checkSpeculatableTasks(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index a87ef030e69c2..ab00bc8f0bf4e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,7 +42,7 @@ private[spark] trait Schedulable { def addSchedulable(schedulable: Schedulable): Unit def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit + def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1705e7f962de2..1c7bfe89c02ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -332,7 +332,8 @@ private[spark] class TaskSchedulerImpl( // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) + removeExecutor(execId, + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) } } @@ -464,7 +465,7 @@ private[spark] class TaskSchedulerImpl( if (activeExecutorIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) + removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { // We may get multiple executorLost() calls with different loss reasons. For example, one @@ -482,7 +483,7 @@ private[spark] class TaskSchedulerImpl( } /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { + private def removeExecutor(executorId: String, reason: ExecutorLossReason) { activeExecutorIds -= executorId val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) @@ -497,7 +498,7 @@ private[spark] class TaskSchedulerImpl( } } executorIdToHost -= executorId - rootPool.executorLost(executorId, host) + rootPool.executorLost(executorId, host, reason) } def executorAdded(execId: String, host: String) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 818b95d67f6be..62af9031b9f8b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -709,6 +709,11 @@ private[spark] class TaskSetManager( } ef.exception + case e: ExecutorLostFailure if e.isNormalExit => + logInfo(s"Task $tid failed because while it was being computed, its executor" + + s" exited normally. Not marking the task as failed.") + None + case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) None @@ -722,10 +727,9 @@ private[spark] class TaskSetManager( put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { - // If a task failed because its attempt to commit was denied, do not count this failure - // towards failing the stage. This is intended to prevent spurious stage failures in cases - // where many speculative tasks are launched and denied to commit. + if (!isZombie && state != TaskState.KILLED + && reason.isInstanceOf[TaskFailedReason] + && reason.asInstanceOf[TaskFailedReason].shouldEventuallyFailJob) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -778,7 +782,7 @@ private[spark] class TaskSetManager( } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { + override def executorLost(execId: String, host: String, reason: ExecutorLossReason) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) // Re-enqueue pending tasks for this host based on the status of the cluster. Note @@ -809,9 +813,12 @@ private[spark] class TaskSetManager( } } } - // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) + val isNormalExit: Boolean = reason match { + case exited: ExecutorExited => exited.isNormalExit + case _ => false + } + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, isNormalExit)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 06f5438433b6e..d94743677783f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.ExecutorLossReason import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -70,7 +71,8 @@ private[spark] object CoarseGrainedClusterMessages { case object StopExecutors extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) + extends CoarseGrainedClusterMessage case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage @@ -92,6 +94,10 @@ private[spark] object CoarseGrainedClusterMessages { hostToLocalTaskCount: Map[String, Int]) extends CoarseGrainedClusterMessage + // Check if an executor was force-killed but for a normal reason. + // This could be the case if the executor is preempted, for instance. + case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage + case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5730a87f960a0..18771f79b44bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** @@ -82,7 +83,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[RpcAddress, String] + protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") @@ -128,6 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { @@ -185,8 +187,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, - "remote Rpc client disassociated")) + addressToExecutorId + .get(remoteAddress) + .foreach(removeExecutor(_, SlaveLost("remote Rpc client disassociated"))) } // Make fake resource offers on just one executor @@ -227,7 +230,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String): Unit = { + def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -239,9 +242,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, SlaveLost(reason)) + scheduler.executorLost(executorId, reason) listenerBus.post( - SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -263,8 +266,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint( - CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) + driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + } + + protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new DriverEndpoint(rpcEnv, properties) } def stopExecutors() { @@ -304,7 +310,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: ExecutorLossReason) { try { driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bbe51b4a09a22..27491ecf8b97d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,7 +23,7 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( @@ -135,11 +135,11 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) + case Some(code) => ExecutorExited(code, isNormalExit = true, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) - removeExecutor(fullId.split("/")(1), reason.toString) + removeExecutor(fullId.split("/")(1), reason) } override def sufficientResourcesRegistered(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 044f6288fabdd..6a4b536dee191 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,12 +17,13 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Future, ExecutionContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler._ import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{ThreadUtils, RpcUtils} @@ -43,8 +44,10 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( - YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) + private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) + + private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) @@ -53,7 +56,7 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean]( + yarnSchedulerEndpointRef.askWithRetry[Boolean]( RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } @@ -61,7 +64,7 @@ private[spark] abstract class YarnSchedulerBackend( * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -90,6 +93,41 @@ private[spark] abstract class YarnSchedulerBackend( } } + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new YarnDriverEndpoint(rpcEnv, properties) + } + + /** + * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. + * This endpoint communicates with the executors and queries the AM for an executor's exit + * status when the executor is disconnected. + */ + private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + /** + * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint + * handles it by assuming the Executor was lost for a bad reason and removes the executor + * immediately. + * + * In YARN's case however it is crucial to talk to the application master and ask why the + * executor had exited. In particular, the executor may have exited due to the executor + * having been preempted. If the executor "exited normally" according to the application + * master then we pass that information down to the TaskSetManager to inform the + * TaskSetManager that tasks on that lost executor should not count towards a job failure. + * + * TODO there's a race condition where while we are querying the ApplicationMaster for + * the executor loss reason, there is the potential that tasks will be scheduled on + * the executor that failed. We should fix this by having this onDisconnected event + * also "blacklist" executors so that tasks are not assigned to them. + */ + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } + } + } + /** * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ @@ -101,6 +139,33 @@ private[spark] abstract class YarnSchedulerBackend( ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) + private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( + executorId: String, + executorRpcAddress: RpcAddress): Unit = { + amEndpoint match { + case Some(am) => + val lossReasonRequest = GetExecutorLossReason(executorId) + val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + future onSuccess { + case reason: ExecutorLossReason => { + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) + } + } + future onFailure { + case NonFatal(e) => { + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) + } + case t => throw t + } + case None => + logWarning("Attempted to check for an executor loss reason" + + " before the AM has registered!") + } + } + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") @@ -113,6 +178,7 @@ private[spark] abstract class YarnSchedulerBackend( removeExecutor(executorId, reason) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => amEndpoint match { @@ -143,7 +209,6 @@ private[spark] abstract class YarnSchedulerBackend( logWarning("Attempted to kill executors before the AM has registered!") context.reply(false) } - } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 452c32d5411cd..65df8874774ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -32,7 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -364,7 +364,7 @@ private[spark] class CoarseMesosSchedulerBackend( if (slaveIdToTaskId.containsKey(slaveId)) { val taskId: Int = slaveIdToTaskId.get(slaveId) taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) } // TODO: This assumes one Spark executor per Mesos slave, // which may no longer be true after SPARK-5095 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 2e424054be785..18da6d2491280 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -390,7 +390,7 @@ private[spark] class MesosSchedulerBackend( slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) + recordSlaveLost(d, slaveId, ExecutorExited(status, isNormalExit = false)) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index cbc94fd6d54d9..24f78744ad74c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,8 +362,9 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) - case ExecutorLostFailure(executorId) => - ("Executor ID" -> executorId) + case ExecutorLostFailure(executorId, isNormalExit) => + ("Executor ID" -> executorId) ~ + ("Normal Exit" -> isNormalExit) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -794,8 +795,10 @@ private[spark] object JsonProtocol { case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => + val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). + map(_.extract[Boolean]) val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - ExecutorLostFailure(executorId.getOrElse("Unknown")) + ExecutorLostFailure(executorId.getOrElse("Unknown"), isNormalExit.getOrElse(false)) case `unknownReason` => UnknownReason } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index edbdb485c5ea4..f0eadf240943e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -334,7 +334,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Now mark host2 as dead sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") + manager.executorLost("exec2", "host2", SlaveLost()) // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) @@ -504,13 +504,36 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") - manager.executorLost("execC", "host2") + manager.executorLost("execC", "host2", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) sched.removeExecutor("execD") - manager.executorLost("execD", "host1") + manager.executorLost("execD", "host1", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } + test("Executors are added but exit normally while running tasks") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val manager = new TaskSetManager(sched, taskSet, 1, new ManualClock) + sched.addExecutor("execA", "host1") + manager.executorAdded() + sched.addExecutor("execC", "host2") + manager.executorAdded() + assert(manager.resourceOffer("exec1", "host1", ANY).isDefined) + sched.removeExecutor("execA") + manager.executorLost("execA", "host1", ExecutorExited(143, true, "Normal termination")) + assert(!sched.taskSetsFailed.contains(taskSet.id)) + assert(manager.resourceOffer("execC", "host2", ANY).isDefined) + sched.removeExecutor("execC") + manager.executorLost("execC", "host2", ExecutorExited(1, false, "Abnormal termination")) + assert(sched.taskSetsFailed.contains(taskSet.id)) + } + test("test RACK_LOCAL tasks") { // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") @@ -721,8 +744,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") - manager.executorLost("execA", "host1") - manager.executorLost("execB.2", "host2") + manager.executorLost("execA", "host1", SlaveLost()) + manager.executorLost("execB.2", "host2", SlaveLost()) clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 343a4139b0ca8..47e548ef0d442 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,7 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) - testTaskEndReason(ExecutorLostFailure("100")) + testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) // BlockId @@ -295,10 +295,10 @@ class JsonProtocolSuite extends SparkFunSuite { test("ExecutorLostFailure backward compatibility") { // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. - val executorLostFailure = ExecutorLostFailure("100") + val executorLostFailure = ExecutorLostFailure("100", true) val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) .removeField({ _._1 == "Executor ID" }) - val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown", true) assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -577,8 +577,10 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => - case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) => + case (ExecutorLostFailure(execId1, isNormalExit1), + ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) + assert(isNormalExit1 === isNormalExit2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 991b5cec00bd8..93621b44c9183 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -590,6 +590,13 @@ private[spark] class ApplicationMaster( case None => logWarning("Container allocator is not ready to kill executors yet.") } context.reply(true) + + case GetExecutorLossReason(eid) => + Option(allocator) match { + case Some(a) => a.enqueueGetLossReasonRequest(eid, context) + case None => logWarning(s"Container allocator is not ready to find" + + s" executor loss reasons yet.") + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 5f897cbcb4e9f..fd88b8b7fe3b9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,8 +21,9 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -36,8 +37,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.util.Utils /** @@ -93,6 +95,11 @@ private[yarn] class YarnAllocator( sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) } + // Executor loss reason requests that are pending - maps from executor ID for inquiry to a + // list of requesters that should be responded to once we find out why the given executor + // was lost. + private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] + // Keep track of which container is running which executor to remove the executors later // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] @@ -235,9 +242,7 @@ private[yarn] class YarnAllocator( val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - processCompletedContainers(completedContainers.asScala) - logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) } @@ -429,7 +434,7 @@ private[yarn] class YarnAllocator( for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId val alreadyReleased = releasedContainers.remove(containerId) - if (!alreadyReleased) { + val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -440,22 +445,42 @@ private[yarn] class YarnAllocator( // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == ContainerExitStatus.PREEMPTED) { - logInfo("Container preempted: " + containerId) - } else if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed += 1 + val exitStatus = completedContainer.getExitStatus + val (isNormalExit, containerExitReason) = exitStatus match { + case ContainerExitStatus.SUCCESS => + (true, s"Executor for container $containerId exited normally.") + case ContainerExitStatus.PREEMPTED => + // Preemption should count as a normal exit, since YARN preempts containers merely + // to do resource sharing, and tasks that fail due to preempted executors could + // just as easily finish on any other executor. See SPARK-8167. + (true, s"Container $containerId was preempted.") + // Should probably still count memory exceeded exit codes towards task failures + case VMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + case PMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + case unknown => + numExecutorsFailed += 1 + (false, "Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + + } + if (isNormalExit) { + logInfo(containerExitReason) + } else { + logWarning(containerExitReason) } + ExecutorExited(0, isNormalExit, containerExitReason) + } else { + // If we have already released this container, then it must mean + // that the driver has explicitly requested it to be killed + ExecutorExited(completedContainer.getExitStatus, isNormalExit = true, + s"Container $containerId exited from explicit termination request.") } if (allocatedContainerToHostMap.contains(containerId)) { @@ -474,18 +499,35 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - + pendingLossReasonRequests.remove(eid).foreach { pendingRequests => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) // Notify backend about the failure of the executor numUnexpectedContainerRelease += 1 - driverRef.send(RemoveExecutor(eid, - s"Yarn deallocated the executor $eid (container $containerId)")) + driverRef.send(RemoveExecutor(eid, exitReason)) } } } } + /** + * Register that some RpcCallContext has asked the AM why the executor was lost. Note that + * we can only find the loss reason to send back in the next call to allocateResources(). + */ + private[yarn] def enqueueGetLossReasonRequest( + eid: String, + context: RpcCallContext): Unit = synchronized { + if (executorIdToContainer.contains(eid)) { + pendingLossReasonRequests + .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else { + logWarning(s"Tried to get the loss reason for non-existent executor $eid") + } + } + private def internalReleaseContainer(container: Container): Unit = { releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) @@ -501,6 +543,8 @@ private object YarnAllocator { Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") val VMEM_EXCEEDED_PATTERN = Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + val VMEM_EXCEEDED_EXIT_CODE = -103 + val PMEM_EXCEEDED_EXIT_CODE = -104 def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) From f0562e8cdbab7ce40f3186da98595312252f8b5c Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Thu, 10 Sep 2015 12:00:21 -0700 Subject: [PATCH 0038/2089] [SPARK-6350] [MESOS] Fine-grained mode scheduler respects mesosExecutor.cores This is a regression introduced in #4960, this commit fixes it and adds a test. tnachen andrewor14 please review, this should be an easy one. Author: Iulian Dragos Closes #8653 from dragos/issue/mesos/fine-grained-maxExecutorCores. --- .../cluster/mesos/MesosSchedulerBackend.scala | 3 +- .../mesos/MesosSchedulerBackendSuite.scala | 33 ++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 18da6d2491280..8edf7007a5daf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils - /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks @@ -127,7 +126,7 @@ private[spark] class MesosSchedulerBackend( } val builder = MesosExecutorInfo.newBuilder() val (resourcesAfterCpu, usedCpuResources) = - partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 319b3173e7a6e..c4dc560031207 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -42,6 +42,38 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + test("Use configured mesosExecutor.cores for ExecutorInfo") { + val mesosExecutorCores = 3 + val conf = new SparkConf + conf.set("spark.mesos.mesosExecutor.cores", mesosExecutorCores.toString) + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val resources = Arrays.asList( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) + // uri is null. + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") + val executorResources = executorInfo.getResourcesList + val cpus = executorResources.asScala.find(_.getName.equals("cpus")).get.getScalar.getValue + + assert(cpus === mesosExecutorCores) + } + test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") @@ -263,7 +295,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) .setHostname(s"host${id.toString}").build() - val mesosOffers = new java.util.ArrayList[Offer] mesosOffers.add(offer) From a5ef2d0600d5e23ca05fabc1005bb81e5ada0727 Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Thu, 10 Sep 2015 12:03:11 -0700 Subject: [PATCH 0039/2089] [SPARK-10514] [MESOS] waiting for min no of total cores acquired by Spark by implementing the sufficientResourcesRegistered method spark.scheduler.minRegisteredResourcesRatio configuration parameter works for YARN mode but not for Mesos Coarse grained mode. If the parameter specified default value of 0 will be set for spark.scheduler.minRegisteredResourcesRatio in base class and this method will always return true. There are no existing test for YARN mode too. Hence not added test for the same. Author: Akash Mishra Closes #8672 from SleepyThread/master. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 4 ++++ docs/configuration.md | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 65df8874774ca..65cb5016cfcc9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -222,6 +222,10 @@ private[spark] class CoarseMesosSchedulerBackend( markRegistered() } + override def sufficientResourcesRegistered(): Boolean = { + totalCoresAcquired >= maxCores * minRegisteredRatio + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} diff --git a/docs/configuration.md b/docs/configuration.md index 0b1a273916314..1a701f18881fe 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1112,10 +1112,11 @@ Apart from these, the following properties are also available, and may be useful - + - + " * 5)) + } + /** * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. @@ -67,12 +75,19 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") - val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) - jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) - jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markSuccessful() - jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) + val peakExecutionMemory = 10 + taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, + Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } From 7b6c856367b9c36348e80e83959150da9656c4dd Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 14 Sep 2015 15:09:43 -0700 Subject: [PATCH 0083/2089] [SPARK-10564] ThreadingSuite: assertion failures in threads don't fail the test (round 2) This is a follow-up patch to #8723. I missed one case there. Author: Andrew Or Closes #8727 from andrewor14/fix-threading-suite. --- .../org/apache/spark/ThreadingSuite.scala | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index cda2b245526f7..a96a4ce201c21 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,12 +147,12 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) + throwable.foreach { t => throw t } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } - throwable.foreach { t => throw t } } test("set local properties in different thread") { @@ -178,8 +178,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - assert(sc.getLocalProperty("test") === null) throwable.foreach { t => throw t } + assert(sc.getLocalProperty("test") === null) } test("set and get local properties in parent-children thread") { @@ -207,15 +207,16 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw t } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) - throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { val jobStarted = new Semaphore(0) val jobEnded = new Semaphore(0) @volatile var jobResult: JobResult = null + var throwable: Option[Throwable] = None sc = new SparkContext("local", "test") sc.setJobGroup("originalJobGroupId", "description") @@ -232,14 +233,19 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { // Create a new thread which will inherit the current thread's properties val thread = new Thread() { override def run(): Unit = { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") + // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task + try { + sc.parallelize(1 to 100).foreach { x => + Thread.sleep(100) + } + } catch { + case s: SparkException => // ignored so that we don't print noise in test logs } } catch { - case s: SparkException => // ignored so that we don't print noise in test logs + case t: Throwable => + throwable = Some(t) } } } @@ -252,6 +258,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { // modification of the properties object should not affect the properties of running jobs sc.cancelJobGroup("originalJobGroupId") jobEnded.tryAcquire(10, TimeUnit.SECONDS) + throwable.foreach { t => throw t } assert(jobResult.isInstanceOf[JobFailed]) } } From 1a0955250bb65cd6f5818ad60efb62ea4b45d18e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 14 Sep 2015 21:47:40 -0400 Subject: [PATCH 0084/2089] [SPARK-9851] Support submitting map stages individually in DAGScheduler This patch adds support for submitting map stages in a DAG individually so that we can make downstream decisions after seeing statistics about their output, as part of SPARK-9850. I also added more comments to many of the key classes in DAGScheduler. By itself, the patch is not super useful except maybe to switch between a shuffle and broadcast join, but with the other subtasks of SPARK-9850 we'll be able to do more interesting decisions. The main entry point is SparkContext.submitMapStage, which lets you run a map stage and see stats about the map output sizes. Other stats could also be collected through accumulators. See AdaptiveSchedulingSuite for a short example. Author: Matei Zaharia Closes #8180 from mateiz/spark-9851. --- .../apache/spark/MapOutputStatistics.scala | 27 ++ .../org/apache/spark/MapOutputTracker.scala | 49 +++- .../scala/org/apache/spark/SparkContext.scala | 17 ++ .../apache/spark/scheduler/ActiveJob.scala | 34 ++- .../apache/spark/scheduler/DAGScheduler.scala | 251 +++++++++++++++--- .../spark/scheduler/DAGSchedulerEvent.scala | 10 + .../apache/spark/scheduler/ResultStage.scala | 17 +- .../spark/scheduler/ShuffleMapStage.scala | 13 +- .../org/apache/spark/scheduler/Stage.scala | 26 +- .../scala/org/apache/spark/FailureSuite.scala | 21 ++ .../scheduler/AdaptiveSchedulingSuite.scala | 65 +++++ .../spark/scheduler/DAGSchedulerSuite.scala | 243 ++++++++++++++++- 12 files changed, 710 insertions(+), 63 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/MapOutputStatistics.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala new file mode 100644 index 0000000000000..f8a6f1d0d8cbb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future. + * + * @param shuffleId ID of the shuffle + * @param bytesByPartitionId approximate number of output bytes for each map output partition + * (may be inexact due to use of compressed map statuses) + */ +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a387592783850..94eb8daa85c53 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io._ +import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} @@ -132,13 +133,43 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") - val startTime = System.currentTimeMillis + val statuses = getStatuses(shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } + } + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + val statuses = getStatuses(dep.shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { // Someone else is fetching it; wait for them to be done @@ -160,7 +191,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so + // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { @@ -175,22 +206,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } - logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + return statuses } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e27b3c4962221..dee6091ce3caf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1984,6 +1984,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new SimpleFutureAction(waiter, resultFunc) } + /** + * Submit a map stage for execution. This is currently an internal API only, but might be + * promoted to DeveloperApi in the future. + */ + private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C]) + : SimpleFutureAction[MapOutputStatistics] = { + assertNotStopped() + val callSite = getCallSite() + var result: MapOutputStatistics = null + val waiter = dagScheduler.submitMapStage( + dependency, + (r: MapOutputStatistics) => { result = r }, + callSite, + localProperties.get) + new SimpleFutureAction[MapOutputStatistics](waiter, result) + } + /** * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] * for more information. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 50a69379412d2..a3d2db31301b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -23,18 +23,42 @@ import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** - * Tracks information about an active job in the DAGScheduler. + * A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a + * ResultStage to execute an action, or a map-stage job, which computes the map outputs for a + * ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive + * query planning, to look at map output statistics before submitting later stages. We distinguish + * between these two types of jobs using the finalStage field of this class. + * + * Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's + * submitJob or submitMapStage methods. However, either type of job may cause the execution of + * other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of + * these previous stages. These dependencies are managed inside DAGScheduler. + * + * @param jobId A unique ID for this job. + * @param finalStage The stage that this job computes (either a ResultStage for an action or a + * ShuffleMapStage for submitMapStage). + * @param callSite Where this job was initiated in the user's program (shown on UI). + * @param listener A listener to notify if tasks in this job finish or the job fails. + * @param properties Scheduling properties attached to the job, such as fair scheduler pool name. */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: ResultStage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], + val finalStage: Stage, val callSite: CallSite, val listener: JobListener, val properties: Properties) { - val numPartitions = partitions.length + /** + * Number of partitions we need to compute for this job. Note that result stages may not need + * to compute all partitions in their target RDD, for actions like first() and lookup(). + */ + val numPartitions = finalStage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.rdd.partitions.length + } + + /** Which partitions of the stage have finished */ val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09e963f5cdf60..b4f90e8347894 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -45,17 +45,65 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. * - * In addition to coming up with a DAG of stages, this class also determines the preferred + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs (MappedRDD, FilteredRDD, etc). + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * * Here's a checklist to use when making or reviewing changes to this class: * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ @@ -295,12 +343,12 @@ class DAGScheduler( */ private def newResultStage( rdd: RDD[_], - numTasks: Int, + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) - val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) - + val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -500,12 +548,25 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - job.finalStage.resultOfJob = None + job.finalStage match { + case r: ResultStage => + r.resultOfJob = None + case m: ShuffleMapStage => + m.mapStageJobs = m.mapStageJobs.filter(_ != job) + } } /** - * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name */ def submitJob[T, U]( rdd: RDD[T], @@ -524,6 +585,7 @@ class DAGScheduler( val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { + // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } @@ -536,6 +598,18 @@ class DAGScheduler( waiter } + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. Throws an exception if the job fials, or returns normally if successful. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -559,6 +633,17 @@ class DAGScheduler( } } + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -575,6 +660,41 @@ class DAGScheduler( listener.awaitResult() // Will throw an exception if the job fails } + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) + waiter + } + /** * Cancel a job that is running or waiting in the queue. */ @@ -583,6 +703,9 @@ class DAGScheduler( eventProcessLoop.post(JobCancelled(jobId)) } + /** + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) @@ -720,31 +843,77 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) + finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } - if (finalStage != null) { - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions".format( - job.jobId, callSite.shortForm, partitions.length)) - logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val jobSubmissionTime = clock.getTimeMillis() - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + submitWaitingStages() + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties) { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.size)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.mapStageJobs = job :: finalStage.mapStageJobs + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (!finalStage.outputLocs.contains(Nil)) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) } + submitWaitingStages() } @@ -814,7 +983,7 @@ class DAGScheduler( case s: ResultStage => val job = s.resultOfJob.get partitionsToCompute.map { id => - val p = job.partitions(id) + val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) }.toMap } @@ -844,7 +1013,7 @@ class DAGScheduler( case stage: ShuffleMapStage => closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() } taskBinary = sc.broadcast(taskBinaryBytes) @@ -875,7 +1044,7 @@ class DAGScheduler( case stage: ResultStage => val job = stage.resultOfJob.get partitionsToCompute.map { id => - val p: Int = job.partitions(id) + val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, @@ -1052,13 +1221,21 @@ class DAGScheduler( logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + .map(_._2).mkString(", ")) submitStage(shuffleStage) + } else { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } } // Note: newly runnable stages will be submitted below when we submit waiting stages } - } + } case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") @@ -1412,6 +1589,17 @@ class DAGScheduler( Nil } + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + def stop() { logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() @@ -1445,6 +1633,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index f72a52e85dc15..dda3b6cc7f960 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.CallSite */ private[scheduler] sealed trait DAGSchedulerEvent +/** A result-yielding job was submitted on a target RDD */ private[scheduler] case class JobSubmitted( jobId: Int, finalRDD: RDD[_], @@ -45,6 +46,15 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +/** A map stage as submitted to run as a separate job */ +private[scheduler] case class MapStageSubmitted( + jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index bf81b9aca4810..c0451da1f0247 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -17,23 +17,30 @@ package org.apache.spark.scheduler +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * The ResultStage represents the final stage in a job. + * ResultStages apply a function on some partitions of an RDD to compute the result of an action. + * The ResultStage object captures the function to execute, `func`, which will be applied to each + * partition, and the set of partition IDs, `partitions`. Some stages may not run on all partitions + * of the RDD, for actions like first() and lookup(). */ private[spark] class ResultStage( id: Int, rdd: RDD[_], - numTasks: Int, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], parents: List[Stage], firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { - // The active job for this result stage. Will be empty if the job has already finished - // (e.g., because the job was cancelled). + /** + * The active job for this result stage. Will be empty if the job has already finished + * (e.g., because the job was cancelled). + */ var resultOfJob: Option[ActiveJob] = None override def toString: String = "ResultStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 48d8d8e9c4b78..7d92960876403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -23,7 +23,15 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** - * The ShuffleMapStage represents the intermediate stages in a job. + * ShuffleMapStages are intermediate stages in the execution DAG that produce data for a shuffle. + * They occur right before each shuffle operation, and might contain multiple pipelined operations + * before that (e.g. map and filter). When executed, they save map output files that can later be + * fetched by reduce tasks. The `shuffleDep` field describes the shuffle each stage is part of, + * and variables like `outputLocs` and `numAvailableOutputs` track how many map outputs are ready. + * + * ShuffleMapStages can also be submitted independently as jobs with DAGScheduler.submitMapStage. + * For such stages, the ActiveJobs that submitted them are tracked in `mapStageJobs`. Note that + * there can be multiple ActiveJobs trying to compute the same shuffle map stage. */ private[spark] class ShuffleMapStage( id: Int, @@ -37,6 +45,9 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id + /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ + var mapStageJobs: List[ActiveJob] = Nil + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index c086535782c23..b37eccbd0f7b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -24,27 +24,33 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * A stage is a set of independent tasks all computing the same function that need to run as part + * A stage is a set of parallel tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the * DAGScheduler runs these stages in topological order. * * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. + * other stage(s), or a result stage, in which case its tasks directly compute a Spark action + * (e.g. count(), save(), etc) by running a function on an RDD. For shuffle map stages, we also + * track the nodes that each output partition is on. * * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * - * The callSite provides a location in user code which relates to the stage. For a shuffle map - * stage, the callSite gives the user code that created the RDD being shuffled. For a result - * stage, the callSite gives the user code that executes the associated action (e.g. count()). - * - * A single stage can consist of multiple attempts. In that case, the latestInfo field will - * be updated for each attempt. + * Finally, a single stage can be re-executed in multiple attempts due to fault recovery. In that + * case, the Stage object will track multiple StageInfo objects to pass to listeners or the web UI. + * The latest one will be accessible through latestInfo. * + * @param id Unique stage ID + * @param rdd RDD that this stage runs on: for a shuffle map stage, it's the RDD we run map tasks + * on, while for a result stage, it's the target RDD that we ran an action on + * @param numTasks Total number of tasks in stage; result stages in particular may not need to + * compute all partitions, e.g. for first(), lookup(), and take(). + * @param parents List of stages that this stage depends on (through shuffle dependencies). + * @param firstJobId ID of the first job this stage was part of, for FIFO scheduling. + * @param callSite Location in the user program associated with this stage: either where the target + * RDD was created, for a shuffle map stage, or where the action for a result stage was called. */ private[scheduler] abstract class Stage( val id: Int, diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index aa50a49c50232..f58756e6f6179 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -217,6 +217,27 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + // Run a 3-task map stage where one task fails once. + test("failure in tasks in a submitMapStage") { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + sc.submitMapStage(dep).get() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala new file mode 100644 index 0000000000000..3fe28027c3c21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.{ShuffledRDDPartition, RDD, ShuffledRDD} +import org.apache.spark._ + +object AdaptiveSchedulingSuiteState { + var tasksRun = 0 + + def clear(): Unit = { + tasksRun = 0 + } +} + +/** A special ShuffledRDD where we can pass a ShuffleDependency object to use */ +class CustomShuffledRDD[K, V, C](@transient dep: ShuffleDependency[K, V, C]) + extends RDD[(K, C)](dep.rdd.context, Seq(dep)) { + + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[(K, C)]] + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](dep.partitioner.numPartitions)(i => new ShuffledRDDPartition(i)) + } +} + +class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext { + test("simple use of submitMapStage") { + try { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.parallelize(1 to 3, 3).map { x => + AdaptiveSchedulingSuiteState.tasksRun += 1 + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep) + sc.submitMapStage(dep).get() + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + assert(shuffled.collect().toSet == Set((1, 1), (2, 2), (3, 3))) + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + } finally { + AdaptiveSchedulingSuiteState.clear() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1b9ff740ff530..1c55f90ad9b44 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -152,6 +152,14 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception) = { failure = exception } } + /** A simple helper class for creating custom JobListeners */ + class SimpleListener extends JobListener { + val results = new HashMap[Int, Any] + var failure: Exception = null + override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result) + override def jobFailed(exception: Exception): Unit = { failure = exception } + } + before { sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() @@ -229,7 +237,7 @@ class DAGSchedulerSuite } } - /** Sends the rdd to the scheduler for scheduling and returns the job id. */ + /** Submits a job to the scheduler and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], @@ -240,6 +248,15 @@ class DAGSchedulerSuite jobId } + /** Submits a map stage to the scheduler and returns the job id. */ + private def submitMapStage( + shuffleDep: ShuffleDependency[_, _, _], + listener: JobListener = jobListener): Int = { + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener)) + jobId + } + /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { runEvent(TaskSetFailed(taskSet, message, None)) @@ -1313,6 +1330,230 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("simple map stage submission") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(results.size === 0) // No results yet + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage; it should directly do the reduce + submit(reduceRdd, Array(0)) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with reduce stage also depending on the data") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit the map stage by itself + submitMapStage(shuffleDep) + + // Submit a reduce job that depends on this map stage + submit(reduceRdd, Array(0)) + + // Complete tasks for the map stage + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + + // Complete tasks for the reduce stage + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with fetch failure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage, but where one reduce will fail a fetch + submit(reduceRdd, Array(0, 1)) + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch + // from, then TaskSet 3 will run the reduce stage + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + + // Run another reduce job without a failure; this should just work + submit(reduceRdd, Array(0, 1)) + complete(taskSets(4), Seq( + (Success, 44), + (Success, 45))) + assert(results === Map(0 -> 44, 1 -> 45)) + results.clear() + assertDataStructuresEmpty() + + // Resubmit the map stage; this should also just work + submitMapStage(shuffleDep) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + } + + /** + * In this test, we have three RDDs with shuffle dependencies, and we submit map stage jobs + * that are waiting on each one, as well as a reduce job on the last one. We test that all of + * these jobs complete even if there are some fetch failures in both shuffles. + */ + test("map stage submission with multiple shared stages and failures") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1)) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + val rdd3 = new MyRDD(sc, 2, List(dep2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + val listener3 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + submit(rdd3, Array(0, 1), listener = listener3) + + // Complete the first stage + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.size)), + (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting the second stage, show a fetch failure + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + assert(listener2.results.size === 0) // Second stage listener should not have a result yet + + // Stage 0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result + + // Stage 1 should now be running as task set 3; make its first task succeed + assert(taskSets(3).stageId === 1) + complete(taskSets(3), Seq( + (Success, makeMapStatus("hostB", rdd2.partitions.size)), + (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(listener2.results.size === 1) + + // Finally, the reduce job should be running as task set 4; make it see a fetch failure, + // then make it run again and succeed + assert(taskSets(4).stageId === 2) + complete(taskSets(4), Seq( + (Success, 52), + (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 + assert(taskSets(5).stageId === 1) + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + complete(taskSets(6), Seq( + (Success, 53))) + assert(listener3.results === Map(0 -> 52, 1 -> 53)) + assertDataStructuresEmpty() + } + + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from that executor. We want to make sure the stage is not reported + * as done until all tasks have completed. + */ + test("map stage submission with executor failure late map task completions") { + val shuffleMapRdd = new MyRDD(sc, 3, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + + submitMapStage(shuffleDep) + + val oldTaskSet = taskSets(0) + runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Pretend host A was lost + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + + // Suppose we also get a completed event from task 1 on the same host; this should be ignored + runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // A completion from another task should work because it's a non-failed host + runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Now complete tasks in the second task set + val newTaskSet = taskSets(1) + assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 1) // Map stage job should now finally be complete + assertDataStructuresEmpty() + + // Also test that a reduce stage using this shuffled data can immediately run + val reduceRDD = new MyRDD(sc, 2, List(shuffleDep)) + results.clear() + submit(reduceRDD, Array(0, 1)) + complete(taskSets(2), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From 55204181004c105c7a3e8c31a099b37e48bfd953 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Sep 2015 19:46:34 -0700 Subject: [PATCH 0085/2089] [SPARK-10542] [PYSPARK] fix serialize namedtuple Author: Davies Liu Closes #8707 from davies/fix_namedtuple. --- python/pyspark/cloudpickle.py | 15 ++++++++++++++- python/pyspark/serializers.py | 1 + python/pyspark/tests.py | 5 +++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 3b647985801b7..95b3abc74244b 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,6 +350,11 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ + # workaround for namedtuple (hijacked by PySpark) + if getattr(obj, '_is_namedtuple_', False): + self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) + return + self.save(_load_class) self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) d.pop('__doc__', None) @@ -382,7 +387,7 @@ def save_instancemethod(self, obj): self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + obj=obj) dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): @@ -744,6 +749,14 @@ def _load_class(cls, d): return cls +def _load_namedtuple(name, fields): + """ + Loads a class generated by namedtuple + """ + from collections import namedtuple + return namedtuple(name, fields) + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 411b4dbf481f1..2a1326947f4f5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -359,6 +359,7 @@ def _hack_namedtuple(cls): def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ + cls._is_namedtuple_ = True return cls diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8bfed074c9052..647504c32f156 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -218,6 +218,11 @@ def test_namedtuple(self): p2 = loads(dumps(p1, 2)) self.assertEqual(p1, p2) + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + def test_itemgetter(self): from operator import itemgetter ser = CloudPickleSerializer() From 4ae4d54794778042b2cc983e52757edac02412ab Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Sep 2015 21:37:43 -0700 Subject: [PATCH 0086/2089] [SPARK-9793] [MLLIB] [PYSPARK] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly PySpark DenseVector, SparseVector ```__eq__``` method should use semantics equality, and DenseVector can compared with SparseVector. Implement PySpark DenseVector, SparseVector ```__hash__``` method based on the first 16 entries. That will make PySpark Vector objects can be used in collections. Author: Yanbo Liang Closes #8166 from yanboliang/spark-9793. --- python/pyspark/mllib/linalg/__init__.py | 90 ++++++++++++++++++++----- python/pyspark/mllib/tests.py | 32 +++++++++ 2 files changed, 107 insertions(+), 15 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 334dc8e38bb8f..380f86e9b44f8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -25,6 +25,7 @@ import sys import array +import struct if sys.version >= '3': basestring = str @@ -122,6 +123,13 @@ def _format_float_list(l): return [_format_float(x) for x in l] +def _double_to_long_bits(value): + if np.isnan(value): + value = float('nan') + # pack double into 64 bits, then unpack as long int + return struct.unpack('Q', struct.pack('d', value))[0] + + class VectorUDT(UserDefinedType): """ SQL user-defined type (UDT) for Vector. @@ -404,11 +412,31 @@ def __repr__(self): return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): - return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) + if isinstance(other, DenseVector): + return np.array_equal(self.array, other.array) + elif isinstance(other, SparseVector): + if len(self) != other.size: + return False + return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return False def __ne__(self, other): return not self == other + def __hash__(self): + size = len(self) + result = 31 + size + nnz = 0 + i = 0 + while i < size and nnz < 128: + if self.array[i] != 0: + result = 31 * result + i + bits = _double_to_long_bits(self.array[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + def __getattr__(self, item): return getattr(self.array, item) @@ -704,20 +732,14 @@ def __repr__(self): return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): - """ - Test SparseVectors for equality. - - >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v1 == v2 - True - >>> v1 != v2 - False - """ - return (isinstance(other, self.__class__) - and other.size == self.size - and np.array_equal(other.indices, self.indices) - and np.array_equal(other.values, self.values)) + if isinstance(other, SparseVector): + return other.size == self.size and np.array_equal(other.indices, self.indices) \ + and np.array_equal(other.values, self.values) + elif isinstance(other, DenseVector): + if self.size != len(other): + return False + return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return False def __getitem__(self, index): inds = self.indices @@ -739,6 +761,19 @@ def __getitem__(self, index): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + result = 31 + self.size + nnz = 0 + i = 0 + while i < len(self.values) and nnz < 128: + if self.values[i] != 0: + result = 31 * result + int(self.indices[i]) + bits = _double_to_long_bits(self.values[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + class Vectors(object): @@ -841,6 +876,31 @@ def parse(s): def zeros(size): return DenseVector(np.zeros(size)) + @staticmethod + def _equals(v1_indices, v1_values, v2_indices, v2_values): + """ + Check equality between sparse/dense vectors, + v1_indices and v2_indices assume to be strictly increasing. + """ + v1_size = len(v1_values) + v2_size = len(v2_values) + k1 = 0 + k2 = 0 + all_equal = True + while all_equal: + while k1 < v1_size and v1_values[k1] == 0: + k1 += 1 + while k2 < v2_size and v2_values[k2] == 0: + k2 += 1 + + if k1 >= v1_size or k2 >= v2_size: + return k1 >= v1_size and k2 >= v2_size + + all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] + k1 += 1 + k2 += 1 + return all_equal + class Matrix(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5097c5e8ba4cd..636f9a06cab7b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -194,6 +194,38 @@ def test_squared_distance(self): self.assertEquals(3.0, _squared_distance(sv, arr)) self.assertEquals(3.0, _squared_distance(sv, narr)) + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEquals(hash(v1), hash(v2)) + self.assertEquals(hash(v1), hash(v3)) + self.assertEquals(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEquals(v1, v2) + self.assertEquals(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + def test_conversion(self): # numpy arrays should be automatically upcast to float64 # tests for fix of [SPARK-5089] From 610971ecfe858b1a48ce69b25614afe52bcbe77f Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 14 Sep 2015 21:58:52 -0700 Subject: [PATCH 0087/2089] [SPARK-10273] Add @since annotation to pyspark.mllib.feature Duplicated the since decorator from pyspark.sql into pyspark (also tweaked to handle functions without docstrings). Added since to methods + "versionadded::" to classes (derived from the git file history in pyspark). Author: noelsmith Closes #8633 from noel-smith/SPARK-10273-since-mllib-feature. --- python/pyspark/mllib/feature.py | 58 ++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f921e3ad1a314..7b077b058c3fd 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -84,11 +84,14 @@ class Normalizer(VectorTransformer): >>> nor2 = Normalizer(float("inf")) >>> nor2.transform(v) DenseVector([0.0, 0.5, 1.0]) + + .. versionadded:: 1.2.0 """ def __init__(self, p=2.0): assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) + @since('1.2.0') def transform(self, vector): """ Applies unit length normalization on a vector. @@ -133,7 +136,11 @@ class StandardScalerModel(JavaVectorTransformer): .. note:: Experimental Represents a StandardScaler model that can transform vectors. + + .. versionadded:: 1.2.0 """ + + @since('1.2.0') def transform(self, vector): """ Applies standardization transformation on a vector. @@ -149,6 +156,7 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + @since('1.4.0') def setWithMean(self, withMean): """ Setter of the boolean which decides @@ -157,6 +165,7 @@ def setWithMean(self, withMean): self.call("setWithMean", withMean) return self + @since('1.4.0') def setWithStd(self, withStd): """ Setter of the boolean which decides @@ -189,6 +198,8 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + + .. versionadded:: 1.2.0 """ def __init__(self, withMean=False, withStd=True): if not (withMean or withStd): @@ -196,6 +207,7 @@ def __init__(self, withMean=False, withStd=True): self.withMean = withMean self.withStd = withStd + @since('1.2.0') def fit(self, dataset): """ Computes the mean and variance and stores as a model to be used @@ -215,7 +227,11 @@ class ChiSqSelectorModel(JavaVectorTransformer): .. note:: Experimental Represents a Chi Squared selector model. + + .. versionadded:: 1.4.0 """ + + @since('1.4.0') def transform(self, vector): """ Applies transformation on a vector. @@ -245,10 +261,13 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) + + .. versionadded:: 1.4.0 """ def __init__(self, numTopFeatures): self.numTopFeatures = int(numTopFeatures) + @since('1.4.0') def fit(self, data): """ Returns a ChiSquared feature selector. @@ -265,6 +284,8 @@ def fit(self, data): class PCAModel(JavaVectorTransformer): """ Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + + .. versionadded:: 1.5.0 """ @@ -281,6 +302,8 @@ class PCA(object): 1.648... >>> pcArray[1] -4.013... + + .. versionadded:: 1.5.0 """ def __init__(self, k): """ @@ -288,6 +311,7 @@ def __init__(self, k): """ self.k = int(k) + @since('1.5.0') def fit(self, data): """ Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -312,14 +336,18 @@ class HashingTF(object): >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) + + .. versionadded:: 1.2.0 """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + @since('1.2.0') def indexOf(self, term): """ Returns the index of the input term. """ return hash(term) % self.numFeatures + @since('1.2.0') def transform(self, document): """ Transforms the input document (list of terms) to term frequency @@ -339,7 +367,10 @@ def transform(self, document): class IDFModel(JavaVectorTransformer): """ Represents an IDF model that can transform term frequency vectors. + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, x): """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -358,6 +389,7 @@ def transform(self, x): """ return JavaVectorTransformer.transform(self, x) + @since('1.4.0') def idf(self): """ Returns the current IDF vector. @@ -401,10 +433,13 @@ class IDF(object): DenseVector([0.0, 0.0, 1.3863, 0.863]) >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0))) SparseVector(4, {1: 0.0, 3: 0.5754}) + + .. versionadded:: 1.2.0 """ def __init__(self, minDocFreq=0): self.minDocFreq = minDocFreq + @since('1.2.0') def fit(self, dataset): """ Computes the inverse document frequency. @@ -420,7 +455,10 @@ def fit(self, dataset): class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, word): """ Transforms a word to its vector representation @@ -435,6 +473,7 @@ def transform(self, word): except Py4JJavaError: raise ValueError("%s not found" % word) + @since('1.2.0') def findSynonyms(self, word, num): """ Find synonyms of a word @@ -450,6 +489,7 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + @since('1.4.0') def getVectors(self): """ Returns a map of words to their vector representations. @@ -457,7 +497,11 @@ def getVectors(self): return self.call("getVectors") @classmethod + @since('1.5.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) return Word2VecModel(jmodel) @@ -507,6 +551,8 @@ class Word2Vec(object): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.2.0 """ def __init__(self): """ @@ -519,6 +565,7 @@ def __init__(self): self.seed = random.randint(0, sys.maxsize) self.minCount = 5 + @since('1.2.0') def setVectorSize(self, vectorSize): """ Sets vector size (default: 100). @@ -526,6 +573,7 @@ def setVectorSize(self, vectorSize): self.vectorSize = vectorSize return self + @since('1.2.0') def setLearningRate(self, learningRate): """ Sets initial learning rate (default: 0.025). @@ -533,6 +581,7 @@ def setLearningRate(self, learningRate): self.learningRate = learningRate return self + @since('1.2.0') def setNumPartitions(self, numPartitions): """ Sets number of partitions (default: 1). Use a small number for @@ -541,6 +590,7 @@ def setNumPartitions(self, numPartitions): self.numPartitions = numPartitions return self + @since('1.2.0') def setNumIterations(self, numIterations): """ Sets number of iterations (default: 1), which should be smaller @@ -549,6 +599,7 @@ def setNumIterations(self, numIterations): self.numIterations = numIterations return self + @since('1.2.0') def setSeed(self, seed): """ Sets random seed. @@ -556,6 +607,7 @@ def setSeed(self, seed): self.seed = seed return self + @since('1.4.0') def setMinCount(self, minCount): """ Sets minCount, the minimum number of times a token must appear @@ -564,6 +616,7 @@ def setMinCount(self, minCount): self.minCount = minCount return self + @since('1.2.0') def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -596,10 +649,13 @@ class ElementwiseProduct(VectorTransformer): >>> rdd = sc.parallelize([a, b]) >>> eprod.transform(rdd).collect() [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + + .. versionadded:: 1.5.0 """ def __init__(self, scalingVector): self.scalingVector = _convert_to_vector(scalingVector) + @since('1.5.0') def transform(self, vector): """ Computes the Hadamard product of the vector. From a2249359d5b0368318a714b292bb1d0dc70c0e27 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 14 Sep 2015 21:59:40 -0700 Subject: [PATCH 0088/2089] [SPARK-10275] [MLLIB] Add @since annotation to pyspark.mllib.random Author: Yu ISHIKAWA Closes #8666 from yu-iskw/SPARK-10275. --- python/pyspark/mllib/random.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 06fbc0eb6aef0..9c733b1332bc0 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -21,6 +21,7 @@ from functools import wraps +from pyspark import since from pyspark.mllib.common import callMLlibFunc @@ -39,9 +40,12 @@ class RandomRDDs(object): """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + + .. addedversion:: 1.1.0 """ @staticmethod + @since("1.1.0") def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -72,6 +76,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.1.0") def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -100,6 +105,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.3.0") def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the log normal @@ -132,6 +138,7 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): size, numPartitions, seed) @staticmethod + @since("1.1.0") def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson @@ -158,6 +165,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Exponential @@ -184,6 +192,7 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Gamma @@ -216,6 +225,7 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -241,6 +251,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -266,6 +277,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -300,6 +312,7 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed @staticmethod @toArray + @since("1.1.0") def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -330,6 +343,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -360,6 +374,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No @staticmethod @toArray + @since("1.3.0") def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn From 833be73314b85b390a9007ed6ed63dc47bbd9e4f Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 14 Sep 2015 23:40:29 -0700 Subject: [PATCH 0089/2089] Small fixes to docs Links work now properly + consistent use of *Spark standalone cluster* (Spark uppercase + lowercase the rest -- seems agreed in the other places in the docs). Author: Jacek Laskowski Closes #8759 from jaceklaskowski/docs-submitting-apps. --- docs/submitting-applications.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e525..7ea4d6f1a3f8f 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -65,8 +65,8 @@ For Python applications, simply pass a `.py` file in the place of ` Date: Mon, 14 Sep 2015 23:41:06 -0700 Subject: [PATCH 0090/2089] [SPARK-10598] [DOCS] Comments preceding toMessage method state: "The edge partition is encoded in the lower * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.". References to bytes should be changed to bits. This contribution is my original work and I license the work to the Spark project under it's open source license. Author: Robin East Closes #8756 from insidedctm/master. --- .../org/apache/spark/graphx/impl/RoutingTablePartition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index eb3c997e0f3c0..4f1260a5a67b2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -34,7 +34,7 @@ object RoutingTablePartition { /** * A message from an edge partition to a vertex specifying the position in which the edge * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower - * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + * 30 bits of the Int, and the position is encoded in the upper 2 bits of the Int. */ type RoutingTableMessage = (VertexId, Int) From 09b7e7c19897549a8622aec095f27b8b38a1a4d3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 00:54:20 -0700 Subject: [PATCH 0091/2089] Update version to 1.6.0-SNAPSHOT. Author: Reynold Xin Closes #8350 from rxin/1.6. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- core/src/main/scala/org/apache/spark/package.scala | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-assembly/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt-assembly/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/kinesis-asl-assembly/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- pom.xml | 2 +- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 13 +++++++++++-- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- unsafe/pom.xml | 2 +- yarn/pom.xml | 2 +- 38 files changed, 49 insertions(+), 40 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d0d7201f004a2..a3a16c42a6214 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.5.0 +Version: 1.6.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc7..4b60ee00ffbe5 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a96..3baf8d47b4dc7 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index a46292c13bcc0..e31d90f608892 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2e..7515aad09db73 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.6.0-SNAPSHOT" } diff --git a/docs/_config.yml b/docs/_config.yml index c0e031a83ba9c..c59cc465ef89d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.5.0 +SPARK_VERSION: 1.6.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.6.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/examples/pom.xml b/examples/pom.xml index e6884b09dca94..f5ab2a7fdc098 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 561ed4babe5d0..dceedcf23ed5b 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0664cfb2021e1..d7c2ac474a18d 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 14f7daaf417e0..132062f94fb45 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 6f4e2a89e9af7..a9ed39ef8c9a0 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index ded863bd985e8..05abd9e2e6810 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index 8412600633734..89713a28ca6a8 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 69b309876a0db..05e6338a08b0a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 178ae8de13b57..244ad58ae9593 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 37bfd10d43663..171df8682c848 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3636a9037d43f..81794a8536318 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 51af3e6f2225f..61ba4787fbf90 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 521b53e230c4a..6dd8ff69c2943 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 478d0019a25f0..87a4f05a05961 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 853dea9a7795e..202fc19002d12 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 2fd768d8119c4..ed38e66aa2467 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a5db14407b4fc..22c0c6008ba37 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/network/common/pom.xml b/network/common/pom.xml index 4141fcb8267a5..1cc054a8936c5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 3d2edf9d94515..7a66c968041ce 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a99f7c4392d3d..e745180eace78 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index 421357e141572..6535994641145 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f16bf989f200b..519052620246f 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.4.0" + val previousSparkVersion = "1.5.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b8b6c8ffa375..87b141cd3b058 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -35,8 +35,17 @@ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("1.6") => Seq( - MimaBuild.excludeSparkPackage("network") - ) + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("network"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution") + ) ++ + MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), diff --git a/repl/pom.xml b/repl/pom.xml index a5a0f1fc2c857..5cf416a4a5448 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 75ab575dfde83..6cfd53e868f83 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 349007789f634..465aa3a3888c2 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3566c87dd248c..f7fe085f34d84 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index be1607476e254..ac67fe5f47be9 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 697895e72fe5b..5cc9001b0e9ab 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 298ee2348b58e..1e64f280e5bed 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 89475ee3cf5a1..066abe92e51c0 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index f6737695307a2..d8e4a4bbead81 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml From c35fdcb7e9c01271ce560dba4e0bd37569c8f5d1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 15 Sep 2015 09:58:49 -0700 Subject: [PATCH 0092/2089] [SPARK-10491] [MLLIB] move RowMatrix.dspr to BLAS jira: https://issues.apache.org/jira/browse/SPARK-10491 We implemented dspr with sparse vector support in `RowMatrix`. This method is also used in WeightedLeastSquares and other places. It would be useful to move it to `linalg.BLAS`. Let me know if new UT needed. Author: Yuhao Yang Closes #8663 from hhbyyh/movedspr. --- .../spark/ml/optim/WeightedLeastSquares.scala | 4 +- .../org/apache/spark/mllib/linalg/BLAS.scala | 44 +++++++++++++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 40 +---------------- .../apache/spark/mllib/linalg/BLASSuite.scala | 25 +++++++++++ 4 files changed, 72 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index a99e2ac4c6913..0ff8931b0bab4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares( if (fitIntercept) { // shift centers // A^T A - aBar aBar^T - RowMatrix.dspr(-1.0, aBar, aaValues) + BLAS.spr(-1.0, aBar, aaValues) // A^T b - bBar aBar BLAS.axpy(-bBar, aBar, abBar) } @@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares { bbSum += w * b * b BLAS.axpy(w, a, aSum) BLAS.axpy(w * b, a, abSum) - RowMatrix.dspr(w, a, aaSum.values) + BLAS.spr(w, a, aaSum) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9ee81eda8a8c0..df9f4ae145b88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -236,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging { _nativeBLAS } + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix in a [[DenseVector]](column major) + */ + def spr(alpha: Double, v: Vector, U: DenseVector): Unit = { + spr(alpha, v, U.values) + } + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix packed in an array (column major) + */ + def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + val n = v.size + v match { + case DenseVector(values) => + NativeBLAS.dspr("U", n, alpha, values, 1, U) + case SparseVector(size, indices, values) => + val nnz = indices.length + var colStartIdx = 0 + var prevCol = 0 + var col = 0 + var j = 0 + var i = 0 + var av = 0.0 + while (j < nnz) { + col = indices(j) + // Skip empty columns. + colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 + col = indices(j) + av = alpha * values(j) + i = 0 + while (i <= j) { + U(colStartIdx + indices(i)) += av * values(i) + i += 1 + } + j += 1 + prevCol = col + } + } + } + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 83779ac88989b..e55ef26858adb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ @@ -123,7 +122,7 @@ class RowMatrix @Since("1.0.0") ( // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { - RowMatrix.dspr(1.0, v, U.data) + BLAS.spr(1.0, v, U.data) U }, combOp = (U1, U2) => U1 += U2) @@ -673,43 +672,6 @@ class RowMatrix @Since("1.0.0") ( @Experimental object RowMatrix { - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param U the upper triangular part of the matrix packed in an array (column major) - */ - // TODO: SPARK-10491 - move this method to linalg.BLAS - private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { - // TODO: Find a better home (breeze?) for this method. - val n = v.size - v match { - case DenseVector(values) => - blas.dspr("U", n, alpha, values, 1, U) - case SparseVector(size, indices, values) => - val nnz = indices.length - var colStartIdx = 0 - var prevCol = 0 - var col = 0 - var j = 0 - var i = 0 - var av = 0.0 - while (j < nnz) { - col = indices(j) - // Skip empty columns. - colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 - col = indices(j) - av = alpha * values(j) - i = 0 - while (i <= j) { - U(colStartIdx + indices(i)) += av * values(i) - i += 1 - } - j += 1 - prevCol = col - } - } - } - /** * Fills a full square matrix from its upper triangular part. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 8db5c8424abe9..96e5ffef7a131 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite { } } + test("spr") { + // test dense vector + val alpha = 0.1 + val x = new DenseVector(Array(1.0, 2, 2.1, 4)) + val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6)) + + spr(alpha, x, U) + assert(U ~== expected absTol 1e-9) + + val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5)) + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + spr(alpha, x, matrix33) + } + } + + // test sparse vector + val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2)) + val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + spr(0.1, sv, U2) + val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4)) + assert(U2 ~== expectedSparse absTol 1e-15) + } + test("syr") { val dA = new DenseMatrix(4, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) From 8abef21dac1a6538c4e4e0140323b83d804d602b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 10:45:02 -0700 Subject: [PATCH 0093/2089] [SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py. This change does two things: - tag a few tests and adds the mechanism in the build to be able to disable those tags, both in maven and sbt, for both junit and scalatest suites. - add some logic to run-tests.py to disable some tags depending on what files have changed; that's used to disable expensive tests when a module hasn't explicitly been changed, to speed up testing for changes that don't directly affect those modules. Author: Marcelo Vanzin Closes #8437 from vanzin/test-tags. --- core/pom.xml | 10 ------- dev/run-tests.py | 19 ++++++++++++-- dev/sparktestsupport/modules.py | 24 ++++++++++++++++- external/flume/pom.xml | 10 ------- external/kafka/pom.xml | 10 ------- external/mqtt/pom.xml | 10 ------- external/twitter/pom.xml | 10 ------- external/zeromq/pom.xml | 10 ------- extras/java8-tests/pom.xml | 10 ------- extras/kinesis-asl/pom.xml | 5 ---- launcher/pom.xml | 5 ---- mllib/pom.xml | 10 ------- network/common/pom.xml | 10 ------- network/shuffle/pom.xml | 10 ------- pom.xml | 17 ++++++++++-- project/SparkBuild.scala | 13 ++++++++-- sql/core/pom.xml | 5 ---- .../execution/HiveCompatibilitySuite.scala | 2 ++ sql/hive/pom.xml | 5 ---- .../spark/sql/hive/ExtendedHiveTest.java | 26 +++++++++++++++++++ .../spark/sql/hive/client/VersionsSuite.scala | 2 ++ streaming/pom.xml | 10 ------- unsafe/pom.xml | 10 ------- .../spark/deploy/yarn/ExtendedYarnTest.java | 26 +++++++++++++++++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 1 + .../yarn/YarnShuffleIntegrationSuite.scala | 1 + 26 files changed, 124 insertions(+), 147 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java create mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index e31d90f608892..8a20181096223 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,16 +331,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index d8b22e1665e7b..1a816585187d9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,6 +118,14 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) +def determine_tags_to_exclude(changed_modules): + tags = [] + for m in modules.all_modules: + if m not in changed_modules: + tags += m.test_tags + return tags + + # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -369,6 +377,7 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] + profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -392,7 +401,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules): +def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -401,6 +410,10 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + + if excluded_tags: + test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] + if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -500,8 +513,10 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) + excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] + excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -541,7 +556,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules) + run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 346452f3174e4..65397f1f3e0bc 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - should_run_r_tests=False): + test_tags=(), should_run_r_tests=False): """ Define a new module. @@ -50,6 +50,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. + :param test_tags A set of tags that will be excluded when running unit tests if the module + is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -60,6 +62,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations + self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -85,6 +88,9 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", + ], + test_tags=[ + "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -398,6 +404,22 @@ def contains_file(self, filename): ) +yarn = Module( + name="yarn", + dependencies=[], + source_file_regexes=[ + "yarn/", + "network/yarn/", + ], + sbt_test_goals=[ + "yarn/test", + "network-yarn/test", + ], + test_tags=[ + "org.apache.spark.deploy.yarn.ExtendedYarnTest" + ] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 132062f94fb45..3154e36c21ef5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,16 +66,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 05abd9e2e6810..7d0d46dadc727 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,16 +86,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 05e6338a08b0a..913c47d33f488 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 244ad58ae9593..9137bf25ee8ae 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 171df8682c848..6fec4f0e8a0f9 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,16 +57,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 81794a8536318..dba3dda8a9562 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,16 +58,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 6dd8ff69c2943..760f183a2ef37 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,11 +74,6 @@ scalacheck_${scala.binary.version} test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index ed38e66aa2467..80696280a1d18 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,11 +42,6 @@ log4j test - - junit - junit - test - org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 22c0c6008ba37..5dedacb38874e 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,16 +94,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 1cc054a8936c5..9c12cca0df609 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,16 +64,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7a66c968041ce..e4f4c57b683c8 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,16 +78,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/pom.xml b/pom.xml index 6535994641145..2927d3e107563 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 0.9.2 ${java.home} + @@ -1952,6 +1964,7 @@ __not_used__ + ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 901cfa538d23e..d80d300f1c3b2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,11 +567,20 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", + // Exclude tags defined in a system property + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.exclude.tags").map { tags => + tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq + }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, + sys.props.get("test.exclude.tags").map { tags => + Seq("--exclude-categories=" + tags) + }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 465aa3a3888c2..fa6732db183d8 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,11 +73,6 @@ jackson-databind ${fasterxml.jackson.version} - - junit - junit - test - org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ab309e0a1d36b..ffc4c32794ca4 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ +@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ac67fe5f47be9..82cfeb2bb95d3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,11 +160,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java new file mode 100644 index 0000000000000..e2183183fb559 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index f0bb77092c0cf..888d1b7b45532 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -32,6 +33,7 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 5cc9001b0e9ab..1e6ee009ca6d5 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,21 +84,11 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.seleniumhq.selenium selenium-java test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 066abe92e51c0..4e8b9a84bb67f 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,16 +55,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java new file mode 100644 index 0000000000000..7a8f2fe979c1f --- /dev/null +++ b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b5a42fd6afd98..105c3090d489d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ +@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 8d9c9b3004eda..4700e2428df08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ +@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 7ca30b505c3561dc2832b463be4c6301a90380e4 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Tue, 15 Sep 2015 12:23:20 -0700 Subject: [PATCH 0094/2089] [PYSPARK] [MLLIB] [DOCS] Replaced addversion with versionadded in mllib.random Missed this when reviewing `pyspark.mllib.random` for SPARK-10275. Author: noelsmith Closes #8773 from noel-smith/mllib-random-versionadded-fix. --- python/pyspark/mllib/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 9c733b1332bc0..6a3c643b66417 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -41,7 +41,7 @@ class RandomRDDs(object): Generator methods for creating RDDs comprised of i.i.d samples from some distribution. - .. addedversion:: 1.1.0 + .. versionadded:: 1.1.0 """ @staticmethod From 0d9ab016755d5b56ce4043f229602169fd752e88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 15 Sep 2015 12:25:31 -0700 Subject: [PATCH 0095/2089] Closes #8738 Closes #8767 Closes #2491 Closes #6795 Closes #2096 Closes #7722 From 416003b26401894ec712e1a5291a92adfbc5af01 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 15 Sep 2015 20:42:33 +0100 Subject: [PATCH 0096/2089] [DOCS] Small fixes to Spark on Yarn doc * a follow-up to 16b6d18613e150c7038c613992d80a7828413e66 as `--num-executors` flag is not suppported. * links + formatting Author: Jacek Laskowski Closes #8762 from jaceklaskowski/docs-spark-on-yarn. --- docs/running-on-yarn.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 5159ef9e3394e..d1244323edfff 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -18,16 +18,16 @@ Spark application's configuration (driver, executors, and the AM when running in There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. + To launch a Spark application in `yarn-cluster` mode: - `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ --master yarn-cluster \ - --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ @@ -37,7 +37,7 @@ For example: The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. -To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. The following shows how you can run `spark-shell` in `yarn-client` mode: $ ./bin/spark-shell --master yarn-client @@ -54,8 +54,8 @@ In `yarn-cluster` mode, the driver runs on a different machine than the client, # Preparations -Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the Spark project website. +Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. +Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration From b42059d2efdf3322334694205a6d951bcc291644 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 13:03:38 -0700 Subject: [PATCH 0097/2089] Revert "[SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py." This reverts commit 8abef21dac1a6538c4e4e0140323b83d804d602b. --- core/pom.xml | 10 +++++++ dev/run-tests.py | 19 ++------------ dev/sparktestsupport/modules.py | 24 +---------------- external/flume/pom.xml | 10 +++++++ external/kafka/pom.xml | 10 +++++++ external/mqtt/pom.xml | 10 +++++++ external/twitter/pom.xml | 10 +++++++ external/zeromq/pom.xml | 10 +++++++ extras/java8-tests/pom.xml | 10 +++++++ extras/kinesis-asl/pom.xml | 5 ++++ launcher/pom.xml | 5 ++++ mllib/pom.xml | 10 +++++++ network/common/pom.xml | 10 +++++++ network/shuffle/pom.xml | 10 +++++++ pom.xml | 17 ++---------- project/SparkBuild.scala | 13 ++-------- sql/core/pom.xml | 5 ++++ .../execution/HiveCompatibilitySuite.scala | 2 -- sql/hive/pom.xml | 5 ++++ .../spark/sql/hive/ExtendedHiveTest.java | 26 ------------------- .../spark/sql/hive/client/VersionsSuite.scala | 2 -- streaming/pom.xml | 10 +++++++ unsafe/pom.xml | 10 +++++++ .../spark/deploy/yarn/ExtendedYarnTest.java | 26 ------------------- .../spark/deploy/yarn/YarnClusterSuite.scala | 1 - .../yarn/YarnShuffleIntegrationSuite.scala | 1 - 26 files changed, 147 insertions(+), 124 deletions(-) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java delete mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index 8a20181096223..e31d90f608892 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,6 +331,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index 1a816585187d9..d8b22e1665e7b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,14 +118,6 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) -def determine_tags_to_exclude(changed_modules): - tags = [] - for m in modules.all_modules: - if m not in changed_modules: - tags += m.test_tags - return tags - - # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -377,7 +369,6 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] - profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -401,7 +392,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): +def run_scala_tests(build_tool, hadoop_version, test_modules): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -410,10 +401,6 @@ def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) - - if excluded_tags: - test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] - if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -513,10 +500,8 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) - excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] - excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -556,7 +541,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) + run_scala_tests(build_tool, hadoop_version, test_modules) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 65397f1f3e0bc..346452f3174e4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + should_run_r_tests=False): """ Define a new module. @@ -50,8 +50,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. - :param test_tags A set of tags that will be excluded when running unit tests if the module - is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -62,7 +60,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations - self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -88,9 +85,6 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", - ], - test_tags=[ - "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -404,22 +398,6 @@ def contains_file(self, filename): ) -yarn = Module( - name="yarn", - dependencies=[], - source_file_regexes=[ - "yarn/", - "network/yarn/", - ], - sbt_test_goals=[ - "yarn/test", - "network-yarn/test", - ], - test_tags=[ - "org.apache.spark.deploy.yarn.ExtendedYarnTest" - ] -) - # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 3154e36c21ef5..132062f94fb45 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,6 +66,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 7d0d46dadc727..05abd9e2e6810 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,6 +86,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 913c47d33f488..05e6338a08b0a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 9137bf25ee8ae..244ad58ae9593 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 6fec4f0e8a0f9..171df8682c848 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,6 +57,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index dba3dda8a9562..81794a8536318 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,6 +58,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 760f183a2ef37..6dd8ff69c2943 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,6 +74,11 @@ scalacheck_${scala.binary.version} test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index 80696280a1d18..ed38e66aa2467 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,6 +42,11 @@ log4j test + + junit + junit + test + org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 5dedacb38874e..22c0c6008ba37 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,6 +94,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 9c12cca0df609..1cc054a8936c5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,6 +64,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index e4f4c57b683c8..7a66c968041ce 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,6 +78,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/pom.xml b/pom.xml index 2927d3e107563..6535994641145 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 0.9.2 ${java.home} - @@ -1964,7 +1952,6 @@ __not_used__ - ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d80d300f1c3b2..901cfa538d23e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,20 +567,11 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", - // Exclude tags defined in a system property - testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, - sys.props.get("test.exclude.tags").map { tags => - tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq - }.getOrElse(Nil): _*), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, - sys.props.get("test.exclude.tags").map { tags => - Seq("--exclude-categories=" + tags) - }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fa6732db183d8..465aa3a3888c2 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,11 @@ jackson-databind ${fasterxml.jackson.version} + + junit + junit + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ffc4c32794ca4..ab309e0a1d36b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,13 +24,11 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 82cfeb2bb95d3..ac67fe5f47be9 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,6 +160,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java deleted file mode 100644 index e2183183fb559..0000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 888d1b7b45532..f0bb77092c0cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -33,7 +32,6 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 1e6ee009ca6d5..5cc9001b0e9ab 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,11 +84,21 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.seleniumhq.selenium selenium-java test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 4e8b9a84bb67f..066abe92e51c0 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,6 +55,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java deleted file mode 100644 index 7a8f2fe979c1f..0000000000000 --- a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 105c3090d489d..b5a42fd6afd98 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 4700e2428df08..8d9c9b3004eda 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ -@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 841972e22c653ba58e9a65433fed203ff288f13a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Sep 2015 13:33:32 -0700 Subject: [PATCH 0098/2089] [SPARK-10437] [SQL] Support aggregation expressions in Order By JIRA: https://issues.apache.org/jira/browse/SPARK-10437 If an expression in `SortOrder` is a resolved one, such as `count(1)`, the corresponding rule in `Analyzer` to make it work in order by will not be applied. Author: Liang-Chi Hsieh Closes #8599 from viirya/orderby-agg. --- .../sql/catalyst/analysis/Analyzer.scala | 14 +++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 591747b45c376..02f34cbf58ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -561,7 +561,7 @@ class Analyzer( } case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved && !sort.resolved => + if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { @@ -598,9 +598,15 @@ class Analyzer( } } - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sortOrder == evaluatedOrderings) { + sort + } else { + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 962b100b532c9..f9981356f364f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1562,6 +1562,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { From 31a229aa739b6d05ec6d91b820fcca79b6b7d6fe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Sep 2015 13:36:52 -0700 Subject: [PATCH 0099/2089] [SPARK-10475] [SQL] improve column prunning for Project on Sort Sometimes we can't push down the whole `Project` though `Sort`, but we still have a chance to push down part of it. Author: Wenchen Fan Closes #8644 from cloud-fan/column-prune. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 +++++++++++++++---- .../optimizer/ColumnPruningSuite.scala | 11 +++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f4caec7451a2..648a65e7c0eb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -228,10 +228,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // Push down project if possible when the child is sort - case p @ Project(projectList, s @ Sort(_, _, grandChild)) - if s.references.subsetOf(p.outputSet) => - s.copy(child = Project(projectList, grandChild)) + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) + } else { + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } + } // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index dbebcb86809de..4a1e7ceaf394b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -80,5 +80,16 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Column pruning for Project on Sort") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + + val query = input.orderBy('b.asc).select('a).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze + + comparePlans(optimized, correctAnswer) + } + // todo: add more tests for column pruning } From a172d96f00e559a2c0ec0fa44ea67bce2f89cd28 Mon Sep 17 00:00:00 2001 From: Gayathri Murali Date: Tue, 15 Sep 2015 14:24:18 -0700 Subject: [PATCH 0100/2089] [SPARK-10380] - Confusing Pyspark sql examples --- python/pyspark/sql/column.py | 20 ++++---------------- python/pyspark/sql/dataframe.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 1d7dfd1b0a3f5..eb7ac2ed24bd6 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -342,33 +342,21 @@ def cast(self, dataType): else: raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) - + @ignore_unicode_prefix @since(1.3) def astype(self, dataType): """ Convert the column into type ``dataType``. - >>> df.select(df.age.astype("string").alias('ages')).collect() + >>> df.select(df.age.astype("string").alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] >>> df.select(df.age.astype(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] - + (Note : astype is alias for cast) """ - if isinstance(dataType, basestring): - jc = self._jc.astype(dataType) - elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.astype(jdt) - else: - raise TypeError("unexpected type: %s" % type(dataType)) - return Column(jc) - - # astype = cast + return self.cast(dataType) - @since(1.3) def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e7d138a6eb143..2921f857ff77d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -835,6 +835,32 @@ def groupBy(self, *cols): from pyspark.sql.group import GroupedData return GroupedData(jgd, self.sql_ctx) + @ignore_unicode_prefix + @since(1.3) + def groupby(self, *cols): + """Groups the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + :func:`groupBy` is an alias for :func:`groupby`. + + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + + >>> df.groupby().avg().collect() + [Row(avg(age)=3.5)] + >>> df.groupby('name').agg({'age': 'mean'}).collect() + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] + >>> df.groupby(df.name).avg().collect() + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] + >>> df.groupby(['name', df.age]).count().collect() + [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] + """ + jgd = self._jdf.groupby(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) def rollup(self, *cols): """ @@ -1309,8 +1335,6 @@ def toPandas(self): # Pandas compatibility ########################################################################################## - groupby = groupBy - # drop_duplicates = dropDuplicates # Having SchemaRDD for backward compatibility (for docs) From be52faa7c72fb4b95829f09a7dc5eb5dccd03524 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 15 Sep 2015 15:46:47 -0700 Subject: [PATCH 0101/2089] [SPARK-7685] [ML] Apply weights to different samples in Logistic Regression In fraud detection dataset, almost all the samples are negative while only couple of them are positive. This type of high imbalanced data will bias the models toward negative resulting poor performance. In python-scikit, they provide a correction allowing users to Over-/undersample the samples of each class according to the given weights. In auto mode, selects weights inversely proportional to class frequencies in the training set. This can be done in a more efficient way by multiplying the weights into loss and gradient instead of doing actual over/undersampling in the training dataset which is very expensive. http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html On the other hand, some of the training data maybe more important like the training samples from tenure users while the training samples from new users maybe less important. We should be able to provide another "weight: Double" information in the LabeledPoint to weight them differently in the learning algorithm. Author: DB Tsai Author: DB Tsai Closes #7884 from dbtsai/SPARK-7685. --- .../classification/LogisticRegression.scala | 199 +++++++++++------- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 12 +- .../stat/MultivariateOnlineSummarizer.scala | 75 ++++--- .../LogisticRegressionSuite.scala | 102 ++++++++- .../MultivariateOnlineSummarizerSuite.scala | 27 +++ project/MimaExcludes.scala | 10 +- 7 files changed, 303 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a460262b87e43..bd96e8d000ff2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,12 +29,12 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.storage.StorageLevel /** @@ -42,7 +42,7 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization with HasThreshold { + with HasStandardization with HasWeightCol with HasThreshold { /** * Set threshold in binary classification, in range [0, 1]. @@ -146,6 +146,17 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[classification] case class Instance(label: Double, weight: Double, features: Vector) + /** * :: Experimental :: * Logistic regression. @@ -218,31 +229,42 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, labelSummarizer) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer), - (label: Double, features: Vector)) => - (summarizer.add(features), labelSummarizer.add(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, - classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer, - classSummarizer2: MultiClassSummarizer)) => - (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2)) - }) + val (summarizer, labelSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer), + c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp) + } val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -295,7 +317,7 @@ class LogisticRegression(override val uid: String) new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialWeightsWithIntercept = + val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) if ($(fitIntercept)) { @@ -312,14 +334,14 @@ class LogisticRegression(override val uid: String) b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialWeightsWithIntercept.toArray(numFeatures) - = math.log(histogram(1).toDouble / histogram(0).toDouble) + initialCoefficientsWithIntercept.toArray(numFeatures) + = math.log(histogram(1) / histogram(0)) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeightsWithIntercept.toBreeze.toDenseVector) + initialCoefficientsWithIntercept.toBreeze.toDenseVector) - val (weights, intercept, objectiveHistory) = { + val (coefficients, intercept, objectiveHistory) = { /* Note that in Logistic Regression, the objective history (loss + regularization) is log-likelihood which is invariance under feature standardization. As a result, @@ -339,28 +361,29 @@ class LogisticRegression(override val uid: String) } /* - The weights are trained in the scaled space; we're converting them back to + The coefficients are trained in the scaled space; we're converting them back to the original space. Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawWeights = state.x.toArray.clone() + val rawCoefficients = state.x.toArray.clone() var i = 0 while (i < numFeatures) { - rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } if ($(fitIntercept)) { - (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result()) + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) } else { - (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result()) + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) } } if (handlePersistence) instances.unpersist() - val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) val logRegSummary = new BinaryLogisticRegressionTrainingSummary( model.transform(dataset), $(probabilityCol), @@ -501,22 +524,29 @@ class LogisticRegressionModel private[ml] ( * corresponding joint dataset. */ private[classification] class MultiClassSummarizer extends Serializable { - private val distinctMap = new mutable.HashMap[Int, Long] + // The first element of value in distinctMap is the actually number of instances, + // and the second element of value is sum of the weights. + private val distinctMap = new mutable.HashMap[Int, (Long, Double)] private var totalInvalidCnt: Long = 0L /** * Add a new label into this MultilabelSummarizer, and update the distinct map. * @param label The label for this data point. + * @param weight The weight of this instances. * @return This MultilabelSummarizer */ - def add(label: Double): this.type = { + def add(label: Double, weight: Double = 1.0): this.type = { + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + + if (weight == 0.0) return this + if (label - label.toInt != 0.0 || label < 0) { totalInvalidCnt += 1 this } else { - val counts: Long = distinctMap.getOrElse(label.toInt, 0L) - distinctMap.put(label.toInt, counts + 1) + val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0)) + distinctMap.put(label.toInt, (counts + 1L, weightSum + weight)) this } } @@ -537,8 +567,8 @@ private[classification] class MultiClassSummarizer extends Serializable { } smallMap.distinctMap.foreach { case (key, value) => - val counts = largeMap.distinctMap.getOrElse(key, 0L) - largeMap.distinctMap.put(key, counts + value) + val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0)) + largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2)) } largeMap.totalInvalidCnt += smallMap.totalInvalidCnt largeMap @@ -550,13 +580,13 @@ private[classification] class MultiClassSummarizer extends Serializable { /** @return The number of distinct labels in the input dataset. */ def numClasses: Int = distinctMap.keySet.max + 1 - /** @return The counts of each label in the input dataset. */ - def histogram: Array[Long] = { - val result = Array.ofDim[Long](numClasses) + /** @return The weightSum of each label in the input dataset. */ + def histogram: Array[Double] = { + val result = Array.ofDim[Double](numClasses) var i = 0 val len = result.length while (i < len) { - result(i) = distinctMap.getOrElse(i, 0L) + result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2 i += 1 } result @@ -565,6 +595,8 @@ private[classification] class MultiClassSummarizer extends Serializable { /** * Abstraction for multinomial Logistic Regression Training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. */ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { @@ -584,10 +616,10 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Dataframe outputted by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each sample. */ + /** Field in "predictions" which gives the the true label of each instance. */ def labelCol: String } @@ -597,8 +629,8 @@ sealed trait LogisticRegressionSummary extends Serializable { * Logistic regression training results. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample as a vector. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance as a vector. + * @param labelCol field in "predictions" which gives the true label of each instance. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -617,8 +649,8 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * Binary Logistic regression results for a given model. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @@ -687,14 +719,14 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used - * in binary classification for samples in sparse or dense vector in a online fashion. + * in binary classification for instances in sparse or dense vector in a online fashion. * * Note that multinomial logistic loss is not supported yet! * * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. @@ -702,25 +734,25 @@ class BinaryLogisticRegressionSummary private[classification] ( * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - weights: Vector, + coefficients: Vector, numClasses: Int, fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { - private var totalCnt: Long = 0L + private var weightSum = 0.0 private var lossSum = 0.0 - private val weightsArray = weights match { + private val coefficientsArray = coefficients match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") } - private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length + private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - private val gradientSumArray = Array.ofDim[Double](weightsArray.length) + private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) /** * Add a new training data to this LogisticAggregator, and update the loss and gradient @@ -729,13 +761,17 @@ private class LogisticAggregator( * @param label The label for this data point. * @param data The features for one data point in dense/sparse vector format to be added * into this aggregator. + * @param weight The weight for over-/undersamples each of training instance. Default is one. * @return This LogisticAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + + def add(label: Double, data: Vector, weight: Double = 1.0): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${data.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val localWeightsArray = weightsArray + if (weight == 0.0) return this + + val localCoefficientsArray = coefficientsArray val localGradientSumArray = gradientSumArray numClasses match { @@ -745,13 +781,13 @@ private class LogisticAggregator( var sum = 0.0 data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localWeightsArray(index) * (value / featuresStd(index)) + sum += localCoefficientsArray(index) * (value / featuresStd(index)) } } - sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 } + sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 } } - val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label) data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { @@ -765,15 +801,15 @@ private class LogisticAggregator( if (label > 0) { // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - lossSum += MLUtils.log1pExp(margin) + lossSum += weight * MLUtils.log1pExp(margin) } else { - lossSum += MLUtils.log1pExp(margin) - margin + lossSum += weight * (MLUtils.log1pExp(margin) - margin) } case _ => new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " + "binary classification for now.") } - totalCnt += 1 + weightSum += weight this } @@ -789,8 +825,8 @@ private class LogisticAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { - totalCnt += other.totalCnt + if (other.weightSum != 0.0) { + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -805,13 +841,17 @@ private class LogisticAggregator( this } - def count: Long = totalCnt - - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -823,7 +863,7 @@ private class LogisticAggregator( * It's used in Breeze's convex optimization routines. */ private class LogisticCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], numClasses: Int, fitIntercept: Boolean, standardization: Boolean, @@ -831,22 +871,23 @@ private class LogisticCostFun( featuresMean: Array[Double], regParamL2: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length - val w = Vectors.fromBreeze(weights) + val w = Vectors.fromBreeze(coefficients) - val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, - featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + val logisticAggregator = { + val seqOp = (c: LogisticAggregator, instance: Instance) => + c.add(instance.label, instance.features, instance.weight) + val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) + + data.treeAggregate( + new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean) + )(seqOp, combOp) + } val totalGradientArray = logisticAggregator.gradient.toArray - // regVal is the sum of weight squares excluding intercept for L2 regularization. + // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e9e99ed1db40e..8049d51fee5ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -42,7 +42,7 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + " probabilities. Note: Not all models output well-calibrated probability estimates!" + - " These probabilities should be treated as confidences, not precise probabilities.", + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), @@ -65,10 +65,10 @@ private[shared] object SharedParamsCodeGen { "options may be added later.", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + - " before fitting the model.", Some("true")), + " before fitting the model", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + - " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 30092170863ad..aff47fc326c4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities") setDefault(probabilityCol, "probability") @@ -270,10 +270,10 @@ private[ml] trait HasHandleInvalid extends Params { private[ml] trait HasStandardization extends Params { /** - * Param for whether to standardize the training features before fitting the model.. + * Param for whether to standardize the training features before fitting the model. * @group param */ - final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.") + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model") setDefault(standardization, true) @@ -304,10 +304,10 @@ private[ml] trait HasSeed extends Params { private[ml] trait HasElasticNetParam extends Params { /** - * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * @group param */ - final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getElasticNetParam: Double = $(elasticNetParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 51b713e263e0c..201333c3690df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -23,16 +23,19 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, - * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector + * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector * format in a online fashion. * * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of * the corresponding joint dataset. * - * A numerically stable algorithm is implemented to compute sample mean and variance: + * A numerically stable algorithm is implemented to compute the mean and variance of instances: * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * + * For weighted instances, the unbiased estimation of variance is defined by the reliability + * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]. */ @Since("1.1.0") @DeveloperApi @@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currM2: Array[Double] = _ private var currL1: Array[Double] = _ private var totalCnt: Long = 0 + private var weightSum: Double = 0.0 + private var weightSquareSum: Double = 0.0 private var nnz: Array[Double] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -55,10 +60,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ @Since("1.1.0") - def add(sample: Vector): this.type = { + def add(sample: Vector): this.type = add(sample, 1.0) + + private[spark] def add(instance: Vector, weight: Double): this.type = { + require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0") + if (weight == 0.0) return this + if (n == 0) { - require(sample.size > 0, s"Vector should have dimension larger than zero.") - n = sample.size + require(instance.size > 0, s"Vector should have dimension larger than zero.") + n = instance.size currMean = Array.ofDim[Double](n) currM2n = Array.ofDim[Double](n) @@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = Array.fill[Double](n)(Double.MaxValue) } - require(n == sample.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.size}.") + require(n == instance.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${instance.size}.") val localCurrMean = currMean val localCurrM2n = currM2n @@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localNnz = nnz val localCurrMax = currMax val localCurrMin = currMin - sample.foreachActive { (index, value) => + instance.foreachActive { (index, value) => if (value != 0.0) { if (localCurrMax(index) < value) { localCurrMax(index) = value @@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val prevMean = localCurrMean(index) val diff = value - prevMean - localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) - localCurrM2n(index) += (value - localCurrMean(index)) * diff - localCurrM2(index) += value * value - localCurrL1(index) += math.abs(value) + localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight) + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + localCurrM2(index) += weight * value * value + localCurrL1(index) += weight * math.abs(value) - localNnz(index) += 1.0 + localNnz(index) += weight } } + weightSum += weight + weightSquareSum += weight * weight totalCnt += 1 this } @@ -112,10 +124,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.weightSum != 0.0 && other.weightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt + weightSum += other.weightSum + weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { val thisNnz = nnz(i) @@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S nnz(i) = totalNnz i += 1 } - } else if (totalCnt == 0 && other.totalCnt != 0) { + } else if (weightSum == 0.0 && other.weightSum != 0.0) { this.n = other.n this.currMean = other.currMean.clone() this.currM2n = other.currM2n.clone() this.currM2 = other.currM2.clone() this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt + this.weightSum = other.weightSum + this.weightSquareSum = other.weightSquareSum this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -158,28 +174,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def mean: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (nnz(i) / totalCnt) + realMean(i) = currMean(i) * (nnz(i) / weightSum) i += 1 } Vectors.dense(realMean) } /** - * Sample variance of each dimension. + * Unbiased estimate of sample variance of each dimension. * */ @Since("1.1.0") override def variance: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realVariance = Array.ofDim[Double](n) - val denominator = totalCnt - 1.0 + val denominator = weightSum - (weightSquareSum / weightSum) // Sample variance is computed, if the denominator is less than 0, the variance is just 0. if (denominator > 0.0) { @@ -187,9 +203,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = - currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * + (weightSum - nnz(i)) / weightSum) / denominator i += 1 } } @@ -209,7 +224,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def numNonzeros: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } @@ -220,11 +235,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def max: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -236,11 +251,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def min: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) @@ -252,7 +267,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL2: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) @@ -271,7 +286,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL1: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(currL1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cce39f382f738..f5219f9f574be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.classification +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame( - generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) + sqlContext.createDataFrame(sc.parallelize(testData, 4)) } } @@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol === "prediction") assert(lr.getRawPredictionCol === "rawPrediction") assert(lr.getProbabilityCol === "probability") + assert(lr.getWeightCol === "") assert(lr.getFitIntercept) assert(lr.getStandardization) val model = lr.fit(dataset) @@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("MultiClassSummarizer") { val summarizer1 = (new MultiClassSummarizer) .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0) - assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2)) + assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1)) assert(summarizer1.countInvalid === 0) assert(summarizer1.numClasses === 7) val summarizer2 = (new MultiClassSummarizer) .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0) - assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1)) assert(summarizer2.countInvalid === 0) assert(summarizer2.numClasses === 6) val summarizer3 = (new MultiClassSummarizer) .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0) - assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2)) + assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3)) assert(summarizer3.countInvalid === 3) assert(summarizer3.numClasses === 5) val summarizer4 = (new MultiClassSummarizer) .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0) - assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer4.histogram === Array[Double](0, 1, 1, 1)) assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) - assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1)) assert(summarizerA.countInvalid === 0) assert(summarizerA.numClasses === 7) // large map merges small one val summarizerB = summarizer3.merge(summarizer4) assert(summarizerB.hashCode() === summarizer3.hashCode()) - assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2)) + assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3)) assert(summarizerB.countInvalid === 5) assert(summarizerB.numClasses === 5) } + test("MultiClassSummarizer with weighted samples") { + val summarizer1 = (new MultiClassSummarizer) + .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1) + assert(Vectors.dense(summarizer1.histogram) ~== + Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10) + assert(summarizer1.countInvalid === 0) + assert(summarizer1.numClasses === 7) + + val summarizer2 = (new MultiClassSummarizer) + .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0) + assert(Vectors.dense(summarizer2.histogram) ~== + Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10) + assert(summarizer2.countInvalid === 0) + assert(summarizer2.numClasses === 6) + + val summarizer = summarizer1.merge(summarizer2) + assert(Vectors.dense(summarizer.histogram) ~== + Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10) + assert(summarizer.countInvalid === 0) + assert(summarizer.numClasses === 7) + } + test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) @@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) + val interceptTheory = math.log(histogram(1) / histogram(0)) val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) @@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .forall(x => x(0) >= x(1))) } + + test("binary logistic regression with weighted samples") { + val (dataset, weightedDataset) = { + val nPoints = 1000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + + // Let's over-sample the positive samples twice. + val data1 = testData.flatMap { case labeledPoint: LabeledPoint => + if (labeledPoint.label == 1.0) { + Iterator(labeledPoint, labeledPoint) + } else { + Iterator(labeledPoint) + } + } + + val rnd = new Random(8392) + val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) => + if (rnd.nextGaussian() > 0.0) { + if (label == 1.0) { + Iterator( + Instance(label, 1.2, features), + Instance(label, 0.8, features), + Instance(0.0, 0.0, features)) + } else { + Iterator( + Instance(label, 0.3, features), + Instance(1.0, 0.0, features), + Instance(label, 0.1, features), + Instance(label, 0.6, features)) + } + } else { + if (label == 1.0) { + Iterator(Instance(label, 2.0, features)) + } else { + Iterator(Instance(label, 1.0, features)) + } + } + } + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LogisticRegression).setFitIntercept(true) + .setRegParam(0.0).setStandardization(true) + val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") + .setRegParam(0.0).setStandardization(true) + val model1a0 = trainer1a.fit(dataset) + val model1a1 = trainer1a.fit(weightedDataset) + val model1b = trainer1b.fit(weightedDataset) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 07efde4f5e6dc..b6d41db69be0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { s0.merge(s1) assert(s0.mean(0) ~== 1.0 absTol 1e-14) } + + test("merging summarizer with weighted samples") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1) + .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge( + (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15) + .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05)) + + assert(summarizer.count === 4) + + // The following values are hand calculated using the formula: + // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + // which defines the reliability weight used for computing the unbiased estimation of variance + // for weighted instances. + assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44)) + absTol 1E-10, "mean mismatch") + assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) + absTol 1E-8, "variance mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4)) + absTol 1E-10, "numNonzeros mismatch") + assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") + assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch") + assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192) + absTol 1E-8, "normL2 mismatch") + assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 87b141cd3b058..46026c1e90ea0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,15 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.count") + ) case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From b6e998634e05db0bb6267173e7b28f885c808c16 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 16:45:47 -0700 Subject: [PATCH 0102/2089] [SPARK-10548] [SPARK-10563] [SQL] Fix concurrent SQL executions *Note: this is for master branch only.* The fix for branch-1.5 is at #8721. The query execution ID is currently passed from a thread to its children, which is not the intended behavior. This led to `IllegalArgumentException: spark.sql.execution.id is already set` when running queries in parallel, e.g.: ``` (1 to 100).par.foreach { _ => sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } ``` The cause is `SparkContext`'s local properties are inherited by default. This patch adds a way to exclude keys we don't want to be inherited, and makes SQL go through that code path. Author: Andrew Or Closes #8710 from andrewor14/concurrent-sql-executions. --- .../scala/org/apache/spark/SparkContext.scala | 9 +- .../org/apache/spark/ThreadingSuite.scala | 65 +++++------ .../sql/execution/SQLExecutionSuite.scala | 101 ++++++++++++++++++ 3 files changed, 132 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dee6091ce3caf..a2f34eafa2c38 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -347,8 +348,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index a96a4ce201c21..54c131cdae367 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,7 +147,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") @@ -178,7 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -207,58 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - var throwable: Option[Throwable] = None - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { try { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task - try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } - } catch { - case s: SparkException => // ignored so that we don't print noise in test logs - } + threadTestValue = sc.getLocalProperty("test") } catch { case t: Throwable => throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - throwable.foreach { t => throw t } - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 0000000000000..63639681ef80a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} From a63cdc769f511e98b38c3318bcc732c9a6c76c22 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 16:53:27 -0700 Subject: [PATCH 0103/2089] [SPARK-10612] [SQL] Add prepare to LocalNode. The idea is that we should separate the function call that does memory reservation (i.e. prepare) from the function call that consumes the input (e.g. open()), so all operators can be a chance to reserve memory before they are all consumed. Author: Reynold Xin Closes #8761 from rxin/SPARK-10612. --- .../org/apache/spark/sql/execution/local/LocalNode.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 9840080e16953..569cff565c092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -45,6 +45,14 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging def output: Seq[Attribute] + /** + * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume + * any input data. + * + * Implementations of this must also call the `prepare()` function of its children. + */ + def prepare(): Unit = children.foreach(_.prepare()) + /** * Initializes the iterator state. Must be called before calling `next()`. * From 99ecfa5945aedaa71765ecf5cce59964ae52eebe Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 15 Sep 2015 17:01:10 -0700 Subject: [PATCH 0104/2089] [SPARK-10575] [SPARK CORE] Wrapped RDD.takeSample with Scope Remove return statements in RDD.takeSample and wrap it withScope Author: vinodkc Author: vinodkc Author: Vinod K C Closes #8730 from vinodkc/fix_takesample_return. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7dd2bc5d7cd72..a56e542242d5f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -469,50 +469,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } - - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** From 38700ea40cb1dd0805cc926a9e629f93c99527ad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 15 Sep 2015 17:11:21 -0700 Subject: [PATCH 0105/2089] [SPARK-10381] Fix mixup of taskAttemptNumber & attemptId in OutputCommitCoordinator When speculative execution is enabled, consider a scenario where the authorized committer of a particular output partition fails during the OutputCommitter.commitTask() call. In this case, the OutputCommitCoordinator is supposed to release that committer's exclusive lock on committing once that task fails. However, due to a unit mismatch (we used task attempt number in one place and task attempt id in another) the lock will not be released, causing Spark to go into an infinite retry loop. This bug was masked by the fact that the OutputCommitCoordinator does not have enough end-to-end tests (the current tests use many mocks). Other factors contributing to this bug are the fact that we have many similarly-named identifiers that have different semantics but the same data types (e.g. attemptNumber and taskAttemptId, with inconsistent variable naming which makes them difficult to distinguish). This patch adds a regression test and fixes this bug by always using task attempt numbers throughout this code. Author: Josh Rosen Closes #8544 from JoshRosen/SPARK-10381. --- .../org/apache/spark/SparkHadoopWriter.scala | 3 +- .../org/apache/spark/TaskEndReason.scala | 7 +- .../executor/CommitDeniedException.scala | 4 +- .../spark/mapred/SparkHadoopMapRedUtil.scala | 20 ++---- .../apache/spark/scheduler/DAGScheduler.scala | 7 +- .../scheduler/OutputCommitCoordinator.scala | 48 +++++++------ .../org/apache/spark/scheduler/TaskInfo.scala | 7 +- .../status/api/v1/AllStagesResource.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- ...putCommitCoordinatorIntegrationSuite.scala | 68 +++++++++++++++++++ .../OutputCommitCoordinatorSuite.scala | 24 ++++--- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- project/MimaExcludes.scala | 36 +++++++++- .../datasources/WriterContainer.scala | 3 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 17 files changed, 174 insertions(+), 69 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ae5926dd534a6..ac6eaab20d8d2 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -104,8 +104,7 @@ class SparkHadoopWriter(jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 2ae878b3e6087..7137246bc34f2 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -193,9 +193,12 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" /** * If a task failed because its attempt to commit was denied, do not count this failure * towards failing the stage. This is intended to prevent spurious stage failures in cases diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da1..7d84889a2def0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f405b732e4725..f7298e8d5c62c 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -91,8 +91,7 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) @@ -122,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -132,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -143,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b4f90e8347894..3c9a66e504403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1128,8 +1128,11 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 5d926377ce86b..add0dedc03f44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private type CommittersByStageMap = + mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Called by DAGScheduler private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() } // Called by DAGScheduler @@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") + if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") authorizedCommitters.remove(partition) } } @@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => authorizedCommitters.get(partition) match { case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") false case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced77700..f113c2b1b8433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b3..24a0b5220695c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -127,7 +127,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2b71f55b7bb4f..712782d27b3cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -621,7 +621,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attempt + val attempt = taskInfo.attemptNumber val svgTag = if (totalExecutionTime == 0) { @@ -967,7 +967,7 @@ private[ui] class TaskDataSource( new TaskTableRowData( info.index, info.taskId, - info.attempt, + info.attemptNumber, info.speculative, info.status, info.taskLocality.toString, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 24f78744ad74c..99614a786bd93 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -266,7 +266,7 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ - ("Attempt" -> taskInfo.attempt) ~ + ("Attempt" -> taskInfo.attemptNumber) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala new file mode 100644 index 0000000000000..1ae5b030f0832 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.{Span, Seconds} + +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.util.Utils + +/** + * Integration tests for the OutputCommitCoordinator. + * + * See also: [[OutputCommitCoordinatorSuite]] for unit tests that use mocks. + */ +class OutputCommitCoordinatorIntegrationSuite + extends SparkFunSuite + with LocalSparkContext + with Timeouts { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .set("master", "local[2,4]") + .set("spark.speculation", "true") + .set("spark.hadoop.mapred.output.committer.class", + classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) + sc = new SparkContext("local[2, 4]", "test", conf) + } + + test("exception thrown in OutputCommitter.commitTask()") { + // Regression test for SPARK-10381 + failAfter(Span(60, Seconds)) { + val tempDir = Utils.createTempDir() + try { + sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + val ctx = TaskContext.get() + if (ctx.attemptNumber < 1) { + throw new java.io.FileNotFoundException("Intentional exception") + } + super.commitTask(context) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e5ecd4b7c2610..6d08d7c5b7d2a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -63,6 +63,9 @@ import scala.language.postfixOps * was not in SparkHadoopWriter, the tests would still pass because only one of the * increments would be captured even though the commit in both tasks was executed * erroneously. + * + * See also: [[OutputCommitCoordinatorIntegrationSuite]] for integration tests that do + * not use mocks. */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -164,27 +167,28 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 - val partition: Long = 2 - val authorizedCommitter: Long = 3 - val nonAuthorizedCommitter: Long = 100 + val partition: Int = 2 + val authorizedCommitter: Int = 3 + val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage) - assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + + assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) // New tasks should still not be able to commit because the authorized committer has not failed assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) // A new task should now be allowed to become the authorized committer assert( - outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) // There can only be one authorized committer assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 47e548ef0d442..143c1b901df11 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -499,7 +499,7 @@ class JsonProtocolSuite extends SparkFunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) - assert(info1.attempt === info2.attempt) + assert(info1.attemptNumber === info2.attemptNumber) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46026c1e90ea0..1c96b0958586f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,7 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticCostFun.this"), @@ -53,6 +53,23 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.count") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.5") => Seq( @@ -213,6 +230,23 @@ object MimaExcludes { // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.4") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index f8ef674ed29c1..cfd64c1d9eb34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -198,8 +198,7 @@ private[sql] abstract class BaseWriterContainer( } def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId) } def abortTask(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2bbb41ca777b7..7a46c69a056b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -54,9 +54,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { details = "" ) - private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( taskId = taskId, - attempt = attempt, + attemptNumber = attemptNumber, // The following fields are not used in tests index = 0, launchTime = 0, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 4ca8042d22367..c8d6b718045a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -121,7 +121,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { From 35a19f3357d2ec017cfefb90f1018403e9617de4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 17:24:32 -0700 Subject: [PATCH 0106/2089] [SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or Closes #8764 from andrewor14/sql-local-tests-cleanup. --- .../spark/sql/execution/local/LocalNode.scala | 8 +- .../sql/execution/local/SampleNode.scala | 16 +- .../local/TakeOrderedAndProjectNode.scala | 2 +- .../spark/sql/execution/SparkPlanTest.scala | 2 +- .../spark/sql/execution/local/DummyNode.scala | 68 ++++ .../sql/execution/local/ExpandNodeSuite.scala | 54 ++- .../sql/execution/local/FilterNodeSuite.scala | 34 +- .../execution/local/HashJoinNodeSuite.scala | 141 ++++---- .../execution/local/IntersectNodeSuite.scala | 24 +- .../sql/execution/local/LimitNodeSuite.scala | 28 +- .../sql/execution/local/LocalNodeSuite.scala | 73 +--- .../sql/execution/local/LocalNodeTest.scala | 165 ++------- .../local/NestedLoopJoinNodeSuite.scala | 316 ++++++------------ .../execution/local/ProjectNodeSuite.scala | 39 ++- .../sql/execution/local/SampleNodeSuite.scala | 35 +- .../TakeOrderedAndProjectNodeSuite.scala | 50 ++- .../sql/execution/local/UnionNodeSuite.scala | 49 +-- 17 files changed, 468 insertions(+), 636 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 569cff565c092..f96b62a67a254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { +abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { protected val codegenEnabled: Boolean = conf.codegenEnabled protected val unsafeEnabled: Boolean = conf.unsafeEnabled - lazy val schema: StructType = StructType.fromAttributes(output) - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - def output: Seq[Attribute] - /** * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume * any input data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index abf3df1c0c2af..793700803f216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.local -import java.util.Random - import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + /** * Sample the dataset. * @@ -51,18 +50,15 @@ case class SampleNode( override def open(): Unit = { child.open() - val (sampler, _seed) = if (withReplacement) { - val random = new Random(seed) + val sampler = + if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, // requiring us to copy the row, which is more expensive than the random number generator. - (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), - // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result - // of DataFrame - random.nextLong()) + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) } else { - (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + new BernoulliCellSampler[InternalRow](lowerBound, upperBound) } - sampler.setSeed(_seed) + sampler.setSeed(seed) iterator = sampler.sample(child.asIterator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 53f1dcc65d8cf..ae672fbca8d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode( } // Close it eagerly since we don't need it. child.close() - iterator = queue.iterator + iterator = queue.toArray.sorted(ord).iterator } override def next(): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index de45ae4635fb7..3d218f01c9ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -238,7 +238,7 @@ object SparkPlanTest { outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 0000000000000..efc3227dd60d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala index cfa7f3f6dcb97..bbd94d8da2d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -17,35 +17,33 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.catalyst.dsl.expressions._ + + class ExpandNodeSuite extends LocalNodeTest { - import testImplicits._ - - test("expand") { - val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") - checkAnswer( - input, - node => - ExpandNode(conf, Seq( - Seq( - input.col("key") + input.col("value"), input.col("key") - input.col("value") - ).map(_.expr), - Seq( - input.col("key") * input.col("value"), input.col("key") / input.col("value") - ).map(_.expr) - ), node.output, node), - Seq( - (2, 0), - (1, 1), - (4, 0), - (4, 1), - (6, 0), - (9, 1), - (8, 0), - (16, 1), - (10, 0), - (25, 1) - ).toDF().collect() - ) + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index a12670e347c25..4eadce646d379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -17,25 +17,29 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val condition = (testData.col("key") % 2) === 0 - checkAnswer( - testData, - node => FilterNode(conf, condition.expr, node), - testData.filter(condition).collect() - ) +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val condition = (emptyTestData.col("key") % 2) === 0 - checkAnswer( - emptyTestData, - node => FilterNode(conf, condition.expr, node), - emptyTestData.filter(condition).collect() - ) + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 78d891351f4a9..5c1bdb088eeed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -18,99 +18,80 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.execution.joins +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class HashJoinNodeSuite extends LocalNodeTest { - import testImplicits._ + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { - test(s"$suiteName: inner join with one match per row") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(upperCaseData.col("N").expr), - Seq(lowerCaseData.col("n").expr), - joins.BuildLeft, - node1, - node2) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N").collect() - ) + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions(new HashJoinNode( + conf, Seq('id1), Seq('id2), buildSide, node1, node2)) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) } + assert(actualOutput === expectedOutput) } - test(s"$suiteName: inner join with multiple matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 1).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - x.join(y).where($"x.a" === $"y.a").collect() - ) - } + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) } - test(s"$suiteName: inner join, no matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 2).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - Nil - ) - } + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) } - test(s"$suiteName: big inner join, 4 matches per row") { - withSQLConf(confPairs: _*) { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as("x") - val bigDataY = bigData.as("y") + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } - checkAnswer2( - bigDataX, - bigDataY, - wrapForUnsafe( - (node1, node2) => - HashJoinNode( - conf, - Seq(bigDataX.col("key").expr), - Seq(bigDataY.col("key").expr), - joins.BuildLeft, - node1, - node2) - ), - bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) - } + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) } } - joinSuite( - "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala index 7deaa375fcfc2..c0ad2021b204a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.execution.local -class IntersectNodeSuite extends LocalNodeTest { - import testImplicits._ +class IntersectNodeSuite extends LocalNodeTest { test("basic") { - val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") - - checkAnswer2( - input1, - input2, - (node1, node2) => IntersectNode(conf, node1, node2), - input1.intersect(input2).collect() - ) + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 3b183902007e4..fb790636a3689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { +class LimitNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer( - testData, - node => LimitNode(conf, 10, node), - testData.limit(10).collect() - ) + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer( - emptyTestData, - node => LimitNode(conf, 10, node), - emptyTestData.limit(10).collect() - ) + testLimit() } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index b89fa46f8b3b4..0d1ed99eec6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -17,28 +17,24 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.IntegerType -class LocalNodeSuite extends SparkFunSuite { - private val data = (1 to 100).toArray +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) - data.foreach { i => + data.foreach { case (k, v) => assert(node.next()) // fetch should be idempotent val fetched = node.fetch() assert(node.fetch() === fetched) assert(node.fetch() === fetched) - assert(node.fetch().numFields === 1) - assert(node.fetch().getInt(0) === i) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) } assert(!node.next()) node.close() @@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite { } test("asIterator") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) val iter = node.asIterator node.open() - data.foreach { i => + data.foreach { case (k, v) => // hasNext should be idempotent assert(iter.hasNext) assert(iter.hasNext) val item = iter.next() - assert(item.numFields === 1) - assert(item.getInt(0) === i) + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) } intercept[NoSuchElementException] { iter.next() @@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite { } test("collect") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) - assert(collected.forall(_.size === 1)) - assert(collected.map(_.getInt(0)) === data) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) node.close() } } - -/** - * A dummy [[LocalNode]] that just returns one row per integer in the input. - */ -private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { - private var index = Int.MinValue - - def this(input: Array[Int]) { - this(new SQLConf, input) - } - - def isOpen: Boolean = { - index != Int.MinValue - } - - override def output: Seq[Attribute] = { - Seq(AttributeReference("something", IntegerType)()) - } - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - val values = Array(input(index).asInstanceOf[Any]) - new GenericInternalRow(values) - } - - override def close(): Unit = { - index = Int.MinValue - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 86dd28064cc6a..098050bcd2236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,147 +17,54 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StringType} -class LocalNodeTest extends SparkFunSuite with SharedSQLContext { - def conf: SQLConf = sqlContext.conf +class LocalNodeTest extends SparkFunSuite { - protected def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer( - input: DataFrame, - nodeFunction: LocalNode => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - input :: Nil, - nodes => nodeFunction(nodes.head), - expectedAnswer, - sortAnswers) - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer2( - left: DataFrame, - right: DataFrame, - nodeFunction: (LocalNode, LocalNode) => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - left :: right :: Nil, - nodes => nodeFunction(nodes(0), nodes(1)), - expectedAnswer, - sortAnswers) - } + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows */ - protected def doCheckAnswer( - input: Seq[DataFrame], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - LocalNodeTest.checkAnswer( - input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { - case Some(errorMessage) => fail(errorMessage) - case None => + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) } } - protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { - new SeqScanNode( - conf, - df.queryExecution.sparkPlan.output, - df.queryExecution.toRdd.map(_.copy()).collect()) - } - -} - -/** - * Helper methods for writing tests of individual local physical operators. - */ -object LocalNodeTest { - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. */ - def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { - - val outputNode = nodeFunction(input) - - val outputResult: Seq[Row] = try { - outputNode.collect() - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing local plan: - | $outputNode - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => - s""" - | Results do not match for local plan: - | $outputNode - | $errorMessage - """.stripMargin + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } } } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index b1ef26ba82f16..40299d9d5ee37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -18,222 +18,128 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class NestedLoopJoinNodeSuite extends LocalNodeTest { - import testImplicits._ - - private def joinSuite( - suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { - test(s"$suiteName: left outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("n") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - upperCaseData.col("N") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) } } + } - test(s"$suiteName: right outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) } - test(s"$suiteName: full outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) - } + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") } } - joinSuite( - "general-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "general-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "tungsten-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") - joinSuite( - "tungsten-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index 38e0a230c46d8..02ecb23d34b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -17,28 +17,33 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} -class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val output = testData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - testData, - node => ProjectNode(conf, columns, node), - testData.select("value", "key").collect() - ) +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val output = emptyTestData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - emptyTestData, - node => ProjectNode(conf, columns, node), - emptyTestData.select("value", "key").collect() - ) + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index 87a7da453999c..a3e83bbd51457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,21 +17,32 @@ package org.apache.spark.sql.execution.local -class SampleNodeSuite extends LocalNodeTest { +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + - import testImplicits._ +class SampleNodeSuite extends LocalNodeTest { private def testSample(withReplacement: Boolean): Unit = { - test(s"withReplacement: $withReplacement") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), - input.sample(withReplacement, 0.3, seed).collect() - ) + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index ff28b24eeff14..42ebc7bfcaadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -17,38 +17,34 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} +import scala.util.Random -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder - import testImplicits._ - private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - sortOrder - } +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { - val testCaseName = if (desc) "desc" else "asc" - test(testCaseName) { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val sortColumn = if (desc) input.col("key").desc else input.col("key") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), - input.sort(sortColumn).limit(5).collect() - ) + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) } } - testTakeOrderedAndProjectNode(desc = false) - testTakeOrderedAndProjectNode(desc = true) + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index eedd7320900f9..666b0235c061d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -17,36 +17,39 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { +class UnionNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer2( - testData, - testData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - testData.unionAll(testData).collect() - ) + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer2( - emptyTestData, - emptyTestData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - emptyTestData.unionAll(emptyTestData).collect() - ) + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) } - test("complicated union") { - val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, - emptyTestData, emptyTestData, testData, emptyTestData) - doCheckAnswer( - dfs, - nodes => UnionNode(conf, nodes), - dfs.reduce(_.unionAll(_)).collect() - ) + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) } } From 64c29afcb787d9f176a197c25314295108ba0471 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 15 Sep 2015 19:41:38 -0700 Subject: [PATCH 0107/2089] [SPARK-9078] [SQL] Allow jdbc dialects to override the query used to check the table. Current implementation uses query with a LIMIT clause to find if table already exists. This syntax works only in some database systems. This patch changes the default query to the one that is likely to work on most databases, and adds a new method to the JdbcDialect abstract class to allow dialects to override the default query. I looked at using the JDBC meta data calls, it turns out there is no common way to find the current schema, catalog..etc. There is a new method Connection.getSchema() , but that is available only starting jdk1.7 , and existing jdbc drivers may not have implemented it. Other option was to use jdbc escape syntax clause for LIMIT, not sure on how well this supported in all the databases also. After looking at all the jdbc metadata options my conclusion was most common way is to use the simple select query with 'where 1 =0' , and allow dialects to customize as needed Author: sureshthalamati Closes #8676 from sureshthalamati/table_exists_spark-9078. --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 9 ++++++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 14 +++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b2a66dd417b4c..745bb4ec9cf1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val conn = JdbcUtils.createConnection(url, props) try { - var tableExists = JdbcUtils.tableExists(conn, table) + var tableExists = JdbcUtils.tableExists(conn, url, table) if (mode == SaveMode.Ignore && tableExists) { return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 26788b2a4fd69..f89d55b20e212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -42,10 +42,13 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, table: String): Boolean = { + def tableExists(conn: Connection, url: String, table: String): Boolean = { + val dialect = JdbcDialects.get(url) + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + // SQL database systems using JDBC meta data calls, considering "table" could also include + // the database name. Query used to find table exists can be overriden by the dialects. + Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c6d05c9b83b98..68ebaaca6c53d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect { override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ed710689cc670..5ab9381de4d66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + } } From b921fe4dc0442aa133ab7d55fba24bc798d59aa2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 15 Sep 2015 19:43:26 -0700 Subject: [PATCH 0108/2089] [SPARK-10595] [ML] [MLLIB] [DOCS] Various ML guide cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Various ML guide cleanups. * ml-guide.md: Make it easier to access the algorithm-specific guides. * LDA user guide: EM often begins with useless topics, but running longer generally improves them dramatically. E.g., 10 iterations on a Wikipedia dataset produces useless topics, but 50 iterations produces very meaningful topics. * mllib-feature-extraction.html#elementwiseproduct: “w†parameter should be “scalingVec†* Clean up Binarizer user guide a little. * Document in Pipeline that users should not put an instance into the Pipeline in more than 1 place. * spark.ml Word2Vec user guide: clean up grammar/writing * Chi Sq Feature Selector docs: Improve text in doc. CC: mengxr feynmanliang Author: Joseph K. Bradley Closes #8752 from jkbradley/mlguide-fixes-1.5. --- docs/ml-features.md | 34 +++++++++++++++++--- docs/ml-guide.md | 31 ++++++++++++------- docs/mllib-clustering.md | 4 +++ docs/mllib-feature-extraction.md | 53 +++++++++++++++++++++----------- docs/mllib-guide.md | 4 +-- 5 files changed, 91 insertions(+), 35 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a414c21b5c280..b70da4ac63845 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -123,12 +123,21 @@ for features_label in rescaledData.select("features", "label").take(3): ## Word2Vec -`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. +`Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a +`Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` +transforms each document into a vector using the average of all words in the document; this vector +can then be used for as features for prediction, document similarity calculations, etc. +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +details. -Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. +In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
+ +Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) +for more details on the API. + {% highlight scala %} import org.apache.spark.ml.feature.Word2Vec @@ -152,6 +161,10 @@ result.select("result").take(3).foreach(println)
+ +Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) +for more details on the API. + {% highlight java %} import java.util.Arrays; @@ -192,6 +205,10 @@ for (Row r: result.select("result").take(3)) {
+ +Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) +for more details on the API. + {% highlight python %} from pyspark.ml.feature import Word2Vec @@ -621,12 +638,15 @@ for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): ## Binarizer -Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. +Binarization is the process of thresholding numerical features to binary (0/1) features. -A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0.
+ +Refer to the [Binarizer API doc](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details. + {% highlight scala %} import org.apache.spark.ml.feature.Binarizer import org.apache.spark.sql.DataFrame @@ -650,6 +670,9 @@ binarizedFeatures.collect().foreach(println)
+ +Refer to the [Binarizer API doc](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details. + {% highlight java %} import java.util.Arrays; @@ -687,6 +710,9 @@ for (Row r : binarizedFeatures.collect()) {
+ +Refer to the [Binarizer API doc](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details. + {% highlight python %} from pyspark.ml.feature import Binarizer diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 78c93a95c7807..c5d7f990021f1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -32,7 +32,21 @@ See the [algorithm guides](#algorithm-guides) section below for guides on sub-pa * This will become a table of contents (this text will be scraped). {:toc} -# Main concepts +# Algorithm guides + +We provide several algorithm guides specific to the Pipelines API. +Several of these algorithms, such as certain feature transformers, are not in the `spark.mllib` API. +Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., random forests +provide class probabilities, and linear models provide model summaries. + +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) + + +# Main concepts in Pipelines Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. @@ -166,6 +180,11 @@ compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. @@ -184,16 +203,6 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm guides - -There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -* [Feature extraction, transformation, and selection](ml-features.html) -* [Decision Trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) - # Code examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 3fb35d3c50b06..c2711cf82deb4 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -507,6 +507,10 @@ must also be $> 1.0$. Providing `Vector(-1)` results in default behavior $> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. * `maxIterations`: The maximum number of EM iterations. +*Note*: It is important to do enough iterations. In early iterations, EM often has useless topics, +but those topics improve dramatically after more iterations. Using at least 20 and possibly +50-100 iterations is often reasonable, depending on your dataset. + `EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only the inferred topics but also the full training corpus and topic distributions for each document in the training corpus. A diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index de86aba2ae627..7e417ed5f37a9 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -380,35 +380,43 @@ data2 = labels.zip(normalizer2.transform(features))
-## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. +## ChiSqSelector -### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) tries to identify relevant +features for use in model construction. It reduces the size of the feature space, which can improve +both speed and statistical learning behavior. -#### Model Fitting +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements +Chi-Squared feature selection. It operates on labeled data with categorical features. +`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, +and then filters (selects) the top features which the class label depends on the most. +This is akin to yielding the features with the most predictive power. -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the -following parameters in the constructor: +The number of features to select can be tuned using a held-out validation set. -* `numTopFeatures` number of top features that the selector will select (filter). +### Model Fitting -We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in -`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that +the selector will select. -This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) -which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes +an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then +returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +The `ChiSqSelectorModel` can be applied either to a `Vector` to produce a reduced `Vector`, or to an `RDD[Vector]` to produce a reduced `RDD[Vector]`. Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). -#### Example +### Example The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
-
+
+ +Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) +for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors @@ -434,7 +442,11 @@ val filteredData = discretizedData.map { lp => {% endhighlight %}
-
+
+ +Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) +for details on the API. + {% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -486,7 +498,12 @@ sc.stop(); ## ElementwiseProduct -ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. +`ElementwiseProduct` multiplies each input vector by a provided "weight" vector, using element-wise +multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This +represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) +between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. +Qu8T948*1# +Denoting the `scalingVec` as "`w`," this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ @@ -506,7 +523,7 @@ v_N [`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: -* `w`: the transforming vector. +* `scalingVec`: the transforming vector. `ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 257f7cc7603fa..91e50ccfecec4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -13,9 +13,9 @@ primitives and higher-level pipeline APIs. It divides into two packages: -* [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API +* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). -* [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API +* [`spark.ml`](ml-guide.html) provides higher-level API built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. From 95b6a8103fb527f501ca26b1d6e3a5859970a1e2 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 15 Sep 2015 23:25:51 -0700 Subject: [PATCH 0109/2089] [SPARK-10516] [ MLLIB] Added values property in DenseVector Author: Vinod K C Closes #8682 from vinodkc/fix_SPARK-10516. --- python/pyspark/mllib/linalg/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 380f86e9b44f8..4829acb16ed8a 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -399,6 +399,10 @@ def squared_distance(self, other): def toArray(self): return self.array + @property + def values(self): + return self.array + def __getitem__(self, item): return self.array[item] From 1894653edce718e874d1ddc9ba442bce43cbc082 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Wed, 16 Sep 2015 10:47:30 +0100 Subject: [PATCH 0110/2089] [SPARK-10511] [BUILD] Reset git repository before packaging source distro The calculation of Spark version is downloading Scala and Zinc in the build directory which is inflating the size of the source distribution. Reseting the repo before packaging the source distribution fix this issue. Author: Luciano Resende Closes #8774 from lresende/spark-10511. --- dev/create-release/release-build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index d0b3a54dde1dc..9dac43ce54425 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -99,6 +99,7 @@ fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" USER_HOST="$ASF_USERNAME@people.apache.org" +git clean -d -f -x rm .gitignore rm -rf .git cd .. From d9b7f3e4dbceb91ea4d1a1fed3ab847335f8588b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Sep 2015 04:34:14 -0700 Subject: [PATCH 0111/2089] [SPARK-10276] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.recommendation Author: Yu ISHIKAWA Closes #8677 from yu-iskw/SPARK-10276. --- python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 506ca2151cce7..95047b5b7b4b7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,7 +18,7 @@ import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable @@ -36,6 +36,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): (1, 2, 5.0) >>> (r[0], r[1], r[2]) (1, 2, 5.0) + + .. versionadded:: 1.2.0 """ def __reduce__(self): @@ -111,13 +113,17 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, user, product): """ Predicts rating for the given user and product. """ return self._java_model.predict(int(user), int(product)) + @since("0.9.0") def predictAll(self, user_product): """ Returns a list of predicted ratings for input user and product pairs. @@ -128,6 +134,7 @@ def predictAll(self, user_product): user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) + @since("1.2.0") def userFeatures(self): """ Returns a paired RDD, where the first element is the user and the @@ -135,6 +142,7 @@ def userFeatures(self): """ return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.2.0") def productFeatures(self): """ Returns a paired RDD, where the first element is the product and the @@ -142,6 +150,7 @@ def productFeatures(self): """ return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.4.0") def recommendUsers(self, product, num): """ Recommends the top "num" number of users for a given product and returns a list @@ -149,6 +158,7 @@ def recommendUsers(self, product, num): """ return list(self.call("recommendUsers", product, num)) + @since("1.4.0") def recommendProducts(self, user, num): """ Recommends the top "num" number of products for a given user and returns a list @@ -157,17 +167,25 @@ def recommendProducts(self, user, num): return list(self.call("recommendProducts", user, num)) @property + @since("1.4.0") def rank(self): + """Rank for the features in this model""" return self.call("rank") @classmethod + @since("1.3.1") def load(cls, sc, path): + """Load a model from the given path""" model = cls._load_java(sc, path) wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) return MatrixFactorizationModel(wrapper) class ALS(object): + """Alternating Least Squares matrix factorization + + .. versionadded:: 0.9.0 + """ @classmethod def _prepare(cls, ratings): @@ -188,15 +206,31 @@ def _prepare(cls, ratings): return ratings @classmethod + @since("0.9.0") def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of ratings given by users to some products, + in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + product of two lower-rank matrices of a given rank (number of features). To solve for these + features, we run a given number of iterations of ALS. This is done using a level of + parallelism given by `blocks`. + """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod + @since("0.9.0") def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of 'implicit preferences' given by users + to some products, in the form of (userID, productID, preference) pairs. We approximate the + ratings matrix as the product of two lower-rank matrices of a given rank (number of + features). To solve for these features, we run a given number of iterations of ALS. + This is done using a level of parallelism given by `blocks`. + """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) From 5dbaf3d3911bbfa003bc75459aaad66b4f6e0c67 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 16 Sep 2015 19:19:23 +0100 Subject: [PATCH 0112/2089] [SPARK-10589] [WEBUI] Add defense against external site framing Set `X-Frame-Options: SAMEORIGIN` to protect against frame-related vulnerability Author: Sean Owen Closes #8745 from srowen/SPARK-10589. --- .../spark/deploy/worker/ui/WorkerWebUI.scala | 7 ++++--- .../org/apache/spark/metrics/MetricsSystem.scala | 2 +- .../spark/metrics/sink/MetricsServlet.scala | 6 +++--- .../scala/org/apache/spark/ui/JettyUtils.scala | 16 ++++++++++++++-- .../main/scala/org/apache/spark/ui/WebUI.scala | 4 ++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 709a27233598c..1a0598e50dcf1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -49,7 +48,9 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4517f465ebd3b..48afe3ae3511f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -88,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a33074..4193e1d21d3c1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 779c0ba083596..b796a44fe01ac 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,6 +78,7 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) // scalastyle:on println @@ -97,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 61449847add3d..81a121fd441bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -76,9 +76,9 @@ private[spark] abstract class WebUI( def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath) + (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + (request: HttpServletRequest) => page.renderJson(request), securityManager, conf, basePath) attachHandler(renderHandler) attachHandler(renderJsonHandler) pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) From 896edb51ab7a88bbb31259e565311a9be6f2ca6d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 16 Sep 2015 13:20:39 -0700 Subject: [PATCH 0113/2089] [SPARK-10050] [SPARKR] Support collecting data of MapType in DataFrame. 1. Support collecting data of MapType from DataFrame. 2. Support data of MapType in createDataFrame. Author: Sun Rui Closes #8711 from sun-rui/SPARK-10050. --- R/pkg/R/SQLContext.R | 5 +- R/pkg/R/deserialize.R | 14 +++++ R/pkg/R/schema.R | 34 ++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 56 +++++++++++++++---- .../scala/org/apache/spark/api/r/SerDe.scala | 31 ++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 6 ++ 6 files changed, 123 insertions(+), 23 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4ac057d0f2d83..1c58fd96d750a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,10 +41,7 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d1858ec227b56..ce88d0b071b72 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,7 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -121,6 +122,19 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 62d4f73878d29..8df1563f8ebc0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,13 +131,33 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - m <- regexec("^array<(.*)>$", type) - matchedStrings <- regmatches(type, m) - if (length(matchedStrings[[1]]) >= 2) { - elemType <- matchedStrings[[1]][2] - checkType(elemType) - return() - } + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }) } stop(paste("Unsupported type for Dataframe:", type)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 98d4402d368e1..e159a69584274 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,7 +57,7 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) -test_that("infer types", { +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") @@ -72,9 +72,9 @@ test_that("infer types", { checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { @@ -242,7 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -253,21 +253,35 @@ test_that("create DataFrame with nested array and struct", { # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) - # ArrayType only for now - l <- list(as.list(1:10), list("a", "b")) - df <- createDataFrame(sqlContext, list(l), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b")) + expect_equal(names(ldf), c("a", "b", "c")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) }) +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("Collect DataFrame with complex types", { - # only ArrayType now - # TODO: tests for StructType and MapType after they are supported + # ArrayType df <- jsonFile(sqlContext, complexTypeJsonPath) ldf <- collect(df) @@ -277,6 +291,24 @@ test_that("Collect DataFrame with complex types", { expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c92bb7a1c73c..0c78613e406e1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -209,11 +209,23 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + def writeObject(dos: DataOutputStream, obj: Object): Unit = { if (obj == null) { writeType(dos, "void") @@ -306,6 +318,25 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case _ => writeType(dos, "jobj") writeJObj(dos, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d4b834adb6e39..f45d119c8cfdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -64,6 +64,12 @@ private[r] object SQLUtils { case r"\Aarray<(.*)${elemType}>\Z" => { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } + case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } From d39f15ea2b8bed5342d2f8e3c1936f915c470783 Mon Sep 17 00:00:00 2001 From: Kevin Cox Date: Wed, 16 Sep 2015 15:30:17 -0700 Subject: [PATCH 0114/2089] [SPARK-9794] [SQL] Fix datetime parsing in SparkSQL. This fixes https://issues.apache.org/jira/browse/SPARK-9794 by using a real ISO8601 parser. (courtesy of the xml component of the standard java library) cc: angelini Author: Kevin Cox Closes #8396 from kevincox/kevincox-sql-time-parsing. --- .../sql/catalyst/util/DateTimeUtils.scala | 27 ++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 687ca000d12bb..400c4327be1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} +import javax.xml.bind.DatatypeConverter; import org.apache.spark.unsafe.types.UTF8String @@ -109,30 +110,22 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { + var indexOfGMT = s.indexOf("GMT"); + if (indexOfGMT != -1) { + // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) + val s0 = s.substring(0, indexOfGMT) + val s1 = s.substring(indexOfGMT + 3) + // Mapped to 2000-01-01T00:00+01:00 + stringToTime(s0 + s1) + } else if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { Timestamp.valueOf(s) } else { Date.valueOf(s) } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) + DatatypeConverter.parseDateTime(s).getTime() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 6b9a11f0ff743..46335941b62d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -136,6 +136,38 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } + test("string to time") { + // Tests with UTC. + var c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-00:00") === c.getTime()) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30T10:00:00Z") === c.getTime()) + + // Tests with set time zone. + c.setTimeZone(TimeZone.getTimeZone("GMT-04:00")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00-04:00") === c.getTime()) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-04:00") === c.getTime()) + + // Tests with local time zone. + c.setTimeZone(TimeZone.getDefault()) + c.set(Calendar.MILLISECOND, 0) + + c.set(2000, 11, 30, 0, 0, 0) + assert(stringToTime("2000-12-30") === new Date(c.getTimeInMillis())) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30 10:00:00") === new Timestamp(c.getTimeInMillis())) + } + test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) From 49c649fa0b6affed108dbae85373b4b7247b338c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Sep 2015 15:32:01 -0700 Subject: [PATCH 0115/2089] Tiny style fix for d39f15ea2b8bed5342d2f8e3c1936f915c470783. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 400c4327be1c7..781ed1688a327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} -import javax.xml.bind.DatatypeConverter; +import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String From 69c9830d288d5b8d7f0abe7c8a65a4c966580a49 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 17 Sep 2015 00:48:57 -0700 Subject: [PATCH 0116/2089] [MINOR] [CORE] Fixes minor variable name typo Author: Cheng Lian Closes #8784 from liancheng/typo-fix. --- .../apache/spark/serializer/GenericAvroSerializerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69d..87f25e7245e1f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -76,9 +76,9 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) - val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) - assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } } From c633ed3260140f1288f326acc4d7a10dcd2e27d5 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:43:59 -0700 Subject: [PATCH 0117/2089] [SPARK-10284] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.tuning Author: Yu ISHIKAWA Closes #8694 from yu-iskw/SPARK-10284. --- python/pyspark/ml/tuning.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index cae778869e9c5..ab5621f45c72c 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,6 +18,7 @@ import itertools import numpy as np +from pyspark import since from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model from pyspark.ml.util import keyword_only @@ -47,11 +48,14 @@ class ParamGridBuilder(object): True >>> all([m in expected for m in output]) True + + .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {} + @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. @@ -60,6 +64,7 @@ def addGrid(self, param, values): return self + @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. @@ -73,6 +78,7 @@ def baseOn(self, *args): return self + @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified @@ -104,6 +110,8 @@ class CrossValidator(Estimator): >>> cvModel = cv.fit(dataset) >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -142,6 +150,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self._set(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): @@ -150,6 +159,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. @@ -157,12 +167,14 @@ def setEstimator(self, value): self._paramMap[self.estimator] = value return self + @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) + @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. @@ -170,12 +182,14 @@ def setEstimatorParamMaps(self, value): self._paramMap[self.estimatorParamMaps] = value return self + @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) + @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. @@ -183,12 +197,14 @@ def setEvaluator(self, value): self._paramMap[self.evaluator] = value return self + @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. @@ -196,6 +212,7 @@ def setNumFolds(self, value): self._paramMap[self.numFolds] = value return self + @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. @@ -231,7 +248,15 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ if extra is None: extra = dict() newCV = Params.copy(self, extra) @@ -246,6 +271,8 @@ def copy(self, extra=None): class CrossValidatorModel(Model): """ Model from k-fold cross validation. + + .. versionadded:: 1.4.0 """ def __init__(self, bestModel): @@ -256,6 +283,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) + @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid From 29bf8aa5a51fdd8c2600533297f991e14fa27c03 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:45:20 -0700 Subject: [PATCH 0118/2089] [SPARK-10283] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.regression Author: Yu ISHIKAWA Closes #8693 from yu-iskw/SPARK-10283. --- python/pyspark/ml/regression.py | 65 +++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a9503608b7f25..21d454f9003bb 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -62,6 +63,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ @keyword_only @@ -81,6 +84,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True): @@ -96,13 +100,31 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) + @since("1.4.0") + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + @since("1.4.0") + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + class LinearRegressionModel(JavaModel): """ Model fitted by LinearRegression. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -110,6 +132,7 @@ def weights(self): return self._call_java("weights") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -162,6 +185,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -193,6 +218,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -209,6 +235,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -216,6 +243,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -225,13 +253,19 @@ def getImpurity(self): @inherit_doc class DecisionTreeModel(JavaModel): + """Abstraction for Decision Tree models. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def numNodes(self): """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property + @since("1.5.0") def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") @@ -242,8 +276,13 @@ def __repr__(self): @inherit_doc class TreeEnsembleModels(JavaModel): + """Represents a tree ensemble model. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @@ -256,6 +295,8 @@ def __repr__(self): class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. + + .. versionadded:: 1.4.0 """ @@ -282,6 +323,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -336,6 +379,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, @@ -353,6 +397,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -360,12 +405,14 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -373,12 +420,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. @@ -386,12 +435,14 @@ def setNumTrees(self, value): self._paramMap[self.numTrees] = value return self + @since("1.4.0") def getNumTrees(self): """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) + @since("1.4.0") def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. @@ -399,6 +450,7 @@ def setFeatureSubsetStrategy(self, value): self._paramMap[self.featureSubsetStrategy] = value return self + @since("1.4.0") def getFeatureSubsetStrategy(self): """ Gets the value of featureSubsetStrategy or its default value. @@ -409,6 +461,8 @@ def getFeatureSubsetStrategy(self): class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. + + .. versionadded:: 1.4.0 """ @@ -435,6 +489,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -481,6 +537,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -498,6 +555,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -505,12 +563,14 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -518,12 +578,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. @@ -531,6 +593,7 @@ def setStepSize(self, value): self._paramMap[self.stepSize] = value return self + @since("1.4.0") def getStepSize(self): """ Gets the value of stepSize or its default value. @@ -541,6 +604,8 @@ def getStepSize(self): class GBTRegressionModel(TreeEnsembleModels): """ Model fitted by GBTRegressor. + + .. versionadded:: 1.4.0 """ From 0ded87a4d49d4484e202bd2ec781821b57b5882c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:47:21 -0700 Subject: [PATCH 0119/2089] [SPARK-10281] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.clustering Author: Yu ISHIKAWA Closes #8691 from yu-iskw/SPARK-10281. --- python/pyspark/ml/clustering.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index cb4c16e25a7a3..7bb8ab94e17df 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -26,8 +27,11 @@ class KMeansModel(JavaModel): """ Model fitted by KMeans. + + .. versionadded:: 1.5.0 """ + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @@ -55,6 +59,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -88,6 +94,7 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ @@ -99,6 +106,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -110,12 +118,14 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of `k` """ return self.getOrDefault(self.k) + @since("1.5.0") def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. @@ -130,12 +140,14 @@ def setInitMode(self, value): self._paramMap[self.initMode] = value return self + @since("1.5.0") def getInitMode(self): """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) + @since("1.5.0") def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. @@ -147,6 +159,7 @@ def setInitSteps(self, value): self._paramMap[self.initSteps] = value return self + @since("1.5.0") def getInitSteps(self): """ Gets the value of `initSteps` From 39b44cb52eb225469eb4ccdf696f0bc6405b9184 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:48:45 -0700 Subject: [PATCH 0120/2089] [SPARK-10278] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.tree Author: Yu ISHIKAWA Closes #8685 from yu-iskw/SPARK-10278. --- python/pyspark/mllib/tree.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 372b86a7c95d9..0001b60093a69 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ import random -from pyspark import SparkContext, RDD +from pyspark import SparkContext, RDD, since from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -30,6 +30,11 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + """TreeEnsembleModel + + .. versionadded:: 1.3.0 + """ + @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -45,12 +50,14 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.3.0") def numTrees(self): """ Get number of trees in ensemble. """ return self.call("numTrees") + @since("1.3.0") def totalNumNodes(self): """ Get total number of nodes, summed over all trees in the @@ -62,6 +69,7 @@ def __repr__(self): """ Summary of model """ return self._java_model.toString() + @since("1.3.0") def toDebugString(self): """ Full model """ return self._java_model.toDebugString() @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. note:: Experimental A decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ + @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. @@ -90,16 +101,23 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.1.0") def numNodes(self): + """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() + @since("1.1.0") def depth(self): + """Get depth of tree. + E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ return self._java_model.depth() def __repr__(self): """ summary of model. """ return self._java_model.toString() + @since("1.2.0") def toDebugString(self): """ full model. """ return self._java_model.toDebugString() @@ -115,6 +133,8 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ @classmethod @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod + @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @classmethod + @since("1.1.0") def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a random forest model. + + .. versionadded:: 1.2.0 """ @classmethod @@ -252,6 +276,8 @@ class RandomForest(object): Learning algorithm for a random forest model for classification or regression. + + .. versionadded:: 1.2.0 """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod + @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod + @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a gradient-boosted tree model. + + .. versionadded:: 1.3.0 """ @classmethod @@ -431,6 +461,8 @@ class GradientBoostedTrees(object): Learning algorithm for a gradient boosted trees model for classification or regression. + + .. versionadded:: 1.3.0 """ @classmethod @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod + @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod + @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): From 4a0b56e8dbb3713b16e58738201d838ffc4b258b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:00 -0700 Subject: [PATCH 0121/2089] [SPARK-10279] [MLLIB] [PYSPARK] [DOCS] Add @since annotation to pyspark.mllib.util Author: Yu ISHIKAWA Closes #8689 from yu-iskw/SPARK-10279. --- python/pyspark/mllib/util.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 10a1e4b3eb0fc..39bc6586dd582 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -23,7 +23,7 @@ xrange = range basestring = str -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -32,6 +32,8 @@ class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. + + .. versionadded:: 1.0.0 """ @staticmethod @@ -69,6 +71,7 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod + @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -123,6 +126,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod + @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. @@ -147,6 +151,7 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod + @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. @@ -172,6 +177,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) @staticmethod + @since("1.5.0") def appendBias(data): """ Returns a new vector with `1.0` (bias) appended to @@ -186,6 +192,7 @@ def appendBias(data): return _convert_to_vector(np.append(vec.toArray(), 1.0)) @staticmethod + @since("1.5.0") def loadVectors(sc, path): """ Loads vectors saved using `RDD[Vector].saveAsTextFile` @@ -197,6 +204,8 @@ def loadVectors(sc, path): class Saveable(object): """ Mixin for models and transformers which may be saved as files. + + .. versionadded:: 1.3.0 """ def save(self, sc, path): @@ -222,9 +231,13 @@ class JavaSaveable(Saveable): """ Mixin for models that provide save() through their Scala implementation. + + .. versionadded:: 1.3.0 """ + @since("1.3.0") def save(self, sc, path): + """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): @@ -235,6 +248,8 @@ def save(self, sc, path): class Loader(object): """ Mixin for classes which can load saved models from files. + + .. versionadded:: 1.3.0 """ @classmethod @@ -256,6 +271,8 @@ class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala implementation. + + .. versionadded:: 1.3.0 """ @classmethod @@ -280,15 +297,21 @@ def _load_java(cls, sc, path): return java_obj.load(sc._jsc.sc(), path) @classmethod + @since("1.3.0") def load(cls, sc, path): + """Load a model from the given path.""" java_model = cls._load_java(sc, path) return cls(java_model) class LinearDataGenerator(object): - """Utils for generating linear data""" + """Utils for generating linear data. + + .. versionadded:: 1.5.0 + """ @staticmethod + @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ @@ -311,6 +334,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, xVariance, int(nPoints), int(seed), float(eps))) @staticmethod + @since("1.5.0") def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ From c74d38fd8faf8cba981cf934341d24b9a3167025 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:46 -0700 Subject: [PATCH 0122/2089] [SPARK-10274] [MLLIB] Add @since annotation to pyspark.mllib.fpm Author: Yu ISHIKAWA Closes #8665 from yu-iskw/SPARK-10274. --- python/pyspark/mllib/fpm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b18..bdabba9602a8c 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -19,7 +19,7 @@ from numpy import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -41,8 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + + .. versionadded:: 1.4.0 """ + @since("1.4.0") def freqItemsets(self): """ Returns the frequent itemsets of this model. @@ -55,9 +58,12 @@ class FPGrowth(object): .. note:: Experimental A Parallel FP-growth algorithm to mine frequent itemsets. + + .. versionadded:: 1.4.0 """ @classmethod + @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. @@ -74,6 +80,8 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ Represents an (items, freq) tuple. + + .. versionadded:: 1.4.0 """ From 268088b899e6e165e746aed87840d47bfaf50c43 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:51:19 -0700 Subject: [PATCH 0123/2089] [SPARK-10282] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.recommendation Author: Yu ISHIKAWA Closes #8692 from yu-iskw/SPARK-10282. --- python/pyspark/ml/recommendation.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b06099ac0aee6..ec5748a1cfe94 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -80,6 +81,8 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=3.19...) >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -122,6 +125,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): @@ -137,6 +141,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem def _create_model(self, java_model): return ALSModel(java_model) + @since("1.4.0") def setRank(self, value): """ Sets the value of :py:attr:`rank`. @@ -144,12 +149,14 @@ def setRank(self, value): self._paramMap[self.rank] = value return self + @since("1.4.0") def getRank(self): """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) + @since("1.4.0") def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. @@ -157,12 +164,14 @@ def setNumUserBlocks(self, value): self._paramMap[self.numUserBlocks] = value return self + @since("1.4.0") def getNumUserBlocks(self): """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) + @since("1.4.0") def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. @@ -170,12 +179,14 @@ def setNumItemBlocks(self, value): self._paramMap[self.numItemBlocks] = value return self + @since("1.4.0") def getNumItemBlocks(self): """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) + @since("1.4.0") def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. @@ -183,6 +194,7 @@ def setNumBlocks(self, value): self._paramMap[self.numUserBlocks] = value self._paramMap[self.numItemBlocks] = value + @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. @@ -190,12 +202,14 @@ def setImplicitPrefs(self, value): self._paramMap[self.implicitPrefs] = value return self + @since("1.4.0") def getImplicitPrefs(self): """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) + @since("1.4.0") def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. @@ -203,12 +217,14 @@ def setAlpha(self, value): self._paramMap[self.alpha] = value return self + @since("1.4.0") def getAlpha(self): """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) + @since("1.4.0") def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. @@ -216,12 +232,14 @@ def setUserCol(self, value): self._paramMap[self.userCol] = value return self + @since("1.4.0") def getUserCol(self): """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) + @since("1.4.0") def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. @@ -229,12 +247,14 @@ def setItemCol(self, value): self._paramMap[self.itemCol] = value return self + @since("1.4.0") def getItemCol(self): """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) + @since("1.4.0") def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. @@ -242,12 +262,14 @@ def setRatingCol(self, value): self._paramMap[self.ratingCol] = value return self + @since("1.4.0") def getRatingCol(self): """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) + @since("1.4.0") def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. @@ -255,6 +277,7 @@ def setNonnegative(self, value): self._paramMap[self.nonnegative] = value return self + @since("1.4.0") def getNonnegative(self): """ Gets the value of nonnegative or its default value. @@ -265,14 +288,18 @@ def getNonnegative(self): class ALSModel(JavaModel): """ Model fitted by ALS. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def rank(self): """rank of the matrix factorization model""" return self._call_java("rank") @property + @since("1.4.0") def userFactors(self): """ a DataFrame that stores user factors in two columns: `id` and @@ -281,6 +308,7 @@ def userFactors(self): return self._call_java("userFactors") @property + @since("1.4.0") def itemFactors(self): """ a DataFrame that stores item factors in two columns: `id` and From e51345e1e04e439827a07c95887d14ba38333057 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 17 Sep 2015 09:17:43 -0700 Subject: [PATCH 0124/2089] [SPARK-10077] [DOCS] [ML] Add package info for java of ml/feature Should be the same as SPARK-7808 but use Java for the code example. It would be great to add package doc for `spark.ml.feature`. Author: Holden Karau Closes #8740 from holdenk/SPARK-10077-JAVA-PACKAGE-DOC-FOR-SPARK.ML.FEATURE. --- .../apache/spark/ml/feature/package-info.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java new file mode 100644 index 0000000000000..c22d2e0cd2d90 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Feature transformers + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which + * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., + * {@link org.apache.spark.feature.HashingTF}. + * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the + * transformation requires some aggregated information of the dataset, e.g., document + * frequencies in {@link org.apache.spark.ml.feature.IDF}. + * For those feature transformers, calling {@link org.apache.spark.ml.Estimator#fit} is required to + * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply + * transformation. + * The transformation is usually done by appending new columns to the input + * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * {@link org.apache.spark.ml.Pipeline} can be used to chain feature transformers, and + * {@link org.apache.spark.ml.feature.VectorAssembler} can be used to combine multiple feature + * transformations, for example: + * + *
+ * 
+ *   import java.util.Arrays;
+ *
+ *   import org.apache.spark.api.java.JavaRDD;
+ *   import static org.apache.spark.sql.types.DataTypes.*;
+ *   import org.apache.spark.sql.types.StructType;
+ *   import org.apache.spark.sql.DataFrame;
+ *   import org.apache.spark.sql.RowFactory;
+ *   import org.apache.spark.sql.Row;
+ *
+ *   import org.apache.spark.ml.feature.*;
+ *   import org.apache.spark.ml.Pipeline;
+ *   import org.apache.spark.ml.PipelineStage;
+ *   import org.apache.spark.ml.PipelineModel;
+ *
+ *  // a DataFrame with three columns: id (integer), text (string), and rating (double).
+ *  StructType schema = createStructType(
+ *    Arrays.asList(
+ *      createStructField("id", IntegerType, false),
+ *      createStructField("text", StringType, false),
+ *      createStructField("rating", DoubleType, false)));
+ *  JavaRDD rowRDD = jsc.parallelize(
+ *    Arrays.asList(
+ *      RowFactory.create(0, "Hi I heard about Spark", 3.0),
+ *      RowFactory.create(1, "I wish Java could use case classes", 4.0),
+ *      RowFactory.create(2, "Logistic regression models are neat", 4.0)));
+ *  DataFrame df = jsql.createDataFrame(rowRDD, schema);
+ *  // define feature transformers
+ *  RegexTokenizer tok = new RegexTokenizer()
+ *    .setInputCol("text")
+ *    .setOutputCol("words");
+ *  StopWordsRemover sw = new StopWordsRemover()
+ *    .setInputCol("words")
+ *    .setOutputCol("filtered_words");
+ *  HashingTF tf = new HashingTF()
+ *    .setInputCol("filtered_words")
+ *    .setOutputCol("tf")
+ *    .setNumFeatures(10000);
+ *  IDF idf = new IDF()
+ *    .setInputCol("tf")
+ *    .setOutputCol("tf_idf");
+ *  VectorAssembler assembler = new VectorAssembler()
+ *    .setInputCols(new String[] {"tf_idf", "rating"})
+ *    .setOutputCol("features");
+ *
+ *  // assemble and fit the feature transformation pipeline
+ *  Pipeline pipeline = new Pipeline()
+ *    .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler});
+ *  PipelineModel model = pipeline.fit(df);
+ *
+ *  // save transformed features with raw data
+ *  model.transform(df)
+ *    .select("id", "text", "rating", "features")
+ *    .write().format("parquet").save("/output/path");
+ * 
+ * 
+ * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see + * scikit-learn.preprocessing + */ +package org.apache.spark.ml.feature; From 2a508df20d03b3d4a3c05b65fb02d849bc080ef9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 09:21:21 -0700 Subject: [PATCH 0125/2089] [SPARK-10459] [SQL] Do not need to have ConvertToSafe for PythonUDF JIRA: https://issues.apache.org/jira/browse/SPARK-10459 As mentioned in the JIRA, `PythonUDF` actually could process `UnsafeRow`. Specially, the rows in `childResults` in `BatchPythonEvaluation` will be projected to a `MutableRow`. So I think we can enable `canProcessUnsafeRows` for `BatchPythonEvaluation` and get rid of redundant `ConvertToSafe`. Author: Liang-Chi Hsieh Closes #8616 from viirya/pyudf-unsafe. --- .../scala/org/apache/spark/sql/execution/pythonUDFs.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80b..d0411da6fdf5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -337,6 +337,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) From c88bb5df94f9696677c3a429472114bc66f32a52 Mon Sep 17 00:00:00 2001 From: "yangping.wu" Date: Thu, 17 Sep 2015 09:52:40 -0700 Subject: [PATCH 0126/2089] [SPARK-10660] Doc describe error in the "Running Spark on YARN" page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the Configuration section, the **spark.yarn.driver.memoryOverhead** and **spark.yarn.am.memoryOverhead**‘s default value should be "driverMemory * 0.10, with minimum of 384" and "AM memory * 0.10, with minimum of 384" respectively. Because from Spark 1.4.0, the **MEMORY_OVERHEAD_FACTOR** is set to 0.1.0, not 0.07. Author: yangping.wu Closes #8797 from 397090770/SparkOnYarnDocError. --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d1244323edfff..3a961d245f3de 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,14 +211,14 @@ If you need a reference to the proper location to put log files in the YARN so t
- + - + From 136c77d8bbf48f7c45dd7c3fbe261a0476f455fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 10:02:15 -0700 Subject: [PATCH 0127/2089] [SPARK-10642] [PYSPARK] Fix crash when calling rdd.lookup() on tuple keys JIRA: https://issues.apache.org/jira/browse/SPARK-10642 When calling `rdd.lookup()` on a RDD with tuple keys, `portable_hash` will return a long. That causes `DAGScheduler.submitJob` to throw `java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Integer`. Author: Liang-Chi Hsieh Closes #8796 from viirya/fix-pyrdd-lookup. --- python/pyspark/rdd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9ef60a7e2c84b..ab5aab1e115f7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -84,7 +84,7 @@ def portable_hash(x): h ^= len(x) if h == -1: h = -2 - return h + return int(h) return hash(x) @@ -2192,6 +2192,9 @@ def lookup(self, key): [42] >>> sorted.lookup(1024) [] + >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey() + >>> list(rdd2.lookup(('a', 'b'))[0]) + ['c'] """ values = self.filter(lambda kv: kv[0] == key).values() From 81b4db374dd61b6f1c30511c70b6ab2a52c68faa Mon Sep 17 00:00:00 2001 From: Josiah Samuel Date: Thu, 17 Sep 2015 10:18:21 -0700 Subject: [PATCH 0128/2089] [SPARK-10172] [CORE] disable sort in HistoryServer webUI This pull request is to address the JIRA SPARK-10172 (History Server web UI gets messed up when sorting on any column). The content of the table gets messed up due to the rowspan attribute of the table data(cell) during sorting. The current table sort library used in SparkUI (sorttable.js) doesn't support/handle cells(td) with rowspans. The fix will disable the table sort in the web UI, when there are jobs listed with multiple attempts. Author: Josiah Samuel Closes #8506 from josiahsams/SPARK-10172. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba1245..b347cb3be69f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } From 36d8b278d82e788bf583e8438fac524d0023311d Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 17 Sep 2015 10:25:18 -0700 Subject: [PATCH 0129/2089] [SPARK-10531] [CORE] AppId is set as AppName in status rest api Verify it manually. Author: Jeff Zhang Closes #8688 from zjffdu/SPARK-10531. --- .../main/scala/org/apache/spark/SparkContext.scala | 1 + .../spark/deploy/history/FsHistoryProvider.scala | 9 ++++----- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 11 ++++++----- .../scala/org/apache/spark/ui/UISeleniumSuite.scala | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a2f34eafa2c38..9c3218719f7fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -521,6 +521,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a5755eac36396..8eb2ba1e8683b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -146,16 +146,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26904d39a9bec..d518e92133aad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -944,7 +944,7 @@ private[deploy] class Master( val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { replayBus.replay(logInput, eventLogFile, maybeTruncated) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index d8b90568b7b9a..99085ada9f0af 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -56,6 +56,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -75,9 +77,8 @@ private[spark] class SparkUI private ( def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,12 +95,12 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 22e30ecaf0533..18eec7da9763e 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -658,6 +658,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/test/" + path) + new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } From e0dc2bc232206d2f4da4278502c1f88babc8b55a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 17 Sep 2015 11:05:30 -0700 Subject: [PATCH 0130/2089] [SPARK-10650] Clean before building docs The [published docs for 1.5.0](http://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/streaming/) have a bunch of test classes in them. The only way I can reproduce this is to `test:compile` before running `unidoc`. To prevent this from happening again, I've added a clean before doc generation. Author: Michael Armbrust Closes #8787 from marmbrus/testsInDocs. --- docs/_plugins/copy_api_dirs.rb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 15ceda11a8a80..01718d98dffe0 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,12 +26,15 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl clean compile unidoc` puts "Moving back into docs dir." cd("docs") + puts "Removing old docs" + puts `rm -rf api` + # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. source = "../target/scala-2.10/unidoc" From aad644fbe29151aec9004817d42e4928bdb326f3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 17 Sep 2015 11:14:52 -0700 Subject: [PATCH 0131/2089] [SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai Closes #8788 from yhuai/udafConversion. --- .../sql/catalyst/CatalystTypeConverters.scala | 7 +- .../spark/sql/RandomDataGenerator.scala | 16 ++- .../spark/sql/execution/aggregate/udaf.scala | 37 +++++- .../org/apache/spark/sql/QueryTest.scala | 21 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 11 ++ .../execution/AggregationQuerySuite.scala | 108 +++++++++++++++++- 6 files changed, 188 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 966623ed017ba..f25591794abdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -138,8 +138,13 @@ object CatalystTypeConverters { private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + + override def toScala(catalystValue: Any): Any = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 4025cbcec1019..e48395028e399 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -108,7 +108,21 @@ object RandomDataGenerator { arr }) case BooleanType => Some(() => rand.nextBoolean()) - case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case DateType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt) + } + Some(generator) case TimestampType => val generator = () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d43d3dd9ffaae..1114fe6552bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < getters.length) { getters(i) = dataTypes(i) match { + case NullType => + (row: InternalRow, ordinal: Int) => null + case BooleanType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) @@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + case DateType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { + case NullType => + (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + case b: BooleanType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { @@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + // To make it work with UnsafeRow, we cannot use setNullAt. + // Please see the comment of UnsafeRow's setDecimal. + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + + case DateType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { - row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case TimestampType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) } else { row.setNullAt(ordinal) } @@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } @@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF( } } + private[this] lazy val outputToCatalystConverter: Any => Any = { + CatalystTypeConverters.createToCatalystConverter(dataType) + } + // This buffer is only used at executor side. private[this] var inputAggregateBuffer: InputAggregationBuffer = null @@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF( override def eval(buffer: InternalRow): Any = { evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(evalAggregateBuffer) + outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer)) } override def toString: String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index cada03e9ac6bb..e3c5a426671d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -115,19 +115,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 46d87843dfa4d..7992fd59ff4ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { assert(new MyDenseVectorUDT().typeName === "mydensevector") assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } + + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a73b1bd52c09f..24b1846923c77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,13 +17,55 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { + + def inputSchema: StructType = schema + + def bufferSchema: StructType = schema + + def dataType: DataType = schema + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + (0 until schema.length).foreach { i => + buffer.update(i, null) + } + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0) && input.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer.update(i, input.get(i)) + } + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer1.update(i, buffer2.get(i)) + } + } + } + + def evaluate(buffer: Row): Any = { + Row.fromSeq(buffer.toSeq) + } +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } } + + test("udaf with all data types") { + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + // Right now, we will use SortBasedAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // UnsafeRow as the aggregation buffer. While, dataTypes will trigger + // SortBasedAggregate to use a safe row as the aggregation buffer. + Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + // The schema used for data generator. + val schemaForGenerator = StructType(fields) + // The schema used for the DataFrame df. + val schema = StructType(StructField("id", IntegerType) +: fields) + + logInfo(s"Testing schema: ${schema.treeString}") + + val udaf = new ScalaAggregateFunction(schema) + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) + val data = (1 to 50).map { i => + dataGenerator.apply() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 1) + val df = sqlContext.createDataFrame(rdd, schema) + + val allColumns = df.schema.fields.map(f => col(f.name)) + val expectedAnaswer = + data + .find(r => r.getInt(0) == 50) + .getOrElse(fail("A row with id 50 should be the expected answer.")) + checkAnswer( + df.groupBy().agg(udaf(allColumns: _*)), + // udaf returns a Row as the output value. + Row(expectedAnaswer) + ) + } + } } class SortBasedAggregationQuerySuite extends AggregationQuerySuite { From 64743870f23bffb8d96dcc8a0181c1452782a151 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 11:24:38 -0700 Subject: [PATCH 0132/2089] [SPARK-10394] [ML] Make GBTParams use shared stepSize ```GBTParams``` has ```stepSize``` as learning rate currently. ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation. Author: Yanbo Liang Closes #8552 from yanboliang/spark-10394. --- .../org/apache/spark/ml/tree/treeParams.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d29f5253c9c3f..42e74ce6d2c69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -365,17 +365,7 @@ private[ml] object RandomForestParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** @group setParam */ + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group setParam + */ def setStepSize(value: Double): this.type = set(stepSize, value) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) + override def validateParams(): Unit = { + require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( + getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + + s"but it given invalid value $getStepSize.") + } /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( From f1c911552cf5d0d60831c79c1881016293aec66c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Sep 2015 11:40:24 -0700 Subject: [PATCH 0133/2089] [SPARK-10657] Remove SCP-based Jenkins log archiving As of https://issues.apache.org/jira/browse/SPARK-7561, we no longer need to use our custom SCP-based mechanism for archiving Jenkins logs on the master machine; this has been superseded by the use of a Jenkins plugin which archives the logs and provides public links to view them. Per shaneknapp, we should remove this log syncing mechanism if it is no longer necessary; removing the need to SCP from the Jenkins workers to the masters is a desired step as part of some larger Jenkins infra refactoring. Author: Josh Rosen Closes #8793 from JoshRosen/remove-jenkins-ssh-to-master. --- dev/run-tests-jenkins | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3be78575e70f1..d3b05fa6df0ce 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -116,39 +116,6 @@ function post_message () { fi } -function send_archived_logs () { - echo "Archiving unit tests logs..." - - local log_files=$( - find .\ - -name "unit-tests.log" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.wrong" - ) - - if [ -z "$log_files" ]; then - echo "> No log files found." >&2 - else - local log_archive="unit-tests-logs.tar.gz" - echo "$log_files" | xargs tar czf ${log_archive} - - local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} - local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) - local scp_status="$?" - - if [ "$scp_status" -ne 0 ]; then - echo "Failed to send archived unit tests logs to Jenkins master." >&2 - echo "> scp_status: ${scp_status}" >&2 - echo "> scp_output: ${scp_output}" >&2 - else - echo "> Send successful." - fi - - rm -f ${log_archive} - fi -} - # post start message { start_message="\ @@ -244,8 +211,6 @@ done test_result_note=" * This patch **fails $failing_test**." fi - - send_archived_logs } # post end message From 4fbf3328692e876f39ea78494510f9d9c5a53f15 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 14:09:06 -0700 Subject: [PATCH 0134/2089] [SPARK-9698] [ML] Add RInteraction transformer for supporting R-style feature interactions This is a pre-req for supporting the ":" operator in the RFormula feature transformer. Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7987 from ericl/interaction. --- .../apache/spark/ml/feature/Interaction.scala | 278 ++++++++++++++++++ .../spark/ml/feature/InteractionSuite.scala | 165 +++++++++++ 2 files changed, 443 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 0000000000000..9194763fb32f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is + * produced. + * + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. + */ +@Experimental +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } + + override def transform(dataset: DataFrame): DataFrame = { + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) + + def interactFunc = udf { row: Row => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var featureIndex = row.length - 1 + while (featureIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentEncoder = featureEncoders(featureIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + featureIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + } + } + dataset.select( + col("*"), + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + } + + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) + case _ => + 1 // numeric feature + } + } + features.map { f => + val numFeatures = f.dataType match { + case _: NumericType | BooleanType => + Array(getNumFeatures(Attribute.fromStructField(f))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( + throw new SparkException("Vector attributes must be defined for interaction.")) + attrs.map(getNumFeatures).toArray + } + new FeatureEncoder(numFeatures) + }.toArray + } + + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(f) + encodedFeatureAttrs(Seq(attr), None) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) + } + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs + } else { + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), featureAttrs.toArray) + } + + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) + parts.flatten.mkString("_") + } + + inputAttrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + override def validateParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. + * + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. + */ +private[ml] class FeatureEncoder(numFeatures: Array[Int]) { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + + /** The size of the output vector. */ + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + var i = 1 + while (i < arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 + } + arr + } + + /** + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { + case d: Double => + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { + assert( + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) + } else { + f(outputOffsets(i), v) + } + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 0000000000000..2beb62ca08233 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("feature encoder") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed + } + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } + } + + test("numeric interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) + } + + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } +} From 0f5ef6dfa67a068606aff8ea9d1addfce73446eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 17 Sep 2015 19:16:34 -0700 Subject: [PATCH 0135/2089] [SPARK-10674] [TESTS] Increase timeouts in SaslIntegrationSuite. 1s seems to trigger too many times on the jenkins build boxes, so increase the timeout and cross fingers. Author: Marcelo Vanzin Closes #8802 from vanzin/SPARK-10674 and squashes the following commits: 3c93117 [Marcelo Vanzin] Use java 7 syntax. d667d1b [Marcelo Vanzin] [SPARK-10674] [tests] Increase timeouts in SaslIntegrationSuite. --- .../spark/network/sasl/SaslIntegrationSuite.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 5cb0e4d4a6458..c393a5e1e6810 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -56,6 +56,11 @@ import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private final static long TIMEOUT_MS = 10_000; + static TransportServer server; static TransportConf conf; static TransportContext context; @@ -102,7 +107,7 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg } @@ -131,7 +136,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], 1000); + client.sendRpcSync(new byte[13], TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -139,7 +144,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -217,12 +222,12 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), 10000); + client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); + byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); long streamId = stream.streamId; From 98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 21:37:10 -0700 Subject: [PATCH 0136/2089] [SPARK-8518] [ML] Log-linear models for survival analysis [Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang Closes #8611 from yanboliang/spark-8518. --- .../ml/regression/AFTSurvivalRegression.scala | 449 ++++++++++++++++++ .../AFTSurvivalRegressionSuite.scala | 311 ++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala new file mode 100644 index 0000000000000..5b25db651f56c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Params for accelerated failure time (AFT) regression. + */ +private[regression] trait AFTSurvivalRegressionParams extends Params + with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter + with HasTol with HasFitIntercept { + + /** + * Param for censor column name. + * The value of this column could be 0 or 1. + * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. + * @group param + */ + @Since("1.6.0") + final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") + + /** @group getParam */ + @Since("1.6.0") + def getCensorCol: String = $(censorCol) + setDefault(censorCol -> "censor") + + /** + * Param for quantile probabilities array. + * Values of the quantile probabilities array should be in the range [0, 1]. + * @group param + */ + @Since("1.6.0") + final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, + "quantileProbabilities", "quantile probabilities array", + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + + /** @group getParam */ + @Since("1.6.0") + def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + + /** Checks whether the input has quantile probabilities array. */ + protected[regression] def hasQuantileProbabilities: Boolean = { + isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + } + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Fit a parametric survival regression model named accelerated failure time (AFT) model + * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]]) + * based on the Weibull distribution of the survival time. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("aftSurvReg")) + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setCensorCol(value: String): this.type = set(censorCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + @Since("1.6.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("1.6.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, + * and put it in an RDD with strong types. + */ + protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } + } + + @Since("1.6.0") + override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + val instances = extractAFTPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val costFun = new AFTCostFun(instances, $(fitIntercept)) + val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + + val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + /* + The weights vector has three parts: + the first element: Double, log(sigma), the log of scale parameter + the second element: Double, intercept of the beta parameter + the third to the end elements: Doubles, regression coefficients vector of the beta parameter + */ + val initialWeights = Vectors.zeros(numFeatures + 2) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeights.toBreeze.toDenseVector) + + val weights = { + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + throw new SparkException(msg) + } + + state.x.toArray.clone() + } + + if (handlePersistence) instances.unpersist() + + val coefficients = Vectors.dense(weights.slice(2, weights.length)) + val intercept = weights(1) + val scale = math.exp(weights(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + copyValues(model.setParent(this)) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model produced by [[AFTSurvivalRegression]]. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegressionModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val intercept: Double, + @Since("1.6.0") val scale: Double) + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + @Since("1.6.0") + def predictQuantiles(features: Vector): Vector = { + require(hasQuantileProbabilities, + "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") + // scale parameter for the Weibull distribution of lifetime + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + // shape parameter for the Weibull distribution of lifetime + val k = 1 / scale + val quantiles = $(quantileProbabilities).map { + q => lambda * math.exp(math.log(-math.log(1 - q)) / k) + } + Vectors.dense(quantiles) + } + + @Since("1.6.0") + def predict(features: Vector): Double = { + math.exp(BLAS.dot(coefficients, features) + intercept) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val predictUDF = udf { features: Vector => predict(features) } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + .setParent(parent) + } +} + +/** + * AFTAggregator computes the gradient and loss for a AFT loss function, + * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * + * The loss function and likelihood function under the AFT model based on: + * Lawless, J. F., Statistical Models and Methods for Lifetime Data, + * New York: John Wiley & Sons, Inc. 2003. + * + * Two AFTAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n, + * with possible right-censoring, the likelihood function under the AFT model is given as + * {{{ + * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} + * }}} + * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not. + * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function + * assumes the form + * {{{ + * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ + * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] + * }}} + * Where S_{0}(\epsilon_{i}) is the baseline survivor function, + * and f_{0}(\epsilon_{i}) is corresponding density function. + * + * The most commonly used log-linear survival regression method is based on the Weibull + * distribution of the survival time. The Weibull distribution for lifetime corresponding + * to extreme value distribution for log of the lifetime, + * and the S_{0}(\epsilon) function is + * {{{ + * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) + * }}} + * the f_{0}(\epsilon_{i}) function is + * {{{ + * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) + * }}} + * The log-likelihood function for Weibull distribution of lifetime is + * {{{ + * \iota(\beta,\sigma)= + * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] + * }}} + * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, + * the loss function we use to optimize is -\iota(\beta,\sigma). + * The gradient functions for \beta and \log\sigma respectively are + * {{{ + * \frac{\partial (-\iota)}{\partial \beta}= + * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} + * }}} + * {{{ + * \frac{\partial (-\iota)}{\partial (\log\sigma)}= + * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] + * }}} + * @param weights The log of scale parameter, the intercept and + * regression coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean) + extends Serializable { + + // beta is the intercept and regression coefficients to the covariates + private val beta = weights.slice(1, weights.length) + // sigma is the scale parameter of the AFT model + private val sigma = math.exp(weights(0)) + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientLogSigmaSum = 0.0 + + def count: Long = totalCnt + + def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt + + // Here we optimize loss function over beta and log(sigma) + def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), + gradientBetaSum/totalCnt.toDouble) + + /** + * Add a new training data to this AFTAggregator, and update the loss and gradient + * of the objective function. + * + * @param data The AFTPoint representation for one data point to be added into this aggregator. + * @return This AFTAggregator object. + */ + def add(data: AFTPoint): this.type = { + + // TODO: Don't create a new xi vector each time. + val xi = if (fitIntercept) { + Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze + } else { + Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze + } + val ti = data.label + val delta = data.censor + val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + + lossSum += math.log(sigma) * delta + lossSum += (math.exp(epsilon) - delta * epsilon) + + // Sanity check (should never occur): + assert(!lossSum.isInfinity, + s"AFTAggregator loss sum is infinity. Error for unknown reason.") + + gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma + gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + + totalCnt += 1 + this + } + + /** + * Merge another AFTAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other AFTAggregator to be merged. + * @return This AFTAggregator object. + */ + def merge(other: AFTAggregator): this.type = { + if (totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + gradientBetaSum += other.gradientBetaSum + gradientLogSigmaSum += other.gradientLogSigmaSum + } + this + } +} + +/** + * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. + * It returns the loss and gradient at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ +private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + + val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + seqOp = (c, v) => (c, v) match { + case (aggregator, instance) => aggregator.add(instance) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + (aftAggregator.loss, aftAggregator.gradient) + } +} + +/** + * Class that represents the (features, label, censor) of a data point. + * + * @param features List of features for this data point. + * @param label Label for this data point. + * @param censor Indicator of the event has occurred or not. If the value is 1, it means + * the event has occurred i.e. uncensored; otherwise censored. + */ +private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { + require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 0000000000000..ca7140a45ea65 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, DataFrame} + +class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } +} From d009da2f5c803f3b7344c96abbfcf3ecef2f5ad2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:05:20 -0700 Subject: [PATCH 0137/2089] [SPARK-10682] [GRAPHX] Remove Bagel test suites. Bagel has been deprecated and we haven't done any changes to it. There is no need to run those tests. This should speed up tests by 1 min. Author: Reynold Xin Closes #8807 from rxin/SPARK-10682. --- bagel/src/test/resources/log4j.properties | 27 ----- .../org/apache/spark/bagel/BagelSuite.scala | 113 ------------------ 2 files changed, 140 deletions(-) delete mode 100644 bagel/src/test/resources/log4j.properties delete mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae92096..0000000000000 --- a/bagel/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74b..0000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} From 93c7650ab60a839a9cbe8b4ea1d5eda93e53ebe0 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Thu, 17 Sep 2015 22:25:24 -0700 Subject: [PATCH 0138/2089] [SPARK-9522] [SQL] SparkSubmit process can not exit if kill application when HiveThriftServer was starting When we start HiveThriftServer, we will start SparkContext first, then start HiveServer2, if we kill application while HiveServer2 is starting then SparkContext will stop successfully, but SparkSubmit process can not exit. Author: linweizhong Closes #7853 from Sephiroth-Lin/SPARK-9522. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9c3218719f7fc..ebd8e946ee7a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -97,7 +97,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index dd9fef9206d0b..a0643cec0fb7c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -93,6 +93,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) From 9a56dcdf7f19c9f7f913a2ce9bc981cb43a113c5 Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Thu, 17 Sep 2015 22:42:46 -0700 Subject: [PATCH 0139/2089] docs/running-on-mesos.md: state default values in default column This PR simply uses the default value column for defaults. Author: Felix Bechstein Closes #8810 from felixb/fix_mesos_doc. --- docs/running-on-mesos.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 247e6ecfbdb86..1814fb32ed8a5 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -332,21 +332,21 @@ See the [configuration page](configuration.html) for information on Spark config - + - + - + - + - + - - - - - diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 79beec4429a99..c5f93bb47f55c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -50,9 +50,6 @@ * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); @@ -254,7 +251,6 @@ private void deleteExecutorDirs(String[] dirs) { * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockResolver. */ - // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1c96b0958586f..814a11e588ceb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,6 +70,10 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ + Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") ) case v if v.startsWith("1.5") => Seq( From 8074208fa47fa654c1055c48cfa0d923edeeb04f Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Fri, 18 Sep 2015 15:40:58 -0700 Subject: [PATCH 0148/2089] [SPARK-10611] Clone Configuration for each task for NewHadoopRDD This patch attempts to fix the Hadoop Configuration thread safety issue for NewHadoopRDD in the same way SPARK-2546 fixed the issue for HadoopRDD. Author: Mingyu Kim Closes #8763 from mingyukim/mkim/SPARK-10611. --- .../org/apache/spark/rdd/BinaryFileRDD.scala | 5 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 6fec00dcd0d85..aedced7408cde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -34,12 +34,13 @@ private[spark] class BinaryFileRDD[T]( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 174979aaeb231..2872b93b8730e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -44,7 +44,6 @@ private[spark] class NewHadoopPartition( extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -84,6 +83,27 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { @@ -104,7 +124,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -268,12 +292,13 @@ private[spark] class WholeTextFileRDD( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) From c8149ef2c57f5c47ab97ee8d8d58a216d4bc4156 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Sep 2015 16:23:05 -0700 Subject: [PATCH 0149/2089] [MINOR] [ML] override toString of AttributeGroup This makes equality test failures much more readable. mengxr Author: Eric Liang Author: Eric Liang Closes #8826 from ericl/attrgroupstr. --- .../scala/org/apache/spark/ml/attribute/AttributeGroup.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 457c15830fd38..2c29eeb01a921 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -183,6 +183,8 @@ class AttributeGroup private ( sum = 37 * sum + attributes.map(_.toSeq).hashCode sum } + + override def toString: String = toMetadata.toString } /** From 22be2ae147a111e88896f6fb42ed46bbf108a99b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 18:42:20 -0700 Subject: [PATCH 0150/2089] [SPARK-10623] [SQL] Fixes ORC predicate push-down When pushing down a leaf predicate, ORC `SearchArgument` builder requires an extra "parent" predicate (any one among `AND`/`OR`/`NOT`) to wrap the leaf predicate. E.g., to push down `a < 1`, we must build `AND(a < 1)` instead. Fortunately, when actually constructing the `SearchArgument`, the builder will eliminate all those unnecessary wrappers. This PR is based on #8783 authored by zhzhan. I also took the chance to simply `OrcFilters` a little bit to improve readability. Author: Cheng Lian Closes #8799 from liancheng/spark-10623/fix-orc-ppd. --- .../spark/sql/hive/orc/OrcFilters.scala | 56 ++++++++----------- .../spark/sql/hive/orc/OrcQuerySuite.scala | 30 ++++++++++ 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index b3d9f7f71a27d..27193f54d3a91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.sources._ * and cannot be used anymore. */ private[orc] object OrcFilters extends Logging { - def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgumentFactory.newBuilder() - buildSearchArgument(conjunction, builder).map(_.build()) - } + def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + for { + // Combines all filters with `And`s to produce a single conjunction predicate + conjunction <- filters.reduceOption(And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { @@ -102,46 +104,32 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() - case EqualTo(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.equals(attribute, _)) + case EqualTo(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().equals(attribute, value).end()) - case EqualNullSafe(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.nullSafeEquals(attribute, _)) + case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThan(attribute, _)) + case LessThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThanEquals(attribute, _)) + case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThanEquals(attribute, _).end()) + case GreaterThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThan(attribute, _).end()) + case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThan(attribute, value).end()) case IsNull(attribute) => - Some(builder.isNull(attribute)) + Some(builder.startAnd().isNull(attribute).end()) case IsNotNull(attribute) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) => - Option(values) - .filter(_.forall(isSearchableLiteral)) - .map(builder.in(attribute, _)) + case In(attribute, values) if values.forall(isSearchableLiteral) => + Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8bc33fcf5d906..5eb39b1129701 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -344,4 +344,34 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-10623 Enable ORC PPD") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + import testImplicits._ + + val path = dir.getCanonicalPath + sqlContext.range(10).coalesce(1).write.orc(path) + val df = sqlContext.read.orc(path) + + def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { + checkAnswer(df.where(pred), answer.map(Row(_))) + } + + checkPredicate('id === 5, Seq(5L)) + checkPredicate('id <=> 5, Seq(5L)) + checkPredicate('id < 5, 0L to 4L) + checkPredicate('id <= 5, 0L to 5L) + checkPredicate('id > 5, 6L to 9L) + checkPredicate('id >= 5, 5L to 9L) + checkPredicate('id.isNull, Seq.empty[Long]) + checkPredicate('id.isNotNull, 0L to 9L) + checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) + checkPredicate('id > 0 && 'id < 3, 1L to 2L) + checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) + checkPredicate(!('id > 3), 0L to 3L) + checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + } + } + } } From 7ff8d68cc19299e16dedfd819b9e96480fa6cf44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Sep 2015 23:58:25 -0700 Subject: [PATCH 0151/2089] [SPARK-10474] [SQL] Aggregation fails to allocate memory for pointer array When `TungstenAggregation` hits memory pressure, it switches from hash-based to sort-based aggregation in-place. However, in the process we try to allocate the pointer array for writing to the new `UnsafeExternalSorter` *before* actually freeing the memory from the hash map. This lead to the following exception: ``` java.io.IOException: Could not acquire 65536 bytes of memory at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.initializeForWriting(UnsafeExternalSorter.java:169) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:220) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:126) at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:257) at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.switchToSortBasedAggregation(TungstenAggregationIterator.scala:435) ``` Author: Andrew Or Closes #8827 from andrewor14/allocate-pointer-array. --- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++- .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../UnsafeFixedWidthAggregationMapSuite.scala | 49 ++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index fc364e0a895b1..14b6aafdea7df 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -159,7 +159,7 @@ public BoxedUnit apply() { /** * Allocates new sort data structures. Called when creating the sorter and after each spill. */ - private void initializeForWriting() throws IOException { + public void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); final long pointerArrayMemory = UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); @@ -187,6 +187,14 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + spill(true); + } + + /** + * Sort and spill the current records in response to memory pressure. + * @param shouldInitializeForWriting whether to allocate memory for writing after the spill + */ + public void spill(boolean shouldInitializeForWriting) throws IOException { assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), @@ -217,7 +225,9 @@ public void spill() throws IOException { // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - initializeForWriting(); + if (shouldInitializeForWriting) { + initializeForWriting(); + } } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 7db6b7ff50f22..b81f67a16b815 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,6 +85,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, // We will use the number of elements in the map as the initialSize of the // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, // we will use 1 as its initial size if the map is empty. + // TODO: track pointer array memory used by this in-memory sorter! final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); @@ -123,8 +124,13 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes, inMemSorter); - sorter.spill(); + // Note: This spill doesn't actually release any memory, so if we try to allocate a new + // pointer array immediately after the spill then we may fail to acquire sufficient space + // for it (SPARK-10474). For this reason, we must initialize for writing explicitly *after* + // we have actually freed memory from our map. + sorter.spill(false /* initialize for writing */); map.free(); + sorter.initializeForWriting(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d1f0b2b1fc52f..ada4d42f991ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,9 +23,10 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -325,7 +326,7 @@ class UnsafeFixedWidthAggregationMapSuite // At here, we also test if copy is correct. iter.getKey.copy() iter.getValue.copy() - count += 1; + count += 1 } // 1 record was from the map and 4096 records were explicitly inserted. @@ -333,4 +334,48 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val smm = ShuffleMemoryManager.createForTesting(65536) + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + smm, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + // Insert into the map until we've run out of space + val rand = new Random(42) + var hasSpace = true + while (hasSpace) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + if (buf == null) { + hasSpace = false + } else { + buf.setInt(0, str.length) + } + } + + // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte + assert(smm.tryToAcquire(1) === 0) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } From d507f9c0b7f7a524137a694ed6443747aaf90463 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 19 Sep 2015 01:59:36 -0700 Subject: [PATCH 0152/2089] [SPARK-10584] [SQL] [DOC] Documentation about the compatible Hive version is wrong. In Spark 1.5.0, Spark SQL is compatible with Hive 0.12.0 through 1.2.1 but the documentation is wrong. /CC yhuai Author: Kousuke Saruta Closes #8776 from sarutak/SPARK-10584-2. --- docs/sql-programming-guide.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0b911d207243..82d4243cc6b27 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1954,7 +1954,7 @@ without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. +in Hive 1.2.1 You can test the JDBC server with the beeline script that comes with either Spark or Hive 1.2.1. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -2260,8 +2260,10 @@ Several caching related features are not supported yet: ## Compatibility with Apache Hive -Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark -SQL is based on Hive 0.12.0 and 0.13.1. +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. +Currently Hive SerDes and UDFs are based on Hive 1.2.1, +and Spark SQL can be connected to different versions of Hive Metastore +(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). #### Deploying in Existing Hive Warehouses From d83b6aae8b4357c56779cc98804eb350ab8af62d Mon Sep 17 00:00:00 2001 From: Alexis Seigneurin Date: Sat, 19 Sep 2015 12:01:22 +0100 Subject: [PATCH 0153/2089] Fixed links to the API Submitting this change on the master branch as requested in https://github.com/apache/spark/pull/8819#issuecomment-141505941 Author: Alexis Seigneurin Closes #8838 from aseigneurin/patch-2. --- docs/ml-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5d7f990021f1..0427ac6695aa1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -619,13 +619,13 @@ for row in selected.collect(): An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. From e789000b88a6bd840f821c53f42c08b97dc02496 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 18:22:43 -0700 Subject: [PATCH 0154/2089] [SPARK-10155] [SQL] Change SqlParser to object to avoid memory leak Since `scala.util.parsing.combinator.Parsers` is thread-safe since Scala 2.10 (See [SI-4929](https://issues.scala-lang.org/browse/SI-4929)), we can change SqlParser to object to avoid memory leak. I didn't change other subclasses of `scala.util.parsing.combinator.Parsers` because there is only one instance in one SQLContext, which should not be an issue. Author: zsxwing Closes #8357 from zsxwing/sql-memory-leak. --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ParserDialect.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++--- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 5898a5f93f381..2bac08eac4fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def parse(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 554fb4eb25eb1..e21d3c05464b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -61,7 +61,7 @@ abstract class ParserDialect { */ private[spark] class DefaultParserDialect extends ParserDialect { @transient - protected val sqlParser = new SqlParser + protected val sqlParser = SqlParser override def parse(sqlText: String): LogicalPlan = { sqlParser.parse(sqlText) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c9573..dfab2398857e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -37,9 +37,9 @@ import org.apache.spark.unsafe.types.CalendarInterval * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser with DataTypeParser { +object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - def parseExpression(input: String): Expression = { + def parseExpression(input: String): Expression = synchronized { // Initialize the Keywords. initLexical phrase(projection)(new lexical.Scanner(input)) match { @@ -48,7 +48,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } - def parseTableIdentifier(input: String): TableIdentifier = { + def parseTableIdentifier(input: String): TableIdentifier = synchronized { // Initialize the Keywords. initLexical phrase(tableIdentifier)(new lexical.Scanner(input)) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3e61123c145cd..8f737c2023931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -720,7 +720,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(new SqlParser().parseExpression(expr)) + Column(SqlParser.parseExpression(expr)) }: _*) } @@ -745,7 +745,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -769,7 +769,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 745bb4ec9cf1c..03e973666e888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -163,7 +163,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(new SqlParser().parseTableIdentifier(tableName)) + insertInto(SqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -197,7 +197,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + saveAsTable(SqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e3fdd782e6ff6..f099940800cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -590,7 +590,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -636,7 +636,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -732,7 +732,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(new SqlParser().parseTableIdentifier(tableName)) + table(SqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60d9c509104d5..2467b4e48415b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -823,7 +823,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d37ba5ddc2d80..c12a734863326 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -291,12 +291,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -311,7 +311,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a5569b0a4446..0c1b41e3377e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -199,7 +199,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive options: Map[String, String], isExternal: Boolean): Unit = { createDataSourceTable( - new SqlParser().parseTableIdentifier(tableName), + SqlParser.parseTableIdentifier(tableName), userSpecifiedSchema, partitionColumns, provider, @@ -375,7 +375,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { - hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + hiveDefaultTableFilePath(SqlParser.parseTableIdentifier(tableName)) } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { From 2117eea71ece825fbc3797c8b38184ae221f5223 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 19 Sep 2015 21:40:21 -0700 Subject: [PATCH 0155/2089] [SPARK-10710] Remove ability to disable spilling in core and SQL It does not make much sense to set `spark.shuffle.spill` or `spark.sql.planner.externalSort` to false: I believe that these configurations were initially added as "escape hatches" to guard against bugs in the external operators, but these operators are now mature and well-tested. In addition, these configurations are not handled in a consistent way anymore: SQL's Tungsten codepath ignores these configurations and will continue to use spilling operators. Similarly, Spark Core's `tungsten-sort` shuffle manager does not respect `spark.shuffle.spill=false`. This pull request removes these configurations, adds warnings at the appropriate places, and deletes a large amount of code which was only used in code paths that did not support spilling. Author: Josh Rosen Closes #8831 from JoshRosen/remove-ability-to-disable-spilling. --- .../scala/org/apache/spark/Aggregator.scala | 59 +++++-------------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 40 ++++--------- .../shuffle/hash/HashShuffleManager.scala | 8 ++- .../shuffle/sort/SortShuffleManager.scala | 10 +++- .../util/collection/ExternalSorter.scala | 6 -- .../spark/deploy/SparkSubmitSuite.scala | 22 +++---- docs/configuration.md | 14 +---- docs/sql-programming-guide.md | 7 --- python/pyspark/rdd.py | 25 +++----- python/pyspark/shuffle.py | 30 ---------- python/pyspark/tests.py | 13 +--- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +-- .../spark/sql/execution/SparkStrategies.scala | 2 - .../apache/spark/sql/execution/commands.scala | 9 +++ .../org/apache/spark/sql/execution/sort.scala | 30 +--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++------ .../execution/RowFormatConvertersSuite.scala | 2 +- .../spark/sql/execution/SortSuite.scala | 4 +- 18 files changed, 81 insertions(+), 234 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 289aab9bd9e51..7196e57d5d2e2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,59 +34,30 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } /** Update task metrics after populating the external map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7bad749d58327..935c3babd8ea1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -128,8 +128,6 @@ class CoGroupedRDD[K: ClassTag]( override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -150,34 +148,16 @@ class CoGroupedRDD[K: ClassTag]( rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409dd..0b46634b8b466 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,13 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b8..476cc1f303da7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,11 +19,17 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 31230d5978b2a..2a30f751ff03d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -116,8 +116,6 @@ private[spark] class ExternalSorter[K, V, C]( private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -229,10 +227,6 @@ private[spark] class ExternalSorter[K, V, C]( * @param usingMap whether we're using a map or buffer as our current in-memory collection */ private def maybeSpillCollection(usingMap: Boolean): Unit = { - if (!spillingEnabled) { - return - } - var estimatedSize = 0L if (usingMap) { estimatedSize = map.estimateSize() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1110ca6051a40..1fd470cd3b01d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -147,7 +147,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -166,7 +166,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") sysProps.keys should not contain ("spark.jars") } @@ -185,7 +185,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -206,7 +206,7 @@ class SparkSubmitSuite sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone cluster mode") { @@ -229,7 +229,7 @@ class SparkSubmitSuite "--supervise", "--driver-memory", "4g", "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -253,9 +253,9 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.memory") sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.ui.enabled") sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -266,7 +266,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -277,7 +277,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -288,7 +288,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -299,7 +299,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { diff --git a/docs/configuration.md b/docs/configuration.md index 3700051efb448..5ec097c78aa38 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -69,7 +69,7 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false +./bin/spark-submit --name "My app" --master local[4] --conf spark.eventLog.enabled=false --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} @@ -449,8 +449,8 @@ Apart from these, the following properties are also available, and may be useful - - - - - diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 82d4243cc6b27..7ae9244c271e3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1936,13 +1936,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - - - -
Property NameDefaultMeaning
spark.streaming.backpressure.enabledfalse + Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). + This enables the Spark Streaming to control the receiving rate based on the + current batch scheduling delays and processing times so that the system receives + only as fast as the system can process. Internally, this dynamically sets the + maximum receiving rate of receivers. This rate is upper bounded by the values + `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + if they are set (see below). +
spark.streaming.blockInterval 200msspark.shuffle.manager sort - Implementation to use for shuffling data. There are two implementations available: - sort and hash. Sort-based shuffle is more memory-efficient and is - the default option starting in 1.2. + Implementation to use for shuffling data. There are three implementations available: + sort, hash and the new (1.5+) tungsten-sort. + Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. + Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly + implementation with a fall back to regular sort based shuffle if its requirements are not + met.
spark.scheduler.minRegisteredResourcesRatio0.8 for YARN mode; 0.0 otherwise0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode, CPU cores in standalone mode) + (resources are executors in yarn mode, CPU cores in standalone mode and Mesos coarsed-grained + mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config From d88abb7e212fb55f9b0398a0f76a753c86b85cf1 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 10 Sep 2015 12:06:49 -0700 Subject: [PATCH 0040/2089] [SPARK-9990] [SQL] Create local hash join operator This PR includes the following changes: - Add SQLConf to LocalNode - Add HashJoinNode - Add ConvertToUnsafeNode and ConvertToSafeNode.scala to test unsafe hash join. Author: zsxwing Closes #8535 from zsxwing/SPARK-9990. --- .../sql/execution/joins/HashedRelation.scala | 4 +- .../execution/local/ConvertToSafeNode.scala | 40 +++++ .../execution/local/ConvertToUnsafeNode.scala | 40 +++++ .../sql/execution/local/FilterNode.scala | 4 +- .../sql/execution/local/HashJoinNode.scala | 137 ++++++++++++++++++ .../spark/sql/execution/local/LimitNode.scala | 3 +- .../spark/sql/execution/local/LocalNode.scala | 83 ++++++++++- .../sql/execution/local/ProjectNode.scala | 4 +- .../sql/execution/local/SeqScanNode.scala | 4 +- .../spark/sql/execution/local/UnionNode.scala | 3 +- .../sql/execution/local/FilterNodeSuite.scala | 4 +- .../execution/local/HashJoinNodeSuite.scala | 130 +++++++++++++++++ .../sql/execution/local/LimitNodeSuite.scala | 4 +- .../sql/execution/local/LocalNodeTest.scala | 9 +- .../execution/local/ProjectNodeSuite.scala | 4 +- .../sql/execution/local/UnionNodeSuite.scala | 6 +- 16 files changed, 455 insertions(+), 24 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6c0196c21a0d1..0cff21ca618b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -38,7 +38,7 @@ import org.apache.spark.{SparkConf, SparkEnv} * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[joins] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation { def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by @@ -111,7 +111,7 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. -private[joins] object HashedRelation { +private[execution] object HashedRelation { def apply( input: Iterator[InternalRow], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala new file mode 100644 index 0000000000000..b31c5a863832e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} + +case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToSafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToSafe = FromUnsafeProjection(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToSafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala new file mode 100644 index 0000000000000..de2f4e661ab44 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} + +case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToUnsafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToUnsafe = UnsafeProjection.create(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToUnsafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index 81dd37c7da733..dd1113b6726cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { +case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var predicate: (InternalRow) => Boolean = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala new file mode 100644 index 0000000000000..a3e68d6a7c341 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. + */ +case class HashJoinNode( + conf: SQLConf, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: LocalNode, + right: LocalNode) extends BinaryLocalNode(conf) { + + private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (left, leftKeys, right, rightKeys) + case BuildRight => (right, rightKeys, left, leftKeys) + } + + private[this] var currentStreamedRow: InternalRow = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ + private[this] var currentMatchPosition: Int = -1 + + private[this] var joinRow: JoinedRow = _ + private[this] var resultProjection: (InternalRow) => InternalRow = _ + + private[this] var hashed: HashedRelation = _ + private[this] var joinKeys: Projection = _ + + override def output: Seq[Attribute] = left.output ++ right.output + + private[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(schema)) + } + + private[this] def buildSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + newMutableProjection(buildKeys, buildNode.output)() + } + } + + private[this] def streamSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(streamedKeys, streamedNode.output) + } else { + newMutableProjection(streamedKeys, streamedNode.output)() + } + } + + override def open(): Unit = { + buildNode.open() + hashed = HashedRelation.apply( + new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator) + streamedNode.open() + joinRow = new JoinedRow + resultProjection = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + joinKeys = streamSideKeyGenerator + } + + override def next(): Boolean = { + currentMatchPosition += 1 + if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) { + fetchNextMatch() + } else { + true + } + } + + /** + * Populate `currentHashMatches` with build-side rows matching the next streamed row. + * @return whether matches are found such that subsequent calls to `fetch` are valid. + */ + private def fetchNextMatch(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamedNode.next()) { + currentStreamedRow = streamedNode.fetch() + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashed.get(key) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + + override def fetch(): InternalRow = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + resultProjection(ret) + } + + override def close(): Unit = { + left.close() + right.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala index fffc52abf6dd5..401b10a5ed307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { +case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { private[this] var count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 1c4469acbf264..c4f8ae304db39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Row +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode extends TreeNode[LocalNode] { +abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { + + protected val codegenEnabled: Boolean = conf.codegenEnabled + + protected val unsafeEnabled: Boolean = conf.unsafeEnabled + + lazy val schema: StructType = StructType.fromAttributes(output) + + private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") def output: Seq[Attribute] @@ -73,17 +85,78 @@ abstract class LocalNode extends TreeNode[LocalNode] { } result } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } -abstract class LeafLocalNode extends LocalNode { +abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { override def children: Seq[LocalNode] = Seq.empty } -abstract class UnaryLocalNode extends LocalNode { +abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { def child: LocalNode override def children: Seq[LocalNode] = Seq(child) } + +abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { + + def left: LocalNode + + def right: LocalNode + + override def children: Seq[LocalNode] = Seq(left, right) +} + +/** + * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. + */ +private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { + private var nextRow: InternalRow = _ + + override def hasNext: Boolean = { + if (nextRow == null) { + val res = localNode.next() + if (res) { + nextRow = localNode.fetch() + } + res + } else { + true + } + } + + override def next(): InternalRow = { + if (hasNext) { + val res = nextRow + nextRow = null + res + } else { + throw new NoSuchElementException + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index 9b8a4fe493026..11529d6dd9b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} -case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { +case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var project: UnsafeProjection = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 242cb66e07b7f..b8467f6ae58e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute /** * An operator that scans some local data collection in the form of Scala Seq. */ -case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { +case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) + extends LeafLocalNode(conf) { private[this] var iterator: Iterator[InternalRow] = _ private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala index ba4aa7671aebd..0f2b8303e7372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class UnionNode(children: Seq[LocalNode]) extends LocalNode { +case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { override def output: Seq[Attribute] = children.head.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index 07209f3779248..a12670e347c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (testData.col("key") % 2) === 0 checkAnswer( testData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), testData.filter(condition).collect() ) } @@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (emptyTestData.col("key") % 2) === 0 checkAnswer( emptyTestData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), emptyTestData.filter(condition).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala new file mode 100644 index 0000000000000..43b6f06aead88 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -0,0 +1,130 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.execution.joins + +class HashJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } + + def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { + test(s"$suiteName: inner join with one match per row") { + withSQLConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(upperCaseData.col("N").expr), + Seq(lowerCaseData.col("n").expr), + joins.BuildLeft, + node1, + node2) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N").collect() + ) + } + } + + test(s"$suiteName: inner join with multiple matches") { + withSQLConf(confPairs: _*) { + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 1).as("y") + checkAnswer2( + x, + y, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(x.col("a").expr), + Seq(y.col("a").expr), + joins.BuildLeft, + node1, + node2) + ), + x.join(y).where($"x.a" === $"y.a").collect() + ) + } + } + + test(s"$suiteName: inner join, no matches") { + withSQLConf(confPairs: _*) { + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 2).as("y") + checkAnswer2( + x, + y, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(x.col("a").expr), + Seq(y.col("a").expr), + joins.BuildLeft, + node1, + node2) + ), + Nil + ) + } + } + + test(s"$suiteName: big inner join, 4 matches per row") { + withSQLConf(confPairs: _*) { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") + + checkAnswer2( + bigDataX, + bigDataY, + wrapForUnsafe( + (node1, node2) => + HashJoinNode( + conf, + Seq(bigDataX.col("key").expr), + Seq(bigDataY.col("key").expr), + joins.BuildLeft, + node1, + node2) + ), + bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) + } + } + } + + joinSuite( + "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 523c02f4a6014..3b183902007e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("basic") { checkAnswer( testData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), testData.limit(10).collect() ) } @@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("empty") { checkAnswer( emptyTestData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), emptyTestData.limit(10).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 95f06081bd0a8..b95d4ea7f8f2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class LocalNodeTest extends SparkFunSuite { +class LocalNodeTest extends SparkFunSuite with SharedSQLContext { + + def conf: SQLConf = sqlContext.conf /** * Runs the LocalNode and makes sure the answer matches the expected result. @@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite { protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { new SeqScanNode( + conf, df.queryExecution.sparkPlan.output, df.queryExecution.toRdd.map(_.copy()).collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index ffcf092e2c66a..38e0a230c46d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( testData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), testData.select("value", "key").collect() ) } @@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( emptyTestData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), emptyTestData.select("value", "key").collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index 34670287c3e1d..eedd7320900f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( testData, testData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), testData.unionAll(testData).collect() ) } @@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( emptyTestData, emptyTestData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), emptyTestData.unionAll(emptyTestData).collect() ) } @@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { emptyTestData, emptyTestData, testData, emptyTestData) doCheckAnswer( dfs, - nodes => UnionNode(nodes), + nodes => UnionNode(conf, nodes), dfs.reduce(_.unionAll(_)).collect() ) } From 45e3be5c138d983f40f619735d60bf7eb78c9bf0 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 10 Sep 2015 12:21:13 -0700 Subject: [PATCH 0041/2089] [SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataFrame. this PR : 1. Enhance reflection in RBackend. Automatically matching a Java array to Scala Seq when finding methods. Util functions like seq(), listToSeq() in R side can be removed, as they will conflict with the Serde logic that transferrs a Scala seq to R side. 2. Enhance the SerDe to support transferring a Scala seq to R side. Data of ArrayType in DataFrame after collection is observed to be of Scala Seq type. 3. Support ArrayType in createDataFrame(). Author: Sun Rui Closes #8458 from sun-rui/SPARK-10049. --- R/pkg/R/DataFrame.R | 26 ++-- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/column.R | 3 +- R/pkg/R/functions.R | 12 +- R/pkg/R/group.R | 4 +- R/pkg/R/schema.R | 54 +++++--- R/pkg/R/utils.R | 10 -- R/pkg/inst/tests/test_sparkSQL.R | 44 ++++++- .../apache/spark/api/r/RBackendHandler.scala | 121 +++++++++++++----- .../scala/org/apache/spark/api/r/SerDe.scala | 109 ++++++++-------- .../org/apache/spark/sql/api/r/SQLUtils.scala | 14 +- 11 files changed, 250 insertions(+), 151 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a00238b41d60..c3c1893487334 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -271,7 +271,7 @@ setMethod("names<-", signature(x = "DataFrame"), function(x, value) { if (!is.null(value)) { - sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) dataFrame(sdf) } }) @@ -843,10 +843,10 @@ setMethod("groupBy", function(x, ...) { cols <- list(...) if (length(cols) >= 1 && class(cols[[1]]) == "character") { - sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1]) } else { jcol <- lapply(cols, function(c) { c@jc }) - sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + sgd <- callJMethod(x@sdf, "groupBy", jcol) } groupedData(sgd) }) @@ -1079,7 +1079,7 @@ setMethod("subset", signature(x = "DataFrame"), #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { - sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "select", col, list(...)) dataFrame(sdf) }) @@ -1090,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "select", jcols) dataFrame(sdf) }) @@ -1106,7 +1106,7 @@ setMethod("select", col(c)@jc } }) - sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + sdf <- callJMethod(x@sdf, "select", cols) dataFrame(sdf) }) @@ -1133,7 +1133,7 @@ setMethod("selectExpr", signature(x = "DataFrame", expr = "character"), function(x, expr, ...) { exprList <- list(expr, ...) - sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + sdf <- callJMethod(x@sdf, "selectExpr", exprList) dataFrame(sdf) }) @@ -1311,12 +1311,12 @@ setMethod("arrange", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col, ...) { if (class(col) == "character") { - sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "sort", col, list(...)) } else if (class(col) == "Column") { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "sort", jcols) } dataFrame(sdf) }) @@ -1664,7 +1664,7 @@ setMethod("describe", signature(x = "DataFrame", col = "character"), function(x, col, ...) { colList <- list(col, ...) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1674,7 +1674,7 @@ setMethod("describe", signature(x = "DataFrame"), function(x) { colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1731,7 +1731,7 @@ setMethod("dropna", naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", - as.integer(minNonNulls), listToSeq(as.list(cols))) + as.integer(minNonNulls), as.list(cols)) dataFrame(sdf) }) @@ -1815,7 +1815,7 @@ setMethod("fillna", sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) } else { - callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + callJMethod(naFunctions, "fill", value, as.list(cols)) } dataFrame(sdf) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1bc6445311473..4ac057d0f2d83 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -49,7 +49,7 @@ infer_type <- function(x) { stopifnot(length(x) > 0) names <- names(x) if (is.null(names)) { - list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { # StructType types <- lapply(x, infer_type) @@ -59,7 +59,7 @@ infer_type <- function(x) { do.call(structType, fields) } } else if (length(x) > 1) { - list(type = "array", elementType = type, containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { type } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 4805096f3f9c5..42e9d12179db7 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -211,8 +211,7 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - table <- listToSeq(as.list(table)) - jc <- callJMethod(x@jc, "in", table) + jc <- callJMethod(x@jc, "in", as.list(table)) return(column(jc)) }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d848730e70433..94687edb05442 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1331,7 +1331,7 @@ setMethod("countDistinct", x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) + jcol) column(jc) }) @@ -1348,7 +1348,7 @@ setMethod("concat", signature(x = "Column"), function(x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1366,7 +1366,7 @@ setMethod("greatest", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1384,7 +1384,7 @@ setMethod("least", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @export setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jcols <- lapply(list(x, ...), function(x) { x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) column(jc) }) @@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"), #' @export setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jcols <- lapply(list(x, ...), function(arg) { arg@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "format_string", format, jcols) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 576ac72f40fc0..4cab1a69f601a 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -102,7 +102,7 @@ setMethod("agg", } } jcols <- lapply(cols, function(c) { c@jc }) - sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1]) } else { stop("agg can only support Column or character") } @@ -124,7 +124,7 @@ createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), function(x, ...) { - sdf <- callJMethod(x@sgd, name, toSeq(...)) + sdf <- callJMethod(x@sgd, name, list(...)) dataFrame(sdf) }) } diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 79c744ef29c23..62d4f73878d29 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) { }) stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructType", - listToSeq(sfObjList)) + sfObjList) structType(stObj) } @@ -114,6 +114,35 @@ structField.jobj <- function(x) { obj } +checkType <- function(type) { + primtiveTypes <- c("byte", + "integer", + "float", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + if (type %in% primtiveTypes) { + return() + } else { + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + } + + stop(paste("Unsupported type for Dataframe:", type)) +} + structField.character <- function(x, type, nullable = TRUE) { if (class(x) != "character") { stop("Field name must be a string.") @@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) { if (class(nullable) != "logical") { stop("nullable must be either TRUE or FALSE") } - options <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - dataType <- if (type %in% options) { - type - } else { - stop(paste("Unsupported type for Dataframe:", type)) - } + + checkType(type) + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructField", x, - dataType, + type, nullable) structField(sfObj) } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3babcb519378e..69a2bc728f842 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -361,16 +361,6 @@ numToInt <- function(num) { as.integer(num) } -# create a Seq in JVM -toSeq <- function(...) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) -} - -# create a Seq in JVM from a list -listToSeq <- function(l) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) -} - # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a # user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6d331f9883d55..1ccfde59176f5 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -49,6 +49,14 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesNa, jsonPathNa) +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -56,10 +64,8 @@ test_that("infer types", { expect_equal(infer_type(TRUE), "boolean") expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") testStruct <- infer_type(list(a = 1L, b = "2")) expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) @@ -236,8 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -# TODO: enable this test after fix serialization for nested object -#test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and struct", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -247,7 +252,32 @@ test_that("create DataFrame with different data types", { # expect_equal(count(df), 1) # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) -#}) + + + # ArrayType only for now + l <- list(as.list(1:10), list("a", "b")) + df <- createDataFrame(sqlContext, list(l), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) +}) + +test_that("Collect DataFrame with complex types", { + # only ArrayType now + # TODO: tests for StructType and MapType after they are supported + df <- jsonFile(sqlContext, complexTypeJsonPath) + + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) +}) test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index bb82f3285f1d9..2a792d81994fd 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend) val methods = cls.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val methods = selectedMethods.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - } - if (methods.isEmpty) { + val index = findMatchedSignature( + selectedMethods.map(_.getParameterTypes), + args) + + if (index.isEmpty) { logWarning(s"cannot find matching method ${cls}.$methodName. " + s"Candidates are:") selectedMethods.foreach { method => @@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args : _*) + + val ret = selectedMethods(index.get).invoke(obj, args : _*) // Write status bit writeInt(dos, 0) writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head + val ctors = cls.getConstructors + val index = findMatchedSignature( + ctors.map(_.getParameterTypes), + args) - val obj = ctor.newInstance(args : _*) + if (index.isEmpty) { + logWarning(s"cannot find matching constructor for ${cls}. " + + s"Candidates are:") + ctors.foreach { ctor => + logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched constructor found for $cls") + } + + val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) @@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => + (0 until numArgs).map { _ => readObject(dis) }.toArray } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- 0 until parameterTypesOfMethods.length) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } - for (i <- 0 to numArgs - 1) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Integer] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + (0 until numArgs).map { i => + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) } - } - if (!parameterWrapperType.isInstance(args(i))) { - return false } } - true + None } } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 190e193427af8..3c92bb7a1c73c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray /** * Utility functions to serialize, deserialize objects to / from R @@ -213,89 +214,97 @@ private[spark] object SerDe { } } - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null) { + def writeObject(dos: DataOutputStream, obj: Object): Unit = { + if (obj == null) { writeType(dos, "void") } else { - value.getClass.getName match { - case "java.lang.Character" => + // Convert ArrayType collected from DataFrame to Java array + // Collected data of ArrayType from a DataFrame is observed to be of + // type "scala.collection.mutable.WrappedArray" + val value = + if (obj.isInstanceOf[WrappedArray[_]]) { + obj.asInstanceOf[WrappedArray[_]].toArray + } else { + obj + } + + value match { + case v: java.lang.Character => writeType(dos, "character") - writeString(dos, value.asInstanceOf[Character].toString) - case "java.lang.String" => + writeString(dos, v.toString) + case v: java.lang.String => writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "java.lang.Long" => + writeString(dos, v) + case v: java.lang.Long => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "java.lang.Float" => + writeDouble(dos, v.toDouble) + case v: java.lang.Float => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "java.math.BigDecimal" => + writeDouble(dos, v.toDouble) + case v: java.math.BigDecimal => writeType(dos, "double") - val javaDecimal = value.asInstanceOf[java.math.BigDecimal] - writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "java.lang.Double" => + writeDouble(dos, scala.math.BigDecimal(v).toDouble) + case v: java.lang.Double => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "java.lang.Byte" => + writeDouble(dos, v) + case v: java.lang.Byte => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Byte].toInt) - case "java.lang.Short" => + writeInt(dos, v.toInt) + case v: java.lang.Short => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Short].toInt) - case "java.lang.Integer" => + writeInt(dos, v.toInt) + case v: java.lang.Integer => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "java.lang.Boolean" => + writeInt(dos, v) + case v: java.lang.Boolean => writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => + writeBoolean(dos, v) + case v: java.sql.Date => writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => + writeDate(dos, v) + case v: java.sql.Time => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => + writeTime(dos, v) + case v: java.sql.Timestamp => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) + writeTime(dos, v) // Handle arrays // Array of primitive types // Special handling for byte array - case "[B" => + case v: Array[Byte] => writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) + writeBytes(dos, v) - case "[C" => + case v: Array[Char] => writeType(dos, "array") - writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) - case "[S" => + writeStringArr(dos, v.map(_.toString)) + case v: Array[Short] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) - case "[I" => + writeIntArr(dos, v.map(_.toInt)) + case v: Array[Int] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => + writeIntArr(dos, v) + case v: Array[Long] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) - case "[F" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Float] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) - case "[D" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Double] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[Z" => + writeDoubleArr(dos, v) + case v: Array[Boolean] => writeType(dos, "array") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + writeBooleanArr(dos, v) // Array of objects, null objects use "void" type - case c if c.startsWith("[") => + case v: Array[Object] => writeType(dos, "list") - val array = value.asInstanceOf[Array[Object]] - writeInt(dos, array.length) - array.foreach(elem => writeObject(dos, elem)) + writeInt(dos, v.length) + v.foreach(elem => writeObject(dos, elem)) case _ => writeType(dos, "jobj") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 7f3defec3d42e..d4b834adb6e39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpres import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} +import scala.util.matching.Regex + private[r] object SQLUtils { def createSQLContext(jsc: JavaSparkContext): SQLContext = { new SQLContext(jsc) @@ -35,14 +37,15 @@ private[r] object SQLUtils { new JavaSparkContext(sqlCtx.sparkContext) } - def toSeq[T](arr: Array[T]): Seq[T] = { - arr.toSeq - } - def createStructType(fields : Seq[StructField]): StructType = { StructType(fields) } + // Support using regex in string interpolation + private[this] implicit class RegexContext(sc: StringContext) { + def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) + } + def getSQLDataType(dataType: String): DataType = { dataType match { case "byte" => org.apache.spark.sql.types.ByteType @@ -58,6 +61,9 @@ private[r] object SQLUtils { case "boolean" => org.apache.spark.sql.types.BooleanType case "timestamp" => org.apache.spark.sql.types.TimestampType case "date" => org.apache.spark.sql.types.DateType + case r"\Aarray<(.*)${elemType}>\Z" => { + org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } From 3db72554be3f13478ccd5915e746491982163298 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 10 Sep 2015 13:22:35 -0700 Subject: [PATCH 0042/2089] [SPARK-10443] [SQL] Refactor SortMergeOuterJoin to reduce duplication `LeftOutputIterator` and `RightOutputIterator` are symmetrically identical and can share a lot of code. If someone makes a change in one but forgets to do the same thing in the other we'll end up with inconsistent behavior. This patch also adds inline comments to clarify the intention of the code. Author: Andrew Or Closes #8596 from andrewor14/smoj-cleanup. --- .../execution/joins/SortMergeOuterJoin.scala | 138 ++++++++++-------- 1 file changed, 77 insertions(+), 61 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index ab20ee573ab5d..c117dff9c8b1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -199,99 +199,115 @@ case class SortMergeOuterJoin( } } - +/** + * An iterator for outputting rows in left outer join. + */ private class LeftOuterIterator( smjScanner: SortMergeJoinScanner, rightNullRow: InternalRow, boundCondition: InternalRow => Boolean, resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var rightIdx: Int = 0 - assert(smjScanner.getBufferedMatches.length == 0) - - private def advanceLeft(): Boolean = { - rightIdx = 0 - if (smjScanner.findNextOuterJoinRows()) { - joinedRow.withLeft(smjScanner.getStreamedRow) - if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching right rows, so return nulls for the right row - joinedRow.withRight(rightNullRow) - } else { - // Find the next row from the right input that satisfied the bound condition - if (!advanceRightUntilBoundConditionSatisfied()) { - joinedRow.withRight(rightNullRow) - } - } - true - } else { - // Left input has been exhausted - false - } - } - - private def advanceRightUntilBoundConditionSatisfied(): Boolean = { - var foundMatch: Boolean = false - while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) - rightIdx += 1 - } - foundMatch - } - - override def advanceNext(): Boolean = { - val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() - if (r) numRows += 1 - r - } + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { - override def getRow: InternalRow = resultProj(joinedRow) + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) } +/** + * An iterator for outputting rows in right outer join. + */ private class RightOuterIterator( smjScanner: SortMergeJoinScanner, leftNullRow: InternalRow, boundCondition: InternalRow => Boolean, resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var leftIdx: Int = 0 + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) +} + +/** + * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. + * + * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the + * streamed side will output 0 or many rows, one for each matching row on the buffered side. + * If there are no matches, then the buffered side of the joined output will be a null row. + * + * In left outer join, the left is the streamed side and the right is the buffered side. + * In right outer join, the right is the streamed side and the left is the buffered side. + * + * @param smjScanner a scanner that streams rows and buffers any matching rows + * @param bufferedSideNullRow the default row to return when a streamed row has no matches + * @param boundCondition an additional filter condition for buffered rows + * @param resultProj how the output should be projected + * @param numOutputRows an accumulator metric for the number of rows output + */ +private abstract class OneSideOuterIterator( + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { + + // A row to store the joined result, reused many times + protected[this] val joinedRow: JoinedRow = new JoinedRow() + + // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row + private[this] var bufferIndex: Int = 0 + + // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) - private def advanceRight(): Boolean = { - leftIdx = 0 + // Set output methods to be overridden by subclasses + protected def setStreamSideOutput(row: InternalRow): Unit + protected def setBufferedSideOutput(row: InternalRow): Unit + + /** + * Advance to the next row on the stream side and populate the buffer with matches. + * @return whether there are more rows in the stream to consume. + */ + private def advanceStream(): Boolean = { + bufferIndex = 0 if (smjScanner.findNextOuterJoinRows()) { - joinedRow.withRight(smjScanner.getStreamedRow) + setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching left rows, so return nulls for the left row - joinedRow.withLeft(leftNullRow) + // There are no matching rows in the buffer, so return the null row + setBufferedSideOutput(bufferedSideNullRow) } else { - // Find the next row from the left input that satisfied the bound condition - if (!advanceLeftUntilBoundConditionSatisfied()) { - joinedRow.withLeft(leftNullRow) + // Find the next row in the buffer that satisfied the bound condition + if (!advanceBufferUntilBoundConditionSatisfied()) { + setBufferedSideOutput(bufferedSideNullRow) } } true } else { - // Right input has been exhausted + // Stream has been exhausted false } } - private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + /** + * Advance to the next row in the buffer that satisfies the bound condition. + * @return whether there is such a row in the current buffer. + */ + private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) - leftIdx += 1 + while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { + setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + foundMatch = boundCondition(joinedRow) + bufferIndex += 1 } foundMatch } override def advanceNext(): Boolean = { - val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() - if (r) numRows += 1 + val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() + if (r) numOutputRows += 1 r } From 4204757714b36364611fb63bc008cf90fc53d8df Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 10 Sep 2015 13:43:13 -0700 Subject: [PATCH 0043/2089] Add 1.5 to master branch EC2 scripts This change brings it to par with `branch-1.5` (and 1.5.0 release) Author: Shivaram Venkataraman Closes #8704 from shivaram/ec2-1.5-update. --- ec2/spark_ec2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 11fd7ee0ec8df..3a2361c6d6d2b 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.4.0" +SPARK_EC2_VERSION = "1.5.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -71,6 +71,8 @@ "1.3.0", "1.3.1", "1.4.0", + "1.4.1", + "1.5.0" ]) SPARK_TACHYON_MAP = { @@ -84,6 +86,8 @@ "1.3.0": "0.5.0", "1.3.1": "0.5.0", "1.4.0": "0.6.4", + "1.4.1": "0.6.4", + "1.5.0": "0.7.1" } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -91,7 +95,7 @@ # Default location to get the spark-ec2 scripts (and ami-list) from DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" def setup_external_libs(libs): From 89562a172fd3efa032f60714d600407c6cfe2c2f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Sep 2015 13:54:20 -0700 Subject: [PATCH 0044/2089] [SPARK-7544] [SQL] [PySpark] pyspark.sql.types.Row implements __getitem__ pyspark.sql.types.Row implements ```__getitem__``` Author: Yanbo Liang Closes #8333 from yanboliang/spark-7544. --- python/pyspark/sql/types.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8bd58d69eeecd..1f86894855cbe 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1176,6 +1176,8 @@ class Row(tuple): >>> row = Row(name="Alice", age=11) >>> row Row(age=11, name='Alice') + >>> row['name'], row['age'] + ('Alice', 11) >>> row.name, row.age ('Alice', 11) @@ -1243,6 +1245,19 @@ def __call__(self, *args): """create new Row object""" return _create_row(self, args) + def __getitem__(self, item): + if isinstance(item, (int, slice)): + return super(Row, self).__getitem__(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return super(Row, self).__getitem__(idx) + except IndexError: + raise KeyError(item) + except ValueError: + raise ValueError(item) + def __getattr__(self, item): if item.startswith("__"): raise AttributeError(item) From 0eabea8a058ad60411c1384930ba12c1c638f5f1 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Thu, 10 Sep 2015 17:24:33 -0700 Subject: [PATCH 0045/2089] [SPARK-9043] Serialize key, value and combiner classes in ShuffleDependency ShuffleManager implementations are currently not given type information for the key, value and combiner classes. Serialization of shuffle objects relies on objects being JavaSerializable, with methods defined for reading/writing the object or, alternatively, serialization via Kryo which uses reflection. Serialization systems like Avro, Thrift and Protobuf generate classes with zero argument constructors and explicit schema information (e.g. IndexedRecords in Avro have get, put and getSchema methods). By serializing the key, value and combiner class names in ShuffleDependency, shuffle implementations will have access to schema information when registerShuffle() is called. Author: Matt Massie Closes #7403 from massie/shuffle-classtags. --- .../scala/org/apache/spark/bagel/Bagel.scala | 2 +- .../scala/org/apache/spark/Dependency.scala | 11 ++- .../apache/spark/api/java/JavaPairRDD.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 5 +- .../apache/spark/rdd/PairRDDFunctions.scala | 90 ++++++++++++++++--- .../org/apache/spark/rdd/ShuffledRDD.scala | 4 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 8 +- .../org/apache/spark/CheckpointSuite.scala | 2 +- .../shuffle/ShuffleDependencySuite.scala | 67 ++++++++++++++ 9 files changed, 168 insertions(+), 23 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index ef0bb2ac13f08..4e6b7686f771d 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -78,7 +78,7 @@ object Bagel extends Logging { val startTime = System.currentTimeMillis val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( + val combinedMsgs = msgs.combineByKeyWithClassTag( combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) val grouped = combinedMsgs.groupWith(verts) val superstep_ = superstep // Create a read-only copy of superstep for capture in closure diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index cfeeb3902c033..9aafc9eb1cde7 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -65,7 +67,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) */ @DeveloperApi -class ShuffleDependency[K, V, C]( +class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @transient private val _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, @@ -76,6 +78,13 @@ class ShuffleDependency[K, V, C]( override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] + private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName + private[spark] val valueClassName: String = reflect.classTag[V].runtimeClass.getName + // Note: It's possible that the combiner class tag is null, if the combineByKey + // methods in PairRDDFunctions are used instead of combineByKeyWithClassTag. + private[spark] val combinerClassName: Option[String] = + Option(reflect.classTag[C]).map(_.runtimeClass.getName) + val shuffleId: Int = _rdd.context.newShuffleId() val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index fb787979c1820..8344f6368ac48 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -239,7 +239,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapSideCombine: Boolean, serializer: Serializer): JavaPairRDD[K, C] = { implicit val ctag: ClassTag[C] = fakeClassTag - fromRDD(rdd.combineByKey( + fromRDD(rdd.combineByKeyWithClassTag( createCombiner, mergeValue, mergeCombiners, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9c617fc719cb5..7bad749d58327 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -22,6 +22,7 @@ import scala.language.existentials import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -74,7 +75,9 @@ private[spark] class CoGroupPartition( * @param part partitioner used to partition the shuffle output */ @DeveloperApi -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) +class CoGroupedRDD[K: ClassTag]( + @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], + part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs). diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 4e5f2e8a5d467..c59f0d4aa75a0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -57,7 +57,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) with SparkHadoopMapReduceUtil with Serializable { + /** + * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type @@ -70,12 +72,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). */ - def combineByKey[C](createCombiner: V => C, + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = self.withScope { + serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -103,13 +107,50 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] */ - def combineByKey[C](createCombiner: V => C, + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializer: Serializer = null): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + partitioner, mapSideCombine, serializer)(null) + } + + /** + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + * This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, numPartitions: Int): RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, numPartitions)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + */ + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + new HashPartitioner(numPartitions)) } /** @@ -133,7 +174,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) - combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) + combineByKeyWithClassTag[U]((v: V) => cleanedSeqOp(createZero(), v), + cleanedSeqOp, combOp, partitioner) } /** @@ -182,7 +224,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) val cleanedFunc = self.context.clean(func) - combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) + combineByKeyWithClassTag[V]((v: V) => cleanedFunc(createZero(), v), + cleanedFunc, cleanedFunc, partitioner) } /** @@ -268,7 +311,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { - combineByKey[V]((v: V) => v, func, func, partitioner) + combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** @@ -392,7 +435,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) h1 } - combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.cardinality()) + combineByKeyWithClassTag(createHLL, mergeValueHLL, mergeHLL, partitioner) + .mapValues(_.cardinality()) } /** @@ -466,7 +510,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createCombiner = (v: V) => CompactBuffer(v) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKey[CompactBuffer[V]]( + val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -565,12 +609,30 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. This method is here for backward compatibility. It + * does not provide combiner classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 2dc47f95937cb..cb15d912bbfb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -37,7 +39,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { */ // TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C]( +class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[(K, C)](prev.context, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a4fa301b06e3..25ec685eff5ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -63,15 +63,17 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => + def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]]) + : Dependency[_] = { if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializer) + new ShuffleDependency[T1, T2, Any](rdd, part, serializer) } } + Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2)) } override def getPartitions: Array[Partition] = { @@ -105,7 +107,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = { dependencies(depNum) match { case oneToOneDependency: OneToOneDependency[_] => val dependencyPartition = partition.narrowDeps(depNum).get.split diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d343bb95cb68c..4d70bfed909b6 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -483,7 +483,7 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + def cogroup[K: ClassTag, V: ClassTag](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala new file mode 100644 index 0000000000000..4d5f599fb12ab --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark._ + +case class KeyClass() + +case class ValueClass() + +case class CombinerClass() + +class ShuffleDependencySuite extends SparkFunSuite with LocalSparkContext { + + val conf = new SparkConf(loadDefaults = false) + + test("key, value, and combiner classes correct in shuffle dependency without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .groupByKey() + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!dep.mapSideCombine, "Test requires that no map-side aggregator is defined") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + } + + test("key, value, and combiner classes available in shuffle dependency with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .aggregateByKey(CombinerClass())({ case (a, b) => a }, { case (a, b) => a }) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.mapSideCombine && dep.aggregator.isDefined, "Test requires map-side aggregation") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == Some(classOf[CombinerClass].getName)) + } + + test("combineByKey null combiner class tag handled correctly") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .combineByKey((v: ValueClass) => v, + (c: AnyRef, v: ValueClass) => c, + (c1: AnyRef, c2: AnyRef) => c1) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == None) + } + +} From 339a527141984bfb182862b0987d3c4690c9ede1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Sep 2015 20:34:00 -0700 Subject: [PATCH 0046/2089] [SPARK-10023] [ML] [PySpark] Unified DecisionTreeParams checkpointInterval between Scala and Python API. "checkpointInterval" is member of DecisionTreeParams in Scala API which is inconsistency with Python API, we should unified them. ``` member of DecisionTreeParams <-> Scala API shared param for all ML Transformer/Estimator <-> Python API ``` Proposal: "checkpointInterval" is also used by ALS, so we make it shared params at Scala. Author: Yanbo Liang Closes #8528 from yanboliang/spark-10023. --- .../DecisionTreeClassifier.scala | 1 + .../ml/param/shared/SharedParamsCodeGen.scala | 3 +- .../spark/ml/param/shared/sharedParams.scala | 4 +-- .../org/apache/spark/ml/tree/treeParams.scala | 32 +++++++------------ 4 files changed, 16 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6f70b96b17ec6..0a75d5d22280f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8c16c6149b40d..e9e99ed1db40e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -56,7 +56,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), - ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", + ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " + + "the cache will get checkpointed every 10 iterations.", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index c26768953e3db..30092170863ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params { private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval (>= 1). + * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1)) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index dbd8d31571d2e..d29f5253c9c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams { +private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { /** * Maximum depth of the tree (>= 0). @@ -96,21 +96,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams { " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - /** - * Specifies how often to checkpoint the cached node IDs. - * E.g. 10 means that the cache will get checkpointed every 10 iterations. - * This is only used if cacheNodeIds is true and if the checkpoint directory is set in - * [[org.apache.spark.SparkContext]]. - * Must be >= 1. - * (default = 10) - * @group expertParam - */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + - " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + - " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.", - ParamValidators.gtEq(1)) - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) @@ -150,12 +135,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** @group expertSetParam */ + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertSetParam + */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** @group expertGetParam */ - final def getCheckpointInterval: Int = $(checkpointInterval) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], From a140dd77c62255d6f7f6817a2517d47feb8540d4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Sep 2015 20:43:38 -0700 Subject: [PATCH 0047/2089] [SPARK-10027] [ML] [PySpark] Add Python API missing methods for ml.feature Missing method of ml.feature are listed here: ```StringIndexer``` lacks of parameter ```handleInvalid```. ```StringIndexerModel``` lacks of method ```labels```. ```VectorIndexerModel``` lacks of methods ```numFeatures``` and ```categoryMaps```. Author: Yanbo Liang Closes #8313 from yanboliang/spark-10027. --- python/pyspark/ml/feature.py | 31 ++++++++++++++++--- .../ml/param/_shared_params_code_gen.py | 5 ++- python/pyspark/ml/param/shared.py | 31 +++++++++++++++++-- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1c423486be8d9..71dc636b83eac 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -920,7 +920,7 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): """ .. note:: Experimental @@ -943,19 +943,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ @keyword_only - def __init__(self, inputCol=None, outputCol=None): + def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCol=None, outputCol=None) + __init__(self, inputCol=None, outputCol=None, handleInvalid="error") """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) + self._setDefault(handleInvalid="error") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None): + def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCol=None, outputCol=None) + setParams(self, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this StringIndexer. """ kwargs = self.setParams._input_kwargs @@ -1235,6 +1236,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> model = indexer.fit(df) >>> model.transform(df).head().indexed DenseVector([1.0, 0.0]) + >>> model.numFeatures + 2 + >>> model.categoryMaps + {0: {0.0: 0, -1.0: 1}} >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test DenseVector([0.0, 1.0]) >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"} @@ -1297,6 +1302,22 @@ class VectorIndexerModel(JavaModel): Model fitted by VectorIndexer. """ + @property + def numFeatures(self): + """ + Number of features, i.e., length of Vectors which this transforms. + """ + return self._call_java("numFeatures") + + @property + def categoryMaps(self): + """ + Feature value index. Keys are categorical feature indices (column indices). + Values are maps from original features values to 0-based category indices. + If a feature is not in this map, it is treated as continuous. + """ + return self._call_java("javaCategoryMaps") + @inherit_doc class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 69efc424ec4ef..926375e44871d 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -121,7 +121,10 @@ def get$Name(self): ("checkpointInterval", "checkpoint interval (>= 1)", None), ("seed", "random seed", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms", None), - ("stepSize", "Step size to be used for each iteration of optimization.", None)] + ("stepSize", "Step size to be used for each iteration of optimization.", None), + ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + + "out rows with bad values), or error (which will throw an errror). More options may be " + + "added later.", None)] code = [] for name, doc, defaultValueStr in shared: param_code = _gen_param_header(name, doc, defaultValueStr) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 595124726366d..682170aee85fb 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -432,6 +432,33 @@ def getStepSize(self): return self.getOrDefault(self.stepSize) +class HasHandleInvalid(Params): + """ + Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + """ + + # a placeholder to make it appear in the generated doc + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + + def __init__(self): + super(HasHandleInvalid, self).__init__() + #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + self._paramMap[self.handleInvalid] = value + return self + + def getHandleInvalid(self): + """ + Gets the value of handleInvalid or its default value. + """ + return self.getOrDefault(self.handleInvalid) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. @@ -444,7 +471,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -460,7 +487,7 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. From e1d7f64296d8164cc4c935aa524af664f63bf9c1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 11 Sep 2015 18:26:56 +0800 Subject: [PATCH 0048/2089] [SPARK-10472] [SQL] Fixes DataType.typeName for UDT Before this fix, `MyDenseVectorUDT.typeName` gives `mydensevecto`, which is not desirable. Author: Cheng Lian Closes #8640 from liancheng/spark-10472/udt-type-name. --- .../main/scala/org/apache/spark/sql/types/DataType.scala | 4 +++- .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7bcd623b3f33e..4b54c31dcc27a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -51,7 +51,9 @@ abstract class DataType extends AbstractDataType { def defaultSize: Int /** Name of the type used in JSON serialization. */ - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + def typeName: String = { + this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + } private[sql] def jsonValue: JValue = typeName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fa8f9c8e00089..46d87843dfa4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -157,4 +157,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { Nil ) } + + test("SPARK-10472 UserDefinedType.typeName") { + assert(IntegerType.typeName === "integer") + assert(new MyDenseVectorUDT().typeName === "mydensevector") + assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") + } } From 9bbe33f318c866c0b13088917542715062f0787f Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Fri, 11 Sep 2015 13:06:14 +0100 Subject: [PATCH 0049/2089] [SPARK-10556] Remove explicit Scala version for sbt project build files Previously, project/plugins.sbt explicitly set scalaVersion to 2.10.4. This can cause issues when using a version of sbt that is compiled against a different version of Scala (for example sbt 0.13.9 uses 2.10.5). Removing this explicit setting will cause build files to be compiled and run against the same version of Scala that sbt is compiled against. Note that this only applies to the project build files (items in project/), it is distinct from the version of Scala we target for the actual spark compilation. Author: Ahir Reddy Closes #8709 from ahirreddy/sbt-scala-version-fix. --- project/plugins.sbt | 2 -- 1 file changed, 2 deletions(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 51820460ca1a0..c06687d8f197b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,5 +1,3 @@ -scalaVersion := "2.10.4" - resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" From c268ca4ddde2f5213b2e3985dcaaac5900aea71c Mon Sep 17 00:00:00 2001 From: y-shimizu Date: Fri, 11 Sep 2015 08:27:30 -0700 Subject: [PATCH 0050/2089] [SPARK-10518] [DOCS] Update code examples in spark.ml user guide to use LIBSVM data source instead of MLUtils I fixed to use LIBSVM data source in the example code in spark.ml instead of MLUtils Author: y-shimizu Closes #8697 from y-shimizu/SPARK-10518. --- docs/ml-ensembles.md | 65 ++++++++++++--------------------------- docs/ml-features.md | 64 +++++++++++++------------------------- docs/ml-linear-methods.md | 22 ++++--------- 3 files changed, 47 insertions(+), 104 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 62749909e01dc..58f566c9b4b55 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -121,10 +121,9 @@ import org.apache.spark.ml.classification.RandomForestClassifier import org.apache.spark.ml.classification.RandomForestClassificationModel import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -193,14 +192,11 @@ import org.apache.spark.ml.classification.RandomForestClassifier; import org.apache.spark.ml.classification.RandomForestClassificationModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -268,10 +264,9 @@ from pyspark.ml import Pipeline from pyspark.ml.classification import RandomForestClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -327,10 +322,9 @@ import org.apache.spark.ml.regression.RandomForestRegressor import org.apache.spark.ml.regression.RandomForestRegressionModel import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -387,14 +381,11 @@ import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.RandomForestRegressionModel; import org.apache.spark.ml.regression.RandomForestRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -450,10 +441,9 @@ from pyspark.ml import Pipeline from pyspark.ml.regression import RandomForestRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -576,10 +566,9 @@ import org.apache.spark.ml.classification.GBTClassifier import org.apache.spark.ml.classification.GBTClassificationModel import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -648,14 +637,10 @@ import org.apache.spark.ml.classification.GBTClassifier; import org.apache.spark.ml.classification.GBTClassificationModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -724,10 +709,9 @@ from pyspark.ml import Pipeline from pyspark.ml.classification import GBTClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -783,10 +767,9 @@ import org.apache.spark.ml.regression.GBTRegressor import org.apache.spark.ml.regression.GBTRegressionModel import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -844,14 +827,10 @@ import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.GBTRegressionModel; import org.apache.spark.ml.regression.GBTRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -908,10 +887,9 @@ from pyspark.ml import Pipeline from pyspark.ml.regression import GBTRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -970,15 +948,14 @@ Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifie {% highlight scala %} import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SQLContext} val sqlContext = new SQLContext(sc) // parse data into dataframe -val data = MLUtils.loadLibSVMFile(sc, - "data/mllib/sample_multiclass_classification_data.txt") -val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") +val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) // instantiate multiclass learner and train val ovr = new OneVsRest().setClassifier(new LogisticRegression) @@ -1016,9 +993,6 @@ import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -1026,10 +1000,9 @@ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); -RDD data = MLUtils.loadLibSVMFile(jsc.sc(), - "data/mllib/sample_multiclass_classification_data.txt"); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); DataFrame train = splits[0]; DataFrame test = splits[1]; diff --git a/docs/ml-features.md b/docs/ml-features.md index 58b31a5a5cc47..a414c21b5c280 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1179,9 +1179,9 @@ In the example below, we read in a dataset of labeled points and then use `Vecto
{% highlight scala %} import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -1200,16 +1200,12 @@ val indexedData = indexerModel.transform(data) {% highlight java %} import java.util.Map; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD rdd = MLUtils.loadLibSVMFile(sc.sc(), - "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -1230,9 +1226,9 @@ DataFrame indexedData = indexerModel.transform(data);
{% highlight python %} from pyspark.ml.feature import VectorIndexer -from pyspark.mllib.util import MLUtils -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) indexerModel = indexer.fit(data) @@ -1253,10 +1249,9 @@ The following example demonstrates how to load a dataset in libsvm format and th
{% highlight scala %} import org.apache.spark.ml.feature.Normalizer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") // Normalize each Vector using $L^1$ norm. val normalizer = new Normalizer() @@ -1272,15 +1267,11 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi
{% highlight java %} -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Normalize each Vector using $L^1$ norm. Normalizer normalizer = new Normalizer() @@ -1297,11 +1288,10 @@ DataFrame lInfNormData =
{% highlight python %} -from pyspark.mllib.util import MLUtils from pyspark.ml.feature import Normalizer -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") # Normalize each Vector using $L^1$ norm. normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) @@ -1335,10 +1325,9 @@ The following example demonstrates how to load a dataset in libsvm format and th
{% highlight scala %} import org.apache.spark.ml.feature.StandardScaler -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -1355,16 +1344,12 @@ val scaledData = scalerModel.transform(dataFrame)
{% highlight java %} -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StandardScaler; import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); StandardScaler scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -1381,11 +1366,10 @@ DataFrame scaledData = scalerModel.transform(dataFrame);
{% highlight python %} -from pyspark.mllib.util import MLUtils from pyspark.ml.feature import StandardScaler -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False) @@ -1424,10 +1408,9 @@ More details can be found in the API docs for [MinMaxScalerModel](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel). {% highlight scala %} import org.apache.spark.ml.feature.MinMaxScaler -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -1448,13 +1431,10 @@ More details can be found in the API docs for import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.MinMaxScaler; import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); MinMaxScaler scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaledFeatures"); diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index cdd9d4999fa1b..4e94e2f9c708d 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -59,10 +59,9 @@ $\alpha$ and `regParam` corresponds to $\lambda$.
{% highlight scala %} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.mllib.util.MLUtils // Load training data -val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val lr = new LogisticRegression() .setMaxIter(10) @@ -81,8 +80,6 @@ println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") {% highlight java %} import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.sql.DataFrame; @@ -98,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample { String path = "data/mllib/sample_libsvm_data.txt"; // Load training data - DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + DataFrame training = sqlContext.read.format("libsvm").load(path); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) @@ -118,11 +115,9 @@ public class LogisticRegressionWithElasticNetExample {
{% highlight python %} from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils # Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) @@ -251,10 +246,9 @@ regression model and extracting model summary statistics.
{% highlight scala %} import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.mllib.util.MLUtils // Load training data -val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val lr = new LinearRegression() .setMaxIter(10) @@ -283,8 +277,6 @@ import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.regression.LinearRegressionModel; import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.sql.DataFrame; @@ -300,7 +292,7 @@ public class LinearRegressionWithElasticNetExample { String path = "data/mllib/sample_libsvm_data.txt"; // Load training data - DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + DataFrame training = sqlContext.read.format("libsvm").load(path); LinearRegression lr = new LinearRegression() .setMaxIter(10) @@ -329,11 +321,9 @@ public class LinearRegressionWithElasticNetExample { {% highlight python %} from pyspark.ml.regression import LinearRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils # Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) From b656e6134fc5cd27e1fe6b6ab30fd7633cab0b14 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Sep 2015 08:50:35 -0700 Subject: [PATCH 0051/2089] [SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here: ```scala HasElasticNetParam HasFitIntercept HasStandardization HasThresholds ``` Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one. Author: Yanbo Liang Closes #8508 from yanboliang/spark-10026. --- python/pyspark/ml/classification.py | 75 ++---------- .../ml/param/_shared_params_code_gen.py | 11 +- python/pyspark/ml/param/shared.py | 111 ++++++++++++++++++ python/pyspark/ml/regression.py | 42 ++----- 4 files changed, 143 insertions(+), 96 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83f808efc3bf0..22bdd1b322aca 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -31,7 +31,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): """ Logistic regression. Currently, this class only supports binary classification. @@ -65,17 +66,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - thresholds = Param(Params._dummy(), "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -83,40 +73,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification, in range [0, 1]. self.threshold = Param(self, "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") - #: param for thresholds or cutoffs in binary or multiclass classification - self.thresholds = \ - Param(self, "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") - self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -124,13 +97,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -142,32 +115,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - self._paramMap[self.fitIntercept] = value - return self - - def getFitIntercept(self): - """ - Gets the value of fitIntercept or its default value. - """ - return self.getOrDefault(self.fitIntercept) - def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 926375e44871d..5b39e5dd4e25b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -124,7 +124,16 @@ def get$Name(self): ("stepSize", "Step size to be used for each iteration of optimization.", None), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None)] + "added later.", None), + ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), + ("fitIntercept", "whether to fit an intercept term.", "True"), + ("standardization", "whether to standardize the training features before fitting the " + + "model.", "True"), + ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + + "predicting each class. Array must have length equal to the number of classes, with " + + "values >= 0. The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class' threshold.", None)] code = [] for name, doc, defaultValueStr in shared: param_code = _gen_param_header(name, doc, defaultValueStr) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 682170aee85fb..af1218128602b 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -459,6 +459,117 @@ def getHandleInvalid(self): return self.getOrDefault(self.handleInvalid) +class HasElasticNetParam(Params): + """ + Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + """ + + # a placeholder to make it appear in the generated doc + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + + def __init__(self): + super(HasElasticNetParam, self).__init__() + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self._setDefault(elasticNetParam=0.0) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + +class HasFitIntercept(Params): + """ + Mixin for param fitIntercept: whether to fit an intercept term.. + """ + + # a placeholder to make it appear in the generated doc + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + + def __init__(self): + super(HasFitIntercept, self).__init__() + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self._setDefault(fitIntercept=True) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + +class HasStandardization(Params): + """ + Mixin for param standardization: whether to standardize the training features before fitting the model.. + """ + + # a placeholder to make it appear in the generated doc + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + + def __init__(self): + super(HasStandardization, self).__init__() + #: param for whether to standardize the training features before fitting the model. + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self._setDefault(standardization=True) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + self._paramMap[self.standardization] = value + return self + + def getStandardization(self): + """ + Gets the value of standardization or its default value. + """ + return self.getOrDefault(self.standardization) + + +class HasThresholds(Params): + """ + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + """ + + # a placeholder to make it appear in the generated doc + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def __init__(self): + super(HasThresholds, self).__init__() + #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + """ + self._paramMap[self.thresholds] = value + return self + + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 44f60a769566d..a9503608b7f25 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,7 +28,8 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol): + HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, + HasStandardization): """ Linear regression. @@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction TypeError: Method setParams forces keyword arguments. """ - # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") - self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ From b01b26260625f0ba14e5f3010207666d62d93864 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Sep 2015 08:52:28 -0700 Subject: [PATCH 0052/2089] [SPARK-9773] [ML] [PySpark] Add Python API for MultilayerPerceptronClassifier Add Python API for ```MultilayerPerceptronClassifier```. Author: Yanbo Liang Closes #8067 from yanboliang/SPARK-9773. --- .../MultilayerPerceptronClassifier.scala | 9 ++ python/pyspark/ml/classification.py | 132 +++++++++++++++++- 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 82fc80c58054f..5f60dea91fcfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} @@ -181,6 +183,13 @@ class MultilayerPerceptronClassificationModel private[ml] ( private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + /** + * Returns layers in a Java List. + */ + private[ml] def javaLayers: java.util.List[Int] = { + layers.toList.asJava + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 22bdd1b322aca..88815e561f572 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,7 +26,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel'] + 'NaiveBayesModel', 'MultilayerPerceptronClassifier', + 'MultilayerPerceptronClassificationModel'] @inherit_doc @@ -755,6 +756,135 @@ def theta(self): return self._call_java("theta") +@inherit_doc +class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasSeed): + """ + Classifier trainer based on the Multilayer Perceptron. + Each layer has sigmoid activation function, output layer has softmax. + Number of inputs has to be equal to the size of feature vectors. + Number of outputs has to be equal to the total number of labels. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (0.0, Vectors.dense([0.0, 0.0])), + ... (1.0, Vectors.dense([0.0, 1.0])), + ... (1.0, Vectors.dense([1.0, 0.0])), + ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> model = mlp.fit(df) + >>> model.layers + [2, 5, 2] + >>> model.weights.size + 27 + >>> testDF = sqlContext.createDataFrame([ + ... (Vectors.dense([1.0, 0.0]),), + ... (Vectors.dense([0.0, 0.0]),)], ["features"]) + >>> model.transform(testDF).show() + +---------+----------+ + | features|prediction| + +---------+----------+ + |[1.0,0.0]| 1.0| + |[0.0,0.0]| 0.0| + +---------+----------+ + ... + """ + + # a placeholder to make it appear in the generated doc + layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + + "neurons and output layer of 10 neurons, default is [1, 1].") + blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is more than " + + "remaining data in a partition then it is adjusted to the size of this " + + "data. Recommended size is between 10 and 1000, default is 128.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + """ + super(MultilayerPerceptronClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) + self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + + "100 neurons and output layer of 10 neurons, default is [1, 1].") + self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is " + + "more than remaining data in a partition then it is adjusted to " + + "the size of this data. Recommended size is between 10 and 1000, " + + "default is 128.") + self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + Sets params for MultilayerPerceptronClassifier. + """ + kwargs = self.setParams._input_kwargs + if layers is None: + return self._set(**kwargs).setLayers([1, 1]) + else: + return self._set(**kwargs) + + def _create_model(self, java_model): + return MultilayerPerceptronClassificationModel(java_model) + + def setLayers(self, value): + """ + Sets the value of :py:attr:`layers`. + """ + self._paramMap[self.layers] = value + return self + + def getLayers(self): + """ + Gets the value of layers or its default value. + """ + return self.getOrDefault(self.layers) + + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + self._paramMap[self.blockSize] = value + return self + + def getBlockSize(self): + """ + Gets the value of blockSize or its default value. + """ + return self.getOrDefault(self.blockSize) + + +class MultilayerPerceptronClassificationModel(JavaModel): + """ + Model fitted by MultilayerPerceptronClassifier. + """ + + @property + def layers(self): + """ + array of layer sizes including input and output layers. + """ + return self._call_java("javaLayers") + + @property + def weights(self): + """ + vector of initial weights for the model that consists of the weights of layers. + """ + return self._call_java("weights") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 960d2d0ac6b5a22242a922f87f745f7d1f736181 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 11 Sep 2015 08:53:40 -0700 Subject: [PATCH 0053/2089] [SPARK-10537] [ML] document LIBSVM source options in public API doc and some minor improvements We should document options in public API doc. Otherwise, it is hard to find out the options without looking at the code. I tried to make `DefaultSource` private and put the documentation to package doc. However, since then there exists no public class under `source.libsvm`, the Java package doc doesn't show up in the generated html file (http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4492654). So I put the doc to `DefaultSource` instead. There are several minor updates in this PR: 1. Do `vectorType == "sparse"` only once. 2. Update `hashCode` and `equals`. 3. Remove inherited doc. 4. Delete temp dir in `afterAll`. Lewuathe Author: Xiangrui Meng Closes #8699 from mengxr/SPARK-10537. --- .../ml/source/libsvm/LibSVMRelation.scala | 71 ++++++++++++------- .../{ => libsvm}/JavaLibSVMRelationSuite.java | 24 +++---- .../{ => libsvm}/LibSVMRelationSuite.scala | 14 ++-- 3 files changed, 66 insertions(+), 43 deletions(-) rename mllib/src/test/java/org/apache/spark/ml/source/{ => libsvm}/JavaLibSVMRelationSuite.java (79%) rename mllib/src/test/scala/org/apache/spark/ml/source/{ => libsvm}/LibSVMRelationSuite.scala (88%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b12cb62a4ef15..1f627777fc68d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -21,12 +21,12 @@ import com.google.common.base.Objects import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{StructType, StructField, DoubleType} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._ * @param vectorType The type of vector. It can be 'sparse' or 'dense' * @param sqlContext The Spark SQLContext */ -private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) +private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with Logging with Serializable { @@ -47,27 +47,56 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec override def buildScan(): RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - + val sparse = vectorType == "sparse" baseRdd.map { pt => - val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse + val features = if (sparse) pt.features.toSparse else pt.features.toDense Row(pt.label, features) } } override def hashCode(): Int = { - Objects.hashCode(path, schema) + Objects.hashCode(path, Double.box(numFeatures), vectorType) } override def equals(other: Any): Boolean = other match { - case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) - case _ => false + case that: LibSVMRelation => + path == that.path && + numFeatures == that.numFeatures && + vectorType == that.vectorType + case _ => + false } - } /** - * This is used for creating DataFrame from LibSVM format file. - * The LibSVM file path must be specified to DefaultSource. + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and + * `features` containing feature vectors stored as [[Vector]]s. + * + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and + * optionally specify options, for example: + * {{{ + * // Scala + * val df = sqlContext.read.format("libsvm") + * .option("numFeatures", "780") + * .load("data/mllib/sample_libsvm_data.txt") + * + * // Java + * DataFrame df = sqlContext.read.format("libsvm") + * .option("numFeatures, "780") + * .load("data/mllib/sample_libsvm_data.txt"); + * }}} + * + * LIBSVM data source supports the following options: + * - "numFeatures": number of features. + * If unspecified or nonpositive, the number of features will be determined automatically at the + * cost of one additional pass. + * This is also useful when the dataset is already split into multiple files and you want to load + * them separately, because some features may not present in certain files, which leads to + * inconsistent feature dimensions. + * - "vectorType": feature vector type, "sparse" (default) or "dense". + * + * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") class DefaultSource extends RelationProvider with DataSourceRegister { @@ -75,24 +104,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - private def checkPath(parameters: Map[String, String]): String = { - require(parameters.contains("path"), "'path' must be specified") - parameters.get("path").get - } - - /** - * Returns a new base relation with the given parameters. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced - * by the Map that is passed to the function. - */ + @Since("1.6.0") override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) : BaseRelation = { - val path = checkPath(parameters) + val path = parameters.getOrElse("path", + throw new IllegalArgumentException("'path' must be specified")) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt - /** - * featuresType can be selected "dense" or "sparse". - * This parameter decides the type of returned feature vector. - */ val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) } diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java similarity index 79% rename from mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java rename to mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 11fa4eec0ccf0..2976b38e45031 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.source; +package org.apache.spark.ml.source.libsvm; import java.io.File; import java.io.IOException; @@ -42,34 +42,34 @@ */ public class JavaLibSVMRelationSuite { private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient DataFrame dataset; + private transient SQLContext sqlContext; - private File tmpDir; - private File path; + private File tempDir; + private String path; @Before public void setUp() throws IOException { jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); - jsql = new SQLContext(jsc); - - tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); - path = new File(tmpDir.getPath(), "part-00000"); + sqlContext = new SQLContext(jsc); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, path, Charsets.US_ASCII); + Files.write(s, file, Charsets.US_ASCII); + path = tempDir.toURI().toString(); } @After public void tearDown() { jsc.stop(); jsc = null; - Utils.deleteRecursively(tmpDir); + Utils.deleteRecursively(tempDir); } @Test public void verifyLibSVMDF() { - dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath()); + DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 8ed134128c8d2..997f574e51f6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.source +package org.apache.spark.ml.source.libsvm import java.io.File @@ -23,11 +23,12 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var tempDir: File = _ var path: String = _ override def beforeAll(): Unit = { @@ -38,12 +39,17 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Utils.createTempDir() - val file = new File(tempDir.getPath, "part-00000") + tempDir = Utils.createTempDir() + val file = new File(tempDir, "part-00000") Files.write(lines, file, Charsets.US_ASCII) path = tempDir.toURI.toString } + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + test("select as sparse vector") { val df = sqlContext.read.format("libsvm").load(path) assert(df.columns(0) == "label") From 2e3a280754a28dc36a71b9ff988e34cbf457f6c3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 11 Sep 2015 08:55:35 -0700 Subject: [PATCH 0054/2089] [MINOR] [MLLIB] [ML] [DOC] Minor doc fixes for StringIndexer and MetadataUtils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: * Make Scala doc for StringIndexerInverse clearer. Also remove Scala doc from transformSchema, so that the doc is inherited. * MetadataUtils.scala: “ Helper utilities for tree-based algorithms†—> not just trees anymore CC: holdenk mengxr Author: Joseph K. Bradley Closes #8679 from jkbradley/doc-fixes-1.5. --- .../spark/ml/feature/StringIndexer.scala | 31 +++++++------------ .../apache/spark/ml/util/MetadataUtils.scala | 2 +- python/pyspark/ml/feature.py | 16 +++++----- 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index b6482ffe0b2ee..3a4ab9a857648 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -181,10 +181,10 @@ class StringIndexerModel ( /** * :: Experimental :: - * A [[Transformer]] that maps a column of string indices back to a new column of corresponding - * string values using either the ML attributes of the input column, or if provided using the labels - * supplied by the user. - * All original columns are kept during transformation. + * A [[Transformer]] that maps a column of indices back to a new column of corresponding + * string values. + * The index-string mapping is either from the ML attributes of the input column, + * or from user-supplied labels (which take precedence over ML attributes). * * @see [[StringIndexer]] for converting strings into indices */ @@ -202,32 +202,23 @@ class IndexToString private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. The default value is an empty array, - * but the empty array is ignored and column metadata used instead. - * @group setParam - */ + /** @group setParam */ def setLabels(value: Array[String]): this.type = set(labels, value) /** - * Param for array of labels. - * Optional labels to be provided by the user. - * Default: Empty array, in which case column metadata is used for labels. + * Optional param for array of labels specifying index-string mapping. + * + * Default: Empty array, in which case [[inputCol]] metadata is used for labels. * @group param */ final val labels: StringArrayParam = new StringArrayParam(this, "labels", - "array of labels, if not provided metadata from inputCol is used instead.") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") setDefault(labels, Array.empty[String]) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. - * @group getParam - */ + /** @group getParam */ final def getLabels: Array[String] = $(labels) - /** Transform the schema for the inverse transformation */ override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index fcb517b5f735e..96a38a3bde960 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructField /** - * Helper utilities for tree-based algorithms + * Helper utilities for algorithms using ML metadata */ private[spark] object MetadataUtils { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 71dc636b83eac..97cbee73a05ed 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -985,17 +985,17 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): """ .. note:: Experimental - A :py:class:`Transformer` that maps a column of string indices back to a new column of - corresponding string values using either the ML attributes of the input column, or if - provided using the labels supplied by the user. - All original columns are kept during transformation. + A :py:class:`Transformer` that maps a column of indices back to a new column of + corresponding string values. + The index-string mapping is either from the ML attributes of the input column, + or from user-supplied labels (which take precedence over ML attributes). See L{StringIndexer} for converting strings into indices. """ # a placeholder to make the labels show up in generated doc labels = Param(Params._dummy(), "labels", - "Optional array of labels to be provided by the user, if not supplied or " + - "empty, column metadata is read for labels") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") @keyword_only def __init__(self, inputCol=None, outputCol=None, labels=None): @@ -1006,8 +1006,8 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) self.labels = Param(self, "labels", - "Optional array of labels to be provided by the user, if not " + - "supplied or empty, column metadata is read for labels") + "Optional array of labels specifying index-string mapping. If not" + + " provided or if empty, then metadata from inputCol is used instead.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) From 6ce0886eb0916a985db142c0b6d2c2b14db5063d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Sep 2015 09:42:53 -0700 Subject: [PATCH 0055/2089] [SPARK-10540] [SQL] Ignore HadoopFsRelationTest's "test all data types" if it is too flaky If hadoopFsRelationSuites's "test all data types" is too flaky we can disable it for now. https://issues.apache.org/jira/browse/SPARK-10540 Author: Yin Huai Closes #8705 from yhuai/SPARK-10540-ignore. --- .../org/apache/spark/sql/sources/hadoopFsRelationSuites.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 24f43cf7c15ca..13223c61584b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,7 +100,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - test("test all data types") { + ignore("test all data types") { withTempPath { file => // Create the schema. val struct = From 5f46444765a377696af76af6e2c77ab14bfdab8e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 11 Sep 2015 10:32:35 -0700 Subject: [PATCH 0056/2089] [SPARK-8530] [ML] add python API for MinMaxScaler jira: https://issues.apache.org/jira/browse/SPARK-8530 add python API for MinMaxScaler jira for MinMaxScaler: https://issues.apache.org/jira/browse/SPARK-7514 Author: Yuhao Yang Closes #7150 from hhbyyh/pythonMinMax. --- python/pyspark/ml/feature.py | 104 +++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 97cbee73a05ed..92db8df80280b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,11 +27,11 @@ from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', - 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel'] + 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', + 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -406,6 +406,100 @@ class IDFModel(JavaModel): """ +@inherit_doc +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to a common range [min, max] linearly using column summary + statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + feature E is calculated as, + + Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min + + For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) + + Note that since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") + >>> model = mmScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[0.0]| [0.0]| + |[2.0]| [1.0]| + +-----+------+ + ... + """ + + # a placeholder to make it appear in the generated doc + min = Param(Params._dummy(), "min", "Lower bound of the output feature range") + max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + + @keyword_only + def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + """ + super(MinMaxScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) + self.min = Param(self, "min", "Lower bound of the output feature range") + self.max = Param(self, "max", "Upper bound of the output feature range") + self._setDefault(min=0.0, max=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + Sets params for this MinMaxScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMin(self, value): + """ + Sets the value of :py:attr:`min`. + """ + self._paramMap[self.min] = value + return self + + def getMin(self): + """ + Gets the value of min or its default value. + """ + return self.getOrDefault(self.min) + + def setMax(self, value): + """ + Sets the value of :py:attr:`max`. + """ + self._paramMap[self.max] = value + return self + + def getMax(self): + """ + Gets the value of max or its default value. + """ + return self.getOrDefault(self.max) + + def _create_model(self, java_model): + return MinMaxScalerModel(java_model) + + +class MinMaxScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MinMaxScaler`. + """ + + @inherit_doc @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol): From b231ab8938ae3c4fc2089cfc69c0d8164807d533 Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 11 Sep 2015 21:45:45 +0100 Subject: [PATCH 0057/2089] [SPARK-10546] Check partitionId's range in ExternalSorter#spill() See this thread for background: http://search-hadoop.com/m/q3RTt0rWvIkHAE81 We should check the range of partition Id and provide meaningful message through exception. Alternatively, we can use abs() and modulo to force the partition Id into legitimate range. However, expectation is that user should correct the logic error in his / her code. Author: tedyu Closes #8703 from tedyu/master. --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 138c05dff19e4..31230d5978b2a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -297,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C]( val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { val partitionId = it.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 From c373866774c082885a50daaf7c83f3a14b0cd714 Mon Sep 17 00:00:00 2001 From: Icaro Medeiros Date: Fri, 11 Sep 2015 21:46:52 +0100 Subject: [PATCH 0058/2089] [PYTHON] Fixed typo in exception message Just fixing a typo in exception message, raised when attempting to pickle SparkContext. Author: Icaro Medeiros Closes #8724 from icaromedeiros/master. --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1b2a52ad64114..a0a1ccbeefb09 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -255,7 +255,7 @@ def __getnewargs__(self): # This method is called when attempting to pickle SparkContext, which is always an error: raise Exception( "It appears that you are attempting to reference SparkContext from a broadcast " - "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "variable, action, or transformation. SparkContext can only be used on the driver, " "not in code that it run on workers. For more information, see SPARK-5063." ) From d5d647380f93f4773f9cb85ea6544892d409b5a1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 Sep 2015 14:15:16 -0700 Subject: [PATCH 0059/2089] [SPARK-10442] [SQL] fix string to boolean cast When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior. However, this behavior is very different from other SQL systems: 1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others. 2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others. 4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean. 6. mysql, oracle, sqlserver don't have boolean type Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail. Author: Wenchen Fan Closes #8698 from cloud-fan/string2boolean. --- .../spark/sql/catalyst/expressions/Cast.scala | 24 +++++++- .../spark/sql/catalyst/util/StringUtils.scala | 8 +++ .../sql/catalyst/expressions/CastSuite.scala | 61 ++++++++++++------- .../sql/sources/hadoopFsRelationSuites.scala | 13 ++++ 4 files changed, 82 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2db954257be35..f0bce388d959a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType) // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, _.numBytes() != 0) + buildCast[UTF8String](_, s => { + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + null + } + }) case TimestampType => buildCast[Long](_, t => t != 0) case DateType => @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + (c, evPrim, evNull) => + s""" + if ($stringUtils.isTrueString($c)) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c)) { + $evPrim = false; + } else { + $evNull = true; + } + """ case TimestampType => (c, evPrim, evNull) => s"$evPrim = $c != 0;" case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 9ddfb3a0d3759..c2eeb3c5650ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern +import org.apache.spark.unsafe.types.UTF8String + object StringUtils { // replace the _ with .{1} exactly match 1 time of any character @@ -44,4 +46,10 @@ object StringUtils { v } } + + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1ad70733eae03..f4db4da7646f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from array") { - val array = Literal.create(Seq("123", "abc", "", null), + val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), + val array_notNull = Literal.create(Seq("123", "true", "f"), ArrayType(StringType, containsNull = false)) checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) @@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false, null)) + checkEvaluation(ret, Seq(null, true, false, null)) } { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) @@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { @@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from map") { val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, valueContainsNull = false)) checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) @@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null)) } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) @@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString(""), + UTF8String.fromString("true"), + UTF8String.fromString("f"), null), StructType(Seq( StructField("a", StringType, nullable = true), @@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct_notNull = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString("")), + UTF8String.fromString("true"), + UTF8String.fromString("f")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false, null)) + checkEvaluation(ret, InternalRow(null, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( Row( - Seq("123", "abc", ""), - Map("a" ->"123", "b" -> "abc", "c" -> ""), + Seq("123", "true", "f"), + Map("a" ->"123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, Row( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map("a" -> null, "b" -> true, "c" -> false), Row(0L))) } - test("case between string and interval") { + test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), @@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } + + test("cast string to boolean") { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 13223c61584b2..8ffcef85668d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) From 1eede3b254ee3793841c92971707094ac8afee35 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 11 Sep 2015 14:55:15 -0700 Subject: [PATCH 0060/2089] [SPARK-7142] [SQL] Minor enhancement to BooleanSimplification Optimizer rule. Incorporate review comments Adding changes suggested by cloud-fan in #5700 cc marmbrus Author: Yash Datta Closes #8716 from saucam/bool_simp. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d9b50f3c97da0..0f4caec7451a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -435,10 +435,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // a && a => a case (l, r) if l fastEquals r => l // a && (not(a) || b) => a && b - case (l, Or(l1, r)) if (Not(l) fastEquals l1) => And(l, r) - case (l, Or(r, l1)) if (Not(l) fastEquals l1) => And(l, r) - case (Or(l, l1), r) if (l1 fastEquals Not(r)) => And(l, r) - case (Or(l1, l), r) if (l1 fastEquals Not(r)) => And(l, r) + case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) + case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) + case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) + case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, From e626ac5f5c27dcc74113070f2fec03682bcd12bd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 11 Sep 2015 15:00:13 -0700 Subject: [PATCH 0061/2089] [SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK, sample and intersect operators This PR is in conflict with #8535. I will update this one when #8535 gets merged. Author: zsxwing Closes #8573 from zsxwing/more-local-operators. --- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/local/IntersectNode.scala | 63 ++++++++++++++ .../spark/sql/execution/local/LocalNode.scala | 5 ++ .../sql/execution/local/SampleNode.scala | 82 +++++++++++++++++++ .../local/TakeOrderedAndProjectNode.scala | 73 +++++++++++++++++ .../execution/local/IntersectNodeSuite.scala | 35 ++++++++ .../sql/execution/local/SampleNodeSuite.scala | 40 +++++++++ .../TakeOrderedAndProjectNodeSuite.scala | 54 ++++++++++++ 8 files changed, 353 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3f68b05a24f44..bf6d44c098ee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed - * @param child the QueryPlan + * @param child the SparkPlan */ @DeveloperApi case class Sample( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala new file mode 100644 index 0000000000000..740d485f8d9e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -0,0 +1,63 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.collection.mutable + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) + extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = left.output + + private[this] var leftRows: mutable.HashSet[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + left.open() + leftRows = mutable.HashSet[InternalRow]() + while (left.next()) { + leftRows += left.fetch().copy() + } + left.close() + right.open() + } + + override def next(): Boolean = { + currentRow = null + while (currentRow == null && right.next()) { + currentRow = right.fetch() + if (!leftRows.contains(currentRow)) { + currentRow = null + } + } + currentRow != null + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index c4f8ae304db39..a2c275db9b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala new file mode 100644 index 0000000000000..abf3df1c0c2af --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import java.util.Random + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + +/** + * Sample the dataset. + * + * @param conf the SQLConf + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LocalNode + */ +case class SampleNode( + conf: SQLConf, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + val (sampler, _seed) = if (withReplacement) { + val random = new Random(seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result + // of DataFrame + random.nextLong()) + } else { + (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + } + sampler.setSeed(_seed) + iterator = sampler.sample(child.asIterator) + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala new file mode 100644 index 0000000000000..53f1dcc65d8cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.util.BoundedPriorityQueue + +case class TakeOrderedAndProjectNode( + conf: SQLConf, + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: LocalNode) extends UnaryLocalNode(conf) { + + private[this] var projection: Option[Projection] = _ + private[this] var ord: InterpretedOrdering = _ + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def open(): Unit = { + child.open() + projection = projectList.map(new InterpretedProjection(_, child.output)) + ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) + while (child.next()) { + queue += child.fetch() + } + // Close it eagerly since we don't need it. + child.close() + iterator = queue.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + val _currentRow = iterator.next() + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 0000000000000..7deaa375fcfc2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,35 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class IntersectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("basic") { + val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") + + checkAnswer2( + input1, + input2, + (node1, node2) => IntersectNode(conf, node1, node2), + input1.intersect(input2).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 0000000000000..87a7da453999c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +class SampleNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def testSample(withReplacement: Boolean): Unit = { + test(s"withReplacement: $withReplacement") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), + input.sample(withReplacement, 0.3, seed).collect() + ) + } + } + + testSample(withReplacement = true) + testSample(withReplacement = false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 0000000000000..ff28b24eeff14 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + sortOrder + } + + private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { + val testCaseName = if (desc) "desc" else "asc" + test(testCaseName) { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val sortColumn = if (desc) input.col("key").desc else input.col("key") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), + input.sort(sortColumn).limit(5).collect() + ) + } + } + + testTakeOrderedAndProjectNode(desc = false) + testTakeOrderedAndProjectNode(desc = true) +} From c2af42b5f32287ff595ad027a8191d4b75702d8d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 11 Sep 2015 15:01:37 -0700 Subject: [PATCH 0062/2089] [SPARK-9990] [SQL] Local hash join follow-ups 1. Hide `LocalNodeIterator` behind the `LocalNode#asIterator` method 2. Add tests for this Author: Andrew Or Closes #8708 from andrewor14/local-hash-join-follow-up. --- .../sql/execution/joins/HashedRelation.scala | 7 +- .../sql/execution/local/HashJoinNode.scala | 3 +- .../spark/sql/execution/local/LocalNode.scala | 4 +- .../sql/execution/local/LocalNodeSuite.scala | 116 ++++++++++++++++++ 4 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0cff21ca618b4..bc255b27502b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.local.LocalNode +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -113,6 +114,10 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR private[execution] object HashedRelation { + def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { + apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) + } + def apply( input: Iterator[InternalRow], numInputRows: LongSQLMetric, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index a3e68d6a7c341..e7b24e3fca2b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -75,8 +75,7 @@ case class HashJoinNode( override def open(): Unit = { buildNode.open() - hashed = HashedRelation.apply( - new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator) + hashed = HashedRelation(buildNode, buildSideKeyGenerator) streamedNode.open() joinRow = new JoinedRow resultProjection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index a2c275db9b35d..e540ef8555eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -77,7 +77,7 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ - def collect(): Seq[Row] = { + final def collect(): Seq[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() @@ -140,7 +140,7 @@ abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { /** * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. */ -private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { +private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { private var nextRow: InternalRow = _ override def hasNext: Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 0000000000000..b89fa46f8b3b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class LocalNodeSuite extends SparkFunSuite { + private val data = (1 to 100).toArray + + test("basic open, next, fetch, close") { + val node = new DummyLocalNode(data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { i => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 1) + assert(node.fetch().getInt(0) === i) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyLocalNode(data) + val iter = node.asIterator + node.open() + data.foreach { i => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 1) + assert(item.getInt(0) === i) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyLocalNode(data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 1)) + assert(collected.map(_.getInt(0)) === data) + node.close() + } + +} + +/** + * A dummy [[LocalNode]] that just returns one row per integer in the input. + */ +private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { + private var index = Int.MinValue + + def this(input: Array[Int]) { + this(new SQLConf, input) + } + + def isOpen: Boolean = { + index != Int.MinValue + } + + override def output: Seq[Attribute] = { + Seq(AttributeReference("something", IntegerType)()) + } + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + val values = Array(input(index).asInstanceOf[Any]) + new GenericInternalRow(values) + } + + override def close(): Unit = { + index = Int.MinValue + } +} From d74c6a143cbd060c25bf14a8d306841b3ec55d03 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 11 Sep 2015 15:02:59 -0700 Subject: [PATCH 0063/2089] [SPARK-10564] ThreadingSuite: assertion failures in threads don't fail the test This commit ensures if an assertion fails within a thread, it will ultimately fail the test. Otherwise we end up potentially masking real bugs by not propagating assertion failures properly. Author: Andrew Or Closes #8723 from andrewor14/fix-threading-suite. --- .../org/apache/spark/ThreadingSuite.scala | 68 ++++++++++++------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 48509f0759a3b..cda2b245526f7 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -119,23 +119,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() + var throwable: Option[Throwable] = None for (i <- 0 until 2) { new Thread { override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() + try { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } }.start() } @@ -145,18 +152,25 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } + throwable.foreach { t => throw t } } test("set local properties in different thread") { sc = new SparkContext("local", "test") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -165,20 +179,27 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === null) + throwable.foreach { t => throw t } } test("set and get local properties in parent-children thread") { sc = new SparkContext("local", "test") sc.setLocalProperty("test", "parent") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - assert(sc.getLocalProperty("test") === "parent") - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + assert(sc.getLocalProperty("test") === "parent") + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -188,6 +209,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) + throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { From c34fc19765bdf55365cdce78d9ba11b220b73bb6 Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Fri, 11 Sep 2015 15:19:04 -0700 Subject: [PATCH 0064/2089] [SPARK-9014] [SQL] Allow Python spark API to use built-in exponential operator This PR addresses (SPARK-9014)[https://issues.apache.org/jira/browse/SPARK-9014] Added functionality: `Column` object in Python now supports exponential operator `**` Example: ``` from pyspark.sql import * df = sqlContext.createDataFrame([Row(a=2)]) df.select(3**df.a,df.a**3,df.a**df.a).collect() ``` Outputs: ``` [Row(POWER(3.0, a)=9.0, POWER(a, 3.0)=8.0, POWER(a, a)=4.0)] ``` Author: 0x0FFF Closes #8658 from 0x0FFF/SPARK-9014. --- python/pyspark/sql/column.py | 13 +++++++++++++ python/pyspark/sql/tests.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 573f65f5bf096..9ca8e1f264cfa 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -91,6 +91,17 @@ def _(self): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -151,6 +162,8 @@ def __init__(self, jc): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb449e8679fa0..f2172b7a27d88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,7 +568,7 @@ def test_column_operators(self): cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) From 6d8367807cb62c2cb139cee1d039dc8b12c63385 Mon Sep 17 00:00:00 2001 From: Daniel Imfeld Date: Sat, 12 Sep 2015 09:19:59 +0100 Subject: [PATCH 0065/2089] [SPARK-10566] [CORE] SnappyCompressionCodec init exception handling masks important error information When throwing an IllegalArgumentException in SnappyCompressionCodec.init, chain the existing exception. This allows potentially important debugging info to be passed to the user. Manual testing shows the exception chained properly, and the test suite still looks fine as well. This contribution is my original work and I license the work to the project under the project's open source license. Author: Daniel Imfeld Closes #8725 from dimfeld/dimfeld-patch-1. --- core/src/main/scala/org/apache/spark/io/CompressionCodec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 607d5a321efca..9dc36704a676d 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -148,7 +148,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { From 8285e3b0d3dc0eff669eba993742dfe0401116f9 Mon Sep 17 00:00:00 2001 From: Nithin Asokan Date: Sat, 12 Sep 2015 09:50:49 +0100 Subject: [PATCH 0066/2089] [SPARK-10554] [CORE] Fix NPE with ShutdownHook https://issues.apache.org/jira/browse/SPARK-10554 Fixes NPE when ShutdownHook tries to cleanup temporary folders Author: Nithin Asokan Closes #8720 from nasokan/SPARK-10554. --- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3f8d26e1d4cab..f7e84a2c2e14c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -164,7 +164,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { From 22730ad54d681ad30e63fe910e8d89360853177d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 Sep 2015 10:40:10 +0100 Subject: [PATCH 0067/2089] [SPARK-10547] [TEST] Streamline / improve style of Java API tests Fix a few Java API test style issues: unused generic types, exceptions, wrong assert argument order Author: Sean Owen Closes #8706 from srowen/SPARK-10547. --- .../java/org/apache/spark/JavaAPISuite.java | 451 ++++++----- .../kafka/JavaDirectKafkaStreamSuite.java | 24 +- .../streaming/kafka/JavaKafkaRDDSuite.java | 17 +- .../streaming/kafka/JavaKafkaStreamSuite.java | 14 +- .../twitter/JavaTwitterStreamSuite.java | 4 +- .../java/org/apache/spark/Java8APISuite.java | 46 +- .../spark/sql/JavaApplySchemaSuite.java | 39 +- .../apache/spark/sql/JavaDataFrameSuite.java | 29 +- .../org/apache/spark/sql/JavaRowSuite.java | 15 +- .../org/apache/spark/sql/JavaUDFSuite.java | 9 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 10 +- .../spark/sql/hive/JavaDataFrameSuite.java | 8 +- .../hive/JavaMetastoreDataSourcesSuite.java | 12 +- .../apache/spark/streaming/JavaAPISuite.java | 752 +++++++++--------- .../spark/streaming/JavaReceiverAPISuite.java | 86 +- 15 files changed, 755 insertions(+), 761 deletions(-) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ebd3d61ae7324..fd8f7f39b7cc8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -90,7 +90,7 @@ public void sparkContextUnion() { JavaRDD sUnion = sc.union(s1, s2); Assert.assertEquals(4, sUnion.count()); // List - List> list = new ArrayList>(); + List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); Assert.assertEquals(4, sUnion.count()); @@ -103,9 +103,9 @@ public void sparkContextUnion() { Assert.assertEquals(4, dUnion.count()); // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); @@ -133,9 +133,9 @@ public void intersection() { JavaDoubleRDD dIntersection = d1.intersection(d2); Assert.assertEquals(2, dIntersection.count()); - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); @@ -165,47 +165,49 @@ public void randomSplit() { @Test public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaPairRDD rdd = sc.parallelizePairs(pairs); // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 5)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(2, 6)); - pairs.add(new Tuple2(0, 8)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(1, 3)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 5)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(2, 6)); + pairs.add(new Tuple2<>(0, 8)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(1, 3)); JavaPairRDD rdd = sc.parallelizePairs(pairs); Partitioner partitioner = new Partitioner() { + @Override public int numPartitions() { return 2; } + @Override public int getPartition(Object key) { - return ((Integer)key).intValue() % 2; + return (Integer) key % 2; } }; @@ -214,10 +216,10 @@ public int getPartition(Object key) { Assert.assertTrue(repartitioned.partitioner().isPresent()); Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), - new Tuple2(0, 8), new Tuple2(2, 6))); - Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), - new Tuple2(3, 8), new Tuple2(3, 8))); + Assert.assertEquals(partitions.get(0), + Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); + Assert.assertEquals(partitions.get(1), + Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test @@ -228,35 +230,37 @@ public void emptyRDD() { @Test public void sortBy() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaRDD> rdd = sc.parallelize(pairs); // compare on first value JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -265,7 +269,7 @@ public void foreach() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) throws IOException { + public void call(String s) { accum.add(1); } }); @@ -278,7 +282,7 @@ public void foreachPartition() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override - public void call(Iterator iter) throws IOException { + public void call(Iterator iter) { while (iter.hasNext()) { iter.next(); accum.add(1); @@ -301,7 +305,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet(indexes.collect()).size()); + Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -317,10 +321,10 @@ public void zipWithIndex() { @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") + )); Assert.assertEquals(2, categories.lookup("Oranges").size()); Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @@ -390,18 +394,17 @@ public String call(Tuple2 x) { @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -411,23 +414,22 @@ public void cogroup() { @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); @@ -439,27 +441,26 @@ public void cogroup3() { @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", "BR"), - new Tuple2("Apples", "US") + new Tuple2<>("Oranges", "BR"), + new Tuple2<>("Apples", "US") )); JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); @@ -471,16 +472,16 @@ public void cogroup4() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -548,11 +549,11 @@ public Integer call(Integer a, Integer b) { public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(5, 1), - new Tuple2(5, 3)), 2); + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(5, 1), + new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), new Function2, Integer, Set>() { @@ -570,20 +571,20 @@ public Set call(Set a, Set b) { } }).collectAsMap(); Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, @@ -602,11 +603,11 @@ public Integer call(Integer a, Integer b) { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey( @@ -690,7 +691,7 @@ public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test @@ -743,6 +744,7 @@ public void javaDoubleRDDHistoGram() { } private static class DoubleComparator implements Comparator, Serializable { + @Override public int compare(Double o1, Double o2) { return o1.compareTo(o2); } @@ -766,14 +768,14 @@ public void min() { public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertTrue(4.0 == max); + Assert.assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertTrue(1.0 == max); + Assert.assertEquals(1.0, max, 0.0); } @Test @@ -809,7 +811,7 @@ public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double sum = rdd.reduce(new Function2() { @Override - public Double call(Double v1, Double v2) throws Exception { + public Double call(Double v1, Double v2) { return v1 + v2; } }); @@ -844,7 +846,7 @@ public double call(Integer x) { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, x); + return new Tuple2<>(x, x); } }).cache(); pairs.collect(); @@ -870,26 +872,25 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMapToPair( + JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { - @Override public Iterable> call(String s) { - List> pairs = new LinkedList>(); + List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { - pairs.add(new Tuple2(word, word)); + pairs.add(new Tuple2<>(word, word)); } return pairs; } } ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + Assert.assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } @@ -897,36 +898,36 @@ public Iterable call(String s) { } }); Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); - } + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMapToPair( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) { + return Collections.singletonList(item.swap()); + } }); - swapped.collect(); + swapped.collect(); - // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + // There was never a bug here, but it's worth testing: + pairRDD.mapToPair(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) { + return item.swap(); + } + }).collect(); } @Test @@ -953,7 +954,7 @@ public void mapPartitionsWithIndex() { JavaRDD partitionSums = rdd.mapPartitionsWithIndex( new Function2, Iterator>() { @Override - public Iterator call(Integer index, Iterator iter) throws Exception { + public Iterator call(Integer index, Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); @@ -972,8 +973,8 @@ public void repartition() { JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); Assert.assertEquals(4, result1.size()); - for (List l: result1) { - Assert.assertTrue(l.size() > 0); + for (List l : result1) { + Assert.assertFalse(l.isEmpty()); } // Growing number of partitions @@ -982,7 +983,7 @@ public void repartition() { List> result2 = repartitioned2.glom().collect(); Assert.assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertTrue(l.size() > 0); + Assert.assertFalse(l.isEmpty()); } } @@ -994,9 +995,9 @@ public void persist() { Assert.assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); @@ -1046,7 +1047,7 @@ public void wholeTextFiles() throws Exception { Files.write(content1, new File(tempDirName + "/part-00000")); Files.write(content2, new File(tempDirName + "/part-00001")); - Map container = new HashMap(); + Map container = new HashMap<>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -1075,16 +1076,16 @@ public void textFilesCompressed() throws IOException { public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); @@ -1093,7 +1094,7 @@ public Tuple2 call(Tuple2 pair) { Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); + return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); Assert.assertEquals(pairs, readRDD.collect()); @@ -1110,7 +1111,7 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); @@ -1131,14 +1132,14 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(new VoidFunction>() { @Override - public void call(Tuple2 pair) throws Exception { + public void call(Tuple2 pair) { pair._2().toArray(); // force the file to read } }); @@ -1162,7 +1163,7 @@ public void binaryRecords() throws Exception { FileChannel channel1 = fos1.getChannel(); for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); } channel1.close(); @@ -1180,24 +1181,23 @@ public void binaryRecords() throws Exception { public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + }).saveAsNewAPIHadoopFile( + outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1210,24 +1210,23 @@ public String call(Tuple2 x) { public void readWithNewAPIHadoopFile() throws IOException { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1251,9 +1250,9 @@ public void objectFilesOfInts() { public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.saveAsObjectFile(outputDir); @@ -1267,23 +1266,22 @@ public void objectFilesOfComplexTypes() { public void hadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1296,16 +1294,16 @@ public String call(Tuple2 x) { public void hadoopFileCompressed() { String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); @@ -1313,8 +1311,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1414,8 +1411,8 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test @@ -1448,20 +1445,20 @@ public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); Function keyFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1 % 3; } }; Function createCombinerFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1; } }; Function2 mergeValueFunction = new Function2() { @Override - public Integer call(Integer v1, Integer v2) throws Exception { + public Integer call(Integer v1, Integer v2) { return v1 + v2; } }; @@ -1496,21 +1493,21 @@ public void mapOnPairRDD() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2(in._2(), in._1()); - } - }); + @Override + public Tuple2 call(Tuple2 in) { + return new Tuple2<>(in._2(), in._1()); + } + }); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @@ -1523,7 +1520,7 @@ public void collectPartitions() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); @@ -1534,23 +1531,23 @@ public Tuple2 call(Integer i) { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), + new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), parts2[1]); } @Test public void countApproxDistinct() { - List arrayData = new ArrayList(); + List arrayData = new ArrayList<>(); int size = 100; for (int i = 0; i < 100000; i++) { arrayData.add(i % size); @@ -1561,15 +1558,15 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { - List> arrayData = new ArrayList>(); + List> arrayData = new ArrayList<>(); for (int i = 10; i < 100; i++) { for (int j = 0; j < i; j++) { - arrayData.add(new Tuple2(i, j)); + arrayData.add(new Tuple2<>(i, j)); } } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); for (Tuple2 resItem : res) { double count = (double)resItem._1(); Long resCount = (Long)resItem._2(); @@ -1587,7 +1584,7 @@ public void collectAsMapWithIntArrayValues() { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, new int[] { x }); + return new Tuple2<>(x, new int[]{x}); } }); pairRDD.collect(); // Works fine @@ -1598,7 +1595,7 @@ public Tuple2 call(Integer x) { @Test public void collectAsMapAndSerialize() throws Exception { JavaPairRDD rdd = - sc.parallelizePairs(Arrays.asList(new Tuple2("foo", 1))); + sc.parallelizePairs(Arrays.asList(new Tuple2<>("foo", 1))); Map map = rdd.collectAsMap(); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); new ObjectOutputStream(bytes).writeObject(map); @@ -1615,7 +1612,7 @@ public void sampleByKey() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1623,12 +1620,12 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertTrue(wrCounts.size() == 2); + Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertTrue(worCounts.size() == 2); + Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); } @@ -1641,7 +1638,7 @@ public void sampleByKeyExact() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1649,25 +1646,25 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { - public SomeCustomClass() { + SomeCustomClass() { // Intentionally left blank } } @Test public void collectUnderlyingScalaRDD() { - List data = new ArrayList(); + List data = new ArrayList<>(); for (int i = 0; i < 100; i++) { data.add(new SomeCustomClass()); } @@ -1679,7 +1676,7 @@ public void collectUnderlyingScalaRDD() { private static final class BuggyMapFunction implements Function { @Override - public T call(T x) throws Exception { + public T call(T x) { throw new IllegalStateException("Custom exception!"); } } @@ -1716,7 +1713,7 @@ public void foreachAsync() throws Exception { JavaFutureAction future = rdd.foreachAsync( new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) { // intentionally left blank. } } @@ -1745,7 +1742,7 @@ public void testAsyncActionCancellation() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) throws InterruptedException { Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. } }); diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 9db07d0507fea..fbdfbf7e509b3 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -75,11 +75,11 @@ public void testKafkaStream() throws InterruptedException { String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); - HashSet sent = new HashSet(); + Set sent = new HashSet<>(); sent.addAll(Arrays.asList(topic1data)); sent.addAll(Arrays.asList(topic2data)); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); @@ -95,17 +95,17 @@ public void testKafkaStream() throws InterruptedException { // Make sure you can get offset ranges from the rdd new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd) { OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); offsetRanges.set(offsets); - Assert.assertEquals(offsets[0].topic(), topic1); + Assert.assertEquals(topic1, offsets[0].topic()); return rdd; } } ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -119,10 +119,10 @@ public String call(Tuple2 kv) throws Exception { StringDecoder.class, String.class, kafkaParams, - topicOffsetToMap(topic2, (long) 0), + topicOffsetToMap(topic2, 0L), new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -133,7 +133,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception unifiedStream.foreachRDD( new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public Void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( @@ -155,14 +155,14 @@ public Void call(JavaRDD rdd) throws Exception { ssc.stop(); } - private HashSet topicToSet(String topic) { - HashSet topicSet = new HashSet(); + private static Set topicToSet(String topic) { + Set topicSet = new HashSet<>(); topicSet.add(topic); return topicSet; } - private HashMap topicOffsetToMap(String topic, Long offsetToStart) { - HashMap topicMap = new HashMap(); + private static Map topicOffsetToMap(String topic, Long offsetToStart) { + Map topicMap = new HashMap<>(); topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); return topicMap; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613ca..afcc6cfccd39a 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -66,10 +67,10 @@ public void testKafkaRDD() throws InterruptedException { String topic1 = "topic1"; String topic2 = "topic2"; - String[] topic1data = createTopicAndSendData(topic1); - String[] topic2data = createTopicAndSendData(topic2); + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { @@ -77,8 +78,8 @@ public void testKafkaRDD() throws InterruptedException { OffsetRange.create(topic2, 0, 0, 1) }; - HashMap emptyLeaders = new HashMap(); - HashMap leaders = new HashMap(); + Map emptyLeaders = new HashMap<>(); + Map leaders = new HashMap<>(); String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); @@ -95,7 +96,7 @@ public void testKafkaRDD() throws InterruptedException { ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -113,7 +114,7 @@ public String call(Tuple2 kv) throws Exception { emptyLeaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -131,7 +132,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception leaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index e4c659215b767..1e69de46cd35d 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -67,10 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { String topic = "topic1"; - HashMap topics = new HashMap(); + Map topics = new HashMap<>(); topics.put(topic, 1); - HashMap sent = new HashMap(); + Map sent = new HashMap<>(); sent.put("a", 5); sent.put("b", 3); sent.put("c", 10); @@ -78,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { kafkaTestUtils.createTopic(topic); kafkaTestUtils.sendMessages(topic, sent); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +97,7 @@ public void testKafkaStream() throws InterruptedException { JavaDStream words = stream.map( new Function, String>() { @Override - public String call(Tuple2 tuple2) throws Exception { + public String call(Tuple2 tuple2) { return tuple2._2(); } } @@ -106,7 +106,7 @@ public String call(Tuple2 tuple2) throws Exception { words.countByValue().foreachRDD( new Function, Void>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public Void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -130,8 +130,8 @@ public Void call(JavaPairRDD rdd) throws Exception { Thread.sleep(200); } Assert.assertEquals(sent.size(), result.size()); - for (String k : sent.keySet()) { - Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); } } } diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java index e46b4e5c7531d..26ec8af455bcf 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.streaming.twitter; -import java.util.Arrays; - import org.junit.Test; import twitter4j.Status; import twitter4j.auth.Authorization; @@ -30,7 +28,7 @@ public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { @Test public void testTwitterStream() { - String[] filters = (String[])Arrays.asList("filter1", "filter2").toArray(); + String[] filters = { "filter1", "filter2" }; Authorization auth = NullAuthorization.getInstance(); // tests the API, does not actually test data receiving diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 729bc0459ce52..14975265ab2ce 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -77,7 +77,7 @@ public void call(String s) { public void foreach() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach((x) -> foreachCalls++); + rdd.foreach(x -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } @@ -180,7 +180,7 @@ public void map() { JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); - JavaRDD strings = rdd.map(x -> x.toString()).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -195,7 +195,9 @@ public void flatMap() { JavaPairRDD pairs = rdd.flatMapToPair(s -> { List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } return pairs2; }); @@ -204,11 +206,12 @@ public void flatMap() { JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { List lengths = new LinkedList<>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; }); - Double x = doubles.first(); Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @@ -228,7 +231,7 @@ public void mapsFromPairsToPairs() { swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(item -> item.swap()).collect(); + pairRDD.map(Tuple2::swap).collect(); } @Test @@ -282,11 +285,11 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = (Iterator i, Iterator s) -> { int sizeI = 0; - int sizeS = 0; while (i.hasNext()) { sizeI += 1; i.next(); } + int sizeS = 0; while (s.hasNext()) { sizeS += 1; s.next(); @@ -301,30 +304,31 @@ public void zipPartitions() { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(x -> intAccum.add(x)); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(x -> doubleAccum.add((double) x)); Assert.assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } - + @Override public Float addAccumulator(Float r, Float t) { return r + t; } - + @Override public Float zero(Float initialValue) { return 0.0f; } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -336,7 +340,7 @@ public Float zero(Float initialValue) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(x -> x.toString()).collect(); + List> s = rdd.keyBy(Object::toString).collect(); Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -349,7 +353,7 @@ public void mapOnPairRDD() { JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), + new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), new Tuple2<>(0, 4)), rdd3.collect()); @@ -361,7 +365,7 @@ public void collectPartitions() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); + List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[]{1, 2}); @@ -371,19 +375,19 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); - parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts[1]); + parts2[1]); } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf693c7c393f6..7b50aad4ad498 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -83,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -95,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -118,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -129,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -165,7 +168,7 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); + List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); @@ -175,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -187,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4867cebf5328c..d981ce947f435 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -61,7 +61,7 @@ public void tearDown() { @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -119,7 +119,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -161,7 +161,7 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); @@ -180,7 +180,8 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -193,16 +194,16 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } @@ -210,7 +211,7 @@ public void testCrosstab() { @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -219,14 +220,14 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test @@ -234,7 +235,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26a..3ab4db2a035d3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bb02b58cca9be..4a78dca7fea66 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,6 +20,7 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -61,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -81,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6f9e7f68dc39c..9e241f20987c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -44,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -64,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -82,7 +82,7 @@ public void tearDown() { @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -91,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 019d8a30266e2..b4bf9eef8fca5 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -40,7 +40,7 @@ public class JavaDataFrameSuite { DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -52,7 +52,7 @@ public void setUp() throws IOException { hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } @@ -71,7 +71,7 @@ public void tearDown() throws IOException { @Test public void saveTableAndQueryIt() { checkAnswer( - df.select(functions.avg("key").over( + df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + @@ -95,7 +95,7 @@ public void testUDAF() { registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); - List expectedResult = new ArrayList(); + List expectedResult = new ArrayList<>(); expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); checkAnswer( aggregatedDF, diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 4192155975c47..c8d272794d10b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -53,7 +53,7 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -77,7 +77,7 @@ public void setUp() throws IOException { fs.delete(hiveManagedPath, true); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -97,7 +97,7 @@ public void tearDown() throws IOException { @Test public void saveExternalTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -120,7 +120,7 @@ public void saveExternalTableAndQueryIt() { @Test public void saveExternalTableWithSchemaAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -132,7 +132,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = @@ -148,7 +148,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); df.write() .format("org.apache.spark.sql.json") .mode(SaveMode.Append) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e0718f73aa13f..c5217149224e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,24 +18,22 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -54,14 +52,14 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +72,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +116,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +178,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +187,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +241,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +260,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +284,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +345,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -378,11 +376,11 @@ public void testQueueStream() { Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -410,10 +408,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -435,70 +433,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -511,32 +509,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -552,14 +550,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -567,9 +563,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -587,89 +583,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -690,13 +686,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -707,7 +703,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -733,8 +729,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -763,7 +759,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -782,39 +778,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -859,13 +855,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -877,25 +873,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -906,28 +902,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -936,22 +932,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -969,23 +965,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1014,7 +1010,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1030,23 +1026,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1054,10 +1050,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1075,11 +1071,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1111,11 +1107,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1136,20 +1132,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1170,13 +1166,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1193,16 +1189,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1220,16 +1216,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1238,12 +1234,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1262,12 +1258,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1278,10 +1274,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1298,19 +1294,19 @@ public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1321,10 +1317,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1341,19 +1337,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1370,15 +1366,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1386,7 +1382,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1399,27 +1395,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1428,7 +1424,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1444,15 +1440,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1465,11 +1461,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1487,14 +1483,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1502,8 +1498,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1519,22 +1515,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1545,7 +1541,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1562,29 +1558,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1620,29 +1616,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1664,13 +1660,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1713,7 +1709,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1752,6 +1748,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1765,20 +1762,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1800,7 +1797,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1818,29 +1815,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1870,7 +1864,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1879,7 +1873,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1892,19 +1886,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69dec..ec2bffd6a5b97 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -36,7 +36,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +63,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +82,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +94,49 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = new Socket(host, port); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + in.close(); + socket.close(); + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} From f4a22808e03fa12bfe1bfc82cf713cfda7e063a9 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Sat, 12 Sep 2015 10:17:15 -0700 Subject: [PATCH 0068/2089] [SPARK-6548] Adding stddev to DataFrame functions Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change. Author: JihongMa Author: Jihong MA Author: Jihong MA Author: Jihong MA Closes #6297 from JihongMA/SPARK-SQL. --- R/pkg/inst/tests/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 36 +-- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../catalyst/analysis/HiveTypeCoercion.scala | 3 + .../spark/sql/catalyst/dsl/package.scala | 3 + .../expressions/aggregate/functions.scala | 143 ++++++++++ .../expressions/aggregate/utils.scala | 18 ++ .../sql/catalyst/expressions/aggregates.scala | 245 ++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/GroupedData.scala | 39 +++ .../org/apache/spark/sql/functions.scala | 27 ++ .../apache/spark/sql/JavaDataFrameSuite.java | 1 + .../spark/sql/DataFrameAggregateSuite.scala | 33 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 42 ++- .../execution/AggregationQuerySuite.scala | 35 --- 16 files changed, 574 insertions(+), 64 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1ccfde59176f5..98d4402d368e1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1147,7 +1147,7 @@ test_that("describe() and summarize() on a DataFrame", { stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c5bf55791240b..fb995fa3a76b5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -653,25 +653,25 @@ def describe(self, *cols): guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() - +-------+---+ - |summary|age| - +-------+---+ - | count| 2| - | mean|3.5| - | stddev|1.5| - | min| 2| - | max| 5| - +-------+---+ + +-------+------------------+ + |summary| age| + +-------+------------------+ + | count| 2| + | mean| 3.5| + | stddev|2.1213203435596424| + | min| 2| + | max| 5| + +-------+------------------+ >>> df.describe(['age', 'name']).show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | mean|3.5| null| - | stddev|1.5| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+---+-----+ + +-------+------------------+-----+ + |summary| age| name| + +-------+------------------+-----+ + | count| 2| 2| + | mean| 3.5| null| + | stddev|2.1213203435596424| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d788151..11b4866bf264b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,6 +168,9 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), + expression[Stddev]("stddev"), + expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), // string functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87c11abbad490..87a3845b2d9e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,6 +297,9 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a7e3a49327655..699c4cc63d09a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,6 +159,9 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def stddev(e: Expression): Expression = Stddev(e) + def stddev_pop(e: Expression): Expression = StddevPop(e) + def stddev_samp(e: Expression): Expression = StddevSamp(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba1..02cd0ac0db118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = min } +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev" +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + def isSample: Boolean + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val preCount = AttributeReference("preCount", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + private val preAvg = AttributeReference("preAvg", resultType)() + private val currentAvg = AttributeReference("currentAvg", resultType)() + private val currentMk = AttributeReference("currentMk", resultType)() + + override val bufferAttributes = preCount :: currentCount :: preAvg :: + currentAvg :: currentMk :: Nil + + override val initialValues = Seq( + /* preCount = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType), + /* preAvg = */ Cast(Literal(0), resultType), + /* currentAvg = */ Cast(Literal(0), resultType), + /* currentMk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + + // update average + // avg = avg + (value - avg)/count + def avgAdd: Expression = { + currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) + } + + // update sum of square of difference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + def mkAdd: Expression = { + val delta1 = Cast(child, resultType) - preAvg + val delta2 = Cast(child, resultType) - currentAvg + currentMk + (delta1 * delta2) + } + + Seq( + /* preCount = */ If(IsNull(child), preCount, currentCount), + /* currentCount = */ If(IsNull(child), currentCount, + Add(currentCount, Cast(Literal(1), resultType))), + /* preAvg = */ If(IsNull(child), preAvg, currentAvg), + /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), + /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + ) + } + + override val mergeExpressions = { + + // count merge + def countMerge: Expression = { + currentCount.left + currentCount.right + } + + // average merge + def avgMerge: Expression = { + ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / + (preCount + currentCount.right) + } + + // update sum of square differences + def mkMerge: Expression = { + val avgDelta = currentAvg.right - preAvg + val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / + (preCount + currentCount.right) + + currentMk.left + currentMk.right + mkDelta + } + + Seq( + /* preCount = */ If(IsNull(currentCount.left), + Cast(Literal(0), resultType), currentCount.left), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, countMerge)), + /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), + /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, + If(IsNull(currentMk.right), currentMk.left, mkMerge)) + ) + } + + override val evaluateExpression = { + // when currentCount == 0, return null + // when currentCount == 1, return 0 + // when currentCount >1 + // stddev_samp = sqrt (currentMk/(currentCount -1)) + // stddev_pop = sqrt (currentMk/currentCount) + val varCol = { + if (isSample) { + currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) + } + else { + currentMk / currentCount + } + } + + If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} + case class Sum(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a95490..ce3dddad87f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -85,6 +85,24 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Stddev(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Stddev(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevPop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevPop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevSamp(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Sum(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Sum(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9cb..f1c47f39043c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag result } } + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + + def isSample: Boolean + + override def asPartial: SplitEvaluation = { + val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() + SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) + } + + override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") + +} + +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV($child)" + override def isSample: Boolean = true +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_POP($child)" + override def isSample: Boolean = false +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_SAMP($child)" + override def isSample: Boolean = true +} + +case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) +} + +case class ComputePartialStdFunction ( + expr: Expression, + base: AggregateExpression1 +) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private var partialCount: Long = 0L + + // the mean of data processed so far + private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update average based on this formula: + // avg = avg + (value - avg)/count + private def avgAddFunction (value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), partialAvg) + Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) + } + + // the sum of squares of difference from mean + private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update sum of square of difference from mean based on following formula: + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), prePartialAvg) + val delta2 = Subtract(Cast(value, computeType), partialAvg) + Add(partialMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + val prePartialAvg = partialAvg.copy() + partialCount += 1 + partialAvg.update(avgAddFunction(exprValue), input) + partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), + partialAvg.eval(null), + partialMk.eval(null))) + } +} + +case class MergePartialStd( + child: Expression, + isSample: Boolean +) extends UnaryExpression with AggregateExpression1 { + def this() = this(null, false) // required for serialization + + override def children: Seq[Expression] = child:: Nil + override def nullable: Boolean = false + override def dataType: DataType = DoubleType + override def toString: String = s"MergePartialStd($child)" + override def newInstance(): MergePartialStdFunction = { + new MergePartialStdFunction(child, this, isSample) + } +} + +case class MergePartialStdFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + def this() = this (null, null, false) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private val combineCount = MutableLiteral(zero.eval(null), computeType) + private val combineAvg = MutableLiteral(zero.eval(null), computeType) + private val combineMk = MutableLiteral(zero.eval(null), computeType) + + private def avgUpdateFunction(preCount: Expression, + partialCount: Expression, + partialAvg: Expression): Expression = { + Divide(Add(Multiply(combineAvg, preCount), + Multiply(partialAvg, partialCount)), + Add(preCount, partialCount)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] + + if (evaluatedExpr != null) { + val exprValue = evaluatedExpr.toArray(computeType) + val (partialCount, partialAvg, partialMk) = + (Literal.create(exprValue(0), computeType), + Literal.create(exprValue(1), computeType), + Literal.create(exprValue(2), computeType)) + + if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { + val preCount = combineCount.copy() + combineCount.update(Add(combineCount, partialCount), input) + + val preAvg = combineAvg.copy() + val avgDelta = Subtract(partialAvg, preAvg) + val mkDelta = Multiply(Multiply(avgDelta, avgDelta), + Divide(Multiply(preCount, partialCount), + combineCount)) + + // update average based on following formula + // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) + combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) + + // update sum of square differences from mean based on following formula + // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) + combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) + } + } + } + + override def eval(input: InternalRow): Any = { + val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] + + if (count == 0) null + else if (count < 2) zero.eval(null) + else { + // when total count > 2 + // stddev_samp = sqrt (combineMk/(combineCount -1)) + // stddev_pop = sqrt (combineMk/combineCount) + val varCol = { + if (isSample) { + Divide(combineMk, Cast(Literal(count - 1), computeType)) + } + else { + Divide(combineMk, Cast(Literal(count), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} + +case class StddevFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + + def this() = this(null, null, false) // Required for serialization + + private val computeType = DoubleType + private var curCount: Long = 0L + private val zero = Cast(Literal(0), computeType) + private val curAvg = MutableLiteral(zero.eval(null), computeType) + private val curMk = MutableLiteral(zero.eval(null), computeType) + + private def curAvgAddFunction(value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), curAvg) + Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) + } + private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), preAvg) + val delta2 = Subtract(Cast(value, computeType), curAvg) + Add(curMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val preAvg: MutableLiteral = curAvg.copy() + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + curCount += 1L + curAvg.update(curAvgAddFunction(exprValue), input) + curMk.update(curMkAddFunction(exprValue, preAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + if (curCount == 0) null + else if (curCount < 2) zero.eval(null) + else { + // when total count > 2, + // stddev_samp = sqrt(curMk/(curCount - 1)) + // stddev_pop = sqrt(curMk/curCount) + val varCol = { + if (isSample) { + Divide(curMk, Cast(Literal(curCount - 1), computeType)) + } + else { + Divide(curMk, Cast(Literal(curCount), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 791c10c3d7ce7..1a687b2374f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1288,15 +1288,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> Stddev, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee31d83cce42c..102b802ad0a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -124,6 +124,9 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min + case "stddev" => Stddev + case "stddev_pop" => StddevPop + case "stddev_samp" => StddevSamp case "sum" => Sum case "count" | "size" => // Turn count(*) into count(1) @@ -283,6 +286,42 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Stddev) + } + + /** + * Compute the population standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevPop) + } + + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevSamp) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a64c4..60d9c509104d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -294,6 +294,33 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = Stddev(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index d981ce947f435..5f9abd4999ce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -90,6 +90,7 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c0950b09b14ad..f5ef9ffd7f4f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dbed4fc247140..c167999af580e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -436,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 664b7a1512faf..962b100b532c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index b126ec455fc69..a73b1bd52c09f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -507,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te }.getMessage assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } - - // TODO: once we support Hive UDAF in the new interface, - // we can remove the following two tests. - withSQLConf("spark.sql.useAggregate2" -> "true") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | mydoublesum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.SortBasedAggregate => agg - case agg: aggregate.TungstenAggregate => agg - } - val message = - "We should fallback to the old aggregation code path if " + - "there is any aggregate function that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) - } } } From b3a7480ab0821ab38f710de96e3ac4a13f62dbca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 12 Sep 2015 16:23:55 -0700 Subject: [PATCH 0069/2089] [SPARK-10330] Add Scalastyle rule to require use of SparkHadoopUtil JobContext methods This is a followup to #8499 which adds a Scalastyle rule to mandate the use of SparkHadoopUtil's JobContext accessor methods and fixes the existing violations. Author: Josh Rosen Closes #8521 from JoshRosen/SPARK-10330-part2. --- .../src/main/scala/org/apache/spark/SparkContext.scala | 6 +++--- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 4 ++++ .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 8 +++++--- .../scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala | 2 +- core/src/test/scala/org/apache/spark/FileSuite.scala | 6 ++++-- .../org/apache/spark/examples/CassandraCQLTest.scala | 3 +++ .../org/apache/spark/examples/CassandraTest.scala | 2 ++ scalastyle-config.xml | 8 ++++++++ .../sql/execution/datasources/WriterContainer.scala | 8 ++++++-- .../sql/execution/datasources/json/JSONRelation.scala | 2 +- .../datasources/parquet/CatalystReadSupport.scala | 6 +++++- .../parquet/DirectParquetOutputCommitter.scala | 6 +++++- .../datasources/parquet/ParquetRelation.scala | 10 +++++++--- .../datasources/parquet/ParquetTypesConverter.scala | 6 +++++- .../org/apache/spark/sql/hive/orc/OrcRelation.scala | 4 ++-- 15 files changed, 61 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cbfe8bf31c3d6..e27b3c4962221 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -858,7 +858,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -910,7 +910,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1092,7 +1092,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f7723ef5bde4c..a0b7365df900a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -192,7 +192,9 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } @@ -204,7 +206,9 @@ class SparkHadoopUtil extends Logging { */ def getTaskAttemptIDFromTaskAttemptContext( context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c59f0d4aa75a0..199d79b811d65 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -996,8 +996,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -1064,7 +1065,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 9babe56267e08..0228c54e0511c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -86,7 +86,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 418763f4e5ffa..fdb00aafc4a48 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel @@ -506,8 +507,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - job.getConfiguration.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") - randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index fa07c1e5017cd..d1b9b8d398dd8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println + // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -81,6 +82,7 @@ object CassandraCQLTest { val job = new Job() job.setInputFormatClass(classOf[CqlPagingInputFormat]) + val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) @@ -135,3 +137,4 @@ object CassandraCQLTest { } } // scalastyle:on println +// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 2e56d24c60c33..1e679bfb55343 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println +// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { } } // scalastyle:on println +// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68fdb4141cf27..64a0c71bbef2a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -168,6 +168,14 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + ^getConfiguration$|^getTaskAttemptID$ + Instead of calling .getConfiguration() or .getTaskAttemptID() directly, + use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9a573db0c023a..f8ef674ed29c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -47,7 +47,8 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + protected val serializableConf = + new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -89,7 +90,8 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + SparkHadoopUtil.get.getConfigurationFromJobContext(job). + set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -182,7 +184,9 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) + // scalastyle:off jobcontext this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + // scalastyle:on jobcontext } private def setupConf(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 7a49157d9e72c..8ee0127c3bde8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -81,7 +81,7 @@ private[sql] class JSONRelation( private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) val paths = inputPaths.map(_.getPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 5a8166fac5418..8c819f1a48cd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -72,7 +72,11 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // Called before `prepareForRead()` when initializing Parquet record reader. override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration + val conf = { + // scalastyle:off jobcontext + context.getConfiguration + // scalastyle:on jobcontext + } // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst // schema of this file from its metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 2c6b914328b60..de1fd0166ac5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -53,7 +53,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = ContextUtil.getConfiguration(jobContext) + val configuration = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(jobContext) + // scalastyle:on jobcontext + } val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c6bbc392cad4c..953fcab126970 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -211,7 +211,11 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) + val conf = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(job) + // scalastyle:on jobcontext + } // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -528,7 +532,7 @@ private[sql] object ParquetRelation extends Logging { assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean, followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -572,7 +576,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, job.getConfiguration) + overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) } private[parquet] def readSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 142301fe87cb6..b647bb6116afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -123,7 +123,11 @@ private[parquet] object ParquetTypesConverter extends Logging { throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") } val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val conf = { + // scalastyle:off jobcontext + configuration.getOrElse(ContextUtil.getConfiguration(job)) + // scalastyle:on jobcontext + } val fs: FileSystem = origPath.getFileSystem(conf) if (fs == null) { throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7e89109259955..d1f30e188eafb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -208,7 +208,7 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - job.getConfiguration match { + SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -289,7 +289,7 @@ private[orc] case class OrcTableScan( def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { From 1dc614b874badde0eee60def46fb47f608bc4759 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 13 Sep 2015 08:36:46 +0100 Subject: [PATCH 0070/2089] [SPARK-10222] [GRAPHX] [DOCS] More thoroughly deprecate Bagel in favor of GraphX Finish deprecating Bagel; remove reference to nonexistent example Author: Sean Owen Closes #8731 from srowen/SPARK-10222. --- .../src/main/scala/org/apache/spark/bagel/Bagel.scala | 6 ++++++ docs/bagel-programming-guide.md | 10 +--------- docs/index.md | 1 - pom.xml | 2 +- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 4e6b7686f771d..8399033ac61ec 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index c2fe6b0e286ce..347ca4a7af989 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -4,7 +4,7 @@ displayTitle: Bagel Programming Guide title: Bagel --- -**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** +**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. @@ -157,11 +157,3 @@ trait Message[K] { def targetId: K } {% endhighlight %} - -# Where to Go from Here - -Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/org/apache/spark/examples/bagel`. You can run them by passing the class name to the `bin/run-example` script included in Spark; e.g.: - - ./bin/run-example org.apache.spark.examples.bagel.WikipediaPageRank - -Each example program prints usage help when run without any arguments. diff --git a/docs/index.md b/docs/index.md index d85cf12defefd..c0dc2b8d7412a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -90,7 +90,6 @@ options for deployment: * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing - * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model **API Docs:** diff --git a/pom.xml b/pom.xml index 88ebceca769e9..421357e141572 100644 --- a/pom.xml +++ b/pom.xml @@ -87,7 +87,7 @@ core - bagel + bagel graphx mllib tools From d81565465cc6d4f38b4ed78036cded630c700388 Mon Sep 17 00:00:00 2001 From: Bertrand Dechoux Date: Mon, 14 Sep 2015 09:18:46 +0100 Subject: [PATCH 0071/2089] [SPARK-9720] [ML] Identifiable types need UID in toString methods A few Identifiable types did override their toString method but without using the parent implementation. As a consequence, the uid was not present anymore in the toString result. It is the default behaviour. This patch is a quick fix. The question of enforcement is still up. No tests have been written to verify the toString method behaviour. That would be long to do because all types should be tested and not only those which have a regression now. It is possible to enforce the condition using the compiler by making the toString method final but that would introduce unwanted potential API breaking changes (see jira). Author: Bertrand Dechoux Closes #8062 from BertrandDechoux/SPARK-9720. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../org/apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 4 ++-- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 0a75d5d22280f..b8eb49f9bdb48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -146,7 +146,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def toString: String = { - s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3073a2a61ce83..ad8683648b975 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -200,7 +200,7 @@ final class GBTClassificationModel( } override def toString: String = { - s"GBTClassificationModel with $numTrees trees" + s"GBTClassificationModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 69cb88a7e6718..082ea1ffad58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -198,7 +198,7 @@ class NaiveBayesModel private[ml] ( } override def toString: String = { - s"NaiveBayesModel with ${pi.size} classes" + s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 11a6d72468333..a6ebee1bb10af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -193,7 +193,7 @@ final class RandomForestClassificationModel private[ml] ( } override def toString: String = { - s"RandomForestClassificationModel with $numTrees trees" + s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index a7fa50444209b..dcd6fe3c406a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -129,7 +129,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def copy(extra: ParamMap): RFormula = defaultCopy(extra) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } /** @@ -171,7 +171,7 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" private def transformLabel(dataset: DataFrame): DataFrame = { val labelName = resolvedFormula.label diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index a2bcd67401d08..d9a244bea28d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -118,7 +118,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def toString: String = { - s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } /** Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b66e61f37dd5e..d841ecb9e58d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -189,7 +189,7 @@ final class GBTRegressionModel( } override def toString: String = { - s"GBTRegressionModel with $numTrees trees" + s"GBTRegressionModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f36da371f577..ddb7214416a69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -155,7 +155,7 @@ final class RandomForestRegressionModel private[ml] ( } override def toString: String = { - s"RandomForestRegressionModel with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" } /** From 32407bfd2bdbf84d65cacfa7554dae6a2332bc37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Sep 2015 11:51:39 -0700 Subject: [PATCH 0072/2089] [SPARK-9899] [SQL] log warning for direct output committer with speculation enabled This is a follow-up of https://github.com/apache/spark/pull/8317. When speculation is enabled, there may be multiply tasks writing to the same path. Generally it's OK as we will write to a temporary directory first and only one task can commit the temporary directory to target path. However, when we use direct output committer, tasks will write data to target path directly without temporary directory. This causes problems like corrupted data. Please see [PR comment](https://github.com/apache/spark/pull/8191#issuecomment-131598385) for more details. Unfortunately, we don't have a simple flag to tell if a output committer will write to temporary directory or not, so for safety, we have to disable any customized output committer when `speculation` is true. Author: Wenchen Fan Closes #8687 from cloud-fan/direct-committer. --- .../apache/spark/rdd/PairRDDFunctions.scala | 44 ++++++++++++++++--- .../hive/execution/InsertIntoHiveTable.scala | 17 ++++++- .../spark/sql/hive/hiveWriterContainers.scala | 1 - 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 199d79b811d65..a981b63942e6d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1018,6 +1018,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsHadoopFile( path: String, @@ -1030,10 +1035,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) hadoopConf.setOutputValueClass(valueClass) - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) hadoopConf.set("mapred.output.compress", "true") @@ -1047,6 +1049,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) } + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = hadoopConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) @@ -1057,6 +1072,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Configuration object for that storage system. The Conf should set an OutputFormat and any * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). @@ -1115,6 +1135,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobCommitter.getClass.getSimpleName + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + jobCommitter.setupJob(jobTaskContext) self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) @@ -1129,7 +1163,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1157,7 +1190,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 58f7fa640e8a9..0c700bdb370ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -62,7 +62,7 @@ case class InsertIntoHiveTable( def output: Seq[Attribute] = Seq.empty - def saveAsHiveFile( + private def saveAsHiveFile( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, @@ -178,6 +178,19 @@ case class InsertIntoHiveTable( val jobConf = new JobConf(sc.hiveconf) val jobConfSer = new SerializableJobConf(jobConf) + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 29a6f08f40728..4ca8042d22367 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils From cf2821ef5fd9965eb6256e8e8b3f1e00c0788098 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 14 Sep 2015 12:06:23 -0700 Subject: [PATCH 0073/2089] [SPARK-10584] [DOC] [SQL] Documentation about spark.sql.hive.metastore.version is wrong. The default value of hive metastore version is 1.2.1 but the documentation says the value of `spark.sql.hive.metastore.version` is 0.13.1. Also, we cannot get the default value by `sqlContext.getConf("spark.sql.hive.metastore.version")`. Author: Kousuke Saruta Closes #8739 from sarutak/SPARK-10584. --- docs/sql-programming-guide.md | 2 +- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6a1b0fbfa1eb3..a0b911d207243 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1687,7 +1687,7 @@ The following options can be used to configure the version of Hive that is used
Property NameDefaultMeaning
spark.sql.hive.metastore.version0.13.11.2.1 Version of the Hive metastore. Available options are 0.12.0 through 1.2.1. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e791cea96b41..d37ba5ddc2d80 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -111,8 +111,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * this does not necessarily need to be the same version of Hive that is used internally by * Spark SQL for execution. */ - protected[hive] def hiveMetastoreVersion: String = - getConf(HIVE_METASTORE_VERSION, hiveExecutionVersion) + protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) /** * The location of the jars that should be used to instantiate the HiveMetastoreClient. This @@ -202,7 +201,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") } // We recursively find all jars in the class loader chain, @@ -606,7 +605,11 @@ private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" - val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" + val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", + defaultValue = Some(hiveExecutionVersion), + doc = "Version of the Hive metastore. Available options are " + + s"0.12.0 through $hiveExecutionVersion.") + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), doc = s""" From ce6f3f163bc667cb5da9ab4331c8bad10cc0d701 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Sep 2015 12:08:52 -0700 Subject: [PATCH 0074/2089] [SPARK-10194] [MLLIB] [PYSPARK] SGD algorithms need convergenceTol parameter in Python [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382) added a ```convergenceTol``` parameter for GradientDescent-based methods in Scala. We need that parameter in Python; otherwise, Python users will not be able to adjust that behavior (or even reproduce behavior from previous releases since the default changed). Author: Yanbo Liang Closes #8457 from yanboliang/spark-10194. --- .../mllib/api/python/PythonMLLibAPI.scala | 20 +++++++++--- python/pyspark/mllib/classification.py | 17 +++++++--- python/pyspark/mllib/regression.py | 32 ++++++++++++------- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f585aacd452e0..69ce7f50709a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -132,7 +132,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) .setValidateData(validateData) @@ -141,6 +142,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, @@ -159,7 +161,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.setIntercept(intercept) .setValidateData(validateData) @@ -168,6 +171,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( lassoAlg, data, @@ -185,7 +189,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.setIntercept(intercept) .setValidateData(validateData) @@ -194,6 +199,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( ridgeAlg, data, @@ -212,7 +218,8 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) .setValidateData(validateData) @@ -221,6 +228,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, @@ -240,7 +248,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) .setValidateData(validateData) @@ -249,6 +258,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 8f27c446a66e8..cb4ee83678081 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -241,7 +241,7 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType="l2", intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a logistic regression model on the given data. @@ -274,11 +274,13 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -439,7 +441,7 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType="l2", - intercept=False, validateData=True): + intercept=False, validateData=True, convergenceTol=0.001): """ Train a support vector machine on the given data. @@ -472,11 +474,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) @@ -600,12 +604,15 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. :param regParam: L2 Regularization parameter. + :param convergenceTol: A condition which decides iteration termination. """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.regParam = regParam self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLogisticRegressionWithSGD, self).__init__( model=self._model) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 41946e3674fbe..256b7537fef6b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -28,7 +28,8 @@ 'LinearRegressionModel', 'LinearRegressionWithSGD', 'RidgeRegressionModel', 'RidgeRegressionWithSGD', 'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel', - 'IsotonicRegression'] + 'IsotonicRegression', 'StreamingLinearAlgorithm', + 'StreamingLinearRegressionWithSGD'] class LabeledPoint(object): @@ -202,7 +203,7 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient Descent (SGD). @@ -244,11 +245,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept), bool(validateData)) + regType, bool(intercept), bool(validateData), + float(convergenceTol)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -330,7 +334,7 @@ class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L1-regularization using Stochastic Gradient Descent. @@ -362,11 +366,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -449,7 +455,7 @@ class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L2-regularization using Stochastic Gradient Descent. @@ -481,11 +487,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) @@ -636,15 +644,17 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): After training on a batch of data, the weights obtained at the end of training are used as initial weights for the next batch. - :param: stepSize Step size for each iteration of gradient descent. - :param: numIterations Total number of iterations run. - :param: miniBatchFraction Fraction of data on which SGD is run for each + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Total number of iterations run. + :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. + :param convergenceTol: A condition which decides iteration termination. """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLinearRegressionWithSGD, self).__init__( model=self._model) From 8a634e9bcc671167613fb575c6c0c054fb4b3479 Mon Sep 17 00:00:00 2001 From: Nick Pritchard Date: Mon, 14 Sep 2015 13:27:45 -0700 Subject: [PATCH 0075/2089] [SPARK-10573] [ML] IndexToString output schema should be StringType Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard Closes #8751 from pnpritchard/SPARK-10573. --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 5 ++--- .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3a4ab9a857648..2b1592930e77b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** @@ -229,8 +229,7 @@ class IndexToString private[ml] ( val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val outputFields = inputFields :+ StructField($(outputCol), StringType) StructType(outputFields) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 05e05bdc64bb1..ddcdb5f4212be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(a === b) } } + + test("IndexToString.transformSchema (SPARK-10573)") { + val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output") + val inSchema = StructType(Seq(StructField("input", DoubleType))) + val outSchema = idxToStr.transformSchema(inSchema) + assert(outSchema("output").dataType === StringType) + } } From 7e32387ae6303fd1cd32389d47df87170b841c67 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Sep 2015 14:10:54 -0700 Subject: [PATCH 0076/2089] [SPARK-10522] [SQL] Nanoseconds of Timestamp in Parquet should be positive Or Hive can't read it back correctly. Thanks vanzin for report this. Author: Davies Liu Closes #8674 from davies/positive_nano. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 12 +++++++----- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 17 ++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d652fce3fd9b6..687ca000d12bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -42,6 +42,7 @@ object DateTimeUtils { final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L @@ -190,13 +191,14 @@ object DateTimeUtils { /** * Returns Julian day and nanoseconds in a day from the number of microseconds + * + * Note: support timestamp since 4717 BC (without negative nanoseconds, compatible with Hive). */ def toJulianDay(us: SQLTimestamp): (Int, Long) = { - val seconds = us / MICROS_PER_SECOND - val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH - val secondsInDay = seconds % SECONDS_PER_DAY - val nanos = (us % MICROS_PER_SECOND) * 1000L - (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) + val julian_us = us + JULIAN_DAY_OF_EPOCH * MICROS_PER_DAY + val day = julian_us / MICROS_PER_DAY + val micros = julian_us % MICROS_PER_DAY + (day.toInt, micros * 1000L) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 1596bb79fa94b..6b9a11f0ff743 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -52,15 +52,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(ns === 0) assert(fromJulianDay(d, ns) == 0L) - val t = Timestamp.valueOf("2015-06-11 10:10:10.100") - val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) - val t1 = toJavaTimestamp(fromJulianDay(d1, ns1)) - assert(t.equals(t1)) - - val t2 = Timestamp.valueOf("2015-06-11 20:10:10.100") - val (d2, ns2) = toJulianDay(fromJavaTimestamp(t2)) - val t22 = toJavaTimestamp(fromJulianDay(d2, ns2)) - assert(t2.equals(t22)) + Seq(Timestamp.valueOf("2015-06-11 10:10:10.100"), + Timestamp.valueOf("2015-06-11 20:10:10.100"), + Timestamp.valueOf("1900-06-11 20:10:10.100")).foreach { t => + val (d, ns) = toJulianDay(fromJavaTimestamp(t)) + assert(ns > 0) + val t1 = toJavaTimestamp(fromJulianDay(d, ns)) + assert(t.equals(t1)) + } } test("SPARK-6785: java date conversion before and after epoch") { From 64f04154e3078ec7340da97e3c2b07cf24e89098 Mon Sep 17 00:00:00 2001 From: Edoardo Vacchi Date: Mon, 14 Sep 2015 14:56:04 -0700 Subject: [PATCH 0077/2089] [SPARK-6981] [SQL] Factor out SparkPlanner and QueryExecution from SQLContext Alternative to PR #6122; in this case the refactored out classes are replaced by inner classes with the same name for backwards binary compatibility * process in a lighter-weight, backwards-compatible way Author: Edoardo Vacchi Closes #6356 from evacchi/sqlctx-refactoring-lite. --- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 138 ++---------------- .../spark/sql/execution/QueryExecution.scala | 85 +++++++++++ .../spark/sql/execution/SQLExecution.scala | 2 +- .../spark/sql/execution/SparkPlanner.scala | 92 ++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 +- 6 files changed, 195 insertions(+), 128 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1a687b2374f14..3e61123c145cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -114,7 +114,7 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4e8414af50b44..e3fdd782e6ff6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -38,6 +38,10 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} +import org.apache.spark.sql.execution.{Filter, _} +import org.apache.spark.sql.{execution => sparkexecution} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.sources._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} @@ -188,9 +192,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) - protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): + org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) + protected[sql] def executePlan(plan: LogicalPlan) = + new sparkexecution.QueryExecution(this, plan) @transient protected[sql] val tlSession = new ThreadLocal[SQLSession]() { @@ -781,77 +787,11 @@ class SQLContext(@transient val sparkContext: SparkContext) }.toArray } - protected[sql] class SparkPlanner extends SparkStrategies { - val sparkContext: SparkContext = self.sparkContext - - val sqlContext: SQLContext = self - - def codegenEnabled: Boolean = self.conf.codegenEnabled - - def unsafeEnabled: Boolean = self.conf.unsafeEnabled - - def numPartitions: Int = self.conf.numShufflePartitions - - def strategies: Seq[Strategy] = - experimental.extraStrategies ++ ( - DataSourceStrategy :: - DDLStrategy :: - TakeOrderedAndProject :: - HashAggregation :: - Aggregation :: - LeftSemiJoin :: - EquiJoinSelection :: - InMemoryScans :: - BasicOperators :: - CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) - - /** - * Used to build table scan operators where complex projection and filtering are done using - * separate physical operators. This function returns the given scan operator with Project and - * Filter nodes added only when needed. For example, a Project operator is only used when the - * final desired output requires complex expressions to be evaluated or when columns can be - * further eliminated out after filtering has been done. - * - * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized - * away by the filter pushdown optimization. - * - * The required attributes for both filtering and expression evaluation are passed to the - * provided `scanBuilder` function so that it can avoid unnecessary column materialization. - */ - def pruneFilterProject( - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - prunePushedDownFilters: Seq[Expression] => Seq[Expression], - scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - - val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = - prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) - - // Right now we still use a projection even if the only evaluation is applying an alias - // to a column. Since this is a no-op, it could be avoided. However, using this - // optimization with the current implementation would change the output schema. - // TODO: Decouple final output schema from expression evaluation so this copy can be - // avoided safely. - - if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && - filterSet.subsetOf(projectSet)) { - // When it is possible to just use column pruning to get the right projection and - // when the columns of this projection are enough to evaluate all filter conditions, - // just do a scan followed by a filter, with no extra project. - val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) - filterCondition.map(Filter(_, scan)).getOrElse(scan) - } else { - val scan = scanBuilder((projectSet ++ filterSet).toSeq) - Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) - } - } - } + @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") + protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) @transient - protected[sql] val planner = new SparkPlanner + protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) @@ -898,59 +838,9 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val conf: SQLConf = new SQLConf } - /** - * :: DeveloperApi :: - * The primary workflow for executing relational queries using Spark. Designed to allow easy - * access to the intermediate phases of query execution for developers. - */ - @DeveloperApi - protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - - lazy val analyzed: LogicalPlan = analyzer.execute(logical) - lazy val withCachedData: LogicalPlan = { - assertAnalyzed() - cacheManager.useCachedData(analyzed) - } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) - - // TODO: Don't just pick the first one... - lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(self) - planner.plan(optimizedPlan).next() - } - // executedPlan should not be used to initialize any SparkPlan. It should be - // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) - - /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[InternalRow] = executedPlan.execute() - - protected def stringOrError[A](f: => A): String = - try f.toString catch { case e: Throwable => e.toString } - - def simpleString: String = - s"""== Physical Plan == - |${stringOrError(executedPlan)} - """.stripMargin.trim - - override def toString: String = { - def output = - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") - - s"""== Parsed Logical Plan == - |${stringOrError(logical)} - |== Analyzed Logical Plan == - |${stringOrError(output)} - |${stringOrError(analyzed)} - |== Optimized Logical Plan == - |${stringOrError(optimizedPlan)} - |== Physical Plan == - |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} - """.stripMargin.trim - } - } + @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") + protected[sql] class QueryExecution(logical: LogicalPlan) + extends sparkexecution.QueryExecution(this, logical) /** * Parses the data type in our internal string representation. The data type string should diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala new file mode 100644 index 0000000000000..7bb4133a29059 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.{Experimental, DeveloperApi} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{InternalRow, optimizer} +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * :: DeveloperApi :: + * The primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + */ +@DeveloperApi +class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + val analyzer = sqlContext.analyzer + val optimizer = sqlContext.optimizer + val planner = sqlContext.planner + val cacheManager = sqlContext.cacheManager + val prepareForExecution = sqlContext.prepareForExecution + + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) + + lazy val analyzed: LogicalPlan = analyzer.execute(logical) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed() + cacheManager.useCachedData(analyzed) + } + lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) + + // TODO: Don't just pick the first one... + lazy val sparkPlan: SparkPlan = { + SparkPlan.currentContext.set(sqlContext) + planner.plan(optimizedPlan).next() + } + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. + lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) + + /** Internal version of the RDD. Avoids copies and has no schema */ + lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + def simpleString: String = + s"""== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + + + override def toString: String = { + def output = + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") + + s"""== Parsed Logical Plan == + |${stringOrError(logical)} + |== Analyzed Logical Plan == + |${stringOrError(output)} + |${stringOrError(analyzed)} + |== Optimized Logical Plan == + |${stringOrError(optimizedPlan)} + |== Physical Plan == + |${stringOrError(executedPlan)} + |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} + """.stripMargin.trim + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index cee58218a885b..1422e15549c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -37,7 +37,7 @@ private[sql] object SQLExecution { * we can connect them with an execution. */ def withNewExecutionId[T]( - sqlContext: SQLContext, queryExecution: SQLContext#QueryExecution)(body: => T): T = { + sqlContext: SQLContext, queryExecution: QueryExecution)(body: => T): T = { val sc = sqlContext.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) if (oldExecutionId == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala new file mode 100644 index 0000000000000..b346f43faebe2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.DataSourceStrategy + +@Experimental +class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { + val sparkContext: SparkContext = sqlContext.sparkContext + + def codegenEnabled: Boolean = sqlContext.conf.codegenEnabled + + def unsafeEnabled: Boolean = sqlContext.conf.unsafeEnabled + + def numPartitions: Int = sqlContext.conf.numShufflePartitions + + def strategies: Seq[Strategy] = + sqlContext.experimental.extraStrategies ++ ( + DataSourceStrategy :: + DDLStrategy :: + TakeOrderedAndProject :: + HashAggregation :: + Aggregation :: + LeftSemiJoin :: + EquiJoinSelection :: + InMemoryScans :: + BasicOperators :: + CartesianProduct :: + BroadcastNestedLoopJoin :: Nil) + + /** + * Used to build table scan operators where complex projection and filtering are done using + * separate physical operators. This function returns the given scan operator with Project and + * Filter nodes added only when needed. For example, a Project operator is only used when the + * final desired output requires complex expressions to be evaluated or when columns can be + * further eliminated out after filtering has been done. + * + * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized + * away by the filter pushdown optimization. + * + * The required attributes for both filtering and expression evaluation are passed to the + * provided `scanBuilder` function so that it can avoid unnecessary column materialization. + */ + def pruneFilterProject( + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + prunePushedDownFilters: Seq[Expression] => Seq[Expression], + scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = + prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) + + // Right now we still use a projection even if the only evaluation is applying an alias + // to a column. Since this is a no-op, it could be avoided. However, using this + // optimization with the current implementation would change the output schema. + // TODO: Decouple final output schema from expression evaluation so this copy can be + // avoided safely. + + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) + filterCondition.map(Filter(_, scan)).getOrElse(scan) + } else { + val scan = scanBuilder((projectSet ++ filterSet).toSeq) + Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4572d5efc92bb..5e40d77689045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { - self: SQLContext#SparkPlanner => + self: SparkPlanner => object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { From 217e4964444f4e07b894b1bca768a0cbbe799ea0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 14 Sep 2015 15:00:27 -0700 Subject: [PATCH 0078/2089] [SPARK-9996] [SPARK-9997] [SQL] Add local expand and NestedLoopJoin operators This PR is in conflict with #8535 and #8573. Will update this one when they are merged. Author: zsxwing Closes #8642 from zsxwing/expand-nest-join. --- .../sql/execution/local/ExpandNode.scala | 60 +++++ .../spark/sql/execution/local/LocalNode.scala | 55 +++- .../execution/local/NestedLoopJoinNode.scala | 156 ++++++++++++ .../sql/execution/local/ExpandNodeSuite.scala | 51 ++++ .../execution/local/HashJoinNodeSuite.scala | 14 - .../sql/execution/local/LocalNodeTest.scala | 14 + .../local/NestedLoopJoinNodeSuite.scala | 239 ++++++++++++++++++ 7 files changed, 574 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala new file mode 100644 index 0000000000000..2aff156d18b54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -0,0 +1,60 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} + +case class ExpandNode( + conf: SQLConf, + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LocalNode) extends UnaryLocalNode(conf) { + + assert(projections.size > 0) + + private[this] var result: InternalRow = _ + private[this] var idx: Int = _ + private[this] var input: InternalRow = _ + private[this] var groups: Array[Projection] = _ + + override def open(): Unit = { + child.open() + groups = projections.map(ee => newProjection(ee, child.output)).toArray + idx = groups.length + } + + override def next(): Boolean = { + if (idx >= groups.length) { + if (child.next()) { + input = child.fetch() + idx = 0 + } else { + return false + } + } + result = groups(idx)(input) + idx += 1 + true + } + + override def fetch(): InternalRow = result + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index e540ef8555eb6..9840080e16953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the content through the [[Iterator]] interface. */ @@ -91,6 +103,28 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging result } + protected def newProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { @@ -113,6 +147,25 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } } + protected def newPredicate( + expression: Expression, + inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + if (codegenEnabled) { + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } + } else { + InterpretedPredicate.create(expression, inputSchema) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala new file mode 100644 index 0000000000000..7321fc66b4dde --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.util.collection.{BitSet, CompactBuffer} + +case class NestedLoopJoinNode( + conf: SQLConf, + left: LocalNode, + right: LocalNode, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"NestedLoopJoin should not take $x as the JoinType") + } + } + + private[this] def genResultProjection: InternalRow => InternalRow = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + private[this] var currentRow: InternalRow = _ + + private[this] var iterator: Iterator[InternalRow] = _ + + override def open(): Unit = { + val (streamed, build) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + build.open() + val buildRelation = new CompactBuffer[InternalRow] + while (build.next()) { + buildRelation += build.fetch().copy() + } + build.close() + + val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + val joinedRow = new JoinedRow + val matchedBuildTuples = new BitSet(buildRelation.size) + val resultProj = genResultProjection + streamed.open() + + // streamedRowMatches also contains null rows if using outer join + val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => + val matchedRows = new CompactBuffer[InternalRow] + + var i = 0 + var streamRowMatched = false + + // Scan the build relation to look for matches for each streamed row + while (i < buildRelation.size) { + val buildRow = buildRelation(i) + buildSide match { + case BuildRight => joinedRow(streamedRow, buildRow) + case BuildLeft => joinedRow(buildRow, streamedRow) + } + if (boundCondition(joinedRow)) { + matchedRows += resultProj(joinedRow).copy() + streamRowMatched = true + matchedBuildTuples.set(i) + } + i += 1 + } + + // If this row had no matches and we're using outer join, join it with the null rows + if (!streamRowMatched) { + (joinType, buildSide) match { + case (LeftOuter | FullOuter, BuildRight) => + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() + case (RightOuter | FullOuter, BuildLeft) => + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() + case _ => + } + } + + matchedRows.iterator + } + + // If we're using outer join, find rows on the build side that didn't match anything + // and join them with the null row + lazy val unmatchedBuildRows: Iterator[InternalRow] = { + var i = 0 + buildRelation.filter { row => + val r = !matchedBuildTuples.get(i) + i += 1 + r + }.iterator + } + iterator = (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } + case (LeftOuter | FullOuter, BuildLeft) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } + case _ => streamedRowMatches + } + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 0000000000000..cfa7f3f6dcb97 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class ExpandNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("expand") { + val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") + checkAnswer( + input, + node => + ExpandNode(conf, Seq( + Seq( + input.col("key") + input.col("value"), input.col("key") - input.col("value") + ).map(_.expr), + Seq( + input.col("key") * input.col("value"), input.col("key") / input.col("value") + ).map(_.expr) + ), node.output, node), + Seq( + (2, 0), + (1, 1), + (4, 0), + (4, 1), + (6, 0), + (9, 1), + (8, 0), + (16, 1), + (10, 0), + (25, 1) + ).toDF().collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 43b6f06aead88..78d891351f4a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -24,20 +24,6 @@ class HashJoinNodeSuite extends LocalNodeTest { import testImplicits._ - private def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { test(s"$suiteName: inner join with one match per row") { withSQLConf(confPairs: _*) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index b95d4ea7f8f2a..86dd28064cc6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -27,6 +27,20 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { def conf: SQLConf = sqlContext.conf + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } + /** * Runs the LocalNode and makes sure the answer matches the expected result. * @param input the input data to be used. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 0000000000000..b1ef26ba82f16 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,239 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def joinSuite( + suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { + test(s"$suiteName: left outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("n") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + upperCaseData.col("N") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + } + } + + test(s"$suiteName: right outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + } + } + + test(s"$suiteName: full outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) + } + } + } + + joinSuite( + "general-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "general-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "tungsten-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") + joinSuite( + "tungsten-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} From 16b6d18613e150c7038c613992d80a7828413e66 Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Mon, 14 Sep 2015 15:02:38 -0700 Subject: [PATCH 0079/2089] [SPARK-10594] [YARN] Remove reference to --num-executors, add --properties-file `ApplicationMaster` no longer has the `--num-executors` flag, and had an undocumented `--properties-file` configuration option. cc srowen Author: Erick Tryzelaar Closes #8754 from erickt/master. --- .../apache/spark/deploy/yarn/ApplicationMasterArguments.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index b08412414aa1c..17d9943c795e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -105,9 +105,9 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println System.exit(exitCode) From 4e2242bb41dda922573046c00c5142745632f95f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 14 Sep 2015 15:03:51 -0700 Subject: [PATCH 0080/2089] [SPARK-10576] [BUILD] Move .java files out of src/main/scala Move .java files in `src/main/scala` to `src/main/java` root, except for `package-info.java` (to stay next to package.scala) Author: Sean Owen Closes #8736 from srowen/SPARK-10576. --- .../org/apache/spark/annotation/AlphaComponent.java | 0 .../{scala => java}/org/apache/spark/annotation/DeveloperApi.java | 0 .../{scala => java}/org/apache/spark/annotation/Experimental.java | 0 .../main/{scala => java}/org/apache/spark/annotation/Private.java | 0 .../{scala => java}/org/apache/spark/graphx/TripletFields.java | 0 .../org/apache/spark/graphx/impl/EdgeActiveness.java | 0 .../org/apache/spark/sql/types/SQLUserDefinedType.java | 0 .../org/apache/spark/streaming/StreamingContextState.java | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename core/src/main/{scala => java}/org/apache/spark/annotation/AlphaComponent.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/DeveloperApi.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/Experimental.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/Private.java (100%) rename graphx/src/main/{scala => java}/org/apache/spark/graphx/TripletFields.java (100%) rename graphx/src/main/{scala => java}/org/apache/spark/graphx/impl/EdgeActiveness.java (100%) rename sql/catalyst/src/main/{scala => java}/org/apache/spark/sql/types/SQLUserDefinedType.java (100%) rename streaming/src/main/{scala => java}/org/apache/spark/streaming/StreamingContextState.java (100%) diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java rename to core/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java rename to core/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Experimental.java rename to core/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Private.java rename to core/src/main/java/org/apache/spark/annotation/Private.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/java/org/apache/spark/graphx/TripletFields.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java rename to graphx/src/main/java/org/apache/spark/graphx/TripletFields.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java rename to graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java similarity index 100% rename from streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java rename to streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java From ffbbc2c58b9bf1e2abc2ea797feada6821ab4de8 Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Mon, 14 Sep 2015 15:05:19 -0700 Subject: [PATCH 0081/2089] [SPARK-10549] scala 2.11 spark on yarn with security - Repl doesn't work Make this lazy so that it can set the yarn mode before creating the securityManager. Author: Tom Graves Author: Thomas Graves Closes #8719 from tgravescs/SPARK-10549. --- .../scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index be31eb2eda546..627148df80c11 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -35,7 +35,8 @@ object Main extends Logging { s.processArguments(List("-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-classpath", getAddedJars.mkString(File.pathSeparator)), true) - val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) + // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed + lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. From fd1e8cddf2635c55fec2ac6e1f1c221c9685af0f Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Mon, 14 Sep 2015 15:07:13 -0700 Subject: [PATCH 0082/2089] [SPARK-10543] [CORE] Peak Execution Memory Quantile should be Per-task Basis Read `PEAK_EXECUTION_MEMORY` using `update` to get per task partial value instead of cumulative value. I tested with this workload: ```scala val size = 1000 val repetitions = 10 val data = sc.parallelize(1 to size, 5).map(x => (util.Random.nextInt(size / repetitions),util.Random.nextDouble)).toDF("key", "value") val res = data.toDF.groupBy("key").agg(sum("value")).count ``` Before: ![image](https://cloud.githubusercontent.com/assets/4317392/9828197/07dd6874-58b8-11e5-9bd9-6ba927c38b26.png) After: ![image](https://cloud.githubusercontent.com/assets/4317392/9828151/a5ddff30-58b7-11e5-8d31-eda5dc4eae79.png) Tasks view: ![image](https://cloud.githubusercontent.com/assets/4317392/9828199/17dc2b84-58b8-11e5-92a8-be89ce4d29d1.png) cc andrewor14 I appreciate if you can give feedback on this since I think you introduced display of this metric. Author: Forest Fang Closes #8726 from saurfang/stagepage. --- .../org/apache/spark/ui/jobs/StagePage.scala | 2 +- .../org/apache/spark/ui/StagePageSuite.scala | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4adc6596ba21c..2b71f55b7bb4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -368,7 +368,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => info.accumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) .toDouble } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 3388c6dca81f1..86699e7f56953 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} @@ -47,6 +47,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { assert(html3.contains(targetString)) } + test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + // verify min/25/50/75/max show task value not cumulative values + assert(html.contains("10.0 b
spark.yarn.driver.memoryOverheaddriverMemory * 0.07, with minimum of 384 driverMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%).
spark.yarn.am.memoryOverheadAM memory * 0.07, with minimum of 384 AM memory * 0.10, with minimum of 384 Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode.
spark.mesos.principalFramework principal to authenticate to Mesos(none) Set the principal with which Spark framework will use to authenticate with Mesos.
spark.mesos.secretFramework secret to authenticate to Mesos(none)/td> Set the secret with which Spark framework will use to authenticate with Mesos.
spark.mesos.roleRole for the Spark framework* Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. @@ -354,7 +354,7 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.constraintsAttribute based constraints to be matched against when accepting resource offers.(none) Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes.
    From 74d8f7dda82c3a16348f3ff22da83203e0b7f708 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:46:13 -0700 Subject: [PATCH 0140/2089] Added tag to documentation. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 1814fb32ed8a5..330c159c67bca 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -346,7 +346,7 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.role** Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. From e3b5d6cb29e0f983fcc55920619e6433298955f5 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 18 Sep 2015 00:43:02 -0700 Subject: [PATCH 0141/2089] [SPARK-10684] [SQL] StructType.interpretedOrdering need not to be serialized Kryo fails with buffer overflow even with max value (2G). {noformat} org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 1 Serialization trace: containsChild (org.apache.spark.sql.catalyst.expressions.BoundReference) child (org.apache.spark.sql.catalyst.expressions.SortOrder) array (scala.collection.mutable.ArraySeq) ordering (org.apache.spark.sql.catalyst.expressions.InterpretedOrdering) interpretedOrdering (org.apache.spark.sql.types.StructType) schema (org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema). To avoid this, increase spark.kryoserializer.buffer.max value. at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:263) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:240) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) {noformat} Author: navis.ryu Closes #8808 from navis/SPARK-10684. --- .../main/scala/org/apache/spark/sql/types/StructType.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d8968ef806390..b29cf22dcb582 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -305,7 +305,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } - private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) + @transient + private[sql] lazy val interpretedOrdering = + InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { From 20fd35dfd1ac402b622604e7bbedcc53a580b0a2 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 18 Sep 2015 08:22:38 -0700 Subject: [PATCH 0142/2089] [SPARK-10451] [SQL] Prevent unnecessary serializations in InMemoryColumnarTableScan Many of the fields in InMemoryColumnar scan and InMemoryRelation can be made transient. This reduces my 1000ms job to abt 700 ms . The task size reduces from 2.8 mb to ~1300kb Author: Yash Datta Closes #8604 from saucam/serde. --- .../columnar/InMemoryColumnarTableScan.scala | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 66d429bc06198..d7e145f9c2bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -48,10 +48,10 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan, + @transient child: SparkPlan, tableName: Option[String])( - private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null, + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private var _statistics: Statistics = null, private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { @@ -62,7 +62,7 @@ private[sql] case class InMemoryRelation( _batchStats } - val partitionStatistics = new PartitionStatistics(output) + @transient val partitionStatistics = new PartitionStatistics(output) private def computeSizeInBytes = { val sizeOfRow: Expression = @@ -196,7 +196,7 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], predicates: Seq[Expression], - relation: InMemoryRelation) + @transient relation: InMemoryRelation) extends LeafNode { override def output: Seq[Attribute] = attributes @@ -205,7 +205,7 @@ private[sql] case class InMemoryColumnarTableScan( // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - val buildFilter: PartialFunction[Expression, Expression] = { + @transient val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -268,16 +268,23 @@ private[sql] case class InMemoryColumnarTableScan( readBatches.setValue(0) } - relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => - val partitionFilter = newPredicate( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), - relation.partitionStatistics.schema) + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val relOutput = relation.output + val buffers = relation.cachedColumnBuffers + + buffers.mapPartitions { cachedBatchIterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + schema) // Find the ordinals and data types of the requested columns. If none are requested, use the // narrowest (the field with minimum default element size). val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { val (narrowestOrdinal, narrowestDataType) = - relation.output.zipWithIndex.map { case (a, ordinal) => + relOutput.zipWithIndex.map { case (a, ordinal) => ordinal -> a.dataType } minBy { case (_, dataType) => ColumnType(dataType).defaultSize @@ -285,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { attributes.map { a => - relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip } @@ -296,7 +303,7 @@ private[sql] case class InMemoryColumnarTableScan( // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => ColumnAccessor( - relation.output(batchColumnIndex).dataType, + relOutput(batchColumnIndex).dataType, ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) } @@ -328,7 +335,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" From 35e8ab939000d4a1a01c1af4015c25ff6f4013a3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Sep 2015 09:53:52 -0700 Subject: [PATCH 0143/2089] [SPARK-10615] [PYSPARK] change assertEquals to assertEqual As ```assertEquals``` is deprecated, so we need to change ```assertEquals``` to ```assertEqual``` for existing python unit tests. Author: Yanbo Liang Closes #8814 from yanboliang/spark-10615. --- python/pyspark/ml/tests.py | 16 +-- python/pyspark/mllib/tests.py | 162 +++++++++++++++--------------- python/pyspark/sql/tests.py | 18 ++-- python/pyspark/streaming/tests.py | 2 +- 4 files changed, 99 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b892318f50bd9..648fa8858fba3 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -182,7 +182,7 @@ def test_params(self): self.assertEqual(testParams.getMaxIter(), 10) testParams.setMaxIter(100) self.assertTrue(testParams.isSet(maxIter)) - self.assertEquals(testParams.getMaxIter(), 100) + self.assertEqual(testParams.getMaxIter(), 100) self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -195,7 +195,7 @@ def test_params(self): testParams._setDefault(seed=41) testParams.setSeed(43) - self.assertEquals( + self.assertEqual( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", @@ -264,23 +264,23 @@ def test_ngram(self): self.assertEqual(ngram0.getInputCol(), "input") self.assertEqual(ngram0.getOutputCol(), "output") transformedDF = ngram0.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default - self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getInputCol(), "input") transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["panda"]) + self.assertEqual(transformedDF.head().output, ["panda"]) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) - self.assertEquals(stopWordRemover.getInputCol(), "input") - self.assertEquals(stopWordRemover.getStopWords(), stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a"]) + self.assertEqual(transformedDF.head().output, ["a"]) class HasInducedError(Params): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 636f9a06cab7b..96cf13495aa95 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -166,13 +166,13 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.]]) arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEquals(10.0, sv.dot(dv)) + self.assertEqual(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEquals(30.0, dv.dot(dv)) + self.assertEqual(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEquals(30.0, lst.dot(dv)) + self.assertEqual(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEquals(7.0, sv.dot(arr)) + self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -181,27 +181,27 @@ def test_squared_distance(self): lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEquals(15.0, _squared_distance(sv, dv)) - self.assertEquals(25.0, _squared_distance(sv, lst)) - self.assertEquals(20.0, _squared_distance(dv, lst)) - self.assertEquals(15.0, _squared_distance(dv, sv)) - self.assertEquals(25.0, _squared_distance(lst, sv)) - self.assertEquals(20.0, _squared_distance(lst, dv)) - self.assertEquals(0.0, _squared_distance(sv, sv)) - self.assertEquals(0.0, _squared_distance(dv, dv)) - self.assertEquals(0.0, _squared_distance(lst, lst)) - self.assertEquals(25.0, _squared_distance(sv, lst1)) - self.assertEquals(3.0, _squared_distance(sv, arr)) - self.assertEquals(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(hash(v1), hash(v2)) - self.assertEquals(hash(v1), hash(v3)) - self.assertEquals(hash(v2), hash(v3)) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) self.assertFalse(hash(v1) == hash(v4)) self.assertFalse(hash(v2) == hash(v4)) @@ -212,8 +212,8 @@ def test_eq(self): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(v1, v2) - self.assertEquals(v1, v3) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) @@ -238,13 +238,13 @@ def test_conversion(self): def test_sparse_vector_indexing(self): sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[-1], 2) + self.assertEqual(sv[-2], 0) + self.assertEqual(sv[-4], 0) for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: @@ -255,7 +255,7 @@ def test_matrix_indexing(self): expected = [[0, 6], [1, 8], [4, 10]] for i in range(3): for j in range(2): - self.assertEquals(mat[i, j], expected[i][j]) + self.assertEqual(mat[i, j], expected[i][j]) def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -308,11 +308,11 @@ def test_sparse_matrix(self): # Test sparse matrix creation. sm1 = SparseMatrix( 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEquals(sm1.numRows, 3) - self.assertEquals(sm1.numCols, 4) - self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) self.assertTrue( repr(sm1), 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') @@ -325,13 +325,13 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1[i, j]) + self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() - self.assertEquals(sm1.numRows, smnew.numRows) - self.assertEquals(sm1.numCols, smnew.numCols) + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) self.assertTrue(array_equal(sm1.values, smnew.values)) @@ -339,11 +339,11 @@ def test_sparse_matrix(self): sm1t = SparseMatrix( 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], isTransposed=True) - self.assertEquals(sm1t.numRows, 3) - self.assertEquals(sm1t.numCols, 4) - self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) expected = [ [3, 2, 0, 0], @@ -352,18 +352,18 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertEqual(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) def test_dense_matrix_is_transposed(self): mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEquals(mat1, mat) + self.assertEqual(mat1, mat) expected = [[0, 4], [1, 6], [3, 9]] for i in range(3): for j in range(2): - self.assertEquals(mat1[i, j], expected[i][j]) + self.assertEqual(mat1[i, j], expected[i][j]) self.assertTrue(array_equal(mat1.toArray(), expected)) sm = mat1.toSparse() @@ -412,8 +412,8 @@ def test_kmeans(self): ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", initializationSteps=7, epsilon=1e-4) - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_kmeans_deterministic(self): from pyspark.mllib.clustering import KMeans @@ -443,8 +443,8 @@ def test_gmm(self): clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, maxIterations=10, seed=56) labels = clusters.predict(data).collect() - self.assertEquals(labels[0], labels[1]) - self.assertEquals(labels[2], labels[3]) + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture @@ -456,7 +456,7 @@ def test_gmm_deterministic(self): clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEquals(round(c1, 7), round(c2, 7)) + self.assertEqual(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -711,18 +711,18 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_to_vector(lil)) - self.assertEquals(sv, _convert_to_vector(lil.tocsc())) - self.assertEquals(sv, _convert_to_vector(lil.tocoo())) - self.assertEquals(sv, _convert_to_vector(lil.tocsr())) - self.assertEquals(sv, _convert_to_vector(lil.todok())) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(l): return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEquals(sv, serialize(lil)) - self.assertEquals(sv, serialize(lil.tocsc())) - self.assertEquals(sv, serialize(lil.tocsr())) - self.assertEquals(sv, serialize(lil.todok())) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix @@ -730,7 +730,7 @@ def test_dot(self): lil[1, 0] = 1 lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEquals(10.0, dv.dot(lil)) + self.assertEqual(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix @@ -739,8 +739,8 @@ def test_squared_distance(self): lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, dv.squared_distance(lil)) - self.assertEquals(15.0, sv.squared_distance(lil)) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" @@ -759,8 +759,8 @@ def test_clustering(self): self.scipy_matrix(3, {2: 1.1}) ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -984,12 +984,12 @@ def test_word2vec_setters(self): .setNumIterations(10) \ .setSeed(1024) \ .setMinCount(3) - self.assertEquals(model.vectorSize, 2) + self.assertEqual(model.vectorSize, 2) self.assertTrue(model.learningRate < 0.02) - self.assertEquals(model.numPartitions, 2) - self.assertEquals(model.numIterations, 10) - self.assertEquals(model.seed, 1024) - self.assertEquals(model.minCount, 3) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) def test_word2vec_get_vectors(self): data = [ @@ -1002,7 +1002,7 @@ def test_word2vec_get_vectors(self): ["a"] ] model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEquals(len(model.getVectors()), 3) + self.assertEqual(len(model.getVectors()), 3) class StandardScalerTests(MLlibTestCase): @@ -1044,8 +1044,8 @@ def test_model_params(self): """Test that the model params are set correctly""" stkm = StreamingKMeans() stkm.setK(5).setDecayFactor(0.0) - self.assertEquals(stkm._k, 5) - self.assertEquals(stkm._decayFactor, 0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) # Model not set yet. self.assertIsNone(stkm.latestModel()) @@ -1053,9 +1053,9 @@ def test_model_params(self): stkm.setInitialCenters( centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEquals( + self.assertEqual( stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) def test_accuracy_for_single_center(self): """Test that parameters obtained are correct for a single center.""" @@ -1070,7 +1070,7 @@ def test_accuracy_for_single_center(self): self.ssc.start() def condition(): - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True self._eventually(condition, catch_assertions=True) @@ -1114,7 +1114,7 @@ def test_trainOn_model(self): def condition(): finalModel = stkm.latestModel() self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True self._eventually(condition, catch_assertions=True) @@ -1141,7 +1141,7 @@ def update(rdd): self.ssc.start() def condition(): - self.assertEquals(result, [[0], [1], [2], [3]]) + self.assertEqual(result, [[0], [1], [2], [3]]) return True self._eventually(condition, catch_assertions=True) @@ -1263,7 +1263,7 @@ def test_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(models), len(input_batches)) + self.assertEqual(len(models), len(input_batches)) return True # We want all batches to finish for this test. @@ -1297,7 +1297,7 @@ def test_predictions(self): self.ssc.start() def condition(): - self.assertEquals(len(true_predicted), len(input_batches)) + self.assertEqual(len(true_predicted), len(input_batches)) return True self._eventually(condition, catch_assertions=True) @@ -1400,7 +1400,7 @@ def test_parameter_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(model_weights), len(batches)) + self.assertEqual(len(model_weights), len(batches)) return True # We want all batches to finish for this test. @@ -1433,7 +1433,7 @@ def test_prediction(self): self.ssc.start() def condition(): - self.assertEquals(len(samples), len(batches)) + self.assertEqual(len(samples), len(batches)) return True # We want all batches to finish for this test. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f2172b7a27d88..3e680f1030a71 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,7 +157,7 @@ class DataTypeTests(unittest.TestCase): def test_data_type_eq(self): lt = LongType() lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEquals(lt, lt2) + self.assertEqual(lt, lt2) # regression test for SPARK-7978 def test_decimal_type(self): @@ -393,7 +393,7 @@ def test_infer_nested_schema(self): CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) df = self.sqlCtx.inferSchema(rdd) - self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] @@ -403,7 +403,7 @@ def test_create_dataframe_from_objects(self): def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") - self.assertEquals(Row(col=None), df.first()) + self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): from datetime import date, datetime @@ -519,14 +519,14 @@ def test_apply_schema_with_udt(self): StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT @@ -554,14 +554,14 @@ def test_parquet_with_udt(self): df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_column_operators(self): ci = self.df.key @@ -826,8 +826,8 @@ def test_infer_long_type(self): output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - self.assertEquals('a', df1.first().f1) - self.assertEquals(100000000000000, df1.first().f2) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) self.assertEqual(_infer_type(1), LongType()) self.assertEqual(_infer_type(2**10), LongType()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index cfea95b0dec71..e4e56fff3b3fc 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -693,7 +693,7 @@ def check_output(n): # Verify that getActiveOrCreate() returns active context self.setupCalled = False - self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) self.assertFalse(self.setupCalled) # Verify that getActiveOrCreate() uses existing SparkContext From 00a2911c5bea67a1a4796fb1d6fd5d0a8ee79001 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 12:19:08 -0700 Subject: [PATCH 0144/2089] [SPARK-10540] Fixes flaky all-data-type test This PR breaks the original test case into multiple ones (one test case for each data type). In this way, test failure output can be much more readable. Within each test case, we build a table with two columns, one of them is for the data type to test, the other is an "index" column, which is used to sort the DataFrame and workaround [SPARK-10591] [1] [1]: https://issues.apache.org/jira/browse/SPARK-10591 Author: Cheng Lian Closes #8768 from liancheng/spark-10540/test-all-data-types. --- .../sql/sources/hadoopFsRelationSuites.scala | 109 +++++++----------- 1 file changed, 43 insertions(+), 66 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8ffcef85668d6..d7504936d90e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,80 +100,57 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - ignore("test all data types") { - withTempPath { file => - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) - // TODO: add CalendarIntervalType to here once we can save it out. - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) - - // Generate data at the driver side. We need to materialize the data first and then - // create RDD. - val maybeDataGenerator = - RandomDataGenerator.forType( - dataType = schema, + private val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new MyDenseVectorUDT() + ).filter(supportsDataType) + + for (dataType <- supportedDataTypes) { + test(s"test all data types - $dataType") { + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, nullable = true, - seed = Some(System.nanoTime())) - val dataGenerator = - maybeDataGenerator - .getOrElse(fail(s"Failed to create data generator for schema $schema")) - val data = (1 to 10).map { i => - dataGenerator.apply() match { - case row: Row => row - case null => Row.fromSeq(Seq.fill(schema.length)(null)) - case other => - fail(s"Row or null is expected to be generated, " + - s"but a ${other.getClass.getCanonicalName} is generated.") + seed = Some(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") } - } - // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 10) - val df = sqlContext.createDataFrame(rdd, schema) + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - // All columns that have supported data types of this source. - val supportedColumns = schema.fields.collect { - case StructField(name, dataType, _, _) if supportsDataType(dataType) => name - } - val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) - - val dfToBeSaved = df.selectExpr(selectedColumns: _*) - - // Save the data out. - dfToBeSaved - .write - .format(dataSourceName) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .save(file.getCanonicalPath) + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) - val loadedDF = - sqlContext + val loadedDF = sqlContext .read .format(dataSourceName) - .schema(dfToBeSaved.schema) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .load(file.getCanonicalPath) - .selectExpr(selectedColumns: _*) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") - // Read the data back. - checkAnswer( - loadedDF, - dfToBeSaved - ) + checkAnswer(loadedDF, df) + } } } From c6f8135ee52202bd86adb090ab631e80330ea4df Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 18 Sep 2015 13:20:13 -0700 Subject: [PATCH 0145/2089] [SPARK-10539] [SQL] Project should not be pushed down through Intersect or Except #8742 Intersect and Except are both set operators and they use the all the columns to compare equality between rows. When pushing their Project parent down, the relations they based on would change, therefore not an equivalent transformation. JIRA: https://issues.apache.org/jira/browse/SPARK-10539 I added some comments based on the fix of https://github.com/apache/spark/pull/8742. Author: Yijie Shen Author: Yin Huai Closes #8823 from yhuai/fix_set_optimization. --- .../sql/catalyst/optimizer/Optimizer.scala | 37 ++++++++++--------- .../optimizer/SetOperationPushDownSuite.scala | 23 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++ 3 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 648a65e7c0eb3..324f40a051c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,7 +85,22 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. */ object SetOperationPushDown extends Rule[LogicalPlan] { @@ -122,40 +137,26 @@ object SetOperationPushDown extends Rule[LogicalPlan] { Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into union + // Push down projection through UNION ALL case Project(projectList, u @ Union(left, right)) => val rewrites = buildRewrites(u) Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => val rewrites = buildRewrites(i) Intersect( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => val rewrites = buildRewrites(e) Except( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72c..3fca47a023dc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -60,23 +60,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c167999af580e..1370713975f2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -907,4 +907,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } } From 3a22b1004f527d54d399dd0225cd7f2f8ffad9c5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 13:47:14 -0700 Subject: [PATCH 0146/2089] [SPARK-10449] [SQL] Don't merge decimal types with incompatable precision or scales From JIRA: Schema merging should only handle struct fields. But currently we also reconcile decimal precision and scale information. Author: Holden Karau Closes #8634 from holdenk/SPARK-10449-dont-merge-different-precision. --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b29cf22dcb582..d6b436724b2a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -373,10 +373,19 @@ object StructType extends AbstractDataType { StructType(newFields) case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) + DecimalType.Fixed(rightPrecision, rightScale)) => + if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { + DecimalType(leftPrecision, leftScale) + } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } else if (leftPrecision != rightPrecision) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } else { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"scala $leftScale and $rightScale") + } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt From 348d7c9a93dd00d3d1859342a8eb0aea2e77f597 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Sep 2015 13:48:41 -0700 Subject: [PATCH 0147/2089] [SPARK-9808] Remove hash shuffle file consolidation. Author: Reynold Xin Closes #8812 from rxin/SPARK-9808-1. --- .../shuffle/FileShuffleBlockResolver.scala | 178 ++---------------- .../apache/spark/storage/BlockManager.scala | 9 - .../org/apache/spark/storage/DiskStore.scala | 3 - .../hash/HashShuffleManagerSuite.scala | 110 ----------- docs/configuration.md | 10 - .../shuffle/ExternalShuffleBlockResolver.java | 4 - project/MimaExcludes.scala | 4 + 7 files changed, 17 insertions(+), 301 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index c057de9b3f4df..d9902f96dfd4e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -28,10 +26,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,24 +39,7 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). @@ -71,26 +50,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,24 +72,16 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -142,58 +102,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,17 +125,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups.asScala; - file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks.asScala; - reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() } logInfo("Deleted all files for shuffle " + shuffleId) true @@ -229,10 +137,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -241,59 +145,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d31aa68eb6954..bca3942f8c555 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -106,15 +106,6 @@ private[spark] class BlockManager( } } - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") - } - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f45956282166..feb9533604ffb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -154,9 +154,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { file.delete() } else { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala deleted file mode 100644 index 491dc3659e184..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{File, FileWriter} - -import scala.language.reflectiveCalls - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} - -class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { - private val testConf = new SparkConf(false) - - private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) - val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) - assert(expected.offset === segment.getOffset) - assert(expected.length === segment.getLength) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val conf = new SparkConf(false) - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - sc = new SparkContext("local", "test", conf) - - val shuffleBlockResolver = - SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockResolver] - - val shuffle1 = shuffleBlockResolver.forMapTask(1, 1, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1", "value") - writer.write("test2", "value") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleBlockResolver.forMapTask(1, 2, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3", "value") - writer.write("test4", "vlue") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is - // concurrent read and writes happening to the same shuffle group. - - val shuffle3 = shuffleBlockResolver.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3", "value") - writer.write("test4", "value") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffleBlockResolver.removeShuffle(1) - } - - def writeToFile(file: File, numBytes: Int) { - val writer = new FileWriter(file, true) - for (i <- 0 until numBytes) writer.write(i) - writer.close() - } -} diff --git a/docs/configuration.md b/docs/configuration.md index 1a701f18881fe..3700051efb448 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -390,16 +390,6 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec.
spark.shuffle.consolidateFilesfalse - If set to "true", consolidates intermediate files created during a shuffle. Creating fewer - files can improve filesystem performance for shuffles with large numbers of reduce tasks. It - is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option - might degrade performance on machines with many (>8) cores due to filesystem limitations. -
spark.shuffle.file.buffer 32kspark.shuffle.memoryFraction 0.2 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will begin to spill to disk. If spills are often, consider increasing this value at the expense of spark.storage.memoryFraction. @@ -483,14 +483,6 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions.
spark.shuffle.spilltrue - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. -
spark.shuffle.spill.compress true
spark.sql.planner.externalSorttrue - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. -
# Distributed SQL Engine diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ab5aab1e115f7..73d7d9a5692a9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,7 +48,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ +from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync @@ -580,12 +580,11 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -610,12 +609,11 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: @@ -1770,13 +1768,11 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = self._can_spill() memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1784,8 +1780,7 @@ def combineLocally(iterator): shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() @@ -1824,9 +1819,6 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - def _can_spill(self): - return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" - def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1857,14 +1849,12 @@ def mergeCombiners(a, b): a.extend(b) return a - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combine(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1872,8 +1862,7 @@ def combine(iterator): shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - merger = ExternalGroupBy(agg, memory, serializer)\ - if spill else InMemoryMerger(agg) + merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b8118bdb7ca76..e974cda9fc3e1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -131,36 +131,6 @@ def items(self): raise NotImplementedError -class InMemoryMerger(Merger): - - """ - In memory merger based on in-memory dict. - """ - - def __init__(self, aggregator): - Merger.__init__(self, aggregator) - self.data = {} - - def mergeValues(self, iterator): - """ Combine the items by creator and combiner """ - # speed up attributes lookup - d, creator = self.data, self.agg.createCombiner - comb = self.agg.mergeValue - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else creator(v) - - def mergeCombiners(self, iterator): - """ Merge the combined items by mergeCombiner """ - # speed up attributes lookup - d, comb = self.data, self.agg.mergeCombiners - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - - def items(self): - """ Return the merged items ad iterator """ - return iter(self.data.items()) - - def _compressed_serializer(self, serializer=None): # always use PickleSerializer to simplify implementation ser = PickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 647504c32f156..f11aaf001c8df 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,7 +62,7 @@ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -95,17 +95,6 @@ def setUp(self): lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) - def test_in_memory(self): - m = InMemoryMerger(self.agg) - m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d084..b9fb90d964206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -330,11 +330,6 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", - defaultValue = Some(true), - doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + - " memory.") - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") @@ -422,6 +417,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" } } @@ -476,8 +472,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5e40d77689045..41b215c79296a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && TungstenSort.supportsSchema(child.schema)) { execution.TungstenSort(sortExprs, global, child) - } else if (sqlContext.conf.externalSortEnabled) { - execution.ExternalSort(sortExprs, global, child) } else { execution.Sort(sortExprs, global, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 95209e6634519..af28e2dfa4186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -105,6 +105,15 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 40ef7c3b53530..27f26245a5ef0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -31,38 +31,12 @@ import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} // This file defines various sort operators. //////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - /** * Performs a sort, spilling to disk as needed. * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. */ -case class ExternalSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan) @@ -93,7 +67,7 @@ case class ExternalSort( } /** - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Optimized version of [[Sort]] that operates on binary data (implemented as part of * Project Tungsten). * * @param global when true performs a global sort of all partitions by shuffling the data first diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f9981356f364f..05b4127cbcaff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -581,28 +581,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() } } @@ -1731,10 +1715,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("external sorting updates peak execution memory") { - withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 4492e37ad01ff..5dc37e5c3c238 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,7 +32,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 3073d492e613b..847c188a30333 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -36,13 +36,13 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } From 1aa9e50256988533fa54584b49dbc408a14438ee Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 20 Sep 2015 16:05:12 -0700 Subject: [PATCH 0156/2089] [SPARK-5905] [MLLIB] Note requirements for certain RowMatrix methods in docs Note methods that fail for cols > 65535; note that SVD does not require n >= m CC mengxr Author: Sean Owen Closes #8839 from srowen/SPARK-5905. --- .../spark/mllib/linalg/distributed/RowMatrix.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e55ef26858adb..7c7d900af3d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -109,7 +109,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. + * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with + * more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -150,7 +151,8 @@ class RowMatrix @Since("1.0.0") ( * - s is a Vector of size k, holding the singular values in descending order, * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * We assume n is smaller than m. The singular values and the right singular vectors are derived + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined @@ -320,7 +322,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. + * Computes the covariance matrix, treating each row as an observation. Note that this cannot + * be computed on matrices with more than 65535 columns. * @return a local dense matrix of size n x n */ @Since("1.0.0") @@ -374,6 +377,8 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * + * Note that this cannot be computed on matrices with more than 65535 columns. + * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ From 0c498717ba9622b6c889e701e8eed5ef9215c030 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 20 Sep 2015 16:16:31 -0700 Subject: [PATCH 0157/2089] [SPARK-10715] [ML] Duplicate initialization flag in WeightedLeastSquare There are duplicate set of initialization flag in `WeightedLeastSquares#add`. `initialized` is already set in `init(Int)`. Author: lewuathe Closes #8837 from Lewuathe/duplicate-initialization-flag. --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 0ff8931b0bab4..4374e99631560 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -193,7 +193,6 @@ private[ml] object WeightedLeastSquares { val ak = a.size if (!initialized) { init(ak) - initialized = true } assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") count += 1L From 01440395176bdbb2662480f03b27851cb860f385 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Sun, 20 Sep 2015 22:55:24 -0700 Subject: [PATCH 0158/2089] [SPARK-10631] [DOCUMENTATION, MLLIB, PYSPARK] Added documentation for few APIs There are some missing API docs in pyspark.mllib.linalg.Vector (including DenseVector and SparseVector). We should add them based on their Scala counterparts. Author: vinodkc Closes #8834 from vinodkc/fix_SPARK-10631. --- python/pyspark/mllib/linalg/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 4829acb16ed8a..f929e3e96fbe2 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -301,11 +301,14 @@ def __reduce__(self): return DenseVector, (self.array.tostring(),) def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ return np.count_nonzero(self.array) def norm(self, p): """ - Calculte the norm of a DenseVector. + Calculates the norm of a DenseVector. >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) @@ -397,10 +400,16 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): + """ + Returns an numpy.ndarray + """ return self.array @property def values(self): + """ + Returns a list of values + """ return self.array def __getitem__(self, item): @@ -479,8 +488,8 @@ def __init__(self, size, *args): :param size: Size of the vector. :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly i - ncreasing indices and a list of corresponding values [index, ...], + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) @@ -521,11 +530,14 @@ def __init__(self, size, *args): raise TypeError("indices array must be sorted") def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ return np.count_nonzero(self.values) def norm(self, p): """ - Calculte the norm of a SparseVector. + Calculates the norm of a SparseVector. >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) @@ -797,7 +809,7 @@ def sparse(size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, + :param args: Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indices and values. >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) From 20a61dbd9b57957fcc5b58ef8935533914172b07 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 18:53:28 +0100 Subject: [PATCH 0159/2089] [SPARK-10626] [MLLIB] create java friendly method for random rdd SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method. Author: Holden Karau Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD. --- .../spark/mllib/random/RandomRDDs.scala | 52 ++++++++++++++++++- .../mllib/random/JavaRandomRDDsSuite.java | 30 +++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 4dd5ea214d678..f8ff26b5795be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD @@ -381,7 +382,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi @Since("1.1.0") @@ -394,6 +395,55 @@ object RandomRDDs { new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * :: DeveloperApi :: + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. + * + * @param jsc JavaSparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int, + seed: Long): JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed) + JavaRDD.fromRDD(rdd) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong()) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, 0); + } + // TODO Generate RDD[Vector] from multivariate distributions. /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 33d81b1e9592b..fce5f6712f462 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.random; +import java.io.Serializable; import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; @@ -231,4 +232,33 @@ public void testGammaVectorRDD() { } } + @Test + public void testArbitrary() { + long size = 10; + long seed = 1L; + int numPartitions = 0; + StringGenerator gen = new StringGenerator(); + JavaRDD rdd1 = randomJavaRDD(sc, gen, size); + JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(size, rdd.count()); + Assert.assertEquals(2, rdd.first().length()); + } + } +} + +// This is just a test generator, it always returns a string of 42 +class StringGenerator implements RandomDataGenerator, Serializable { + @Override + public String nextValue() { + return "42"; + } + @Override + public StringGenerator copy() { + return new StringGenerator(); + } + @Override + public void setSeed(long seed) { + } } From ebbf85f07bb8de0d566f1ae4b41f26421180bebe Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 21 Sep 2015 11:39:04 -0700 Subject: [PATCH 0160/2089] [SPARK-7989] [SPARK-10651] [CORE] [TESTS] Increase timeout to fix flaky tests I noticed only one block manager registered with master in an unsuccessful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/3534/) ``` 15/09/16 13:02:30.981 pool-1-thread-1-ScalaTest-running-BroadcastSuite INFO SparkContext: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 13:02:38.133 sparkDriver-akka.actor.default-dispatcher-19 INFO BlockManagerMasterEndpoint: Registering block manager localhost:48196 with 530.3 MB RAM, BlockManagerId(0, localhost, 48196) ``` In addition, the first block manager needed 7+ seconds to start. But the test expected 2 block managers so it failed. However, there was no exception in this log file. So I checked a successful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3536/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/) and it needed 4-5 seconds to set up the local cluster: ``` 15/09/16 18:11:27.738 sparkWorker1-akka.actor.default-dispatcher-5 INFO Worker: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 18:11:30.838 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:54202 with 530.3 MB RAM, BlockManagerId(1, localhost, 54202) 15/09/16 18:11:32.112 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:32955 with 530.3 MB RAM, BlockManagerId(0, localhost, 32955) ``` In this build, the first block manager needed only 3+ seconds to start. Comparing these two builds, I guess it's possible that the local cluster in `BroadcastSuite` cannot be ready in 10 seconds if the Jenkins worker is busy. So I just increased the timeout to 60 seconds to see if this can fix the issue. Author: zsxwing Closes #8813 from zsxwing/fix-BroadcastSuite. --- .../scala/org/apache/spark/ExternalShuffleServiceSuite.scala | 2 +- .../test/scala/org/apache/spark/broadcast/BroadcastSuite.scala | 2 +- .../apache/spark/scheduler/SparkListenerWithClusterSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index e846a72c888c6..231f4631e0a47 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -61,7 +61,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index fb7a8ae3f9d41..ba21075ce6be5 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -311,7 +311,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d1e23ed527ff1..9fa8859382911 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) From ca9fe540fe04e2e230d1e76526b5502bab152914 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 21 Sep 2015 19:46:39 +0100 Subject: [PATCH 0161/2089] [SPARK-10662] [DOCS] Code snippets are not properly formatted in tables * Backticks are processed properly in Spark Properties table * Removed unnecessary spaces * See http://people.apache.org/~pwendell/spark-nightly/spark-master-docs/latest/running-on-yarn.html Author: Jacek Laskowski Closes #8795 from jaceklaskowski/docs-yarn-formatting. --- docs/configuration.md | 97 +++++++++++++++-------------- docs/programming-guide.md | 100 +++++++++++++++--------------- docs/running-on-mesos.md | 14 ++--- docs/running-on-yarn.md | 106 ++++++++++++++++---------------- docs/sql-programming-guide.md | 16 ++--- docs/submitting-applications.md | 8 +-- 6 files changed, 171 insertions(+), 170 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5ec097c78aa38..b22587c70316b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,20 +34,20 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may actually require one to prevent any sort of starvation issues. -Properties that specify some time duration should be configured with a unit of time. +Properties that specify some time duration should be configured with a unit of time. The following format is accepted: - + 25ms (milliseconds) 5s (seconds) 10m or 10min (minutes) 3h (hours) 5d (days) 1y (years) - - + + Properties that specify a byte size should be configured with a unit of size. The following format is accepted: @@ -140,7 +140,7 @@ of the most common options to set are: Amount of memory to use for the driver process, i.e. where SparkContext is initialized. (e.g. 1g, 2g). - +
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option @@ -207,7 +207,7 @@ Apart from these, the following properties are also available, and may be useful
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-class-path command line option or in + Instead, please set this through the --driver-class-path command line option or in your default properties file. @@ -216,10 +216,10 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. - +
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-java-options command line option or in + Instead, please set this through the --driver-java-options command line option or in your default properties file. @@ -228,10 +228,10 @@ Apart from these, the following properties are also available, and may be useful (none) Set a special library path to use when launching the driver JVM. - +
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-library-path command line option or in + Instead, please set this through the --driver-library-path command line option or in your default properties file. @@ -242,7 +242,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in the the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. - + This is used in cluster mode only. @@ -250,8 +250,8 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to prepend to the classpath of executors. This exists primarily for - backwards-compatibility with older versions of Spark. Users typically should not need to set + Extra classpath entries to prepend to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -259,9 +259,9 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Heap size settings can be set with spark.executor.memory. @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by - passing a profiler class in as a parameter to the `SparkContext` constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -460,11 +460,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.enabled false - Enables the external shuffle service. This service preserves the shuffle files written by - executors so the executors can be safely removed. This must be enabled if + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if spark.dynamicAllocation.enabled is "true". The external shuffle service must be set up in order to enable it. See - dynamic allocation + dynamic allocation configuration and setup documentation for more information. @@ -747,9 +747,9 @@ Apart from these, the following properties are also available, and may be useful 1 in YARN mode, all the available cores on the worker in standalone mode. The number of cores to use on each executor. For YARN and standalone mode only. - - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one executor per application will run on each worker. @@ -893,14 +893,14 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval 1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. @@ -909,9 +909,9 @@ Apart from these, the following properties are also available, and may be useful 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with `spark.akka.heartbeat.interval` if you need to. + this along with spark.akka.heartbeat.interval if you need to. @@ -978,7 +978,7 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout 120s - Default timeout for all network interactions. This config will be used in place of + Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or @@ -991,8 +991,8 @@ Apart from these, the following properties are also available, and may be useful Maximum number of retries when binding to a port before giving up. When a port is given a specific value (non 0), each subsequent retry will - increment the port used in the previous attempt by 1 before retrying. This - essentially allows it to try a range of ports from the start port specified + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified to port + maxRetries. @@ -1191,7 +1191,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout 60s - If dynamic allocation is enabled and an executor has been idle for more than this duration, + If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this description. @@ -1424,11 +1424,11 @@ Apart from these, the following properties are also available, and may be useful false Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). - This enables the Spark Streaming to control the receiving rate based on the + This enables the Spark Streaming to control the receiving rate based on the current batch scheduling delays and processing times so that the system receives - only as fast as the system can process. Internally, this dynamically sets the + only as fast as the system can process. Internally, this dynamically sets the maximum receiving rate of receivers. This rate is upper bounded by the values - `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + spark.streaming.receiver.maxRate and spark.streaming.kafka.maxRatePerPartition if they are set (see below). @@ -1542,15 +1542,15 @@ The following variables can be set in `spark-env.sh`: Environment VariableMeaning JAVA_HOME - Location where Java is installed (if it's not on your default `PATH`). + Location where Java is installed (if it's not on your default PATH). PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is `python`). + Python binary executable to use for PySpark in both driver and workers (default is python). PYSPARK_DRIVER_PYTHON - Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). SPARK_LOCAL_IP @@ -1580,4 +1580,3 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config To specify a different configuration directory other than the default "SPARK_HOME/conf", you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. - diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4cf83bb392636..8ad238315f12c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -182,8 +182,8 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. You can also add dependencies -(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -217,7 +217,7 @@ context connects to using the `--master` argument, and you can add Python .zip, to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: @@ -249,8 +249,8 @@ the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support $ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} -After the IPython Notebook server is launched, you can create a new "Python 2" notebook from -the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the IPython notebook.
@@ -418,9 +418,9 @@ Apart from text files, Spark's Python API also supports several other data forma **Writable Support** -PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the -resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, -PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following Writables are automatically converted: @@ -435,9 +435,9 @@ Writables are automatically converted:
MapWritabledict
-Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, -users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default -converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get Python `array.array` for arrays of primitive types, users need to specify custom converters. **Saving and Loading SequenceFiles** @@ -454,7 +454,7 @@ classes can be specified, but for standard Writables this is not required. **Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: @@ -474,15 +474,15 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. -A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided -for this. Simply extend this trait and implement your transformation code in the ```convert``` -method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark +A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided +for this. Simply extend this trait and implement your transformation code in the ```convert``` +method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark classpath. -See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
@@ -758,7 +758,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
@@ -777,7 +777,7 @@ println("Counter value: " + counter)
{% highlight java %} int counter = 0; -JavaRDD rdd = sc.parallelize(data); +JavaRDD rdd = sc.parallelize(data); // Wrong: Don't do this!! rdd.foreach(x -> counter += x); @@ -803,7 +803,7 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. @@ -813,9 +813,9 @@ To ensure well-defined behavior in these sorts of scenarios one should use an [` In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. -#### Printing elements of an RDD +#### Printing elements of an RDD Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. - + ### Working with Key-Value Pairs
@@ -859,7 +859,7 @@ only available on RDDs of key-value pairs. The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. -In Java, key-value pairs are represented using the +In Java, key-value pairs are represented using the [scala.Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) class from the Scala standard library. You can simply call `new Tuple2(a, b)` to create a tuple, and access its fields later with `tuple._1()` and `tuple._2()`. @@ -974,7 +974,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or aggregateByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -1025,7 +1025,7 @@ for details. repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, - sort records by their keys. This is more efficient than calling repartition and then sorting within + sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1038,7 +1038,7 @@ RDD API doc [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), [Python](api/python/pyspark.html#pyspark.RDD), [R](api/R/index.html)) - + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1094,7 +1094,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1118,13 +1118,13 @@ co-located to compute the result. In Spark, data is generally not distributed across partitions to be in the necessary place for a specific operation. During computations, a single task will operate on a single partition - thus, to organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an -all-to-all operation. It must read from all partitions to find all the values for all keys, -and then bring together values across partitions to compute the final result for each key - +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - this is called the **shuffle**. Although the set of elements in each partition of newly shuffled data will be deterministic, and so -is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably -ordered data following shuffle then it's possible to use: +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: * `mapPartitions` to sort each partition using, for example, `.sorted` * `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning @@ -1141,26 +1141,26 @@ network I/O. To organize data for the shuffle, Spark generates sets of tasks - * organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from MapReduce and does not directly relate to Spark's `map` and `reduce` operations. -Internally, results from individual map tasks are kept in memory until they can't fit. Then, these -are sorted based on the target partition and written to a single file. On the reduce side, tasks +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks read the relevant sorted blocks. - -Certain shuffle operations can consume significant amounts of heap memory since they employ -in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations -generate these on the reduce side. When data does not fit in memory Spark will spill these tables + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are preserved until the corresponding RDDs are no longer used and are garbage collected. -This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references -to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the -'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). ## RDD Persistence @@ -1246,7 +1246,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1345,7 +1345,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1474,8 +1474,8 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
-For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator -will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: @@ -1486,7 +1486,7 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being {% highlight scala %} val accum = sc.accumulator(0) data.map { x => accum += x; f(x) } -// Here, accum is still 0 because no actions have caused the `map` to be computed. +// Here, accum is still 0 because no actions have caused the map to be computed. {% endhighlight %}
@@ -1553,7 +1553,7 @@ Several changes were made to the Java API: code that `extends Function` should `implement Function` instead. * New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning +* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning `(Key, List)` pairs to `(Key, Iterable)`.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 330c159c67bca..460a66f37dd64 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -245,7 +245,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - If set to "true", runs over Mesos clusters in + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use @@ -254,16 +254,16 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores - 0 + 0 Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. spark.mesos.mesosExecutor.cores - 1.0 + 1.0 (Fine-grained mode only) Number of cores to give each Mesos executor. This does not include the cores used to run the Spark tasks. In other words, even if no Spark task @@ -287,7 +287,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the list of volumes which will be mounted into the Docker image, which was set using spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings following the form passed to docker run -v. That is they take the form: + mappings following the form passed to docker run -v. That is they take the form:
[host_path:]container_path[:ro|:rw]
@@ -318,7 +318,7 @@ See the [configuration page](configuration.html) for information on Spark config executor memory * 0.10, with minimum of 384 The amount of additional memory, specified in MB, to be allocated per executor. By default, - the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value. @@ -339,7 +339,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.secret - (none)/td> + (none) Set the secret with which Spark framework will use to authenticate with Mesos. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3a961d245f3de..0e25ccf512c02 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -23,7 +23,7 @@ Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.ht To launch a Spark application in `yarn-cluster` mode: $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - + For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ @@ -43,7 +43,7 @@ To launch a Spark application in `yarn-client` mode, do the same, but replace `y ## Adding Other JARs -In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ --master yarn-cluster \ @@ -64,16 +64,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Debugging your Application -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the `yarn logs` command. yarn logs -applicationId - + will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +large value (e.g. `36000`), and then access the application cache through `yarn.nodemanager.local-dirs` on the nodes on which containers are launched. This directory contains the launch script, JARs, and all environment variables used for launching each container. This process is useful for debugging classpath problems in particular. (Note that enabling this requires admin privileges on cluster @@ -92,7 +92,7 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. #### Spark Properties @@ -100,24 +100,26 @@ If you need a reference to the proper location to put log files in the YARN so t Property NameDefaultMeaning spark.yarn.am.memory - 512m + 512m Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). In cluster mode, use spark.driver.memory instead. +

+ Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. spark.driver.cores - 1 + 1 Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. spark.yarn.am.cores - 1 + 1 Number of cores to use for the YARN Application Master in client mode. In cluster mode, use spark.driver.cores instead. @@ -125,39 +127,39 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.am.waitTime - 100s + 100s - In `yarn-cluster` mode, time for the application master to wait for the - SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait + In yarn-cluster mode, time for the YARN Application Master to wait for the + SparkContext to be initialized. In yarn-client mode, time for the YARN Application Master to wait for the driver to connect to it. spark.yarn.submit.file.replication - The default HDFS replication (usually 3) + The default HDFS replication (usually 3) HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. spark.yarn.preserve.staging.files - false + false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. spark.yarn.scheduler.heartbeat.interval-ms - 3000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. - The value is capped at half the value of YARN's configuration for the expiry interval - (yarn.am.liveness-monitor.expiry-interval-ms). + The value is capped at half the value of YARN's configuration for the expiry interval, i.e. + yarn.am.liveness-monitor.expiry-interval-ms. spark.yarn.scheduler.initial-allocation.interval - 200ms + 200ms The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager when there are pending container allocation requests. It should be no larger than @@ -177,8 +179,8 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. - For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. + The address of the Spark history server, e.g. host.com:18080. The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For example, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to ${hadoopconf-yarn.resourcemanager.hostname}:18080. @@ -197,42 +199,42 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances - 2 + 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. spark.yarn.am.port (random) - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. spark.yarn.queue - default + default The name of the YARN queue to which the application is submitted. @@ -245,18 +247,18 @@ If you need a reference to the proper location to put log files in the YARN so t By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to "hdfs:///some/path". + set this configuration to hdfs:///some/path. spark.yarn.access.namenodes (none) - A list of secure HDFS namenodes your Spark application is going to access. For - example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. - The Spark application must have acess to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + A comma-separated list of secure HDFS namenodes your Spark application is going to access. For + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. + The Spark application must have access to the namenodes listed and Kerberos must + be properly configured to be able to access them (either in the same realm or in + a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -264,18 +266,18 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) - Add the environment variable specified by EnvironmentVariableName to the - Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In `yarn-cluster` mode this controls - the environment of the SPARK driver and in `yarn-client` mode it only controls - the environment of the executor launcher. + Add the environment variable specified by EnvironmentVariableName to the + Application Master process launched on YARN. The user can specify multiple of + these and to set multiple environment variables. In yarn-cluster mode this controls + the environment of the Spark driver and in yarn-client mode it only controls + the environment of the executor launcher. spark.yarn.containerLauncherMaxThreads - 25 + 25 - The maximum number of threads to use in the application master for launching executor containers. + The maximum number of threads to use in the YARN Application Master for launching executor containers. @@ -283,19 +285,19 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use `spark.driver.extraJavaOptions` instead. + In cluster mode, use spark.driver.extraJavaOptions instead. spark.yarn.am.extraLibraryPath (none) - Set a special library path to use when launching the application master in client mode. + Set a special library path to use when launching the YARN Application Master in client mode. spark.yarn.maxAppAttempts - yarn.resourcemanager.am.max-attempts in YARN + yarn.resourcemanager.am.max-attempts in YARN The maximum number of attempts that will be made to submit the application. It should be no larger than the global number of max attempts in the YARN configuration. @@ -303,10 +305,10 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.submit.waitAppCompletion - true + true In YARN cluster mode, controls whether the client waits to exit until the application completes. - If set to true, the client process will stay alive reporting the application's status. + If set to true, the client process will stay alive reporting the application's status. Otherwise, the client process will exit after submission. @@ -332,7 +334,7 @@ If you need a reference to the proper location to put log files in the YARN so t (none) The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, for renewing the login tickets and the delegation tokens periodically. @@ -371,14 +373,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.security.tokens.${service}.enabled - true + true Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. By default, delegation tokens for all supported services are retrieved when those services are configured, but it's possible to disable that behavior if it somehow conflicts with the application being run.

- Currently supported services are: hive, hbase + Currently supported services are: hive, hbase @@ -387,5 +389,5 @@ If you need a reference to the proper location to put log files in the YARN so t - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9244c271e3..a1cbc7de97c65 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). @@ -1706,8 +1706,8 @@ The following options can be used to configure the version of Hive that is used either 1.2.1 or not defined.

  • maven
  • Use Hive jars of specified version downloaded from Maven repositories. This configuration - is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + is not generally recommended for production deployments. +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1806,7 +1806,7 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( +val jdbcDF = sqlContext.read.format("jdbc").options( Map("url" -> "jdbc:postgresql:dbserver", "dbtable" -> "schema.tablename")).load() {% endhighlight %} @@ -2023,11 +2023,11 @@ options. - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with code generation for expression evaluation. These features can both be disabled by setting - `spark.sql.tungsten.enabled` to `false. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.tungsten.enabled` to `false`. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 7ea4d6f1a3f8f..915be0f479157 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -103,7 +103,7 @@ run it with `--help`. Here are a few examples of common options: export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ # can also be `yarn-client` for client mode + --master yarn-cluster \ # can also be yarn-client for client mode --executor-memory 20G \ --num-executors 50 \ /path/to/examples.jar \ @@ -174,9 +174,9 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. -Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates -with `--packages`. All transitive dependencies will be handled when using this command. Additional -repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries From 331f0b10f78a37d96d3e573d211d74a0935265db Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Mon, 21 Sep 2015 12:09:00 -0700 Subject: [PATCH 0162/2089] [SPARK-9642] [ML] LinearRegression should supported weighted data In many modeling application, data points are not necessarily sampled with equal probabilities. Linear regression should support weighting which account the over or under sampling. work in progress. Author: Meihua Wu Closes #8631 from rotationsymmetry/SPARK-9642. --- .../ml/regression/LinearRegression.scala | 164 +++++++++++------- .../ml/regression/LinearRegressionSuite.scala | 88 ++++++++++ project/MimaExcludes.scala | 8 +- 3 files changed, 191 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e4602d36ccc87..78a67c5fdab20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,21 +31,29 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.functions.{col, udf, lit} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization + with HasFitIntercept with HasStandardization with HasWeightCol + +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[regression] case class Instance(label: Double, weight: Double, features: Vector) /** * :: Experimental :: @@ -123,30 +131,43 @@ class LinearRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, statCounter) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new StatCounter))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), - (label: Double, features: Vector)) => - (summarizer.add(features), statCounter.merge(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), - (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => - (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) - }) - - val numFeatures = summarizer.mean.size - val yMean = statCounter.mean - val yStd = math.sqrt(statCounter.variance) + val (featuresSummarizer, ySummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val numFeatures = featuresSummarizer.mean.size + val yMean = ySummarizer.mean(0) + val yStd = math.sqrt(ySummarizer.variance(0)) // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -167,8 +188,8 @@ class LinearRegression(override val uid: String) return copyValues(model.setSummary(trainingSummary)) } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. @@ -318,7 +339,8 @@ class LinearRegressionModel private[ml] ( /** * :: Experimental :: - * Linear regression training results. + * Linear regression training results. Currently, the training summary ignores the + * training weights except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -477,7 +499,7 @@ class LinearRegressionSummary private[regression] ( * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param labelStd The standard deviation value of the label. * @param labelMean The mean value of the label. * @param fitIntercept Whether to fit an intercept term. @@ -485,7 +507,7 @@ class LinearRegressionSummary private[regression] ( * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( - weights: Vector, + coefficients: Vector, labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -493,26 +515,28 @@ private class LeastSquaresAggregator( featuresMean: Array[Double]) extends Serializable { private var totalCnt: Long = 0L + private var weightSum: Double = 0.0 private var lossSum = 0.0 - private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { - val weightsArray = weights.toArray.clone() + private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { + val coefficientsArray = coefficients.toArray.clone() var sum = 0.0 var i = 0 - val len = weightsArray.length + val len = coefficientsArray.length while (i < len) { if (featuresStd(i) != 0.0) { - weightsArray(i) /= featuresStd(i) - sum += weightsArray(i) * featuresMean(i) + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) } else { - weightsArray(i) = 0.0 + coefficientsArray(i) = 0.0 } i += 1 } - (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (coefficientsArray, offset, coefficientsArray.length) } - private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) private val gradientSumArray = Array.ofDim[Double](dim) @@ -520,30 +544,33 @@ private class LeastSquaresAggregator( * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The data point instance to be added. * @return This LeastSquaresAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") + def add(instance: Instance): this.type = + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + if (weight == 0.0) return this - if (diff != 0) { - val localGradientSumArray = gradientSumArray - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += diff * value / featuresStd(index) + val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / featuresStd(index) + } } + lossSum += weight * diff * diff / 2.0 } - lossSum += diff * diff / 2.0 - } - totalCnt += 1 - this - } + totalCnt += 1 + weightSum += weight + this + } /** * Merge another LeastSquaresAggregator, and update the loss and gradient @@ -557,8 +584,9 @@ private class LeastSquaresAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { + if (other.weightSum != 0) { totalCnt += other.totalCnt + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -574,11 +602,17 @@ private class LeastSquaresAggregator( def count: Long = totalCnt - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -589,7 +623,7 @@ private class LeastSquaresAggregator( * It's used in Breeze's convex optimization routines. */ private class LeastSquaresCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -598,17 +632,13 @@ private class LeastSquaresCostFun( featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { - val w = Vectors.fromBreeze(weights) + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeff = Vectors.fromBreeze(coefficients) - val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd, labelMean, fitIntercept, featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + seqOp = (aggregator, instance) => aggregator.add(instance), + combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) val totalGradientArray = leastSquaresAggregator.gradient.toArray @@ -616,7 +646,7 @@ private class LeastSquaresCostFun( 0.0 } else { var sum = 0.0 - w.foreachActive { (index, value) => + coeff.foreachActive { (index, value) => // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc734..8428f4f00b370 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .zip(testSummary.residuals.select("residuals").collect()) .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } } + + test("linear regression with weighted samples"){ + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 814a11e588ceb..b2e6be706637b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,10 +70,14 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ - Seq( + ) ++ Seq( ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this") ) case v if v.startsWith("1.5") => Seq( From b78c65b03ae87a3ba348c9d29ff4c296349eb49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?hushan=5B=E8=83=A1=E7=8F=8A=5D?= Date: Mon, 21 Sep 2015 14:26:15 -0500 Subject: [PATCH 0163/2089] [SPARK-5259] [CORE] don't submit stage until its dependencies map outputs are registered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track pending tasks by partition ID instead of Task objects. Before this change, failure & retry could result in a case where a stage got submitted before the map output from its dependencies get registered. This was due to an error in the condition for registering map outputs. Author: hushan[胡çŠ] Author: Imran Rashid Closes #7699 from squito/SPARK-5259. --- .../apache/spark/scheduler/DAGScheduler.scala | 12 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 197 ++++++++++++++++-- 4 files changed, 191 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c9a66e504403..394228b2728d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -944,7 +944,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { @@ -1060,8 +1060,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1152,7 +1152,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1198,7 +1198,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1242,7 +1242,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b37eccbd0f7b8..a3829c319c48d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -66,7 +66,7 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 62af9031b9f8b..c02597c4365c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,8 +487,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1c55f90ad9b44..6b5bcf0574de6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -479,8 +479,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -490,7 +490,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -782,8 +782,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -1035,6 +1035,173 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we + * have completions from both the first & second attempt of stage 1. So all the map output is + * available before we finish any task set for stage 1. We want to make sure that we don't + * submit stage 2 until the map output for stage 1 is registered + */ + test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, + null, + createFakeTaskInfo(), + null)) + + // so we resubmit stage 0, which completes happily + scheduler.resubmitFailedStages() + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent( + task, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + runEvent(CompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + // late task finish from the first attempt + runEvent(CompletionEvent( + taskSets(1).tasks(2), + Success, + makeMapStatus("hostB", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) and also the newly available stage 1 + val stageToReduceIdxs = Seq( + 0 -> (0 until 3), + 1 -> (0 until 1) + ) + for { + (stage, reduceIdxs) <- stageToReduceIdxs + reduceIdx <- reduceIdxs + } { + // this would throw an exception if the map status hadn't been registered + val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + + /** + * We lose an executor after completing some shuffle map tasks on it. Those tasks get + * resubmitted, and when they finish the job completes normally + */ + test("register map outputs correctly after ExecutorLost and task Resubmitted") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) + submit(reduceRdd, Array(0)) + + // complete some of the tasks from the first stage, on one host + runEvent(CompletionEvent( + taskSets(0).tasks(0), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + + // now that host goes down + runEvent(ExecutorLost("exec-hostA")) + + // so we resubmit those tasks + runEvent(CompletionEvent( + taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + + // now complete everything on a different host + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)) + )) + + // now we should submit stage 1, and the map output from stage 0 should be registered + + // check that we have all the map output for stage 0 + (0 until reduceRdd.partitions.length).foreach { reduceIdx => + val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 1 has been submitted + assert(taskSets.size == 2) + val stage1TaskSet = taskSets(1) + assert(stage1TaskSet.stageId == 1) + assert(stage1TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -1393,8 +1560,8 @@ class DAGSchedulerSuite // Submit a map stage by itself submitMapStage(shuffleDep) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) assert(results.size === 1) results.clear() assertDataStructuresEmpty() @@ -1407,7 +1574,7 @@ class DAGSchedulerSuite // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) results.clear() @@ -1452,8 +1619,8 @@ class DAGSchedulerSuite // Complete the first stage assert(taskSets(0).stageId === 0) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", rdd1.partitions.size)), - (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -1461,7 +1628,7 @@ class DAGSchedulerSuite // When attempting the second stage, show a fetch failure assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (Success, makeMapStatus("hostA", rdd2.partitions.length)), (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -1469,7 +1636,7 @@ class DAGSchedulerSuite // Stage 0 should now be running as task set 2; make its task succeed assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( - (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -1477,8 +1644,8 @@ class DAGSchedulerSuite // Stage 1 should now be running as task set 3; make its first task succeed assert(taskSets(3).stageId === 1) complete(taskSets(3), Seq( - (Success, makeMapStatus("hostB", rdd2.partitions.size)), - (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + (Success, makeMapStatus("hostB", rdd2.partitions.length)), + (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) assert(listener2.results.size === 1) @@ -1494,7 +1661,7 @@ class DAGSchedulerSuite // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 assert(taskSets(5).stageId === 1) complete(taskSets(5), Seq( - (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + (Success, makeMapStatus("hostE", rdd2.partitions.length)))) complete(taskSets(6), Seq( (Success, 53))) assert(listener3.results === Map(0 -> 52, 1 -> 53)) From ba882db6f43dd2bc05675133158e4664ed07030a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:06:23 -0700 Subject: [PATCH 0164/2089] [SPARK-9769] [ML] [PY] add python api for countvectorizermodel From JIRA: Add Python API, user guide and example for ml.feature.CountVectorizerModel Author: Holden Karau Closes #8561 from holdenk/SPARK-9769-add-python-api-for-countvectorizermodel. --- python/pyspark/ml/feature.py | 148 +++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92db8df80280b..f41d72f877256 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,12 +26,13 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', - 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', - 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] +__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', + 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -171,6 +172,141 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. + >>> df = sqlContext.createDataFrame( + ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ... ["label", "raw"]) + >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> model = cv.fit(df) + >>> model.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... + >>> sorted(map(str, model.vocabulary)) + ['a', 'b', 'c'] + """ + + # a placeholder to make it appear in the generated doc + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0") + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + + @keyword_only + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + """ + super(CountVectorizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", + self.uid) + self.minTF = Param( + self, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then " + + "this specifies a fraction (out of the document's token count). Note that the " + + "parameter is only used in transform of CountVectorizerModel and does not affect" + + "fitting. Default 1.0") + self.minDF = Param( + self, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of " + + "documents. Default 1.0") + self.vocabSize = Param( + self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + Set the params for the CountVectorizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + self._paramMap[self.minTF] = value + return self + + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + def setMinDF(self, value): + """ + Sets the value of :py:attr:`minDF`. + """ + self._paramMap[self.minDF] = value + return self + + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + def setVocabSize(self, value): + """ + Sets the value of :py:attr:`vocabSize`. + """ + self._paramMap[self.vocabSize] = value + return self + + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + def _create_model(self, java_model): + return CountVectorizerModel(java_model) + + +class CountVectorizerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by CountVectorizer. + """ + + @property + def vocabulary(self): + """ + An array of terms in the vocabulary. + """ + return self._call_java("vocabulary") + + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ From aeef44a3e32b53f7adecc8e9cfd684fb4598e87d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 21 Sep 2015 13:11:28 -0700 Subject: [PATCH 0165/2089] [SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing Implementation of significance testing using Streaming API. Author: Feynman Liang Author: Feynman Liang Closes #4716 from feynmanliang/ab_testing. --- .../examples/mllib/StreamingTestExample.scala | 90 +++++++ .../spark/mllib/stat/test/StreamingTest.scala | 145 +++++++++++ .../mllib/stat/test/StreamingTestMethod.scala | 167 ++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 22 ++ .../spark/mllib/stat/StreamingTestSuite.scala | 243 ++++++++++++++++++ 5 files changed, 667 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala new file mode 100644 index 0000000000000..ab29f90254d34 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * StreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object StreamingTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: StreamingTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => (label.toBoolean, value.toDouble) + }) + + val streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = streamingTest.registerStream(data) + out.print() + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala new file mode 100644 index 0000000000000..75c6a51d09571 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. + * + * To address novelty affects, the `peacePeriod` specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The `windowSize` sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. + * + * Use a builder pattern to construct a streaming test in an application, for example: + * {{{ + * val model = new StreamingTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * }}} + */ +@Experimental +@Since("1.6.0") +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest + + /** Set the number of initial batches to ignore. Default: 0. */ + @Since("1.6.0") + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. Default: 0. + * A value of 0 will use all batches seen so far. + */ + @Since("1.6.0") + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. Default: "welch" */ + @Since("1.6.0") + def setTestMethod(method: String): this.type = { + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of (key,value) pairs where the key denotes group membership (true = + * experiment, false = control) and the value is the numerical metric to test for + * significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + + testMethod.doTest(pairedSummaries) + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + data.context.sparkContext.parallelize(Seq()) + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + if (this.windowSize == 0) { + data.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + data + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala new file mode 100644 index 0000000000000..a7eaed51b4d55 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] + */ +private[stat] sealed trait StreamingTestMethod extends Serializable { + + val methodName: String + val nullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform streaming 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] + + /** + * Implicit adapter to convert between streaming summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new StreamingTestResult( + tTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * @see http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new StreamingTestResult( + tTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. + * + * Currently supported tests: `welch`, `student`. + */ +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch"->WelchTTest, + "student"->StudentTTest) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index d01b3707be944..b0916d3e84651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] ( "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for streaming testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class StreamingTestResult @Since("1.6.0") ( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "Streaming test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala new file mode 100644 index 0000000000000..d3e9ef4ff079c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter +import org.apache.spark.util.random.XORShiftRandom + +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.methodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.methodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_._1)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + (false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} From 97a99dde6e8d69a4c4c135dc1d9b1520b2548b5b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 21 Sep 2015 13:15:44 -0700 Subject: [PATCH 0166/2089] [SPARK-10676] [DOCS] Add documentation for SASL encryption options. Author: Marcelo Vanzin Closes #8803 from vanzin/SPARK-10676. --- docs/configuration.md | 16 ++++++++++++++++ docs/security.md | 22 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index b22587c70316b..284f97ad09ec3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1285,6 +1285,22 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + spark.authenticate.enableSaslEncryption + false + + Enable encrypted communication when authentication is enabled. This option is currently + only supported by the block transfer service. + + + + spark.network.sasl.serverAlwaysEncrypt + false + + Disable unencrypted connections for services that support SASL authentication. This is + currently supported by the external shuffle service. + + spark.core.connection.ack.wait.timeout 60s diff --git a/docs/security.md b/docs/security.md index d4ffa60e59a33..177109415180b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,9 +23,16 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +supported for the block transfer service. Encryption is not yet supported for the WebUI. -Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle +files, cached data, and other application files. If encrypting this data is desired, a workaround is +to configure your cluster manager to store application data on encrypted disks. + +### SSL Configuration + +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -47,6 +54,17 @@ follows: * Import all exported public keys into a single trust-store * Distribute the trust-store over the nodes +### Configuring SASL Encryption + +SASL encryption is currently supported for the block transfer service when authentication +(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set +`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + +When using an external shuffle service, it's possible to disable unencrypted connections by setting +`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that +option is enabled, applications that are not set up to use SASL encryption will fail to connect to +the shuffle service. + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight From 362539f8d97f6bb67f0d0983f7dea36b77cc9d18 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:33:10 -0700 Subject: [PATCH 0167/2089] [SPARK-10630] [SQL] Add a createDataFrame API that takes in a java list It would be nice to support creating a DataFrame directly from a Java List of Row. Author: Holden Karau Closes #8779 from holdenk/SPARK-10630-create-DataFrame-from-Java-List. --- .../scala/org/apache/spark/sql/SQLContext.scala | 14 ++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f099940800cc0..1bd4e26fb3162 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -476,6 +476,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 1.6.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + /** * Applies a schema to an RDD of Java Beans. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5f9abd4999ce0..250ac2e1092d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -37,6 +37,7 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -181,6 +182,15 @@ public void testCreateDataFrameFromJavaBeans() { } } + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + private static final Comparator crosstabRowComparator = new Comparator() { @Override public int compare(Row row1, Row row2) { From 7c4f852bfc39537840f56cd8121457a0dc1ad7c1 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 21 Sep 2015 14:24:19 -0700 Subject: [PATCH 0168/2089] [DOC] [PYSPARK] [MLLIB] Added newlines to docstrings to fix parameter formatting Added newlines before `:param ...:` and `:return:` markup. Without these, parameter lists aren't formatted correctly in the API docs. I.e: ![screen shot 2015-09-21 at 21 49 26](https://cloud.githubusercontent.com/assets/11915197/10004686/de3c41d4-60aa-11e5-9c50-a46dcb51243f.png) .. looks like this once newline is added: ![screen shot 2015-09-21 at 21 50 14](https://cloud.githubusercontent.com/assets/11915197/10004706/f86bfb08-60aa-11e5-8524-ae4436713502.png) Author: noelsmith Closes #8851 from noel-smith/docstring-missing-newline-fix. --- python/pyspark/ml/param/__init__.py | 4 ++++ python/pyspark/ml/pipeline.py | 1 + python/pyspark/ml/tuning.py | 2 ++ python/pyspark/ml/wrapper.py | 2 ++ python/pyspark/mllib/evaluation.py | 2 +- python/pyspark/mllib/linalg/__init__.py | 1 + python/pyspark/streaming/context.py | 2 ++ python/pyspark/streaming/mqtt.py | 1 + 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index eeeac49b21980..2e0c63cb47b17 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -164,6 +164,7 @@ def extractParamMap(self, extra=None): a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra. + :param extra: extra param values :return: merged param map """ @@ -182,6 +183,7 @@ def copy(self, extra=None): embedded and extra parameters over and returns the copy. Subclasses should override this method if the default approach is not sufficient. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -201,6 +203,7 @@ def _shouldOwn(self, param): def _resolveParam(self, param): """ Resolves a param and validates the ownership. + :param param: param name or the param instance, which must belong to this Params instance :return: resolved param instance @@ -243,6 +246,7 @@ def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. + :param to: the target instance :param extra: extra params to be copied :return: the target instance with param values copied diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 13cf2b0f7bbd9..312a8502b3a2c 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -154,6 +154,7 @@ def __init__(self, stages=None): def setStages(self, value): """ Set pipeline stages. + :param value: a list of transformers or estimators :return: the pipeline instance """ diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ab5621f45c72c..705ee53685752 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -254,6 +254,7 @@ def copy(self, extra=None): Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -290,6 +291,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 8218c7c5f801c..4bcb4aaec89de 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -119,6 +119,7 @@ def _create_model(self, java_model): def _fit_java(self, dataset): """ Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) @@ -173,6 +174,7 @@ def copy(self, extra=None): extra params. This implementation first calls Params.copy and then make a copy of the companion Java model with extra params. So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4398ca86f2ec2..a90e5c50e54b9 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -147,7 +147,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels an RDD of (prediction, label) pairs. + :param predictionAndLabels: an RDD of (prediction, label) pairs. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index f929e3e96fbe2..ea42127f1651f 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -240,6 +240,7 @@ class Vector(object): def toArray(self): """ Convert the vector into an numpy.ndarray + :return: numpy.ndarray """ raise NotImplementedError diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4069d7a149986..a8c9ffc235b9e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -240,6 +240,7 @@ def start(self): def awaitTermination(self, timeout=None): """ Wait for the execution to stop. + @param timeout: time to wait in seconds """ if timeout is None: @@ -252,6 +253,7 @@ def awaitTerminationOrTimeout(self, timeout): Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds """ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index f06598971c548..fa83006c36db6 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -31,6 +31,7 @@ def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object :param brokerUrl: Url of remote mqtt publisher :param topic: topic name to subscribe to From 72869883f12b6e0a4e5aad79c0ac2cfdb4d83f09 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Sep 2015 16:47:52 -0700 Subject: [PATCH 0169/2089] [SPARK-10649] [STREAMING] Prevent inheriting job group and irrelevant job description in streaming jobs The job group, and job descriptions information is passed through thread local properties, and get inherited by child threads. In case of spark streaming, the streaming jobs inherit these properties from the thread that called streamingContext.start(). This may not make sense. 1. Job group: This is mainly used for cancelling a group of jobs together. It does not make sense to cancel streaming jobs like this, as the effect will be unpredictable. And its not a valid usecase any way, to cancel a streaming context, call streamingContext.stop() 2. Job description: This is used to pass on nice text descriptions for jobs to show up in the UI. The job description of the thread that calls streamingContext.start() is not useful for all the streaming jobs, as it does not make sense for all of the streaming jobs to have the same description, and the description may or may not be related to streaming. The solution in this PR is meant for the Spark master branch, where local properties are inherited by cloning the properties. The job group and job description in the thread that starts the streaming scheduler are explicitly removed, so that all the subsequent child threads does not inherit them. Also, the starting is done in a new child thread, so that setting the job group and description for streaming, does not change those properties in the thread that called streamingContext.start(). Author: Tathagata Das Closes #8781 from tdas/SPARK-10649. --- .../org/apache/spark/util/ThreadUtils.scala | 59 +++++++++++++++++++ .../apache/spark/util/ThreadUtilsSuite.scala | 24 +++++++- .../spark/streaming/StreamingContext.scala | 15 ++++- .../streaming/StreamingContextSuite.scala | 32 ++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3d..22e291a2b48d6 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -21,6 +21,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -86,4 +87,62 @@ private[spark] object ThreadUtils { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() Executors.newSingleThreadScheduledExecutor(threadFactory) } + + /** + * Run a piece of code in a new thread and return the result. Exception in the new thread is + * thrown in the caller thread with an adjusted stack trace that removes references to this + * method for clarity. The exception stack traces will be like the following + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def runInNewThread[T]( + threadName: String, + isDaemon: Boolean = true)(body: => T): T = { + @volatile var exception: Option[Throwable] = None + @volatile var result: T = null.asInstanceOf[T] + + val thread = new Thread(threadName) { + override def run(): Unit = { + try { + result = body + } catch { + case NonFatal(e) => + exception = Some(e) + } + } + } + thread.setDaemon(isDaemon) + thread.start() + thread.join() + + exception match { + case Some(realException) => + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). + val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace.takeWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement( + s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", + " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + throw realException + case None => + result + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 8c51e6b14b7fc..620e4debf4e08 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.Random import org.apache.spark.SparkFunSuite @@ -66,4 +67,25 @@ class ThreadUtilsSuite extends SparkFunSuite { val futureThreadName = Await.result(f, 10.seconds) assert(futureThreadName === callerThreadName) } + + test("runInNewThread") { + import ThreadUtils._ + assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") + assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) + assert( + runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false + ) + val uniqueExceptionMessage = "test" + Random.nextInt() + val exception = intercept[IllegalArgumentException] { + runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } + } + assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) + assert(exception.getStackTrace.mkString("\n").contains( + "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, + "stack trace does not contain expected place holder" + ) + assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + "stack trace contains unexpected references to ThreadUtils" + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b496d1f341a0b..6720ba4f72cf3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -588,12 +588,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -618,6 +626,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index d26894e88fc26..3b9d0d15ea04c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound === null) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) From 0494c80ef54f6f3a8c6f2d92abfe1a77a91df8b0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 21 Sep 2015 18:06:45 -0700 Subject: [PATCH 0170/2089] [SPARK-10495] [SQL] Read date values in JSON data stored by Spark 1.5.0. https://issues.apache.org/jira/browse/SPARK-10681 Author: Yin Huai Closes #8806 from yhuai/SPARK-10495. --- .../datasources/json/JacksonGenerator.scala | 36 ++++++ .../datasources/json/JacksonParser.scala | 15 ++- .../datasources/json/JsonSuite.scala | 103 +++++++++++++++++- 3 files changed, 152 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index f65c7bbd6e29d..23bada1ddd92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -73,6 +73,38 @@ private[sql] object JacksonGenerator { valWriter(field.dataType, v) } gen.writeEndObject() + + // For UDT, udt.serialize will produce SQL types. So, we need the following three cases. + case (ArrayType(ty, _), v: ArrayData) => + gen.writeStartArray() + v.foreach(ty, (_, value) => valWriter(ty, value)) + gen.writeEndArray() + + case (MapType(kt, vt, _), v: MapData) => + gen.writeStartObject() + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) + gen.writeEndObject() + + case (StructType(ty), v: InternalRow) => + gen.writeStartObject() + var i = 0 + while (i < ty.length) { + val field = ty(i) + val value = v.get(i, field.dataType) + if (value != null) { + gen.writeFieldName(field.name) + valWriter(field.dataType, value) + } + i += 1 + } + gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) @@ -133,6 +165,10 @@ private[sql] object JacksonGenerator { i += 1 } gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index ff4d8c04e8eaf..c51140749c8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -62,10 +62,23 @@ private[sql] object JacksonParser { // guard the non string type null + case (VALUE_STRING, BinaryType) => + parser.getBinaryValue + case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + val stringValue = parser.getText + if (stringValue.contains("-")) { + // The format of this string will probably be "yyyy-mm-dd". + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + } else { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } case (VALUE_STRING, TimestampType) => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6a18cc6d27138..b614e6c4148fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType @@ -1159,4 +1159,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } } From c986e933a900602af47966bd41edb2116c421a39 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 21 Sep 2015 21:09:59 -0700 Subject: [PATCH 0171/2089] [SPARK-10711] [SPARKR] Do not assume spark.submit.deployMode is always set In ```RUtils.sparkRPackagePath()``` we 1. Call ``` sys.props("spark.submit.deployMode")``` which returns null if ```spark.submit.deployMode``` is not suet 2. Call ``` sparkConf.get("spark.submit.deployMode")``` which throws ```NoSuchElementException``` if ```spark.submit.deployMode``` is not set. This patch simply passes a default value ("cluster") for ```spark.submit.deployMode```. cc rxin Author: Hossein Closes #8832 from falaki/SPARK-10711. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 9e807cc52f18c..fd5646b5b6372 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -44,7 +44,7 @@ private[spark] object RUtils { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" From 1cd67415728e660a90e4dbe136272b5d6b8f1142 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 23:21:24 -0700 Subject: [PATCH 0172/2089] [SPARK-9821] [PYSPARK] pyspark-reduceByKey-should-take-a-custom-partitioner from the issue: In Scala, I can supply a custom partitioner to reduceByKey (and other aggregation/repartitioning methods like aggregateByKey and combinedByKey), but as far as I can tell from the Pyspark API, there's no way to do the same in Python. Here's an example of my code in Scala: weblogs.map(s => (getFileType(s), 1)).reduceByKey(new FileTypePartitioner(),_+_) But I can't figure out how to do the same in Python. The closest I can get is to call repartition before reduceByKey like so: weblogs.map(lambda s: (getFileType(s), 1)).partitionBy(3,hash_filetype).reduceByKey(lambda v1,v2: v1+v2).collect() But that defeats the purpose, because I'm shuffling twice instead of once, so my performance is worse instead of better. Author: Holden Karau Closes #8569 from holdenk/SPARK-9821-pyspark-reduceByKey-should-take-a-custom-partitioner. --- python/pyspark/rdd.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 73d7d9a5692a9..56e892243c79c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -686,7 +686,7 @@ def cartesian(self, other): other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None): + def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ Return an RDD of grouped items. @@ -695,7 +695,7 @@ def groupBy(self, f, numPartitions=None): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) @ignore_unicode_prefix def pipe(self, command, env=None, checkCode=False): @@ -1539,22 +1539,23 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numPartitions} partitions, or + Output will be partitioned with C{numPartitions} partitions, or the default parallelism level if C{numPartitions} is not specified. + Default partitioner is hash-partition. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numPartitions) + return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): """ @@ -1739,7 +1740,7 @@ def add_shuffle_key(split, iterator): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numPartitions=None): + numPartitions=None, partitionFunc=portable_hash): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -1777,7 +1778,7 @@ def combineLocally(iterator): return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) @@ -1786,7 +1787,8 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) - def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, + partitionFunc=portable_hash): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -1800,9 +1802,9 @@ def createZero(): return copy.deepcopy(zeroValue) return self.combineByKey( - lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc) - def foldByKey(self, zeroValue, func, numPartitions=None): + def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -1817,13 +1819,14 @@ def foldByKey(self, zeroValue, func, numPartitions=None): def createZero(): return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions, + partitionFunc) def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None): + def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -1859,7 +1862,7 @@ def combine(iterator): return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer) From bf20d6c9f9e478a5de24b45bbafd4dd89666c4cf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Sep 2015 23:29:59 -0700 Subject: [PATCH 0173/2089] [SPARK-10716] [BUILD] spark-1.5.0-bin-hadoop2.6.tgz file doesn't uncompress on OS X due to hidden file Remove ._SUCCESS.crc hidden file that may cause problems in distribution tar archive, and is not used Author: Sean Owen Closes #8846 from srowen/SPARK-10716. --- .../test_support/sql/orc_partitioned/._SUCCESS.crc | Bin 8 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/test_support/sql/orc_partitioned/._SUCCESS.crc diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc deleted file mode 100644 index 3b7b044936a890cd8d651d349a752d819d71d22c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8 PcmYc;N@ieSU}69O2$TUk From 0180b849dbaf191826231eda7dfaaf146a19602b Mon Sep 17 00:00:00 2001 From: Jian Feng Date: Mon, 21 Sep 2015 23:36:41 -0700 Subject: [PATCH 0174/2089] [SPARK-10577] [PYSPARK] DataFrame hint for broadcast join https://issues.apache.org/jira/browse/SPARK-10577 Author: Jian Feng Closes #8801 from Jianfeng-chs/master. --- python/pyspark/sql/functions.py | 9 +++++++++ python/pyspark/sql/tests.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26b8662718a60..fa04f4cd83b6f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -29,6 +29,7 @@ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import DataFrame def _create_function(name, doc=""): @@ -189,6 +190,14 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@since(1.6) +def broadcast(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3e680f1030a71..645133b2b2d84 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1075,6 +1075,24 @@ def foo(): self.assertRaises(TypeError, foo) + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() + class HiveContextSQLTests(ReusedPySparkTestCase): From 781b21ba2a873ed29394c8dbc74fc700e3e0d17e Mon Sep 17 00:00:00 2001 From: Ewan Leith Date: Mon, 21 Sep 2015 23:43:20 -0700 Subject: [PATCH 0175/2089] [SPARK-10419] [SQL] Adding SQLServer support for datetimeoffset types to JdbcDialects Reading from Microsoft SQL Server over jdbc fails when the table contains datetimeoffset types. This patch registers a SQLServer JDBC Dialect that maps datetimeoffset to a String, as Microsoft suggest. Author: Ewan Leith Closes #8575 from realitymine-coordinator/sqlserver. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 68ebaaca6c53d..c70fea1c3f50e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -137,6 +137,8 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -260,3 +262,19 @@ case object DB2Dialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. + */ +@DeveloperApi +case object MsSqlServerDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Some(StringType) + } else None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5ab9381de4d66..c4b039a9c5359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -408,6 +408,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } From 1fcefef06950e2f03477282368ca835bbf40ff24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Sep 2015 23:46:00 -0700 Subject: [PATCH 0176/2089] [SPARK-10446][SQL] Support to specify join type when calling join with usingColumns JIRA: https://issues.apache.org/jira/browse/SPARK-10446 Currently the method `join(right: DataFrame, usingColumns: Seq[String])` only supports inner join. It is more convenient to have it support other join types. Author: Liang-Chi Hsieh Closes #8600 from viirya/usingcolumns_df. --- python/pyspark/sql/dataframe.py | 6 ++++- .../org/apache/spark/sql/DataFrame.scala | 22 ++++++++++++++++++- .../apache/spark/sql/DataFrameJoinSuite.scala | 13 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fb995fa3a76b5..80f8d8a0eb37d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -567,7 +567,11 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): - jdf = self._jdf.join(other._jdf, self._jseq(on)) + if how is None: + jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8f737c2023931..ba94d77b2e60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -484,6 +484,26 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @group dfops + * @since 1.6.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( @@ -502,7 +522,7 @@ class DataFrame private[sql]( Join( joined.left, joined.right, - joinType = Inner, + joinType = JoinType(joinType), condition) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e2716d7841d85..56ad71ea4f487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") From f24316e6d928c263cbf3872edd97982059c3db22 Mon Sep 17 00:00:00 2001 From: Madhusudanan Kandasamy Date: Tue, 22 Sep 2015 00:03:48 -0700 Subject: [PATCH 0177/2089] [SPARK-10458] [SPARK CORE] Added isStopped() method in SparkContext Added isStopped() method in SparkContext Author: Madhusudanan Kandasamy Closes #8749 from kmadhugit/SPARK-10458. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ebd8e946ee7a2..967fec9f42bcf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus From fd61b004877ba4d51c95cd0e08f53bffdf106395 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 00:05:30 -0700 Subject: [PATCH 0178/2089] [Minor] style fix for previous commit f24316e --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 967fec9f42bcf..bf3aeb488d597 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** * @return true if context is stopped or in the midst of stopping. */ From 4da32bc0e747fefe847bffe493785d4d16069c04 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 00:07:30 -0700 Subject: [PATCH 0179/2089] [SPARK-8567] [SQL] Increase the timeout of o.a.s.sql.hive.HiveSparkSubmitSuite to 5 minutes. https://issues.apache.org/jira/browse/SPARK-8567 Looks like "SPARK-8368: includes jars passed in through --jars" is pretty flaky now. Based on some history runs, the time spent on a successful run may be from 1.5 minutes to almost 3 minutes. Let's try to increase the timeout and see if we can fix this test. https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.5-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.0,label=spark-test/385/testReport/junit/org.apache.spark.sql.hive/HiveSparkSubmitSuite/SPARK_8368__includes_jars_passed_in_through___jars/history/?start=25 Author: Yin Huai Closes #8850 from yhuai/SPARK-8567-anotherTry. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 97df249bdb6d6..5f1660b62d418 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -139,7 +139,7 @@ class HiveSparkSubmitSuite new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180.seconds) { process.waitFor() } + val exitCode = failAfter(300.seconds) { process.waitFor() } if (exitCode != 0) { // include logs in output. Note that logging is async and may not have completed // at the time this exception is raised From f3b727c801408b1cd50e5d9463f2fe0fce654a16 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 00:09:29 -0700 Subject: [PATCH 0180/2089] [SQL] [MINOR] map -> foreach. DataFrame.explain should use foreach to print the explain content. Author: Reynold Xin Closes #8862 from rxin/map-foreach. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ba94d77b2e60e..a11140b717360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -320,9 +320,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ def explain(extended: Boolean): Unit = { - ExplainCommand( - queryExecution.logical, - extended = extended).queryExecution.executedPlan.executeCollect().map { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + explain.queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println From 0bd0e5bed2176b119b3ada590993e153757ea09b Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 22 Sep 2015 00:14:27 -0700 Subject: [PATCH 0181/2089] =?UTF-8?q?[SPARK-10695]=20[DOCUMENTATION]=20[ME?= =?UTF-8?q?SOS]=20Fixing=20incorrect=20value=20informati=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …on for spark.mesos.constraints parameter. Author: Akash Mishra Closes #8816 from SleepyThread/constraint-fix. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 460a66f37dd64..ec5a44d79212b 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -189,10 +189,10 @@ using `conf.set("spark.cores.max", "10")` (for example). You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support From 7278f792a73bbcf8d68f38dc2d87cf722693c4cf Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Tue, 22 Sep 2015 11:03:21 +0100 Subject: [PATCH 0182/2089] [SPARK-10718] [BUILD] Update License on conf files and corresponding excludes file update Update License on conf files and corresponding excludes file update Author: Rekha Joshi Author: Joshi Closes #8842 from rekhajoshm/SPARK-10718. --- .rat-excludes | 12 ------------ conf/docker.properties.template | 17 +++++++++++++++++ conf/fairscheduler.xml.template | 18 ++++++++++++++++++ conf/log4j.properties.template | 17 +++++++++++++++++ conf/metrics.properties.template | 17 +++++++++++++++++ conf/slaves.template | 17 +++++++++++++++++ conf/spark-defaults.conf.template | 17 +++++++++++++++++ conf/spark-env.sh.template | 17 +++++++++++++++++ .../spark/log4j-defaults-repl.properties | 17 +++++++++++++++++ .../org/apache/spark/log4j-defaults.properties | 17 +++++++++++++++++ 10 files changed, 154 insertions(+), 12 deletions(-) diff --git a/.rat-excludes b/.rat-excludes index 9165872b9fb27..08fba6d351d6a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9b..55cb094b4af46 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a35986..385b2e772d2c8 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 74c5cea94403a..f3046be54d7c6 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f5..d6962e0da2f30 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20a..be42a638230b7 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e1363..19cba6e71ed19 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c05fe381a36a7..990ded420be72 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index 689afea64f8db..c85abc35b93bf 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=WARN, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 27006e45e932b..d44cc85dcbd82 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender From 870b8a2edd44c9274c43ca0db4ef5b0998e16fd8 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Tue, 22 Sep 2015 11:05:24 +0100 Subject: [PATCH 0183/2089] [SPARK-10706] [MLLIB] Add java wrapper for random vector rdd Add java wrapper for random vector rdd holdenk srowen Author: Meihua Wu Closes #8841 from rotationsymmetry/SPARK-10706. --- .../spark/mllib/random/RandomRDDs.scala | 42 +++++++++++++++++++ .../mllib/random/JavaRandomRDDsSuite.java | 17 ++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index f8ff26b5795be..41d7c4d355f61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -855,6 +855,48 @@ object RandomRDDs { sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fce5f6712f462..5728df5aeebdc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -246,6 +246,23 @@ public void testArbitrary() { Assert.assertEquals(2, rdd.first().length()); } } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } } // This is just a test generator, it always returns a string of 42 From f4a3c4e34ce93bcaf29c0a35573932880a8b792b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Sep 2015 10:19:08 -0700 Subject: [PATCH 0184/2089] [SPARK-9962] [ML] Decision Tree training: prevNodeIdsForInstances.unpersist() at end of training NodeIdCache: prevNodeIdsForInstances.unpersist() needs to be called at end of training. Author: Holden Karau Closes #8541 from holdenk/SPARK-9962-decission-tree-training-prevNodeIdsForiNstances-unpersist-at-end-of-training. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++---- .../org/apache/spark/mllib/tree/impl/NodeIdCache.scala | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 488e8e4fb5dcd..c5ad8df73fac9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -164,10 +164,10 @@ private[spark] class NodeIdCache( } } } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 8f9eb24b57b55..0abed5411143d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -166,6 +166,10 @@ private[spark] class NodeIdCache( fs.delete(new Path(old.getCheckpointFile.get), true) } } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } From 7104ee0e5dc1290b8b845a0a8ddcdb1875cfd060 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Sep 2015 11:00:33 -0700 Subject: [PATCH 0185/2089] [SPARK-10750] [ML] ML Param validate should print better error information Currently when you set illegal value for params of array type (such as IntArrayParam, DoubleArrayParam, StringArrayParam), it will throw IllegalArgumentException but with incomprehensible error information. Take ```VectorSlicer.setNames``` as an example: ```scala val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") // The value of setNames must be contain distinct elements, so the next line will throw exception. vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4", "f1")) ``` It will throw IllegalArgumentException as: ``` vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. java.lang.IllegalArgumentException: vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. ``` We should distinguish the value of array type from primitive type at Param.validate(value: T), and we will get better error information. ``` vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. java.lang.IllegalArgumentException: vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. ``` Author: Yanbo Liang Closes #8863 from yanboliang/spark-10750. --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index de32b7218c277..48f6269e57e98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -65,7 +65,12 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + val valueToString = value match { + case v: Array[_] => v.mkString("[", ",", "]") + case _ => value.toString + } + throw new IllegalArgumentException( + s"$parent parameter $name given invalid value $valueToString.") } } From 2ea0f2e11b82ef4817c7e6a162ea23da7860b893 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 22 Sep 2015 11:01:32 -0700 Subject: [PATCH 0186/2089] [SPARK-9585] Delete the input format caching because some input format are non thread safe If we cache the InputFormat, all tasks on the same executor will share it. Some InputFormat is thread safety, but some are not, such as HiveHBaseTableInputFormat. If tasks share a non thread safe InputFormat, unexpected error may be occurs. To avoid it, I think we should delete the input format caching. Author: xutingjun Author: meiyoula <1039320815@qq.com> Author: Xutingjun Closes #7918 from XuTingjun/cached_inputFormat. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8f2655d63b797..77b57132b9f1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -182,17 +182,11 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] if (newInputFormat.isInstanceOf[Configurable]) { newInputFormat.asInstanceOf[Configurable].setConf(conf) } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } From 22d40159e60dd27a428e4051ef607292cbffbff3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 22 Sep 2015 11:07:01 -0700 Subject: [PATCH 0187/2089] [SPARK-10593] [SQL] fix resolve output of Generate The output of Generate should not be resolved as Reference. Author: Davies Liu Closes #8755 from davies/view. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 1 - .../catalyst/plans/logical/basicOperators.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 14 ++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02f34cbf58ad0..bf72d47ce1ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -378,6 +378,22 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by + // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. + case g @ Generate(generator, join, outer, qualifier, output, child) + if child.resolved && !generator.resolved => + val newG = generator transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldExpr) => + ExtractValue(child, fieldExpr, resolver) + } + if (newG.fastEquals(generator)) { + g + } else { + Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 55286f9f2fc5c..0ec9f08571082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 722f69cdca827..ae9482c10f126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -68,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8126d02335217..bb02473dd17ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1170,4 +1170,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sqlContext.table("`db.t`"), df) } } + + test("SPARK-10593 same column names in lateral view") { + val df = sqlContext.sql( + """ + |select + |insideLayer2.json as a2 + |from (select '{"layer1": {"layer2": "text inside layer 2"}}' json) test + |lateral view json_tuple(json, 'layer1') insideLayer1 as json + |lateral view json_tuple(insideLayer1.json, 'layer2') insideLayer2 as json + """.stripMargin + ) + + checkAnswer(df, Row("text inside layer 2") :: Nil) + } } From 1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 11:50:22 -0700 Subject: [PATCH 0188/2089] [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader. Author: Josh Rosen Closes #8825 from JoshRosen/shuffle-reader-cleanup. --- ...shShuffleReader.scala => BlockStoreShuffleReader.scala} | 5 ++--- .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- ...eaderSuite.scala => BlockStoreShuffleReaderSuite.scala} | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{hash/HashShuffleReader.scala => BlockStoreShuffleReader.scala} (97%) rename core/src/test/scala/org/apache/spark/shuffle/{hash/HashShuffleReaderSuite.scala => BlockStoreShuffleReaderSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1b..6dc9a16e58531 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b466..d2e2fc4c110a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303da7..9df4e551669cc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 96% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 05b3afef5b839..a5eafb1b5529e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -134,7 +133,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, From 5017c685f484ec256101d1d33bad11d9e0c0f641 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Sep 2015 12:14:15 -0700 Subject: [PATCH 0189/2089] [SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan Closes #8858 from cloud-fan/non-deter. --- .../sql/catalyst/optimizer/Optimizer.scala | 69 ++++++++++++++----- .../optimizer/SetOperationPushDownSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 41 +++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 324f40a051c38..63602eaa8ccd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] { * Intersect: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. */ -object SetOperationPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * Splits the condition expression into small conditions by `And`, and partition them by + * deterministic, and finally recombine them by `And`. It returns an expression containing + * all deterministic expressions (the first field of the returned Tuple2) and an expression + * containing all non-deterministic expressions (the second field of the returned Tuple2). + */ + private def partitionByDeterministic(condition: Expression): (Expression, Expression) = { + val andConditions = splitConjunctivePredicates(condition) + andConditions.partition(_.deterministic) match { + case (deterministic, nondeterministic) => + deterministic.reduceOption(And).getOrElse(Literal(true)) -> + nondeterministic.reduceOption(And).getOrElse(Literal(true)) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union case Filter(condition, u @ Union(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(u) - Union( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection through UNION ALL - case Project(projectList, u @ Union(left, right)) => - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Union( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, u @ Union(left, right)) => + if (projectList.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } else { + p + } // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(i) - Intersect( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Intersect( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(e) - Except( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Except( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 3fca47a023dc6..1595ad9327423 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - SetOperationPushDown) :: Nil + SetOperationPushDown, + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1370713975f2f..d919877746c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(intersect.count() === 30) assert(except.count() === 70) } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } } From 2204cdb28483b249616068085d4e88554fe6acef Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:29:39 -0700 Subject: [PATCH 0190/2089] [SPARK-10672] [SQL] Do not fail when we cannot save the metadata of a data source table in a hive compatible way https://issues.apache.org/jira/browse/SPARK-10672 With changes in this PR, we will fallback to same the metadata of a table in Spark SQL specific way if we fail to save it in a hive compatible way (Hive throws an exception because of its internal restrictions, e.g. binary and decimal types cannot be saved to parquet if the metastore is running Hive 0.13). I manually tested the fix with the following test in `DataSourceWithHiveMetastoreCatalogSuite` (`spark.sql.hive.metastore.version=0.13` and `spark.sql.hive.metastore.jars`=`maven`). ``` test(s"fail to save metadata of a parquet table in hive 0.13") { withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, cast("val_1" as binary) AS d2 """.stripMargin) sql( s"""describe formatted t """.stripMargin).collect.foreach(println) sqlContext.table("t").show } } } } ``` Without this fix, we will fail with the following error. ``` org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:619) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:576) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply$mcV$sp(ClientWrapper.scala:359) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:256) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:211) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:248) at org.apache.spark.sql.hive.client.ClientWrapper.createTable(ClientWrapper.scala:357) at org.apache.spark.sql.hive.HiveMetastoreCatalog.createDataSourceTable(HiveMetastoreCatalog.scala:358) at org.apache.spark.sql.hive.execution.CreateMetastoreDataSourceAsSelect.run(commands.scala:285) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult$lzycompute(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.doExecute(commands.scala:69) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:140) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:138) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:138) at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:58) at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:58) at org.apache.spark.sql.DataFrame.(DataFrame.scala:144) at org.apache.spark.sql.DataFrame.(DataFrame.scala:129) at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:725) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2$$anonfun$apply$2.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:165) at org.apache.spark.sql.test.SQLTestUtils$class.withTable(SQLTestUtils.scala:150) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:162) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.test.SQLTestUtils$class.withTempPath(SQLTestUtils.scala:125) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTempPath(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.scalatest.FunSuite.runTest(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:318) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.org$scalatest$BeforeAndAfterAll$$super$run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:462) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:671) at sbt.ForkMain$Run$2.call(ForkMain.java:294) at sbt.ForkMain$Run$2.call(ForkMain.java:284) at java.util.concurrent.FutureTask.run(FutureTask.java:262) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.getObjectInspector(ArrayWritableObjectInspector.java:108) at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.(ArrayWritableObjectInspector.java:60) at org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe.initialize(ParquetHiveSerDe.java:113) at org.apache.hadoop.hive.metastore.MetaStoreUtils.getDeserializer(MetaStoreUtils.java:339) at org.apache.hadoop.hive.ql.metadata.Table.getDeserializerFromMetaStore(Table.java:288) at org.apache.hadoop.hive.ql.metadata.Table.checkValidity(Table.java:194) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:597) ... 76 more ``` Author: Yin Huai Closes #8824 from yhuai/datasourceMetadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 101 +++++++++--------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0c1b41e3377e3..012634cb5aeb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -309,69 +309,68 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } // TODO: Support persisting partitioned data source relations in Hive compatible format - val hiveTable = (maybeSerDe, dataSource.relation) match { + val qualifiedTableName = tableIdent.quotedString + val (hiveCompitiableTable, logMessage) = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) - if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - // Hive ParquetSerDe doesn't support decimal type until 1.2.0. - val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) - val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) - - val hiveParquetSupportsDecimal = client.version match { - case org.apache.spark.sql.hive.client.hive.v1_2 => true - case _ => false - } - - if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { - // If Hive version is below 1.2.0, we cannot save Hive compatible schema to - // metastore when the file format is Parquet and the schema has DecimalType. - logWarning { - "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + - "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + - s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." - } - newSparkSQLSpecificMetastoreTable() - } else { - logInfo { - "Persisting data source relation with a single input path into Hive metastore in " + - s"Hive compatible format. Input path: ${relation.paths.head}" - } - newHiveCompatibleMetastoreTable(relation, serde) - } + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) + val message = + s"Persisting data source relation $qualifiedTableName with a single input path " + + s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + (Some(hiveTable), message) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => - logWarning { - "Persisting partitioned data source relation into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), relation: HadoopFsRelation) => - logWarning { - "Persisting data source relation with multiple input paths into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting data source relation $qualifiedTableName with multiple input paths into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + s"Input paths: " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), _) => - logWarning { - s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + - "Persisting it into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Data source relation $qualifiedTableName is not a " + + s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " + + "in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) case _ => - logWarning { + val message = s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - "Persisting data source relation into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + s"Persisting data source relation $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) } - client.createTable(hiveTable) + (hiveCompitiableTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatiable way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + client.createTable(table) + } catch { + case throwable: Throwable => + val warningMessage = + s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " + + s"it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, throwable) + val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() + client.createTable(sparkSqlSpecificTable) + } + + case (None, message) => + logWarning(message) + val hiveTable = newSparkSQLSpecificMetastoreTable() + client.createTable(hiveTable) + } } def hiveDefaultTableFilePath(tableName: String): String = { From 5aea987c904b281d7952ad8db40a32561b4ec5cf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:31:35 -0700 Subject: [PATCH 0191/2089] [SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai Closes #8854 from yhuai/SMJBug. --- .../codegen/GenerateProjection.scala | 2 ++ .../apache/spark/sql/execution/Window.scala | 9 ++++-- .../sql/execution/joins/SortMergeJoin.scala | 25 +++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 28 +++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2164ddf03d1b2..75524b568d685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { + // GenerateProjection does not work with UnsafeRows. + assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 0269d6d4b7a1c..f8929530c5036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -253,7 +253,11 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(partitionSpec, child.output) + val grouping = if (child.outputsUnsafeRows) { + UnsafeProjection.create(partitionSpec, child.output) + } else { + newProjection(partitionSpec, child.output) + } // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow @@ -277,7 +281,8 @@ case class Window( val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - val currentGroup = nextGroup + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() rows = new CompactBuffer while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 906f20d2a7289..70a1af6a7063a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,9 +56,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - protected[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(leftKeys) @@ -82,6 +79,28 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05b4127cbcaff..eca6f1073889a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1), Row(1))) } } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = + SQLConf.SORTMERGE_JOIN.key -> "true" :: + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: + SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } } From a96ba40f7ee1352288ea676d8844e1c8174202eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 14:11:46 -0700 Subject: [PATCH 0192/2089] [SPARK-10714] [SPARK-8632] [SPARK-10685] [SQL] Refactor Python UDF handling This patch refactors Python UDF handling: 1. Extract the per-partition Python UDF calling logic from PythonRDD into a PythonRunner. PythonRunner itself expects iterator as input/output, and thus has no dependency on RDD. This way, we can use PythonRunner directly in a mapPartitions call, or in the future in an environment without RDDs. 2. Use PythonRunner in Spark SQL's BatchPythonEvaluation. 3. Updated BatchPythonEvaluation to only use its input once, rather than twice. This should fix Python UDF performance regression in Spark 1.5. There are a number of small cleanups I wanted to do when I looked at the code, but I kept most of those out so the diff looks small. This basically implements the approach in https://github.com/apache/spark/pull/8833, but with some code moving around so the correctness doesn't depend on the inner workings of Spark serialization and task execution. Author: Reynold Xin Closes #8835 from rxin/python-iter-refactor. --- .../apache/spark/api/python/PythonRDD.scala | 54 ++++++++++--- .../spark/sql/execution/pythonUDFs.scala | 80 +++++++++++-------- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 69da180593bb5..3788d1829758a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( parent: RDD[_], @@ -61,11 +61,39 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") @@ -75,7 +103,7 @@ private[spark] class PythonRDD( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,11 +242,11 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size()) for (include <- pythonIncludes.asScala) { @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index d0411da6fdf5a..c35c726bfc503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -329,7 +329,13 @@ case class EvaluatePython( /** * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) @@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val parent = childResults.mapPartitions { iter => + inputRDD.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => + queue.add(row) EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } - } - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - } - }.mapPartitions { iter => + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } + val joined = new JoinedRow - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + joined(queue.poll(), row) } } } From 61d4c07f4becb42f054e588be56ed13239644410 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 16:35:43 -0700 Subject: [PATCH 0193/2089] [SPARK-10640] History server fails to parse TaskCommitDenied ... simply because the code is missing! Author: Andrew Or Closes #8828 from andrewor14/task-end-reason-json. --- .../scala/org/apache/spark/TaskEndReason.scala | 6 +++++- .../org/apache/spark/util/JsonProtocol.scala | 13 +++++++++++++ .../apache/spark/util/JsonProtocolSuite.scala | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7137246bc34f2..9335c5f4160bf 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,13 +17,17 @@ package org.apache.spark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 99614a786bd93..40729fa5a4ffe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,6 +362,10 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) + case taskCommitDenied: TaskCommitDenied => + ("Job ID" -> taskCommitDenied.jobID) ~ + ("Partition ID" -> taskCommitDenied.partitionID) ~ + ("Attempt Number" -> taskCommitDenied.attemptNumber) case ExecutorLostFailure(executorId, isNormalExit) => ("Executor ID" -> executorId) ~ ("Normal Exit" -> isNormalExit) @@ -770,6 +774,7 @@ private[spark] object JsonProtocol { val exceptionFailure = Utils.getFormattedClassName(ExceptionFailure) val taskResultLost = Utils.getFormattedClassName(TaskResultLost) val taskKilled = Utils.getFormattedClassName(TaskKilled) + val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) @@ -794,6 +799,14 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled + case `taskCommitDenied` => + // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON + // de/serialization logic was not added until 1.5.1. To provide backward compatibility + // for reading those logs, we need to provide default values for all the fields. + val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). map(_.extract[Boolean]) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 143c1b901df11..a24bf2931cca0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,6 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) + testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) @@ -352,6 +353,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } + // `TaskCommitDenied` was added in 1.3.0 but JSON de/serialization logic was added in 1.5.1 + test("TaskCommitDenied backward compatibility") { + val denied = TaskCommitDenied(1, 2, 3) + val oldDenied = JsonProtocol.taskEndReasonToJson(denied) + .removeField({ _._1 == "Job ID" }) + .removeField({ _._1 == "Partition ID" }) + .removeField({ _._1 == "Attempt Number" }) + val expectedDenied = TaskCommitDenied(-1, -1, -1) + assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -577,6 +589,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => + case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), + TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => + assert(jobId1 === jobId2) + assert(partitionId1 === partitionId2) + assert(attemptNumber1 === attemptNumber2) case (ExecutorLostFailure(execId1, isNormalExit1), ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) From 84f81e035e1dab1b42c36563041df6ba16e7b287 Mon Sep 17 00:00:00 2001 From: Zhichao Li Date: Tue, 22 Sep 2015 19:41:57 -0700 Subject: [PATCH 0194/2089] [SPARK-10310] [SQL] Fixes script transformation field/line delimiters **Please attribute this PR to `Zhichao Li `.** This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer classes. Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`, and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be addressed in separate PR(s). Author: Cheng Lian Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters. --- .../org/apache/spark/sql/hive/HiveQl.scala | 52 ++++++++++--- .../hive/execution/ScriptTransformation.scala | 75 +++++++++++++++---- .../resources/data/scripts/test_transform.py | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 39 ++++++++++ .../execution/ScriptTransformationSuite.scala | 2 + 5 files changed, 152 insertions(+), 22 deletions(-) create mode 100755 sql/hive/src/test/resources/data/scripts/test_transform.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d5cd7e98b5267..256440a9a2e97 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException @@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, None, Nil) + (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (BaseSemanticAnalyzer.unescapeSQLString(name), BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) - val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) + + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) + } else { + None + } + + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None + } + val schema = HiveScriptIOSchema( inRowFormat, outRowFormat, inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, schemaLess) + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) Some( logical.ScriptTransformation( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 32bddbaeaeaf9..b30117f0de997 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -24,20 +24,22 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} import org.apache.spark.{Logging, TaskContext} /** @@ -58,6 +60,8 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) @@ -67,6 +71,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHiveConf = serializedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -96,7 +101,8 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHiveConf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call @@ -109,6 +115,10 @@ case class ScriptTransformation( val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, localHiveConf).orNull + var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { outputSerde.getSerializedClass().newInstance @@ -134,15 +144,25 @@ case class ScriptTransformation( } } else if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject - try { - scriptOutputWritable.readFields(scriptOutputStream) - true - } catch { - case _: EOFException => - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } + + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + writerThread.exception.foreach(throw _) false + } else { + true + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } } else { true @@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + conf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread( TaskContext.setTaskContext(taskContext) val dataOutputStream = new DataOutputStream(outputStream) + @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception @@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } } } outputStream.close() @@ -290,6 +317,8 @@ case class HiveScriptIOSchema ( outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { private val defaultFormat = Map( @@ -347,4 +376,24 @@ case class HiveScriptIOSchema ( serde } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } } diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py new file mode 100755 index 0000000000000..ac6d11d8b919c --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transform.py @@ -0,0 +1,6 @@ +import sys + +delim = sys.argv[1] + +for row in sys.stdin: + print(delim.join([w + '#' for w in row[:-1].split(delim)])) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bb02473dd17ca..71823e32ad389 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(df, Row("text inside layer 2") :: Nil) } + + test("SPARK-10310: " + + "script transformation using default input/output SerDe and record reader/writer") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + checkAnswer( + sql( + """FROM( + | FROM test SELECT TRANSFORM(a, b) + | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | AS (c STRING, d STRING) + |) t + |SELECT c + """.stripMargin), + (0 until 5).map(i => Row(i + "#"))) + } + + test("SPARK-10310: script transformation using LazySimpleSerDe") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + val df = sql( + """FROM test + |SELECT TRANSFORM(a, b) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |AS (c STRING, d STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + """.stripMargin) + + checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index cb8d0fca8e693..7cfdb886b585d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { outputSerdeClass = None, inputSerdeProps = Seq.empty, outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, schemaLess = false ) From 558e9c7e60a7c0d85ba26634e97562ad2163e91d Mon Sep 17 00:00:00 2001 From: Matt Hagen Date: Tue, 22 Sep 2015 21:14:25 -0700 Subject: [PATCH 0195/2089] [SPARK-10663] Removed unnecessary invocation of DataFrame.toDF method. The Scala example under the "Example: Pipeline" heading in this document initializes the "test" variable to a DataFrame. Because test is already a DF, there is not need to call test.toDF as the example does in a subsequent line: model.transform(test.toDF). So, I removed the extraneous toDF invocation. Author: Matt Hagen Closes #8875 from hagenhaus/SPARK-10663. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 0427ac6695aa1..fd3a6167bc65e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -475,7 +475,7 @@ val test = sqlContext.createDataFrame(Seq( )).toDF("id", "text") // Make predictions on test documents. -model.transform(test.toDF) +model.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => From 5548a254755bb84edae2768b94ab1816e1b49b91 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Sep 2015 22:44:09 -0700 Subject: [PATCH 0196/2089] [SPARK-10652] [SPARK-10742] [STREAMING] Set meaningful job descriptions for all streaming jobs Here is the screenshot after adding the job descriptions to threads that run receivers and the scheduler thread running the batch jobs. ## All jobs page * Added job descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9924165/cda4a372-5cb1-11e5-91ca-d43a32c699e9.png) ## All stages page * Added stage descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9923814/2cce266a-5cae-11e5-8a3f-dad84d06c50e.png) ## Streaming batch details page * Added the +details link ![image](https://cloud.githubusercontent.com/assets/663212/9921977/24014a32-5c98-11e5-958e-457b6c38065b.png) Author: Tathagata Das Closes #8791 from tdas/SPARK-10652. --- .../scala/org/apache/spark/ui/UIUtils.scala | 62 ++++++++++++++++- .../apache/spark/ui/jobs/AllJobsPage.scala | 14 ++-- .../org/apache/spark/ui/jobs/StageTable.scala | 7 +- .../org/apache/spark/ui/UIUtilsSuite.scala | 66 +++++++++++++++++++ .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/scheduler/JobScheduler.scala | 15 ++++- .../streaming/scheduler/ReceiverTracker.scala | 5 +- .../apache/spark/streaming/ui/BatchPage.scala | 33 ++++++---- .../streaming/StreamingContextSuite.scala | 2 +- 9 files changed, 179 insertions(+), 29 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f2da417724104..21dc8f0b65485 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -395,4 +397,60 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like } private def createExecutorTable() : Seq[Node] = { From 9631ca35275b0ce8a5219f975907ac36ed11f528 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Nov 2015 08:59:20 +0000 Subject: [PATCH 0873/2089] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Update to Commons Collections 3.2.2 to avoid any potential remote code execution vulnerability Author: Sean Owen Closes #9731 from srowen/SPARK-11652. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 940e2d8740bf1..ad849112ce76c 100644 --- a/pom.xml +++ b/pom.xml @@ -162,6 +162,8 @@ 3.1 3.4.1 + + 3.2.2 2.10.5 2.10 ${scala.version} @@ -475,6 +477,11 @@ commons-math3 ${commons.math3.version} + + org.apache.commons + commons-collections + ${commons.collections.version} + org.apache.ivy ivy From 1429e0a2b562469146b6fa06051c85a00092e5b8 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Wed, 18 Nov 2015 09:10:15 +0000 Subject: [PATCH 0874/2089] rmse was wrongly calculated It was multiplying with U instaed of dividing by U Author: Viveka Kulharia Closes #9771 from vivkul/patch-1. --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e94..205ca02962bee 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): From 3a6807fdf07b0e73d76502a6bd91ad979fde8b61 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 18 Nov 2015 08:18:54 -0800 Subject: [PATCH 0875/2089] =?UTF-8?q?[SPARK-11804]=20[PYSPARK]=20Exception?= =?UTF-8?q?=20raise=20when=20using=20Jdbc=20predicates=20opt=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion in PySpark Author: Jeff Zhang Closes #9791 from zjffdu/SPARK-11804. --- python/pyspark/sql/readwriter.py | 10 +++++----- python/pyspark/sql/utils.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b8ddb9feba34..e8f0d7ec77035 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -26,6 +26,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -131,9 +132,7 @@ def load(self, path=None, format=None, schema=None, **options): if type(path) == list: paths = path gateway = self._sqlContext._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] + jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) return self._df(self._jreader.load(jpaths)) else: return self._df(self._jreader.load(path)) @@ -269,8 +268,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index c4fda8bd3b891..b0a0373372d20 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -71,3 +71,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr From a97d6f3a5861e9f2bbe36957e3b39f835f3e214c Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 18 Nov 2015 08:32:03 -0800 Subject: [PATCH 0876/2089] [SPARK-11281][SPARKR] Add tests covering the issue. The goal of this PR is to add tests covering the issue to ensure that is was resolved by [SPARK-11086](https://issues.apache.org/jira/browse/SPARK-11086). Author: zero323 Closes #9743 from zero323/SPARK-11281-tests. --- R/pkg/inst/tests/test_sparkSQL.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8ff06276599e2..87ab33f6384b1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -229,7 +229,7 @@ test_that("create DataFrame from list or data.frame", { df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) - l <- list(list(a=1, b=2), list(a=3, b=4)) + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) @@ -292,11 +292,15 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - ldf <- data.frame(row.names=1:2) + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) - expect_equivalent(ldf, collect(sdf)) + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) }) # For test map type and struct type in DataFrame From 224723e6a8b198ef45d6c5ca5d2f9c61188ada8f Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 18 Nov 2015 08:41:45 -0800 Subject: [PATCH 0877/2089] [SPARK-11773][SPARKR] Implement collection functions in SparkR. Author: Sun Rui Closes #9764 from sun-rui/SPARK-11773. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 109 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 10 ++- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 10 +++ 6 files changed, 100 insertions(+), 35 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2ee7d6f94f1bc..260c9edce62e0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -98,6 +98,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "array_contains", "asc", "ascii", "asin", @@ -215,6 +216,7 @@ exportMethods("%in%", "sinh", "size", "skewness", + "sort_array", "soundex", "stddev", "stddev_pop", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fd105ba5bc9bb..34177e3cdd94f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2198,4 +2198,4 @@ setMethod("coltypes", rTypes[naIndices] <- types[naIndices] rTypes - }) \ No newline at end of file + }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d0255a62f155..ff0f438045c14 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -373,22 +373,6 @@ setMethod("exp", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @export -#' @examples \dontrun{explode(df$c)} -setMethod("explode", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) - column(jc) - }) - #' expm1 #' #' Computes the exponential of the given value minus one. @@ -980,22 +964,6 @@ setMethod("sinh", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @rdname size -#' @name size -#' @family collection_funcs -#' @export -#' @examples \dontrun{size(df$c)} -setMethod("size", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) - column(jc) - }) - #' skewness #' #' Aggregate function: returns the skewness of the values in a group. @@ -2365,3 +2333,80 @@ setMethod("rowNumber", jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") column(jc) }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index afdeffc2abd83..0dcd05438222b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -644,6 +644,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @export setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + #' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -961,6 +965,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @export setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1076,4 +1084,4 @@ setGeneric("with") #' @rdname coltypes #' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index db3b2c4bbd799..45c77a86c9582 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -635,4 +635,4 @@ assignNewEnv <- function(data) { assign(x = cols[i], value = data[, cols[i]], envir = env) } env -} \ No newline at end of file +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 87ab33f6384b1..d9a94faff7ac0 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -878,6 +878,16 @@ test_that("column functions", { df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) }) # test_that("column binary mathfunctions", { From 3cca5ffb3d60d5de9a54bc71cf0b8279898936d2 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Wed, 18 Nov 2015 09:28:59 -0800 Subject: [PATCH 0878/2089] [SPARK-11195][CORE] Use correct classloader for TaskResultGetter Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel Closes #9779 from choochootrain/spark-11195-master. --- .../scala/org/apache/spark/TestUtils.scala | 11 ++-- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 65 ++++++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index acfe751f6c746..43c89b258f2fa 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.nio.charset.StandardCharsets +import java.nio.file.Paths import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} @@ -83,15 +84,15 @@ private[spark] object TestUtils { } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -123,7 +124,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2ee..f4965994d8277 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff529..bc72c3685e8c1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } From cffb899c4397ecccedbcc41e7cf3da91f953435a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:15:50 -0800 Subject: [PATCH 0879/2089] [SPARK-11803][SQL] fix Dataset self-join When we resolve the join operator, we may change the output of right side if self-join is detected. So in `Dataset.joinWith`, we should resolve the join operator first, and then get the left output and right output from it, instead of using `left.output` and `right.output` directly. Author: Wenchen Fan Closes #9806 from cloud-fan/self-join. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 14 +++++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 817c20fdbb9f3..b644f6ad3096d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -498,13 +498,17 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan + val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(left.output.head, "_1")() - case _ => Alias(CreateStruct(left.output), "_1")() + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() } val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(right.output.head, "_2")() - case _ => Alias(CreateStruct(right.output), "_2")() + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() } @@ -513,7 +517,7 @@ class Dataset[T] private[sql]( withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, - Join(left, right, Inner, Some(condition.expr))) + joined.analyzed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a522894c374f9..198962b8fb750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } - ignore("self join") { + test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) @@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.groupBy(p => p).count().collect().toSeq == Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - ignore("kryo encoder self join") { + test("kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == Set( (KryoData(1), KryoData(1)), From 33b837333435ceb0c04d1f361a5383c4fe6a5a75 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:23:12 -0800 Subject: [PATCH 0880/2089] [SPARK-11725][SQL] correctly handle null inputs for UDF If user use primitive parameters in UDF, there is no way for him to do the null-check for primitive inputs, so we are assuming the primitive input is null-propagatable for this case and return null if the input is null. Author: Wenchen Fan Closes #9770 from cloud-fan/udf. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 ++++ .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +++ .../sql/catalyst/ScalaReflectionSuite.scala | 17 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 44 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++ 6 files changed, 121 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b3dd351e38e8..38828e59a2152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -719,6 +719,15 @@ trait ScalaReflection { } } + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes + } + def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. case obj: Boolean => BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdba..f00c451b5981a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -85,6 +85,8 @@ class Analyzer( extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), + Batch("UDF", Once, + HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -1063,6 +1065,34 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. + */ + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case plan => plan transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3388cc20a9803..03b89221ef2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF. */ case class ScalaUDF( function: AnyRef, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737f..4ea410d492b01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 65f09b46afae1..08586a97411ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest { ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } + + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 35cdab50bdec9..5a7f24684d1b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(df("*")), Row(1, "a")) checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) null else i * 2 + } + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } } From dbf428c87ab34b6f76c75946043bdf5f60c9b1b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:33:17 -0800 Subject: [PATCH 0881/2089] [SPARK-11795][SQL] combine grouping attributes into a single NamedExpression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat. However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`. Author: Wenchen Fan Closes #9792 from cloud-fan/agg. --- .../main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 +++++++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 ++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c66162ee2148a..3f84e22a1025b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedTEncoder, dataAttributes).named) - val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val keyColumn = if (groupingAttributes.length > 1) { + Alias(CreateStruct(groupingAttributes), "key")() + } else { + groupingAttributes.head + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) new Dataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 198962b8fb750..b6db583dfe01f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - ignore("Dataset should set the resolved encoders internally for maps") { - // TODO: Enable this once we fix SPARK-11793. + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() @@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds, - (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { From 90a7519daaa7f4ee3be7c5a9aa244120811ff6eb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 18 Nov 2015 11:35:41 -0800 Subject: [PATCH 0882/2089] [MINOR][BUILD] Ignore ensime cache Using ENSIME, I often have `.ensime_cache` polluting my source tree. This PR simply adds the cache directory to `.gitignore` Author: Jakob Odersky Closes #9708 from jodersky/master. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 08f2d8f7543f0..07524bc429e92 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log From 6f99522d13d8db9fcc767f7c3189557b9a53d283 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 11:49:12 -0800 Subject: [PATCH 0883/2089] [SPARK-11792] [SQL] [FOLLOW-UP] Change SizeEstimation to KnownSizeEstimation and make estimatedSize return Long instead of Option[Long] https://issues.apache.org/jira/browse/SPARK-11792 The main changes include: * Renaming `SizeEstimation` to `KnownSizeEstimation`. Hopefully this new name has more information. * Making `estimatedSize` return `Long` instead of `Option[Long]`. * In `UnsaveHashedRelation`, `estimatedSize` will delegate the work to `SizeEstimator` if we have not created a `BytesToBytesMap`. Since we will put `UnsaveHashedRelation` to `BlockManager`, it is generally good to let it provide a more accurate size estimation. Also, if we do not put `BytesToBytesMap` directly into `BlockerManager`, I feel it is not really necessary to make `BytesToBytesMap` extends `KnownSizeEstimation`. Author: Yin Huai Closes #9813 from yhuai/SPARK-11792-followup. --- .../org/apache/spark/util/SizeEstimator.scala | 30 ++++++++++--------- .../spark/util/SizeEstimatorSuite.scala | 14 ++------- .../sql/execution/joins/HashedRelation.scala | 12 +++++--- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index c3a2675ee5f45..09864e3f8392d 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. */ -private[spark] trait SizeEstimation { - def estimatedSize: Option[Long] +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long } /** @@ -209,18 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val estimatedSize = obj match { - case s: SizeEstimation => s.estimatedSize - case _ => None - } - if (estimatedSize.isDefined) { - state.size += estimatedSize.get - } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 9b6261af123e6..101610e38014e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } -class DummyClass8 extends SizeEstimation { +class DummyClass8 extends KnownSizeEstimation { val x: Int = 0 - override def estimatedSize: Option[Long] = Some(2015) -} - -class DummyClass9 extends SizeEstimation { - val x: Int = 0 - - override def estimatedSize: Option[Long] = None + override def estimatedSize: Long = 2015 } class SizeEstimatorSuite @@ -231,9 +225,5 @@ class SizeEstimatorSuite // DummyClass8 provides its size estimation. assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) - - // DummyClass9 does not provide its size estimation. - assertResult(16)(SizeEstimator.estimate(new DummyClass9)) - assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 49ae09bf53782..aebfea5832402 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimation, Utils} +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -190,7 +190,7 @@ private[execution] object HashedRelation { private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation - with SizeEstimation + with KnownSizeEstimation with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation( } } - override def estimatedSize: Option[Long] = { - Option(binaryMap).map(_.getTotalMemoryConsumption) + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } } override def get(key: InternalRow): Seq[InternalRow] = { From 94624eacb0fdbbe210894151a956f8150cdf527e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 18 Nov 2015 11:53:28 -0800 Subject: [PATCH 0884/2089] [SPARK-11739][SQL] clear the instantiated SQLContext Currently, if the first SQLContext is not removed after stopping SparkContext, a SQLContext could set there forever. This patch make this more robust. Author: Davies Liu Closes #9706 from davies/clear_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 17 +++++++++++------ .../spark/sql/MultiSQLContextsSuite.scala | 5 ++--- .../execution/ExchangeCoordinatorSuite.scala | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd1fdc4edb39d..39471d2fb79a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1229,7 +1229,7 @@ class SQLContext private[sql]( // construction of the instance. sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext(self) + SQLContext.clearInstantiatedContext() } }) @@ -1270,13 +1270,13 @@ object SQLContext { */ def getOrCreate(sparkContext: SparkContext): SQLContext = { val ctx = activeContext.get() - if (ctx != null) { + if (ctx != null && !ctx.sparkContext.isStopped) { return ctx } synchronized { val ctx = instantiatedContext.get() - if (ctx == null) { + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) } else { ctx @@ -1284,12 +1284,17 @@ object SQLContext { } } - private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(sqlContext, null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) } private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(null, sqlContext) + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } + } } private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 0e8fcb6a858b1..34c5c68fd1c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { testNewSession(rootSQLContext) testNewSession(rootSQLContext) testCreatingNewSQLContext(allowMultipleSQLContexts) - - SQLContext.clearInstantiatedContext(rootSQLContext) } finally { sc.stop() + SQLContext.clearInstantiatedContext() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 25f2f5caeed15..b96d50a70b85c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() } override protected def afterAll(): Unit = { From 31921e0f0bd559d042148d1ea32f865fb3068f38 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 18 Nov 2015 12:09:54 -0800 Subject: [PATCH 0885/2089] [SPARK-4557][STREAMING] Spark Streaming foreachRDD Java API method should accept a VoidFunction<...> Currently streaming foreachRDD Java API uses a function prototype requiring a return value of null. This PR deprecates the old method and uses VoidFunction to allow for more concise declaration. Also added VoidFunction2 to Java API in order to use in Streaming methods. Unit test is added for using foreachRDD with VoidFunction, and changes have been tested with Java 7 and Java 8 using lambdas. Author: Bryan Cutler Closes #9488 from BryanCutler/foreachRDD-VoidFunction-SPARK-4557. --- .../api/java/function/VoidFunction2.java | 27 ++++++++++++ .../apache/spark/streaming/Java8APISuite.java | 26 ++++++++++++ project/MimaExcludes.scala | 4 ++ .../streaming/api/java/JavaDStreamLike.scala | 24 ++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 41 ++++++++++++++++++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 0000000000000..6c576ab678455 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + public void call(T1 v1, T2 v2) throws Exception; +} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 163ae92c12c6d..4eee97bc89613 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -28,6 +28,7 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -360,6 +361,31 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> null); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @Test public void testPairFlatMap() { List> inputData = Arrays.asList( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eb70d27c34c20..bb45d1bb12146 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -142,6 +142,10 @@ object MimaExcludes { "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") ) case v if v.startsWith("1.5") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index edfa474677f15..84acec7d8e330 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -308,7 +308,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") def foreachRDD(foreachFunc: JFunction[R, Void]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -316,11 +319,30 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction[R]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e4..609bb4413b6b1 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -37,7 +37,9 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -45,7 +47,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -768,6 +769,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { From a416e41e285700f861559d710dbf857405bfddf6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 12:50:29 -0800 Subject: [PATCH 0886/2089] [SPARK-11809] Switch the default Mesos mode to coarse-grained mode Based on my conversions with people, I believe the consensus is that the coarse-grained mode is more stable and easier to reason about. It is best to use that as the default rather than the more flaky fine-grained mode. Author: Reynold Xin Closes #9795 from rxin/SPARK-11809. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- docs/job-scheduling.md | 2 +- docs/running-on-mesos.md | 27 ++++++++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b5645b08f92d4..ab374cb71286a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2710,7 +2710,7 @@ object SparkContext extends Logging { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a3c34cb6796fa..36327c6efeaf3 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5be208cf3461e..a197d0e373027 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -161,21 +161,15 @@ Note that jars or python files that are passed to spark-submit should be URIs re # Mesos Run Modes -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. - -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explictly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -186,6 +180,19 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} From 7c5b641808740ba5eed05ba8204cdbaf3fc579f5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 18 Nov 2015 12:53:22 -0800 Subject: [PATCH 0887/2089] [SPARK-10745][CORE] Separate configs between shuffle and RPC [SPARK-6028](https://issues.apache.org/jira/browse/SPARK-6028) uses network module to implement RPC. However, there are some configurations named with `spark.shuffle` prefix in the network module. This PR refactors them to make sure the user can control them in shuffle and RPC separately. The user can use `spark.rpc.*` to set the configuration for netty RPC. Author: Shixiong Zhu Closes #9481 from zsxwing/SPARK-10745. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- .../network/netty/SparkTransportConf.scala | 12 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 8 +-- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../spark/network/util/TransportConf.java | 65 ++++++++++++++----- .../network/ChunkFetchIntegrationSuite.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 2 +- .../org/apache/spark/network/StreamSuite.java | 2 +- .../network/TransportClientFactorySuite.java | 6 +- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockResolverSuite.java | 2 +- .../shuffle/ExternalShuffleCleanupSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../shuffle/RetryingBlockFetcherSuite.java | 2 +- .../network/yarn/YarnShuffleService.java | 2 +- 23 files changed, 84 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index a039d543c35e7..e8a1e35c3fc48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -45,7 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val transportConf = + SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler, true) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70a42f9045e6b..b0694e3c6c8af 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -41,7 +41,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d685..84833f59d7afe 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22c..3e0c497969502 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,16 +22,13 @@ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -49,7 +46,8 @@ private[netty] class NettyRpcEnv( securityManager: SecurityManager) extends RpcEnv(conf) with Logging { private val transportConf = SparkTransportConf.fromSparkConf( - conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2de9b6a651692..7d08eae0b4871 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] class CoarseMesosSchedulerBackend( private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), + SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled())) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 39fadd8783518..cc5f933393adf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -46,7 +46,7 @@ private[spark] trait ShuffleWriterGroup { private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 05b1eed7f3bef..fadb8fe7ed0ab 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -47,7 +47,7 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 661c706af32b1..ab0007fb78993 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -122,7 +122,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 231f4631e0a47..1c775bcb3d9c1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,7 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3b2eff377955a..115135d44adbd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,18 +23,53 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + private final ConfigProvider conf; - public TransportConf(ConfigProvider conf) { + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } /** Connect timeout in milliseconds. Default 120 secs. */ @@ -42,23 +77,23 @@ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); } /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } /** * Receive buffer size (SO_RCVBUF). @@ -67,28 +102,28 @@ public int numConnectionsPerPeer() { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. * If set to 0, we will not do any retries. */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; } /** @@ -101,11 +136,11 @@ public int memoryMapBytes() { } /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are * created only when data is going to be transferred. This can reduce the number of open files. */ public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed0..dc5fa1cee69bc 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -83,7 +83,7 @@ public static void setUp() throws Exception { fp.write(fileContent); fp.close(); - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 84ebb337e6d54..42955ef69235a 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -60,7 +60,7 @@ public class RequestTimeoutIntegrationSuite { public void setUp() throws Exception { Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f01..8eb56bdd9846f 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -49,7 +49,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 6dcec831dec71..00158fd081626 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -89,7 +89,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f447137419306..dac7d4a5b0a09 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -52,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -76,7 +76,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -182,7 +182,7 @@ public void closeBlockClientsWithFactory() throws IOException { @Test public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { @Override public String get(String name) { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 3469e84e7f4da..b146899670180 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -207,7 +207,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -242,7 +242,7 @@ public void testFileRegionEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -368,7 +368,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e6810..1c2fa4d0d462c 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -70,7 +70,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 3c6cb367dea46..a9958232a1d28 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 2f4f1d0df478b..532d7ab8d01bd 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -35,7 +35,7 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a3f9a38b1aeb9..2095f41d79c16 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -91,7 +91,7 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aa99efda94948..08ddb3755bd08 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,7 +39,7 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 06e46f9241094..3a6ef0d3f8476 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -254,7 +254,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 11ea7f3fd3cfe..ba6d30a74c673 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -120,7 +120,7 @@ protected void serviceInit(Configuration conf) { registeredExecutorFile = findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); From 09ad9533d5760652de59fa4830c24cb8667958ac Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 18 Nov 2015 13:03:37 -0800 Subject: [PATCH 0888/2089] [SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa Closes #9705 from JihongMA/SPARK-11720. --- python/pyspark/sql/dataframe.py | 2 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Kurtosis.scala | 9 +++++---- .../expressions/aggregate/Skewness.scala | 9 +++++---- .../expressions/aggregate/Stddev.scala | 18 ++++++++++++++---- .../expressions/aggregate/Variance.scala | 18 ++++++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 18 ++++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad6ad0235a904..0dd75ba7ca820 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ def describe(self, *cols): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| NaN| + | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index de5872ab11eb1..d07d4c338cdfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 8fa3aac9f1a51..c2bf2cb94116c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -37,16 +37,17 @@ case class Kurtosis(child: Expression, override protected val momentOrder = 4 // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index e1c01a5b82781..9411bcea2539a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -36,16 +36,17 @@ case class Skewness(child: Expression, override protected val momentOrder = 3 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 05dd5e3b22543..eec79a9033e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -36,11 +36,17 @@ case class StddevSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } } } @@ -62,10 +68,14 @@ case class StddevPop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ede2da2805966..cf3a740305391 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } } } @@ -62,10 +68,14 @@ case class VariancePop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else moments(2) / n + if (n == 0.0) { + null + } else { + moments(2) / n + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 432e8d17623a4..71adf2148a403 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null)) } test("zero sum") { @@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5a7f24684d1b7..6399b0165c4c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", "NaN", "NaN"), + Row("stddev", null, null), Row("min", null, null), Row("max", null, null)) From 045a4f045821dcf60442f0600c2df1b79bddb536 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 13:06:25 -0800 Subject: [PATCH 0889/2089] [SPARK-6790][ML] Add spark.ml LinearRegression import/export This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang Author: Joseph K. Bradley Closes #9814 from jkbradley/fayeshine-patch-6790. --- .../ml/regression/LinearRegression.scala | 77 ++++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 34 +++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e581983..ca55d5915e688 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object LinearRegression extends Readable[LinearRegression] { + + @Since("1.6.0") + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] + + @Since("1.6.0") + override def load(path: String): LinearRegression = read.load(path) } /** @@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + + private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8fedad..2bdc0e184d734 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } From 2acdf10b1f3bb1242dba64efa798c672fde9f0d2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 13:16:31 -0800 Subject: [PATCH 0890/2089] [SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9786 from jkbradley/als-io. --- .../apache/spark/ml/recommendation/ALS.scala | 75 ++++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 14 +++- .../spark/ml/recommendation/ALSSuite.scala | 78 ++++++++++++++++--- 3 files changed, 150 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 535f266b9a944..d92514d2e239e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -182,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with Writable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -220,8 +223,60 @@ class ALSModel private[ml] ( val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends Readable[ALSModel] { + + @Since("1.6.0") + override def read: Reader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = read.load(path) + + private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = render("rank" -> instance.rank) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private[recommendation] class ALSModelReader extends Reader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.recommendation.ALSModel" + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank: Int = metadata.extraMetadata match { + case Some(m: JValue) => + (m \ "rank").extract[Int] + case None => + throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + + s" ${metadata.metadataStr}") + } + + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -254,7 +309,7 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def copy(extra: ParamMap): ALS = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends Readable[ALS] with Logging { /** * :: DeveloperApi :: @@ -356,6 +415,12 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def read: Reader[ALS] = new DefaultParamsReader[ALS] + + @Since("1.6.0") + override def load(path: String): ALS = read.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dddb72af5ba78..d8ce907af5323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ - def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter { ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("extraMetadata" -> extraMetadata) val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] + * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] * @param metadataStr Full metadata file String (for debugging) */ case class Metadata( @@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + extraMetadata: Option[JValue], metadataStr: String) /** @@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" + val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index eadc80e0e62b1..2c3fb84160dcb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,28 +25,26 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } From e391abdf2cb6098a35347bd123b815ee9ac5b689 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 13:25:15 -0800 Subject: [PATCH 0891/2089] [SPARK-11813][MLLIB] Avoid serialization of vocab in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11813 I found the problem during training a large corpus. Avoid serialization of vocab in Word2Vec has 2 benefits. 1. Performance improvement for less serialization. 2. Increase the capacity of Word2Vec a lot. Currently in the fit of word2vec, the closure mainly includes serialization of Word2Vec and 2 global table. the main part of Word2vec is the vocab of size: vocab * 40 * 2 * 4 = 320 vocab 2 global table: vocab * vectorSize * 8. If vectorSize = 20, that's 160 vocab. Their sum cannot exceed Int.max due to the restriction of ByteArrayOutputStream. In any case, avoiding serialization of vocab helps decrease the size of the closure serialization, especially when vectorSize is small, thus to allow larger vocabulary. Actually there's another possible fix, make local copy of fields to avoid including Word2Vec in the closure. Let me know if that's preferred. Author: Yuhao Yang Closes #9803 from hhbyyh/w2vVocab. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f3e4d346e358a..7ab0d89d23a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -145,8 +145,8 @@ class Word2Vec extends Serializable with Logging { private var trainWordsCount = 0 private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) From e222d758499ad2609046cc1a2cc8afb45c5bccbb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:30:29 -0800 Subject: [PATCH 0892/2089] [SPARK-11684][R][ML][DOC] Update SparkR glm API doc, user guide and example codes This PR includes: * Update SparkR:::glm, SparkR:::summary API docs. * Update SparkR machine learning user guide and example codes to show: * supporting feature interaction in R formula. * summary for gaussian GLM model. * coefficients for binomial GLM model. mengxr Author: Yanbo Liang Closes #9727 from yanboliang/spark-11684. --- R/pkg/R/mllib.R | 18 +++++-- docs/sparkr.md | 50 ++++++++++++++++--- .../ml/regression/LinearRegression.scala | 3 ++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f23e1c7f1fce4..8d3b4388ae575 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @param standardize Whether to standardize features before training +#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and +#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory +#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an +#' analytical solution to the linear regression problem. The default value is "auto" +#' which means that the solver algorithm is selected automatically. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -79,9 +85,15 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param x A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. +#' @param object A fitted MLlib model +#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family +#' or a list with 'coefficients' component for binomial family. \cr +#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals +#' of the estimation, the 'coefficients' gives the estimated coefficients and their +#' estimated standard errors, t values and p-values. (It only available when model +#' fitted by normal solver.) \cr +#' For binomial family: the 'coefficients' gives the estimated coefficients. +#' See summary.glm for more information. \cr #' @rdname summary #' @export #' @examples diff --git a/docs/sparkr.md b/docs/sparkr.md index 437bd4756c276..a744b76be7466 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,24 +286,37 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model
    {% highlight r %} # Create the DataFrame df <- createDataFrame(sqlContext, iris) -# Fit a linear model over the dataset. +# Fit a gaussian GLM model over the dataset. model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") -# Model coefficients are returned in a similar format to R's native glm(). +# Model summary are returned in a similar format to R's native glm(). summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## ##$coefficients -## Estimate -##(Intercept) 2.2513930 -##Sepal_Width 0.8035609 -##Species_versicolor 1.4587432 -##Species_virginica 1.9468169 +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 # Make predictions based on the model. predictions <- predict(model, newData = df) @@ -317,3 +330,24 @@ head(select(predictions, "Sepal_Length", "prediction")) ##6 5.4 5.385281 {% endhighlight %}
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ca55d5915e688..f7c44f0a51b8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -145,6 +145,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. * The default value is "auto" which means that the solver algorithm is * selected automatically. * @group setParam From 603a721c21488e17c15c45ce1de893e6b3d02274 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:32:06 -0800 Subject: [PATCH 0893/2089] [SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol [SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr Author: Yanbo Liang Closes #9811 from yanboliang/spark-11820. --- python/pyspark/ml/classification.py | 17 +++++++++-------- python/pyspark/ml/regression.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 603f2c7f798dc..4a2982e2047ff 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, - HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol): """ Logistic regression. Currently, this class only supports binary classification. @@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) >>> model.weights DenseVector([5.5...]) @@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() @@ -105,12 +106,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7648bf13266bf..944e648ec8801 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver): + HasStandardization, HasSolver, HasWeightCol): """ Linear regression. @@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 @@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -92,11 +92,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs From 54db79702513e11335c33bcf3a03c59e965e6f16 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 18 Nov 2015 14:05:18 -0800 Subject: [PATCH 0894/2089] [SPARK-11544][SQL] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9652 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++++--- .../datasources/json/JsonSuite.scala | 36 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df63..f9465157c936d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178affe..f09b61e838159 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 15:42:07 -0800 Subject: [PATCH 0895/2089] [SPARK-11810][SQL] Java-based encoder for opaque types in Datasets. This patch refactors the existing Kryo encoder expressions and adds support for Java serialization. Author: Reynold Xin Closes #9802 from rxin/SPARK-11810. --- .../scala/org/apache/spark/sql/Encoder.scala | 41 +++++++++--- .../sql/catalyst/expressions/objects.scala | 67 ++++++++++++------- .../catalyst/encoders/FlatEncoderSuite.scala | 27 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++++- 4 files changed, 130 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 79c2255641c06..1ed5111440c80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -43,28 +43,49 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - */ - def kryo[T: ClassTag]: Encoder[T] = { - val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) - val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq(ser), - fromRowExpression = deser, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 489c6126f8cd3..acf0da240051e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} +import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} @@ -517,29 +517,39 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy } } -/** Serializes an input object using Kryo serializer. */ -case class SerializeWithKryo(child: Expression) extends UnaryExpression { +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $kryo.serialize(${input.value}, null).array(); + ${ev.value} = $serializer.serialize(${input.value}, null).array(); } """ } @@ -548,29 +558,38 @@ case class SerializeWithKryo(child: Expression) extends UnaryExpression { } /** - * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit - * parameter because TreeNode cannot copy implicit parameters. + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression { override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.javaType(dataType)}) - $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 2729db84897a2..6e0322fb6e019 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -76,17 +76,34 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { // Kryo encoders encodeDecodeTest( "hello", - Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + encoderFor(Encoders.kryo[String]), "kryo string") encodeDecodeTest( - new NotJavaSerializable(15), - Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object serialization") + + // Java encoders + encodeDecodeTest( + "hello", + encoderFor(Encoders.javaSerialization[String]), + "java string") + encodeDecodeTest( + new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), + "java object serialization") } +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} -class NotJavaSerializable(val value: Int) { +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[NotJavaSerializable].value + this.value == other.asInstanceOf[JavaSerializable].value } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b6db583dfe01f..89d964aa3e469 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -357,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } - test("kryo encoder") { + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -365,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - test("kryo encoder self join") { + test("Kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -375,6 +375,25 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (KryoData(2), KryoData(1)), (KryoData(2), KryoData(2)))) } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + ignore("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } } @@ -406,3 +425,16 @@ class KryoData(val a: Int) { object KryoData { def apply(a: Int): KryoData = new KryoData(a) } + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) +} From 7e987de1770f4ab3d54bc05db8de0a1ef035941d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 15:47:49 -0800 Subject: [PATCH 0896/2089] [SPARK-6787][ML] add read/write to estimators under ml.feature (1) Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng Closes #9798 from mengxr/SPARK-6787. --- .../spark/ml/feature/CountVectorizer.scala | 72 +++++++++++++++-- .../org/apache/spark/ml/feature/IDF.scala | 71 ++++++++++++++++- .../spark/ml/feature/MinMaxScaler.scala | 72 +++++++++++++++-- .../spark/ml/feature/StandardScaler.scala | 78 ++++++++++++++++++- .../spark/ml/feature/StringIndexer.scala | 70 +++++++++++++++-- .../ml/feature/CountVectorizerSuite.scala | 24 +++++- .../apache/spark/ml/feature/IDFSuite.scala | 19 ++++- .../spark/ml/feature/MinMaxScalerSuite.scala | 25 +++++- .../ml/feature/StandardScalerSuite.scala | 64 +++++++++++---- .../spark/ml/feature/StringIndexerSuite.scala | 19 ++++- 10 files changed, 467 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b85064..5ff9bfb7d1119 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.DataFrame import org.apache.spark.util.collection.OpenHashMap /** @@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { def this() = this(Identifiable.randomUID("cntVec")) @@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object CountVectorizer extends Readable[CountVectorizer] { + + @Since("1.6.0") + override def read: Reader[CountVectorizer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) } /** @@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String) */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams { + extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + + import CountVectorizerModel._ def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) @@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } + + @Since("1.6.0") + override def write: Writer = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends Readable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + + private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 4c36df75d8aa0..53ad34ef12646 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { def this() = this(Identifiable.randomUID("idf")) @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IDF extends Readable[IDF] { + + @Since("1.6.0") + override def read: Reader[IDF] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) } /** @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with Writable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -117,4 +134,50 @@ class IDFModel private[ml] ( val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: Writer = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends Readable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends Reader[IDFModel] { + + private val className = "org.apache.spark.ml.feature.IDFModel" + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 1b494ec8b1727..24d964fae834e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ @@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object MinMaxScaler extends Readable[MinMaxScaler] { + + @Since("1.6.0") + override def read: Reader[MinMaxScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) } /** @@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] ( val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends Readable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + + private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index f6d0b0c0e9e75..ab04e5418dd4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with Writable { def this() = this(Identifiable.randomUID("stdScal")) @@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StandardScaler extends Readable[StandardScaler] { + + @Since("1.6.0") + override def read: Reader[StandardScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) } /** @@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { + extends Model[StandardScalerModel] with StandardScalerParams with Writable { + + import StandardScalerModel._ /** Standard deviation of the StandardScalerModel */ val std: Vector = scaler.std @@ -112,6 +129,14 @@ class StandardScalerModel private[ml] ( /** Mean of the StandardScalerModel */ val mean: Vector = scaler.mean + /** Whether to scale to unit standard deviation. */ + @Since("1.6.0") + def getWithStd: Boolean = scaler.withStd + + /** Whether to center data with mean. */ + @Since("1.6.0") + def getWithMean: Boolean = scaler.withMean + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -138,4 +163,49 @@ class StandardScalerModel private[ml] ( val copied = new StandardScalerModel(uid, scaler) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends Readable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + + private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends Reader[StandardScalerModel] { + + private val className = "org.apache.spark.ml.feature.StandardScalerModel" + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = + sqlContext.read.parquet(dataPath) + .select("std", "mean", "withStd", "withMean") + .head() + // This is very likely to change in the future because withStd and withMean should be params. + val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) + val model = new StandardScalerModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f782a272d11db..f16f6afc002d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with Writable { def this() = this(Identifiable.randomUID("strIdx")) @@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StringIndexer extends Readable[StringIndexer] { + + @Since("1.6.0") + override def read: Reader[StringIndexer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) } /** @@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with Writable { + + import StringIndexerModel._ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) @@ -176,6 +193,49 @@ class StringIndexerModel ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends Readable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StringIndexerModelReader extends Reader[StringIndexerModel] { + + private val className = "org.apache.spark.ml.feature.StringIndexerModel" + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index e192fa4850af0..9c9999017317d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { test("params") { + ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } @@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429b..bc958c15857ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c04dda41eea34..09183fe65b722 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { val sqlContext = new SQLContext(sc) @@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 879a3ae875004..49a4b2efe0c29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ ) } - def assertResult(dataframe: DataFrame): Unit = { - dataframe.select("standarded_features", "expected").collect().foreach { + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "The vector value is not correct after standardization.") } } + test("params") { + ParamsSuite.checkParams(new StandardScaler) + val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) + ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + } + test("Standardization with default parameter") { val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") - val standardscaler0 = new StandardScaler() + val standardScaler0 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .fit(df0) - assertResult(standardscaler0.transform(df0)) + assertResult(standardScaler0.transform(df0)) } test("Standardization with setter") { @@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") - val standardscaler1 = new StandardScaler() + val standardScaler1 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(true) .fit(df1) - val standardscaler2 = new StandardScaler() + val standardScaler2 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(false) .fit(df2) - val standardscaler3 = new StandardScaler() + val standardScaler3 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(false) .setWithStd(false) .fit(df3) - assertResult(standardscaler1.transform(df1)) - assertResult(standardscaler2.transform(df2)) - assertResult(standardscaler3.transform(df3)) + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val oldModel = new feature.StandardScalerModel( + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) + val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + assert(newInstance.getWithStd === instance.getWithStd) + assert(newInstance.getWithMean === instance.getWithMean) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index be37bfb438833..749bfac747826 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,23 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) + } + test("IndexToString params") { val idxToStr = new IndexToString() ParamsSuite.checkParams(idxToStr) @@ -175,7 +192,7 @@ class StringIndexerSuite assert(outSchema("output").dataType === StringType) } - test("read/write") { + test("IndexToString read/write") { val t = new IndexToString() .setInputCol("myInputCol") .setOutputCol("myOutputCol") From 3a9851936ddfe5bcb6a7f364d535fac977551f5d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 15:55:41 -0800 Subject: [PATCH 0897/2089] [SPARK-11649] Properly set Akka frame size in SparkListenerSuite test SparkListenerSuite's _"onTaskGettingResult() called when result fetched remotely"_ test was extremely slow (1 to 4 minutes to run) and recently became extremely flaky, frequently failing with OutOfMemoryError. The root cause was the fact that this was using `System.setProperty` to set the Akka frame size, which was not actually modifying the frame size. As a result, this test would allocate much more data than necessary. The fix here is to simply use SparkConf in order to configure the frame size. Author: Josh Rosen Closes #9822 from JoshRosen/SPARK-11649. --- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 53102b9f1c936..84e545851f49e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -269,14 +269,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.akka.frameSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + assert(akkaFrameSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } From c07a50b86254578625be777b1890ff95e832ac6e Mon Sep 17 00:00:00 2001 From: Derek Dagit Date: Wed, 18 Nov 2015 15:56:54 -0800 Subject: [PATCH 0898/2089] [SPARK-10930] History "Stages" page "duration" can be confusing Author: Derek Dagit Closes #9051 from d2r/spark-10930-ui-max-task-dur. --- .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ea806d09b6009..2a1c3c1a50ec9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -145,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes From 4b117121900e5f242e7c8f46a69164385f0da7cc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 16:00:35 -0800 Subject: [PATCH 0899/2089] [SPARK-11495] Fix potential socket / file handle leaks that were found via static analysis The HP Fortify Opens Source Review team (https://www.hpfod.com/open-source-review-project) reported a handful of potential resource leaks that were discovered using their static analysis tool. We should fix the issues identified by their scan. Author: Josh Rosen Closes #9455 from JoshRosen/fix-potential-resource-leaks. --- .../spark/unsafe/map/BytesToBytesMap.java | 7 ++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 38 +++++++++++-------- .../streaming/JavaCustomReceiver.java | 31 +++++++-------- .../network/ChunkFetchIntegrationSuite.java | 15 ++++++-- .../shuffle/TestShuffleDataContext.java | 32 ++++++++++------ .../spark/streaming/JavaReceiverAPISuite.java | 20 ++++++---- 6 files changed, 90 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 04694dc54418c..3387f9a4177ce 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -24,6 +24,7 @@ import java.util.LinkedList; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -272,6 +273,7 @@ private void advanceToNextPage() { } } try { + Closeables.close(reader, /* swallowIOException = */ false); reader = spillWriters.getFirst().getReader(blockManager); recordsInPage = -1; } catch (IOException e) { @@ -318,6 +320,11 @@ public Location next() { try { reader.loadNext(); } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } // Scala iterator does not handle exception Platform.throwException(e); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 039e940a357ea..dcb13e6581e54 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -20,8 +20,7 @@ import java.io.*; import com.google.common.io.ByteStreams; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.io.Closeables; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -31,10 +30,8 @@ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -public final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { - private final File file; private InputStream in; private DataInputStream din; @@ -52,11 +49,15 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { assert (file.length() > 0); - this.file = file; final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); - this.in = blockManager.wrapForCompression(blockId, bs); - this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + try { + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; + } } @Override @@ -75,12 +76,7 @@ public void loadNext() throws IOException { ByteStreams.readFully(in, arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { - in.close(); - if (!file.delete() && file.exists()) { - logger.warn("Unable to delete spill file {}", file.getPath()); - } - in = null; - din = null; + close(); } } @@ -103,4 +99,16 @@ public int getRecordLength() { public long getKeyPrefix() { return keyPrefix; } + + @Override + public void close() throws IOException { + if (in != null) { + try { + in.close(); + } finally { + in = null; + din = null; + } + } + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e6..4b50fbf59f80e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -121,23 +122,23 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dc5fa1cee69bc..50a324e293386 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -78,10 +79,15 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); @@ -117,6 +123,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c7..7ac1ca128aed0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index ec2bffd6a5b97..7a8ef9d14784c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -23,6 +23,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -121,14 +122,19 @@ public void onStop() { private void receive() { try { - Socket socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - in.close(); - socket.close(); } catch(ConnectException ce) { ce.printStackTrace(); restart("Could not connect", ce); From a402c92c92b2e1c85d264f6077aec8f6d6a08270 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Nov 2015 16:08:06 -0800 Subject: [PATCH 0900/2089] [SPARK-11814][STREAMING] Add better default checkpoint duration DStream checkpoint interval is by default set at max(10 second, batch interval). That's bad for large batch intervals where the checkpoint interval = batch interval, and RDDs get checkpointed every batch. This PR is to set the checkpoint interval of trackStateByKey to 10 * batch duration. Author: Tathagata Das Closes #9805 from tdas/SPARK-11814. --- .../streaming/dstream/TrackStateDStream.scala | 13 ++++++ .../streaming/TrackStateByKeySuite.scala | 44 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 98e881e6ae115..0ada1111ce30a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ /** * :: Experimental :: @@ -120,6 +121,14 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Enable automatic checkpointing */ override val mustCheckpoint = true + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD @@ -141,3 +150,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } } } + +private[streaming] object InternalTrackStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index e3072b4442840..58aef74c0040f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -22,9 +22,10 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag +import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -57,6 +58,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sc = new SparkContext(conf) } + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + test("state - get, exists, update, remove, ") { var state: StateImpl[Int] = null @@ -436,6 +443,41 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) } + test("trackStateByKey - checkpoint durations") { + val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + try { + ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (value: Option[Int], state: State[Int]) => 0 + val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) + val internalTrackStateStream = trackStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + trackStateStream.checkpoint(d) + } + trackStateStream.register() + ssc.start() // should initialize all the checkpoint durations + assert(trackStateStream.checkpointDuration === null) + assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], From 921900fd06362474f8caac675803d526a0986d70 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 18 Nov 2015 16:19:00 -0800 Subject: [PATCH 0901/2089] [SPARK-11791] Fix flaky test in BatchedWriteAheadLogSuite stack trace of failure: ``` org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 62 times over 1.006322071 seconds. Last failure message: Argument(s) are different! Wanted: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.BatchedWriteAheadLogSuite$$anonfun$23$$anonfun$apply$mcV$sp$15.apply(WriteAheadLogSuite.scala:518) Actual invocation has different arguments: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.WriteAheadLogSuite$BlockingWriteAheadLog.write(WriteAheadLogSuite.scala:756) ``` I believe the issue was that due to a race condition, the ordering of the events could be messed up in the final ByteBuffer, therefore the comparison fails. By adding eventually between the requests, we make sure the ordering is preserved. Note that in real life situations, the ordering across threads will not matter. Another solution would be to implement a custom mockito matcher that sorts and then compares the results, but that kind of sounds like overkill to me. Let me know what you think tdas zsxwing Author: Burak Yavuz Closes #9790 from brkyvz/fix-flaky-2. --- .../spark/streaming/util/WriteAheadLogSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7f80d6ecdbbb5..eaa88ea3cd380 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -30,6 +30,7 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor import org.mockito.Matchers.{eq => meq} import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -507,15 +508,18 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } blockingWal.allowWrite() - val buffer1 = wrapArrayArrayByte(Array(event1)) - val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) eventually(timeout(1 second)) { assert(batchedWal.invokePrivate(queueLength()) === 0) - verify(wal, times(1)).write(meq(buffer1), meq(3L)) + verify(wal, times(1)).write(meq(buffer), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. - verify(wal, times(1)).write(meq(buffer2), meq(10L)) + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) } } From 59a501359a267fbdb7689058693aa788703e54b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 16:48:09 -0800 Subject: [PATCH 0902/2089] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust Closes #9602 from marmbrus/dataset-replClasses. --- .../spark/sql/catalyst/ScalaReflection.scala | 81 ++++++++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 26 +++++- .../sql/catalyst/encoders/OuterScopes.scala | 42 ++++++++++ .../catalyst/encoders/ProductEncoder.scala | 6 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateProjection.scala | 10 +-- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 6 +- .../sql/catalyst/expressions/literals.scala | 6 ++ .../sql/catalyst/expressions/objects.scala | 42 ++++++++-- .../encoders/ExpressionEncoderSuite.scala | 7 +- .../encoders/ProductEncoderSuite.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 8 +- .../aggregate/TypedAggregateExpression.scala | 19 ++--- 17 files changed, 193 insertions(+), 82 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 38828e59a2152..59ccf356f2c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror import universe._ @@ -53,30 +42,6 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -114,7 +79,9 @@ trait ScalaReflection { } ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) } } @@ -640,6 +607,48 @@ trait ScalaReflection { } } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b977f278c5b5c..456b595008479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.encoders +import java.util.concurrent.ConcurrentMap + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.util.Utils -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ @@ -211,7 +213,9 @@ case class ExpressionEncoder[T]( * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { val positionToAttribute = AttributeMap.toIndex(schema) val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) @@ -219,7 +223,23 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(fromRowExpression = analyzedPlan.expressions.head.children.head) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 0000000000000..a753b187bcd32 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 55c4ee11b20f4..2914c6ee790ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag object ProductEncoder { import ScalaReflection.universe._ + import ScalaReflection.mirror import ScalaReflection.localTypeOf import ScalaReflection.dataTypeFor import ScalaReflection.Schema @@ -420,8 +421,7 @@ object ProductEncoder { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -429,7 +429,7 @@ object ProductEncoder { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f34..a31574c251af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -34,7 +34,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55a..40189f0877764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c0d313b2e1301..f229f2000d8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ${initMutableStates(ctx)} } - public Object apply(Object r) { + public java.lang.Object apply(java.lang.Object r) { // GenerateProjection does not work with UnsafeRows. assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); @@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + public java.lang.Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases } return null; } - public void update(int i, Object value) { + public void update(int i, java.lang.Object value) { if (value == null) { setNullAt(i); return; @@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return result; } - public boolean equals(Object other) { + public boolean equals(java.lang.Object other) { if (other instanceof SpecificRow) { SpecificRow row = (SpecificRow) other; $columnChecks @@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; + java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; ${copyColumns} return new ${classOf[GenericInternalRow].getName}(arr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923f..b7926bda3de19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4c17d02a23725..7b6c9373ebe30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b3..da602d9b4bce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26d..e34fd49be8389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,12 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index acf0da240051e..f865a9408ef4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -178,6 +179,15 @@ case class Invoke( } } +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + /** * Constructs a new instance of the given class, using the result of evaluating the specified * expressions as arguments. @@ -189,12 +199,15 @@ case class Invoke( * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. */ case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -209,30 +222,43 @@ case class NewInstance( val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + if (propagateNull) { val objNullCheck = if (ctx.defaultValue(dataType) == "null") { s"${ev.isNull} = ${ev.value} == null;" } else { "" } - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); + $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = ${ev.value} == null; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9fe64b4cf10e4..cde0364f3dd9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.encoders import java.util.Arrays +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.ArrayType abstract class ExpressionEncoderSuite extends SparkFunSuite { + val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + protected def encodeDecodeTest[T]( input: T, encoder: ExpressionEncoder[T], @@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite { test(s"encode/decode for $testName: $input") { val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema, outers).bind(schema) val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index bc539d62c537d..1798514c5c38b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -53,6 +53,10 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) class ProductEncoderSuite extends ExpressionEncoderSuite { + outers.put(getClass.getName, this) + + case class InnerClass(i: Int) + productTest(InnerClass(1)) productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b644f6ad3096d..bdcdc5d47cbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,7 +74,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag @@ -375,7 +375,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 3f84e22a1025b..7e5acbe8517d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896bb8c..6ce41aaf01e27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -52,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -92,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -114,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result From e99d3392068bc929c900a4cc7b50e9e2b437a23a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 18:34:01 -0800 Subject: [PATCH 0903/2089] [SPARK-11839][ML] refactor save/write traits * add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng Closes #9827 from mengxr/SPARK-11839. --- .../scala/org/apache/spark/ml/Pipeline.scala | 40 +++++++++---------- .../classification/LogisticRegression.scala | 29 +++++++------- .../apache/spark/ml/feature/Binarizer.scala | 12 ++---- .../apache/spark/ml/feature/Bucketizer.scala | 12 ++---- .../spark/ml/feature/CountVectorizer.scala | 22 ++++------ .../org/apache/spark/ml/feature/DCT.scala | 12 ++---- .../apache/spark/ml/feature/HashingTF.scala | 12 ++---- .../org/apache/spark/ml/feature/IDF.scala | 23 +++++------ .../apache/spark/ml/feature/Interaction.scala | 12 ++---- .../spark/ml/feature/MinMaxScaler.scala | 22 ++++------ .../org/apache/spark/ml/feature/NGram.scala | 12 ++---- .../apache/spark/ml/feature/Normalizer.scala | 12 ++---- .../spark/ml/feature/OneHotEncoder.scala | 12 ++---- .../ml/feature/PolynomialExpansion.scala | 12 ++---- .../ml/feature/QuantileDiscretizer.scala | 12 ++---- .../spark/ml/feature/SQLTransformer.scala | 13 ++---- .../spark/ml/feature/StandardScaler.scala | 22 ++++------ .../spark/ml/feature/StopWordsRemover.scala | 12 ++---- .../spark/ml/feature/StringIndexer.scala | 32 +++++---------- .../apache/spark/ml/feature/Tokenizer.scala | 24 +++-------- .../spark/ml/feature/VectorAssembler.scala | 12 ++---- .../spark/ml/feature/VectorSlicer.scala | 12 ++---- .../apache/spark/ml/recommendation/ALS.scala | 27 +++++-------- .../ml/regression/LinearRegression.scala | 30 ++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 40 ++++++++++++------- .../org/apache/spark/ml/PipelineSuite.scala | 14 +++---- .../spark/ml/util/DefaultReadWriteTest.scala | 17 ++++---- 27 files changed, 190 insertions(+), 321 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 25f0c696f42be..b0f22e042ec56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Reader -import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new Pipeline.PipelineWriter(this) + override def write: MLWriter = new Pipeline.PipelineWriter(this) } -object Pipeline extends Readable[Pipeline] { +object Pipeline extends MLReadable[Pipeline] { - override def read: Reader[Pipeline] = new PipelineReader + override def read: MLReader[Pipeline] = new PipelineReader - override def load(path: String): Pipeline = read.load(path) + override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends Reader[Pipeline] { + private[ml] class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.Pipeline" @@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] { } } - /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] { /** Check that all stages are Writable */ def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { - case stage: Writable => // good + case stage: MLWritable => // good case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + @@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } @@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] { val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } @@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Writable with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -333,18 +333,18 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModel.PipelineModelWriter(this) + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } -object PipelineModel extends Readable[PipelineModel] { +object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite - override def read: Reader[PipelineModel] = new PipelineModelReader + override def read: MLReader[PipelineModel] = new PipelineModelReader - override def load(path: String): PipelineModel = read.load(path) + override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends Reader[PipelineModel] { + private[ml] class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.PipelineModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcbf47..a3cc49f7f018c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e2be6547d8f00..63c06581482ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Writable with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,17 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Binarizer extends Readable[Binarizer] { - - @Since("1.6.0") - override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] +object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") - override def load(path: String): Binarizer = read.load(path) + override def load(path: String): Binarizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7095fbd70aa07..324353a96afb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } -object Bucketizer extends Readable[Bucketizer] { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { @@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] { } @Since("1.6.0") - override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] - - @Since("1.6.0") - override def load(path: String): Bucketizer = read.load(path) + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5ff9bfb7d1119..4969cf42450d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object CountVectorizer extends Readable[CountVectorizer] { - - @Since("1.6.0") - override def read: Reader[CountVectorizer] = new DefaultParamsReader +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { @Since("1.6.0") override def load(path: String): CountVectorizer = super.load(path) @@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] { */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ @@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } @Since("1.6.0") - override def write: Writer = new CountVectorizerModelWriter(this) + override def write: MLWriter = new CountVectorizerModelWriter(this) } @Since("1.6.0") -object CountVectorizerModel extends Readable[CountVectorizerModel] { +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private[CountVectorizerModel] - class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { private case class Data(vocabulary: Seq[String]) @@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } } - private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { private val className = "org.apache.spark.ml.feature.CountVectorizerModel" @@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } @Since("1.6.0") - override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader @Since("1.6.0") override def load(path: String): CountVectorizerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6ea5a616173ee..6bed72164a1da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] with Writable { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -69,17 +69,11 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object DCT extends Readable[DCT] { - - @Since("1.6.0") - override def read: Reader[DCT] = new DefaultParamsReader[DCT] +object DCT extends DefaultParamsReadable[DCT] { @Since("1.6.0") - override def load(path: String): DCT = read.load(path) + override def load(path: String): DCT = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6d2ea675f5617..9e15835429a38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType} */ @Experimental class HashingTF(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,17 +77,11 @@ class HashingTF(override val uid: String) } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object HashingTF extends Readable[HashingTF] { - - @Since("1.6.0") - override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] +object HashingTF extends DefaultParamsReadable[HashingTF] { @Since("1.6.0") - override def load(path: String): HashingTF = read.load(path) + override def load(path: String): HashingTF = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 53ad34ef12646..0e00ef6f2ee20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IDF extends Readable[IDF] { - - @Since("1.6.0") - override def read: Reader[IDF] = new DefaultParamsReader +object IDF extends DefaultParamsReadable[IDF] { @Since("1.6.0") override def load(path: String): IDF = super.load(path) @@ -110,7 +105,7 @@ object IDF extends Readable[IDF] { class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase with Writable { + extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ @@ -140,13 +135,13 @@ class IDFModel private[ml] ( def idf: Vector = idfModel.idf @Since("1.6.0") - override def write: Writer = new IDFModelWriter(this) + override def write: MLWriter = new IDFModelWriter(this) } @Since("1.6.0") -object IDFModel extends Readable[IDFModel] { +object IDFModel extends MLReadable[IDFModel] { - private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { private case class Data(idf: Vector) @@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] { } } - private class IDFModelReader extends Reader[IDFModel] { + private class IDFModelReader extends MLReader[IDFModel] { private val className = "org.apache.spark.ml.feature.IDFModel" @@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] { } @Since("1.6.0") - override def read: Reader[IDFModel] = new IDFModelReader + override def read: MLReader[IDFModel] = new IDFModelReader @Since("1.6.0") override def load(path: String): IDFModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9df6b311cc9da..2181119f04a5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.types._ @Since("1.6.0") @Experimental class Interaction @Since("1.6.0") (override val uid: String) extends Transformer - with HasInputCols with HasOutputCol with Writable { + with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) @@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Interaction extends Readable[Interaction] { - - @Since("1.6.0") - override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] +object Interaction extends DefaultParamsReadable[Interaction] { @Since("1.6.0") - override def load(path: String): Interaction = read.load(path) + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 24d964fae834e..ed24eabb50444 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object MinMaxScaler extends Readable[MinMaxScaler] { - - @Since("1.6.0") - override def read: Reader[MinMaxScaler] = new DefaultParamsReader +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { @Since("1.6.0") override def load(path: String): MinMaxScaler = super.load(path) @@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ @@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new MinMaxScalerModelWriter(this) + override def write: MLWriter = new MinMaxScalerModelWriter(this) } @Since("1.6.0") -object MinMaxScalerModel extends Readable[MinMaxScalerModel] { +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private[MinMaxScalerModel] - class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { private case class Data(originalMin: Vector, originalMax: Vector) @@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } } - private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" @@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } @Since("1.6.0") - override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader @Since("1.6.0") override def load(path: String): MinMaxScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4a17acd95199f..65414ecbefbbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,17 +66,11 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object NGram extends Readable[NGram] { - - @Since("1.6.0") - override def read: Reader[NGram] = new DefaultParamsReader[NGram] +object NGram extends DefaultParamsReadable[NGram] { @Since("1.6.0") - override def load(path: String): NGram = read.load(path) + override def load(path: String): NGram = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 9df6a091d5058..c2d514fd9629e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class Normalizer(override val uid: String) - extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,17 +56,11 @@ class Normalizer(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Normalizer extends Readable[Normalizer] { - - @Since("1.6.0") - override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] +object Normalizer extends DefaultParamsReadable[Normalizer] { @Since("1.6.0") - override def load(path: String): Normalizer = read.load(path) + override def load(path: String): Normalizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4e2adfaafa21e..d70164eaf0224 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Writable { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object OneHotEncoder extends Readable[OneHotEncoder] { - - @Since("1.6.0") - override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { @Since("1.6.0") - override def load(path: String): OneHotEncoder = read.load(path) + override def load(path: String): OneHotEncoder = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 49415398325fd..08610593fadda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } /** @@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String) * current index and increment it properly for sparse input. */ @Since("1.6.0") -object PolynomialExpansion extends Readable[PolynomialExpansion] { +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] { } @Since("1.6.0") - override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] - - @Since("1.6.0") - override def load(path: String): PolynomialExpansion = read.load(path) + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2da5c966d2967..7bf67c6325a35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ @@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { } @Since("1.6.0") - override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] - - @Since("1.6.0") - override def load(path: String): QuantileDiscretizer = read.load(path) + override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c115064ff301a..3a735017ba836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType */ @Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) @@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object SQLTransformer extends Readable[SQLTransformer] { - - @Since("1.6.0") - override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") - override def load(path: String): SQLTransformer = read.load(path) + override def load(path: String): SQLTransformer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index ab04e5418dd4f..1f689c1da1ba9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with Writable { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) @@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StandardScaler extends Readable[StandardScaler] { - - @Since("1.6.0") - override def read: Reader[StandardScaler] = new DefaultParamsReader +object StandardScaler extends DefaultParamsReadable[StandardScaler] { @Since("1.6.0") override def load(path: String): StandardScaler = super.load(path) @@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] { class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams with Writable { + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ @@ -165,14 +159,14 @@ class StandardScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new StandardScalerModelWriter(this) + override def write: MLWriter = new StandardScalerModelWriter(this) } @Since("1.6.0") -object StandardScalerModel extends Readable[StandardScalerModel] { +object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] - class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) @@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } } - private class StandardScalerModelReader extends Reader[StandardScalerModel] { + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { private val className = "org.apache.spark.ml.feature.StandardScalerModel" @@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } @Since("1.6.0") - override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader @Since("1.6.0") override def load(path: String): StandardScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f1146988dcc7c..318808596dc6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StopWordsRemover extends Readable[StopWordsRemover] { - - @Since("1.6.0") - override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { @Since("1.6.0") - override def load(path: String): StopWordsRemover = read.load(path) + override def load(path: String): StopWordsRemover = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f16f6afc002d8..97a2e4f6d6ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase with Writable { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StringIndexer extends Readable[StringIndexer] { - - @Since("1.6.0") - override def read: Reader[StringIndexer] = new DefaultParamsReader +object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] { class StringIndexerModel ( override val uid: String, val labels: Array[String]) - extends Model[StringIndexerModel] with StringIndexerBase with Writable { + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -199,10 +193,10 @@ class StringIndexerModel ( } @Since("1.6.0") -object StringIndexerModel extends Readable[StringIndexerModel] { +object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] - class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { private case class Data(labels: Array[String]) @@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } } - private class StringIndexerModelReader extends Reader[StringIndexerModel] { + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { private val className = "org.apache.spark.ml.feature.StringIndexerModel" @@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } @Since("1.6.0") - override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader @Since("1.6.0") override def load(path: String): StringIndexerModel = super.load(path) @@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { */ @Experimental class IndexToString private[ml] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String) override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IndexToString extends Readable[IndexToString] { - - @Since("1.6.0") - override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] +object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") - override def load(path: String): IndexToString = read.load(path) + override def load(path: String): IndexToString = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0e4445d1e2fa7..8ad7bbedaab5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class Tokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -46,19 +46,13 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Tokenizer extends Readable[Tokenizer] { - - @Since("1.6.0") - override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] +object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") - override def load(path: String): Tokenizer = read.load(path) + override def load(path: String): Tokenizer = super.load(path) } /** @@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] { */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object RegexTokenizer extends Readable[RegexTokenizer] { - - @Since("1.6.0") - override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { @Since("1.6.0") - override def load(path: String): RegexTokenizer = read.load(path) + override def load(path: String): RegexTokenizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7e54205292ca2..0feec0549852b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with Writable { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorAssembler extends Readable[VectorAssembler] { - - @Since("1.6.0") - override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") - override def load(path: String): VectorAssembler = read.load(path) + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 911582b55b574..5410a50bc2e47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorSlicer extends Readable[VectorSlicer] { +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ private[feature] def validIndices(indices: Array[Int]): Boolean = { @@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] { } @Since("1.6.0") - override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] - - @Since("1.6.0") - override def load(path: String): VectorSlicer = read.load(path) + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d92514d2e239e..795b73c4c2121 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -185,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams with Writable { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -225,19 +225,19 @@ class ALSModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new ALSModel.ALSModelWriter(this) + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } @Since("1.6.0") -object ALSModel extends Readable[ALSModel] { +object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") - override def read: Reader[ALSModel] = new ALSModelReader + override def read: MLReader[ALSModel] = new ALSModelReader @Since("1.6.0") - override def load(path: String): ALSModel = read.load(path) + override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] { } } - private[recommendation] class ALSModelReader extends Reader[ALSModel] { + private[recommendation] class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.recommendation.ALSModel" @@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w } override def copy(extra: ParamMap): ALS = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w * than 2 billion. */ @DeveloperApi -object ALS extends Readable[ALS] with Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging { case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) @Since("1.6.0") - override def read: Reader[ALS] = new DefaultParamsReader[ALS] - - @Since("1.6.0") - override def load(path: String): ALS = read.load(path) + override def load(path: String): ALS = super.load(path) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7c44f0a51b8a..7ba1a60edaf71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Writable with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object LinearRegression extends Readable[LinearRegression] { - - @Since("1.6.0") - override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] +object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") - override def load(path: String): LinearRegression = read.load(path) + override def load(path: String): LinearRegression = super.load(path) } /** @@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with Writable { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) } @Since("1.6.0") -object LinearRegressionModel extends Readable[LinearRegressionModel] { +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") - override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader @Since("1.6.0") - override def load(path: String): LinearRegressionModel = read.load(path) + override def load(path: String): LinearRegressionModel = super.load(path) - /** [[Writer]] instance for [[LinearRegressionModel]] */ + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data(intercept: Double, coefficients: Vector) @@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] { } } - private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.regression.LinearRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d8ce907af5323..ff9322dba122a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Writer]] and [[Reader]]. + * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None @@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite with Logging { +abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging { } /** - * Trait for classes that provide [[Writer]]. + * Trait for classes that provide [[MLWriter]]. */ @Since("1.6.0") -trait Writable { +trait MLWritable { /** - * Returns a [[Writer]] instance for this ML instance. + * Returns an [[MLWriter]] instance for this ML instance. */ @Since("1.6.0") - def write: Writer + def write: MLWriter /** * Saves this ML instance to the input path, a shortcut of `write.save(path)`. @@ -130,13 +130,18 @@ trait Writable { def save(path: String): Unit = write.save(path) } +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + /** * Abstract class for utility classes that can load ML instances. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -abstract class Reader[T] extends BaseReadWrite { +abstract class MLReader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. @@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite { } /** - * Trait for objects that provide [[Reader]]. + * Trait for objects that provide [[MLReader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Readable[T] { +trait MLReadable[T] { /** - * Returns a [[Reader]] instance for this class. + * Returns an [[MLReader]] instance for this class. */ @Since("1.6.0") - def read: Reader[T] + def read: MLReader[T] /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. @@ -171,13 +176,18 @@ trait Readable[T] { def load(path: String): T = read.load(path) } +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + /** - * Default [[Writer]] implementation for transformers and estimators that contain basic + * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer { +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter { } /** - * Default [[Reader]] implementation for transformers and estimators that contain basic + * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ -private[ml] class DefaultParamsReader[T] extends Reader[T] { +private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7f5c3895acb0c..12aba6bc6dbeb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends Transformer with Writable { +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) override def transform(dataset: DataFrame): DataFrame = dataset override def transformSchema(schema: StructType): StructType = schema } -object WritableStage extends Readable[WritableStage] { +object WritableStage extends MLReadable[WritableStage] { - override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] - override def load(path: String): WritableStage = read.load(path) + override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dd1e8acce9418..84d06b43d6224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.save(path) } instance.write.overwrite().save(path) - val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) @@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam E Type of [[Estimator]] * @tparam M Type of [[Model]] produced by estimator */ - def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: DataFrame, testParams: Map[String, Any], @@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Writable { +class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) } -object MyParams extends Readable[MyParams] { +object MyParams extends MLReadable[MyParams] { - override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] - override def load(path: String): MyParams = read.load(path) + override def load(path: String): MyParams = super.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From e61367b9f9bfc8e123369d55d7ca5925568b98a7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 18:34:36 -0800 Subject: [PATCH 0904/2089] [SPARK-11833][SQL] Add Java tests for Kryo/Java Dataset encoders Also added some nicer error messages for incompatible types (private types and primitive types) for Kryo/Java encoder. Author: Reynold Xin Closes #9823 from rxin/SPARK-11833. --- .../scala/org/apache/spark/sql/Encoder.scala | 69 +++++++++++------ .../encoders/EncoderErrorMessageSuite.scala | 40 ++++++++++ .../catalyst/encoders/FlatEncoderSuite.scala | 22 ++---- .../apache/spark/sql/JavaDatasetSuite.java | 75 ++++++++++++++++++- 4 files changed, 166 insertions(+), 40 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ed5111440c80..d54f2854fb33f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.lang.reflect.Modifier + import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} @@ -43,30 +45,28 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - toRowExpressions = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -75,6 +75,8 @@ object Encoders { * serialization. This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -83,17 +85,40 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } def tuple[T1, T2]( e1: Encoder[T1], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 0000000000000..0b2a10bb04c10 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 6e0322fb6e019..07523d49f4266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -74,24 +74,14 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { FlatEncoder[Map[Int, Map[String, Int]]], "map of map") // Kryo encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.kryo[String]), - "kryo string") - encodeDecodeTest( - new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), - "kryo object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") + encodeDecodeTest(new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") // Java encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.javaSerialization[String]), - "java string") - encodeDecodeTest( - new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), - "java object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") + encodeDecodeTest(new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") } /** For testing Kryo serialization based encoder. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d9b22506fbd3b..ce40dd856f679 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -24,6 +24,7 @@ import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; + import org.junit.*; import org.apache.spark.Accumulator; @@ -410,8 +411,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); Assert.assertEquals( Arrays.asList( - new Tuple4("a", 3, 3L, 2L), - new Tuple4("b", 3, 3L, 1L)), + new Tuple4<>("a", 3, 3L, 2L), + new Tuple4<>("b", 3, 3L, 1L)), agged2.collectAsList()); } @@ -437,4 +438,74 @@ public Integer finish(Integer reduction) { return reduction; } } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } } From 6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 18 Nov 2015 18:38:45 -0800 Subject: [PATCH 0905/2089] [SPARK-11787][SQL] Improve Parquet scan performance when using flat schemas. This patch adds an alternate to the Parquet RecordReader from the parquet-mr project that is much faster for flat schemas. Instead of using the general converter mechanism from parquet-mr, this directly uses the lower level APIs from parquet-columnar and a customer RecordReader that directly assembles into UnsafeRows. This is optionally disabled and only used for supported schemas. Using the tpcds store sales table and doing a sum of increasingly more columns, the results are: For 1 Column: Before: 11.3M rows/second After: 18.2M rows/second For 2 Columns: Before: 7.2M rows/second After: 11.2M rows/second For 5 Columns: Before: 2.9M rows/second After: 4.5M rows/second Author: Nong Li Closes #9774 from nongli/parquet. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 41 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 + .../expressions/codegen/BufferHolder.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 20 +- .../SpecificParquetRecordReaderBase.java | 240 +++++++ .../parquet/UnsafeRowParquetRecordReader.java | 593 ++++++++++++++++++ .../parquet/CatalystRowConverter.scala | 48 +- .../parquet/ParquetFilterSuite.scala | 4 +- 8 files changed, 944 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 264dae7f39085..4d176332b69ce 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import scala.reflect.ClassTag - import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + +import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( @@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + // TODO: move this class to sql.execution and remove this. + reader = Utils.classForName( + "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") + .newInstance().asInstanceOf[RecordReader[Void, V]] + try { + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } catch { + case e: Exception => reader = null + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb62a4..33769363a0ed5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -178,6 +178,15 @@ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c9468678065d..d26b1b187c27b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,19 +17,28 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. - * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * A helper class to manage the row buffer when construct unsafe rows. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; - public void grow(int neededSize) { + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -41,12 +50,23 @@ public void grow(int neededSize) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } } } + public void grow(int neededSize) { + grow(neededSize, null); + } + public void reset() { cursor = Platform.BYTE_ARRAY_OFFSET; } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8fb4..e227c0dec9748 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -35,6 +35,7 @@ public class UnsafeRowWriter { // The offset of the global buffer where we start to write this row. private int startingOffset; private int nullBitsSize; + private UnsafeRow row; public void initialize(BufferHolder holder, int numFields) { this.holder = holder; @@ -43,7 +44,7 @@ public void initialize(BufferHolder holder, int numFields) { // grow the global buffer to make sure it has enough space to write fixed-length data. final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize); + holder.grow(fixedSize, row); holder.cursor += fixedSize; // zero-out the null bits region @@ -52,12 +53,19 @@ public void initialize(BufferHolder holder, int numFields) { } } + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } @@ -90,7 +98,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); + holder.grow(paddingBytes, row); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -153,7 +161,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -185,7 +193,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -206,7 +214,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -222,7 +230,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 0000000000000..2ed30c1f5a8d9 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 0000000000000..8a92e489ccb7c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.limit() - bytes.position(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.position(), len); + } + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3cb1..94298fae2d69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 458786f77af3f..c8028a5ef5528 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does + // not do row by row filtering (and we probably don't want to push that). + ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => From 9c0654d36c6d171dd273850c2cc2f415cc2a5a6b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 18:41:40 -0800 Subject: [PATCH 0906/2089] Revert "[SPARK-11544][SQL] sqlContext doesn't use PathFilter" This reverts commit 54db79702513e11335c33bcf3a03c59e965e6f16. --- .../apache/spark/sql/sources/interfaces.scala | 25 +++---------- .../datasources/json/JsonSuite.scala | 36 ++----------------- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936d..b3d3bdf50df63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,8 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -448,15 +447,9 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + logInfo(s"Listing $qualified on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) - } else { - Try(fs.listStatus(qualified)).getOrElse(Array.empty) - } + Try(fs.listStatus(qualified)).getOrElse(Array.empty) }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -854,16 +847,8 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f09b61e838159..6042b1178affe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,27 +19,19 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TestFileFilter extends PathFilter { - override def accept(path: Path): Boolean = path.getParent.getName != "p=2" -} - class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1398,28 +1390,4 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } - - test("SPARK-11544 test pathfilter") { - withTempPath { dir => - val path = dir.getCanonicalPath - - val df = sqlContext.range(2) - df.write.json(path + "/p=1") - df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) - - val clonedConf = new Configuration(hadoopConfiguration) - try { - hadoopConfiguration.setClass( - "mapreduce.input.pathFilter.class", - classOf[TestFileFilter], - classOf[PathFilter]) - assert(sqlContext.read.json(path).count() === 2) - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - } } From 67c75828ff4df2e305bdf5d6be5a11201d1da3f3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 18:49:46 -0800 Subject: [PATCH 0907/2089] [SPARK-11816][ML] fix some style issue in ML/MLlib examples jira: https://issues.apache.org/jira/browse/SPARK-11816 Currently I only fixed some obvious comments issue like // scalastyle:off println on the bottom. Yet the style in examples is not quite consistent, like only half of the examples are with // Example usage: ./bin/run-example mllib.FPGrowthExample \, Author: Yuhao Yang Closes #9808 from hhbyyh/exampleStyle. --- .../java/org/apache/spark/examples/ml/JavaKMeansExample.java | 2 +- .../apache/spark/examples/ml/AFTSurvivalRegressionExample.scala | 2 +- .../spark/examples/ml/DecisionTreeClassificationExample.scala | 1 + .../spark/examples/ml/DecisionTreeRegressionExample.scala | 1 + .../examples/ml/MultilayerPerceptronClassifierExample.scala | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465c..47665ff2b1f3c 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -41,7 +41,7 @@ * An example demonstrating a k-means clustering. * Run with *
    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index 5da285e83681f..f4b3613ccb94f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -59,4 +59,4 @@ object AFTSurvivalRegressionExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index ff8a0a90f1e44..db024b5cad935 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -90,3 +90,4 @@ object DecisionTreeClassificationExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index fc402724d2156..ad01f55df72b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -78,3 +78,4 @@ object DecisionTreeRegressionExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 146b83c8be490..9c98076bd24b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,4 +66,4 @@ object MultilayerPerceptronClassifierExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println From fc3f77b42d62ca789d0ee07403795978961991c7 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 18 Nov 2015 19:37:14 -0800 Subject: [PATCH 0908/2089] [SPARK-11614][SQL] serde parameters should be set only when all params are ready see HIVE-7975 and HIVE-12373 With changed semantic of setters in thrift objects in hive, setter should be called only after all parameters are set. It's not problem of current state but will be a problem in some day. Author: navis.ryu Closes #9580 from navis/SPARK-11614. --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f4d45714fae4e..9a981d02ad67c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -804,12 +804,13 @@ private[hive] case class MetastoreRelation val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) + // maps and lists should be set only after all elements are ready (see HIVE-7975) serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } From d02d5b9295b169c3ebb0967453b2835edb8a121f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 21:44:01 -0800 Subject: [PATCH 0909/2089] [SPARK-11842][ML] Small cleanups to existing Readers and Writers Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley Closes #9829 from jkbradley/ml-io-cleanups. --- .../scala/org/apache/spark/ml/Pipeline.scala | 22 +++++++++++++------ .../classification/LogisticRegression.scala | 19 ++++++++++------ .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 6 ++--- .../ml/regression/LinearRegression.scala | 4 ++-- .../LogisticRegressionSuite.scala | 2 +- 10 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b0f22e042ec56..6f15b37abcb30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.MLReader import org.apache.spark.ml.util.MLWriter @@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + @Since("1.6.0") override def write: MLWriter = new Pipeline.PipelineWriter(this) } +@Since("1.6.0") object Pipeline extends MLReadable[Pipeline] { + @Since("1.6.0") override def read: MLReader[Pipeline] = new PipelineReader + @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends MLReader[Pipeline] { + private class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) @@ -333,18 +337,22 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } +@Since("1.6.0") object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite + @Since("1.6.0") override def read: MLReader[PipelineModel] = new PipelineModelReader + @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends MLReader[PipelineModel] { + private class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" + private val className = classOf[PipelineModel].getName override def load(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a3cc49f7f018c..418bbdc9a058f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] ( * * This also does not save the [[parent]] currently. */ + @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } +@Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + @Since("1.6.0") override def load(path: String): LogisticRegressionModel = super.load(path) /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ - private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { private case class Data( @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } - private[classification] class LogisticRegressionModelReader + private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -839,7 +844,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 4969cf42450d2..b9e2144c0ad40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { - private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 0e00ef6f2ee20..f7b0f29a27c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] { private class IDFModelReader extends MLReader[IDFModel] { - private val className = "org.apache.spark.ml.feature.IDFModel" + private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ed24eabb50444..c2866f5eceff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { - private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1f689c1da1ba9..6d545219ebf49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private class StandardScalerModelReader extends MLReader[StandardScalerModel] { - private val className = "org.apache.spark.ml.feature.StandardScalerModel" + private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 97a2e4f6d6ca4..5c40c35eeaa48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private class StringIndexerModelReader extends MLReader[StringIndexerModel] { - private val className = "org.apache.spark.ml.feature.StringIndexerModel" + private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 795b73c4c2121..4d35177ad9b0f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] { } } - private[recommendation] class ALSModelReader extends MLReader[ALSModel] { + private class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.recommendation.ALSModel" + private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7ba1a60edaf71..70ccec766c471 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48ce1bb630685..a9a6ff8a783d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -898,7 +898,7 @@ object LogisticRegressionSuite { "regParam" -> 0.01, "elasticNetParam" -> 0.1, "maxIter" -> 2, // intentionally small - "fitIntercept" -> false, + "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, "threshold" -> 0.6 From 1a93323c5bab18ed7e55bf6f7b13aae88cb9721c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 18 Nov 2015 23:32:49 -0800 Subject: [PATCH 0910/2089] [SPARK-11339][SPARKR] Document the list of functions in R base package that are masked by functions with same name in SparkR Added tests for function that are reported as masked, to make sure the base:: or stats:: function can be called. For those we can't call, added them to SparkR programming guide. It would seem to me `table, sample, subset, filter, cov` not working are not actually expected - I investigated/experimented with them but couldn't get them to work. It looks like as they are defined in base or stats they are missing the S3 generic, eg. ``` > methods("transform") [1] transform,ANY-method transform.data.frame [3] transform,DataFrame-method transform.default see '?methods' for accessing help and source code > methods("subset") [1] subset.data.frame subset,DataFrame-method subset.default [4] subset.matrix see '?methods' for accessing help and source code Warning message: In .S3methods(generic.function, class, parent.frame()) : function 'subset' appears not to be S3 generic; found functions that look like S3 methods ``` Any idea? More information on masking: http://www.ats.ucla.edu/stat/r/faq/referencing_objects.htm http://www.sfu.ca/~sweldon/howTo/guide4.pdf This is what the output doc looks like (minus css): ![image](https://cloud.githubusercontent.com/assets/8969467/11229714/2946e5de-8d4d-11e5-94b0-dda9696b6fdd.png) Author: felixcheung Closes #9785 from felixcheung/rmasked. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 2 +- R/pkg/R/generics.R | 4 ++-- R/pkg/inst/tests/test_mllib.R | 5 +++++ R/pkg/inst/tests/test_sparkSQL.R | 33 +++++++++++++++++++++++++++- docs/sparkr.md | 37 +++++++++++++++++++++++++++++++- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 34177e3cdd94f..06b0108b1389e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,7 +2152,7 @@ setMethod("with", }) #' Returns the column types of a DataFrame. -#' +#' #' @name coltypes #' @title Get column types of a DataFrame #' @family dataframe_funcs diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ff0f438045c14..25a1f22101494 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2204,7 +2204,7 @@ setMethod("denseRank", #' @export #' @examples \dontrun{lag(df$c)} setMethod("lag", - signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + signature(x = "characterOrColumn"), function(x, offset, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0dcd05438222b..71004a05ba611 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,7 +539,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) # @rdname subset # @export -setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export @@ -790,7 +790,7 @@ setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag #' @export -setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index d497ad8c9daa3..e0667e5e22c18 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -31,6 +31,11 @@ test_that("glm and predict", { model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("glm should work with long formula", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d9a94faff7ac0..3f4f319fe745d 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -433,6 +433,10 @@ test_that("table() returns a new DataFrame", { expect_is(tabledf, "DataFrame") expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + + # Test base::table is working + #a <- letters[1:3] + #expect_equal(class(table(a, sample(a))), "table") }) test_that("toRDD() returns an RRDD", { @@ -673,6 +677,9 @@ test_that("sample on a DataFrame", { # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) }) test_that("select operators", { @@ -753,6 +760,9 @@ test_that("subsetting", { df6 <- subset(df, df$age %in% c(30), c(1,2)) expect_equal(count(df6), 1) expect_equal(columns(df6), c("name", "age")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) }) test_that("selectExpr() on a DataFrame", { @@ -888,6 +898,9 @@ test_that("column functions", { expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) }) # test_that("column binary mathfunctions", { @@ -1086,7 +1099,7 @@ test_that("group by, agg functions", { gd3_local <- collect(agg(gd3, var(df8$age))) expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) - # make sure base:: or stats::sd, var are working + # Test stats::sd, stats::var are working expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) @@ -1138,6 +1151,9 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered5), 1) filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) }) test_that("join() and merge() on a DataFrame", { @@ -1284,6 +1300,12 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { expect_is(unioned, "DataFrame") expect_equal(count(intersected), 1) expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) }) test_that("withColumn() and withColumnRenamed()", { @@ -1365,6 +1387,9 @@ test_that("describe() and summarize() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[4, "name"], "Andy") expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) }) test_that("dropna() and na.omit() on a DataFrame", { @@ -1448,6 +1473,9 @@ test_that("dropna() and na.omit() on a DataFrame", { expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) }) test_that("fillna() on a DataFrame", { @@ -1510,6 +1538,9 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) result <- corr(df, "singles", "doubles", "pearson") expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) }) test_that("freqItems() on a DataFrame", { diff --git a/docs/sparkr.md b/docs/sparkr.md index a744b76be7466..cfb9b41350f45 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,7 +286,7 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). @@ -351,3 +351,38 @@ summary(model) ##Sepal_Width 0.404655 {% endhighlight %}
    + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    table in package:base
    base::table(...,
    +            exclude = if (useNA == "no") c(NA, NaN),
    +            useNA = c("no", "ifany", "always"),
    +            dnn = list.names(...), deparse.level = 1)
    + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + From f449992009becc8f7c7f06cda522b9beaa1e263c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 10:48:04 -0800 Subject: [PATCH 0911/2089] [SPARK-11849][SQL] Analyzer should replace current_date and current_timestamp with literals We currently rely on the optimizer's constant folding to replace current_timestamp and current_date. However, this can still result in different values for different instances of current_timestamp/current_date if the optimizer is not running fast enough. A better solution is to replace these functions in the analyzer in one shot. Author: Reynold Xin Closes #9833 from rxin/SPARK-11849. --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f00c451b5981a..84781cd57f3dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -65,9 +65,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -84,7 +83,8 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic), + PullOutNondeterministic, + ComputeCurrentTime), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1076,7 +1076,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. - case plan => plan transformExpressionsUp { + case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) @@ -1162,3 +1162,20 @@ object CleanupAliases extends Rule[LogicalPlan] { } } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 08586a97411ae..e051069951887 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -218,4 +219,41 @@ class AnalysisSuite extends AnalysisTest { udf4) // checkUDF(udf4, expected4) } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = in.analyze.asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = in.analyze.asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } } From 962878843b611fa6229e3ee67bb22e2a4bc283cd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 19 Nov 2015 11:02:17 -0800 Subject: [PATCH 0912/2089] [SPARK-11840][SQL] Restore the 1.5's behavior of planning a single distinct aggregation. The impact of this change is for a query that has a single distinct column and does not have any grouping expression like `SELECT COUNT(DISTINCT a) FROM table` The plan will be changed from ``` AGG-2 (count distinct) Shuffle to a single reducer Partial-AGG-2 (count distinct) AGG-1 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on 1) ``` to the following one (1.5 uses this) ``` AGG-2 AGG-1 (grouping on a) Shuffle to a single reducer Partial-AGG-1(grouping on a) ``` The first plan is more robust. However, to better benchmark the impact of this change, we should use 1.5's plan and use the conf of `spark.sql.specializeSingleDistinctAggPlanning` to control the plan. Author: Yin Huai Closes #9828 from yhuai/distinctRewriter. --- .../sql/catalyst/analysis/DistinctAggregationRewriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index c0c960471a61a..9c78f6d4cc71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,8 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { // When the flag is set to specialize single distinct agg planning, // we will rely on our Aggregation strategy to handle queries with a single - // distinct column and this aggregate operator does have grouping expressions. - distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + // distinct column. + distinctAggGroups.size > 1 } else { distinctAggGroups.size >= 1 } From 72d150c271d2b206148fd0917a0def263445121b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 19 Nov 2015 11:57:50 -0800 Subject: [PATCH 0913/2089] [SPARK-11830][CORE] Make NettyRpcEnv bind to the specified host This PR includes the following change: 1. Bind NettyRpcEnv to the specified host 2. Fix the port information in the log for NettyRpcEnv. 3. Fix the service name of NettyRpcEnv. Author: zsxwing Author: Shixiong Zhu Closes #9821 from zsxwing/SPARK-11830. --- .../src/main/scala/org/apache/spark/SparkEnv.scala | 9 ++++++++- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 7 +++---- .../org/apache/spark/network/TransportContext.java | 8 +++++++- .../spark/network/server/TransportServer.java | 14 ++++++++++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4474a83bedbdb..88df27f733f2a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -258,8 +258,15 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { + val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + AkkaUtils.createActorSystem( + actorSystemName + "ActorSystem", + hostname, + actorSystemPort, + conf, + securityManager + )._1 } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3e0c497969502..3ce359868039b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -102,7 +102,7 @@ private[netty] class NettyRpcEnv( } else { java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(host, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -337,10 +337,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) - (nettyEnv, actualPort) + (nettyEnv, nettyEnv.address.port) } try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 } catch { case NonFatal(e) => nettyEnv.shutdown() @@ -370,7 +370,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { * @param conf Spark configuration. * @param endpointAddress The address where the endpoint is listening. * @param nettyEnv The RpcEnv associated with this ref. - * @param local Whether the referenced endpoint lives in the same process. */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 1b64b863a9fe5..238710d17249a 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -94,7 +94,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8d..baae235e02205 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); From 276a7e130252c0e7aba702ae5570b3c4f424b23b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:45:04 -0800 Subject: [PATCH 0914/2089] [SPARK-11633][SQL] LogicalRDD throws TreeNode Exception : Failed to Copy Node When handling self joins, the implementation did not consider the case insensitivity of HiveContext. It could cause an exception as shown in the JIRA: ``` TreeNodeException: Failed to copy node. ``` The fix is low risk. It avoids unnecessary attribute replacement. It should not affect the existing behavior of self joins. Also added the test case to cover this case. Author: gatorsmile Closes #9762 from gatorsmile/joinMakeCopy. --- .../apache/spark/sql/execution/ExistingRDD.scala | 4 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 62620ec642c78..623348f6768a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,6 +74,10 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + sqlContext :: Nil + } + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6399b0165c4c3..dd6d06512ff60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1110,6 +1110,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) From 7d4aba18722727c85893ad8d8f07d4494665dcfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:46:36 -0800 Subject: [PATCH 0915/2089] [SPARK-11848][SQL] Support EXPLAIN in DataSet APIs When debugging DataSet API, I always need to print the logical and physical plans. I am wondering if we should provide a simple API for EXPLAIN? Author: gatorsmile Closes #9832 from gatorsmile/explainDS. --- .../org/apache/spark/sql/DataFrame.scala | 23 +------------------ .../spark/sql/execution/Queryable.scala | 21 +++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ba4ba18d2122..98358127e2709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -308,27 +308,6 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) // scalastyle:on println - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } - - /** - * Only prints the physical plan to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 9ca383896a09b..e86a52c149a2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -25,6 +26,7 @@ import scala.util.control.NonFatal private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution + def sqlContext: SQLContext override def toString: String = { try { @@ -34,4 +36,23 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(): Unit = explain(extended = false) } From 47d1c2325caaf9ffe31695b6fff529314b8582f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Nov 2015 12:54:25 -0800 Subject: [PATCH 0916/2089] [SPARK-11750][SQL] revert SPARK-11727 and code clean up After some experiment, I found it's not convenient to have separate encoder builders: `FlatEncoder` and `ProductEncoder`. For example, when create encoders for `ScalaUDF`, we have no idea if the type `T` is flat or not. So I revert the splitting change in https://github.com/apache/spark/pull/9693, while still keeping the bug fixes and tests. Author: Wenchen Fan Closes #9726 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Encoder.scala | 16 +- .../spark/sql/catalyst/ScalaReflection.scala | 354 +++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 19 +- .../sql/catalyst/encoders/FlatEncoder.scala | 50 -- .../catalyst/encoders/ProductEncoder.scala | 452 ------------------ .../sql/catalyst/encoders/RowEncoder.scala | 12 +- .../sql/catalyst/expressions/objects.scala | 7 +- .../sql/catalyst/ScalaReflectionSuite.scala | 68 --- .../encoders/ExpressionEncoderSuite.scala | 218 ++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 99 ---- .../encoders/ProductEncoderSuite.scala | 156 ------ .../org/apache/spark/sql/GroupedDataset.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 23 +- .../org/apache/spark/sql/functions.scala | 4 +- 14 files changed, 364 insertions(+), 1118 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index d54f2854fb33f..86bb536459035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,14 +45,14 @@ trait Encoder[T] extends Serializable { */ object Encoders { - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 59ccf356f2c48..33ae700706dae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -50,39 +50,29 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor(tpe: `Type`): DataType = tpe match { - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(arrayType)) = tpe - val cls = arrayType match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 1).getClass + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - } - ObjectType(cls) - case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - ObjectType(clazz) - } + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className: String = tpe.erasure.typeSymbol.asClass.fullName + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) + } + } } /** @@ -90,7 +80,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - def arrayClassFor(tpe: `Type`): DataType = { + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -108,6 +98,15 @@ object ScalaReflection extends ScalaReflection { ObjectType(cls) } + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + /** * Returns an expression that can be used to construct an object of type `T` given an input * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes @@ -116,63 +115,33 @@ object ScalaReflection extends ScalaReflection { * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. + * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = - path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = - path - .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) - /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => - getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - val boxedType = optType match { - // For primitive types we must manually box the primitive value. - case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) - case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) - case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) - case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) - case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) - case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) - case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) - case _ => None - } - - boxedType.map { boxedType => - val objectType = ObjectType(boxedType) - WrapOption( - objectType, - NewInstance( - boxedType, - getPath :: Nil, - propagateNull = true, - objectType)) - }.getOrElse { - val className: String = optType.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - val objectType = ObjectType(cls) - - WrapOption(objectType, constructorFor(optType, path)) - } + WrapOption(constructorFor(optType, path)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -231,11 +200,11 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -248,57 +217,52 @@ object ScalaReflection extends ScalaReflection { } primitiveMethod.map { method => - Invoke(getPath, method, dataTypeFor(t)) + Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { - val returnType = dataTypeFor(t) Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), "array", - returnType) + arrayClassFor(elementType)) } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - - val primitiveMethodKey = keyType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } val keyData = Invoke( MapObjects( p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) - val primitiveMethodValue = valueType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - val valueData = Invoke( MapObjects( p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) @@ -308,40 +272,6 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - // Avoid boxing when possible by just wrapping a primitive array. - val primitiveMethod = elementType match { - case _ if nullable => None - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - val arrayData = primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), - "array", - arrayClassFor(elementType)) - } - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -361,8 +291,7 @@ object ScalaReflection extends ScalaReflection { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -370,7 +299,7 @@ object ScalaReflection extends ScalaReflection { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) @@ -388,22 +317,19 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - } } /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe) match { - case s: CreateNamedStruct => s - case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) - } + extractorFor(inputObject, localTypeOf[T]) match { + case s: CreateNamedStruct => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - protected def extractorFor( + private def extractorFor( inputObject: Expression, tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { @@ -491,51 +417,36 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (dataType.isInstanceOf[AtomicType]) { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - val rawMap = inputObject val keys = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + val values = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) NewInstance( classOf[ArrayBasedMapData], - keys :: values :: Nil, + convertedKeys :: convertedValues :: Nil, dataType = MapType(keyDataType, valueDataType, valueNullable)) case t if t <:< localTypeOf[String] => @@ -558,6 +469,7 @@ object ScalaReflection extends ScalaReflection { DateType, "fromJavaDate", inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal, @@ -587,26 +499,24 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if t <:< definitions.IntTpe => - BoundReference(0, IntegerType, false) - case t if t <:< definitions.LongTpe => - BoundReference(0, LongType, false) - case t if t <:< definitions.DoubleTpe => - BoundReference(0, DoubleType, false) - case t if t <:< definitions.FloatTpe => - BoundReference(0, FloatType, false) - case t if t <:< definitions.ShortTpe => - BoundReference(0, ShortType, false) - case t if t <:< definitions.ByteTpe => - BoundReference(0, ByteType, false) - case t if t <:< definitions.BooleanTpe => - BoundReference(0, BooleanType, false) - case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } } } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } } /** @@ -635,8 +545,7 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** * Return the Scala Type for `T` in the current classloader mirror. @@ -736,39 +645,4 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } - - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. - } - - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { - - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 456b595008479..6eeba1442c1f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** - * A factory for constructing encoders that convert objects and primitves to and from the + * A factory for constructing encoders that convert objects and primitives to and from the * internal row format using catalyst expressions and code generation. By default, the * expressions used to retrieve values from an input row when producing an object will be created as * follows: @@ -44,20 +44,21 @@ import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) + val fromRowExpression = ScalaReflection.constructorFor[T] new ExpressionEncoder[T]( - extractExpression.dataType, + toRowExpression.dataType, flat, - extractExpression.flatten, - constructExpression, + toRowExpression.flatten, + fromRowExpression, ClassTag[T](cls)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala deleted file mode 100644 index 6d307ab13a9fc..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} -import org.apache.spark.sql.catalyst.ScalaReflection - -object FlatEncoder { - import ScalaReflection.schemaFor - import ScalaReflection.dataTypeFor - - def apply[T : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) - - val input = BoundReference(0, dataTypeFor(tpe), nullable = true) - val toRowExpression = CreateNamedStruct( - Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) - val fromRowExpression = ProductEncoder.constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = true, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 2914c6ee790ce..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import org.apache.spark.util.Utils -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} - -import scala.reflect.ClassTag - -object ProductEncoder { - import ScalaReflection.universe._ - import ScalaReflection.mirror - import ScalaReflection.localTypeOf - import ScalaReflection.dataTypeFor - import ScalaReflection.Schema - import ScalaReflection.schemaFor - import ScalaReflection.arrayClassFor - - def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] - val fromRowExpression = constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = false, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) - - val unwrapped = UnwrapOption(optionObjectType, inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) - } - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil - }) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case other => - throw new UnsupportedOperationException(s"Encoder for type $other is not supported") - } - } - } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (RowEncoder.isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } - - def constructorFor( - tpe: `Type`, - path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) - .getOrElse(BoundReference(ordinal, dataType, false)) - - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) - - tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath - - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - WrapOption(null, constructorFor(optType, path)) - - case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) - - case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) - } - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keyData = - Invoke( - MapObjects( - p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - MapObjects( - p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType - - // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) - } else { - constructorFor(fieldType, Some(addToPath(fieldName))) - } - } - - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) - - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { - newInstance - } - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 9bb1602494b68..4cda4824acdc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -132,17 +133,8 @@ object RowEncoder { CreateStruct(convertedFields) } - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true - case _ => false - } - private def externalDataTypeFor(dt: DataType): DataType = dt match { - case _ if isNativeType(dt) => dt + case _ if ScalaReflection.isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f865a9408ef4e..ef7399e0196ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,6 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -300,10 +299,9 @@ case class UnwrapOption( /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. - * @param optionType The datatype to be held inside of the Option. * @param child The expression to evaluate and wrap. */ -case class WrapOption(optionType: DataType, child: Expression) +case class WrapOption(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -316,14 +314,13 @@ case class WrapOption(optionType: DataType, child: Expression) throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 4ea410d492b01..c2aace1ef238e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) - } - test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cde0364f3dd9d..76459b34a484f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,24 +17,234 @@ package org.apache.spark.sql.catalyst.encoders +import java.sql.{Timestamp, Date} import java.util.Arrays import java.util.concurrent.ConcurrentMap +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.types.ArrayType -abstract class ExpressionEncoderSuite extends SparkFunSuite { - val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() +case class RepeatedStruct(s: Seq[PrimitiveData]) - protected def encodeDecodeTest[T]( +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} + +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} + +class ExpressionEncoderSuite extends SparkFunSuite { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } + + case class InnerClass(i: Int) + productTest(InnerClass(1)) + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, - encoder: ExpressionEncoder[T], testName: String): Unit = { test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema, outers).bind(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala deleted file mode 100644 index 07523d49f4266..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import java.sql.{Date, Timestamp} -import org.apache.spark.sql.Encoders - -class FlatEncoderSuite extends ExpressionEncoderSuite { - encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") - encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") - encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") - encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") - encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") - encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") - encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") - - encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") - - encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") - type JDecimal = java.math.BigDecimal - // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") - - encodeDecodeTest("hello", FlatEncoder[String], "string") - encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") - encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") - encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") - - encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") - encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") - encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") - encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") - encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") - - encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), - FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") - encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), - FlatEncoder[Seq[Seq[String]]], "seq of seq of string") - - encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") - encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") - encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") - encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") - encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") - - encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), - FlatEncoder[Array[Array[Int]]], "array of array of int") - encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), - FlatEncoder[Array[Array[String]]], "array of array of string") - - encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") - encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") - encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), - FlatEncoder[Map[Int, Map[String, Int]]], "map of map") - - // Kryo encoders - encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") - encodeDecodeTest(new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") - - // Java encoders - encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") - encodeDecodeTest(new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") -} - -/** For testing Kryo serialization based encoder. */ -class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value - } -} - -/** For testing Java serialization based encoder. */ -class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala deleted file mode 100644 index 1798514c5c38b..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) { - override def equals(other: Any): Boolean = other match { - case NestedArray(otherArray) => - java.util.Arrays.deepEquals( - a.asInstanceOf[Array[AnyRef]], - otherArray.asInstanceOf[Array[AnyRef]]) - case _ => false - } -} - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ProductEncoderSuite extends ExpressionEncoderSuite { - outers.put(getClass.getName, this) - - case class InnerClass(i: Int) - productTest(InnerClass(1)) - - productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - productTest( - OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - productTest(OptionalData(None, None, None, None, None, None, None, None)) - - productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - productTest(BoxedData(null, null, null, null, null, null, null)) - - productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) - - productTest(("Seq[(String, String)]", - Seq(("a", "b")))) - productTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - productTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - productTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - productTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - productTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - productTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - productTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - productTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - productTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - productTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - productTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - productTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - productTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - productTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - productTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - productTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTest( - 1 -> 10L, - ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), - "tuple with 2 flat encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), - "tuple with 2 product encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), - "tuple with flat encoder and product encoder") - - encodeDecodeTest( - (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), - ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), - "tuple with product encoder and flat encoder") - - encodeDecodeTest( - (1, (10, 100L)), - { - val intEnc = FlatEncoder[Int] - val longEnc = FlatEncoder[Long] - ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) - }, - "nested tuple encoder") - - private def productTest[T <: Product : TypeTag](input: T): Unit = { - encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7e5acbe8517d1..6de3dd626576a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -242,7 +242,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 8471eea1b7d9c..25ffdcde17717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -28,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.StructField import org.apache.spark.unsafe.types.UTF8String @@ -37,16 +34,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() - implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] - implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] - implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] - implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] - implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] - implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] - implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] - implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** * Creates a [[Dataset]] from an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 95158de710acf..b27b1340cce46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.FlatEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(FlatEncoder[Long]) + count(Column(columnName)).as(ExpressionEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. From 4700074530d9a398843e13f0ef514be97a237cea Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 19 Nov 2015 13:08:01 -0800 Subject: [PATCH 0917/2089] [SPARK-11778][SQL] parse table name before it is passed to lookupRelation Fix a bug in DataFrameReader.table (table with schema name such as "db_name.table" doesn't work) Use SqlParser.parseTableIdentifier to parse the table name before lookupRelation. Author: Huaxin Gao Closes #9773 from huaxingao/spark-11778. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../spark/sql/hive/HiveDataFrameAnalyticsSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5872fbded3833..dcb3737b70fbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -313,7 +313,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) + DataFrame(sqlContext, + sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf765265..f19a74d4b3724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,10 +34,14 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -74,4 +78,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } + + // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, + // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException + test("table name with schema") { + hiveContext.read.table("usrdb.test") + } } From 599a8c6e2bf7da70b20ef3046f5ce099dfd637f8 Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Thu, 19 Nov 2015 13:57:23 -0800 Subject: [PATCH 0918/2089] [SPARK-11812][PYSPARK] invFunc=None works properly with python's reduceByKeyAndWindow invFunc is optional and can be None. Instead of invFunc (the parameter) invReduceFunc (a local function) was checked for trueness (that is, not None, in this context). A local function is never None, thus the case of invFunc=None (a common one when inverse reduction is not defined) was treated incorrectly, resulting in loss of data. In addition, the docstring used wrong parameter names, also fixed. Author: David Tolpin Closes #9775 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 6 +++--- python/pyspark/streaming/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce18d..acec850f02c2d 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -556,7 +556,7 @@ def invReduceFunc(t, a, b): if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0bcd1f15532b5..3403f6d20d789 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -582,6 +582,17 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): From 014c0f7a9dfdb1686fa9aeacaadb2a17a855a943 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 14:48:18 -0800 Subject: [PATCH 0919/2089] [SPARK-11858][SQL] Move sql.columnar into sql.execution. In addition, tightened visibility of a lot of classes in the columnar package from private[sql] to private[columnar]. Author: Reynold Xin Closes #9842 from rxin/SPARK-11858. --- .../spark/sql/execution/CacheManager.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../columnar/ColumnAccessor.scala | 42 +++++++-------- .../columnar/ColumnBuilder.scala | 51 ++++++++++--------- .../columnar/ColumnStats.scala | 34 ++++++------- .../{ => execution}/columnar/ColumnType.scala | 48 ++++++++--------- .../columnar/GenerateColumnAccessor.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 5 +- .../columnar/NullableColumnAccessor.scala | 4 +- .../columnar/NullableColumnBuilder.scala | 4 +- .../CompressibleColumnAccessor.scala | 6 +-- .../CompressibleColumnBuilder.scala | 6 +-- .../compression/CompressionScheme.scala | 16 +++--- .../compression/compressionSchemes.scala | 16 +++--- .../apache/spark/sql/execution/package.scala | 2 + .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 2 +- .../columnar/ColumnStatsSuite.scala | 6 +-- .../columnar/ColumnTypeSuite.scala | 4 +- .../columnar/ColumnarTestUtils.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 6 +-- .../compression/DictionaryEncodingSuite.scala | 6 +-- .../compression/IntegralDeltaSuite.scala | 6 +-- .../compression/RunLengthEncodingSuite.scala | 6 +-- .../TestCompressibleColumnBuilder.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 30 files changed, 155 insertions(+), 147 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnAccessor.scala (75%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnBuilder.scala (74%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStats.scala (88%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnType.scala (93%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/GenerateColumnAccessor.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarTableScan.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessor.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilder.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnAccessor.scala (84%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnBuilder.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressionScheme.scala (83%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/compressionSchemes.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStatsSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnTypeSuite.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnarTestUtils.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessorSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilderSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/PartitionBatchPruningSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/BooleanBitSetSuite.scala (94%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/DictionaryEncodingSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/IntegralDeltaSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/RunLengthEncodingSuite.scala (95%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/TestCompressibleColumnBuilder.scala (93%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f85aeb1b02694..293fcfe96e677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3d4ce633c07c9..f67c951bc0663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.{Strategy, execution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 42ec4d3433f16..fee36f6023895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} -import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods * for primitive values provided by [[MutableRow]]. */ -private[sql] trait ColumnAccessor { +private[columnar] trait ColumnAccessor { initialize() protected def initialize() @@ -41,7 +41,7 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[JvmType]( +private[columnar] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { @@ -61,65 +61,65 @@ private[sql] abstract class BasicColumnAccessor[JvmType]( protected def underlyingBuffer = buffer } -private[sql] class NullColumnAccessor(buffer: ByteBuffer) +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Any](buffer, NULL) with NullableColumnAccessor -private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor with CompressibleColumnAccessor[T] -private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BYTE) -private[sql] class ShortColumnAccessor(buffer: ByteBuffer) +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, INT) -private[sql] class LongColumnAccessor(buffer: ByteBuffer) +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class FloatColumnAccessor(buffer: ByteBuffer) +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, DOUBLE) -private[sql] class StringColumnAccessor(buffer: ByteBuffer) +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor -private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) with NullableColumnAccessor -private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor -private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor -private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[sql] object ColumnAccessor { +private[columnar] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 599f30f2d73b4..7e26f19bb7449 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ -private[sql] trait ColumnBuilder { +private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ @@ -46,7 +46,7 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[JvmType]( +private[columnar] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, val columnType: ColumnType[JvmType]) extends ColumnBuilder { @@ -84,17 +84,17 @@ private[sql] class BasicColumnBuilder[JvmType]( } } -private[sql] class NullColumnBuilder +private[columnar] class NullColumnBuilder extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) with NullableColumnBuilder -private[sql] abstract class ComplexColumnBuilder[JvmType]( +private[columnar] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T#InternalType](columnStats, columnType) @@ -102,40 +102,45 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) -private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType) +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnBuilder(dataType: DecimalType) +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) -private[sql] class StructColumnBuilder(dataType: StructType) +private[columnar] class StructColumnBuilder(dataType: StructType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) -private[sql] class ArrayColumnBuilder(dataType: ArrayType) +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) -private[sql] class MapColumnBuilder(dataType: MapType) +private[columnar] class MapColumnBuilder(dataType: MapType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) -private[sql] object ColumnBuilder { +private[columnar] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 91a05650585cf..c52ee9ffd6d2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() @@ -32,7 +32,7 @@ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } -private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { val (forAttribute, schema) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) @@ -45,10 +45,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed trait ColumnStats extends Serializable { +private[columnar] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - private[sql] var sizeInBytes = 0L + private[columnar] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. @@ -72,14 +72,14 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * A no-op ColumnStats only used for testing purposes. */ -private[sql] class NoopColumnStats extends ColumnStats { +private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } -private[sql] class BooleanColumnStats extends ColumnStats { +private[columnar] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true @@ -97,7 +97,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ByteColumnStats extends ColumnStats { +private[columnar] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue @@ -115,7 +115,7 @@ private[sql] class ByteColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ShortColumnStats extends ColumnStats { +private[columnar] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue @@ -133,7 +133,7 @@ private[sql] class ShortColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class IntColumnStats extends ColumnStats { +private[columnar] class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue @@ -151,7 +151,7 @@ private[sql] class IntColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class LongColumnStats extends ColumnStats { +private[columnar] class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue @@ -169,7 +169,7 @@ private[sql] class LongColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class FloatColumnStats extends ColumnStats { +private[columnar] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue @@ -187,7 +187,7 @@ private[sql] class FloatColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class DoubleColumnStats extends ColumnStats { +private[columnar] class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue @@ -205,7 +205,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class StringColumnStats extends ColumnStats { +private[columnar] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null @@ -223,7 +223,7 @@ private[sql] class StringColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class BinaryColumnStats extends ColumnStats { +private[columnar] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { @@ -235,7 +235,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null @@ -256,7 +256,7 @@ private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends Column new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 68e509eb5047d..c9f2329db4b6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * * WARNNING: This only works with HeapByteBuffer */ -object ByteBufferHelper { +private[columnar] object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -73,7 +73,7 @@ object ByteBufferHelper { * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[JvmType] { +private[columnar] sealed abstract class ColumnType[JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -142,7 +142,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] object NULL extends ColumnType[Any] { +private[columnar] object NULL extends ColumnType[Any] { override def dataType: DataType = NullType override def defaultSize: Int = 0 @@ -152,7 +152,7 @@ private[sql] object NULL extends ColumnType[Any] { override def getField(row: InternalRow, ordinal: Int): Any = null } -private[sql] abstract class NativeColumnType[T <: AtomicType]( +private[columnar] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) extends ColumnType[T#InternalType] { @@ -163,7 +163,7 @@ private[sql] abstract class NativeColumnType[T <: AtomicType]( def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } -private[sql] object INT extends NativeColumnType(IntegerType, 4) { +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -192,7 +192,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } } -private[sql] object LONG extends NativeColumnType(LongType, 8) { +private[columnar] object LONG extends NativeColumnType(LongType, 8) { override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } @@ -220,7 +220,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } } -private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } @@ -248,7 +248,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } } -private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } @@ -276,7 +276,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } } -private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def append(v: Boolean, buffer: ByteBuffer): Unit = { buffer.put(if (v) 1: Byte else 0: Byte) } @@ -302,7 +302,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } } -private[sql] object BYTE extends NativeColumnType(ByteType, 1) { +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } @@ -330,7 +330,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } } -private[sql] object SHORT extends NativeColumnType(ShortType, 2) { +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } @@ -362,7 +362,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper * objects. */ -private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { @@ -387,7 +387,7 @@ private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { } } -private[sql] object STRING +private[columnar] object STRING extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { @@ -425,7 +425,7 @@ private[sql] object STRING override def clone(v: UTF8String): UTF8String = v.clone() } -private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { @@ -467,13 +467,13 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) } } -private[sql] object COMPACT_DECIMAL { +private[columnar] object COMPACT_DECIMAL { def apply(dt: DecimalType): COMPACT_DECIMAL = { COMPACT_DECIMAL(dt.precision, dt.scale) } } -private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] @@ -492,7 +492,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: } } -private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType @@ -512,7 +512,7 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } -private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayColumnType[Decimal](12) { override val dataType: DataType = DecimalType(precision, scale) @@ -539,13 +539,13 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) } } -private[sql] object LARGE_DECIMAL { +private[columnar] object LARGE_DECIMAL { def apply(dt: DecimalType): LARGE_DECIMAL = { LARGE_DECIMAL(dt.precision, dt.scale) } } -private[sql] case class STRUCT(dataType: StructType) +private[columnar] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -586,7 +586,7 @@ private[sql] case class STRUCT(dataType: StructType) override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) +private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -625,7 +625,7 @@ private[sql] case class ARRAY(dataType: ArrayType) override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) +private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -663,7 +663,7 @@ private[sql] case class MAP(dataType: MapType) override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } -private[sql] object ColumnType { +private[columnar] object ColumnType { def apply(dataType: DataType): ColumnType[_] = { dataType match { case NullType => NULL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff9393b465b7a..eaafc96e4d2e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -121,7 +121,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; - import org.apache.spark.sql.columnar.MutableUnsafeRow; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ae77298e6da2f..ce701fb3a7f28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer @@ -50,7 +50,8 @@ private[sql] object InMemoryRelation { * @param buffers The buffers for serialized columns * @param stats The stat of columns */ -private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 7eaecfe047c3f..8d99546924de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.expressions.MutableRow -private[sql] trait NullableColumnAccessor extends ColumnAccessor { +private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ private var nullCount: Int = _ private var seenNulls: Int = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala index 76cfddf1cd01a..3a1931bfb5c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow * +---+-----+---------+ * }}} */ -private[sql] trait NullableColumnBuilder extends ColumnBuilder { +private[columnar] trait NullableColumnBuilder extends ColumnBuilder { protected var nulls: ByteBuffer = _ protected var nullCount: Int = _ private var pos: Int = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index cb205defbb1ad..6579b5068e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { +private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 161021ff96154..b0e216feb5595 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.execution.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType /** @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.AtomicType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: AtomicType] +private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 9322b772fd898..920381f9c63d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: AtomicType] { +private[columnar] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int @@ -37,13 +37,13 @@ private[sql] trait Encoder[T <: AtomicType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: AtomicType] { +private[columnar] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean } -private[sql] trait CompressionScheme { +private[columnar] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_]): Boolean @@ -53,15 +53,15 @@ private[sql] trait CompressionScheme { def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } -private[sql] trait WithCompressionSchemes { +private[columnar] trait WithCompressionSchemes { def schemes: Seq[CompressionScheme] } -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { +private[columnar] trait AllCompressionSchemes extends WithCompressionSchemes { override val schemes: Seq[CompressionScheme] = CompressionScheme.all } -private[sql] object CompressionScheme { +private[columnar] object CompressionScheme { val all: Seq[CompressionScheme] = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 41c9a284e3e4a..941f03b745a07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer @@ -23,11 +23,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ -private[sql] case object PassThrough extends CompressionScheme { +private[columnar] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_]): Boolean = true @@ -64,7 +64,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } -private[sql] case object RunLengthEncoding extends CompressionScheme { +private[columnar] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { @@ -172,7 +172,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } -private[sql] case object DictionaryEncoding extends CompressionScheme { +private[columnar] case object DictionaryEncoding extends CompressionScheme { override val typeId = 2 // 32K unique values allowed @@ -281,7 +281,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } -private[sql] case object BooleanBitSet extends CompressionScheme { +private[columnar] case object BooleanBitSet extends CompressionScheme { override val typeId = 3 val BITS_PER_LONG = 64 @@ -371,7 +371,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } } -private[sql] case object IntDelta extends CompressionScheme { +private[columnar] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -451,7 +451,7 @@ private[sql] case object IntDelta extends CompressionScheme { } } -private[sql] case object LongDelta extends CompressionScheme { +private[columnar] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 28fa231e722d0..c912734bba9e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -19,5 +19,7 @@ package org.apache.spark.sql /** * The physical execution component of Spark SQL. Note that this is a private package. + * All classes in catalyst are considered an internal API to Spark SQL and are subject + * to change between minor releases. */ package object execution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index bce94dafad755..d86df4cfb9b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,7 +27,7 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -280,7 +280,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sql("CACHE TABLE testData") sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => - val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5417b195f396..6ea1fe4ccfd89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 89a664001bdd2..b2d04f7c5a6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -50,7 +50,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -86,7 +86,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 63bc39bfa0307..34dd96929e6c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index a5882f7870e37..9cae65ef6f5dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 6265e40a0a07b..25afed25c897b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.sql.{Date, Timestamp} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index aa1605fee8c73..35dc9a276cef7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer @@ -38,7 +38,7 @@ object TestNullableColumnAccessor { } class NullableColumnAccessorSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 91404577832a0..93be3e16a5ed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -36,7 +36,7 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 6b7401464f46f..d762f7bfe914c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 9a2948c59ba42..ccbddef0fad3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index acfab6586c0d1..830ca0294e1b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 2111e9fbe62cb..988a577a7b4d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 67ec08f594a43..ce3affba55c71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5268dfe0aa03e..5e078f251375a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types.AtomicType class TestCompressibleColumnBuilder[T <: AtomicType]( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 5c2fc7d82ffbd..99478e82d419f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} From 90d384dcbc1d1a3466cf8bae570a26f23012c102 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 19 Nov 2015 14:49:25 -0800 Subject: [PATCH 0920/2089] [SPARK-11831][CORE][TESTS] Use port 0 to avoid port conflicts in tests Use port 0 to fix port-contention-related flakiness Author: Shixiong Zhu Closes #9841 from zsxwing/SPARK-11831. --- .../org/apache/spark/rpc/RpcEnvSuite.scala | 24 +++++++++---------- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 834e4743df866..2f55006420ce1 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -39,7 +39,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() - env = createRpcEnv(conf, "local", 12345) + env = createRpcEnv(conf, "local", 0) } override def afterAll(): Unit = { @@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -497,7 +497,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") @@ -543,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") @@ -571,8 +571,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { @volatile var message: String = null @@ -602,8 +602,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 6478ab51c4da2..7aac02775e1bf 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -40,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) + RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -59,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) + RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) From 3bd77b213a9cd177c3ea3c61d37e5098e55f75a5 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 19 Nov 2015 14:51:40 -0800 Subject: [PATCH 0921/2089] =?UTF-8?q?[SPARK-11799][CORE]=20Make=20it=20exp?= =?UTF-8?q?licit=20in=20executor=20logs=20that=20uncaught=20e=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …xceptions are thrown during executor shutdown This commit will make sure that when uncaught exceptions are prepended with [Container in shutdown] when JVM is shutting down. Author: Srinivasa Reddy Vundela Closes #9809 from vundela/master_11799. --- .../apache/spark/util/SparkUncaughtExceptionHandler.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 7248187247330..5e322557e9649 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler override def uncaughtException(thread: Thread, exception: Throwable) { try { - logError("Uncaught exception in thread " + thread, exception) + // Make it explicit that uncaught exceptions are thrown when container is shutting down. + // It will help users when they analyze the executor logs + val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else "" + val errMsg = "Uncaught exception in thread " + logError(inShutdownMsg + errMsg + thread, exception) // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) From f7135ed7194d4f936f6f58e14f02b1ed93f68ad1 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 14:53:58 -0800 Subject: [PATCH 0922/2089] [SPARK-11828][CORE] Register DAGScheduler metrics source after app id is known. Author: Marcelo Vanzin Closes #9820 from vanzin/SPARK-11828. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ab374cb71286a..af4456c05b0a1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -581,6 +581,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(_dagScheduler.metricsSource) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a9518fff4e7b..ae725b467d8c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -130,7 +130,7 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() @@ -1580,8 +1580,6 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread and register the metrics source at the end of the constructor - env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } From 01403aa97b6aaab9b86ae806b5ea9e82690a741f Mon Sep 17 00:00:00 2001 From: hushan Date: Thu, 19 Nov 2015 14:56:00 -0800 Subject: [PATCH 0923/2089] [SPARK-11746][CORE] Use cache-aware method dependencies a small change Author: hushan Closes #9691 from suyanNone/unify-getDependency. --- .../main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index d6a37e8cc5dac..0c6ddda52cee9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag]( } override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + dependencies.head.asInstanceOf[PruneDependency[T]].partitions } From 37cff1b1a79cad11277612cb9bc8bc2365cf5ff2 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 19 Nov 2015 15:11:30 -0800 Subject: [PATCH 0924/2089] [SPARK-11275][SQL] Incorrect results when using rollup/cube Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result. Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer. Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite: ``` build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite' ``` This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it. Author: Andrew Ray Closes #9815 from aray/groupingset-agg-fix. --- .../sql/catalyst/analysis/Analyzer.scala | 58 +++++++---------- .../plans/logical/basicOperators.scala | 4 ++ .../spark/sql/DataFrameAggregateSuite.scala | 62 +++++++++++++++++++ 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84781cd57f3dc..47962ebe6ef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -213,45 +213,35 @@ class Analyzer( GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + // Expand works by setting grouping expressions to null as determined by the bitmasks. To + // prevent these null values from being used in an aggregate instead of the original value + // we need to create new aliases for all group by expressions that will only be used for + // the intended purpose. + val groupByAliases: Seq[Alias] = x.groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + val aggregations: Seq[NamedExpression] = x.aggregations.map { + // If an expression is an aggregate (contains a AggregateExpression) then we dont change + // it so that the aggregation is computed on the unmodified value of its argument + // expressions. + case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr + // If not then its a grouping expression and we need to use the modified (with nulls from + // Expand) value of the expression. + case expr => expr.transformDown { + case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + }.asInstanceOf[NamedExpression] } + val child = Project(x.child.output ++ groupByAliases, x.child) + val groupByAttributes = groupByAliases.map(_.toAttribute) + Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + groupByAttributes :+ VirtualColumn.groupingIdAttribute, + aggregations, + Expand(x.bitmasks, groupByAttributes, gid, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 45630a591d349..0c444482c5e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode { override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 71adf2148a403..9c42f65bb6f52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("rollup") { + checkAnswer( + courseSales.rollup("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + courseSales.cube("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("rollup overlapping columns") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube overlapping columns") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), From 880128f37e1bc0b9d98d1786670be62a06c648f2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 16:49:18 -0800 Subject: [PATCH 0925/2089] [SPARK-4134][CORE] Lower severity of some executor loss logs. Don't log ERROR messages when executors are explicitly killed or when the exit reason is not yet known. Author: Marcelo Vanzin Closes #9780 from vanzin/SPARK-11789. --- .../spark/scheduler/ExecutorLossReason.scala | 2 + .../spark/scheduler/TaskSchedulerImpl.scala | 44 ++++++++++++------- .../spark/scheduler/TaskSetManager.scala | 1 + .../CoarseGrainedSchedulerBackend.scala | 18 +++++--- .../spark/deploy/yarn/YarnAllocator.scala | 4 +- 5 files changed, 45 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 47a5cbff4930b..7e1197d742802 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -40,6 +40,8 @@ private[spark] object ExecutorExited { } } +private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.") + /** * A loss reason that means we don't yet know why the executor exited. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bf0419db1f75e..bdf19f9f277d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -470,25 +470,25 @@ private[spark] class TaskSchedulerImpl( synchronized { if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - executorIdToHost.get(executorId) match { - case Some(_) => - // If the host mapping still exists, it means we don't know the loss reason for the - // executor. So call removeExecutor() to update tasks running on that executor when - // the real loss reason is finally known. - logError(s"Actual reason for lost executor $executorId: ${reason.message}") - removeExecutor(executorId, reason) - - case None => - // We may get multiple executorLost() calls with different loss reasons. For example, - // one may be triggered by a dropped connection from the slave while another may be a - // report of executor termination from Mesos. We produce log messages for both so we - // eventually report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } + executorIdToHost.get(executorId) match { + case Some(hostPort) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError(s"Lost an executor $executorId (already removed): $reason") + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -498,6 +498,18 @@ private[spark] class TaskSchedulerImpl( } } + private def logExecutorLoss( + executorId: String, + hostPort: String, + reason: ExecutorLossReason): Unit = reason match { + case LossReasonPending => + logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") + case ExecutorKilled => + logInfo(s"Executor $executorId on $hostPort killed by driver.") + case _ => + logError(s"Lost executor $executorId on $hostPort: $reason") + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 114468c48c44c..a02f3017cb6e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -800,6 +800,7 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp + case ExecutorKilled => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6f0c910c009a5..505c161141c88 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -64,8 +64,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val listenerBus = scheduler.sc.listenerBus - // Executors we have requested the cluster manager to kill that have not died yet - private val executorsPendingToRemove = new HashSet[String] + // Executors we have requested the cluster manager to kill that have not died yet; maps + // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't + // be considered an app-related failure). + private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it protected var hostToLocalTaskCount: Map[String, Int] = Map.empty @@ -250,15 +252,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { + val killed = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId - executorsPendingToRemove -= executorId executorsPendingLossReason -= executorId + executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, reason) + scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") @@ -459,6 +461,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones * @param force whether to force kill busy executors @@ -479,7 +485,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsPendingToRemove ++= executorsToKill + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7e39c3ea56af3..73cd9031f0250 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -481,7 +481,7 @@ private[yarn] class YarnAllocator( (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) - case unknown => + case _ => numExecutorsFailed += 1 (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + @@ -493,7 +493,7 @@ private[yarn] class YarnAllocator( } else { logInfo(containerExitReason) } - ExecutorExited(0, exitCausedByApp, containerExitReason) + ExecutorExited(exitStatus, exitCausedByApp, containerExitReason) } else { // If we have already released this container, then it must mean // that the driver has explicitly requested it to be killed From b2cecb80ece59a1c086d4ae7aeebef445a4e7299 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 19 Nov 2015 16:50:08 -0800 Subject: [PATCH 0926/2089] [SPARK-11845][STREAMING][TEST] Added unit test to verify TrackStateRDD is correctly checkpointed To make sure that all lineage is correctly truncated for TrackStateRDD when checkpointed. Author: Tathagata Das Closes #9831 from tdas/SPARK-11845. --- .../org/apache/spark/CheckpointSuite.scala | 411 +++++++++--------- .../streaming/rdd/TrackStateRDDSuite.scala | 60 ++- 2 files changed, 267 insertions(+), 204 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 119e5fc28e412..ab23326c6c25d 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,17 +21,223 @@ import java.io.File import scala.reflect.ClassTag +import org.apache.spark.CheckpointSuite._ import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +trait RDDCheckpointTester { self: SparkFunSuite => + + protected val partitioner = new HashPartitioner(2) + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + + /** Implementations of this trait must implement this method */ + protected def sparkContext: SparkContext + + /** + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + + // Find serialized sizes before and after the checkpoint + logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + checkpoint(operatedRDD, reliableCheckpoint) + val result = collectFunc(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the checkpoint file has been created + if (reliableCheckpoint) { + assert( + collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + } + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDDs = operatedRDD.dependencies.map(_.rdd) + val rddType = operatedRDD.getClass.getSimpleName + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + // Find serialized sizes before and after the checkpoint + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } + val result = collectFunc(operatedRDD) // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + protected def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + private def initializeRdd(rdd: RDD[_]): Unit = { + rdd.partitions // forces the initialization of the partitions + rdd.dependencies.map(_.rdd).foreach(initializeRdd) + } + + /** Checkpoint the RDD either locally or reliably. */ + protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + protected def runTest(name: String)(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + test(name + " [local checkpoint]")(body(false)) + } + + /** + * Generate an RDD such that both the RDD and its partitions have large size. + */ + protected def generateFatRDD(): RDD[Int] = { + new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) + } + + /** + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. + */ + protected def generateFatPairRDD(): RDD[(Int, Int)] = { + new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + } +} + /** * Test suite for end-to-end checkpointing functionality. * This tests both reliable checkpoints and local checkpoints. */ -class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { private var checkpointDir: File = _ - private val partitioner = new HashPartitioner(2) override def beforeEach(): Unit = { super.beforeEach() @@ -46,6 +252,8 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) @@ -250,204 +458,6 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } - - // Utility test methods - - /** Checkpoint the RDD either locally or reliably. */ - private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { - if (reliableCheckpoint) { - rdd.checkpoint() - } else { - rdd.localCheckpoint() - } - } - - /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - private def runTest(name: String)(body: Boolean => Unit): Unit = { - test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) - } - - private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() - - /** - * Test checkpointing of the RDD generated by the given operation. It tests whether the - * serialized size of the RDD is reduce after checkpointing or not. This function should be called - * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDD[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - val partitionsBeforeCheckpoint = operatedRDD.partitions - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - checkpoint(operatedRDD, reliableCheckpoint) - val result = collectFunc(operatedRDD) - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the checkpoint file has been created - if (reliableCheckpoint) { - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) - } - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed from its earlier partitions - assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the RDD has reduced. - logInfo("Size of " + rddType + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * the generated RDD will remember the partitions and therefore potentially the whole lineage. - * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDDPartitions[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDDs = operatedRDD.dependencies.map(_.rdd) - val rddType = operatedRDD.getClass.getSimpleName - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - // checkpoint the parent RDD, not the generated one - parentRDDs.foreach { rdd => - checkpoint(rdd, reliableCheckpoint) - } - val result = collectFunc(operatedRDD) // force checkpointing - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the partitions has reduced - logInfo("Size of partitions of " + rddType + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") - assert( - partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" - ) - } - - /** - * Generate an RDD such that both the RDD and its partitions have large size. - */ - private def generateFatRDD(): RDD[Int] = { - new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) - } - - /** - * Generate an pair RDD (with partitioner) such that both the RDD and its partitions - * have large size. - */ - private def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - val rddSize = Utils.serialize(rdd).size - val rddCpDataSize = Utils.serialize(rdd.checkpointData).size - val rddPartitionSize = Utils.serialize(rdd.partitions).size - val rddDependenciesSize = Utils.serialize(rdd.dependencies).size - - // Print detailed size, helps in debugging - logInfo("Serialized sizes of " + rdd + - ": RDD = " + rddSize + - ", RDD checkpoint data = " + rddCpDataSize + - ", RDD partitions = " + rddPartitionSize + - ", RDD dependencies = " + rddDependenciesSize - ) - // this makes sure that serializing the RDD's checkpoint data does not - // serialize the whole RDD as well - assert( - rddSize > rddCpDataSize, - "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + - "whole RDD with checkpoint data (" + rddSize + ")" - ) - (rddSize - rddCpDataSize, rddPartitionSize) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - private def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } - - /** - * Recursively force the initialization of the all members of an RDD and it parents. - */ - private def initializeRdd(rdd: RDD[_]): Unit = { - rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd) - } - } /** RDD partition that has large serialized size. */ @@ -494,5 +504,4 @@ object CheckpointSuite { part ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } - } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 19ef5a14f8ab4..0feb3af1abb0f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -17,31 +17,40 @@ package org.apache.spark.streaming.rdd +import java.io.File + import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.util.OpenHashMapBasedStateMap -import org.apache.spark.streaming.{Time, State} -import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.util.Utils -class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { +class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { private var sc: SparkContext = null + private var checkpointDir: File = _ override def beforeAll(): Unit = { sc = new SparkContext( new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) } override def afterAll(): Unit = { if (sc != null) { sc.stop() } + Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + test("creation from pair RDD") { val data = Seq((1, "1"), (2, "2"), (3, "3")) val partitioner = new HashPartitioner(10) @@ -278,6 +287,51 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { rdd7, Seq(("k3", 2)), Set()) } + test("checkpointing") { + /** + * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs - + * the data RDD and the parent TrackStateRDD. + */ + def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) } + .collect.toSet + } + + /** Generate TrackStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : TrackStateRDD[Int, Int, Int, Int] = { + TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate TrackStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = { + + // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent + new TrackStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From ee21407747fb00db2f26d1119446ccbb20c19232 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 19 Nov 2015 17:14:10 -0800 Subject: [PATCH 0927/2089] [SPARK-11864][SQL] Improve performance of max/min This PR has the following optimization: 1) The greatest/least already does the null-check, so the `If` and `IsNull` are not necessary. 2) In greatest/least, it should initialize the result using the first child (removing one block). 3) For primitive types, the generated greater expression is too complicated (`a > b ? 1 : (a < b) ? -1 : 0) > 0`), should be as simple as `a > b` Combine these optimization, this could improve the performance of `ss_max` query by 30%. Author: Davies Liu Closes #9846 from davies/improve_max. --- .../catalyst/expressions/aggregate/Max.scala | 5 +-- .../catalyst/expressions/aggregate/Min.scala | 5 +-- .../expressions/codegen/CodeGenerator.scala | 12 ++++++ .../expressions/conditionalExpressions.scala | 38 +++++++++++-------- .../expressions/nullExpressions.scala | 10 +++-- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 61cae44cd0f5b..906003188d4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -46,13 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + /* max = */ Greatest(Seq(max, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val greatest = Greatest(Seq(max.left, max.right)) Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + /* max = */ Greatest(Seq(max.left, max.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 242456d9e2e18..39f7afbd081cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -47,13 +47,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + /* min = */ Least(Seq(min, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val least = Least(Seq(min.left, min.right)) Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + /* min = */ Least(Seq(min.left, min.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1718cfbd35332..1b7260cdfe515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -329,6 +329,18 @@ class CodeGenContext { throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } + /** + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ + def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" + case _ => s"(${genComp(dataType, c1, c2)}) > 0" + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 0d4af43978ea1..694a2a7c54a90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -348,19 +348,22 @@ case class Least(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} < 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, ev.value, eval.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } @@ -403,19 +406,22 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} > 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, eval.value, ev.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 94deafb75b69c..df4747d4e6f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -62,11 +62,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val first = children(0) + val rest = children.drop(1) + val firstEval = first.gen(ctx) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${firstEval.code} + boolean ${ev.isNull} = ${firstEval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; """ + - children.map { e => + rest.map { e => val eval = e.gen(ctx) s""" if (${ev.isNull}) { From 7ee7d5a3c4ff77d2cee2afce36ff41f6302e6315 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Nov 2015 19:46:10 -0800 Subject: [PATCH 0928/2089] [SPARK-11544][SQL][TEST-HADOOP1.0] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9830 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++--- .../datasources/json/JsonSuite.scala | 41 ++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df63..f9465157c936d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178affe..ba7718c864637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,33 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Setting it twice as the name of the propery has changed between hadoop versions. + hadoopConfiguration.setClass( + "mapred.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 4114ce20fbe820f111e55e891ae3889b0e6e0006 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 22:01:02 -0800 Subject: [PATCH 0929/2089] [SPARK-11846] Add save/load for AFTSurvivalRegression and IsotonicRegression https://issues.apache.org/jira/browse/SPARK-11846 mengxr Author: Xusen Yin Closes #9836 from yinxusen/SPARK-11846. --- .../ml/regression/AFTSurvivalRegression.scala | 78 +++++++++++++++-- .../ml/regression/IsotonicRegression.scala | 83 +++++++++++++++++-- .../AFTSurvivalRegressionSuite.scala | 37 ++++++++- .../regression/IsotonicRegressionSuite.scala | 34 +++++++- 4 files changed, 210 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index b7d095872ffa5..aedfb48058dc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -21,20 +21,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} +import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. @@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params @Experimental @Since("1.6.0") class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams + with DefaultParamsWritable with Logging { @Since("1.6.0") def this() = this(Identifiable.randomUID("aftSurvReg")) @@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) } +@Since("1.6.0") +object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[AFTSurvivalRegression]]. @@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) - extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { /** @group setParam */ @Since("1.6.0") @@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] ( copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = + new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) +} + +@Since("1.6.0") +object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ + private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( + instance: AFTSurvivalRegressionModel + ) extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double, scale: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: coefficients, intercept, scale + val data = Data(instance.coefficients, instance.intercept, instance.scale) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[AFTSurvivalRegressionModel].getName + + override def load(path: String): AFTSurvivalRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("coefficients", "intercept", "scale").head() + val coefficients = data.getAs[Vector](0) + val intercept = data.getDouble(1) + val scale = data.getDouble(2) + val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a1fe01b047108..bbb1c7ac0a51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures @Since("1.5.0") @Experimental class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) @@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } } +@Since("1.6.0") +object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { + + @Since("1.6.0") + override def load(path: String): IsotonicRegression = super.load(path) +} + /** * :: Experimental :: * Model fitted by IsotonicRegression. @@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri class IsotonicRegressionModel private[ml] ( override val uid: String, private val oldModel: MLlibIsotonicRegressionModel) - extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable { /** @group setParam */ @Since("1.5.0") @@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false) } + + @Since("1.6.0") + override def write: MLWriter = + new IsotonicRegressionModelWriter(this) +} + +@Since("1.6.0") +object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader + + @Since("1.6.0") + override def load(path: String): IsotonicRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */ + private[IsotonicRegressionModel] class IsotonicRegressionModelWriter ( + instance: IsotonicRegressionModel + ) extends MLWriter with Logging { + + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: boundaries, predictions, isotonic + val data = Data( + instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[IsotonicRegressionModel].getName + + override def load(path: String): IsotonicRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("boundaries", "predictions", "isotonic").head() + val boundaries = data.getAs[Seq[Double]](0).toArray + val predictions = data.getAs[Seq[Double]](1).toArray + val isotonic = data.getBoolean(2) + val model = new IsotonicRegressionModel( + metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 359f31027172b..d718ef63b531a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -21,14 +21,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} -class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class AFTSurvivalRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(prediction ~== model.predict(features) relTol 1E-5) } } + + test("read/write") { + def checkModelData( + model: AFTSurvivalRegressionModel, + model2: AFTSurvivalRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.scale === model2.scale) + } + val aft = new AFTSurvivalRegression() + testEstimatorAndModelReadWrite(aft, datasetMultivariate, + AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + } +} + +object AFTSurvivalRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "fitIntercept" -> true, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 59f4193abc8f0..f067c29d27a7d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class IsotonicRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { sqlContext.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } @@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) } + + test("read/write") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + + def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = { + assert(model.boundaries === model2.boundaries) + assert(model.predictions === model2.predictions) + assert(model.isotonic === model2.isotonic) + } + + val ir = new IsotonicRegression() + testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, + checkModelData) + } +} + +object IsotonicRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "isotonic" -> true, + "featureIndex" -> 0 + ) } From 3b7f056da87a23f3a96f0311b3a947a9b698f38b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:02:17 -0800 Subject: [PATCH 0930/2089] [SPARK-11829][ML] Add read/write to estimators under ml.feature (II) Add read/write support to the following estimators under spark.ml: * ChiSqSelector * PCA * VectorIndexer * Word2Vec Author: Yanbo Liang Closes #9838 from yanboliang/spark-11829. --- .../spark/ml/feature/ChiSqSelector.scala | 65 ++++++++++++++++-- .../org/apache/spark/ml/feature/PCA.scala | 67 +++++++++++++++++-- .../spark/ml/feature/VectorIndexer.scala | 66 ++++++++++++++++-- .../apache/spark/ml/feature/Word2Vec.scala | 67 +++++++++++++++++-- .../apache/spark/mllib/feature/Word2Vec.scala | 6 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 22 +++++- .../apache/spark/ml/feature/PCASuite.scala | 26 ++++++- .../spark/ml/feature/VectorIndexerSuite.scala | 22 +++++- .../spark/ml/feature/Word2VecSuite.scala | 30 ++++++++- 9 files changed, 338 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 5e4061fba5494..dfec03828f4b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.{AttributeGroup, _} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params */ @Experimental final class ChiSqSelector(override val uid: String) - extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("chiSqSelector")) @@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String) override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) } +@Since("1.6.0") +object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { + + @Since("1.6.0") + override def load(path: String): ChiSqSelector = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[ChiSqSelector]]. @@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String) final class ChiSqSelectorModel private[ml] ( override val uid: String, private val chiSqSelector: feature.ChiSqSelectorModel) - extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { + + import ChiSqSelectorModel._ + + /** list of indices to select (filter). Must be ordered asc */ + val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures /** @group setParam */ def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] ( val copied = new ChiSqSelectorModel(uid, chiSqSelector) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ChiSqSelectorModelWriter(this) +} + +@Since("1.6.0") +object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + + private[ChiSqSelectorModel] + class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { + + private case class Data(selectedFeatures: Seq[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.selectedFeatures.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] { + + private val className = classOf[ChiSqSelectorModel].getName + + override def load(path: String): ChiSqSelectorModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val selectedFeatures = data.getAs[Seq[Int]](0).toArray + val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) + val model = new ChiSqSelectorModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader + + @Since("1.6.0") + override def load(path: String): ChiSqSelectorModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 539084704b653..32d7afee6e73b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * PCA trains a model to project vectors to a low-dimensional space using PCA. */ @Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("pca")) @@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def copy(extra: ParamMap): PCA = defaultCopy(extra) } +@Since("1.6.0") +object PCA extends DefaultParamsReadable[PCA] { + + @Since("1.6.0") + override def load(path: String): PCA = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[PCA]]. @@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams class PCAModel private[ml] ( override val uid: String, pcaModel: feature.PCAModel) - extends Model[PCAModel] with PCAParams { + extends Model[PCAModel] with PCAParams with MLWritable { + + import PCAModel._ + + /** a principal components Matrix. Each column is one principal component. */ + val pc: DenseMatrix = pcaModel.pc /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -127,4 +142,46 @@ class PCAModel private[ml] ( val copied = new PCAModel(uid, pcaModel) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PCAModelWriter(this) +} + +@Since("1.6.0") +object PCAModel extends MLReadable[PCAModel] { + + private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { + + private case class Data(k: Int, pc: DenseMatrix) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.getK, instance.pc) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class PCAModelReader extends MLReader[PCAModel] { + + private val className = classOf[PCAModel].getName + + override def load(path: String): PCAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("k", "pc") + .head() + val oldModel = new feature.PCAModel(k, pc) + val model = new PCAModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[PCAModel] = new PCAModelReader + + @Since("1.6.0") + override def load(path: String): PCAModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 52e0599e38d83..a637a6f2881de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,12 +22,14 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.udf @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ @Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams { + with VectorIndexerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecIdx")) @@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } -private object VectorIndexer { +@Since("1.6.0") +object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + + @Since("1.6.0") + override def load(path: String): VectorIndexer = super.load(path) /** * Helper class for tracking unique values for each feature. @@ -146,7 +152,7 @@ private object VectorIndexer { * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. * @param maxCategories This class caps the number of unique values collected at maxCategories. */ - class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) extends Serializable { /** featureValueSets[feature index] = set of unique values */ @@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) - extends Model[VectorIndexerModel] with VectorIndexerParams { + extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { + + import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { @@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] ( val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new VectorIndexerModelWriter(this) +} + +@Since("1.6.0") +object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + + private[VectorIndexerModel] + class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { + + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numFeatures, instance.categoryMaps) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] { + + private val className = classOf[VectorIndexerModel].getName + + override def load(path: String): VectorIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("numFeatures", "categoryMaps") + .head() + val numFeatures = data.getAs[Int](0) + val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) + val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader + + @Since("1.6.0") + override def load(path: String): VectorIndexerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 708dbeef84db4..a8d61b6dea00b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,15 +17,17 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -90,7 +92,8 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("w2v")) @@ -139,6 +142,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[Word2Vec]]. @@ -147,7 +157,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] class Word2VecModel private[ml] ( override val uid: String, @transient private val wordVectors: feature.Word2VecModel) - extends Model[Word2VecModel] with Word2VecBase { + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + + import Word2VecModel._ /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and @@ -224,4 +236,49 @@ class Word2VecModel private[ml] ( val copied = new Word2VecModel(uid, wordVectors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7ab0d89d23a3f..a47f27b0afb14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -432,9 +432,9 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Since("1.1.0") -class Word2VecModel private[mllib] ( - private val wordIndex: Map[String, Int], - private val wordVectors: Array[Float]) extends Serializable with Saveable { +class Word2VecModel private[spark] ( + private[spark] val wordIndex: Map[String, Int], + private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { private val numWords = wordIndex.size // vectorSize: Dimension of each word's vector. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e5a42967bd2c8..7827db2794cf3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + test("Test Chi-Square selector") { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vec1 ~== vec2 absTol 1e-1) } } + + test("ChiSqSelector read/write") { + val t = new ChiSqSelector() + .setFeaturesCol("myFeaturesCol") + .setLabelCol("myLabelCol") + .setOutputCol("myOutputCol") + .setNumTopFeatures(2) + testDefaultReadWrite(t) + } + + test("ChiSqSelectorModel read/write") { + val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) + val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.selectedFeatures === instance.selectedFeatures) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 30c500f87a769..5a21cd20ceede 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PCA) @@ -65,4 +65,24 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("read/write") { + + def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { + assert(model1.pc === model2.pc) + } + val allParams: Map[String, Any] = Map( + "k" -> 3, + "inputCol" -> "features", + "outputCol" -> "pca_features" + ) + val data = Seq( + (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), + (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + ) + val df = sqlContext.createDataFrame(data).toDF("id", "features") + val pca = new PCA().setK(3) + testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8cb0a2cf14d37..67817fa4baf56 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,13 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest with Logging { import VectorIndexerSuite.FeatureData @@ -251,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L } } } + + test("VectorIndexer read/write") { + val t = new VectorIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxCategories(30) + testDefaultReadWrite(t) + } + + test("VectorIndexerModel read/write") { + val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1, + 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2)) + val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.numFeatures === instance.numFeatures) + assert(newInstance.categoryMaps === instance.categoryMaps) + } } private[feature] object VectorIndexerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 23dfdaa9f8fc6..a773244cd735e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -143,5 +143,31 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("Word2Vec read/write") { + val t = new Word2Vec() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxIter(2) + .setMinCount(8) + .setNumPartitions(1) + .setSeed(42L) + .setStepSize(0.01) + .setVectorSize(100) + testDefaultReadWrite(t) + } + + test("Word2VecModel read/write") { + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val oldModel = new OldWord2VecModel(word2VecMap) + val instance = new Word2VecModel("myWord2VecModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + } } From 7216f405454f6f3557b5b1f72df8f393605faf60 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:14:01 -0800 Subject: [PATCH 0931/2089] [SPARK-11875][ML][PYSPARK] Update doc for PySpark HasCheckpointInterval * Update doc for PySpark ```HasCheckpointInterval``` that users can understand how to disable checkpoint. * Update doc for PySpark ```cacheNodeIds``` of ```DecisionTreeParams``` to notify the relationship between ```cacheNodeIds``` and ```checkpointInterval```. Author: Yanbo Liang Closes #9856 from yanboliang/spark-11875. --- python/pyspark/ml/param/_shared_params_code_gen.py | 6 ++++-- python/pyspark/ml/param/shared.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 070c5db01ae73..0528dc1e3a6b9 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -118,7 +118,8 @@ def get$Name(self): ("inputCols", "input column names.", None), ("outputCol", "output column name.", "self.uid + '__output'"), ("numFeatures", "number of features.", None), - ("checkpointInterval", "checkpoint interval (>= 1).", None), + ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), ("seed", "random seed.", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms.", None), ("stepSize", "Step size to be used for each iteration of optimization.", None), @@ -157,7 +158,8 @@ def get$Name(self): ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + - "Caching can speed up training of deeper trees.")] + "Caching can speed up training of deeper trees. Users can set how often should the " + + "cache be checkpointed or disable it by setting checkpointInterval.")] decisionTreeCode = '''class DecisionTreeParams(Params): """ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4bdf2a8cc563f..4d960801502c2 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -325,16 +325,16 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: checkpoint interval (>= 1). + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1).") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for checkpoint interval (>= 1). - self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1).") + #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def setCheckpointInterval(self, value): """ @@ -636,7 +636,7 @@ class DecisionTreeParams(Params): minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def __init__(self): @@ -651,8 +651,8 @@ def __init__(self): self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") #: param for Maximum memory in MB allocated to histogram aggregation. self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. + self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def setMaxDepth(self, value): """ From 0fff8eb3e476165461658d4e16682ec64269fdfe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 19 Nov 2015 23:42:24 -0800 Subject: [PATCH 0932/2089] [SPARK-11869][ML] Clean up TempDirectory properly in ML tests Need to remove parent directory (```className```) rather than just tempDir (```className/random_name```) I tested this with IDFSuite, which has 2 read/write tests, and it fixes the problem. CC: mengxr Can you confirm this is fine? I believe it is since the same ```random_name``` is used for all tests in a suite; we basically have an extra unneeded level of nesting. Author: Joseph K. Bradley Closes #9851 from jkbradley/tempdir-cleanup. --- .../src/test/scala/org/apache/spark/ml/util/TempDirectory.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 2742026a69c2e..c8a0bb16247b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -35,7 +35,7 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => override def beforeAll(): Unit = { super.beforeAll() - _tempDir = Utils.createTempDir(this.getClass.getName) + _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) } override def afterAll(): Unit = { From 3e1d120cedb4bd9e1595e95d4d531cf61da6684d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 23:43:18 -0800 Subject: [PATCH 0933/2089] [SPARK-11867] Add save/load for kmeans and naive bayes https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin Closes #9849 from yinxusen/SPARK-11867. --- .../spark/ml/classification/NaiveBayes.scala | 68 +++++++++++++++++-- .../apache/spark/ml/clustering/KMeans.scala | 67 ++++++++++++++++-- .../ml/classification/NaiveBayesSuite.scala | 47 +++++++++++-- .../spark/ml/clustering/KMeansSuite.scala | 41 ++++++++--- 4 files changed, 195 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a14dcecbaf5b9..c512a2cb8bf3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { @Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("nb")) @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** * :: Experimental :: * Model produced by [[NaiveBayes]] @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] ( s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be63002396..71e968497500f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,10 +17,12 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} - /** * Common params for KMeans and KMeansModel */ @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -129,6 +131,52 @@ class KMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -141,7 +189,7 @@ class KMeansModel private[ml] ( @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") ( } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 98bc9511163e7..082a6bcd211ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} +import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + } def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { @@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") } + + test("read/write") { + def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { + assert(model.pi === model2.pi) + assert(model.theta === model2.theta) + } + val nb = new NaiveBayes() + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + } +} + +object NaiveBayesSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "smoothing" -> 0.1 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c05f90550d161..2724e51f31aa4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) -object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext - val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) - sql.createDataFrame(rdd) - } -} - -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) } + + test("read/write") { + def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new KMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + } +} + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) } From a66142decee48bf5689fb7f4f33646d7bb1ac08d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 00:46:29 -0800 Subject: [PATCH 0934/2089] [SPARK-11877] Prevent agg. fallback conf. from leaking across test suites This patch fixes an issue where the `spark.sql.TungstenAggregate.testFallbackStartsAt` SQLConf setting was not properly reset / cleared at the end of `TungstenAggregationQueryWithControlledFallbackSuite`. This ended up causing test failures in HiveCompatibilitySuite in Maven builds by causing spilling to occur way too frequently. This configuration leak was inadvertently introduced during test cleanup in #9618. Author: Josh Rosen Closes #9857 from JoshRosen/clear-fallback-prop-in-test-teardown. --- .../execution/AggregationQuerySuite.scala | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6dde79f74d3d8..39c0a2a0de045 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -868,29 +868,27 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => - sqlContext.setConf( - "spark.sql.TungstenAggregate.testFallbackStartsAt", - fallbackStartsAt.toString) - - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = DataFrame(sqlContext, actual.logicalPlan) - - QueryTest.checkAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using TungstenAggregate with - |controlled fallback (it falls back to sort-based aggregation once it has processed - |$fallbackStartsAt input rows). The query is - |${actual.queryExecution} - | - |$errorMessage - """.stripMargin - - fail(newErrorMessage) - case None => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } } } } From 9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 20 Nov 2015 09:55:53 -0800 Subject: [PATCH 0935/2089] [SPARK-11852][ML] StandardScaler minor refactor ```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```. Author: Yanbo Liang Closes #9839 from yanboliang/standardScaler-refactor. --- .../spark/ml/feature/StandardScaler.scala | 60 +++++++++---------- .../ml/feature/StandardScalerSuite.scala | 11 ++-- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 6d545219ebf49..d76a9c6275e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType} private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { /** - * Centers the data with mean before scaling. + * Whether to center the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ - val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + val withMean: BooleanParam = new BooleanParam(this, "withMean", + "Whether to center data with mean") + + /** @group getParam */ + def getWithMean: Boolean = $(withMean) /** - * Scales the data to unit standard deviation. + * Whether to scale the data to unit standard deviation. * Default: true * @group param */ - val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") + val withStd: BooleanParam = new BooleanParam(this, "withStd", + "Whether to scale the data to unit standard deviation") + + /** @group getParam */ + def getWithStd: Boolean = $(withStd) + + setDefault(withMean -> false, withStd -> true) } /** @@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM def this() = this(Identifiable.randomUID("stdScal")) - setDefault(withMean -> false, withStd -> true) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) + copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] { /** * :: Experimental :: * Model fitted by [[StandardScaler]]. + * + * @param std Standard deviation of the StandardScalerModel + * @param mean Mean of the StandardScalerModel */ @Experimental class StandardScalerModel private[ml] ( override val uid: String, - scaler: feature.StandardScalerModel) + val std: Vector, + val mean: Vector) extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ - /** Standard deviation of the StandardScalerModel */ - val std: Vector = scaler.std - - /** Mean of the StandardScalerModel */ - val mean: Vector = scaler.mean - - /** Whether to scale to unit standard deviation. */ - @Since("1.6.0") - def getWithStd: Boolean = scaler.withStd - - /** Whether to center data with mean. */ - @Since("1.6.0") - def getWithMean: Boolean = scaler.withMean - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -139,6 +137,7 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } dataset.withColumn($(outputCol), scale(col($(inputCol)))) } @@ -154,7 +153,7 @@ class StandardScalerModel private[ml] ( } override def copy(extra: ParamMap): StandardScalerModel = { - val copied = new StandardScalerModel(uid, scaler) + val copied = new StandardScalerModel(uid, std, mean) copyValues(copied, extra).setParent(parent) } @@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { - private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + private case class Data(std: Vector, mean: Vector) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = - sqlContext.read.parquet(dataPath) - .select("std", "mean", "withStd", "withMean") - .head() - // This is very likely to change in the future because withStd and withMean should be params. - val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) - val model = new StandardScalerModel(metadata.uid, oldModel) + val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + .select("std", "mean") + .head() + val model = new StandardScalerModel(metadata.uid, std, mean) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 49a4b2efe0c29..1eae125a524ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("params") { ParamsSuite.checkParams(new StandardScaler) - val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) - ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + ParamsSuite.checkParams(new StandardScalerModel("empty", + Vectors.dense(1.0), Vectors.dense(2.0))) } test("Standardization with default parameter") { @@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("StandardScalerModel read/write") { - val oldModel = new feature.StandardScalerModel( - Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) - val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val instance = new StandardScalerModel("myStandardScalerModel", + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0)) val newInstance = testDefaultReadWrite(instance) assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) - assert(newInstance.getWithStd === instance.getWithStd) - assert(newInstance.getWithMean === instance.getWithMean) } } From e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 20 Nov 2015 09:57:09 -0800 Subject: [PATCH 0936/2089] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Author: Yuhao Yang Closes #9722 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 30 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 94 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 0000000000000..1743ef43a6ddf --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,30 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +
    +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a17..6f35b30c3d4df 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %} - + \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec4..54e35fcbb15af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 0000000000000..b3a7d2eb29780 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 0000000000000..419ce3d87a6ac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From bef361c589c0a38740232fd8d0a45841e4fc969a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 11:20:47 -0800 Subject: [PATCH 0937/2089] [SPARK-11876][SQL] Support printSchema in DataSet API DataSet APIs look great! However, I am lost when doing multiple level joins. For example, ``` val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2").printSchema() ``` The printed schema is like ``` root |-- _1: struct (nullable = true) | |-- _1: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) | |-- _2: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) |-- _2: struct (nullable = true) | |-- _1: string (nullable = true) | |-- _2: integer (nullable = true) ``` Personally, I think we need the printSchema function. Sometimes, I do not know how to specify the column, especially when their data types are mixed. For example, if I want to write the following select for the above multi-level join, I have to know the schema: ``` newDS.select(expr("_1._2._2 + 1").as[Int]).collect() ``` marmbrus rxin cloud-fan Do you have the same feeling? Author: gatorsmile Closes #9855 from gatorsmile/printSchemaDataSet. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 9 --------- .../scala/org/apache/spark/sql/execution/Queryable.scala | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 98358127e2709..7abcecaa2880e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -299,15 +299,6 @@ class DataFrame private[sql]( */ def columns: Array[String] = schema.fields.map(_.name) - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index e86a52c149a2f..321e2c783537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -37,6 +37,15 @@ private[sql] trait Queryable { } } + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + /** * Prints the plans (logical and physical) to the console for debugging purposes. * @since 1.3.0 From 60bfb113325c71491f8dcf98b6036b0caa2144fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Nov 2015 11:43:45 -0800 Subject: [PATCH 0938/2089] [SPARK-11817][SQL] Truncating the fractional seconds to prevent inserting a NULL JIRA: https://issues.apache.org/jira/browse/SPARK-11817 Instead of return None, we should truncate the fractional seconds to prevent inserting NULL. Author: Liang-Chi Hsieh Closes #9834 from viirya/truncate-fractional-sec. --- .../apache/spark/sql/catalyst/util/DateTimeUtils.scala | 5 +++++ .../spark/sql/catalyst/util/DateTimeUtilsSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 17a5527f3fb29..2b93882919487 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -327,6 +327,11 @@ object DateTimeUtils { return None } + // Instead of return None, we truncate the fractional seconds to prevent inserting NULL + if (segments(6) > 999999) { + segments(6) = segments(6).toString.take(6).toInt + } + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index faca128badfd6..0ce5a2fb69505 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -343,6 +343,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === + c.getTimeInMillis * 1000 + 123456) } test("hours") { From 3b9d2a347f9c796b90852173d84189834e499e25 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 12:04:42 -0800 Subject: [PATCH 0939/2089] [SPARK-11819][SQL] nice error message for missing encoder before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`. After this PR, the error message become more friendly, for example: ``` No Encoder found for abc.xyz.NonEncodable - array element class: "abc.xyz.NonEncodable" - field (class: "scala.Array", name: "arrayField") - root class: "abc.xyz.AnotherClass" ``` Author: Wenchen Fan Closes #9810 from cloud-fan/error-message. --- .../spark/sql/catalyst/ScalaReflection.scala | 90 ++++++++++++++----- .../encoders/EncoderErrorMessageSuite.scala | 62 +++++++++++++ 2 files changed, 129 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 33ae700706dae..918050b531c02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) className match { case "scala.Array" => val TypeRef(_, _, Seq(elementType)) = tpe @@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection { } } - /** Returns expressions for extracting all the fields from the given type. */ + /** + * Returns expressions for extracting all the fields from the given type. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - extractorFor(inputObject, localTypeOf[T]) match { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + extractorFor(inputObject, tpe, walkedTypePath) match { case s: CreateNamedStruct => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } @@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection { /** Helper for extracting internal fields from a case class. */ private def extractorFor( inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + tpe: `Type`, + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = silentSchemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here + // to trigger the type check. + extractorFor(inputObject, elementType, newPath) + + MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + } + } + if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { @@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(optType) val classObj = Utils.classForName(className) val optionObjectType = ObjectType(classObj) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(optionObjectType, inputObject) expressions.If( IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + extractorFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) case t if t <:< localTypeOf[Array[_]] => @@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } } /** @@ -561,7 +588,7 @@ trait ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -637,6 +664,23 @@ trait ScalaReflection { } } + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. + */ + private def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) + } + + /** Returns the full class name for a type. */ + private def getClassNameFromType(tpe: `Type`): String = { + tpe.erasure.typeSymbol.asClass.fullName + } + /** * Returns classes of input parameters of scala function object. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index 0b2a10bb04c10..8c766ef829923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -17,9 +17,22 @@ package org.apache.spark.sql.catalyst.encoders +import scala.reflect.ClassTag + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +class NonEncodable(i: Int) + +case class ComplexNonEncodable1(name1: NonEncodable) + +case class ComplexNonEncodable2(name2: ComplexNonEncodable1) + +case class ComplexNonEncodable3(name3: Option[NonEncodable]) + +case class ComplexNonEncodable4(name4: Array[NonEncodable]) + +case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) class EncoderErrorMessageSuite extends SparkFunSuite { @@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite { intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } } + + test("nice error message for missing encoder") { + val errorMsg1 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage + assert(errorMsg1.contains( + s"""root class: "${clsName[ComplexNonEncodable1]}"""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg2 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage + assert(errorMsg2.contains( + s"""root class: "${clsName[ComplexNonEncodable2]}"""")) + assert(errorMsg2.contains( + s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg3 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage + assert(errorMsg3.contains( + s"""root class: "${clsName[ComplexNonEncodable3]}"""")) + assert(errorMsg3.contains( + s"""field (class: "scala.Option", name: "name3")""")) + assert(errorMsg3.contains( + s"""option value class: "${clsName[NonEncodable]}"""")) + + val errorMsg4 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage + assert(errorMsg4.contains( + s"""root class: "${clsName[ComplexNonEncodable4]}"""")) + assert(errorMsg4.contains( + s"""field (class: "scala.Array", name: "name4")""")) + assert(errorMsg4.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + + val errorMsg5 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage + assert(errorMsg5.contains( + s"""root class: "${clsName[ComplexNonEncodable5]}"""")) + assert(errorMsg5.contains( + s"""field (class: "scala.Option", name: "name5")""")) + assert(errorMsg5.contains( + s"""option value class: "scala.Array"""")) + assert(errorMsg5.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + } + + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } From 652def318e47890bd0a0977dc982cc07f99fb06a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 13:17:35 -0800 Subject: [PATCH 0940/2089] [SPARK-11650] Reduce RPC timeouts to speed up slow AkkaUtilsSuite test This patch reduces some RPC timeouts in order to speed up the slow "AkkaUtilsSuite.remote fetch ssl on - untrusted server", which used to take two minutes to run. Author: Josh Rosen Closes #9869 from JoshRosen/SPARK-11650. --- core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 61601016e005e..0af4b6098bb0a 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -340,10 +340,11 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + .set("spark.rpc.askTimeout", "5s") + .set("spark.rpc.lookupTimeout", "5s") val securityManagerBad = new SecurityManager(slaveConf) val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) try { slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) fail("should receive either ActorNotFound or TimeoutException") From 9ed4ad4265cf9d3135307eb62dae6de0b220fc21 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 14:19:34 -0800 Subject: [PATCH 0941/2089] [SPARK-11724][SQL] Change casting between int and timestamp to consistently treat int in seconds. Hive has since changed this behavior as well. https://issues.apache.org/jira/browse/HIVE-3454 Author: Nong Li Author: Nong Li Author: Yin Huai Closes #9685 from nongli/spark-11724. --- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++-- .../sql/catalyst/expressions/CastSuite.scala | 16 +++++---- .../apache/spark/sql/DateFunctionsSuite.scala | 3 ++ ...esting-0-237a6af90a857da1efcbe98f6bbbf9d6} | 2 +- ... cast #3-0-76ee270337f664b36cacfc6528ac109 | 1 - ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 1 - ...cast #7-0-1d70654217035f8ce5f64344f4c5a80f | 1 - .../sql/hive/execution/HiveQuerySuite.scala | 34 +++++++++++++------ 8 files changed, 39 insertions(+), 25 deletions(-) rename sql/hive/src/test/resources/golden/{constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 => constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6} (52%) delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5564e242b0472..533d17ea5c172 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -204,8 +204,8 @@ case class Cast(child: Expression, dataType: DataType) if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong } - // converting milliseconds to us - private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting seconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000000L // converting us to seconds private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong // converting us to seconds in double @@ -647,7 +647,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def decimalToTimestampCode(d: String): String = s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" private[this] def timestampToIntegerCode(ts: String): String = s"java.lang.Math.floor((double) $ts / 1000000L)" private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f4db4da7646f8..ab77a764483e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -258,8 +258,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -348,14 +348,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -479,10 +479,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - DateTimeUtils.fromJavaTimestamp(ts)) - checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + DateTimeUtils.fromJavaTimestamp(ts) * 1000) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 241cbd0115070..a61c3aa48a73f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -448,6 +448,9 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer(sql(s"select cast ($now as timestamp)"), Row(new java.util.Date(now * 1000))) } test("to_unix_timestamp") { diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 similarity index 52% rename from sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 rename to sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 index 7c41615f8c184..a01c2622c68e2 100644 --- a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 +++ b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 @@ -1 +1 @@ -1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 deleted file mode 100644 index d00491fd7e5bb..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 +++ /dev/null @@ -1 +0,0 @@ -1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 deleted file mode 100644 index 84a31a5a6970b..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ /dev/null @@ -1 +0,0 @@ --0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f deleted file mode 100644 index 3fbedf693b51d..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f +++ /dev/null @@ -1 +0,0 @@ --2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f0a7a6cc7a1e3..8a5acaf3e10bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.sql.Timestamp import java.util.{Locale, TimeZone} import scala.util.Try @@ -248,12 +249,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, - |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, - |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, - |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, - |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL22, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23 |FROM src LIMIT 1""".stripMargin) + test("constant null testing timestamp") { + val r1 = sql("SELECT IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL20") + .collect().head + assert(new Timestamp(1000) == r1.getTimestamp(0)) + } + createQueryTest("constant array", """ |SELECT sort_array( @@ -603,26 +609,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(0.001 == res.getDouble(0)) + assert(1 == res.getDouble(0)) } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #3", - "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #3") { + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) + } createQueryTest("timestamp cast #4", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #5", - "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) + } createQueryTest("timestamp cast #6", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #7", - "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #7") { + val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(-1200 == res.getInt(0)) + } createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") From be7a2cfd978143f6f265eca63e9e24f755bc9f22 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 20 Nov 2015 14:23:01 -0800 Subject: [PATCH 0942/2089] [SPARK-11870][STREAMING][PYSPARK] Rethrow the exceptions in TransformFunction and TransformFunctionSerializer TransformFunction and TransformFunctionSerializer don't rethrow the exception, so when any exception happens, it just return None. This will cause some weird NPE and confuse people. Author: Shixiong Zhu Closes #9847 from zsxwing/pyspark-streaming-exception. --- python/pyspark/streaming/tests.py | 16 ++++++++++++++++ python/pyspark/streaming/util.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 3403f6d20d789..a0e0267cafa58 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,22 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_failed_func(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("failed") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + return + + self.fail("a failed func should throw an error") + class StreamingListenerTests(PySparkStreamingTestCase): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283bd..767c732eb90b4 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -64,6 +64,7 @@ def call(self, milliseconds, jrdds): return r._jrdd except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunction(%s)" % self.func @@ -95,6 +96,7 @@ def dumps(self, id): return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() + raise def loads(self, data): try: @@ -102,6 +104,7 @@ def loads(self, data): return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer From 89fd9bd06160fa89dedbf685bfe159ffe4a06ec6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 14:31:26 -0800 Subject: [PATCH 0943/2089] [SPARK-11887] Close PersistenceEngine at the end of PersistenceEngineSuite tests In PersistenceEngineSuite, we do not call `close()` on the PersistenceEngine at the end of the test. For the ZooKeeperPersistenceEngine, this causes us to leak a ZooKeeper client, causing the logs of unrelated tests to be periodically spammed with connection error messages from that client: ``` 15/11/20 05:13:35.789 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) INFO ClientCnxn: Opening socket connection to server localhost/127.0.0.1:15741. Will not attempt to authenticate using SASL (unknown error) 15/11/20 05:13:35.790 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) WARN ClientCnxn: Session 0x15124ff48dd0000 for server null, unexpected error, closing socket connection and attempting reconnect java.net.ConnectException: Connection refused at sun.nio.ch.SocketChannelImpl.checkConnect(Native Method) at sun.nio.ch.SocketChannelImpl.finishConnect(SocketChannelImpl.java:739) at org.apache.zookeeper.ClientCnxnSocketNIO.doTransport(ClientCnxnSocketNIO.java:350) at org.apache.zookeeper.ClientCnxn$SendThread.run(ClientCnxn.java:1068) ``` This patch fixes this by using a `finally` block. Author: Josh Rosen Closes #9864 from JoshRosen/close-zookeeper-client-in-tests. --- .../master/PersistenceEngineSuite.scala | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 34775577de8a3..7a44728675680 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -63,56 +63,60 @@ class PersistenceEngineSuite extends SparkFunSuite { conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { val serializer = new JavaSerializer(conf) val persistenceEngine = persistenceEngineCreator(serializer) - persistenceEngine.persist("test_1", "test_1_value") - assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.persist("test_2", "test_2_value") - assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) - persistenceEngine.unpersist("test_1") - assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.unpersist("test_2") - assert(persistenceEngine.read[String]("test_").isEmpty) - - // Test deserializing objects that contain RpcEndpointRef - val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) try { - // Create a real endpoint so that we can test RpcEndpointRef deserialization - val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { - override val rpcEnv: RpcEnv = testRpcEnv - }) - - val workerToPersist = new WorkerInfo( - id = "test_worker", - host = "127.0.0.1", - port = 10000, - cores = 0, - memory = 0, - endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) - - persistenceEngine.addWorker(workerToPersist) - - val (storedApps, storedDrivers, storedWorkers) = - persistenceEngine.readPersistedData(testRpcEnv) - - assert(storedApps.isEmpty) - assert(storedDrivers.isEmpty) - - // Check deserializing WorkerInfo - assert(storedWorkers.size == 1) - val recoveryWorkerInfo = storedWorkers.head - assert(workerToPersist.id === recoveryWorkerInfo.id) - assert(workerToPersist.host === recoveryWorkerInfo.host) - assert(workerToPersist.port === recoveryWorkerInfo.port) - assert(workerToPersist.cores === recoveryWorkerInfo.cores) - assert(workerToPersist.memory === recoveryWorkerInfo.memory) - assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } } finally { - testRpcEnv.shutdown() - testRpcEnv.awaitTermination() + persistenceEngine.close() } } From 03ba56d78f50747710d01c27d409ba2be42ae557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Fri, 20 Nov 2015 14:45:40 -0800 Subject: [PATCH 0944/2089] [SPARK-11716][SQL] UDFRegistration just drops the input type when re-creating the UserDefinedFunction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-11716 This is one is #9739 and a regression test. When commit it, please make sure the author is jbonofre. You can find the original PR at https://github.com/apache/spark/pull/9739 closes #9739 Author: Jean-Baptiste Onofré Author: Yin Huai Closes #9868 from yhuai/SPARK-11716. --- .../apache/spark/sql/UDFRegistration.scala | 48 +++++++++---------- .../scala/org/apache/spark/sql/UDFSuite.scala | 15 ++++++ 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fc4d0938c533a..051694c0d43a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -88,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try($inputTypes).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) }""") } @@ -120,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -133,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -146,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -159,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -172,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -211,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -224,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -237,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -250,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -263,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -276,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -289,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -302,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -315,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -341,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -367,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -380,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -393,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -406,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 9837fa6bdb357..fd736718af12c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -232,4 +232,19 @@ class UDFSuite extends QueryTest with SharedSQLContext { | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp """.stripMargin).toDF(), complexData.select("m", "a", "b")) } + + test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { + val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + + // Without the fix, this will fail because we fail to cast data type of b to string + // because myUDF does not know its input data type. With the fix, this query should not + // fail. + checkAnswer( + testData2.select(myUDF($"a", $"b").as("t")), + testData2.selectExpr("struct(a, b)")) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(), + testData2) + } } From a6239d587c638691f52eca3eee905c53fbf35a12 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 20 Nov 2015 15:10:55 -0800 Subject: [PATCH 0945/2089] [SPARK-11756][SPARKR] Fix use of aliases - SparkR can not output help information for SparkR:::summary correctly Fix use of aliases and changes uses of rdname and seealso `aliases` is the hint for `?` - it should not be linked to some other name - those should be seealso https://cran.r-project.org/web/packages/roxygen2/vignettes/rd.html Clean up usage on family, as multiple use of family with the same rdname is causing duplicated See Also html blocks (like http://spark.apache.org/docs/latest/api/R/count.html) Also changing some rdname for dplyr-like variant for better R user visibility in R doc, eg. rbind, summary, mutate, summarize shivaram yanboliang Author: felixcheung Closes #9750 from felixcheung/rdocaliases. --- R/pkg/R/DataFrame.R | 96 ++++++++++++--------------------------------- R/pkg/R/broadcast.R | 1 - R/pkg/R/generics.R | 12 +++--- R/pkg/R/group.R | 12 +++--- 4 files changed, 37 insertions(+), 84 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 06b0108b1389e..8a13e7a36766d 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,7 +254,6 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns -#' @aliases names #' @export #' @examples #'\dontrun{ @@ -272,7 +271,6 @@ setMethod("columns", }) }) -#' @family DataFrame functions #' @rdname columns #' @name names setMethod("names", @@ -281,7 +279,6 @@ setMethod("names", columns(x) }) -#' @family DataFrame functions #' @rdname columns #' @name names<- setMethod("names<-", @@ -533,14 +530,8 @@ setMethod("distinct", dataFrame(sdf) }) -#' @title Distinct rows in a DataFrame -# -#' @description Returns a new DataFrame containing distinct rows in this DataFrame -#' -#' @family DataFrame functions -#' @rdname unique +#' @rdname distinct #' @name unique -#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -557,7 +548,7 @@ setMethod("unique", #' #' @family DataFrame functions #' @rdname sample -#' @aliases sample_frac +#' @name sample #' @export #' @examples #'\dontrun{ @@ -579,7 +570,6 @@ setMethod("sample", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname sample #' @name sample_frac setMethod("sample_frac", @@ -589,16 +579,15 @@ setMethod("sample_frac", sample(x, withReplacement, fraction) }) -#' Count +#' nrow #' #' Returns the number of rows in a DataFrame #' #' @param x A SparkSQL DataFrame #' #' @family DataFrame functions -#' @rdname count +#' @rdname nrow #' @name count -#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -614,14 +603,8 @@ setMethod("count", callJMethod(x@sdf, "count") }) -#' @title Number of rows for a DataFrame -#' @description Returns number of rows in a DataFrames -#' #' @name nrow -#' -#' @family DataFrame functions #' @rdname nrow -#' @aliases count setMethod("nrow", signature(x = "DataFrame"), function(x) { @@ -870,7 +853,6 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @aliases group_by #' @family DataFrame functions #' @rdname groupBy #' @name groupBy @@ -896,7 +878,6 @@ setMethod("groupBy", groupedData(sgd) }) -#' @family DataFrame functions #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -913,7 +894,6 @@ setMethod("group_by", #' @family DataFrame functions #' @rdname agg #' @name agg -#' @aliases summarize #' @export setMethod("agg", signature(x = "DataFrame"), @@ -921,7 +901,6 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @family DataFrame functions #' @rdname agg #' @name summarize setMethod("summarize", @@ -1092,7 +1071,6 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @family DataFrame functions #' @rdname subset #' @name subset -#' @aliases [ #' @family subsetting functions #' @examples #' \dontrun{ @@ -1216,7 +1194,7 @@ setMethod("selectExpr", #' @family DataFrame functions #' @rdname withColumn #' @name withColumn -#' @aliases mutate transform +#' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1231,7 +1209,6 @@ setMethod("withColumn", function(x, colName, col) { select(x, x$"*", alias(col, colName)) }) - #' Mutate #' #' Return a new DataFrame with the specified columns added. @@ -1240,9 +1217,9 @@ setMethod("withColumn", #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name mutate -#' @aliases withColumn transform +#' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ @@ -1273,17 +1250,15 @@ setMethod("mutate", }) #' @export -#' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name transform -#' @aliases withColumn mutate setMethod("transform", signature(`_data` = "DataFrame"), function(`_data`, ...) { mutate(`_data`, ...) }) -#' WithColumnRenamed +#' rename #' #' Rename an existing column in a DataFrame. #' @@ -1292,8 +1267,9 @@ setMethod("transform", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @family DataFrame functions -#' @rdname withColumnRenamed +#' @rdname rename #' @name withColumnRenamed +#' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1316,17 +1292,9 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' Rename -#' -#' Rename an existing column in a DataFrame. -#' -#' @param x A DataFrame -#' @param newCol A named pair of the form new_column_name = existing_column -#' @return A DataFrame with the column name changed. -#' @family DataFrame functions -#' @rdname withColumnRenamed +#' @param newColPair A named pair of the form new_column_name = existing_column +#' @rdname rename #' @name rename -#' @aliases withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1371,7 +1339,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @family DataFrame functions #' @rdname arrange #' @name arrange -#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1395,8 +1362,8 @@ setMethod("arrange", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname arrange +#' @name arrange #' @export setMethod("arrange", signature(x = "DataFrame", col = "character"), @@ -1427,9 +1394,9 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) -#' @family DataFrame functions #' @rdname arrange -#' @name orderby +#' @name orderBy +#' @export setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1492,6 +1459,7 @@ setMethod("where", #' @family DataFrame functions #' @rdname join #' @name join +#' @seealso \link{merge} #' @export #' @examples #'\dontrun{ @@ -1528,9 +1496,7 @@ setMethod("join", dataFrame(sdf) }) -#' #' @name merge -#' @aliases join #' @title Merges two data frames #' @param x the first data frame to be joined #' @param y the second data frame to be joined @@ -1550,6 +1516,7 @@ setMethod("join", #' outer join will be returned. #' @family DataFrame functions #' @rdname merge +#' @seealso \link{join} #' @export #' @examples #'\dontrun{ @@ -1671,7 +1638,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { cols } -#' UnionAll +#' rbind #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. @@ -1681,7 +1648,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @family DataFrame functions -#' @rdname unionAll +#' @rdname rbind #' @name unionAll #' @export #' @examples @@ -1700,13 +1667,11 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -#' #' @description Returns a new DataFrame containing rows of all parameters. #' -#' @family DataFrame functions #' @rdname rbind #' @name rbind -#' @aliases unionAll +#' @export setMethod("rbind", signature(... = "DataFrame"), function(x, ..., deparse.level = 1) { @@ -1795,7 +1760,6 @@ setMethod("except", #' @family DataFrame functions #' @rdname write.df #' @name write.df -#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1828,7 +1792,6 @@ setMethod("write.df", callJMethod(df@sdf, "save", source, jmode, options) }) -#' @family DataFrame functions #' @rdname write.df #' @name saveDF #' @export @@ -1891,7 +1854,7 @@ setMethod("saveAsTable", callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) }) -#' describe +#' summary #' #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. @@ -1901,9 +1864,8 @@ setMethod("saveAsTable", #' @param ... Additional expressions #' @return A DataFrame #' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe -#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1923,8 +1885,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe setMethod("describe", signature(x = "DataFrame"), @@ -1934,11 +1895,6 @@ setMethod("describe", dataFrame(sdf) }) -#' @title Summary -#' -#' @description Computes statistics for numeric columns of the DataFrame -#' -#' @family DataFrame functions #' @rdname summary #' @name summary setMethod("summary", @@ -1966,7 +1922,6 @@ setMethod("summary", #' @family DataFrame functions #' @rdname nafunctions #' @name dropna -#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1993,7 +1948,6 @@ setMethod("dropna", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname nafunctions #' @name na.omit #' @export @@ -2019,9 +1973,7 @@ setMethod("na.omit", #' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. -#' @return A DataFrame #' -#' @family DataFrame functions #' @rdname nafunctions #' @name fillna #' @export diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 2403925b267c8..38f0eed95e065 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) { # # @param bcast The broadcast variable to get # @rdname broadcast -# @aliases value,Broadcast-method setMethod("value", signature(bcast = "Broadcast"), function(bcast) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 71004a05ba611..1b3f10ea04643 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -397,7 +397,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @export setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) -#' @rdname describe +#' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -459,11 +459,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' rdname merge +#' @rdname merge #' @export setGeneric("merge") -#' @rdname withColumn +#' @rdname mutate #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) @@ -475,7 +475,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -553,7 +553,7 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname unionAll +#' @rdname rbind #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) @@ -565,7 +565,7 @@ setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index e5f702faee65d..23b49aebda05f 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -68,7 +68,7 @@ setMethod("count", dataFrame(callJMethod(x@sgd, "count")) }) -#' Agg +#' summarize #' #' Aggregates on the entire DataFrame without groups. #' The resulting DataFrame will also contain the grouping columns. @@ -78,12 +78,14 @@ setMethod("count", #' #' @param x a GroupedData #' @return a DataFrame -#' @rdname agg +#' @rdname summarize +#' @name agg #' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' -#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df4 <- summarize(df, ageSum = max(df$age)) #' } setMethod("agg", signature(x = "GroupedData"), @@ -110,8 +112,8 @@ setMethod("agg", dataFrame(sdf) }) -#' @rdname agg -#' @aliases agg +#' @rdname summarize +#' @name summarize setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { From 4b84c72dfbb9ddb415fee35f69305b5d7b280891 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:17:17 -0800 Subject: [PATCH 0946/2089] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders #theScaryParts (i.e. changes to the repl, executor classloaders and codegen)... Author: Michael Armbrust Author: Yin Huai Closes #9825 from marmbrus/dataset-replClasses2. --- .../org/apache/spark/repl/SparkIMain.scala | 14 +++++++---- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++++ .../spark/repl/ExecutorClassLoader.scala | 8 ++++++- .../expressions/codegen/CodeGenerator.scala | 4 ++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11e..829b12269fd2b 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1221,10 +1221,16 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s%s%s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | ${indentCode(toCompute)} + """.stripMargin val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669bee..081aa03002cc6 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3d2d235a00c93..a976e96809cb8 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -65,7 +65,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader case e: ClassNotFoundException => { val classOption = findClassLocally(name) classOption match { - case None => throw new ClassNotFoundException(name, e) + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) case Some(a) => a } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1b7260cdfe515..2f3d6aeb86c5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ - +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -536,7 +536,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From ed47b1e660b830e2d4fac8d6df93f634b260393c Mon Sep 17 00:00:00 2001 From: Vikas Nelamangala Date: Fri, 20 Nov 2015 15:18:41 -0800 Subject: [PATCH 0947/2089] [SPARK-11549][DOCS] Replace example code in mllib-evaluation-metrics.md using include_example Author: Vikas Nelamangala Closes #9689 from vikasnp/master. --- docs/mllib-evaluation-metrics.md | 940 +----------------- ...avaBinaryClassificationMetricsExample.java | 113 +++ ...ultiLabelClassificationMetricsExample.java | 80 ++ ...ulticlassClassificationMetricsExample.java | 97 ++ .../mllib/JavaRankingMetricsExample.java | 176 ++++ .../mllib/JavaRegressionMetricsExample.java | 91 ++ .../binary_classification_metrics_example.py | 55 + .../mllib/multi_class_metrics_example.py | 69 ++ .../mllib/multi_label_metrics_example.py | 61 ++ .../python/mllib/ranking_metrics_example.py | 55 + .../mllib/regression_metrics_example.py | 59 ++ .../BinaryClassificationMetricsExample.scala | 103 ++ .../mllib/MultiLabelMetricsExample.scala | 69 ++ .../mllib/MulticlassMetricsExample.scala | 99 ++ .../mllib/RankingMetricsExample.scala | 110 ++ .../mllib/RegressionMetricsExample.scala | 67 ++ 16 files changed, 1319 insertions(+), 925 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java create mode 100644 examples/src/main/python/mllib/binary_classification_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_class_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_label_metrics_example.py create mode 100644 examples/src/main/python/mllib/ranking_metrics_example.py create mode 100644 examples/src/main/python/mllib/regression_metrics_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index f73eff637dc36..6924037b941f3 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -104,214 +104,21 @@ data, and evaluate the performance of the algorithm by several binary evaluation
    Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training) - -// Clear the prediction threshold so the model will return probabilities -model.clearThreshold - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new BinaryClassificationMetrics(predictionAndLabels) - -// Precision by threshold -val precision = metrics.precisionByThreshold -precision.foreach { case (t, p) => - println(s"Threshold: $t, Precision: $p") -} - -// Recall by threshold -val recall = metrics.recallByThreshold -recall.foreach { case (t, r) => - println(s"Threshold: $t, Recall: $r") -} - -// Precision-Recall Curve -val PRC = metrics.pr - -// F-measure -val f1Score = metrics.fMeasureByThreshold -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 1") -} - -val beta = 0.5 -val fScore = metrics.fMeasureByThreshold(beta) -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 0.5") -} - -// AUPRC -val auPRC = metrics.areaUnderPR -println("Area under precision-recall curve = " + auPRC) - -// Compute thresholds used in ROC and PR curves -val thresholds = precision.map(_._1) - -// ROC Curve -val roc = metrics.roc - -// AUROC -val auROC = metrics.areaUnderROC -println("Area under ROC = " + auROC) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
    Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class BinaryClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_binary_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training.rdd()); - - // Clear the prediction threshold so the model will return probabilities - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); - - // Precision by threshold - JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); - - // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); - - // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); - - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); - - // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); - - // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - public Double call (Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); - - // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); - - // AUPRC - System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); - - // AUROC - System.out.println("Area under ROC = " + metrics.areaUnderROC()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java %}
    Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Several of the methods available in scala are currently missing from pyspark - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = BinaryClassificationMetrics(predictionAndLabels) - -# Area under precision-recall curve -print("Area under PR = %s" % metrics.areaUnderPR) - -# Area under ROC curve -print("Area under ROC = %s" % metrics.areaUnderROC) - -{% endhighlight %} - +{% include_example python/mllib/binary_classification_metrics_example.py %}
    @@ -433,204 +240,21 @@ the data, and evaluate the performance of the algorithm by several multiclass cl
    Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training) - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new MulticlassMetrics(predictionAndLabels) - -// Confusion matrix -println("Confusion matrix:") -println(metrics.confusionMatrix) - -// Overall Statistics -val precision = metrics.precision -val recall = metrics.recall // same as true positive rate -val f1Score = metrics.fMeasure -println("Summary Statistics") -println(s"Precision = $precision") -println(s"Recall = $recall") -println(s"F1 Score = $f1Score") - -// Precision by label -val labels = metrics.labels -labels.foreach { l => - println(s"Precision($l) = " + metrics.precision(l)) -} - -// Recall by label -labels.foreach { l => - println(s"Recall($l) = " + metrics.recall(l)) -} - -// False positive rate by label -labels.foreach { l => - println(s"FPR($l) = " + metrics.falsePositiveRate(l)) -} - -// F-measure by label -labels.foreach { l => - println(s"F1-Score($l) = " + metrics.fMeasure(l)) -} - -// Weighted stats -println(s"Weighted precision: ${metrics.weightedPrecision}") -println(s"Weighted recall: ${metrics.weightedRecall}") -println(s"Weighted F1 score: ${metrics.weightedFMeasure}") -println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
    Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MulticlassClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - - // Confusion matrix - Matrix confusion = metrics.confusionMatrix(); - System.out.println("Confusion matrix: \n" + confusion); - - // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); - } - - //Weighted stats - System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); - System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); - System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); - System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} + {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
    Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.util import MLUtils -from pyspark.mllib.evaluation import MulticlassMetrics - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training, numClasses=3) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = MulticlassMetrics(predictionAndLabels) - -# Overall statistics -precision = metrics.precision() -recall = metrics.recall() -f1Score = metrics.fMeasure() -print("Summary Stats") -print("Precision = %s" % precision) -print("Recall = %s" % recall) -print("F1 Score = %s" % f1Score) - -# Statistics by class -labels = data.map(lambda lp: lp.label).distinct().collect() -for label in sorted(labels): - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) - -# Weighted stats -print("Weighted recall = %s" % metrics.weightedRecall) -print("Weighted precision = %s" % metrics.weightedPrecision) -print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) -print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) -print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) -{% endhighlight %} +{% include_example python/mllib/multi_class_metrics_example.py %}
    @@ -766,154 +390,21 @@ True classes:
    Refer to the [`MultilabelMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.MultilabelMetrics -import org.apache.spark.rdd.RDD; - -val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( - Seq((Array(0.0, 1.0), Array(0.0, 2.0)), - (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), - (Array(2.0), Array(2.0)), - (Array(2.0, 0.0), Array(2.0, 0.0)), - (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), - (Array(1.0), Array(1.0, 2.0))), 2) - -// Instantiate metrics object -val metrics = new MultilabelMetrics(scoreAndLabels) - -// Summary stats -println(s"Recall = ${metrics.recall}") -println(s"Precision = ${metrics.precision}") -println(s"F1 measure = ${metrics.f1Measure}") -println(s"Accuracy = ${metrics.accuracy}") - -// Individual label stats -metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) -metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) -metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) - -// Micro stats -println(s"Micro recall = ${metrics.microRecall}") -println(s"Micro precision = ${metrics.microPrecision}") -println(s"Micro F1 measure = ${metrics.microF1Measure}") - -// Hamming loss -println(s"Hamming loss = ${metrics.hammingLoss}") - -// Subset accuracy -println(s"Subset accuracy = ${metrics.subsetAccuracy}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala %}
    Refer to the [`MultilabelMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.evaluation.MultilabelMetrics; -import org.apache.spark.SparkConf; -import java.util.Arrays; -import java.util.List; - -public class MultilabelClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - - List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) - ); - JavaRDD> scoreAndLabels = sc.parallelize(data); - - // Instantiate metrics object - MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); - - // Summary stats - System.out.format("Recall = %f\n", metrics.recall()); - System.out.format("Precision = %f\n", metrics.precision()); - System.out.format("F1 measure = %f\n", metrics.f1Measure()); - System.out.format("Accuracy = %f\n", metrics.accuracy()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); - } - - // Micro stats - System.out.format("Micro recall = %f\n", metrics.microRecall()); - System.out.format("Micro precision = %f\n", metrics.microPrecision()); - System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); - - // Hamming loss - System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); - - // Subset accuracy - System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); - - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java %}
    Refer to the [`MultilabelMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.evaluation import MultilabelMetrics - -scoreAndLabels = sc.parallelize([ - ([0.0, 1.0], [0.0, 2.0]), - ([0.0, 2.0], [0.0, 1.0]), - ([], [0.0]), - ([2.0], [2.0]), - ([2.0, 0.0], [2.0, 0.0]), - ([0.0, 1.0, 2.0], [0.0, 1.0]), - ([1.0], [1.0, 2.0])]) - -# Instantiate metrics object -metrics = MultilabelMetrics(scoreAndLabels) - -# Summary stats -print("Recall = %s" % metrics.recall()) -print("Precision = %s" % metrics.precision()) -print("F1 measure = %s" % metrics.f1Measure()) -print("Accuracy = %s" % metrics.accuracy) - -# Individual label stats -labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() -for label in labels: - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) - -# Micro stats -print("Micro precision = %s" % metrics.microPrecision) -print("Micro recall = %s" % metrics.microRecall) -print("Micro F1 measure = %s" % metrics.microF1Measure) - -# Hamming loss -print("Hamming loss = %s" % metrics.hammingLoss) - -# Subset accuracy -print("Subset accuracy = %s" % metrics.subsetAccuracy) - -{% endhighlight %} +{% include_example python/mllib/multi_label_metrics_example.py %}
    @@ -1027,280 +518,21 @@ expanded world of non-positive weights are "the same as never having interacted
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} -import org.apache.spark.mllib.recommendation.{ALS, Rating} - -// Read in the ratings data -val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => - val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) -}.cache() - -// Map ratings to 1 or 0, 1 indicating a movie that should be recommended -val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() - -// Summarize ratings -val numRatings = ratings.count() -val numUsers = ratings.map(_.user).distinct().count() -val numMovies = ratings.map(_.product).distinct().count() -println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - -// Build the model -val numIterations = 10 -val rank = 10 -val lambda = 0.01 -val model = ALS.train(ratings, rank, numIterations, lambda) - -// Define a function to scale ratings from 0 to 1 -def scaledRating(r: Rating): Rating = { - val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) - Rating(r.user, r.product, scaledRating) -} - -// Get sorted top ten predictions for each user and then scale from [0, 1] -val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => - (user, recs.map(scaledRating)) -} - -// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document -// Compare with top ten most relevant documents -val userMovies = binarizedRatings.groupBy(_.user) -val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => - (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) -} - -// Instantiate metrics object -val metrics = new RankingMetrics(relevantDocuments) - -// Precision at K -Array(1, 3, 5).foreach{ k => - println(s"Precision at $k = ${metrics.precisionAt(k)}") -} - -// Mean average precision -println(s"Mean average precision = ${metrics.meanAveragePrecision}") - -// Normalized discounted cumulative gain -Array(1, 3, 5).foreach{ k => - println(s"NDCG at $k = ${metrics.ndcgAt(k)}") -} - -// Get predictions for each data point -val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) -val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) -val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => - (predicted, actual) -} - -// Get the RMSE using regression metrics -val regressionMetrics = new RegressionMetrics(predictionsAndLabels) -println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${regressionMetrics.r2}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import java.util.*; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.mllib.evaluation.RankingMetrics; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.Rating; - -// Read in the ratings data -public class Ranking { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - String path = "data/mllib/sample_movielens_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); - } - } - ); - ratings.cache(); - - // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); - - // Get top 10 recommendations for every user and scale ratings from 0 to 1 - JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2(t._1(), scaledRatings); - } - } - ); - JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); - - // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } - else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); - } - } - ); - - // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - public Object call(Rating r) { - return r.user(); - } - } - ); - - // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - public List call(Iterable docs) { - List products = new ArrayList(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } - } - return products; - } - } - ); - - // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - public List call(Rating[] docs) { - List products = new ArrayList(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; - } - } - ); - JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); - - // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); - - // Precision and NDCG at k - Integer[] kVector = {1, 3, 5}; - for (Integer k : kVector) { - System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); - System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); - } - - // Mean average precision - System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); - - // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - - // Create regression metrics object - RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); - - // Root mean squared error - System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R-squared = %f\n", regressionMetrics.r2()); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating -from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics - -# Read in the ratings data -lines = sc.textFile("data/mllib/sample_movielens_data.txt") - -def parseLine(line): - fields = line.split("::") - return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) -ratings = lines.map(lambda r: parseLine(r)) - -# Train a model on to predict user-product ratings -model = ALS.train(ratings, 10, 10, 0.01) - -# Get predicted ratings on all existing user-product pairs -testData = ratings.map(lambda p: (p.user, p.product)) -predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) - -ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) -scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) - -# Instantiate regression metrics to compare predicted and actual ratings -metrics = RegressionMetrics(scoreAndLabels) - -# Root mean sqaured error -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -{% endhighlight %} +{% include_example python/mllib/ranking_metrics_example.py %}
    @@ -1350,163 +582,21 @@ and evaluate the performance of the algorithm by several regression metrics.
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load the data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() - -// Build the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(data, numIterations) - -// Get predictions -val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) -} - -// Instantiate metrics object -val metrics = new RegressionMetrics(valuesAndPreds) - -// Squared error -println(s"MSE = ${metrics.meanSquaredError}") -println(s"RMSE = ${metrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${metrics.r2}") - -// Mean absolute error -println(s"MAE = ${metrics.meanAbsoluteError}") - -// Explained variance -println(s"Explained variance = ${metrics.explainedVariance}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from pyspark.mllib.evaluation import RegressionMetrics -from pyspark.mllib.linalg import DenseVector - -# Load and parse the data -def parsePoint(line): - values = line.split() - return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) - -data = sc.textFile("data/mllib/sample_linear_regression_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Get predictions -valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) - -# Instantiate metrics object -metrics = RegressionMetrics(valuesAndPreds) - -# Squared Error -print("MSE = %s" % metrics.meanSquaredError) -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -# Mean absolute error -print("MAE = %s" % metrics.meanAbsoluteError) - -# Explained variance -print("Explained variance = %s" % metrics.explainedVariance) - -{% endhighlight %} +{% include_example python/mllib/regression_metrics_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java new file mode 100644 index 0000000000000..980a9108af53f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaBinaryClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = + data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call(Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java new file mode 100644 index 0000000000000..b54e1ea3f2bcf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +// $example off$ +import org.apache.spark.SparkContext; + +public class JavaMultiLabelClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision + (metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure + (metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java new file mode 100644 index 0000000000000..21f628fb51b6e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaMulticlassClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision + (metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure + (metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java new file mode 100644 index 0000000000000..7c4c97e74681f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaRankingMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join + (userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java new file mode 100644 index 0000000000000..d2efc6bf97776 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRegressionMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py new file mode 100644 index 0000000000000..437acb998acc3 --- /dev/null +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Binary Classification Metrics Example. +""" +from __future__ import print_function +import sys +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinaryClassificationMetricsExample") + sqlContext = SQLContext(sc) + # $example on$ + # Several of the methods available in scala are currently missing from pyspark + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = BinaryClassificationMetrics(predictionAndLabels) + + # Area under precision-recall curve + print("Area under PR = %s" % metrics.areaUnderPR) + + # Area under ROC curve + print("Area under ROC = %s" % metrics.areaUnderROC) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py new file mode 100644 index 0000000000000..cd56b3c97c778 --- /dev/null +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiClassMetricsExample") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = MulticlassMetrics(predictionAndLabels) + + # Overall statistics + precision = metrics.precision() + recall = metrics.recall() + f1Score = metrics.fMeasure() + print("Summary Stats") + print("Precision = %s" % precision) + print("Recall = %s" % recall) + print("F1 Score = %s" % f1Score) + + # Statistics by class + labels = data.map(lambda lp: lp.label).distinct().collect() + for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + + # Weighted stats + print("Weighted recall = %s" % metrics.weightedRecall) + print("Weighted precision = %s" % metrics.weightedPrecision) + print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) + print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) + print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_label_metrics_example.py b/examples/src/main/python/mllib/multi_label_metrics_example.py new file mode 100644 index 0000000000000..960ade6597379 --- /dev/null +++ b/examples/src/main/python/mllib/multi_label_metrics_example.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.evaluation import MultilabelMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiLabelMetricsExample") + # $example on$ + scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + + # Instantiate metrics object + metrics = MultilabelMetrics(scoreAndLabels) + + # Summary stats + print("Recall = %s" % metrics.recall()) + print("Precision = %s" % metrics.precision()) + print("F1 measure = %s" % metrics.f1Measure()) + print("Accuracy = %s" % metrics.accuracy) + + # Individual label stats + labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() + for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + + # Micro stats + print("Micro precision = %s" % metrics.microPrecision) + print("Micro recall = %s" % metrics.microRecall) + print("Micro F1 measure = %s" % metrics.microF1Measure) + + # Hamming loss + print("Hamming loss = %s" % metrics.hammingLoss) + + # Subset accuracy + print("Subset accuracy = %s" % metrics.subsetAccuracy) + # $example off$ diff --git a/examples/src/main/python/mllib/ranking_metrics_example.py b/examples/src/main/python/mllib/ranking_metrics_example.py new file mode 100644 index 0000000000000..327791966c901 --- /dev/null +++ b/examples/src/main/python/mllib/ranking_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Ranking Metrics Example") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Read in the ratings data + lines = sc.textFile("data/mllib/sample_movielens_data.txt") + + def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) + ratings = lines.map(lambda r: parseLine(r)) + + # Train a model on to predict user-product ratings + model = ALS.train(ratings, 10, 10, 0.01) + + # Get predicted ratings on all existing user-product pairs + testData = ratings.map(lambda p: (p.user, p.product)) + predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + + ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) + scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + + # Instantiate regression metrics to compare predicted and actual ratings + metrics = RegressionMetrics(scoreAndLabels) + + # Root mean sqaured error + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + # $example off$ diff --git a/examples/src/main/python/mllib/regression_metrics_example.py b/examples/src/main/python/mllib/regression_metrics_example.py new file mode 100644 index 0000000000000..a3a83aafd7a1f --- /dev/null +++ b/examples/src/main/python/mllib/regression_metrics_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Regression Metrics Example") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), + DenseVector([float(x.split(':')[1]) for x in values[1:]])) + + data = sc.textFile("data/mllib/sample_linear_regression_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData) + + # Get predictions + valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + + # Instantiate metrics object + metrics = RegressionMetrics(valuesAndPreds) + + # Squared Error + print("MSE = %s" % metrics.meanSquaredError) + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + + # Mean absolute error + print("MAE = %s" % metrics.meanAbsoluteError) + + # Explained variance + print("Explained variance = %s" % metrics.explainedVariance) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala new file mode 100644 index 0000000000000..13a37827ab935 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object BinaryClassificationMetricsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("BinaryClassificationMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new BinaryClassificationMetrics(predictionAndLabels) + + // Precision by threshold + val precision = metrics.precisionByThreshold + precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") + } + + // Recall by threshold + val recall = metrics.recallByThreshold + recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") + } + + // Precision-Recall Curve + val PRC = metrics.pr + + // F-measure + val f1Score = metrics.fMeasureByThreshold + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") + } + + val beta = 0.5 + val fScore = metrics.fMeasureByThreshold(beta) + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") + } + + // AUPRC + val auPRC = metrics.areaUnderPR + println("Area under precision-recall curve = " + auPRC) + + // Compute thresholds used in ROC and PR curves + val thresholds = precision.map(_._1) + + // ROC Curve + val roc = metrics.roc + + // AUROC + val auROC = metrics.areaUnderROC + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala new file mode 100644 index 0000000000000..4503c15360adc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MultiLabelMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultiLabelMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + + // Instantiate metrics object + val metrics = new MultilabelMetrics(scoreAndLabels) + + // Summary stats + println(s"Recall = ${metrics.recall}") + println(s"Precision = ${metrics.precision}") + println(s"F1 measure = ${metrics.f1Measure}") + println(s"Accuracy = ${metrics.accuracy}") + + // Individual label stats + metrics.labels.foreach(label => + println(s"Class $label precision = ${metrics.precision(label)}")) + metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) + metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + + // Micro stats + println(s"Micro recall = ${metrics.microRecall}") + println(s"Micro precision = ${metrics.microPrecision}") + println(s"Micro F1 measure = ${metrics.microF1Measure}") + + // Hamming loss + println(s"Hamming loss = ${metrics.hammingLoss}") + + // Subset accuracy + println(s"Subset accuracy = ${metrics.subsetAccuracy}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala new file mode 100644 index 0000000000000..0904449245989 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MulticlassMetricsExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MulticlassMetricsExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new MulticlassMetrics(predictionAndLabels) + + // Confusion matrix + println("Confusion matrix:") + println(metrics.confusionMatrix) + + // Overall Statistics + val precision = metrics.precision + val recall = metrics.recall // same as true positive rate + val f1Score = metrics.fMeasure + println("Summary Statistics") + println(s"Precision = $precision") + println(s"Recall = $recall") + println(s"F1 Score = $f1Score") + + // Precision by label + val labels = metrics.labels + labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) + } + + // Recall by label + labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) + } + + // False positive rate by label + labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) + } + + // F-measure by label + labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) + } + + // Weighted stats + println(s"Weighted precision: ${metrics.weightedPrecision}") + println(s"Weighted recall: ${metrics.weightedRecall}") + println(s"Weighted F1 score: ${metrics.weightedFMeasure}") + println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala new file mode 100644 index 0000000000000..cffa03d5cc9f4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} + +object RankingMetricsExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("RankingMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Read in the ratings data + val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + }.cache() + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + val binarizedRatings = ratings.map(r => Rating(r.user, r.product, + if (r.rating > 0) 1.0 else 0.0)).cache() + + // Summarize ratings + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + // Build the model + val numIterations = 10 + val rank = 10 + val lambda = 0.01 + val model = ALS.train(ratings, rank, numIterations, lambda) + + // Define a function to scale ratings from 0 to 1 + def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) + } + + // Get sorted top ten predictions for each user and then scale from [0, 1] + val userRecommended = model.recommendProductsForUsers(10).map { case (user, recs) => + (user, recs.map(scaledRating)) + } + + // Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document + // Compare with top ten most relevant documents + val userMovies = binarizedRatings.groupBy(_.user) + val relevantDocuments = userMovies.join(userRecommended).map { case (user, (actual, + predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) + } + + // Instantiate metrics object + val metrics = new RankingMetrics(relevantDocuments) + + // Precision at K + Array(1, 3, 5).foreach { k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") + } + + // Mean average precision + println(s"Mean average precision = ${metrics.meanAveragePrecision}") + + // Normalized discounted cumulative gain + Array(1, 3, 5).foreach { k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") + } + + // Get predictions for each data point + val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, + r.product), r.rating)) + val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) + val predictionsAndLabels = allPredictions.join(allRatings).map { case ((user, product), + (predicted, actual)) => + (predicted, actual) + } + + // Get the RMSE using regression metrics + val regressionMetrics = new RegressionMetrics(predictionsAndLabels) + println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${regressionMetrics.r2}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala new file mode 100644 index 0000000000000..47d44532521ca --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println + +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RegressionMetricsExample { + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("RegressionMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + + // Build the model + val numIterations = 100 + val model = LinearRegressionWithSGD.train(data, numIterations) + + // Get predictions + val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) + } + + // Instantiate metrics object + val metrics = new RegressionMetrics(valuesAndPreds) + + // Squared error + println(s"MSE = ${metrics.meanSquaredError}") + println(s"RMSE = ${metrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${metrics.r2}") + + // Mean absolute error + println(s"MAE = ${metrics.meanAbsoluteError}") + + // Explained variance + println(s"Explained variance = ${metrics.explainedVariance}") + // $example off$ + } +} +// scalastyle:on println + From 58b4e4f88a330135c4cec04a30d24ef91bc61d91 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 15:30:53 -0800 Subject: [PATCH 0948/2089] [SPARK-11787][SPARK-11883][SQL][FOLLOW-UP] Cleanup for this patch. This mainly moves SqlNewHadoopRDD to the sql package. There is some state that is shared between core and I've left that in core. This allows some other associated minor cleanup. Author: Nong Li Closes #9845 from nongli/spark-11787. --- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../spark/rdd/SqlNewHadoopRDDState.scala | 41 +++++++++++++ .../sql/catalyst/expressions/UnsafeRow.java | 59 ++++++++++++++---- .../catalyst/expressions/InputFileName.scala | 6 +- .../parquet/UnsafeRowParquetRecordReader.java | 14 +++++ .../scala/org/apache/spark/sql/SQLConf.scala | 5 ++ .../datasources}/SqlNewHadoopRDD.scala | 60 +++++++------------ .../datasources/parquet/ParquetRelation.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 43 ++++++------- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++ 10 files changed, 175 insertions(+), 80 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala rename {core/src/main/scala/org/apache/spark/rdd => sql/core/src/main/scala/org/apache/spark/sql/execution/datasources}/SqlNewHadoopRDD.scala (86%) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 7db583468792e..f37c95bedc0a5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,8 @@ class HadoopRDD[K, V]( // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -256,7 +256,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala new file mode 100644 index 0000000000000..3f15fff793661 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. + * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD + */ +private[spark] object SqlNewHadoopRDDState { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 33769363a0ed5..b6979d0c82977 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.*; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -26,12 +30,26 @@ import java.util.HashSet; import java.util.Set; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CalendarIntervalType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.NullType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -39,9 +57,23 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DateType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.NullType; +import static org.apache.spark.sql.types.DataTypes.ShortType; +import static org.apache.spark.sql.types.DataTypes.TimestampType; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -116,11 +148,6 @@ public static boolean isMutable(DataType dt) { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -187,6 +214,12 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, numFields, sizeInBytes); } + + public void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index d809877817a5b..bf215783fc27d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{DataType, StringType} @@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDD.getInputFileName() + SqlNewHadoopRDDState.getInputFileName() } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 8a92e489ccb7c..dade488ca281b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -108,6 +108,19 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private static final int DEFAULT_VAR_LEN_SIZE = 32; + /** + * Tries to initialize the reader for this split. Returns true if this reader supports reading + * this split and false otherwise. + */ + public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) { + try { + initialize(inputSplit, taskAttemptContext); + return true; + } catch (Exception e) { + return false; + } + } + /** * Implementation of RecordReader API. */ @@ -326,6 +339,7 @@ private void decodeBinaryBatch(int col, int num) throws IOException { } else { rowWriters[n].write(col, bytes.array(), bytes.position(), len); } + rows[n].setNotNullAt(col); } else { rows[n].setNullAt(col); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f40e603cd1939..5ef3a48c56a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -323,6 +323,11 @@ private[spark] object SQLConf { "option must be set in Hadoop Configuration. 2. This option overrides " + "\"spark.sql.sources.outputCommitterClass\".") + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( + key = "spark.sql.parquet.enableUnsafeRowRecordReader", + defaultValue = Some(true), + doc = "Enables using the custom ParquetUnsafeRowRecordReader.") + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "When true, enable filter pushdown for ORC files.") diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 4d176332b69ce..56cb63d9eff2a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -28,13 +30,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import scala.reflect.ClassTag - private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -61,13 +62,13 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sc : SparkContext, + sqlContext: SQLContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sc, Nil) + extends RDD[V](sqlContext.sparkContext, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -99,7 +100,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). protected val enableUnsafeRowParquetReader: Boolean = - sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) @@ -120,8 +121,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { + theSplit: SparkPartition, + context: TaskContext): Iterator[V] = { val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) @@ -132,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -163,15 +164,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( * TODO: plumb this through a different way? */ if (enableUnsafeRowParquetReader && - format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - // TODO: move this class to sql.execution and remove this. - reader = Utils.classForName( - "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") - .newInstance().asInstanceOf[RecordReader[Void, V]] - try { - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - } catch { - case e: Exception => reader = null + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader() + if (!parquetReader.tryInitialize( + split.serializableHadoopSplit.value, hadoopAttemptContext)) { + parquetReader.close() + } else { + reader = parquetReader.asInstanceOf[RecordReader[Void, V]] } } @@ -217,7 +216,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -235,7 +234,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { @@ -276,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } super.persist(storageLevel) } -} - -private[spark] object SqlNewHadoopRDD { - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index cb0aab8cc0d09..fdd745f48e973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -319,7 +319,7 @@ private[sql] class ParquetRelation( Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, + sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c8028a5ef5528..cc5aae03d5516 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,29 +337,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does - // not do row by row filtering (and we probably don't want to push that). - ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") - - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - - // The result should be single row. - // When a filter is pushed to Parquet, Parquet can apply it to every row. - // So, we can check the number of rows returned from the Parquet - // to make sure our filter pushdown work. - assert(childRDD.count == 1) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // This is the source RDD without Spark-side filtering. + val childRDD = + df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(childRDD.count == 1) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 177ab42f7767c..0c5d4887ed799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -579,6 +579,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("null and non-null strings") { + // Create a dataset where the first values are NULL and then some non-null values. The + // number of non-nulls needs to be bigger than the ParquetReader batch size. + val data = sqlContext.range(200).map { i => + if (i.getLong(0) < 150) Row(None) + else Row("a") + } + val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil)) + assert(df.agg("col" -> "count").collect().head.getLong(0) == 50) + + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/data" + df.write.parquet(path) + + val df2 = sqlContext.read.parquet(path) + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } + } + test("read dictionary encoded decimals written as INT32") { checkAnswer( // Decimal column in this file is encoded using plain dictionary From 968acf3bd9a502fcad15df3e53e359695ae702cc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:36:30 -0800 Subject: [PATCH 0949/2089] [SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` :38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust Closes #9870 from marmbrus/dataset-repl-agg. --- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 27 +++---------------- .../apache/spark/sql/JavaDatasetSuite.java | 8 +++--- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 081aa03002cc6..cbcccb11f14ae 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 6de3dd626576a..263f049104762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - /** - * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. - * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. - * - * @since 1.6.0 - */ + // This is here to prevent us from adding overloads that would be ambiguous. @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ce40dd856f679..f7249b8945c49 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -404,11 +404,9 @@ public String call(Tuple2 value) throws Exception { grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), - expr("sum(_2)"), - count("*")) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( new Tuple4<>("a", 3, 3L, 2L), From 68ed046836975b492b594967256d3c7951b568a5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:38:04 -0800 Subject: [PATCH 0950/2089] [SPARK-11890][SQL] Fix compilation for Scala 2.11 Author: Michael Armbrust Closes #9871 from marmbrus/scala211-break. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 918050b531c02..4a4a62ed1a468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - private def silentSchemaFor(tpe: `Type`): Schema = try { + protected def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - private def getClassNameFromType(tpe: `Type`): String = { + protected def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 47815878ad5e47e89bfbd57acb848be2ce67a4a5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 16:02:03 -0800 Subject: [PATCH 0951/2089] [HOTFIX] Fix Java Dataset Tests --- .../test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f7249b8945c49..f32374b4c04df 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -409,8 +409,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( - new Tuple4<>("a", 3, 3L, 2L), - new Tuple4<>("b", 3, 3L, 1L)), + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), agged2.collectAsList()); } From a2dce22e0a25922e2052318d32f32877b7c27ec2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 20 Nov 2015 16:51:47 -0800 Subject: [PATCH 0952/2089] Revert "[SPARK-11689][ML] Add user guide and example code for LDA under spark.ml" This reverts commit e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7. --- docs/ml-clustering.md | 30 ------ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 - .../spark/examples/ml/JavaLDAExample.java | 94 ------------------- .../apache/spark/examples/ml/LDAExample.scala | 77 --------------- 5 files changed, 1 insertion(+), 204 deletions(-) delete mode 100644 docs/ml-clustering.md delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md deleted file mode 100644 index 1743ef43a6ddf..0000000000000 --- a/docs/ml-clustering.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -layout: global -title: Clustering - ML -displayTitle: ML - Clustering ---- - -In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). - -## Latent Dirichlet allocation (LDA) - -`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, -and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by -`EMLDAOptimizer` to a `DistributedLDAModel` if needed. - -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. - -
    -{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. - -{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 6f35b30c3d4df..be18a05361a17 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,7 +40,6 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -951,4 +950,4 @@ model.transform(test) {% endhighlight %} - \ No newline at end of file + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15af..91e50ccfecec4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,7 +69,6 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java deleted file mode 100644 index b3a7d2eb29780..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.ml.clustering.LDA; -import org.apache.spark.ml.clustering.LDAModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -/** - * An example demonstrating LDA - * Run with - *
    - * bin/run-example ml.JavaLDAExample
    - * 
    - */ -public class JavaLDAExample { - - private static class ParseVector implements Function { - private static final Pattern separator = Pattern.compile(" "); - - @Override - public Row call(String line) { - String[] tok = separator.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - Vector[] points = {Vectors.dense(point)}; - return new GenericRow(points); - } - } - - public static void main(String[] args) { - - String inputFile = "data/mllib/sample_lda_data.txt"; - - // Parses the arguments - SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // Loads data - JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); - StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; - StructType schema = new StructType(fields); - DataFrame dataset = sqlContext.createDataFrame(points, schema); - - // Trains a LDA model - LDA lda = new LDA() - .setK(10) - .setMaxIter(10); - LDAModel model = lda.fit(dataset); - - System.out.println(model.logLikelihood(dataset)); - System.out.println(model.logPerplexity(dataset)); - - // Shows the result - DataFrame topics = model.describeTopics(3); - topics.show(false); - model.transform(dataset).show(false); - - jsc.stop(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala deleted file mode 100644 index 419ce3d87a6ac..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml - -// scalastyle:off println -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -// $example on$ -import org.apache.spark.ml.clustering.LDA -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} -// $example off$ - -/** - * An example demonstrating a LDA of ML pipeline. - * Run with - * {{{ - * bin/run-example ml.LDAExample - * }}} - */ -object LDAExample { - - final val FEATURES_COL = "features" - - def main(args: Array[String]): Unit = { - - val input = "data/mllib/sample_lda_data.txt" - // Creates a Spark context and a SQL context - val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) - - // Trains a LDA model - val lda = new LDA() - .setK(10) - .setMaxIter(10) - .setFeaturesCol(FEATURES_COL) - val model = lda.fit(dataset) - val transformed = model.transform(dataset) - - val ll = model.logLikelihood(dataset) - val lp = model.logPerplexity(dataset) - - // describeTopics - val topics = model.describeTopics(3) - - // Shows the result - topics.show(false) - transformed.show(false) - - // $example off$ - sc.stop() - } -} -// scalastyle:on println From 7d3f922c4ba76c4193f98234ae662065c39cdfb1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 23:31:19 -0800 Subject: [PATCH 0953/2089] [SPARK-11819][SQL][FOLLOW-UP] fix scala 2.11 build seems scala 2.11 doesn't support: define private methods in `trait xxx` and use it in `object xxx extend xxx`. Author: Wenchen Fan Closes #9879 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4a4a62ed1a468..476becec4dd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - protected def silentSchemaFor(tpe: `Type`): Schema = try { + def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - protected def getClassNameFromType(tpe: `Type`): String = { + def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 54328b6d862fe62ae01bdd87df4798ceb9d506d6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:10:13 -0800 Subject: [PATCH 0954/2089] [SPARK-11900][SQL] Add since version for all encoders Author: Reynold Xin Closes #9881 from rxin/SPARK-11900. --- .../scala/org/apache/spark/sql/Encoder.scala | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 86bb536459035..5cb8edf64e87c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,13 +45,52 @@ trait Encoder[T] extends Serializable { */ object Encoders { + /** + * An encoder for nullable boolean type. + * @since 1.6.0 + */ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * @since 1.6.0 + */ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * @since 1.6.0 + */ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * @since 1.6.0 + */ def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * @since 1.6.0 + */ def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * @since 1.6.0 + */ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * @since 1.6.0 + */ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * @since 1.6.0 + */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** @@ -59,6 +98,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) @@ -67,6 +108,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -77,6 +120,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -87,6 +132,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) @@ -120,12 +167,20 @@ object Encoders { ) } + /** + * An encoder for 2-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2]( e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) } + /** + * An encoder for 3-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3]( e1: Encoder[T1], e2: Encoder[T2], @@ -133,6 +188,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) } + /** + * An encoder for 4-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4]( e1: Encoder[T1], e2: Encoder[T2], @@ -141,6 +200,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) } + /** + * An encoder for 5-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4, T5]( e1: Encoder[T1], e2: Encoder[T2], From 596710268e29e8f624c3ba2fade08b66ec7084eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:54:18 -0800 Subject: [PATCH 0955/2089] [SPARK-11901][SQL] API audit for Aggregator. Author: Reynold Xin Closes #9882 from rxin/SPARK-11901. --- .../scala/org/apache/spark/sql/Dataset.scala | 1 - .../spark/sql/expressions/Aggregator.scala | 39 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bdcdc5d47cbae..07647508421a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 72610e735f782..b0cd32b5f73e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} /** * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] @@ -32,55 +31,65 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} * case class Data(i: Int) * * val customSummer = new Aggregator[Data, Int, Int] { - * def zero = 0 - * def reduce(b: Int, a: Data) = b + a.i - * def present(r: Int) = r + * def zero: Int = 0 + * def reduce(b: Int, a: Data): Int = b + a.i + * def merge(b1: Int, b2: Int): Int = b1 + b2 + * def present(r: Int): Int = r * }.toColumn() * - * val ds: Dataset[Data] + * val ds: Dataset[Data] = ... * val aggregated = ds.select(customSummer) * }}} * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam A The input type for the aggregation. + * @tparam I The input type for the aggregation. * @tparam B The type of the intermediate value of the reduction. - * @tparam C The type of the final result. + * @tparam O The type of the final output result. + * + * @since 1.6.0 */ -abstract class Aggregator[-A, B, C] extends Serializable { +abstract class Aggregator[-I, B, O] extends Serializable { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + /** + * A zero value for this aggregation. Should satisfy the property that any b + zero = b. + * @since 1.6.0 + */ def zero: B /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. + * @since 1.6.0 */ - def reduce(b: B, a: A): B + def reduce(b: B, a: I): B /** - * Merge two intermediate values + * Merge two intermediate values. + * @since 1.6.0 */ def merge(b1: B, b2: B): B /** * Transform the output of the reduction. + * @since 1.6.0 */ - def finish(reduction: B): C + def finish(reduction: B): O /** * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] * operations. + * @since 1.6.0 */ def toColumn( implicit bEncoder: Encoder[B], - cEncoder: Encoder[C]): TypedColumn[A, C] = { + cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = new AggregateExpression( TypedAggregateExpression(this), Complete, false) - new TypedColumn[A, C](expr, encoderFor[C]) + new TypedColumn[I, O](expr, encoderFor[O]) } } From ff442bbcffd4f93cfcc2f76d160011e725d2fb3f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 15:00:37 -0800 Subject: [PATCH 0956/2089] [SPARK-11899][SQL] API audit for GroupedDataset. 1. Renamed map to mapGroup, flatMap to flatMapGroup. 2. Renamed asKey -> keyAs. 3. Added more documentation. 4. Changed type parameter T to V on GroupedDataset. 5. Added since versions for all functions. Author: Reynold Xin Closes #9880 from rxin/SPARK-11899. --- .../api/java/function/MapGroupFunction.java | 2 +- .../scala/org/apache/spark/sql/Encoder.scala | 4 + .../sql/catalyst/JavaTypeInference.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 2 + .../org/apache/spark/sql/DataFrame.scala | 1 - .../org/apache/spark/sql/GroupedDataset.scala | 132 ++++++++++++++---- .../apache/spark/sql/JavaDatasetSuite.java | 8 +- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 20 +-- 9 files changed, 131 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java index 2935f9986a560..4f3f222e064bb 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -21,7 +21,7 @@ import java.util.Iterator; /** - * Base interface for a map function used in GroupedDataset's map function. + * Base interface for a map function used in GroupedDataset's mapGroup function. */ public interface MapGroupFunction extends Serializable { R call(K key, Iterator values) throws Exception; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 5cb8edf64e87c..03aa25eda807f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.types._ * * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. + * + * @since 1.6.0 */ trait Encoder[T] extends Serializable { @@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable { /** * Methods for creating encoders. + * + * @since 1.6.0 */ object Encoders { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 88a457f87ce4e..7d4cfbe6faecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ /** * Type-inference utilities for POJOs and Java collections. */ -private [sql] object JavaTypeInference { +object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) @@ -53,7 +53,6 @@ private [sql] object JavaTypeInference { * @return (SQL data type, nullable) */ private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 82e9cd7f50a31..30c554a85e693 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -46,6 +46,8 @@ private[sql] object Column { * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. + * + * @since 1.6.0 */ class TypedColumn[-T, U]( expr: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7abcecaa2880e..5586fc994b98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -110,7 +110,6 @@ private[sql] object DataFrame { * @groupname action Actions * @since 1.3.0 */ -// TODO: Improve documentation. @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 263f049104762..7f43ce16901b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.Aggregator /** * :: Experimental :: @@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution * making this change to the class hierarchy would break some function signatures. As such, this * class should be considered a preview of the final API. Changes will be made to the interface * after Spark 1.6. + * + * @since 1.6.0 */ @Experimental -class GroupedDataset[K, T] private[sql]( +class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[T], + tEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * + * @since 1.6.0 */ - def asKey[L : Encoder]: GroupedDataset[L, T] = + def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], unresolvedTEncoder, @@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 */ def keys: Dataset[K] = { new Dataset[K]( @@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql]( * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( @@ -108,8 +121,25 @@ class GroupedDataset[K, T] private[sql]( logicalPlan)) } - def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -117,32 +147,62 @@ class GroupedDataset[K, T] private[sql]( * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { - val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) - flatMap(func) + def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroup(func) } - def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - map((key, data) => f.call(key, data.asJava))(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroup((key, data) => f.call(key, data.asJava))(encoder) } /** * Reduces the elements of each group of data using the specified binary function. * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 */ - def reduce(f: (T, T) => T): Dataset[(K, T)] = { - val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + def reduce(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMap(func) + flatMapGroup(func) } - def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { reduce(f.call _) } @@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql]( /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3, U4]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3], - col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. + * + * @since 1.6.0 */ def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) @@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql]( * be passed the grouping key and 2 iterators containing all elements in the group from * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, @@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql]( other.logicalPlan)) } + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ def cogroup[U, R]( other: GroupedDataset[K, U], - f: CoGroupFunction[K, T, U, R], + f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f32374b4c04df..cf335efdd23b8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.map(new MapGroupFunction() { + Dataset mapped = grouped.mapGroup(new MapGroupFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,7 +183,7 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMap( + Dataset flatMapped = grouped.flatMapGroup( new FlatMapGroupFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { @@ -247,9 +247,9 @@ public void testGroupByColumn() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); GroupedDataset grouped = - ds.groupBy(length(col("value"))).asKey(Encoders.INT()); + ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.map( + Dataset mapped = grouped.mapGroup( new MapGroupFunction() { @Override public String call(Integer key, Iterator data) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 63b00975e4eb1..d387710357be0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.map { case (g, iter) => + val agged = grouped.mapGroup { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 89d964aa3e469..9da02550b39ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + val agged = grouped.flatMapGroup { case (g, iter) => + Iterator(g._1, iter.map(_._2).sum.toString) + } checkAnswer( agged, @@ -253,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -262,8 +264,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -272,8 +274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -282,8 +284,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 426004a9c9a864f90494d08601e6974709091a56 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 22 Nov 2015 10:36:47 -0800 Subject: [PATCH 0957/2089] [SPARK-11908][SQL] Add NullType support to RowEncoder JIRA: https://issues.apache.org/jira/browse/SPARK-11908 We should add NullType support to RowEncoder. Author: Liang-Chi Hsieh Closes #9891 from viirya/rowencoder-nulltype. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 5 +++-- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 3 +++ .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 4cda4824acdc3..fa553e7c5324c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -48,7 +48,7 @@ object RowEncoder { private def extractorsFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject case udt: UserDefinedType[_] => @@ -143,6 +143,7 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) } private def constructorFor(schema: StructType): Expression = { @@ -158,7 +159,7 @@ object RowEncoder { } private def constructorFor(input: Expression): Expression = input.dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index ef7399e0196ab..82317d3385167 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -369,6 +369,9 @@ case class MapObjects( private lazy val completeFunction = function(loopAttribute) private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case NullType => + val nullTypeClassName = NullType.getClass.getName + ".MODULE$" + (i: String) => s".get($i, $nullTypeClassName)" case IntegerType => (i: String) => s".getInt($i)" case LongType => (i: String) => s".getLong($i)" case FloatType => (i: String) => s".getFloat($i)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 46c6e0d98d349..0ea51ece4bc5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -80,11 +80,13 @@ class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) private val mapOfString = MapType(StringType, StringType) private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() + .add("null", NullType) .add("boolean", BooleanType) .add("byte", ByteType) .add("short", ShortType) @@ -101,6 +103,7 @@ class RowEncoderSuite extends SparkFunSuite { encodeDecodeTest( new StructType() + .add("arrayOfNull", arrayOfNull) .add("arrayOfString", arrayOfString) .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) From fe89c1817d668e46adf70d0896c42c22a547c76a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 22 Nov 2015 21:45:46 -0800 Subject: [PATCH 0958/2089] [SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng Closes #9873 from mengxr/SPARK-11895. --- .../DataFrameExample.scala} | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/{mllib/DatasetExample.scala => ml/DataFrameExample.scala} (51%) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala similarity index 51% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index dc13f82488af7..424f00158c2f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -16,7 +16,7 @@ */ // scalastyle:off println -package org.apache.spark.examples.mllib +package org.apache.spark.examples.ml import java.io.File @@ -24,25 +24,22 @@ import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * ./bin/run-example ml.DataFrameExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DatasetExample { +object DataFrameExample { - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -52,9 +49,6 @@ object DatasetExample { opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) checkConfig { params => success } @@ -69,55 +63,42 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) + // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() sc.stop() } - } // scalastyle:on println From a6fda0bfc16a13b28b1cecc96f1ff91363089144 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 22 Nov 2015 21:48:48 -0800 Subject: [PATCH 0959/2089] [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9848 from jkbradley/cv-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +-- .../BinaryClassificationEvaluator.scala | 11 +- .../MulticlassClassificationEvaluator.scala | 12 +- .../ml/evaluation/RegressionEvaluator.scala | 11 +- .../apache/spark/ml/recommendation/ALS.scala | 14 +- .../spark/ml/tuning/CrossValidator.scala | 229 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 48 ++-- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../BinaryClassificationEvaluatorSuite.scala | 13 +- ...lticlassClassificationEvaluatorSuite.scala | 13 +- .../evaluation/RegressionEvaluatorSuite.scala | 12 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 202 ++++++++++++++- 12 files changed, 522 insertions(+), 85 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6f15b37abcb30..4b2b3f8489fd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] { stages: Array[PipelineStage], sc: SparkContext, path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] { implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) } (metadata.uid, stages) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca81c3..bfb70963b151d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a8d9..c44db0ec595ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d3e0..daaa174a086e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4d35177ad9b0f..b798aa1fab767 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] { private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - val extraMetadata = render("rank" -> instance.rank) + val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] { override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats - val rank: Int = metadata.extraMetadata match { - case Some(m: JValue) => - (m \ "rank").extract[Int] - case None => - throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + - s" ${metadata.metadataStr}") - } - + val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString val userFactors = sqlContext.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed86b9..83a9048374267 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,17 +18,24 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.{JObject, DefaultFormats} +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { */ @Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { + with CrossValidatorParams with MLWritable with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** @@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { override def validateParams(): Unit = { bestModel.validateParams() @@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ff9322dba122a..8484b1f801066 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata( instance: Params, path: String, sc: SparkContext, - extraMetadata: Option[JValue] = None): Unit = { + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = params.map { case ParamPair(p, v) => + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) - }.toList - val metadata = ("class" -> cls) ~ + }.toList)) + val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - ("extraMetadata" -> extraMetadata) + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] - * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] - * @param metadataStr Full metadata file String (for debugging) + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) */ case class Metadata( className: String, @@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, - extraMetadata: Option[JValue], - metadataStr: String) + metadata: JValue, + metadataJson: String) /** * Load metadata from file. @@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" - val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } /** @@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader { } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 12aba6bc6dbeb..8c86767456368 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml -import java.io.File - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe66777..a535c1218ecfa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b3701..7ee65975d22f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da323935..60886bf77d2f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a0337..dd6366050c020 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] From fc4b792d287095d70379a51f117c225d8d857078 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Sun, 22 Nov 2015 21:51:42 -0800 Subject: [PATCH 0960/2089] [SPARK-11835] Adds a sidebar menu to MLlib's documentation This PR adds a sidebar menu when browsing the user guide of MLlib. It uses a YAML file to describe the structure of the documentation. It should be trivial to adapt this to the other projects. ![screen shot 2015-11-18 at 4 46 12 pm](https://cloud.githubusercontent.com/assets/7594753/11259591/a55173f4-8e17-11e5-9340-0aed79d66262.png) Author: Timothy Hunter Closes #9826 from thunterdb/spark-11835. --- docs/_data/menu-ml.yaml | 10 ++++ docs/_data/menu-mllib.yaml | 75 +++++++++++++++++++++++++ docs/_includes/nav-left-wrapper-ml.html | 8 +++ docs/_includes/nav-left.html | 17 ++++++ docs/_layouts/global.html | 24 +++++--- docs/css/main.css | 37 ++++++++++++ 6 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 docs/_data/menu-ml.yaml create mode 100644 docs/_data/menu-mllib.yaml create mode 100644 docs/_includes/nav-left-wrapper-ml.html create mode 100644 docs/_includes/nav-left.html diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 0000000000000..dff3d33bf4ed1 --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,10 @@ +- text: Feature extraction, transformation, and selection + url: ml-features.html +- text: Decision trees for classification and regression + url: ml-decision-tree.html +- text: Ensembles + url: ml-ensembles.html +- text: Linear methods with elastic-net regularization + url: ml-linear-methods.html +- text: Multilayer perceptron classifier + url: ml-ann.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 0000000000000..12d22abd52826 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 0000000000000..0103e890cc21a --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
    +
    +

    spark.ml package

    + {% include nav-left.html nav=include.nav-ml %} +

    spark.mllib package

    + {% include nav-left.html nav=include.nav-mllib %} +
    +
    \ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 0000000000000..73176f4132554 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 467ff7a03fb70..1b09e2221e173 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -124,16 +124,24 @@ -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    - {{ content }} + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + {% endif %} -
    + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} + +
    +
    diff --git a/docs/css/main.css b/docs/css/main.css index d770173be1014..356b324d6303b 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,18 @@ margin-left: 10px; } +body .container-wrapper { + position: absolute; + width: 100%; + display: flex; +} + body #content { + position: relative; + line-height: 1.6; /* Inspired by Github's wiki style */ + background-color: white; + padding-left: 15px; } .title { @@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + position: absolute; + height: 100%; + + width: 256px; + margin-top: -20px; + padding-top: 20px; + background-color: #F0F8FC; +} + +.left-menu { + position: fixed; + max-width: 350px; + + padding-right: 10px; + width: 256px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} \ No newline at end of file From d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 22 Nov 2015 21:56:07 -0800 Subject: [PATCH 0961/2089] [SPARK-11912][ML] ml.feature.PCA minor refactor Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang Closes #9897 from yanboliang/spark-11912. --- .../org/apache/spark/ml/feature/PCA.scala | 23 +++++++------- .../apache/spark/ml/feature/PCASuite.scala | 31 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 32d7afee6e73b..aa88cb03d23c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] { /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) + val pc: DenseMatrix) extends Model[PCAModel] with PCAParams with MLWritable { import PCAModel._ - /** a principal components Matrix. Each column is one principal component. */ - val pc: DenseMatrix = pcaModel.pc - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -124,6 +123,7 @@ class PCAModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -139,7 +139,7 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) + val copied = new PCAModel(uid, pc) copyValues(copied, extra).setParent(parent) } @@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] { private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { - private case class Data(k: Int, pc: DenseMatrix) + private case class Data(pc: DenseMatrix) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.getK, instance.pc) + val data = Data(instance.pc) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) - .select("k", "pc") + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("pc") .head() - val oldModel = new feature.PCAModel(k, pc) - val model = new PCAModel(metadata.uid, oldModel) + val model = new PCAModel(metadata.uid, pc) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 5a21cd20ceede..edab21e6c3072 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val model = new PCAModel("pca", mat) ParamsSuite.checkParams(model) } @@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } } - test("read/write") { + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } - def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { - assert(model1.pc === model2.pc) - } - val allParams: Map[String, Any] = Map( - "k" -> 3, - "inputCol" -> "features", - "outputCol" -> "pca_features" - ) - val data = Seq( - (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), - (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - ) - val df = sqlContext.createDataFrame(data).toDF("id", "features") - val pca = new PCA().setK(3) - testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) } } From 4be360d4ee6cdb4d06306feca38ddef5212608cf Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 22 Nov 2015 22:05:01 -0800 Subject: [PATCH 0962/2089] [SPARK-11902][ML] Unhandled case in VectorAssembler#transform There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT. So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType". This PR aims to fix this, throwing a SparkException when dealing with an unknown column type. Author: BenFradet Closes #9885 from BenFradet/SPARK-11902. --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 2 ++ .../spark/ml/feature/VectorAssemblerSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec0549852b..801096fed27bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9bf2c..9c1c00f41ab1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) From 94ce65dfcbba1fe3a1fc9d8002c37d9cd1a11336 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 23 Nov 2015 08:53:40 -0800 Subject: [PATCH 0963/2089] [SPARK-11628][SQL] support column datatype of char(x) to recognize HiveChar Can someone review my code to make sure I'm not missing anything? Thanks! Author: Xiu Guo Author: Xiu Guo Closes #9612 from xguo27/SPARK-11628. --- .../sql/catalyst/util/DataTypeParser.scala | 6 ++++- .../catalyst/util/DataTypeParserSuite.scala | 8 ++++-- .../spark/sql/sources/TableScanSuite.scala | 5 ++++ .../spark/sql/hive/HiveInspectors.scala | 25 ++++++++++++++++--- .../apache/spark/sql/hive/TableReader.scala | 3 +++ .../spark/sql/hive/client/HiveShim.scala | 3 ++- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 2b83651f9086d..515c071c283b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -52,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | - varchar + varchar | + char protected lazy val fixedDecimalType: Parser[DataType] = ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { @@ -60,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { DecimalType(precision.toInt, scale.toInt) } + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + protected lazy val varchar: Parser[DataType] = "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1e3409a9db6eb..bebf708965474 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -49,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) checkDataType("array", ArrayType(DoubleType, true)) @@ -83,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite { |struct< | struct:struct, | MAP:Map, - | arrAy:Array> + | arrAy:Array, + | anotherArray:Array> """.stripMargin, StructType( StructField("struct", @@ -91,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) // A column name can be a reserved word in our DDL parser and SqlParser. checkDataType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 12af8068c398f..26c1ff520406c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -85,6 +85,7 @@ case class AllDataTypesScan( Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 36f0708f9da3d..95b57d6ad124a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} @@ -61,6 +61,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -75,6 +76,7 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -93,7 +95,8 @@ import org.apache.spark.unsafe.types.UTF8String * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -137,6 +140,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -259,6 +263,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -303,11 +309,15 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => @@ -377,6 +387,15 @@ private[hive] trait HiveInspectors { null } + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.size) + } else { + null + } + case _: JavaHiveDecimalObjectInspector => (o: Any) => if (o != null) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 69f481c49a655..70ee02823eeba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -382,6 +382,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 48bbb21e6c1de..346840079b853 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -321,7 +321,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { From 1a5baaa6517872b9a4fd6cd41c4b2cf1e390f6d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:13:59 -0800 Subject: [PATCH 0964/2089] [SPARK-11894][SQL] fix isNull for GetInternalRowField We should use `InternalRow.isNullAt` to check if the field is null before calling `InternalRow.getXXX` Thanks gatorsmile who discovered this bug. Author: Wenchen Fan Closes #9904 from cloud-fan/null. --- .../sql/catalyst/expressions/objects.scala | 23 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 82317d3385167..4a1f419f0ad8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -236,11 +236,6 @@ case class NewInstance( } if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -531,15 +526,15 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val row = child.gen(ctx) - s""" - ${row.code} - final boolean ${ev.isNull} = ${row.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; - } - """ + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9da02550b39ce..cc8e4325fd2f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((JavaData(1), 1L), (JavaData(2), 1L))) } - ignore("Java encoder self join") { + test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(1)), (JavaData(2), JavaData(2)))) } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkAnswer( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } } From f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:15:40 -0800 Subject: [PATCH 0965/2089] [SPARK-11921][SQL] fix `nullable` of encoder schema Author: Wenchen Fan Closes #9906 from cloud-fan/nullable. --- .../catalyst/encoders/ExpressionEncoder.scala | 15 +++++++- .../encoders/ExpressionEncoderSuite.scala | 38 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c1f3..7bc9aed0b204e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,8 +54,13 @@ object ExpressionEncoder { val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + new ExpressionEncoder[T]( - toRowExpression.dataType, + schema, flat, toRowExpression.flatten, fromRowExpression, @@ -71,7 +76,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 76459b34a484f..d6ca138672ef1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( From 946b406519af58c79041217e6f93854b6cf80acd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:39:33 -0800 Subject: [PATCH 0966/2089] [SPARK-11913][SQL] support typed aggregate with complex buffer schema Author: Wenchen Fan Closes #9898 from cloud-fan/agg. --- .../aggregate/TypedAggregateExpression.scala | 25 +++++++---- .../spark/sql/DatasetAggregatorSuite.scala | 41 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6ce41aaf01e27..a9719128a626e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -23,9 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -46,14 +45,12 @@ object TypedAggregateExpression { /** * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has * the following limitations: - * - It assumes the aggregator reduces and returns a single column of type `long`. - * - It might only work when there is a single aggregator in the first column. * - It assumes the aggregator has a zero, `0`. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - bEncoder: ExpressionEncoder[Any], // Should be bound. + unresolvedBEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -80,10 +77,14 @@ case class TypedAggregateExpression( override lazy val inputTypes: Seq[DataType] = Nil - override val aggBufferSchema: StructType = bEncoder.schema + override val aggBufferSchema: StructType = unresolvedBEncoder.schema override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + val bEncoder = unresolvedBEncoder + .resolve(aggBufferAttributes, OuterScopes.outerScopes) + .bind(aggBufferAttributes) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -93,12 +94,18 @@ case class TypedAggregateExpression( lazy val boundA = aEncoder.get private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - // todo: need a more neat way to assign the value. var i = 0 while (i < aggBufferAttributes.length) { + val offset = mutableAggBufferOffset + i aggBufferSchema(i).dataType match { - case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) - case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) + case ByteType => buffer.setByte(offset, value.getByte(i)) + case ShortType => buffer.setShort(offset, value.getShort(i)) + case IntegerType => buffer.setInt(offset, value.getInt(i)) + case LongType => buffer.setLong(offset, value.getLong(i)) + case FloatType => buffer.setFloat(offset, value.getFloat(i)) + case DoubleType => buffer.setDouble(offset, value.getDouble(i)) + case other => buffer.update(offset, value.get(i, other)) } i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 9377589790011..19dce5d1e2f37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L } case class AggData(a: Int, b: String) -object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { +object ClassInputAgg extends Aggregator[AggData, Int, Int] { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 @@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { override def merge(b1: Int, b2: Int): Int = b1 + b2 } +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupBy(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkAnswer( + ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } } From 5fd86e4fc2e06d2403ca538ae417580c93b69e06 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 23 Nov 2015 10:41:17 -0800 Subject: [PATCH 0967/2089] [SPARK-7173][YARN] Add label expression support for application master Add label expression support for AM to restrict it runs on the specific set of nodes. I tested it locally and works fine. sryza and vanzin please help to review, thanks a lot. Author: jerryshao Closes #9800 from jerryshao/SPARK-7173. --- docs/running-on-yarn.md | 9 +++++++ .../org/apache/spark/deploy/yarn/Client.scala | 26 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index db6bfa69ee0fe..925a1e0ba6fcf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -326,6 +326,15 @@ If you need a reference to the proper location to put log files in the YARN so t Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ba799884f5689..a77a3e2420e24 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -225,7 +225,31 @@ private[spark] class Client( val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + + if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) { + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression") + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, amLabelExpression) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + } else { + appContext.setResource(capability) + } + appContext } From 5231cd5acaae69d735ba3209531705cc222f3cfb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 10:45:23 -0800 Subject: [PATCH 0968/2089] [SPARK-11762][NETWORK] Account for active streams when couting outstanding requests. This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin Closes #9747 from vanzin/SPARK-11762. --- .../network/client/StreamInterceptor.java | 12 ++++++++- .../client/TransportResponseHandler.java | 15 +++++++++-- .../TransportResponseHandlerSuite.java | 27 +++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 02230a00e69fc..88ba3ccebdf20 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -30,13 +30,19 @@ */ class StreamInterceptor implements TransportFrameDecoder.Interceptor { + private final TransportResponseHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private volatile long bytesRead; - StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; this.streamId = streamId; this.byteCount = byteCount; this.callback = callback; @@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } @@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); + handler.deactivateStream(); throw re; } else if (bytesRead == byteCount) { + handler.deactivateStream(); callback.onComplete(streamId); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index ed3f36af58048..cc88991b588c1 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; private final Queue streamCallbacks; + private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) { } public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(callback); } + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -177,14 +185,16 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); frameDecoder.setInterceptor(interceptor); + streamActive = true; } catch (Exception e) { logger.error("Error installing stream handler.", e); + deactivateStream(); } } else { logger.error("Could not find callback for StreamResponse."); @@ -208,7 +218,8 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 17a03ebe88a93..30144f4a9fc7a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -28,12 +29,16 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuite { @Test @@ -112,4 +117,26 @@ public void handleFailedRPC() { verify(callback, times(1)).onFailure((Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void testActiveStreams() { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } } From 98d7ec7df4bb115dbd84cb9acd744b6c8abfebd5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Nov 2015 11:51:29 -0800 Subject: [PATCH 0969/2089] [SPARK-11920][ML][DOC] ML LinearRegression should use correct dataset in examples and user guide doc ML ```LinearRegression``` use ```data/mllib/sample_libsvm_data.txt``` as dataset in examples and user guide doc, but it's actually classification dataset rather than regression dataset. We should use ```data/mllib/sample_linear_regression_data.txt``` instead. The deeper causes is that ```LinearRegression``` with "normal" solver can not solve this dataset correctly, may be due to the ill condition and unreasonable label. This issue has been reported at [SPARK-11918](https://issues.apache.org/jira/browse/SPARK-11918). It will confuse users if they run the example code but get exception, so we should make this change which can clearly illustrate the usage of ```LinearRegression``` algorithm. Author: Yanbo Liang Closes #9905 from yanboliang/spark-11920. --- .../examples/ml/JavaLinearRegressionWithElasticNetExample.java | 2 +- .../src/main/python/ml/linear_regression_with_elastic_net.py | 3 ++- .../examples/ml/LinearRegressionWithElasticNetExample.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java index 593f8fb3e9fe9..4ad7676c8d32b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -37,7 +37,7 @@ public static void main(String[] args) { // $example on$ // Load training data DataFrame training = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); + .load("data/mllib/sample_linear_regression_data.txt"); LinearRegression lr = new LinearRegression() .setMaxIter(10) diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index b0278276330c3..a4cd40cf26726 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -29,7 +29,8 @@ # $example on$ # Load training data - training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 5a51ece6f9ba7..22c824cea84d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -33,7 +33,8 @@ object LinearRegressionWithElasticNetExample { // $example on$ // Load training data - val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() .setMaxIter(10) From f6dcc6e96ad3f88563d717d5b6c45394b44db747 Mon Sep 17 00:00:00 2001 From: Mortada Mehyar Date: Mon, 23 Nov 2015 12:03:15 -0800 Subject: [PATCH 0970/2089] [SPARK-11837][EC2] python3 compatibility for launching ec2 m3 instances this currently breaks for python3 because `string` module doesn't have `letters` anymore, instead `ascii_letters` should be used Author: Mortada Mehyar Closes #9797 from mortada/python3_fix. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9327e21e43db7..9fd652a3df4c4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -595,7 +595,7 @@ def launch_cluster(conn, opts, cluster_name): dev = BlockDeviceType() dev.ephemeral_name = 'ephemeral%d' % i # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] + name = '/dev/sd' + string.ascii_letters[i + 1] block_map[name] = dev # Launch slaves From 1b6e938be836786bac542fa430580248161e5403 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 13:19:10 -0800 Subject: [PATCH 0971/2089] [SPARK-4424] Remove spark.driver.allowMultipleContexts override in tests This patch removes `spark.driver.allowMultipleContexts=true` from our test configuration. The multiple SparkContexts check was originally disabled because certain tests suites in SQL needed to create multiple contexts. As far as I know, this configuration change is no longer necessary, so we should remove it in order to make it easier to find test cleanup bugs. Author: Josh Rosen Closes #9865 from JoshRosen/SPARK-4424. --- pom.xml | 2 -- project/SparkBuild.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/pom.xml b/pom.xml index ad849112ce76c..234fd5dea1a6e 100644 --- a/pom.xml +++ b/pom.xml @@ -1958,7 +1958,6 @@ false false false - true true src @@ -1997,7 +1996,6 @@ 1 false false - true true __not_used__ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 67724c4e9e411..f575f0012d59e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -632,7 +632,6 @@ object TestSettings { javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", From 1d9120201012213edb1971a09e0849336dbb9415 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 23 Nov 2015 13:44:30 -0800 Subject: [PATCH 0972/2089] [SPARK-11836][SQL] udf/cast should not create new SQLContext They should use the existing SQLContext. Author: Davies Liu Closes #9914 from davies/create_udf. --- python/pyspark/sql/column.py | 7 ++++--- python/pyspark/sql/functions.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9ca8e1f264cfa..81fd4e782628a 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -346,9 +346,10 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c3da513c13897..a1ca723bbd7ab 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1457,14 +1457,15 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): + from pyspark.sql import SQLContext f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, From 242be7daed9b01d19794bb2cf1ac421fe5ab7262 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 23 Nov 2015 13:46:34 -0800 Subject: [PATCH 0973/2089] [SPARK-11910][STREAMING][DOCS] Update twitter4j dependency version Author: Luciano Resende Closes #9892 from lresende/SPARK-11910. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96b36b7a73209..ed6b28c282135 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -723,7 +723,7 @@ Some of these advanced sources are as follows. - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by Twitter4J library. You can either get the public stream, or get the filtered stream based on a From 7cfa4c6bc36d97e459d4adee7b03d537d63c337e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:51:43 -0800 Subject: [PATCH 0974/2089] [SPARK-11865][NETWORK] Avoid returning inactive client in TransportClientFactory. There's a very narrow race here where it would be possible for the timeout handler to close a channel after the client factory verified that the channel was still active. This change makes sure the client is marked as being recently in use so that the timeout handler does not close it until a new timeout cycle elapses. Author: Marcelo Vanzin Closes #9853 from vanzin/SPARK-11865. --- .../spark/network/client/TransportClient.java | 9 ++++- .../client/TransportClientFactory.java | 15 ++++++-- .../client/TransportResponseHandler.java | 9 +++-- .../server/TransportChannelHandler.java | 36 ++++++++++++------- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a0ba223e340a2..876fcd846791c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -73,10 +73,12 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; } public Channel getChannel() { @@ -84,7 +86,7 @@ public Channel getChannel() { } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { @@ -263,6 +265,11 @@ public void onFailure(Throwable e) { } } + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 42a4f664e697c..659c47160c7be 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index cc88991b588c1..be181e0660826 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -71,7 +71,7 @@ public TransportResponseHandler(Channel channel) { } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -80,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -227,4 +227,9 @@ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f8fcd1c3d7d76..29d688a67578c 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -116,20 +116,32 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc // there are outstanding requests, we also do a secondary consistency check to ensure // there's no race between the idle timeout and incrementing the numOutstandingRequests // (see SPARK-7003). - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); - } else if (closeIdleConnections) { - // While CloseIdleConnections is enable, we also close idle connection - ctx.close(); + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } } } } } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + } From c2467dadae8ce44010a912ee91c429310f8add65 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:54:19 -0800 Subject: [PATCH 0975/2089] [SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv. This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. Author: Marcelo Vanzin Closes #9530 from vanzin/SPARK-11140. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 14 -- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 ++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 60 +++++++- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 138 ++++++++++++++++-- .../spark/rpc/netty/NettyStreamManager.scala | 63 ++++++++ .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 39 ++++- .../rpc/netty/NettyRpcHandlerSuite.scala | 10 +- docs/configuration.md | 2 + docs/security.md | 5 +- .../launcher/AbstractCommandBuilder.java | 2 +- .../client/TransportClientFactory.java | 6 +- .../server/TransportChannelHandler.java | 1 + 14 files changed, 356 insertions(+), 47 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index af4456c05b0a1..b153a7b08e590 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 88df27f733f2a..84230e32a4462 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -66,7 +66,6 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, @@ -91,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -367,17 +365,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -422,7 +409,6 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, memoryManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf76..3d7d281b0dd66 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,6 +17,9 @@ package org.apache.spark.rpc +import java.io.File +import java.nio.channels.ReadableByteChannel + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} @@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. + */ + def openChannel(uri: String): ReadableByteChannel + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 059a7e10ec12f..94dbec593c315 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * remove Akka from the dependencies. */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def openChannel(uri: String): ReadableByteChannel = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3ce359868039b..68701f609f77a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable @@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( + private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool @@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = error = e + + override def read(dst: ByteBuffer): Int = { + if (error != null) { + throw error + } + source.read(dst) + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef( override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { // TODO: Can we add connection callback (channel registered) to the underlying framework? // A variable to track whether we should dispatch the RemoteProcessConnected message. @@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler( dispatcher.postRemoteMessage(messageToDispatch, callback) } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] @@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 0000000000000..eb1d2604fb235 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + } + + require(file != null, s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1b3acb8ef7f51..af632349c9cae 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.nio.channels.Channels import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection @@ -535,6 +536,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2f55006420ce1..2b664c6313efa 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.rpc -import java.io.NotSerializableException +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -25,10 +27,14 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val jarUri = env.fileServer.addJar(jar) + + val destDir = Utils.createTempDir() + val destFile = new File(destDir, file.getName()) + val destJar = new File(destDir, jar.getName()) + + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) + Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) + + assert(Files.equal(file, destFile)) + assert(Files.equal(jar, destJar)) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b66..ccca795683da3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,17 +25,19 @@ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) @@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) diff --git a/docs/configuration.md b/docs/configuration.md index c496146e3ed63..4de202d7f7631 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1020,6 +1020,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the Akka RPC backend. @@ -1027,6 +1028,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend. diff --git a/docs/security.md b/docs/security.md index 177109415180b..e1af221d446b0 100644 --- a/docs/security.md +++ b/docs/security.md @@ -149,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -157,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47fc..55fe156cf665f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 659c47160c7be..61bafc8380049 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -170,8 +170,10 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } /** - * Create a completely new {@link TransportClient} to the given remote host / port - * But this connection is not pooled. + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 29d688a67578c..3164e00679035 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -138,6 +138,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } } } + ctx.fireUserEventTriggered(evt); } public TransportResponseHandler getResponseHandler() { From 9db5f601facfdaba6e4333a6b2d2e4a9f009c788 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 16:33:26 -0800 Subject: [PATCH 0976/2089] [SPARK-9866][SQL] Speed up VersionsSuite by using persistent Ivy cache This patch attempts to speed up VersionsSuite by storing fetched Hive JARs in an Ivy cache that persists across tests runs. If `SPARK_VERSIONS_SUITE_IVY_PATH` is set, that path will be used for the cache; if it is not set, VersionsSuite will create a temporary Ivy cache which is deleted after the test completes. Author: Josh Rosen Closes #9624 from JoshRosen/SPARK-9866. --- .../apache/spark/sql/hive/client/VersionsSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c6d034a23a1c6..7bc13bc60d30e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -36,10 +36,12 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - // Do not use a temp path here to speed up subsequent executions of the unit test during - // development. - private val ivyPath = Some( - new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } private def buildConf() = { lazy val warehousePath = Utils.createTempDir() From 105745645b12afbbc2a350518cb5853a88944183 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 23 Nov 2015 17:11:51 -0800 Subject: [PATCH 0977/2089] [SPARK-10560][PYSPARK][MLLIB][DOCS] Make StreamingLogisticRegressionWithSGD Python API equal to Scala one This is to bring the API documentation of StreamingLogisticReressionWithSGD and StreamingLinearRegressionWithSGC in line with the Scala versions. -Fixed the algorithm descriptions -Added default values to parameter descriptions -Changed StreamingLogisticRegressionWithSGD regParam to default to 0, as in the Scala version Author: Bryan Cutler Closes #9141 from BryanCutler/StreamingLogisticRegressionWithSGD-python-api-sync. --- python/pyspark/mllib/classification.py | 37 +++++++++++++++++--------- python/pyspark/mllib/regression.py | 32 ++++++++++++++-------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index aab4015ba80f8..9e6f17ef6e942 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -652,21 +652,34 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a batch of data. - - The weights obtained at the end of training a stream are used as initial - weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Number of iterations run for each batch of data. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param regParam: L2 Regularization parameter. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a logistic regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream. + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param regParam: + L2 Regularization parameter. + (default: 0.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 6f00d1df209c0..13b3397501c0b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -734,17 +734,27 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LinearRegression with SGD on a batch of data. - - The problem minimized is (1 / n_samples) * (y - weights'X)**2. - After training on a batch of data, the weights obtained at the end of - training are used as initial weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Total number of iterations run. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a linear regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ From 026ea2eab1f3cde270e8a6391d002915f3e1c6e5 Mon Sep 17 00:00:00 2001 From: Stephen Samuel Date: Mon, 23 Nov 2015 19:52:12 -0800 Subject: [PATCH 0978/2089] Updated sql programming guide to include jdbc fetch size Author: Stephen Samuel Closes #9377 from sksamuel/master. --- docs/sql-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e347754055e79..d7b205c2fa0df 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1820,6 +1820,7 @@ the Data Sources API. The following options are supported: register itself with the JDBC subsystem. + partitionColumn, lowerBound, upperBound, numPartitions @@ -1831,6 +1832,13 @@ the Data Sources API. The following options are supported: partitioned and returned. + + + fetchSize + + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). + +
    From 8d57524662fad4a0760f3bc924e690c2a110e7f7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Nov 2015 22:22:15 -0800 Subject: [PATCH 0979/2089] [SPARK-11933][SQL] Rename mapGroup -> mapGroups and flatMapGroup -> flatMapGroups. Based on feedback from Matei, this is more consistent with mapPartitions in Spark. Also addresses some of the cleanups from a previous commit that renames the type variables. Author: Reynold Xin Closes #9919 from rxin/SPARK-11933. --- ...nction.java => FlatMapGroupsFunction.java} | 2 +- ...upFunction.java => MapGroupsFunction.java} | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 36 +++++++++---------- .../apache/spark/sql/JavaDatasetSuite.java | 10 +++--- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +-- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 6 files changed, 33 insertions(+), 33 deletions(-) rename core/src/main/java/org/apache/spark/api/java/function/{FlatMapGroupFunction.java => FlatMapGroupsFunction.java} (93%) rename core/src/main/java/org/apache/spark/api/java/function/{MapGroupFunction.java => MapGroupsFunction.java} (93%) diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index 18a2d733ca70d..d7a80e7b129b0 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -23,6 +23,6 @@ /** * A function that returns zero or more output records from each grouping key and its values. */ -public interface FlatMapGroupFunction extends Serializable { +public interface FlatMapGroupsFunction extends Serializable { Iterable call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java index 4f3f222e064bb..faa59eabc8b4f 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -23,6 +23,6 @@ /** * Base interface for a map function used in GroupedDataset's mapGroup function. */ -public interface MapGroupFunction extends Serializable { +public interface MapGroupsFunction extends Serializable { R call(K key, Iterator values) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7f43ce16901b9..793a86b132907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.expressions.Aggregator @Experimental class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[V], + vEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -53,12 +53,12 @@ class GroupedDataset[K, V] private[sql]( // queryexecution. private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedTEncoder = encoderFor(tEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedTEncoder = - unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -76,7 +76,7 @@ class GroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], - unresolvedTEncoder, + unresolvedVEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -110,13 +110,13 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( f, resolvedKEncoder, - resolvedTEncoder, + resolvedVEncoder, groupingAttributes, logicalPlan)) } @@ -138,8 +138,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -158,9 +158,9 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) - flatMapGroup(func) + flatMapGroups(func) } /** @@ -179,8 +179,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroup((key, data) => f.call(key, data.asJava))(encoder) + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) } /** @@ -192,8 +192,8 @@ class GroupedDataset[K, V] private[sql]( def reduce(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMapGroup(func) + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) } /** @@ -213,7 +213,7 @@ class GroupedDataset[K, V] private[sql]( private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => - tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes) + tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) case _ => c } @@ -227,7 +227,7 @@ class GroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map( - _.withInputType(resolvedTEncoder, dataAttributes).named) + _.withInputType(resolvedVEncoder, dataAttributes).named) val keyColumn = if (groupingAttributes.length > 1) { Alias(CreateStruct(groupingAttributes), "key")() } else { @@ -304,7 +304,7 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedTEncoder + implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index cf335efdd23b8..67a3190cb7d4f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.mapGroup(new MapGroupFunction() { + Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,8 +183,8 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMapGroup( - new FlatMapGroupFunction() { + Dataset flatMapped = grouped.flatMapGroups( + new FlatMapGroupsFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -249,8 +249,8 @@ public void testGroupByColumn() { GroupedDataset grouped = ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.mapGroup( - new MapGroupFunction() { + Dataset mapped = grouped.mapGroups( + new MapGroupsFunction() { @Override public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index d387710357be0..f75d0961823c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroup { case (g, iter) => + val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc8e4325fd2f5..dbdd7ba14a5b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMapGroup { case (g, iter) => + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } @@ -255,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -265,7 +265,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -275,7 +275,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -285,7 +285,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 6cf51a7007bd72eb93ade149ca9fc53be5b32a17 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 23 Nov 2015 22:22:50 -0800 Subject: [PATCH 0980/2089] [SPARK-11903] Remove --skip-java-test Per [pwendell's comments on SPARK-11903](https://issues.apache.org/jira/browse/SPARK-11903?focusedCommentId=15021511&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15021511) I'm removing this dead code. If we are concerned about preserving compatibility, I can instead leave the option in and add a warning. For example: ```sh echo "Warning: '--skip-java-test' is deprecated and has no effect." ;; ``` cc pwendell, srowen Author: Nicholas Chammas Closes #9924 from nchammas/make-distribution. --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index d7d27e253f721..7b417fe7cf619 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -69,9 +69,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --skip-java-test) - SKIP_JAVA_TEST=true - ;; --with-tachyon) SPARK_TACHYON=true ;; From 4021a28ac30b65cb61cf1e041253847253a2d89f Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 23 Nov 2015 22:26:08 -0800 Subject: [PATCH 0981/2089] [SPARK-10707][SQL] Fix nullability computation in union output Author: Mikhail Bautin Closes #9308 from mbautin/SPARK-10707. --- .../plans/logical/basicOperators.scala | 11 +++++-- .../spark/sql/execution/basicOperators.scala | 9 ++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 31 +++++++++++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0c444482c5e4c..737e62fd59214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -92,8 +92,10 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - // TODO: These aren't really the same attributes as nullability etc might change. - final override def output: Seq[Attribute] = left.output + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) + } final override lazy val resolved: Boolean = childrenResolved && @@ -115,7 +117,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output +} case class Join( left: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e79092efdaa3e..d57b8e7a9ed61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -130,8 +130,13 @@ case class Sample( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - // TODO: attributes output by union should be distinct for nullability purposes - override def output: Seq[Attribute] = children.head.output + override def output: Seq[Attribute] = { + children.tail.foldLeft(children.head.output) { case (currentOutput, child) => + currentOutput.zip(child.output).map { case (a1, a2) => + a1.withNullability(a1.nullable || a2.nullable) + } + } + } override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 167aea87de077..bb82b562aaaa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1997,4 +1997,35 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } + + test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { + // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the + // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of + // UNION was incorrectly considered non-nullable: + checkAnswer( + sql("""SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), + Seq(Row(0))) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (2)") { + // This test uses RAND() to stop column pruning for Union and checks the resulting isnull + // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v" + // column of the union was considered non-nullable. + checkAnswer( + sql( + """ + |SELECT a FROM ( + | SELECT ISNULL(v) AS a, RAND() FROM ( + | SELECT 'foo' AS v UNION ALL SELECT null AS v + | ) my_union + |) my_view + """.stripMargin), + Row(false) :: Row(true) :: Nil) + } + } From 12eea834d7382fbaa9c92182b682b8724049d7c1 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Tue, 24 Nov 2015 00:07:40 -0800 Subject: [PATCH 0982/2089] [SPARK-11897][SQL] Add @scala.annotations.varargs to sql functions Author: Xiu Guo Closes #9918 from xguo27/SPARK-11897. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b27b1340cce46..6137ce3a70fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -689,6 +689,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { array((colName +: colNames).map(col) : _*) } @@ -871,6 +872,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { struct((colName +: colNames).map(col) : _*) } From 800bd799acf7f10a469d8d6537279953129eb2c6 Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Tue, 24 Nov 2015 09:03:32 +0000 Subject: [PATCH 0983/2089] [SPARK-11906][WEB UI] Speculation Tasks Cause ProgressBar UI Overflow When there are speculative tasks in the stage, running progress bar could overflow and goes hidden on a new line: ![image](https://cloud.githubusercontent.com/assets/4317392/11326841/5fd3482e-9142-11e5-8ca5-cb2f0c0c8964.png) 3 completed / 2 running (including 1 speculative) out of 4 total tasks This is a simple fix by capping the started tasks at `total - completed` tasks ![image](https://cloud.githubusercontent.com/assets/4317392/11326842/6bb67260-9142-11e5-90f0-37f9174878ec.png) I should note my preferred way to fix it is via css style ```css .progress { display: flex; } ``` which shifts the correction burden from driver to web browser. However I couldn't get selenium test to measure the position/dimension of the progress bar correctly to get this unit tested. It also has the side effect that the width will be calibrated so the running occupies 2 / 5 instead of 1 / 4. ![image](https://cloud.githubusercontent.com/assets/4317392/11326848/7b03e9f0-9142-11e5-89ad-bd99cb0647cf.png) All in all, since this cosmetic bug is minor enough, I suppose the original simple fix should be good enough. Author: Forest Fang Closes #9896 from saurfang/progressbar. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 4 +++- .../test/scala/org/apache/spark/ui/UIUtilsSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 25dcb604d9e5f..84a1116a5c498 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -319,7 +319,9 @@ private[spark] object UIUtils extends Logging { skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) + // started + completed can be > total when there are speculative tasks + val boundedStarted = math.min(started, total - completed) + val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
    diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 2b693c165180f..dd8d5ec27f87e 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -57,6 +57,16 @@ class UIUtilsSuite extends SparkFunSuite { ) } + test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { + val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val expected = Seq( +
    , +
    + ) + assert(generated.sameElements(expected), + s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") + } + private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) From d4a5e6f719079639ffd38470f4d8d1f6fde3228d Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 24 Nov 2015 23:24:49 +0800 Subject: [PATCH 0984/2089] [SPARK-11043][SQL] BugFix:Set the operator log in the thrift server. `SessionManager` will set the `operationLog` if the configuration `hive.server2.logging.operation.enabled` is true in version of hive 1.2.1. But the spark did not adapt to this change, so no matter enabled the configuration or not, spark thrift server will always log the warn message. PS: if `hive.server2.logging.operation.enabled` is false, it should log the warn message (the same as hive thrift server). Author: huangzhaowei Closes #9056 from SaintBacchus/SPARK-11043. --- .../SparkExecuteStatementOperation.scala | 8 ++++---- .../thriftserver/SparkSQLSessionManager.scala | 5 +++++ .../thriftserver/HiveThriftServer2Suites.scala | 16 +++++++++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 82fef92dcb73b..e022ee86a763a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -134,12 +134,12 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = resultSchema - override def run(): Unit = { + override def runInternal(): Unit = { setState(OperationState.PENDING) setHasResultSet(true) // avoid no resultset for async run if (!runInBackground) { - runInternal() + execute() } else { val sparkServiceUGI = Utils.getUGI() @@ -151,7 +151,7 @@ private[hive] class SparkExecuteStatementOperation( val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { try { - runInternal() + execute() } catch { case e: HiveSQLException => setOperationException(e) @@ -188,7 +188,7 @@ private[hive] class SparkExecuteStatementOperation( } } - override def runInternal(): Unit = { + private def execute(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index af4fcdf021bd4..de4e9c62b57a4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -41,6 +41,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) + // Create operation log root directory, if operation logging is enabled + if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + invoke(classOf[SessionManager], this, "initOperationLogRootDir") + } + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) getAncestorField[Log](this, 3, "LOG").info( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1dd898aa38350..139d8e897ba1d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} +import scala.io.Source import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -507,6 +508,12 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(rs2.getInt(2) === 500) } } + + test("SPARK-11043 check operation log root directory") { + val expectedLine = + "Operation log root directory is created: " + operationLogPath.getAbsoluteFile + assert(Source.fromFile(logPath).getLines().exists(_.contains(expectedLine))) + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -642,7 +649,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") - private var logPath: File = _ + protected var logPath: File = _ + protected var operationLogPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -679,6 +687,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -706,6 +715,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl warehousePath.delete() metastorePath = Utils.createTempDir() metastorePath.delete() + operationLogPath = Utils.createTempDir() + operationLogPath.delete() logPath = null logTailingProcess = null @@ -782,6 +793,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() metastorePath = null + operationLogPath.delete() + operationLogPath = null + Option(logPath).foreach(_.delete()) logPath = null From 5889880fbe9628681042036892ef7ebd4f0857b4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 24 Nov 2015 23:32:05 +0800 Subject: [PATCH 0985/2089] [SPARK-11592][SQL] flush spark-sql command line history to history file Currently, `spark-sql` would not flush command history when exiting. Author: Daoyuan Wang Closes #9563 from adrian-wang/jline. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6419002a2aa89..4b928e600b355 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -194,6 +194,22 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // add shutdown hook to flush the history to history file + Runtime.getRuntime.addShutdownHook(new Thread(new Runnable() { + override def run() = { + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + })) + // TODO: missing /* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") From be9dd1550c1816559d3d418a19c692e715f1c94e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 24 Nov 2015 09:20:09 -0800 Subject: [PATCH 0986/2089] =?UTF-8?q?[SPARK-11818][REPL]=20Fix=20ExecutorC?= =?UTF-8?q?lassLoader=20to=20lookup=20resources=20from=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …parent class loader Without patch, two additional tests of ExecutorClassLoaderSuite fails. - "resource from parent" - "resources from parent" Detailed explanation is here, https://issues.apache.org/jira/browse/SPARK-11818?focusedCommentId=15011202&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15011202 Author: Jungtaek Lim Closes #9812 from HeartSaVioR/SPARK-11818. --- .../spark/repl/ExecutorClassLoader.scala | 12 +++++++- .../spark/repl/ExecutorClassLoaderSuite.scala | 29 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a976e96809cb8..a8859fcd4584b 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -34,7 +34,9 @@ import org.apache.spark.util.ParentClassLoader /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. - * Allows the user to specify if user class path should be first + * Allows the user to specify if user class path should be first. + * This class loader delegates getting/finding resources to parent loader, + * which makes sense until REPL never provide resource dynamically. */ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { @@ -55,6 +57,14 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + override def findClass(name: String): Class[_] = { userClassPathFirst match { case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index a58eda12b1120..c1211f7596b9c 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -19,8 +19,13 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import java.nio.charset.StandardCharsets +import java.util + +import com.google.common.io.Files import scala.concurrent.duration._ +import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps @@ -41,6 +46,7 @@ class ExecutorClassLoaderSuite val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val parentResourceNames = List("fake-resource.txt") var tempDir1: File = _ var tempDir2: File = _ var url1: String = _ @@ -54,6 +60,9 @@ class ExecutorClassLoaderSuite url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentResourceNames.foreach { x => + Files.write("resource".getBytes(StandardCharsets.UTF_8), new File(tempDir2, x)) + } parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) } @@ -99,6 +108,26 @@ class ExecutorClassLoaderSuite } } + test("resource from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val is = classLoader.getResourceAsStream(resourceName) + assert(is != null, s"Resource $resourceName not found") + val content = Source.fromInputStream(is, "UTF-8").getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } + + test("resources from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) + assert(resources.hasMoreElements, s"Resource $resourceName not found") + val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class // from the driver's class server would leak a HTTP connection, causing the class server's From e5aaae6e1145b8c25c4872b2992ab425da9c6f9b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 09:28:39 -0800 Subject: [PATCH 0987/2089] [SPARK-11942][SQL] fix encoder life cycle for CoGroup we should pass in resolved encodera to logical `CoGroup` and bind them in physical `CoGroup` Author: Wenchen Fan Closes #9928 from cloud-fan/cogroup. --- .../plans/logical/basicOperators.scala | 27 ++++++++++--------- .../org/apache/spark/sql/GroupedDataset.scala | 4 ++- .../spark/sql/execution/basicOperators.scala | 20 +++++++------- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++++++++ 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 737e62fd59214..5665fd7e5f419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -553,19 +553,22 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], + def apply[Key, Left, Right, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup[K, Left, Right, R] = { + right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { CoGroup( func, - encoderFor[K], - encoderFor[Left], - encoderFor[Right], - encoderFor[R], - encoderFor[R].schema.toAttributes, + keyEnc, + leftEnc, + rightEnc, + encoderFor[Result], + encoderFor[Result].schema.toAttributes, leftGroup, rightGroup, left, @@ -577,12 +580,12 @@ object CoGroup { * A relation produced by applying `func` to each grouping key and associated values from left and * right children. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 793a86b132907..a10a89342fb5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( f, + this.resolvedKEncoder, + this.resolvedVEncoder, + other.resolvedVEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d57b8e7a9ed61..a42aea0b96d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -375,12 +375,12 @@ case class MapGroups[K, T, U]( * iterators containing all elements in the group from left and right side. * The result of this function is encoded and flattened before being output. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], @@ -397,15 +397,17 @@ case class CoGroup[K, Left, Right, R]( left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) + val boundKeyEnc = keyEnc.bind(leftGroup) + val boundLeftEnc = leftEnc.bind(left.output) + val boundRightEnc = rightEnc.bind(right.output) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) - result.map(rEncoder.toRow) + boundKeyEnc.fromRow(key), + leftResult.map(boundLeftEnc.fromRow), + rightResult.map(boundRightEnc.fromRow)) + result.map(resultEnc.toRow) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dbdd7ba14a5b7..13eede1b17d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } + test("cogroup with complex data") { + val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() + val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a", 2 -> "bc", 3 -> "d") + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") From 56a0aba0a60326ba026056c9a23f3f6ec7258c19 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 24 Nov 2015 09:52:53 -0800 Subject: [PATCH 0988/2089] [SPARK-11952][ML] Remove duplicate ml examples Remove duplicate ml examples (only for ml). mengxr Author: Yanbo Liang Closes #9933 from yanboliang/SPARK-11685. --- .../main/python/ml/gradient_boosted_trees.py | 82 ----------------- .../src/main/python/ml/logistic_regression.py | 66 -------------- .../main/python/ml/random_forest_example.py | 87 ------------------- 3 files changed, 235 deletions(-) delete mode 100644 examples/src/main/python/ml/gradient_boosted_trees.py delete mode 100644 examples/src/main/python/ml/logistic_regression.py delete mode 100644 examples/src/main/python/ml/random_forest_example.py diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py deleted file mode 100644 index c3bf8aa2eb1e6..0000000000000 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import GBTRegressor -from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. -Note: GBTClassifier only supports binary classification currently -Run with: - bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py -""" - - -def testClassification(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = BinaryClassificationMetrics(predictionAndLabels) - print("AUC %.3f" % metrics.areaUnderROC) - - -def testRegression(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGBTExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py deleted file mode 100644 index 4cd027fdfbe8a..0000000000000 --- a/examples/src/main/python/ml/logistic_regression.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.evaluation import MulticlassMetrics -from pyspark.ml.feature import StringIndexer -from pyspark.sql import SQLContext - -""" -A simple example demonstrating a logistic regression with elastic net regularization Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression.py -""" - -if __name__ == "__main__": - - if len(sys.argv) > 1: - print("Usage: logistic_regression", file=sys.stderr) - exit(-1) - - sc = SparkContext(appName="PythonLogisticRegressionExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [training, test] = td.randomSplit([0.7, 0.3]) - - lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") - lr.setElasticNetParam(0.8) - - # Fit the model - lrModel = lr.fit(training) - - predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py deleted file mode 100644 index dc6a778670193..0000000000000 --- a/examples/src/main/python/ml/random_forest_example.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import RandomForestRegressor -from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a RandomForest Classification/Regression Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/random_forest_example.py -""" - - -def testClassification(train, test): - # Train a RandomForest model. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - # Note: Use larger numTrees in practice. - - rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - -def testRegression(train, test): - # Train a RandomForest model. - # Note: Use larger numTrees in practice. - - rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() From 9e24ba667e43290fbaa3cacb93cf5d9be790f1fd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 24 Nov 2015 09:54:55 -0800 Subject: [PATCH 0989/2089] [SPARK-11521][ML][DOC] Document that Logistic, Linear Regression summaries ignore weight col Doc for 1.6 that the summaries mostly ignore the weight column. To be corrected for 1.7 CC: mengxr thunterdb Author: Joseph K. Bradley Closes #9927 from jkbradley/linregsummary-doc. --- .../ml/classification/LogisticRegression.scala | 18 ++++++++++++++++++ .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 418bbdc9a058f..d320d64dd90d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -755,23 +755,35 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the receiver operating characteristic (ROC) curve, * which is an Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** * Computes the area under the receiver operating characteristic (ROC) curve. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** * Returns the precision-recall curve, which is an Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") @@ -781,6 +793,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, precision) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") @@ -790,6 +805,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, recall) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70ccec766c471..1db91666f21ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -540,6 +540,9 @@ class LinearRegressionSummary private[regression] ( * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -547,6 +550,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -554,6 +560,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -561,6 +570,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -568,6 +580,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 From 52bc25c8e26d4be250d8ff7864067528f4f98592 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 24 Nov 2015 09:56:17 -0800 Subject: [PATCH 0990/2089] [SPARK-11847][ML] Model export/import for spark.ml: LDA Add read/write support to LDA, similar to ALS. save/load for ml.LocalLDAModel is done. For DistributedLDAModel, I'm not sure if we can invoke save on the mllib.DistributedLDAModel directly. I'll send update after some test. Author: Yuhao Yang Closes #9894 from hhbyyh/ldaMLsave. --- .../org/apache/spark/ml/clustering/LDA.scala | 110 +++++++++++++++++- .../spark/mllib/clustering/LDAModel.scala | 4 +- .../apache/spark/ml/clustering/LDASuite.scala | 44 ++++++- 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 92e05815d6a3d..830510b1698d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, @Since("1.6.0") @transient protected val sqlContext: SQLContext) - extends Model[LDAModel] with LDAParams with Logging { + extends Model[LDAModel] with LDAParams with Logging with MLWritable { // NOTE to developers: // This abstraction should contain all important functionality for basic LDA usage. @@ -486,6 +487,64 @@ class LocalLDAModel private[ml] ( @Since("1.6.0") override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) } @@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] ( */ @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) } @@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") @Experimental class LDA @Since("1.6.0") ( - @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("lda")) @@ -695,7 +794,7 @@ class LDA @Since("1.6.0") ( } -private[clustering] object LDA { +private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { @@ -706,4 +805,7 @@ private[clustering] object LDA { (docId, features) } } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index cd520f09bd466..7384d065a2ea8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") -class LocalLDAModel private[clustering] ( +class LocalLDAModel private[spark] ( @Since("1.3.0") val topics: Matrix, @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double = 100) + override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { @Since("1.3.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b634d31cc34f0..97dbfd9a4314a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -39,10 +40,24 @@ object LDASuite { }.map(v => new TestRow(v)) sql.createDataFrame(rdd) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051 + ) } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { val k: Int = 5 val vocabSize: Int = 30 @@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val lp = model.logPrior assert(lp <= 0.0 && lp != Double.NegativeInfinity) } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } } From 19530da6903fa59b051eec69b9c17e231c68454b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 11:09:01 -0800 Subject: [PATCH 0991/2089] [SPARK-11926][SQL] unify GetStructField and GetInternalRowField Author: Wenchen Fan Closes #9909 from cloud-fan/get-struct. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 8 +++---- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/complexTypeExtractors.scala | 18 ++++++++-------- .../expressions/namedExpressions.scala | 4 ++-- .../sql/catalyst/expressions/objects.scala | 21 ------------------- .../expressions/ComplexTypeSuite.scala | 4 ++-- 9 files changed, 21 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 476becec4dd52..d133ad3f0d89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) + .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) /** Returns the current path or `BoundReference`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6485bdfb30234..1b2a8dc4c7f14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. attribute.get.dataType match { - case s: StructType => { - s.fields.map( f => { - val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) Alias(extract, target.get + "." + f.name)() - }) } + case _ => { throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7bc9aed0b204e..0c10a56c555f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -111,7 +111,7 @@ object ExpressionEncoder { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index fa553e7c5324c..67518f52d4a58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -220,7 +220,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetInternalRowField(input, i, f.dataType))) + constructorFor(GetStructField(input, i))) } CreateExternalRow(convertedFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 540ed3500616a..169435a10ea2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyString: String = { transform { - case a: AttributeReference => PrettyAttribute(a.name) + case a: AttributeReference => PrettyAttribute(a.name, a.dataType) case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f871b737fff3a..10ce10aaf6da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -51,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -97,18 +97,18 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. - * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - override def dataType: DataType = child.dataType match { - case s: StructType => s(ordinal).dataType - // This is a hack to avoid breaking existing code until we remove the need for the struct field - case _ => field.dataType - } + private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + + override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, field.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00b7970bd16c6..26b6aca79971e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -273,7 +273,8 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute(name: String, dataType: DataType = NullType) + extends Attribute with Unevaluable { override def toString: String = name @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4a1f419f0ad8d..62d09f0f55105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { } } -case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) - extends UnaryExpression { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ - }) - } -} - /** * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb423f..62fd47234b33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } From 81012546ee5a80d2576740af0dad067b0f5962c5 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 24 Nov 2015 12:22:33 -0800 Subject: [PATCH 0992/2089] [SPARK-11872] Prevent the call to SparkContext#stop() in the listener bus's thread This is continuation of SPARK-11761 Andrew suggested adding this protection. See tail of https://github.com/apache/spark/pull/9741 Author: tedyu Closes #9852 from tedyu/master. --- .../scala/org/apache/spark/SparkContext.scala | 4 +++ .../spark/scheduler/SparkListenerSuite.scala | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b153a7b08e590..e19ba113702c6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1694,6 +1694,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Shut down the SparkContext. def stop() { + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop SparkContext within listener thread of" + + " AsynchronousListenerBus") + } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. if (!stopped.compareAndSet(false, true)) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 84e545851f49e..f20d5be7c0ee0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers +import org.apache.spark.SparkException import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} @@ -36,6 +37,21 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + test("don't call sc.stop in listener") { + sc = new SparkContext("local", "SparkListenerSuite") + val listener = new SparkContextStoppingListener(sc) + val bus = new LiveListenerBus + bus.addListener(listener) + + // Starting listener bus should flush all buffered events + bus.start(sc) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + bus.stop() + assert(listener.sparkExSeen) + } + test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -443,6 +459,21 @@ private class BasicJobCounter extends SparkListener { override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } +/** + * A simple listener that tries to stop SparkContext. + */ +private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { + @volatile var sparkExSeen = false + override def onJobEnd(job: SparkListenerJobEnd): Unit = { + try { + sc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} + private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 From f3152722791b163fa66597b3684009058195ba33 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 12:54:37 -0800 Subject: [PATCH 0993/2089] [SPARK-11946][SQL] Audit pivot API for 1.6. Currently pivot's signature looks like ```scala scala.annotation.varargs def pivot(pivotColumn: Column, values: Column*): GroupedData scala.annotation.varargs def pivot(pivotColumn: String, values: Any*): GroupedData ``` I think we can remove the one that takes "Column" types, since callers should always be passing in literals. It'd also be more clear if the values are not varargs, but rather Seq or java.util.List. I also made similar changes for Python. Author: Reynold Xin Closes #9929 from rxin/SPARK-11946. --- .../apache/spark/scheduler/DAGScheduler.scala | 1 - python/pyspark/sql/group.py | 12 +- .../sql/catalyst/expressions/literals.scala | 1 + .../org/apache/spark/sql/GroupedData.scala | 154 ++++++++++-------- .../apache/spark/sql/JavaDataFrameSuite.java | 16 ++ .../spark/sql/DataFramePivotSuite.scala | 21 +-- .../apache/spark/sql/test/SQLTestData.scala | 1 + 7 files changed, 125 insertions(+), 81 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ae725b467d8c4..77a184dfe4bee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1574,7 +1574,6 @@ class DAGScheduler( } def stop() { - logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 227f40bc3cf53..d8ed7eb2dda64 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -168,20 +168,24 @@ def sum(self, *cols): """ @since(1.6) - def pivot(self, pivot_col, *values): + def pivot(self, pivot_col, values=None): """Pivots a column of the current DataFrame and preform the specified aggregation. :param pivot_col: Column to pivot :param values: Optional list of values of pivotColumn that will be translated to columns in the output data frame. If values are not provided the method with do an immediate call to .distinct() on the pivot column. - >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect() + + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ - jgd = self._jdf.pivot(_to_java_column(pivot_col), - _to_seq(self.sql_ctx._sc, values, _create_column_from_literal)) + if values is None: + jgd = self._jdf.pivot(pivot_col) + else: + jgd = self._jdf.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e34fd49be8389..68ec688c99f93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -44,6 +44,7 @@ object Literal { case a: Array[Byte] => Literal(a, BinaryType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) + case v: Literal => v case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 63dd7fbcbe9e4..ee7150cbbfbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.NumericType /** @@ -282,74 +282,96 @@ class GroupedData protected[sql]( } /** - * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified - * aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) - * // Or without specifying column values - * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { - case _: GroupedData.PivotType => - throw new UnsupportedOperationException("repeated pivots are not supported") - case GroupedData.GroupByType => - val pivotValues = if (values.nonEmpty) { - values.map { - case Column(literal: Literal) => literal - case other => - throw new UnsupportedOperationException( - s"The values of a pivot must be literals, found $other") - } - } else { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .sort(pivotColumn) - .map(_.get(0)) - .take(maxValues + 1) - .map(Literal(_)).toSeq - if (values.length > maxValues) { - throw new RuntimeException( - s"The pivot column $pivotColumn has more than $maxValues distinct values, " + - "this could indicate an error. " + - "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + - s"to at least the number of distinct values of the pivot column.") - } - values - } - new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) - case _ => - throw new UnsupportedOperationException("pivot is only supported after a groupBy") + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): GroupedData = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") - * // Or without specifying column values - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: String, values: Any*): GroupedData = { - val resolvedPivotColumn = Column(df.resolve(pivotColumn)) - pivot(resolvedPivotColumn, values.map(functions.lit): _*) + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + groupType match { + case GroupedData.GroupByType => + new GroupedData( + df, + groupingExprs, + GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + pivot(pivotColumn, values.asScala) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 567bdddece80e..a12fed3c0c6af 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -282,4 +282,20 @@ public void testSampleBy() { Assert.assertEquals(1, actual[1].getLong(0)); Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); } + + @Test + public void pivot() { + DataFrame df = context.table("courseSales"); + Row[] actual = df.groupBy("year") + .pivot("course", Arrays.asList("dotNET", "Java")) + .agg(sum("earnings")).orderBy("year").collect(); + + Assert.assertEquals(2012, actual[0].getInt(0)); + Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + + Assert.assertEquals(2013, actual[1].getInt(0)); + Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 0c23d142670c1..fc53aba68ebb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with literals") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) @@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with literals") { checkAnswer( - courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot courses with literals and multiple aggregations") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil @@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with string values (cast)") { checkAnswer( - courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot year with int values") { checkAnswer( - courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } @@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with no values") { // Note Java comes before dotNet in sorted order checkAnswer( - courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil ) } test("pivot year with no values") { checkAnswer( - courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot max values inforced") { + test("pivot max values enforced") { sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) - intercept[RuntimeException]( - courseSales.groupBy($"year").pivot($"course") + intercept[AnalysisException]( + courseSales.groupBy("year").pivot("course") ) sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index abad0d7eaaedf..83c63e04f344a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self => person salary complexData + courseSales } } From e6dd237463d2de8c506f0735dfdb3f43e8122513 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Nov 2015 15:08:02 -0600 Subject: [PATCH 0994/2089] [SPARK-11929][CORE] Make the repl log4j configuration override the root logger. In the default Spark distribution, there are currently two separate log4j config files, with different default values for the root logger, so that when running the shell you have a different default log level. This makes the shell more usable, since the logs don't overwhelm the output. But if you install a custom log4j.properties, you lose that, because then it's going to be used no matter whether you're running a regular app or the shell. With this change, the overriding of the log level is done differently; the log level repl's main class (org.apache.spark.repl.Main) is used to define the root logger's level when running the shell, defaulting to WARN if it's not set explicitly. On a somewhat related change, the shell output about the "sc" variable was changed a bit to contain a little more useful information about the application, since when the root logger's log level is WARN, that information is never shown to the user. Author: Marcelo Vanzin Closes #9816 from vanzin/shell-logging. --- conf/log4j.properties.template | 5 +++ .../spark/log4j-defaults-repl.properties | 33 -------------- .../apache/spark/log4j-defaults.properties | 5 +++ .../main/scala/org/apache/spark/Logging.scala | 45 ++++++++++--------- .../apache/spark/repl/SparkILoopInit.scala | 21 ++++----- .../org/apache/spark/repl/SparkILoop.scala | 25 ++++++----- 6 files changed, 57 insertions(+), 77 deletions(-) delete mode 100644 core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index f3046be54d7c6..9809b0c828487 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties deleted file mode 100644 index c85abc35b93bf..0000000000000 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - -# Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR -log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO -log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO - -# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support -log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL -log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index d44cc85dcbd82..0750488e4adf9 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 69f6e06ee0057..e35e158c7e8a6 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.log4j.{LogManager, PropertyConfigurator} +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder @@ -119,30 +119,31 @@ trait Logging { val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + // scalastyle:off println if (!log4j12Initialized) { - // scalastyle:off println - if (Utils.isInInterpreter) { - val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" - Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") - System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") - case None => - System.err.println(s"Spark was unable to load $replDefaultLogProps") - } - } else { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") - } + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") } - // scalastyle:on println } + + if (Utils.isInInterpreter) { + // Use the repl's main class to define the default log level when running the shell, + // overriding the root logger's config if they're different. + val rootLogger = LogManager.getRootLogger() + val replLogger = LogManager.getLogger("org.apache.spark.repl.Main") + val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + if (replLevel != rootLogger.getEffectiveLevel()) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel).") + rootLogger.setLevel(replLevel) + } + } + // scalastyle:on println } Logging.initialized = true diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index bd3314d94eed6..99e1e1df33fd8 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,18 +123,19 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) command(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 33d262558b1fc..e91139fb29f69 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -37,18 +37,19 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) processLine(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) processLine("import org.apache.spark.SparkContext._") processLine("import sqlContext.implicits._") @@ -85,7 +86,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = sparkStandardCommands - /** + /** * We override `loadFiles` because we need to initialize Spark *before* the REPL * sees any files, so that the Spark context is visible in those files. This is a bit of a * hack, but there isn't another hook available to us at this point. @@ -98,7 +99,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) object SparkILoop { - /** + /** * Creates an interpreter loop with default settings and feeds * the given code to it as input. */ From 58d9b260556a89a3d0832d583acafba1df7c6751 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Nov 2015 14:33:28 -0800 Subject: [PATCH 0995/2089] [SPARK-11805] free the array in UnsafeExternalSorter during spilling After calling spill() on SortedIterator, the array inside InMemorySorter is not needed, it should be freed during spilling, this could help to join multiple tables with limited memory. Author: Davies Liu Closes #9793 from davies/free_array. --- .../unsafe/sort/UnsafeExternalSorter.java | 10 +++--- .../unsafe/sort/UnsafeInMemorySorter.java | 31 ++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9a7b2ad06cab6..2e40312674737 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -468,6 +468,12 @@ public long spill() throws IOException { } allocatedPages.clear(); } + + // in-memory sorter will not be used after spilling + assert(inMemSorter != null); + released += inMemSorter.getMemoryUsage(); + inMemSorter.free(); + inMemSorter = null; return released; } } @@ -489,10 +495,6 @@ public void loadNext() throws IOException { } upstream = nextUpstream; nextUpstream = null; - - assert(inMemSorter != null); - inMemSorter.free(); - inMemSorter = null; } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f46..dce1f15a2963c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -108,6 +108,7 @@ public UnsafeInMemorySorter( */ public void free() { consumer.freeArray(array); + array = null; } public void reset() { @@ -160,28 +161,22 @@ public void insertRecord(long recordPointer, long keyPrefix) { pos++; } - public static final class SortedIterator extends UnsafeSorterIterator { + public final class SortedIterator extends UnsafeSorterIterator { - private final TaskMemoryManager memoryManager; - private final int sortBufferInsertPosition; - private final LongArray sortBuffer; - private int position = 0; + private final int numRecords; + private int position; private Object baseObject; private long baseOffset; private long keyPrefix; private int recordLength; - private SortedIterator( - TaskMemoryManager memoryManager, - int sortBufferInsertPosition, - LongArray sortBuffer) { - this.memoryManager = memoryManager; - this.sortBufferInsertPosition = sortBufferInsertPosition; - this.sortBuffer = sortBuffer; + private SortedIterator(int numRecords) { + this.numRecords = numRecords; + this.position = 0; } public SortedIterator clone () { - SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; iter.baseOffset = baseOffset; @@ -192,21 +187,21 @@ public SortedIterator clone () { @Override public boolean hasNext() { - return position < sortBufferInsertPosition; + return position / 2 < numRecords; } public int numRecordsLeft() { - return (sortBufferInsertPosition - position) / 2; + return numRecords - position / 2; } @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer.get(position); + final long recordPointer = array.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer.get(position + 1); + keyPrefix = array.get(position + 1); position += 2; } @@ -229,6 +224,6 @@ public void loadNext() { */ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); - return new SortedIterator(memoryManager, pos, array); + return new SortedIterator(pos / 2); } } From 34ca392da7097a1fbe48cd6c3ebff51453ca26ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 14:51:01 -0800 Subject: [PATCH 0996/2089] Added a line of comment to explain why the extra sort exists in pivot. --- sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee7150cbbfbca..abd531c4ba541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -304,7 +304,7 @@ class GroupedData protected[sql]( // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() - .sort(pivotColumn) + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order .map(_.get(0)) .take(maxValues + 1) .toSeq From c7f95df5c6d8eb2e6f11cf58b704fea34326a5f2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Nov 2015 14:59:14 -0800 Subject: [PATCH 0997/2089] [SPARK-11783][SQL] Fixes execution Hive client when using remote Hive metastore When using remote Hive metastore, `hive.metastore.uris` is set to the metastore URI. However, it overrides `javax.jdo.option.ConnectionURL` unexpectedly, thus the execution Hive client connects to the actual remote Hive metastore instead of the Derby metastore created in the temporary directory. Cleaning this configuration for the execution Hive client fixes this issue. Author: Cheng Lian Closes #9895 from liancheng/spark-11783.clean-remote-metastore-config. --- .../org/apache/spark/sql/hive/HiveContext.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0bb5af7d5c85..8a4264194ae8d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -736,6 +736,21 @@ private[hive] object HiveContext { s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + + // SPARK-11783: When "hive.metastore.uris" is set, the metastore connection mode will be + // remote (https://cwiki.apache.org/confluence/display/Hive/AdminManual+MetastoreAdmin + // mentions that "If hive.metastore.uris is empty local mode is assumed, remote otherwise"). + // Remote means that the metastore server is running in its own process. + // When the mode is remote, configurations like "javax.jdo.option.ConnectionURL" will not be + // used (because they are used by remote metastore server that talks to the database). + // Because execution Hive should always connects to a embedded derby metastore. + // We have to remove the value of hive.metastore.uris. So, the execution Hive client connects + // to the actual embedded derby metastore instead of the remote metastore. + // You can search HiveConf.ConfVars.METASTOREURIS in the code of HiveConf (in Hive's repo). + // Then, you will find that the local metastore mode is only set to true when + // hive.metastore.uris is not set. + propMap.put(ConfVars.METASTOREURIS.varname, "") + propMap.toMap } From 238ae51b66ac12d15fba6aff061804004c5ca6cb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 24 Nov 2015 15:54:10 -0800 Subject: [PATCH 0998/2089] [SPARK-11914][SQL] Support coalesce and repartition in Dataset APIs This PR is to provide two common `coalesce` and `repartition` in Dataset APIs. After reading the comments of SPARK-9999, I am unclear about the plan for supporting re-partitioning in Dataset APIs. Currently, both RDD APIs and Dataframe APIs provide users such a flexibility to control the number of partitions. In most traditional RDBMS, they expose the number of partitions, the partitioning columns, the table partitioning methods to DBAs for performance tuning and storage planning. Normally, these parameters could largely affect the query performance. Since the actual performance depends on the workload types, I think it is almost impossible to automate the discovery of the best partitioning strategy for all the scenarios. I am wondering if Dataset APIs are planning to hide these APIs from users? Feel free to reject my PR if it does not match the plan. Thank you for your answers. marmbrus rxin cloud-fan Author: gatorsmile Closes #9899 from gatorsmile/coalesce. --- .../scala/org/apache/spark/sql/Dataset.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 07647508421a4..17e2611790d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -152,6 +152,25 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * @since 1.6.0 + */ + def repartition(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = true, _) + } + + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @since 1.6.0 + */ + def coalesce(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = false, _) + } + /* *********************** * * Functional Operations * * *********************** */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 13eede1b17d8b..c253fdbb8c99e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -52,6 +52,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.takeAsList(1).get(0) == item) } + test("coalesce, repartition") { + val data = (1 to 100).map(i => ClassData(i.toString, i)) + val ds = data.toDS() + + assert(ds.repartition(10).rdd.partitions.length == 10) + checkAnswer( + ds.repartition(10), + data: _*) + + assert(ds.coalesce(1).rdd.partitions.length == 1) + checkAnswer( + ds.coalesce(1), + data: _*) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( From 25bbd3c16e8e8be4d2c43000223d54650e9a3696 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:16:07 -0800 Subject: [PATCH 0999/2089] [SPARK-11967][SQL] Consistent use of varargs for multiple paths in DataFrameReader This patch makes it consistent to use varargs in all DataFrameReader methods, including Parquet, JSON, text, and the generic load function. Also added a few more API tests for the Java API. Author: Reynold Xin Closes #9945 from rxin/SPARK-11967. --- python/pyspark/sql/readwriter.py | 19 ++++++---- .../apache/spark/sql/DataFrameReader.scala | 36 +++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 23 ++++++++++++ sql/core/src/test/resources/text-suite2.txt | 1 + .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 5 files changed, 66 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/test/resources/text-suite2.txt diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e8f0d7ec77035..2e75f0c8a1827 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,7 +109,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes @@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options): self.options(**options) if path is not None: if type(path) == list: - paths = path - gateway = self._sqlContext._sc._gateway - jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) - return self._df(self._jreader.load(jpaths)) + return self._df( + self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load(path)) else: @@ -175,6 +174,8 @@ def json(self, path, schema=None): self.schema(schema) if isinstance(path, basestring): return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): return self._df(self._jreader.json(path._jrdd)) else: @@ -205,16 +206,20 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, path): + def text(self, paths): """Loads a text file and returns a [[DataFrame]] with a single string column named "text". Each line in the text file is a new row in the resulting DataFrame. + :param paths: string, or list of strings, for input path(s). + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ - return self._df(self._jreader.text(path)) + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) @since(1.5) def orc(self, path): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dcb3737b70fbf..3ed1e55adec6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.SqlParser import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def load(path: String): DataFrame = { option("path", path).load() } @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def load(paths: Array[String]): DataFrame = { + @scala.annotation.varargs + def load(paths: String*): DataFrame = { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • * - * @param path input path * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def json(path: String): DataFrame = format("json").load(path) + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + * + * @since 1.6.0 + */ + def json(paths: String*): DataFrame = format("json").load(paths : _*) + /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. @@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * sqlContext.read().text("/path/to/spark/README.md") * }}} * - * @param path input path + * @param paths input path * @since 1.6.0 */ - def text(path: String): DataFrame = format("text").load(path) + @scala.annotation.varargs + def text(paths: String*): DataFrame = format("text").load(paths : _*) /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a12fed3c0c6af..8e0b2dbca4a98 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -298,4 +298,27 @@ public void pivot() { Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + + public void testGenericLoad() { + DataFrame df1 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } + + @Test + public void testTextLoad() { + DataFrame df1 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } } diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt new file mode 100644 index 0000000000000..f9d498c80493c --- /dev/null +++ b/sql/core/src/test/resources/text-suite2.txt @@ -0,0 +1 @@ +This is another file for testing multi path loading. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd6d06512ff60..76e9648aa7533 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + checkAnswer(sqlContext.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) checkAnswer(sqlContext.read.format("json").load(dir1), From 4d6bbbc03ddb6650b00eb638e4876a196014c19c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:58:55 -0800 Subject: [PATCH 1000/2089] [SPARK-11947][SQL] Mark deprecated methods with "This will be removed in Spark 2.0." Also fixed some documentation as I saw them. Author: Reynold Xin Closes #9930 from rxin/SPARK-11947. --- project/MimaExcludes.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 20 +++-- .../org/apache/spark/sql/DataFrame.scala | 72 +++++++++------ .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../org/apache/spark/sql/SQLContext.scala | 88 ++++++++++--------- .../org/apache/spark/sql/SQLImplicits.scala | 25 +++++- .../sql/{ => execution}/SparkSQLParser.scala | 15 ++-- .../org/apache/spark/sql/functions.scala | 52 ++++++----- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 9 files changed, 172 insertions(+), 110 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/SparkSQLParser.scala (89%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bb45d1bb12146..54a9ad956d119 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -108,7 +108,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") ) ++ Seq( // SPARK-11485 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 30c554a85e693..b3cd9e1eff142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -42,7 +42,8 @@ private[sql] object Column { /** * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. - * @since 1.6.0 + * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. @@ -51,7 +52,8 @@ private[sql] object Column { */ class TypedColumn[-T, U]( expr: Expression, - private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + private[sql] val encoder: ExpressionEncoder[U]) + extends Column(expr) { /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -61,12 +63,11 @@ class TypedColumn[-T, U]( inputEncoder: ExpressionEncoder[_], schema: Seq[Attribute]): TypedColumn[T, U] = { val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U] (expr transform { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(boundEncoder), - children = schema) - }, encoder) + new TypedColumn[T, U]( + expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy(aEncoder = Some(boundEncoder), children = schema) + }, + encoder) } } @@ -691,8 +692,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops * @since 1.3.0 + * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. */ - @deprecated("use isin", "1.5.0") + @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") @scala.annotation.varargs def in(list: Any*): Column = isin(list : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5586fc994b98a..5eca1db9525ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1713,9 +1713,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `toDF()`. + * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. */ - @deprecated("use toDF", "1.3.0") + @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") def toSchemaRDD: DataFrame = this /** @@ -1725,9 +1725,9 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. + * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write w.jdbc(url, table, new Properties) @@ -1744,9 +1744,9 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) w.jdbc(url, table, new Properties) @@ -1757,9 +1757,9 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. + * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.parquet(path)", "1.4.0") + @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") def saveAsParquetFile(path: String): Unit = { write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } @@ -1782,8 +1782,9 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String): Unit = { write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } @@ -1805,8 +1806,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { write.mode(mode).saveAsTable(tableName) } @@ -1829,8 +1832,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { write.format(source).saveAsTable(tableName) } @@ -1853,8 +1858,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).saveAsTable(tableName) } @@ -1877,9 +1884,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1907,9 +1915,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1923,9 +1932,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. + * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use write.save(path)", "1.4.0") + @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String): Unit = { write.save(path) } @@ -1935,8 +1944,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default. * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).save(path)", "1.4.0") + @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, mode: SaveMode): Unit = { write.mode(mode).save(path) } @@ -1946,8 +1956,9 @@ class DataFrame private[sql]( * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).save(path)", "1.4.0") + @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String): Unit = { write.format(source).save(path) } @@ -1957,8 +1968,10 @@ class DataFrame private[sql]( * [[SaveMode]] specified by mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).save(path). " + + "This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).save(path) } @@ -1969,8 +1982,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1985,8 +2000,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1994,14 +2011,15 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } - /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } @@ -2012,8 +2030,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String): Unit = { write.mode(SaveMode.Append).insertInto(tableName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 17e2611790d5a..dd84b8bc11e2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** + * :: Experimental :: * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel * using functional or relational operations. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39471d2fb79a7..46bf544fd885f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -942,33 +942,33 @@ class SQLContext private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } @@ -978,9 +978,9 @@ class SQLContext private[sql]( * [[DataFrame]] if no paths are passed in. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. + * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.parquet()", "1.4.0") + @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") @scala.annotation.varargs def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { @@ -995,9 +995,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String): DataFrame = { read.json(path) } @@ -1007,18 +1007,18 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { read.schema(schema).json(path) } /** * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(path) } @@ -1029,9 +1029,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) /** @@ -1040,9 +1040,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) /** @@ -1050,9 +1050,9 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1062,9 +1062,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1074,9 +1074,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1086,9 +1086,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1098,9 +1098,9 @@ class SQLContext private[sql]( * using the default data source configured by spark.sql.sources.default. * * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. + * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use read.load(path)", "1.4.0") + @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String): DataFrame = { read.load(path) } @@ -1110,8 +1110,9 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).load(path)", "1.4.0") + @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String, source: String): DataFrame = { read.format(source).load(path) } @@ -1122,8 +1123,10 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1135,7 +1138,8 @@ class SQLContext private[sql]( * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1148,7 +1152,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() @@ -1162,7 +1167,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() } @@ -1172,9 +1178,9 @@ class SQLContext private[sql]( * url named table. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String): DataFrame = { read.jdbc(url, table, new Properties) } @@ -1190,9 +1196,9 @@ class SQLContext private[sql]( * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc( url: String, table: String, @@ -1210,9 +1216,9 @@ class SQLContext private[sql]( * of the [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { read.jdbc(url, table, theParts, new Properties) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 25ffdcde17717..6735d02954b8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -30,19 +30,38 @@ import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * + * @since 1.6.0 */ abstract class SQLImplicits { + protected def _sqlContext: SQLContext + /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + /** @since 1.6.0 */ + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** @@ -84,9 +103,9 @@ abstract class SQLImplicits { DataFrameHolder(_sqlContext.createDataFrame(data)) } - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. + // Do NOT add more implicit conversions for primitive types. + // They are likely to break source compatibility by making existing implicit conversions + // ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]]. /** * Creates a single column DataFrame from an RDD[Int]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index ea8fce6ca9cf2..b3e8d0d84937e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType - /** * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * * @param fallback A function that parses an input string to a logical plan */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -100,14 +99,14 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case _ ~ dbName => ShowTablesCommand(dbName) } | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => ShowFunctions(f._1, Some(f._2)) - case None => ShowFunctions(None, None) + case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) + case None => logical.ShowFunctions(None, None) } ) private lazy val desc: Parser[LogicalPlan] = DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) + case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6137ce3a70fdb..77dd5bc72508b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql - - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -39,11 +37,11 @@ import org.apache.spark.util.Utils * "bridge" methods due to the use of covariant return types. * * {{{ - * In LegacyFunctions: - * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * // In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); * - * In functions: - * public static org.apache.spark.sql.TypedColumn avg(...); + * // In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); * }}} * * This allows us to use the same functions both in typed [[Dataset]] operations and untyped @@ -2528,8 +2526,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } @@ -2541,8 +2540,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } @@ -2554,8 +2554,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } @@ -2567,8 +2568,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } @@ -2580,8 +2582,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } @@ -2593,8 +2596,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } @@ -2606,8 +2610,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } @@ -2619,8 +2624,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } @@ -2632,8 +2638,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } @@ -2644,9 +2651,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } @@ -2657,9 +2665,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2700,9 +2709,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. + * This will be removed in Spark 2.0. */ - @deprecated("Use callUDF", "1.5.0") + @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 81af684ba0bf1..b554d135e4b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -80,7 +80,11 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat private var partitionedDF: DataFrame = _ - private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil) + private val partitionedDataSchema: StructType = + new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("c", StringType) protected override def beforeAll(): Unit = { this.tempPath = Utils.createTempDir() From a5d988763319f63a8e2b58673dd4f9098f17c835 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 24 Nov 2015 20:58:47 -0800 Subject: [PATCH 1001/2089] [STREAMING][FLAKY-TEST] Catch execution context race condition in `FileBasedWriteAheadLog.close()` There is a race condition in `FileBasedWriteAheadLog.close()`, where if delete's of old log files are in progress, the write ahead log may close, and result in a `RejectedExecutionException`. This is okay, and should be handled gracefully. Example test failures: https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.6-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=spark-test/95/testReport/junit/org.apache.spark.streaming.util/BatchedWriteAheadLogWithCloseFileAfterWriteSuite/BatchedWriteAheadLog___clean_old_logs/ The reason the test fails is in `afterEach`, `writeAheadLog.close` is called, and there may still be async deletes in flight. tdas zsxwing Author: Burak Yavuz Closes #9953 from brkyvz/flaky-ss. --- .../streaming/util/FileBasedWriteAheadLog.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 72705f1a9c010..f5165f7c39122 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} import scala.collection.JavaConverters._ @@ -176,10 +176,16 @@ private[streaming] class FileBasedWriteAheadLog( } oldLogFiles.foreach { logInfo => if (!executionContext.isShutdown) { - val f = Future { deleteFile(logInfo) }(executionContext) - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + try { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } catch { + case e: RejectedExecutionException => + logWarning("Execution context shutdown before deleting old WriteAheadLogs. " + + "This would not affect recovery correctness.", e) } } } From 151d7c2baf18403e6e59e97c80c8bcded6148038 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 21:30:53 -0800 Subject: [PATCH 1002/2089] [SPARK-10621][SQL] Consistent naming for functions in SQL, Python, Scala Author: Reynold Xin Closes #9948 from rxin/SPARK-10621. --- python/pyspark/sql/functions.py | 111 +++++++++++++--- .../org/apache/spark/sql/functions.scala | 124 ++++++++++++++---- 2 files changed, 196 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a1ca723bbd7ab..e3786e0fa5fb2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -150,18 +150,18 @@ def _(): _window_functions = { 'rowNumber': - """returns a sequential number starting at 1 within a window partition. - - This is equivalent to the ROW_NUMBER function in SQL.""", + """.. note:: Deprecated in 1.6, use row_number instead.""", + 'row_number': + """returns a sequential number starting at 1 within a window partition.""", 'denseRank': + """.. note:: Deprecated in 1.6, use dense_rank instead.""", + 'dense_rank': """returns the rank of rows within a window partition, without any gaps. The difference between rank and denseRank is that denseRank leaves no gaps in ranking sequence when there are ties. That is, if you were ranking a competition using denseRank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. - - This is equivalent to the DENSE_RANK function in SQL.""", + place and that the next person came in third.""", 'rank': """returns the rank of rows within a window partition. @@ -172,14 +172,14 @@ def _(): This is equivalent to the RANK function in SQL.""", 'cumeDist': + """.. note:: Deprecated in 1.6, use cume_dist instead.""", + 'cume_dist': """returns the cumulative distribution of values within a window partition, - i.e. the fraction of rows that are below the current row. - - This is equivalent to the CUME_DIST function in SQL.""", + i.e. the fraction of rows that are below the current row.""", 'percentRank': - """returns the relative rank (i.e. percentile) of rows within a window partition. - - This is equivalent to the PERCENT_RANK function in SQL.""", + """.. note:: Deprecated in 1.6, use percent_rank instead.""", + 'percent_rank': + """returns the relative rank (i.e. percentile) of rows within a window partition.""", } for _name, _doc in _functions.items(): @@ -189,7 +189,7 @@ def _(): for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): - globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) + globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc @@ -288,6 +288,38 @@ def countDistinct(col, *cols): @since(1.4) def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + +@since(1.6) +def input_file_name(): + """Creates a string column for the file name of the current Spark task. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.input_file_name()) + + +@since(1.6) +def isnan(col): + """An expression that returns true iff the column is NaN. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnan(_to_java_column(col))) + + +@since(1.6) +def isnull(col): + """An expression that returns true iff the column is null. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnull(_to_java_column(col))) + + +@since(1.6) +def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -300,11 +332,21 @@ def monotonicallyIncreasingId(): 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) - >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.monotonicallyIncreasingId()) + return Column(sc._jvm.functions.monotonically_increasing_id()) + + +@since(1.6) +def nanvl(col1, col2): + """Returns col1 if it is not NaN, or col2 if col1 is NaN. + + Both inputs should be floating point columns (DoubleType or FloatType). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @since(1.4) @@ -382,15 +424,23 @@ def shiftRightUnsigned(col, numBits): @since(1.4) def sparkPartitionId(): + """ + .. note:: Deprecated in 1.6, use spark_partition_id instead. + """ + return spark_partition_id() + + +@since(1.6) +def spark_partition_id(): """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. - >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sparkPartitionId()) + return Column(sc._jvm.functions.spark_partition_id()) @since(1.5) @@ -1410,6 +1460,33 @@ def explode(col): return Column(jc) +@since(1.6) +def get_json_object(col, path): + """ + Extracts json object from a json string based on json path specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + :param col: string column in json format + :param path: path to the json object to extract + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) + return Column(jc) + + +@since(1.6) +def json_tuple(col, fields): + """Creates a new row for a json column according to the given field names. + + :param col: string column in json format + :param fields: list of fields to extract + + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77dd5bc72508b..276c5dfc8b062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -472,6 +472,13 @@ object functions extends LegacyFunctions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def cumeDist(): Column = cume_dist() + /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -481,13 +488,17 @@ object functions extends LegacyFunctions { * cumeDist(x) = number of values before (and including) x / N * }}} * - * - * This is equivalent to the CUME_DIST function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") + def denseRank(): Column = dense_rank() /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -497,12 +508,10 @@ object functions extends LegacyFunctions { * and had three people tie for second place, you would say that all three were in second * place and that the next person came in third. * - * This is equivalent to the DENSE_RANK function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -620,6 +629,13 @@ object functions extends LegacyFunctions { */ def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") + def percentRank(): Column = percent_rank() + /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. * @@ -631,9 +647,9 @@ object functions extends LegacyFunctions { * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } /** * Window function: returns the rank of rows within a window partition. @@ -650,15 +666,20 @@ object functions extends LegacyFunctions { */ def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. + */ + @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") + def rowNumber(): Column = row_number() + /** * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the ROW_NUMBER function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -720,20 +741,43 @@ object functions extends LegacyFunctions { @scala.annotation.varargs def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. + */ + @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") + def inputFileName(): Column = input_file_name() + /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs + * @since 1.6.0 */ - def inputFileName(): Column = withExpr { InputFileName() } + def input_file_name(): Column = withExpr { InputFileName() } + + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. + */ + @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") + def isNaN(e: Column): Column = isnan(e) /** * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.5.0 + * @since 1.6.0 + */ + def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + + /** + * Return true iff the column is null. + * + * @group normal_funcs + * @since 1.6.0 */ - def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) } + def isnull(e: Column): Column = withExpr { IsNull(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -750,7 +794,24 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } + def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.6.0 + */ + def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -825,15 +886,23 @@ object functions extends LegacyFunctions { */ def randn(): Column = randn(Utils.random.nextLong) + /** + * @group normal_funcs + * @since 1.4.0 + * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def sparkPartitionId(): Column = spark_partition_id() + /** * Partition ID of the Spark task. * * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def sparkPartitionId(): Column = withExpr { SparkPartitionID() } + def spark_partition_id(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -2305,6 +2374,17 @@ object functions extends LegacyFunctions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + * + * @group collection_funcs + * @since 1.6.0 + */ + def get_json_object(e: Column, path: String): Column = withExpr { + GetJsonObject(e.expr, lit(path).expr) + } + /** * Creates a new row for a json column according to the given field names. * @@ -2313,7 +2393,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def json_tuple(json: Column, fields: String*): Column = withExpr { - require(fields.length > 0, "at least 1 field name should be given.") + require(fields.nonEmpty, "at least 1 field name should be given.") JsonTuple(json.expr +: fields.map(Literal.apply)) } From 2169886883d33b33acf378ac42a626576b342df1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Nov 2015 23:13:01 -0800 Subject: [PATCH 1003/2089] [SPARK-11979][STREAMING] Empty TrackStateRDD cannot be checkpointed and recovered from checkpoint file This solves the following exception caused when empty state RDD is checkpointed and recovered. The root cause is that an empty OpenHashMapBasedStateMap cannot be deserialized as the initialCapacity is set to zero. ``` Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 20, localhost): java.lang.IllegalArgumentException: requirement failed: Invalid initial capacity at scala.Predef$.require(Predef.scala:233) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:96) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:86) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.readObject(StateMap.scala:291) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1017) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1893) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:1990) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1915) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:370) at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76) at org.apache.spark.serializer.DeserializationStream$$anon$1.getNext(Serializer.scala:181) at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:744) ``` Author: Tathagata Das Closes #9958 from tdas/SPARK-11979. --- .../spark/streaming/util/StateMap.scala | 19 +++++++----- .../spark/streaming/StateMapSuite.scala | 30 ++++++++++++------- .../streaming/rdd/TrackStateRDDSuite.scala | 10 +++++++ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 34287c3e00908..3f139ad138c88 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -59,7 +59,7 @@ private[streaming] object StateMap { def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) } } @@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = 64, + initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD ) extends StateMap[K, S] { self => @@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( deltaChainThreshold = deltaChainThreshold) def this(deltaChainThreshold: Int) = this( - initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - @transient @volatile private var deltaMap = - new OpenHashMap[K, StateInfo[S]](initialCapacity) + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) /** Get the session data if it exists */ override def get(key: K): Option[S] = { @@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the parent map. Keep reading records, until the limiter is reached // First read the approximate number of records to expect and allocate properly size // OpenHashMap - val parentSessionStoreSizeHint = inputStream.readInt() + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( - initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) // Read the records until the limit marking object has been reached var parentSessionLoopDone = false @@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap { class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 48d3b41b66cbf..c4a01eaea739e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite { test("OpenHashMapBasedStateMap - serializing and deserializing") { val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + map1.put(1, 100, 1) map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") val map2 = map1.copy() + // Do not test compaction + assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + map2.put(3, 300, 3) map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") val map3 = map2.copy() + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) - - // Do not test compaction - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - - val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( - Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") } test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { @@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite { assert(map.deltaChainLength > deltaChainThreshold) assert(map.shouldCompact === true) - val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") assert(deser_map.deltaChainLength < deltaChainThreshold) assert(deser_map.shouldCompact === false) - assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") } test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { @@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } + private def testSerialization[MapType <: StateMap[Int, Int]]( + map: MapType, msg: String): MapType = { + val deserMap = Utils.deserialize[MapType]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assertMap(deserMap, map, 1, msg) + deserMap + } + // Assert whether all the data and operations on a state map matches that of a reference state map private def assertMap( mapToTest: StateMap[Int, Int], diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 0feb3af1abb0f..3b2d43f2ce581 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) } + test("checkpointing empty state RDD") { + val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From 2610e06124c7fc0b2b1cfb2e3050a35ab492fb71 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 01:02:36 -0800 Subject: [PATCH 1004/2089] [SPARK-11970][SQL] Adding JoinType into JoinWith and support Sample in Dataset API Except inner join, maybe the other join types are also useful when users are using the joinWith function. Thus, added the joinType into the existing joinWith call in Dataset APIs. Also providing another joinWith interface for the cartesian-join-like functionality. Please provide your opinions. marmbrus rxin cloud-fan Thank you! Author: gatorsmile Closes #9921 from gatorsmile/joinWith. --- .../scala/org/apache/spark/sql/Dataset.scala | 45 +++++++++++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 36 ++++++++++++--- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dd84b8bc11e2b..97eb5b969280d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -83,7 +83,6 @@ class Dataset[T] private[sql]( /** * Returns the schema of the encoded form of the objects in this [[Dataset]]. - * * @since 1.6.0 */ def schema: StructType = resolvedTEncoder.schema @@ -185,7 +184,6 @@ class Dataset[T] private[sql]( * .transform(featurize) * .transform(...) * }}} - * * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) @@ -453,6 +451,21 @@ class Dataset[T] private[sql]( c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + /** + * Returns a new [[Dataset]] by sampling a fraction of records. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = + withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + + /** + * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + /* **************** * * Set operations * * **************** */ @@ -511,13 +524,17 @@ class Dataset[T] private[sql]( * types as well as working with relational data where either side of the join has column * names in common. * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val left = this.logicalPlan val right = other.logicalPlan - val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) val leftOutput = joined.analyzed.output.take(left.output.length) val rightOutput = joined.analyzed.output.takeRight(right.output.length) @@ -540,6 +557,18 @@ class Dataset[T] private[sql]( } } + /** + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + /* ************************** * * Gather to Driver Actions * * ************************** */ @@ -584,7 +613,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -594,7 +622,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c253fdbb8c99e..7d539180ded9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -185,17 +185,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(1, 2).toDS().as("b") checkAnswer( - ds1.joinWith(ds2, $"a.value" === $"b.value"), + ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } - test("joinWith, expression condition") { - val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() - val ds2 = Seq(("a", 1), ("b", 2)).toDS() + test("joinWith, expression condition, outer join") { + val nullInteger = null.asInstanceOf[Integer] + val nullString = null.asInstanceOf[String] + val ds1 = Seq(ClassNullableData("a", 1), + ClassNullableData("c", 3)).toDS() + val ds2 = Seq(("a", new Integer(1)), + ("b", new Integer(2))).toDS() checkAnswer( - ds1.joinWith(ds2, $"_1" === $"a"), - (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + ds1.joinWith(ds2, $"_1" === $"a", "outer"), + (ClassNullableData("a", 1), ("a", new Integer(1))), + (ClassNullableData("c", 3), (nullString, nullInteger)), + (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) } test("joinWith tuple with primitive, expression") { @@ -225,7 +231,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) - } test("groupBy function, keys") { @@ -367,6 +372,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a", 2 -> "bc", 3 -> "d") } + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + 5, 10, 52, 73) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + 3, 17, 27, 58, 62) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -440,6 +461,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { case class ClassData(a: String, b: Int) +case class ClassNullableData(a: String, b: Integer) /** * A class used to test serialization using encoders. This class throws exceptions when using From a0f1a11837bfffb76582499d36fbaf21a1d628cb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 01:03:18 -0800 Subject: [PATCH 1005/2089] [SPARK-11981][SQL] Move implementations of methods back to DataFrame from Queryable Also added show methods to Dataset. Author: Reynold Xin Closes #9964 from rxin/SPARK-11981. --- .../org/apache/spark/sql/DataFrame.scala | 35 ++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 77 ++++++++++++++++++- .../spark/sql/execution/Queryable.scala | 32 ++------ 3 files changed, 111 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5eca1db9525ec..d8319b9a97fcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -112,8 +112,8 @@ private[sql] object DataFrame { */ @Experimental class DataFrame private[sql]( - @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) + @transient override val sqlContext: SQLContext, + @DeveloperApi @transient override val queryExecution: QueryExecution) extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure @@ -282,6 +282,35 @@ class DataFrame private[sql]( */ def schema: StructType = queryExecution.analyzed.schema + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + override def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @group basic + * @since 1.3.0 + */ + override def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + override def explain(): Unit = explain(extended = false) + /** * Returns all column names and their data types as an array. * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 97eb5b969280d..da4600133290f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -61,8 +61,8 @@ import org.apache.spark.util.Utils */ @Experimental class Dataset[T] private[sql]( - @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution, + @transient override val sqlContext: SQLContext, + @transient override val queryExecution: QueryExecution, tEncoder: Encoder[T]) extends Queryable with Serializable { /** @@ -85,7 +85,25 @@ class Dataset[T] private[sql]( * Returns the schema of the encoded form of the objects in this [[Dataset]]. * @since 1.6.0 */ - def schema: StructType = resolvedTEncoder.schema + override def schema: StructType = resolvedTEncoder.schema + + /** + * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * @since 1.6.0 + */ + override def printSchema(): Unit = toDF().printSchema() + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(extended: Boolean): Unit = toDF().explain(extended) + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(): Unit = toDF().explain() /* ************* * * Conversions * @@ -152,6 +170,59 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @since 1.6.0 + */ + def show(numRows: Int): Unit = show(numRows, truncate = true) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. + * + * @since 1.6.0 + */ + def show(): Unit = show(20) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate) + /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 321e2c783537f..f2f5997d1b7c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution +import scala.util.control.NonFatal + import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType -import scala.util.control.NonFatal - /** A trait that holds shared code between DataFrames and Datasets. */ private[sql] trait Queryable { def schema: StructType @@ -37,31 +37,9 @@ private[sql] trait Queryable { } } - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println + def printSchema(): Unit - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } + def explain(extended: Boolean): Unit - /** - * Only prints the physical plan to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) + def explain(): Unit } From 63850026576b3ea7783f9d4b975171dc3cff6e4c Mon Sep 17 00:00:00 2001 From: Ashwin Swaroop Date: Wed, 25 Nov 2015 13:41:14 +0000 Subject: [PATCH 1006/2089] [SPARK-11686][CORE] Issue WARN when dynamic allocation is disabled due to spark.dynamicAllocation.enabled and spark.executor.instances both set Changed the log type to a 'warning' instead of 'info' as required. Author: Ashwin Swaroop Closes #9926 from ashwinswaroop/master. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e19ba113702c6..2c10779f2b893 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -556,7 +556,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") } _executorAllocationManager = From b9b6fbe89b6d1a890faa02c1a53bb670a6255362 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 13:49:58 +0000 Subject: [PATCH 1007/2089] =?UTF-8?q?[SPARK-11860][PYSAPRK][DOCUMENTATION]?= =?UTF-8?q?=20Invalid=20argument=20specification=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …for registerFunction [Python] Straightforward change on the python doc Author: Jeff Zhang Closes #9901 from zjffdu/SPARK-11860. --- python/pyspark/sql/context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5a85ac31025e8..a49c1b58d0180 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -195,14 +195,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param samplingRatio: lambda function + :param f: python function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) From 0a5aef753e70e93d7e56054f354a52e4d4e18932 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 25 Nov 2015 09:34:34 -0600 Subject: [PATCH 1008/2089] [SPARK-10666][SPARK-6880][CORE] Use properties from ActiveJob associated with a Stage This issue was addressed in https://github.com/apache/spark/pull/5494, but the fix in that PR, while safe in the sense that it will prevent the SparkContext from shutting down, misses the actual bug. The intent of `submitMissingTasks` should be understood as "submit the Tasks that are missing for the Stage, and run them as part of the ActiveJob identified by jobId". Because of a long-standing bug, the `jobId` parameter was never being used. Instead, we were trying to use the jobId with which the Stage was created -- which may no longer exist as an ActiveJob, hence the crash reported in SPARK-6880. The correct fix is to use the ActiveJob specified by the supplied jobId parameter, which is guaranteed to exist at the call sites of submitMissingTasks. This fix should be applied to all maintenance branches, since it has existed since 1.0. kayousterhout pankajarora12 Author: Mark Hamstra Author: Imran Rashid Closes #6291 from markhamstra/SPARK-6880. --- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/DAGSchedulerSuite.scala | 107 +++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 77a184dfe4bee..e01a9609b9a0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,9 @@ class DAGScheduler( stage.resetInternalAccumulators() } - val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -1047,7 +1049,7 @@ class DAGScheduler( stage.pendingPartitions ++= tasks.map(_.partitionId) logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4d6b25455226f..653d41fc053c9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.Properties + import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -262,9 +264,10 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - listener: JobListener = jobListener): Int = { + listener: JobListener = jobListener, + properties: Properties = null): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties)) jobId } @@ -1322,6 +1325,106 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + def checkJobPropertiesAndPriority(taskSet: TaskSet, expected: String, priority: Int): Unit = { + assert(taskSet.properties != null) + assert(taskSet.properties.getProperty("testProperty") === expected) + assert(taskSet.priority === priority) + } + + def launchJobsThatShareStageAndCancelFirst(): ShuffleDependency[Int, Int, Nothing] = { + val baseRdd = new MyRDD(sc, 1, Nil) + val shuffleDep1 = new ShuffleDependency(baseRdd, new HashPartitioner(1)) + val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(intermediateRdd, new HashPartitioner(1)) + val finalRdd1 = new MyRDD(sc, 1, List(shuffleDep2)) + val finalRdd2 = new MyRDD(sc, 1, List(shuffleDep2)) + val job1Properties = new Properties() + val job2Properties = new Properties() + job1Properties.setProperty("testProperty", "job1") + job2Properties.setProperty("testProperty", "job2") + + // Run jobs 1 & 2, both referencing the same stage, then cancel job1. + // Note that we have to submit job2 before we cancel job1 to have them actually share + // *Stages*, and not just shuffle dependencies, due to skipped stages (at least until + // we address SPARK-10193.) + val jobId1 = submit(finalRdd1, Array(0), properties = job1Properties) + val jobId2 = submit(finalRdd2, Array(0), properties = job2Properties) + assert(scheduler.activeJobs.nonEmpty) + val testProperty1 = scheduler.jobIdToActiveJob(jobId1).properties.getProperty("testProperty") + + // remove job1 as an ActiveJob + cancel(jobId1) + + // job2 should still be running + assert(scheduler.activeJobs.nonEmpty) + val testProperty2 = scheduler.jobIdToActiveJob(jobId2).properties.getProperty("testProperty") + assert(testProperty1 != testProperty2) + // NB: This next assert isn't necessarily the "desired" behavior; it's just to document + // the current behavior. We've already submitted the TaskSet for stage 0 based on job1, but + // even though we have cancelled that job and are now running it because of job2, we haven't + // updated the TaskSet's properties. Changing the properties to "job2" is likely the more + // correct behavior. + val job1Id = 0 // TaskSet priority for Stages run with "job1" as the ActiveJob + checkJobPropertiesAndPriority(taskSets(0), "job1", job1Id) + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + + shuffleDep1 + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active + */ + test("stage used by two jobs, the first no longer active (SPARK-6880)") { + launchJobsThatShareStageAndCancelFirst() + + // The next check is the key for SPARK-6880. For the stage which was shared by both job1 and + // job2 but never had any tasks submitted for job1, the properties of job2 are now used to run + // the stage. + checkJobPropertiesAndPriority(taskSets(1), "job2", 1) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + assert(taskSets(2).properties != null) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active, even when + * there are fetch failures + */ + test("stage used by two jobs, some fetch failures, and the first job no longer active " + + "(SPARK-6880)") { + val shuffleDep1 = launchJobsThatShareStageAndCancelFirst() + val job2Id = 1 // TaskSet priority for Stages run with "job2" as the ActiveJob + + // lets say there is a fetch failure in this task set, which makes us go back and + // run stage 0, attempt 1 + complete(taskSets(1), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // stage 0, attempt 1 should have the properties of job2 + assert(taskSets(2).stageId === 0) + assert(taskSets(2).stageAttemptId === 1) + checkJobPropertiesAndPriority(taskSets(2), "job2", job2Id) + + // run the rest of the stages normally, checking that they have the correct properties + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(3), "job2", job2Id) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(4), "job2", job2Id) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) From c1f85fc71e71e07534b89c84677d977bb20994f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 09:47:20 -0800 Subject: [PATCH 1009/2089] [SPARK-11956][CORE] Fix a few bugs in network lib-based file transfer. - NettyRpcEnv::openStream() now correctly propagates errors to the read side of the pipe. - NettyStreamManager now throws if the file being transferred does not exist. - The network library now correctly handles zero-sized streams. Author: Marcelo Vanzin Closes #9941 from vanzin/SPARK-11956. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 19 +++++++++---- .../spark/rpc/netty/NettyStreamManager.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 27 +++++++++++++----- .../client/TransportResponseHandler.java | 28 ++++++++++++------- .../org/apache/spark/network/StreamSuite.java | 23 ++++++++++++++- 5 files changed, 75 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68701f609f77a..c8fa870f50e68 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -27,7 +27,7 @@ import javax.annotation.Nullable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} +import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv( @volatile private var error: Throwable = _ - def setError(e: Throwable): Unit = error = e + def setError(e: Throwable): Unit = { + error = e + source.close() + } override def read(dst: ByteBuffer): Int = { - if (error != null) { - throw error + val result = if (error == null) { + Try(source.read(dst)) + } else { + Failure(error) + } + + result match { + case Success(bytesRead) => bytesRead + case Failure(error) => throw error } - source.read(dst) } override def close(): Unit = source.close() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index eb1d2604fb235..a2768b4252dcb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") } - require(file != null, s"File not found: $streamId") + require(file != null && file.isFile(), s"File not found: $streamId") new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2b664c6313efa..6cc958a5f6bc8 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) val fileUri = env.fileServer.addFile(file) + val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val destDir = Utils.createTempDir() - val destFile = new File(destDir, file.getName()) - val destJar = new File(destDir, jar.getName()) - val sm = new SecurityManager(conf) val hc = SparkHadoopUtil.get.conf - Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) - Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) - assert(Files.equal(file, destFile)) - assert(Files.equal(jar, destJar)) + val files = Seq( + (file, fileUri), + (empty, emptyUri), + (jar, jarUri)) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be181e0660826..4c15045363b84 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -185,16 +185,24 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } } } else { logger.error("Could not find callback for StreamResponse."); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 00158fd081626..538f3efe8d6f2 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,13 +51,14 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; private static TransportServer server; private static TransportClientFactory clientFactory; private static File testFile; private static File tempDir; + private static ByteBuffer emptyBuffer; private static ByteBuffer smallBuffer; private static ByteBuffer largeBuffer; @@ -73,6 +74,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); smallBuffer = createBuffer(100); largeBuffer = createBuffer(100000); @@ -103,6 +105,8 @@ public ManagedBuffer openStream(String streamId) { return new NioManagedBuffer(largeBuffer); case "smallBuffer": return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); case "file": return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); default: @@ -138,6 +142,18 @@ public static void tearDown() { } } + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + @Test public void testSingleStream() throws Throwable { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -226,6 +242,11 @@ public void run() { outFile = File.createTempFile("data", ".tmp", tempDir); out = new FileOutputStream(outFile); break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; default: throw new IllegalArgumentException(streamId); } From faabdfa2bd416ae514961535f1953e8e9e8b1f3f Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 Nov 2015 10:36:35 -0800 Subject: [PATCH 1010/2089] [SPARK-11984][SQL][PYTHON] Fix typos in doc for pivot for scala and python Author: felixcheung Closes #9967 from felixcheung/pypivotdoc. --- python/pyspark/sql/group.py | 6 +++--- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index d8ed7eb2dda64..1911588309aff 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,11 +169,11 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and preform the specified aggregation. + """Pivots a column of the current DataFrame and perform the specified aggregation. :param pivot_col: Column to pivot - :param values: Optional list of values of pivotColumn that will be translated to columns in - the output data frame. If values are not provided the method with do an immediate call + :param values: Optional list of values of pivot column that will be translated to columns in + the output DataFrame. If values are not provided the method will do an immediate call to .distinct() on the pivot column. >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index abd531c4ba541..13341a88a6b74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -282,7 +282,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -321,7 +321,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -353,7 +353,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. From 6b781576a15d8d5c5fbed8bef1c5bda95b3d44ac Mon Sep 17 00:00:00 2001 From: Zhongshuai Pei Date: Wed, 25 Nov 2015 10:37:34 -0800 Subject: [PATCH 1011/2089] [SPARK-11974][CORE] Not all the temp dirs had been deleted when the JVM exits deleting the temp dir like that ``` scala> import scala.collection.mutable import scala.collection.mutable scala> val a = mutable.Set(1,2,3,4,7,0,8,98,9) a: scala.collection.mutable.Set[Int] = Set(0, 9, 1, 2, 3, 7, 4, 8, 98) scala> a.foreach(x => {a.remove(x) }) scala> a.foreach(println(_)) 98 ``` You may not modify a collection while traversing or iterating over it.This can not delete all element of the collection Author: Zhongshuai Pei Closes #9951 from DoingDone9/Bug_RemainDir. --- .../scala/org/apache/spark/util/ShutdownHookManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index db4a8b304ec3e..4012dca3ecdf8 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -57,7 +57,9 @@ private[spark] object ShutdownHookManager extends Logging { // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => + // we need to materialize the paths to delete because deleteRecursively removes items from + // shutdownDeletePaths as we are traversing through it. + shutdownDeletePaths.toArray.foreach { dirPath => try { logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath)) From dc1d324fdf83e9f4b1debfb277533b002691d71f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 11:11:39 -0800 Subject: [PATCH 1012/2089] [SPARK-11969] [SQL] [PYSPARK] visualization of SQL query for pyspark Currently, we does not have visualization for SQL query from Python, this PR fix that. cc zsxwing Author: Davies Liu Closes #9949 from davies/pyspark_sql_ui. --- python/pyspark/sql/dataframe.py | 2 +- .../main/scala/org/apache/spark/sql/DataFrame.scala | 7 +++++++ .../org/apache/spark/sql/execution/python.scala | 12 +++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0dd75ba7ca820..746bb55e14f22 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -277,7 +277,7 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d8319b9a97fcf..6197f10813a3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1735,6 +1736,12 @@ class DataFrame private[sql]( EvaluatePython.javaToPython(rdd) } + protected[sql] def collectToPython(): Int = { + withNewExecutionId { + PythonRDD.collectAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// // Deprecated methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index d611b0011da16..defcec95fb555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -121,11 +121,13 @@ object EvaluatePython { def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") + df.withNewExecutionId { + val iter = new SerDeUtil.AutoBatchedPickler( + df.queryExecution.executedPlan.executeTake(n).iterator.map { row => + EvaluatePython.toJava(row, df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } } /** From 0dee44a6646daae0cc03dbc32125e080dff0f4ae Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 25 Nov 2015 11:35:52 -0800 Subject: [PATCH 1013/2089] [MINOR] Remove unnecessary spaces in `include_example.rb` Author: Yu ISHIKAWA Closes #9960 from yu-iskw/minor-remove-spaces. --- docs/_plugins/include_example.rb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 549f81fe1b1bc..564c86680f68e 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -20,12 +20,12 @@ module Jekyll class IncludeExampleTag < Liquid::Tag - + def initialize(tag_name, markup, tokens) @markup = markup super end - + def render(context) site = context.registers[:site] config_dir = '../examples/src/main' @@ -37,7 +37,7 @@ def render(context) code = File.open(@file).read.encode("UTF-8") code = select_lines(code) - + rendered_code = Pygments.highlight(code, :lexer => @lang) hint = "
    Find full example code at " \ @@ -45,7 +45,7 @@ def render(context) rendered_code + hint end - + # Trim the code block so as to have the same indention, regardless of their positions in the # code file. def trim_codeblock(lines) From 67b67320884282ccf3102e2af96f877e9b186517 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 11:37:42 -0800 Subject: [PATCH 1014/2089] [DOCUMENTATION] Fix minor doc error Author: Jeff Zhang Closes #9956 from zjffdu/dev_typo. --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 4de202d7f7631..741d6b2b37a87 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,7 +35,7 @@ val sc = new SparkContext(conf) {% endhighlight %} Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may -actually require one to prevent any sort of starvation issues. +actually require more than 1 thread to prevent any sort of starvation issues. Properties that specify some time duration should be configured with a unit of time. The following format is accepted: From 83653ac5e71996c5a366a42170bed316b208f1b5 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 25 Nov 2015 11:39:00 -0800 Subject: [PATCH 1015/2089] [SPARK-10864][WEB UI] app name is hidden if window is resized Currently the Web UI navbar has a minimum width of 1200px; so if a window is resized smaller than that the app name goes off screen. The 1200px width seems to have been chosen since it fits the longest example app name without wrapping. To work with smaller window widths I made the tabs wrap since it looked better than wrapping the app name. This is a distinct change in how the navbar looks and I'm not sure if it's what we actually want to do. Other notes: - min-width set to 600px to keep the tabs from wrapping individually (will need to be adjusted if tabs are added) - app name will also wrap (making three levels) if a really really long app name is used Author: Alex Bozarth Closes #9874 from ajbozarth/spark10864. --- .../main/resources/org/apache/spark/ui/static/webui.css | 8 ++------ core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 04f3070d25b4a..c628a0c706553 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -16,14 +16,9 @@ */ .navbar { - height: 50px; font-size: 15px; margin-bottom: 15px; - min-width: 1200px -} - -.navbar .navbar-inner { - height: 50px; + min-width: 600px; } .navbar .brand { @@ -46,6 +41,7 @@ .navbar-text { height: 50px; line-height: 3.3; + white-space: nowrap; } table.sortable thead { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 84a1116a5c498..1e8194f57888e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -210,10 +210,10 @@ private[spark] object UIUtils extends Logging { {org.apache.spark.SPARK_VERSION}
    - +
    From 9f3e59a16822fb61d60cf103bd4f7823552939c6 Mon Sep 17 00:00:00 2001 From: wangt Date: Wed, 25 Nov 2015 11:41:05 -0800 Subject: [PATCH 1016/2089] [SPARK-11880][WINDOWS][SPARK SUBMIT] bin/load-spark-env.cmd loads spark-env.cmd from wrong directory * On windows the `bin/load-spark-env.cmd` tries to load `spark-env.cmd` from `%~dp0..\..\conf`, where `~dp0` points to `bin` and `conf` is only one level up. * Updated `bin/load-spark-env.cmd` to load `spark-env.cmd` from `%~dp0..\conf`, instead of `%~dp0..\..\conf` Author: wangt Closes #9863 from toddwan/master. --- bin/load-spark-env.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 36d932c453b6f..59080edd294f2 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -27,7 +27,7 @@ if [%SPARK_ENV_LOADED%] == [] ( if not [%SPARK_CONF_DIR%] == [] ( set user_conf_dir=%SPARK_CONF_DIR% ) else ( - set user_conf_dir=%~dp0..\..\conf + set user_conf_dir=%~dp0..\conf ) call :LoadSparkEnv From 88875d9413ec7d64a88d40857ffcf97b5853a7f2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 25 Nov 2015 11:42:53 -0800 Subject: [PATCH 1017/2089] [SPARK-10558][CORE] Fix wrong executor state in Master `ExecutorAdded` can only be sent to `AppClient` when worker report back the executor state as `LOADING`, otherwise because of concurrency issue, `AppClient` will possibly receive `ExectuorAdded` at first, then `ExecutorStateUpdated` with `LOADING` state. Also Master will change the executor state from `LAUNCHING` to `RUNNING` (`AppClient` report back the state as `RUNNING`), then to `LOADING` (worker report back to state as `LOADING`), it should be `LAUNCHING` -> `LOADING` -> `RUNNING`. Also it is wrongly shown in master UI, the state of executor should be `RUNNING` rather than `LOADING`: ![screen shot 2015-09-11 at 2 30 28 pm](https://cloud.githubusercontent.com/assets/850797/9809254/3155d840-5899-11e5-8cdf-ad06fef75762.png) Author: jerryshao Closes #8714 from jerryshao/SPARK-10558. --- .../org/apache/spark/deploy/ExecutorState.scala | 2 +- .../org/apache/spark/deploy/client/AppClient.scala | 3 --- .../org/apache/spark/deploy/master/Master.scala | 14 +++++++++++--- .../org/apache/spark/deploy/worker/Worker.scala | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index efa88c62e1f5d..69c98e28931d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy private[deploy] object ExecutorState extends Enumeration { - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value + val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED = Value type ExecutorState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index afab362e213b5..df6ba7d669ce9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -178,9 +178,6 @@ private[spark] class AppClient( val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not - // guaranteed), `ExecutorStateChanged` may be sent to a dead master. - sendToMaster(ExecutorStateChanged(appId.get, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b25a487806c7f..9952c97dbdffc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -253,9 +253,17 @@ private[deploy] class Master( execOption match { case Some(exec) => { val appInfo = idToApp(appId) + val oldState = exec.state exec.state = state - if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } + + if (state == ExecutorState.RUNNING) { + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") + appInfo.resetRetryCount() + } + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) + if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -702,8 +710,8 @@ private[deploy] class Master( worker.addExecutor(exec) worker.endpoint.send(LaunchExecutor(masterUrl, exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) - exec.application.driver.send(ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) + exec.application.driver.send( + ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index a45867e7680ec..418faf8fc967f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -469,7 +469,7 @@ private[deploy] class Worker( executorDir, workerUri, conf, - appLocalDirs, ExecutorState.LOADING) + appLocalDirs, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ From d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 25 Nov 2015 11:47:21 -0800 Subject: [PATCH 1018/2089] [SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and TransformFunctionSerializer to Java The Python exception track in TransformFunction and TransformFunctionSerializer is not sent back to Java. Py4j just throws a very general exception, which is hard to debug. This PRs adds `getFailure` method to get the failure message in Java side. Author: Shixiong Zhu Closes #9922 from zsxwing/SPARK-11935. --- python/pyspark/streaming/tests.py | 82 ++++++++++++++++++- python/pyspark/streaming/util.py | 29 +++++-- .../streaming/api/python/PythonDStream.scala | 52 ++++++++++-- 3 files changed, 144 insertions(+), 19 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a0e0267cafa58..d380d697bc51c 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -404,17 +404,69 @@ def func(dstream): self._test_func(input, func, expected) def test_failed_func(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) input = [self.sc.parallelize([d], 1) for d in range(4)] input_stream = self.ssc.queueStream(input) def failed_func(i): - raise ValueError("failed") + raise ValueError("This is a special error") input_stream.map(failed_func).pprint() self.ssc.start() try: self.ssc.awaitTerminationOrTimeout(10) except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func2(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream1 = self.ssc.queueStream(input) + input_stream2 = self.ssc.queueStream(input) + + def failed_func(rdd1, rdd2): + raise ValueError("This is a special error") + + input_stream1.transformWith(failed_func, input_stream2, True).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func_with_reseting_failure(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + if i == 1: + # Make it fail in the second batch + raise ValueError("This is a special error") + else: + return i + + # We should be able to see the results of the 3rd and 4th batches even if the second batch + # fails + expected = [[0], [2], [3]] + self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3)) + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) return self.fail("a failed func should throw an error") @@ -780,6 +832,34 @@ def tearDown(self): if self.cpd is not None: shutil.rmtree(self.cpd) + def test_transform_function_serializer_failure(self): + inputd = tempfile.mkdtemp() + self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure") + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + + # A function that cannot be serialized + def process(time, rdd): + sc.parallelize(range(1, 10)) + + ssc.textFileStream(inputd).foreachRDD(process) + return ssc + + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + try: + self.ssc.start() + except: + import traceback + failure = traceback.format_exc() + self.assertTrue( + "It appears that you are attempting to reference SparkContext" in failure) + return + + self.fail("using SparkContext in process should fail because it's not Serializable") + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 767c732eb90b4..c7f02bca2ae38 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -38,12 +38,15 @@ def __init__(self, ctx, func, *deserializers): self.func = func self.deserializers = deserializers self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.failure = None def rdd_wrapper(self, func): self._rdd_wrapper = func return self def call(self, milliseconds, jrdds): + # Clear the failure + self.failure = None try: if self.ctx is None: self.ctx = SparkContext._active_spark_context @@ -62,9 +65,11 @@ def call(self, milliseconds, jrdds): r = self.func(t, *rdds) if r: return r._jrdd - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunction(%s)" % self.func @@ -89,22 +94,28 @@ def __init__(self, ctx, serializer, gateway=None): self.serializer = serializer self.gateway = gateway or self.ctx._gateway self.gateway.jvm.PythonDStream.registerSerializer(self) + self.failure = None def dumps(self, id): + # Clear the failure + self.failure = None try: func = self.gateway.gateway_property.pool[id] return bytearray(self.serializer.dumps((func.func, func.deserializers))) - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() def loads(self, data): + # Clear the failure + self.failure = None try: f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index dfc569451df86..994309ddd0a3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -26,6 +26,7 @@ import scala.language.existentials import py4j.GatewayServer +import org.apache.spark.SparkException import org.apache.spark.api.java._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -40,6 +41,13 @@ import org.apache.spark.util.Utils */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] + + /** + * Get the failure, if any, in the last call to `call`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -48,6 +56,13 @@ private[python] trait PythonTransformFunction { private[python] trait PythonTransformFunctionSerializer { def dumps(id: String): Array[Byte] def loads(bytes: Array[Byte]): PythonTransformFunction + + /** + * Get the failure, if any, in the last call to `dumps` or `loads`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -59,18 +74,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) - .map(_.rdd) + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava - Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { - pfunc.call(time.milliseconds, rdds) + callPythonTransformFunction(time.milliseconds, rdds) + } + + private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = { + val resultRDD = pfunc.call(time, rdds) + val failure = pfunc.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + resultRDD } private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -103,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer { /* * Register a serializer from Python, should be called during initialization */ - def register(ser: PythonTransformFunctionSerializer): Unit = { + def register(ser: PythonTransformFunctionSerializer): Unit = synchronized { serializer = ser } - def serialize(func: PythonTransformFunction): Array[Byte] = { + def serialize(func: PythonTransformFunction): Array[Byte] = synchronized { require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) val id = f.get(h).asInstanceOf[String] - serializer.dumps(id) + val results = serializer.dumps(id) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + results } - def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized { require(serializer != null, "Serializer has not been registered!") - serializer.loads(bytes) + val pfunc = serializer.loads(bytes) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + pfunc } } From 4e81783e92f464d479baaf93eccc3adb1496989a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 12:58:18 -0800 Subject: [PATCH 1019/2089] [SPARK-11866][NETWORK][CORE] Make sure timed out RPCs are cleaned up. This change does a couple of different things to make sure that the RpcEnv-level code and the network library agree about the status of outstanding RPCs. For RPCs that do not expect a reply ("RpcEnv.send"), support for one way messages (hello CORBA!) was added to the network layer. This is a "fire and forget" message that does not require any state to be kept by the TransportClient; as a result, the RpcEnv 'Ack' message is not needed anymore. For RPCs that do expect a reply ("RpcEnv.ask"), the network library now returns the internal RPC id; if the RpcEnv layer decides to time out the RPC before the network layer does, it now asks the TransportClient to forget about the RPC, so that if the network-level timeout occurs, the client is not killed. As part of implementing the above, I cleaned up some of the code in the netty rpc backend, removing types that were not necessary and factoring out some common code. Of interest is a slight change in the exceptions when posting messages to a stopped RpcEnv; that's mostly to avoid nasty error messages from the local-cluster backend when shutting down, which pollutes the terminal output. Author: Marcelo Vanzin Closes #9917 from vanzin/SPARK-11866. --- .../spark/deploy/worker/ExecutorRunner.scala | 6 +- .../apache/spark/rpc/netty/Dispatcher.scala | 55 +++---- .../org/apache/spark/rpc/netty/Inbox.scala | 28 ++-- .../spark/rpc/netty/NettyRpcCallContext.scala | 35 +--- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 153 +++++++----------- .../org/apache/spark/rpc/netty/Outbox.scala | 64 ++++++-- .../apache/spark/rpc/netty/InboxSuite.scala | 6 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 2 +- .../spark/network/client/TransportClient.java | 34 +++- .../spark/network/protocol/Message.java | 4 +- .../network/protocol/MessageDecoder.java | 3 + .../spark/network/protocol/OneWayMessage.java | 75 +++++++++ .../spark/network/sasl/SaslRpcHandler.java | 5 + .../spark/network/server/RpcHandler.java | 36 +++++ .../server/TransportRequestHandler.java | 18 ++- .../apache/spark/network/ProtocolSuite.java | 2 + .../spark/network/RpcIntegrationSuite.java | 31 ++++ .../spark/network/sasl/SparkSaslSuite.java | 9 ++ 18 files changed, 374 insertions(+), 192 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 3aef0515cbf6e..25a17473e4b53 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -92,7 +92,11 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + try { + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + } catch { + case e: IllegalStateException => logWarning(e.getMessage(), e) + } } /** Stop this executor runner, including killing the process it launched */ diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index eb25d6c7b721b..533c9847661b6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -106,44 +106,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage( - name, - _ => message, - () => { logWarning(s"Drop $message because $name has been stopped") }) + postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e)) } } /** Posts a message sent by a remote endpoint. */ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new RemoteNettyRpcCallContext( - nettyEnv, sender, callback, message.senderAddress, message.needReply) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - postMessage(message.receiver.name, createMessage, onEndpointStopped) + val rpcCallContext = + new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } /** Posts a message sent by a local endpoint. */ def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - p.tryFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } + val rpcCallContext = + new LocalNettyRpcCallContext(message.senderAddress, p) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e)) + } - postMessage(message.receiver.name, createMessage, onEndpointStopped) + /** Posts a one-way message. */ + def postOneWayMessage(message: RequestMessage): Unit = { + postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), + (e) => throw e) } /** @@ -155,21 +141,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { */ private def postMessage( endpointName: String, - createMessageFn: NettyRpcEndpointRef => InboxMessage, - callbackIfStopped: () => Unit): Unit = { + message: InboxMessage, + callbackIfStopped: (Exception) => Unit): Unit = { val shouldCallOnStop = synchronized { val data = endpoints.get(endpointName) if (stopped || data == null) { true } else { - data.inbox.post(createMessageFn(data.ref)) + data.inbox.post(message) receivers.offer(data) false } } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - callbackIfStopped() + val error = if (stopped) { + new IllegalStateException("RpcEnv already stopped.") + } else { + new SparkException(s"Could not find $endpointName or it has been stopped.") + } + callbackIfStopped(error) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 464027f07cc88..175463cc10319 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -27,10 +27,13 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} private[netty] sealed trait InboxMessage -private[netty] case class ContentMessage( +private[netty] case class OneWayMessage( + senderAddress: RpcAddress, + content: Any) extends InboxMessage + +private[netty] case class RpcMessage( senderAddress: RpcAddress, content: Any, - needReply: Boolean, context: NettyRpcCallContext) extends InboxMessage private[netty] case object OnStart extends InboxMessage @@ -96,29 +99,24 @@ private[netty] class Inbox( while (true) { safelyCall(endpoint) { message match { - case ContentMessage(_sender, content, needReply, context) => - // The partial function to call - val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive + case RpcMessage(_sender, content, context) => try { - pf.applyOrElse[Any, Unit](content, { msg => + endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg => throw new SparkException(s"Unsupported message $message from ${_sender}") }) - if (!needReply) { - context.finish() - } } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - } + context.sendFailure(e) // Throw the exception -- this exception will be caught by the safelyCall function. // The endpoint's onError function will be called. throw e } + case OneWayMessage(_sender, content) => + endpoint.receive.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 21d5bb4923d1b..6637e2321f673 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -23,49 +23,28 @@ import org.apache.spark.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -private[netty] abstract class NettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, - override val senderAddress: RpcAddress, - needReply: Boolean) +private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) extends RpcCallContext with Logging { protected def send(message: Any): Unit override def reply(response: Any): Unit = { - if (needReply) { - send(AskResponse(endpointRef, response)) - } else { - throw new IllegalStateException( - s"Cannot send $response to the sender because the sender does not expect a reply") - } + send(response) } override def sendFailure(e: Throwable): Unit = { - if (needReply) { - send(AskResponse(endpointRef, RpcFailure(e))) - } else { - logError(e.getMessage, e) - throw new IllegalStateException( - "Cannot send reply to the sender because the sender won't handle it") - } + send(RpcFailure(e)) } - def finish(): Unit = { - if (!needReply) { - send(Ack(endpointRef)) - } - } } /** * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. */ private[netty] class LocalNettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, - needReply: Boolean, p: Promise[Any]) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { p.success(message) @@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext( */ private[netty] class RemoteNettyRpcCallContext( nettyEnv: NettyRpcEnv, - endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, - senderAddress: RpcAddress, - needReply: Boolean) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + senderAddress: RpcAddress) + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c8fa870f50e68..c7d74fa1d9195 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -150,7 +150,7 @@ private[netty] class NettyRpcEnv( private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { if (receiver.client != null) { - receiver.client.sendRpc(message.content, message.createCallback(receiver.client)); + message.sendWith(receiver.client) } else { require(receiver.address != null, "Cannot send message to client endpoint with no listen address.") @@ -182,25 +182,10 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { // Message to a local RPC endpoint. - val promise = Promise[Any]() - dispatcher.postLocalMessage(message, promise) - promise.future.onComplete { - case Success(response) => - val ack = response.asInstanceOf[Ack] - logTrace(s"Received ack from ${ack.sender}") - case Failure(e) => - logWarning(s"Exception when sending $message", e) - }(ThreadUtils.sameThread) + dispatcher.postOneWayMessage(message) } else { // Message to a remote RPC endpoint. - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - logWarning(s"Exception when sending $message", e) - }, - (client, response) => { - val ack = deserialize[Ack](client, response) - logDebug(s"Receive ack from ${ack.sender}") - })) + postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) } } @@ -208,46 +193,52 @@ private[netty] class NettyRpcEnv( clientFactory.createClient(address.host, address.port) } - private[netty] def ask(message: RequestMessage): Future[Any] = { + private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address + + def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning(s"Ignored failure: $e") + } + } + + def onSuccess(reply: Any): Unit = reply match { + case RpcFailure(e) => onFailure(e) + case rpcReply => + if (!promise.trySuccess(rpcReply)) { + logWarning(s"Ignored message: $reply") + } + } + if (remoteAddr == address) { val p = Promise[Any]() - dispatcher.postLocalMessage(message, p) p.future.onComplete { - case Success(response) => - val reply = response.asInstanceOf[AskResponse] - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - case Failure(e) => - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) } else { - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - }, - (client, response) => { - val reply = deserialize[AskResponse](client, response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - })) + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) } - promise.future + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure( + new TimeoutException("Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) + }(ThreadUtils.sameThread) + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } private[netty] def serialize(content: Any): Array[Byte] = { @@ -512,25 +503,12 @@ private[netty] class NettyRpcEndpointRef( override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - val promise = Promise[Any]() - val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) - f.onComplete { v => - timeoutCancelable.cancel(true) - if (!promise.tryComplete(v)) { - logWarning(s"Ignore message $v") - } - }(ThreadUtils.sameThread) - promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) } override def toString: String = s"NettyRpcEndpointRef(${_address})" @@ -549,24 +527,7 @@ private[netty] class NettyRpcEndpointRef( * The message that is sent from the sender to the receiver. */ private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) - -/** - * The base trait for all messages that are sent back from the receiver to the sender. - */ -private[netty] trait ResponseMessage - -/** - * The reply for `ask` from the receiver side. - */ -private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) - extends ResponseMessage - -/** - * A message to send back to the receiver side. It's necessary because [[TransportClient]] only - * clean the resources when it receives a reply. - */ -private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) /** * A response that indicates some failure happens in the receiver side. @@ -598,6 +559,18 @@ private[netty] class NettyRpcHandler( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postRemoteMessage(messageToDispatch, callback) + } + + override def receive( + client: TransportClient, + message: Array[Byte]): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postOneWayMessage(messageToDispatch) + } + + private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) @@ -605,14 +578,12 @@ private[netty] class NettyRpcHandler( dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) - val messageToDispatch = if (requestMessage.senderAddress == null) { - // Create a new message with the socket address of the client as the sender. - RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, - requestMessage.needReply) - } else { - requestMessage - } - dispatcher.postRemoteMessage(messageToDispatch, callback) + if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + } else { + requestMessage + } } override def getStreamManager: StreamManager = streamManager diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 2f6817f2eb935..36fdd00bbc4c2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -22,22 +22,56 @@ import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.rpc.RpcAddress -private[netty] case class OutboxMessage(content: Array[Byte], - _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) { +private[netty] sealed trait OutboxMessage { - def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() { - override def onFailure(e: Throwable): Unit = { - _onFailure(e) - } + def sendWith(client: TransportClient): Unit - override def onSuccess(response: Array[Byte]): Unit = { - _onSuccess(client, response) - } + def onFailure(e: Throwable): Unit + +} + +private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage + with Logging { + + override def sendWith(client: TransportClient): Unit = { + client.send(content) + } + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Failed to send one-way RPC.", e) + } + +} + +private[netty] case class RpcOutboxMessage( + content: Array[Byte], + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, Array[Byte]) => Unit) + extends OutboxMessage with RpcResponseCallback { + + private var client: TransportClient = _ + private var requestId: Long = _ + + override def sendWith(client: TransportClient): Unit = { + this.client = client + this.requestId = client.sendRpc(content, this) + } + + def onTimeout(): Unit = { + require(client != null, "TransportClient has not yet been set.") + client.removeRpcRequest(requestId) + } + + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + _onSuccess(client, response) } } @@ -82,7 +116,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { } } if (dropped) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) } else { drainOutbox() } @@ -122,7 +156,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { try { val _client = synchronized { client } if (_client != null) { - _client.sendRpc(message.content, message.createCallback(_client)) + message.sendWith(_client) } else { assert(stopped == true) } @@ -195,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(e) + message.onFailure(e) message = messages.poll() } assert(messages.isEmpty) @@ -229,7 +263,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) message = messages.poll() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 276c077b3d13e..2136795b18813 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", true, null) + val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite { new Thread { override def run(): Unit = { for (_ <- 0 until 100) { - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) } exitLatch.countDown() diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ccca795683da3..323184cdd9b6e 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -33,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) - .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 876fcd846791c..8a58e7b24585b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -36,6 +37,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; @@ -205,8 +207,12 @@ public void operationComplete(ChannelFuture future) throws Exception { /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. + * + * @param message The message to send. + * @param callback Callback to handle the RPC's reply. + * @return The RPC's id. */ - public void sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(byte[] message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -235,6 +241,8 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + + return requestId; } /** @@ -265,11 +273,35 @@ public void onFailure(Throwable e) { } } + /** + * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the + * message, and no delivery guarantees are made. + * + * @param message The message to send. + */ + public void send(byte[] message) { + channel.writeAndFlush(new OneWayMessage(message)); + } + + /** + * Removes any state associated with the given RPC. + * + * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + */ + public void removeRpcRequest(long requestId) { + handler.removeRpcRequest(requestId); + } + /** Mark this channel as having timed out. */ public void timeOut() { this.timedOut = true; } + @VisibleForTesting + public TransportResponseHandler getHandler() { + return handler; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index d01598c20f16f..39afd03db60ee 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -28,7 +28,8 @@ public interface Message extends Encodable { public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8); + StreamRequest(6), StreamResponse(7), StreamFailure(8), + OneWayMessage(9); private final byte id; @@ -55,6 +56,7 @@ public static Type decode(ByteBuf buf) { case 6: return StreamRequest; case 7: return StreamResponse; case 8: return StreamFailure; + case 9: return OneWayMessage; default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3c04048f3821a..074780f2b95ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case RpcFailure: return RpcFailure.decode(in); + case OneWayMessage: + return OneWayMessage.decode(in); + case StreamRequest: return StreamRequest.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java new file mode 100644 index 0000000000000..95a0270be3da9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A RPC that does not expect a reply, which is handled by a remote + * {@link org.apache.spark.network.server.RpcHandler}. + */ +public final class OneWayMessage implements RequestMessage { + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public OneWayMessage(byte[] message) { + this.message = message; + } + + @Override + public Type type() { return Type.OneWayMessage; } + + @Override + public int encodedLength() { + return Encoders.ByteArrays.encodedLength(message); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.ByteArrays.encode(buf, message); + } + + public static OneWayMessage decode(ByteBuf buf) { + byte[] message = Encoders.ByteArrays.decode(buf); + return new OneWayMessage(message); + } + + @Override + public int hashCode() { + return Arrays.hashCode(message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OneWayMessage) { + OneWayMessage o = (OneWayMessage) other; + return Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 7033adb9cae6f..830db94b890c5 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -108,6 +108,11 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, byte[] message) { + delegate.receive(client, message); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index dbb7f95f55bc0..65109ddfe13b9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,9 @@ package org.apache.spark.network.server; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -24,6 +27,9 @@ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ public abstract class RpcHandler { + + private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. @@ -47,6 +53,19 @@ public abstract void receive( */ public abstract StreamManager getStreamManager(); + /** + * Receives an RPC message that does not expect a reply. The default implementation will + * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * any of the callback methods are called. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + */ + public void receive(TransportClient client, byte[] message) { + receive(client, message, ONE_WAY_CALLBACK); + } + /** * Invoked when the connection associated with the given client has been invalidated. * No further requests will come from this client. @@ -54,4 +73,21 @@ public abstract void receive( public void connectionTerminated(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } + + private static class OneWayRpcCallback implements RpcResponseCallback { + + private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); + + @Override + public void onSuccess(byte[] response) { + logger.warn("Response provided for one-way RPC."); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Error response provided for one-way RPC.", e); + } + + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4f67bd573be21..db18ea77d1073 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -27,13 +28,14 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.OneWayMessage; +import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; @@ -95,6 +97,8 @@ public void handle(RequestMessage request) { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof OneWayMessage) { + processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); } else { @@ -156,6 +160,14 @@ public void onFailure(Throwable e) { } } + private void processOneWayMessage(OneWayMessage req) { + try { + rpcHandler.receive(reverseClient, req.message); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 22b451fc0e60e..1aa20900ffe74 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; @@ -84,6 +85,7 @@ public void requests() { testClientToServer(new RpcRequest(12345, new byte[0])); testClientToServer(new RpcRequest(12345, new byte[100])); testClientToServer(new StreamRequest("abcde")); + testClientToServer(new OneWayMessage(new byte[100])); } @Test diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8eb56bdd9846f..88fa2258bb794 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,9 +17,11 @@ package org.apache.spark.network; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -46,6 +48,7 @@ public class RpcIntegrationSuite { static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; + static List oneWayMsgs; @BeforeClass public static void setUp() throws Exception { @@ -64,12 +67,19 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, byte[] message) { + String msg = new String(message, Charsets.UTF_8); + oneWayMsgs.add(msg); + } + @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); + oneWayMsgs = new ArrayList<>(); } @AfterClass @@ -158,6 +168,27 @@ public void sendSuccessAndFailure() throws Exception { assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } + @Test + public void sendOneWayMessage() throws Exception { + final String message = "no reply"; + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.send(message.getBytes(Charsets.UTF_8)); + assertEquals(0, client.getHandler().numOutstandingRequests()); + + // Make sure the message arrives. + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { + TimeUnit.MILLISECONDS.sleep(10); + } + + assertEquals(1, oneWayMsgs.size()); + assertEquals(message, oneWayMsgs.get(0)); + } finally { + client.close(); + } + } + private void assertErrorsContain(Set errors, Set contains) { assertEquals(contains.size(), errors.size()); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index b146899670180..a6f180bc40c9a 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.*; import java.io.File; +import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -353,6 +354,14 @@ public void testRpcHandlerDelegate() throws Exception { verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); } + @Test + public void testDelegates() throws Exception { + Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); + for (Method m : rpcHandlerMethods) { + SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + } + } + private static class SaslTestCtx { final TransportClient client; From ecac2835458bbf73fe59413d5bf921500c5b987d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 13:45:41 -0800 Subject: [PATCH 1020/2089] Fix Aggregator documentation (rename present to finish). --- .../scala/org/apache/spark/sql/expressions/Aggregator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index b0cd32b5f73e6..65117d5824755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} * def zero: Int = 0 * def reduce(b: Int, a: Data): Int = b + a.i * def merge(b1: Int, b2: Int): Int = b1 + b2 - * def present(r: Int): Int = r + * def finish(r: Int): Int = r * }.toColumn() * * val ds: Dataset[Data] = ... From 21e5606419c4b7462d30580c549e9bfa0123ae23 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 25 Nov 2015 13:51:30 -0800 Subject: [PATCH 1021/2089] [SPARK-11983][SQL] remove all unused codegen fallback trait Author: Daoyuan Wang Closes #9966 from adrian-wang/removeFallback. --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 3 +-- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- .../spark/sql/catalyst/expressions/NonFoldableLiteral.scala | 3 +-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 533d17ea5c172..a2c6c39fd8ce2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -104,8 +104,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) - extends UnaryExpression with CodegenFallback { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9e484c5ed83bf..adef6050c3565 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -66,7 +66,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -117,7 +117,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 31ecf4a9e810a..118fd695fe2f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.types._ * A literal value that is not foldable. Used in expression codegen testing to test code path * that behave differently based on foldable values. */ -case class NonFoldableLiteral(value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = false override def nullable: Boolean = true From cc243a079b1c039d6e7f0b410d1654d94a090e14 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Wed, 25 Nov 2015 15:13:13 -0800 Subject: [PATCH 1022/2089] [SPARK-11206] Support SQL UI on the history server On the live web UI, there is a SQL tab which provides valuable information for the SQL query. But once the workload is finished, we won't see the SQL tab on the history server. It will be helpful if we support SQL UI on the history server so we can analyze it even after its execution. To support SQL UI on the history server: 1. I added an `onOtherEvent` method to the `SparkListener` trait and post all SQL related events to the same event bus. 2. Two SQL events `SparkListenerSQLExecutionStart` and `SparkListenerSQLExecutionEnd` are defined in the sql module. 3. The new SQL events are written to event log using Jackson. 4. A new trait `SparkHistoryListenerFactory` is added to allow the history server to feed events to the SQL history listener. The SQL implementation is loaded at runtime using `java.util.ServiceLoader`. Author: Carson Wang Closes #9297 from carsonwang/SqlHistoryUI. --- .rat-excludes | 1 + .../org/apache/spark/JavaSparkListener.java | 3 + .../apache/spark/SparkFirehoseListener.java | 4 + .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 24 ++- .../spark/scheduler/SparkListenerBus.scala | 1 + .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 + .../org/apache/spark/sql/SQLContext.scala | 18 ++- .../spark/sql/execution/SQLExecution.scala | 24 +-- .../spark/sql/execution/SparkPlanInfo.scala | 46 ++++++ .../sql/execution/metric/SQLMetricInfo.scala | 30 ++++ .../sql/execution/metric/SQLMetrics.scala | 56 ++++--- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++++++++------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 43 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 + 21 files changed, 327 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 08fba6d351d6a..7262c960ed6bb 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,4 +82,5 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b88..23bc9a2e81727 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba6063..e6b24afd88ad4 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528cf..eaa07acc5132e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f1..075a7f13172de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aad..95722a07144ec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4608bce202ec8..8da6884a38535 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -154,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c9beeb25e05af..7f5d713ec6505 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -511,6 +516,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 0000000000000..507100be90967 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 46bf544fd885f..1c2ac5f6f11bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1263,6 +1263,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1307,6 +1309,10 @@ object SQLContext { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1355,9 +1361,13 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c94..34971986261c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 0000000000000..486ce34064e43 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 0000000000000..2708219ad3485 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e95..6c0f6f8a52dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,21 +104,39 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +144,25 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e74d6fb396e1c..c74ad40406992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.Node import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6a..e19a1e3e5851f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -118,7 +141,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) } } @@ -140,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics, + taskEnd.taskMetrics.accumulatorUpdates(), finishTask = true) } @@ -148,15 +172,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Map[Long, Any], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +195,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +206,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +214,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -289,6 +312,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc6..4f50b2ecdc8f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3fe..7af0ff09c5c6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) nodes += node - val childrenNodes = plan.children.map( + val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 5e2b4154dd7ce..ebfa1eaf3e5bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,6 +26,7 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -82,7 +83,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096c..f93d081d0c30e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,7 +82,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +91,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -206,7 +207,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +221,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +251,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +274,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +292,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62ed..e7b376548787c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,6 +42,7 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From d1930ec01ab5a9d83f801f8ae8d4f15a38d98b76 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 21:25:20 -0800 Subject: [PATCH 1023/2089] [SPARK-12003] [SQL] remove the prefix for name after expanded star Right now, the expended start will include the name of expression as prefix for column, that's not better than without expending, we should not have the prefix. Author: Davies Liu Closes #9984 from davies/expand_star. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 1b2a8dc4c7f14..4f89b462a6ce3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -204,7 +204,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu case s: StructType => s.zipWithIndex.map { case (f, i) => val extract = GetStructField(attribute.get, i) - Alias(extract, target.get + "." + f.name)() + Alias(extract, f.name)() } case _ => { From 068b6438d6886ce5b4aa698383866f466d913d66 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 23:24:33 -0800 Subject: [PATCH 1024/2089] [SPARK-11980][SPARK-10621][SQL] Fix json_tuple and add test cases for Added Python test cases for the function `isnan`, `isnull`, `nanvl` and `json_tuple`. Fixed a bug in the function `json_tuple` rxin , could you help me review my changes? Please let me know anything is missing. Thank you! Have a good Thanksgiving day! Author: gatorsmile Closes #9977 from gatorsmile/json_tuple. --- python/pyspark/sql/functions.py | 44 +++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e3786e0fa5fb2..90625949f747a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -286,14 +286,6 @@ def countDistinct(col, *cols): return Column(jc) -@since(1.4) -def monotonicallyIncreasingId(): - """ - .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. - """ - return monotonically_increasing_id() - - @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -305,6 +297,10 @@ def input_file_name(): @since(1.6) def isnan(col): """An expression that returns true iff the column is NaN. + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnan(_to_java_column(col))) @@ -313,11 +309,23 @@ def isnan(col): @since(1.6) def isnull(col): """An expression that returns true iff the column is null. + + >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.4) +def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. @@ -344,6 +352,10 @@ def nanvl(col1, col2): """Returns col1 if it is not NaN, or col2 if col1 is NaN. Both inputs should be floating point columns (DoubleType or FloatType). + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() + [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @@ -1460,6 +1472,7 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix @since(1.6) def get_json_object(col, path): """ @@ -1468,22 +1481,33 @@ def get_json_object(col, path): :param col: string column in json format :param path: path to the json object to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ + get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) return Column(jc) +@ignore_unicode_prefix @since(1.6) -def json_tuple(col, fields): +def json_tuple(col, *fields): """Creates a new row for a json column according to the given field names. :param col: string column in json format :param fields: list of fields to extract + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) return Column(jc) From d3ef693325f91a1ed340c9756c81244a80398eb2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 25 Nov 2015 23:31:21 -0800 Subject: [PATCH 1025/2089] [SPARK-11999][CORE] Fix the issue that ThreadUtils.newDaemonCachedThreadPool doesn't cache any task In the previous codes, `newDaemonCachedThreadPool` uses `SynchronousQueue`, which is wrong. `SynchronousQueue` is an empty queue that cannot cache any task. This patch uses `LinkedBlockingQueue` to fix it along with other fixes to make sure `newDaemonCachedThreadPool` can use at most `maxThreadNumber` threads, and after that, cache tasks to `LinkedBlockingQueue`. Author: Shixiong Zhu Closes #9978 from zsxwing/cached-threadpool. --- .../org/apache/spark/util/ThreadUtils.scala | 14 ++++-- .../apache/spark/util/ThreadUtilsSuite.scala | 45 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 53283448c87b1..f9fbe2ff858ce 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -56,10 +56,18 @@ private[spark] object ThreadUtils { * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = { + def newDaemonCachedThreadPool( + prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = { val threadFactory = namedThreadFactory(prefix) - new ThreadPoolExecutor( - 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory) + val threadPool = new ThreadPoolExecutor( + maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks + maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used + keepAliveSeconds, + TimeUnit.SECONDS, + new LinkedBlockingQueue[Runnable], + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool } /** diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 620e4debf4e08..92ae038967528 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -24,6 +24,8 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, Future} import scala.util.Random +import org.scalatest.concurrent.Eventually._ + import org.apache.spark.SparkFunSuite class ThreadUtilsSuite extends SparkFunSuite { @@ -59,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite { } } + test("newDaemonCachedThreadPool") { + val maxThreadNumber = 10 + val startThreadsLatch = new CountDownLatch(maxThreadNumber) + val latch = new CountDownLatch(1) + val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "ThreadUtilsSuite-newDaemonCachedThreadPool", + maxThreadNumber, + keepAliveSeconds = 2) + try { + for (_ <- 1 to maxThreadNumber) { + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + startThreadsLatch.countDown() + latch.await(10, TimeUnit.SECONDS) + } + }) + } + startThreadsLatch.await(10, TimeUnit.SECONDS) + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 0) + + // Submit a new task and it should be put into the queue since the thread number reaches the + // limitation + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + latch.await(10, TimeUnit.SECONDS) + } + }) + + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 1) + + latch.countDown() + eventually(timeout(10.seconds)) { + // All threads should be stopped after keepAliveSeconds + assert(cachedThreadPool.getActiveCount === 0) + assert(cachedThreadPool.getPoolSize === 0) + } + } finally { + cachedThreadPool.shutdownNow() + } + } + test("sameThread") { val callerThreadName = Thread.currentThread().getName() val f = Future { From 27d69a0573ed55e916a464e268dcfd5ecc6ed849 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 26 Nov 2015 00:19:42 -0800 Subject: [PATCH 1026/2089] [SPARK-11973] [SQL] push filter through aggregation with alias and literals Currently, filter can't be pushed through aggregation with alias or literals, this patch fix that. After this patch, the time of TPC-DS query 4 go down to 13 seconds from 141 seconds (10x improvements). cc nongli yhuai Author: Davies Liu Closes #9959 from davies/push_filter2. --- .../sql/catalyst/expressions/predicates.scala | 9 ++++ .../sql/catalyst/optimizer/Optimizer.scala | 28 ++++++---- .../optimizer/FilterPushdownSuite.scala | 53 +++++++++++++++++++ 3 files changed, 79 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68557479a9591..304b438c84ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -65,6 +65,15 @@ trait PredicateHelper { } } + // Substitute any known alias from a map. + protected def replaceAlias( + condition: Expression, + aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b54..52f609bc158ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -640,20 +640,14 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe filter } else { // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + val pushedCondition = + deterministic.map(replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { - condition.transform { - case a: Attribute => sourceAliases.getOrElse(a, a) - } - } } /** @@ -690,12 +684,24 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) + + def hasAggregate(expression: Expression): Boolean = expression match { + case agg: AggregateExpression => true + case other => expression.children.exists(hasAggregate) + } + // Create a map of Alias for expressions that does not have AggregateExpression + val aliasMap = AttributeMap(aggregateExpressions.collect { + case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + }) + + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => + val replaced = replaceAlias(conjunct, aliasMap) + replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild)) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0290fafe879f6..0128c220baaca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -697,4 +697,57 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("aggregate: push down filters with alias") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where(('c === 2L || 'aa > 4) && 'aa < 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a + 1 < 3) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L || 'aa > 4) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with literal") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L && 'd === "s") + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where("s" === "s") + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filters which is nondeterministic") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .analyze + + comparePlans(optimized, correctAnswer) + } } From 001f0528a851ac314b390e65eb0583f89e69a949 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Nov 2015 01:15:05 -0800 Subject: [PATCH 1027/2089] [SPARK-12005][SQL] Work around VerifyError in HyperLogLogPlusPlus. Just move the code around a bit; that seems to make the JVM happy. Author: Marcelo Vanzin Closes #9985 from vanzin/SPARK-12005. --- .../expressions/aggregate/HyperLogLogPlusPlus.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8a95c541f1e86..e1fd22e36764e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -63,11 +63,7 @@ case class HyperLogLogPlusPlus( def this(child: Expression, relativeSD: Expression) = { this( child = child, - relativeSD = relativeSD match { - case Literal(d: Double, DoubleType) => d - case _ => - throw new AnalysisException("The second argument should be a double literal.") - }, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), mutableAggBufferOffset = 0, inputAggBufferOffset = 0) } @@ -448,4 +444,11 @@ object HyperLogLogPlusPlus { Array(189083, 185696.913, 182348.774, 179035.946, 175762.762, 172526.444, 169329.754, 166166.099, 163043.269, 159958.91, 156907.912, 153906.845, 150924.199, 147996.568, 145093.457, 142239.233, 139421.475, 136632.27, 133889.588, 131174.2, 128511.619, 125868.621, 123265.385, 120721.061, 118181.769, 115709.456, 113252.446, 110840.198, 108465.099, 106126.164, 103823.469, 101556.618, 99308.004, 97124.508, 94937.803, 92833.731, 90745.061, 88677.627, 86617.47, 84650.442, 82697.833, 80769.132, 78879.629, 77014.432, 75215.626, 73384.587, 71652.482, 69895.93, 68209.301, 66553.669, 64921.981, 63310.323, 61742.115, 60205.018, 58698.658, 57190.657, 55760.865, 54331.169, 52908.167, 51550.273, 50225.254, 48922.421, 47614.533, 46362.049, 45098.569, 43926.083, 42736.03, 41593.473, 40425.26, 39316.237, 38243.651, 37170.617, 36114.609, 35084.19, 34117.233, 33206.509, 32231.505, 31318.728, 30403.404, 29540.0550000001, 28679.236, 27825.862, 26965.216, 26179.148, 25462.08, 24645.952, 23922.523, 23198.144, 22529.128, 21762.4179999999, 21134.779, 20459.117, 19840.818, 19187.04, 18636.3689999999, 17982.831, 17439.7389999999, 16874.547, 16358.2169999999, 15835.684, 15352.914, 14823.681, 14329.313, 13816.897, 13342.874, 12880.882, 12491.648, 12021.254, 11625.392, 11293.7610000001, 10813.697, 10456.209, 10099.074, 9755.39000000001, 9393.18500000006, 9047.57900000003, 8657.98499999999, 8395.85900000005, 8033, 7736.95900000003, 7430.59699999995, 7258.47699999996, 6924.58200000005, 6691.29399999999, 6357.92500000005, 6202.05700000003, 5921.19700000004, 5628.28399999999, 5404.96799999999, 5226.71100000001, 4990.75600000005, 4799.77399999998, 4622.93099999998, 4472.478, 4171.78700000001, 3957.46299999999, 3868.95200000005, 3691.14300000004, 3474.63100000005, 3341.67200000002, 3109.14000000001, 3071.97400000005, 2796.40399999998, 2756.17799999996, 2611.46999999997, 2471.93000000005, 2382.26399999997, 2209.22400000005, 2142.28399999999, 2013.96100000001, 1911.18999999994, 1818.27099999995, 1668.47900000005, 1519.65800000005, 1469.67599999998, 1367.13800000004, 1248.52899999998, 1181.23600000003, 1022.71900000004, 1088.20700000005, 959.03600000008, 876.095999999903, 791.183999999892, 703.337000000058, 731.949999999953, 586.86400000006, 526.024999999907, 323.004999999888, 320.448000000091, 340.672999999952, 309.638999999966, 216.601999999955, 102.922999999952, 19.2399999999907, -0.114000000059605, -32.6240000000689, -89.3179999999702, -153.497999999905, -64.2970000000205, -143.695999999996, -259.497999999905, -253.017999999924, -213.948000000091, -397.590000000084, -434.006000000052, -403.475000000093, -297.958000000101, -404.317000000039, -528.898999999976, -506.621000000043, -513.205000000075, -479.351000000024, -596.139999999898, -527.016999999993, -664.681000000099, -680.306000000099, -704.050000000047, -850.486000000034, -757.43200000003, -713.308999999892) ) // scalastyle:on + + private def validateDoubleLiteral(exp: Expression): Double = exp match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + } + } From bc16a67562560c732833260cbc34825f7e9dcb8f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Nov 2015 11:31:28 -0800 Subject: [PATCH 1028/2089] [SPARK-11863][SQL] Unable to resolve order by if it contains mixture of aliases and real columns this is based on https://github.com/apache/spark/pull/9844, with some bug fix and clean up. The problems is that, normal operator should be resolved based on its child, but `Sort` operator can also be resolved based on its grandchild. So we have 3 rules that can resolve `Sort`: `ResolveReferences`, `ResolveSortReferences`(if grandchild is `Project`) and `ResolveAggregateFunctions`(if grandchild is `Aggregate`). For example, `select c1 as a , c2 as b from tab group by c1, c2 order by a, c2`, we need to resolve `a` and `c2` for `Sort`. Firstly `a` will be resolved in `ResolveReferences` based on its child, and when we reach `ResolveAggregateFunctions`, we will try to resolve both `a` and `c2` based on its grandchild, but failed because `a` is not a legal aggregate expression. whoever merge this PR, please give the credit to dilipbiswal Author: Dilip Biswal Author: Wenchen Fan Closes #9961 from cloud-fan/sort. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47962ebe6ef82..94ffbbb2e5c65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -630,7 +630,8 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -663,13 +664,19 @@ class Analyzer( } } + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + // Since we don't rely on sort.resolved as the stop condition for this rule, // we need to check this and prevent applying this rule multiple times - if (sortOrder == evaluatedOrderings) { + if (sortOrder == finalSortOrders) { sort } else { Project(aggregate.output, - Sort(evaluatedOrderings, global, + Sort(finalSortOrders, global, aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e051069951887..aeeca802d8bb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -220,6 +220,24 @@ class AnalysisSuite extends AnalysisTest { // checkUDF(udf4, expected4) } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { + val a = testRelation2.output(0) + val c = testRelation2.output(2) + val alias1 = a.as("a1") + val alias2 = c.as("a2") + val alias3 = count(a).as("a3") + + val plan = testRelation2 + .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3")) + .orderBy('a1.asc, 'c.asc) + + val expected = testRelation2 + .groupBy(a, c)(alias1, alias2, alias3) + .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) + .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) + checkAnalysis(plan, expected) + } + test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) From ad76562390b81207f8f32491c0bd8ad0e020141f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 26 Nov 2015 16:20:08 -0800 Subject: [PATCH 1029/2089] [SPARK-11998][SQL][TEST-HADOOP2.0] When downloading Hadoop artifacts from maven, we need to try to download the version that is used by Spark If we need to download Hive/Hadoop artifacts, try to download a Hadoop that matches the Hadoop used by Spark. If the Hadoop artifact cannot be resolved (e.g. Hadoop version is a vendor specific version like 2.0.0-cdh4.1.1), we will use Hadoop 2.4.0 (we used to hard code this version as the hadoop that we will download from maven) and we will not share Hadoop classes. I tested this match in my laptop with the following confs (these confs are used by our builds). All tests are good. ``` build/sbt -Phadoop-1 -Dhadoop.version=1.2.1 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Phadoop-1 -Dhadoop.version=2.0.0-mr1-cdh4.1.1 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Pyarn -Phadoop-2.2 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Pkinesis-asl -Phive-thriftserver -Phive ``` Author: Yin Huai Closes #9979 from yhuai/versionsSuite. --- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../hive/client/IsolatedClientLoader.scala | 62 +++++++++++++++---- .../spark/sql/hive/client/VersionsSuite.scala | 23 +++++-- 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8a4264194ae8d..e83941c2ecf66 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.hadoop.util.VersionInfo import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry @@ -288,7 +289,8 @@ class HiveContext private[hive]( logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") IsolatedClientLoader.forVersion( - version = hiveMetastoreVersion, + hiveMetastoreVersion = hiveMetastoreVersion, + hadoopVersion = VersionInfo.getVersion, config = allConfig, barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index e041e0d8e5ae8..010051d255fdc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -34,23 +34,51 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ -private[hive] object IsolatedClientLoader { +private[hive] object IsolatedClientLoader extends Logging { /** * Creates isolated Hive client loaders by downloading the requested version from maven. */ def forVersion( - version: String, + hiveMetastoreVersion: String, + hadoopVersion: String, config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { - val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, - downloadVersion(resolvedVersion, ivyPath)) + val resolvedVersion = hiveVersion(hiveMetastoreVersion) + // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact + // with the given version, we will use Hadoop 2.4.0 and then will not share Hadoop classes. + var sharesHadoopClasses = true + val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { + resolvedVersions((resolvedVersion, hadoopVersion)) + } else { + val (downloadedFiles, actualHadoopVersion) = + try { + (downloadVersion(resolvedVersion, hadoopVersion, ivyPath), hadoopVersion) + } catch { + case e: RuntimeException if e.getMessage.contains("hadoop") => + // If the error message contains hadoop, it is probably because the hadoop + // version cannot be resolved (e.g. it is a vendor specific version like + // 2.0.0-cdh4.1.1). If it is the case, we will try just + // "org.apache.hadoop:hadoop-client:2.4.0". "org.apache.hadoop:hadoop-client:2.4.0" + // is used just because we used to hard code it as the hadoop artifact to download. + logWarning(s"Failed to resolve Hadoop artifacts for the version ${hadoopVersion}. " + + s"We will change the hadoop version from ${hadoopVersion} to 2.4.0 and try again. " + + "Hadoop classes will not be shared between Spark and Hive metastore client. " + + "It is recommended to set jars used by Hive metastore client through " + + "spark.sql.hive.metastore.jars in the production environment.") + sharesHadoopClasses = false + (downloadVersion(resolvedVersion, "2.4.0", ivyPath), "2.4.0") + } + resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) + resolvedVersions((resolvedVersion, actualHadoopVersion)) + } + new IsolatedClientLoader( - version = hiveVersion(version), + version = hiveVersion(hiveMetastoreVersion), execJars = files, config = config, + sharesHadoopClasses = sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -64,12 +92,15 @@ private[hive] object IsolatedClientLoader { case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } - private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + private def downloadVersion( + version: HiveVersion, + hadoopVersion: String, + ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ Seq("com.google.guava:guava:14.0.1", - "org.apache.hadoop:hadoop-client:2.4.0") + s"org.apache.hadoop:hadoop-client:$hadoopVersion") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -86,7 +117,10 @@ private[hive] object IsolatedClientLoader { tempDir.listFiles().map(_.toURI.toURL) } - private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] + // A map from a given pair of HiveVersion and Hadoop version to jar files. + // It is only used by forVersion. + private val resolvedVersions = + new scala.collection.mutable.HashMap[(HiveVersion, String), Seq[URL]] } /** @@ -106,6 +140,7 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. + * @param sharesHadoopClasses When true, we will share Hadoop classes between Spark and * @param rootClassLoader The system root classloader. Must not know about Hive classes. * @param baseClassLoader The spark classloader that is used to load shared classes. */ @@ -114,6 +149,7 @@ private[hive] class IsolatedClientLoader( val execJars: Seq[URL] = Seq.empty, val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, + val sharesHadoopClasses: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, val sharedPrefixes: Seq[String] = Seq.empty, @@ -126,16 +162,20 @@ private[hive] class IsolatedClientLoader( /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray - protected def isSharedClass(name: String): Boolean = + protected def isSharedClass(name: String): Boolean = { + val isHadoopClass = + name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") + name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || - (name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.")) || + (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) + } /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7bc13bc60d30e..502b240f3650f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.client import java.io.File +import org.apache.hadoop.util.VersionInfo + import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} @@ -53,9 +55,11 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, - buildConf(), - ivyPath).createClient() + val badClient = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -85,7 +89,11 @@ class VersionsSuite extends SparkFunSuite with Logging { ignore("failure sanity check") { val e = intercept[Throwable] { val badClient = quietly { - IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient() + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = "13", + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") @@ -99,7 +107,12 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient() + client = + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } test(s"$version: createDatabase") { From de28e4d4deca385b7c40b3a6a1efcd6e2fec2f9b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Nov 2015 18:47:54 -0800 Subject: [PATCH 1030/2089] [SPARK-11973][SQL] Improve optimizer code readability. This is a followup for https://github.com/apache/spark/pull/9959. I added more documentation and rewrote some monadic code into simpler ifs. Author: Reynold Xin Closes #9995 from rxin/SPARK-11973. --- .../sql/catalyst/optimizer/Optimizer.scala | 50 +++++++++---------- .../optimizer/FilterPushdownSuite.scala | 2 +- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 52f609bc158ca..2901d8f2efddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, - RemoveDispensable, + RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -660,14 +660,14 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf g.child.outputSet + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) } else { filter } @@ -675,34 +675,34 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } /** - * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference - * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath, - * and the rest should remain above. + * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only + * non-aggregate attributes (typically literals or grouping expressions). */ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - - def hasAggregate(expression: Expression): Boolean = expression match { - case agg: AggregateExpression => true - case other => expression.children.exists(hasAggregate) - } - // Create a map of Alias for expressions that does not have AggregateExpression - val aliasMap = AttributeMap(aggregateExpressions.collect { - case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + case filter @ Filter(condition, aggregate: Aggregate) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) }) - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => - val replaced = replaceAlias(conjunct, aliasMap) - replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic + // For each filter, expand the alias and check if the filter can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic } + if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) - val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + // If there is no more filter to stay up, just eliminate the filter. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } @@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * And also pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details @@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { /** * Removes nodes that are not necessary. */ -object RemoveDispensable extends Rule[LogicalPlan] { +object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child case PromotePrecision(child) => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0128c220baaca..fba4c5ca77d64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -734,7 +734,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("aggregate: don't push down filters which is nondeterministic") { + test("aggregate: don't push down filters that are nondeterministic") { val originalQuery = testRelation .select('a, 'b) .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) From 4376b5bea8171e4e73b3dbabbfdf84fa1afd140b Mon Sep 17 00:00:00 2001 From: muxator Date: Thu, 26 Nov 2015 18:52:20 -0800 Subject: [PATCH 1031/2089] doc typo: "classificaion" -> "classification" Author: muxator Closes #10008 from muxator/patch-1. --- docs/mllib-linear-methods.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 0c76e6e999465..132f8c354aa9c 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -122,7 +122,7 @@ Under the hood, linear methods use convex optimization methods to optimize the o [Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into categories. The most common classification type is -[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), where there are two categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). From 0c1e72e7f79231e537299b57a1ab7cd843171923 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 26 Nov 2015 18:56:22 -0800 Subject: [PATCH 1032/2089] [SPARK-11996][CORE] Make the executor thread dump work again In the previous implementation, the driver needs to know the executor listening address to send the thread dump request. However, in Netty RPC, the executor doesn't listen to any port, so the executor thread dump feature is broken. This patch makes the driver use the endpointRef stored in BlockManagerMasterEndpoint to send the thread dump request to fix it. Author: Shixiong Zhu Closes #9976 from zsxwing/executor-thread-dump. --- .../scala/org/apache/spark/SparkContext.scala | 10 ++--- .../org/apache/spark/executor/Executor.scala | 5 --- .../spark/executor/ExecutorEndpoint.scala | 43 ------------------- .../spark/storage/BlockManagerMaster.scala | 4 +- .../storage/BlockManagerMasterEndpoint.scala | 12 +++--- .../spark/storage/BlockManagerMessages.scala | 7 ++- .../storage/BlockManagerSlaveEndpoint.scala | 7 ++- project/MimaExcludes.scala | 8 ++++ 8 files changed, 29 insertions(+), 67 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2c10779f2b893..b030d3c71dc20 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,20 +48,20 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ +import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -619,11 +619,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get - val endpointRef = env.rpcEnv.setupEndpointRef( - SparkEnv.executorActorSystemName, - RpcAddress(host, port), - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9e88d488c0379..6154f06e3ac11 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -85,10 +85,6 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an RpcEndpoint for receiving RPCs from the driver - private val executorEndpoint = env.rpcEnv.setupEndpoint( - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) - // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) @@ -136,7 +132,6 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.rpcEnv.stop(executorEndpoint) heartbeater.shutdown() heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala deleted file mode 100644 index cf362f8464735..0000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils - -/** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. - */ -private[spark] -class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case TriggerThreadDump => - context.reply(Utils.getThreadDump()) - } - -} - -object ExecutorEndpoint { - val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f45bff34d4dbc..440c4c18aadd0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -87,8 +87,8 @@ class BlockManagerMaster( driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { + driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7db6035553ae6..41892b4ffce5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} @@ -75,8 +74,8 @@ class BlockManagerMasterEndpoint( case GetPeers(blockManagerId) => context.reply(getPeers(blockManagerId)) - case GetRpcHostPortForExecutor(executorId) => - context.reply(getRpcHostPortForExecutor(executorId)) + case GetExecutorEndpointRef(executorId) => + context.reply(getExecutorEndpointRef(executorId)) case GetMemoryStatus => context.reply(memoryStatus) @@ -388,15 +387,14 @@ class BlockManagerMasterEndpoint( } /** - * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its - * [[BlockManagerSlaveEndpoint]]. + * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages. */ - private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId) ) yield { - (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + info.slaveEndpoint } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843d..f392a4a0cd9be 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,6 +42,11 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + /** + * Driver -> Executor message to trigger a thread dump. + */ + case object TriggerThreadDump extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -90,7 +95,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetExecutorEndpointRef(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index e749631bf6f19..9eca902f7454e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,10 +19,10 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * An RpcEndpoint to take commands from the master to execute options. For example, @@ -70,6 +70,9 @@ class BlockManagerSlaveEndpoint( case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) + + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 54a9ad956d119..566bfe8efb7a4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -147,6 +147,14 @@ object MimaExcludes { // SPARK-4557 Changed foreachRDD to use VoidFunction ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") + ) ++ Seq( + // SPARK-11996 Make the executor thread dump work again + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") ) case v if v.startsWith("1.5") => Seq( From 6f6bb0e893c8370cbab4d63a56d74e00cb7f3cf6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 26 Nov 2015 19:00:36 -0800 Subject: [PATCH 1033/2089] [SPARK-12011][SQL] Stddev/Variance etc should support columnName as arguments Spark SQL aggregate function: ```Java stddev stddev_pop stddev_samp variance var_pop var_samp skewness kurtosis collect_list collect_set ``` should support ```columnName``` as arguments like other aggregate function(max/min/count/sum). Author: Yanbo Liang Closes #9994 from yanboliang/SPARK-12011. --- .../org/apache/spark/sql/functions.scala | 86 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 3 + 2 files changed, 89 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 276c5dfc8b062..e79defbbbdeea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -214,6 +214,16 @@ object functions extends LegacyFunctions { */ def collect_list(e: Column): Column = callUDF("collect_list", e) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(columnName: String): Column = collect_list(Column(columnName)) + /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * @@ -224,6 +234,16 @@ object functions extends LegacyFunctions { */ def collect_set(e: Column): Column = callUDF("collect_set", e) + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(columnName: String): Column = collect_set(Column(columnName)) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * @@ -312,6 +332,14 @@ object functions extends LegacyFunctions { */ def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the last value in a group. * @@ -386,6 +414,14 @@ object functions extends LegacyFunctions { */ def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + /** * Aggregate function: alias for [[stddev_samp]]. * @@ -394,6 +430,14 @@ object functions extends LegacyFunctions { */ def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + /** + * Aggregate function: alias for [[stddev_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(columnName: String): Column = stddev(Column(columnName)) + /** * Aggregate function: returns the sample standard deviation of * the expression in a group. @@ -403,6 +447,15 @@ object functions extends LegacyFunctions { */ def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + /** + * Aggregate function: returns the sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + /** * Aggregate function: returns the population standard deviation of * the expression in a group. @@ -412,6 +465,15 @@ object functions extends LegacyFunctions { */ def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * @@ -452,6 +514,14 @@ object functions extends LegacyFunctions { */ def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + /** + * Aggregate function: alias for [[var_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + /** * Aggregate function: returns the unbiased variance of the values in a group. * @@ -460,6 +530,14 @@ object functions extends LegacyFunctions { */ def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + /** * Aggregate function: returns the population variance of the values in a group. * @@ -468,6 +546,14 @@ object functions extends LegacyFunctions { */ def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9c42f65bb6f52..b5c636d0de1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -261,6 +261,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) + checkAnswer( + testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { From b63938a8b04a30feb6b2255c4d4e530a74855afc Mon Sep 17 00:00:00 2001 From: mariusvniekerk Date: Thu, 26 Nov 2015 19:13:16 -0800 Subject: [PATCH 1034/2089] [SPARK-11881][SQL] Fix for postgresql fetchsize > 0 Reference: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor In order for PostgreSQL to honor the fetchSize non-zero setting, its Connection.autoCommit needs to be set to false. Otherwise, it will just quietly ignore the fetchSize setting. This adds a new side-effecting dialect specific beforeFetch method that will fire before a select query is ran. Author: mariusvniekerk Closes #9861 from mariusvniekerk/SPARK-11881. --- .../execution/datasources/jdbc/JDBCRDD.scala | 12 ++++++++++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 11 +++++++++++ .../apache/spark/sql/jdbc/PostgresDialect.scala | 17 ++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 89c850ce238d7..f9b72597dd2a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -224,6 +224,7 @@ private[sql] object JDBCRDD extends Logging { quotedColumns, filters, parts, + url, properties) } } @@ -241,6 +242,7 @@ private[sql] class JDBCRDD( columns: Array[String], filters: Array[Filter], partitions: Array[Partition], + url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { @@ -361,6 +363,9 @@ private[sql] class JDBCRDD( context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -489,6 +494,13 @@ private[sql] class JDBCRDD( } try { if (null != conn) { + if (!conn.getAutoCommit && !conn.isClosed) { + try { + conn.commit() + } catch { + case e: Throwable => logWarning("Exception committing transaction", e) + } + } conn.close() } logInfo("closed connection") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b3b2cb6178c52..13db141f27db6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.jdbc +import java.sql.Connection + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -97,6 +99,15 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ + def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + } + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index ed3faa1268635..3cf80f576e92c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -70,4 +70,19 @@ private object PostgresDialect extends JdbcDialect { override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } + + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + super.beforeFetch(connection, properties) + + // According to the postgres jdbc documentation we need to be in autocommit=false if we actually + // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the + // rows inside the driver when fetching. + // + // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor + // + if (properties.getOrElse("fetchsize", "0").toInt > 0) { + connection.setAutoCommit(false) + } + + } } From d8220885c492141dfc61e8ffb92934f2339fe8d3 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 26 Nov 2015 19:15:22 -0800 Subject: [PATCH 1035/2089] [SPARK-11917][PYSPARK] Add SQLContext#dropTempTable to PySpark Author: Jeff Zhang Closes #9903 from zjffdu/SPARK-11917. --- python/pyspark/sql/context.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index a49c1b58d0180..b05aa2f5c4cd7 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -445,6 +445,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") + @since(1.6) + def dropTempTable(self, tableName): + """ Remove the temp table from catalog. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> sqlContext.dropTempTable("table1") + """ + self._ssql_ctx.dropTempTable(tableName) + def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. From 4d4cbc034bef559f47f8b74cecd8196dc8a85348 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Nov 2015 19:17:46 -0800 Subject: [PATCH 1036/2089] [SPARK-11778][SQL] add regression test Fix regression test for SPARK-11778. marmbrus Could you please take a look? Thank you very much!! Author: Huaxin Gao Closes #9890 from huaxingao/spark-11778-regression-test. --- .../hive/HiveDataFrameAnalyticsSuite.scala | 10 ------ .../spark/sql/hive/HiveDataFrameSuite.scala | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index f19a74d4b3724..9864acf765265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,14 +34,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") - hiveContext.sql("create schema usrdb") - hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") - hiveContext.sql("drop table usrdb.test") - hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -78,10 +74,4 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } - - // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, - // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException - test("table name with schema") { - hiveContext.read.table("usrdb.test") - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala new file mode 100644 index 0000000000000..7fdc5d71937ff --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.QueryTest + +class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { + test("table name with schema") { + // regression test for SPARK-11778 + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c int)") + hiveContext.read.table("usrdb.test") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") + } +} From 5eaed4e45c6c57e995ac7438016fad545716e596 Mon Sep 17 00:00:00 2001 From: Jeremy Derr Date: Thu, 26 Nov 2015 19:25:13 -0800 Subject: [PATCH 1037/2089] [SPARK-11991] fixes If `--private-ips` is required but not provided, spark_ec2.py may behave inappropriately, including attempting to ssh to localhost in attempts to verify ssh connectivity to the cluster. This fixes that behavior by raising a `UsageError` exception if `get_dns_name` is unable to determine a hostname as a result. Author: Jeremy Derr Closes #9975 from jcderr/SPARK-11991/ec_spark.py_hostname_check. --- ec2/spark_ec2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9fd652a3df4c4..84a950c9f6529 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -1242,6 +1242,10 @@ def get_ip_address(instance, private_ips=False): def get_dns_name(instance, private_ips=False): dns = instance.public_dns_name if not private_ips else \ instance.private_ip_address + if not dns: + raise UsageError("Failed to determine hostname of {0}.\n" + "Please check that you provided --private-ips if " + "necessary".format(instance)) return dns From 10e315c28c933b967674ae51e1b2f24160c2e8a5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Nov 2015 19:36:43 -0800 Subject: [PATCH 1038/2089] Fix style violation for b63938a8b04 --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index f9b72597dd2a9..57a8a044a37cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD @@ -498,7 +500,7 @@ private[sql] class JDBCRDD( try { conn.commit() } catch { - case e: Throwable => logWarning("Exception committing transaction", e) + case NonFatal(e) => logWarning("Exception committing transaction", e) } } conn.close() From a374e20b5492c775f20d32e8fbddadbd8098a655 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Nov 2015 21:04:40 -0800 Subject: [PATCH 1039/2089] [SPARK-11997] [SQL] NPE when save a DataFrame as parquet and partitioned by long column Check for partition column null-ability while building the partition spec. Author: Dilip Biswal Closes #10001 from dilipbiswal/spark-11997. --- .../org/apache/spark/sql/sources/interfaces.scala | 2 +- .../datasources/parquet/ParquetQuerySuite.scala | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936d..9ace25dc7d21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -607,7 +607,7 @@ abstract class HadoopFsRelation private[sql]( def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => Cast( - Literal.create(row.getString(i), StringType), + Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType).eval() }: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 70fae32b7e7a1..f777e973052d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -252,6 +252,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-11997 parquet with null partition values") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(1, 3) + .selectExpr("if(id % 2 = 0, null, id) AS n", "id") + .write.partitionBy("n").parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).filter("n is null"), + Row(2, null)) + } + } + // This test case is ignored because of parquet-mr bug PARQUET-370 ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => From ba02f6cb5a40511cefa511d410be93c035d43f23 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 27 Nov 2015 11:48:01 -0800 Subject: [PATCH 1040/2089] [SPARK-12025][SPARKR] Rename some window rank function names for SparkR Change ```cumeDist -> cume_dist, denseRank -> dense_rank, percentRank -> percent_rank, rowNumber -> row_number``` at SparkR side. There are two reasons that we should make this change: * We should follow the [naming convention rule of R](http://www.inside-r.org/node/230645) * Spark DataFrame has deprecated the old convention (such as ```cumeDist```) and will remove it in Spark 2.0. It's better to fix this issue before 1.6 release, otherwise we will make breaking API change. cc shivaram sun-rui Author: Yanbo Liang Closes #10016 from yanboliang/SPARK-12025. --- R/pkg/NAMESPACE | 8 ++--- R/pkg/R/functions.R | 54 ++++++++++++++++---------------- R/pkg/R/generics.R | 16 +++++----- R/pkg/inst/tests/test_sparkSQL.R | 4 +-- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 260c9edce62e0..5d04dd6acaab8 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -123,14 +123,14 @@ exportMethods("%in%", "count", "countDistinct", "crc32", - "cumeDist", + "cume_dist", "date_add", "date_format", "date_sub", "datediff", "dayofmonth", "dayofyear", - "denseRank", + "dense_rank", "desc", "endsWith", "exp", @@ -188,7 +188,7 @@ exportMethods("%in%", "next_day", "ntile", "otherwise", - "percentRank", + "percent_rank", "pmod", "quarter", "rand", @@ -200,7 +200,7 @@ exportMethods("%in%", "rint", "rlike", "round", - "rowNumber", + "row_number", "rpad", "rtrim", "second", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 25a1f22101494..e98e7a0117ca0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2146,47 +2146,47 @@ setMethod("ifelse", ###################### Window functions###################### -#' cumeDist +#' cume_dist #' #' Window function: returns the cumulative distribution of values within a window partition, #' i.e. the fraction of rows that are below the current row. #' #' N = total number of rows in the partition -#' cumeDist(x) = number of values before (and including) x / N +#' cume_dist(x) = number of values before (and including) x / N #' #' This is equivalent to the CUME_DIST function in SQL. #' -#' @rdname cumeDist -#' @name cumeDist +#' @rdname cume_dist +#' @name cume_dist #' @family window_funcs #' @export -#' @examples \dontrun{cumeDist()} -setMethod("cumeDist", +#' @examples \dontrun{cume_dist()} +setMethod("cume_dist", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist") + jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") column(jc) }) -#' denseRank +#' dense_rank #' #' Window function: returns the rank of rows within a window partition, without any gaps. -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. #' #' This is equivalent to the DENSE_RANK function in SQL. #' -#' @rdname denseRank -#' @name denseRank +#' @rdname dense_rank +#' @name dense_rank #' @family window_funcs #' @export -#' @examples \dontrun{denseRank()} -setMethod("denseRank", +#' @examples \dontrun{dense_rank()} +setMethod("dense_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "denseRank") + jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") column(jc) }) @@ -2264,7 +2264,7 @@ setMethod("ntile", column(jc) }) -#' percentRank +#' percent_rank #' #' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. #' @@ -2274,15 +2274,15 @@ setMethod("ntile", #' #' This is equivalent to the PERCENT_RANK function in SQL. #' -#' @rdname percentRank -#' @name percentRank +#' @rdname percent_rank +#' @name percent_rank #' @family window_funcs #' @export -#' @examples \dontrun{percentRank()} -setMethod("percentRank", +#' @examples \dontrun{percent_rank()} +setMethod("percent_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "percentRank") + jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") column(jc) }) @@ -2316,21 +2316,21 @@ setMethod("rank", base::rank(x, ...) }) -#' rowNumber +#' row_number #' #' Window function: returns a sequential number starting at 1 within a window partition. #' #' This is equivalent to the ROW_NUMBER function in SQL. #' -#' @rdname rowNumber -#' @name rowNumber +#' @rdname row_number +#' @name row_number #' @family window_funcs #' @export -#' @examples \dontrun{rowNumber()} -setMethod("rowNumber", +#' @examples \dontrun{row_number()} +setMethod("row_number", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") + jc <- callJStatic("org.apache.spark.sql.functions", "row_number") column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1b3f10ea04643..0c305441e043e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -700,9 +700,9 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname cumeDist +#' @rdname cume_dist #' @export -setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") }) +setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) #' @rdname datediff #' @export @@ -728,9 +728,9 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname denseRank +#' @rdname dense_rank #' @export -setGeneric("denseRank", function(x) { standardGeneric("denseRank") }) +setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) #' @rdname explode #' @export @@ -872,9 +872,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @rdname percentRank +#' @rdname percent_rank #' @export -setGeneric("percentRank", function(x) { standardGeneric("percentRank") }) +setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) #' @rdname pmod #' @export @@ -913,9 +913,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @export setGeneric("rint", function(x, ...) { standardGeneric("rint") }) -#' @rdname rowNumber +#' @rdname row_number #' @export -setGeneric("rowNumber", function(x) { standardGeneric("rowNumber") }) +setGeneric("row_number", function(x) { standardGeneric("row_number") }) #' @rdname rpad #' @export diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3f4f319fe745d..0fbe0658265b2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -861,8 +861,8 @@ test_that("column functions", { c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c14 <- cumeDist() + ntile(1) - c15 <- denseRank() + percentRank() + rank() + rowNumber() + c14 <- cume_dist() + ntile(1) + c15 <- dense_rank() + percent_rank() + rank() + row_number() # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") From f57e6c9effdb9e282fc8ae66dc30fe053fed5272 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 27 Nov 2015 11:50:18 -0800 Subject: [PATCH 1041/2089] [SPARK-12021][STREAMING][TESTS] Fix the potential dead-lock in StreamingListenerSuite In StreamingListenerSuite."don't call ssc.stop in listener", after the main thread calls `ssc.stop()`, `StreamingContextStoppingCollector` may call `ssc.stop()` in the listener bus thread, which is a dead-lock. This PR updated `StreamingContextStoppingCollector` to only call `ssc.stop()` in the first batch to avoid the dead-lock. Author: Shixiong Zhu Closes #10011 from zsxwing/fix-test-deadlock. --- .../streaming/StreamingListenerSuite.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index df4575ab25aad..04cd5bdc26be2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -222,7 +222,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val batchCounter = new BatchCounter(_ssc) _ssc.start() // Make sure running at least one batch - batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + if (!batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)) { + fail("The first batch cannot complete in 10 seconds") + } + // When reaching here, we can make sure `StreamingContextStoppingCollector` won't call + // `ssc.stop()`, so it's safe to call `_ssc.stop()` now. _ssc.stop() assert(contextStoppingCollector.sparkExSeen) } @@ -345,12 +349,21 @@ class FailureReasonsCollector extends StreamingListener { */ class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { @volatile var sparkExSeen = false + + private var isFirstBatch = true + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - try { - ssc.stop() - } catch { - case se: SparkException => - sparkExSeen = true + if (isFirstBatch) { + // We should only call `ssc.stop()` in the first batch. Otherwise, it's possible that the main + // thread is calling `ssc.stop()`, while StreamingContextStoppingCollector is also calling + // `ssc.stop()` in the listener thread, which becomes a dead-lock. + isFirstBatch = false + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } } } } From b9921524d970f9413039967c1f17ae2e736982f0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 27 Nov 2015 15:11:13 -0800 Subject: [PATCH 1042/2089] [SPARK-12020][TESTS][TEST-HADOOP2.0] PR builder cannot trigger hadoop 2.0 test https://issues.apache.org/jira/browse/SPARK-12020 Author: Yin Huai Closes #10010 from yhuai/SPARK-12020. --- dev/run-tests-jenkins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 623004310e189..4f390ef1eaa32 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -164,7 +164,7 @@ def main(): # Switch the Hadoop profile based on the PR title: if "test-hadoop1.0" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" - if "test-hadoop2.2" in ghprb_pull_title: + if "test-hadoop2.0" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" if "test-hadoop2.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" From 149cd692ee2e127d79386fd8e584f4f70a2906ba Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 27 Nov 2015 22:44:08 -0800 Subject: [PATCH 1043/2089] [SPARK-12028] [SQL] get_json_object returns an incorrect result when the value is null literals When calling `get_json_object` for the following two cases, both results are `"null"`: ```scala val tuple: Seq[(String, String)] = ("5", """{"f1": null}""") :: Nil val df: DataFrame = tuple.toDF("key", "jstring") val res = df.select(functions.get_json_object($"jstring", "$.f1")).collect() ``` ```scala val tuple2: Seq[(String, String)] = ("5", """{"f1": "null"}""") :: Nil val df2: DataFrame = tuple2.toDF("key", "jstring") val res3 = df2.select(functions.get_json_object($"jstring", "$.f1")).collect() ``` Fixed the problem and also added a test case. Author: gatorsmile Closes #10018 from gatorsmile/get_json_object. --- .../expressions/jsonExpressions.scala | 7 +++++-- .../apache/spark/sql/JsonFunctionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd73236a7876..4991b9cb54e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -298,8 +298,11 @@ case class GetJsonObject(json: Expression, path: Expression) case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => // exact field match - p.nextToken() - evaluatePath(p, g, style, xs) + if (p.nextToken() != JsonToken.VALUE_NULL) { + evaluatePath(p, g, style, xs) + } else { + false + } case (FIELD_NAME, Wildcard :: xs) => // wildcard field match diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 14fd56fc8c222..1f384edf321b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -39,6 +39,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { ("6", "[invalid JSON string]") :: Nil + test("function get_json_object - null") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: + Nil + + checkAnswer( + df.select($"key", functions.get_json_object($"jstring", "$.f1"), + functions.get_json_object($"jstring", "$.f2"), + functions.get_json_object($"jstring", "$.f3"), + functions.get_json_object($"jstring", "$.f4"), + functions.get_json_object($"jstring", "$.f5")), + expected) + } + test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") val expected = From 28e46ab46368ea3833c8e805163893bbb6f2a265 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 28 Nov 2015 21:02:05 -0800 Subject: [PATCH 1044/2089] [SPARK-12029][SPARKR] Improve column functions signature, param check, tests, fix doc and add examples shivaram sun-rui Author: felixcheung Closes #10019 from felixcheung/rfunctionsdoc. --- R/pkg/R/functions.R | 121 +++++++++++++++++++++++-------- R/pkg/inst/tests/test_sparkSQL.R | 9 ++- 2 files changed, 96 insertions(+), 34 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e98e7a0117ca0..b30331c61c9a7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -878,7 +878,7 @@ setMethod("rtrim", #'} setMethod("sd", signature(x = "Column"), - function(x, na.rm = FALSE) { + function(x) { # In R, sample standard deviation is calculated with the sd() function. stddev_samp(x) }) @@ -1250,7 +1250,7 @@ setMethod("upper", #'} setMethod("var", signature(x = "Column"), - function(x, y = NULL, na.rm = FALSE, use) { + function(x) { # In R, sample variance is calculated with the var() function. var_samp(x) }) @@ -1467,6 +1467,7 @@ setMethod("pmod", signature(y = "Column"), #' @name approxCountDistinct #' @return the approximate number of distinct items in a group. #' @export +#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.05) { @@ -1481,14 +1482,16 @@ setMethod("approxCountDistinct", #' @name countDistinct #' @return the number of distinct items in a group. #' @export +#' @examples \dontrun{countDistinct(df$c)} setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcol <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function (x) { + stopifnot(class(x) == "Column") x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - jcol) + jcols) column(jc) }) @@ -1501,10 +1504,14 @@ setMethod("countDistinct", #' @rdname concat #' @name concat #' @export +#' @examples \dontrun{concat(df$strings, df$strings2)} setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1518,11 +1525,15 @@ setMethod("concat", #' @rdname greatest #' @name greatest #' @export +#' @examples \dontrun{greatest(df$c, df$d)} setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1530,17 +1541,21 @@ setMethod("greatest", #' least #' #' Returns the least value of the list of column names, skipping null values. -#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' This function takes at least 2 parameters. It will return null if all parameters are null. #' #' @family normal_funcs #' @rdname least #' @name least #' @export +#' @examples \dontrun{least(df$c, df$d)} setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1549,11 +1564,10 @@ setMethod("least", #' #' Computes the ceiling of the given value. #' -#' @family math_funcs #' @rdname ceil -#' @name ceil -#' @aliases ceil +#' @name ceiling #' @export +#' @examples \dontrun{ceiling(df$c)} setMethod("ceiling", signature(x = "Column"), function(x) { @@ -1564,11 +1578,10 @@ setMethod("ceiling", #' #' Computes the signum of the given value. #' -#' @family math_funcs #' @rdname signum -#' @name signum -#' @aliases signum +#' @name sign #' @export +#' @examples \dontrun{sign(df$c)} setMethod("sign", signature(x = "Column"), function(x) { signum(x) @@ -1578,11 +1591,10 @@ setMethod("sign", signature(x = "Column"), #' #' Aggregate function: returns the number of distinct items in a group. #' -#' @family agg_funcs #' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct +#' @name n_distinct #' @export +#' @examples \dontrun{n_distinct(df$c)} setMethod("n_distinct", signature(x = "Column"), function(x, ...) { countDistinct(x, ...) @@ -1592,11 +1604,10 @@ setMethod("n_distinct", signature(x = "Column"), #' #' Aggregate function: returns the number of items in a group. #' -#' @family agg_funcs #' @rdname count -#' @name count -#' @aliases count +#' @name n #' @export +#' @examples \dontrun{n(df$c)} setMethod("n", signature(x = "Column"), function(x) { count(x) @@ -1617,6 +1628,7 @@ setMethod("n", signature(x = "Column"), #' @rdname date_format #' @name date_format #' @export +#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) @@ -1631,6 +1643,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' @rdname from_utc_timestamp #' @name from_utc_timestamp #' @export +#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) @@ -1649,6 +1662,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' @rdname instr #' @name instr #' @export +#' @examples \dontrun{instr(df$c, 'b')} setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) @@ -1663,13 +1677,18 @@ setMethod("instr", signature(y = "Column", x = "character"), #' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first #' Sunday after 2015-07-27. #' -#' Day of the week parameter is case insensitive, and accepts: +#' Day of the week parameter is case insensitive, and accepts first three or two characters: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' #' @family datetime_funcs #' @rdname next_day #' @name next_day #' @export +#' @examples +#'\dontrun{ +#'next_day(df$d, 'Sun') +#'next_day(df$d, 'Sunday') +#'} setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) @@ -1684,6 +1703,7 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' @rdname to_utc_timestamp #' @name to_utc_timestamp #' @export +#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) @@ -1697,8 +1717,8 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' @name add_months #' @family datetime_funcs #' @rdname add_months -#' @name add_months #' @export +#' @examples \dontrun{add_months(df$d, 1)} setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) @@ -1713,6 +1733,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' @rdname date_add #' @name date_add #' @export +#' @examples \dontrun{date_add(df$d, 1)} setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) @@ -1727,6 +1748,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname date_sub #' @name date_sub #' @export +#' @examples \dontrun{date_sub(df$d, 1)} setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) @@ -1735,16 +1757,19 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, #' and returns the result as a string column. #' -#' If d is 0, the result has no decimal point or fractional part. -#' If d < 0, the result will be null.' +#' If x is 0, the result has no decimal point or fractional part. +#' If x < 0, the result will be null. #' +#' @param y column to format +#' @param x number of decimal place to format to #' @family string_funcs #' @rdname format_number #' @name format_number #' @export +#' @examples \dontrun{format_number(df$n, 4)} setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1764,6 +1789,7 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' @rdname sha2 #' @name sha2 #' @export +#' @examples \dontrun{sha2(df$c, 256)} setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) @@ -1779,6 +1805,7 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' @rdname shiftLeft #' @name shiftLeft #' @export +#' @examples \dontrun{shiftLeft(df$c, 1)} setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1796,6 +1823,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' @rdname shiftRight #' @name shiftRight #' @export +#' @examples \dontrun{shiftRight(df$c, 1)} setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1813,6 +1841,7 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned #' @export +#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1830,6 +1859,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @rdname concat_ws #' @name concat_ws #' @export +#' @examples \dontrun{concat_ws('-', df$s, df$d)} setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) @@ -1845,6 +1875,7 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @rdname conv #' @name conv #' @export +#' @examples \dontrun{conv(df$n, 2, 16)} setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { fromBase <- as.integer(fromBase) @@ -1864,6 +1895,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' @rdname expr #' @name expr #' @export +#' @examples \dontrun{expr('length(name)')} setMethod("expr", signature(x = "character"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) @@ -1878,6 +1910,7 @@ setMethod("expr", signature(x = "character"), #' @rdname format_string #' @name format_string #' @export +#' @examples \dontrun{format_string('%d %s', df$a, df$b)} setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { jcols <- lapply(list(x, ...), function(arg) { arg@jc }) @@ -1897,6 +1930,11 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname from_unixtime #' @name from_unixtime #' @export +#' @examples +#'\dontrun{ +#'from_unixtime(df$t) +#'from_unixtime(df$t, 'yyyy/MM/dd HH') +#'} setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1915,6 +1953,7 @@ setMethod("from_unixtime", signature(x = "Column"), #' @rdname locate #' @name locate #' @export +#' @examples \dontrun{locate('b', df$c, 1)} setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 0) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1931,6 +1970,7 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @rdname lpad #' @name lpad #' @export +#' @examples \dontrun{lpad(df$c, 6, '#')} setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1947,12 +1987,13 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname rand #' @name rand #' @export +#' @examples \dontrun{rand()} setMethod("rand", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand") column(jc) }) -#' @family normal_funcs + #' @rdname rand #' @name rand #' @export @@ -1970,12 +2011,13 @@ setMethod("rand", signature(seed = "numeric"), #' @rdname randn #' @name randn #' @export +#' @examples \dontrun{randn()} setMethod("randn", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn") column(jc) }) -#' @family normal_funcs + #' @rdname randn #' @name randn #' @export @@ -1993,6 +2035,7 @@ setMethod("randn", signature(seed = "numeric"), #' @rdname regexp_extract #' @name regexp_extract #' @export +#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), function(x, pattern, idx) { @@ -2010,6 +2053,7 @@ setMethod("regexp_extract", #' @rdname regexp_replace #' @name regexp_replace #' @export +#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), function(x, pattern, replacement) { @@ -2027,6 +2071,7 @@ setMethod("regexp_replace", #' @rdname rpad #' @name rpad #' @export +#' @examples \dontrun{rpad(df$c, 6, '#')} setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2040,12 +2085,17 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' Returns the substring from string str before count occurrences of the delimiter delim. #' If count is positive, everything the left of the final delimiter (counting from left) is #' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' right) is returned. substring_index performs a case-sensitive match when searching for delim. #' #' @family string_funcs #' @rdname substring_index #' @name substring_index #' @export +#' @examples +#'\dontrun{ +#'substring_index(df$c, '.', 2) +#'substring_index(df$c, '.', -1) +#'} setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), function(x, delim, count) { @@ -2066,6 +2116,7 @@ setMethod("substring_index", #' @rdname translate #' @name translate #' @export +#' @examples \dontrun{translate(df$c, 'rnlt', '123')} setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), function(x, matchingString, replaceString) { @@ -2082,12 +2133,18 @@ setMethod("translate", #' @rdname unix_timestamp #' @name unix_timestamp #' @export +#' @examples +#'\dontrun{ +#'unix_timestamp() +#'unix_timestamp(df$t) +#'unix_timestamp(df$t, 'yyyy-MM-dd HH') +#'} setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -2096,7 +2153,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -2113,7 +2170,9 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @family normal_funcs #' @rdname when #' @name when +#' @seealso \link{ifelse} #' @export +#' @examples \dontrun{when(df$age == 2, df$age + 1)} setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc @@ -2130,7 +2189,9 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @family normal_funcs #' @rdname ifelse #' @name ifelse +#' @seealso \link{when} #' @export +#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0fbe0658265b2..899fc3b977385 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -880,14 +880,15 @@ test_that("column functions", { expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) - expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + df4 <- select(df, countDistinct(df$age, df$name)) + expect_equal(collect(df4)[[1, 1]], 2) + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) - expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) - df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) - expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") # Test array_contains() and sort_array() df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) From c793d2d9a1ccc203fc103eb0636958fe8d71f471 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 28 Nov 2015 21:16:21 -0800 Subject: [PATCH 1045/2089] [SPARK-9319][SPARKR] Add support for setting column names, types Add support for for colnames, colnames<-, coltypes<- Also added tests for names, names<- which have no test previously. I merged with PR 8984 (coltypes). Clicked the wrong thing, crewed up the PR. Recreated it here. Was #9218 shivaram sun-rui Author: felixcheung Closes #9654 from felixcheung/colnamescoltypes. --- R/pkg/NAMESPACE | 6 +- R/pkg/R/DataFrame.R | 166 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 20 +++- R/pkg/R/types.R | 8 ++ R/pkg/inst/tests/test_sparkSQL.R | 40 +++++++- 5 files changed, 185 insertions(+), 55 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5d04dd6acaab8..43e5e0119e7fe 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -27,7 +27,10 @@ exportMethods("arrange", "attach", "cache", "collect", + "colnames", + "colnames<-", "coltypes", + "coltypes<-", "columns", "count", "cov", @@ -56,6 +59,7 @@ exportMethods("arrange", "mutate", "na.omit", "names", + "names<-", "ncol", "nrow", "orderBy", @@ -276,4 +280,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") \ No newline at end of file + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a13e7a36766d..f89e2682d9e29 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,6 +254,7 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns + #' @export #' @examples #'\dontrun{ @@ -262,6 +263,7 @@ setMethod("dtypes", #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' columns(df) +#' colnames(df) #'} setMethod("columns", signature(x = "DataFrame"), @@ -290,6 +292,121 @@ setMethod("names<-", } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Get column types of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @return value A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @name coltypes +#' @family DataFrame functions +#' @export +#' @examples +#'\dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#'} +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param value A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. +#' @rdname coltypes +#' @name coltypes<- +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0) { + stop("Cannot set types of an empty DataFrame with no Column") + } + if (length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + if (!is.na(value[i])) { + stype <- rToSQLTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext @@ -2102,52 +2219,3 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) - -#' Returns the column types of a DataFrame. -#' -#' @name coltypes -#' @title Get column types of a DataFrame -#' @family dataframe_funcs -#' @param x (DataFrame) -#' @return value (character) A character vector with the column types of the given DataFrame -#' @rdname coltypes -#' @examples \dontrun{ -#' irisDF <- createDataFrame(sqlContext, iris) -#' coltypes(irisDF) -#' } -setMethod("coltypes", - signature(x = "DataFrame"), - function(x) { - # Get the data types of the DataFrame by invoking dtypes() function - types <- sapply(dtypes(x), function(x) {x[[2]]}) - - # Map Spark data types into R's data types using DATA_TYPES environment - rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { - - # Check for primitive types - type <- PRIMITIVE_TYPES[[x]] - - if (is.null(type)) { - # Check for complex types - for (t in names(COMPLEX_TYPES)) { - if (substring(x, 1, nchar(t)) == t) { - type <- COMPLEX_TYPES[[t]] - break - } - } - - if (is.null(type)) { - stop(paste("Unsupported data type: ", x)) - } - } - type - }) - - # Find which types don't have mapping to R - naIndices <- which(is.na(rTypes)) - - # Assign the original scala data types to the unmatched ones - rTypes[naIndices] <- types[naIndices] - - rTypes - }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0c305441e043e..711ce38f9e104 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,6 +385,22 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname columns +#' @export +setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) + +#' @rdname columns +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) @@ -1081,7 +1097,3 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") - -#' @rdname coltypes -#' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index 1828c23ab0f6d..dae4fe858bdbc 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -41,3 +41,11 @@ COMPLEX_TYPES <- list( # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) + +# An environment for mapping R to Scala, names are R types and values are Scala types. +rToSQLTypes <- as.environment(list( + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", + "character" = "string", + "logical" = "boolean")) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 899fc3b977385..d3b2f20bf81c5 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -622,6 +622,26 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form expect_equal(testNames[2], "name") }) +test_that("names() colnames() set the column names", { + df <- jsonFile(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") + + # Test base::colnames base::names + m2 <- cbind(1, 1:4) + expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) + colnames(m2) <- c("x","Y") + expect_equal(colnames(m2), c("x", "Y")) + + z <- list(a = 1, b = "c", c = 1:3) + expect_equal(names(z)[3], "c") + names(z)[3] <- "c2" + expect_equal(names(z)[3], "c2") +}) + test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) @@ -1617,7 +1637,7 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) -test_that("Method coltypes() to get R's data types of a DataFrame", { +test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) data <- data.frame(c1=c(1,2,3), @@ -1636,6 +1656,24 @@ test_that("Method coltypes() to get R's data types of a DataFrame", { x <- createDataFrame(sqlContext, list(list(as.environment( list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") + + df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("character", "integer") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + + expect_error(coltypes(df) <- c("character"), + "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") }) unlink(parquetPath) From cc7a1bc9370b163f51230e5ca4be612d133a5086 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sun, 29 Nov 2015 11:08:26 -0800 Subject: [PATCH 1046/2089] [SPARK-11781][SPARKR] SparkR has problem in inferring type of raw type. Author: Sun Rui Closes #9769 from sun-rui/SPARK-11781. --- R/pkg/R/DataFrame.R | 34 ++++++++++++++++------------- R/pkg/R/SQLContext.R | 2 +- R/pkg/R/types.R | 37 ++++++++++++++++++-------------- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f89e2682d9e29..a82ded9c51fac 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -793,8 +793,8 @@ setMethod("dim", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - names <- columns(x) - ncol <- length(names) + dtypes <- dtypes(x) + ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() @@ -817,25 +817,29 @@ setMethod("collect", # data of complex type can be held. But getting a cell from a column # of list type returns a list instead of a vector. So for columns of # non-complex type, append them as vector. + # + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] + colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[names[colIndex]]] <- col + df[[colName]] <- col } else { - # TODO: more robust check on column of primitive types - vec <- do.call(c, col) - if (class(vec) != "list") { - df[[names[colIndex]]] <- vec + colType <- dtypes[[colIndex]][[2]] + # Note that "binary" columns behave like complex types. + if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { + vec <- do.call(c, col) + stopifnot(class(vec) != "list") + df[[colName]] <- vec } else { - # For columns of complex type, be careful to access them. - # Get a column of complex type returns a list. - # Get a cell from a column of complex type returns a list instead of a vector. - df[[names[colIndex]]] <- col - } + df[[colName]] <- col + } + } } + df } - df - } - }) + }) #' Limit #' diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a62b25fde926d..85541c8e22447 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -63,7 +63,7 @@ infer_type <- function(x) { }) type <- Reduce(paste0, type) type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") - } else if (length(x) > 1) { + } else if (length(x) > 1 && type != "binary") { paste0("array<", infer_type(x[[1]]), ">") } else { type diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index dae4fe858bdbc..1f06af7e904fe 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -19,25 +19,30 @@ # values are equivalent R types. This is stored in an environment to allow for # more efficient look up (environments use hashmaps). PRIMITIVE_TYPES <- as.environment(list( - "byte"="integer", - "tinyint"="integer", - "smallint"="integer", - "integer"="integer", - "bigint"="numeric", - "float"="numeric", - "double"="numeric", - "decimal"="numeric", - "string"="character", - "binary"="raw", - "boolean"="logical", - "timestamp"="POSIXct", - "date"="Date")) + "tinyint" = "integer", + "smallint" = "integer", + "int" = "integer", + "bigint" = "numeric", + "float" = "numeric", + "double" = "numeric", + "decimal" = "numeric", + "string" = "character", + "binary" = "raw", + "boolean" = "logical", + "timestamp" = "POSIXct", + "date" = "Date", + # following types are not SQL types returned by dtypes(). They are listed here for usage + # by checkType() in schema.R. + # TODO: refactor checkType() in schema.R. + "byte" = "integer", + "integer" = "integer" + )) # The complex data types. These do not have any direct mapping to R's types. COMPLEX_TYPES <- list( - "map"=NA, - "array"=NA, - "struct"=NA) + "map" = NA, + "array" = NA, + "struct" = NA) # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d3b2f20bf81c5..92ec82096c6df 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -72,6 +72,8 @@ test_that("infer types and check types", { expect_equal(infer_type(e), "map") expect_error(checkType("map"), "Key type in a map must be string or character") + + expect_equal(infer_type(as.raw(c(1, 2, 3))), "binary") }) test_that("structType and structField", { @@ -250,6 +252,10 @@ test_that("create DataFrame from list or data.frame", { mtcarsdf <- createDataFrame(sqlContext, mtcars) expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(sqlContext, list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) }) test_that("create DataFrame with different data types", { From 3d28081e53698ed77e93c04299957c02bcaba9bf Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 29 Nov 2015 14:13:11 -0800 Subject: [PATCH 1047/2089] [SPARK-12024][SQL] More efficient multi-column counting. In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell Closes #10015 from hvanhovell/SPARK-12024. --- .../expressions/aggregate/Count.scala | 21 ++-------- .../expressions/conditionalExpressions.scala | 27 ------------- .../sql/catalyst/optimizer/Optimizer.scala | 14 ++++--- .../ConditionalExpressionSuite.scala | 14 ------- .../spark/sql/execution/aggregate/utils.scala | 39 +++++++++---------- .../spark/sql/expressions/WindowSpec.scala | 4 +- 6 files changed, 33 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 09a1da9200df0..441f52ab5ca58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Count(child: Expression) extends DeclarativeAggregate { - override def children: Seq[Expression] = child :: Nil +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false @@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = LongType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) private lazy val count = AttributeReference("count", LongType)() @@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions = Seq( - /* count = */ If(IsNull(child), count, count + 1L) + /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) ) override lazy val mergeExpressions = Seq( @@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate { } object Count { - def apply(children: Seq[Expression]): Count = { - // This is used to deal with COUNT DISTINCT. When we have multiple - // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). - // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any - // null in the arguments, we will not count that row. So, we use DropAnyNull at here - // to return a null when any field of the created STRUCT is null. - val child = if (children.size > 1) { - DropAnyNull(CreateStruct(children)) - } else { - children.head - } - Count(child) - } + def apply(child: Expression): Count = Count(child :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 694a2a7c54a90..40b1eec63e551 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } -/** Operator that drops a row when it contains any nulls. */ -case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StructType) - - protected override def nullSafeEval(input: Any): InternalRow = { - val row = input.asInstanceOf[InternalRow] - if (row.anyNull) { - null - } else { - row - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.anyNull()) { - ${ev.isNull} = true; - } else { - ${ev.value} = $eval; - } - """ - }) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2901d8f2efddf..06d14fcf8b9c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { + def nonNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => false + case _ => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter { - case Literal(null, _) => false - case _ => true - } + val newChildren = children.filter(nonNullLiteral) if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index c1e3c17b87102..0df673bb9fa02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } - - test("function dropAnyNull") { - val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) - val a = create_row("a", "q") - val nullStr: String = null - checkEvaluation(drop, a, a) - checkEvaluation(drop, null, create_row("b", nullStr)) - checkEvaluation(drop, null, create_row(nullStr, nullStr)) - - val row = 'r.struct( - StructField("a", StringType, false), - StructField("b", StringType, true)).at(0) - checkEvaluation(DropAnyNull(row), null, create_row(null)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index a70e41436c7aa..76b938cdb694e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -146,20 +146,16 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one - // DISTINCT aggregate function, all of those functions will have the same column expression. + // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpression: Expression = { - val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - assert(allDistinctColumnExpressions.length == 1) - allDistinctColumnExpressions.head - } - val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match { + val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. @@ -170,10 +166,11 @@ object Utils { // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression + val partialAggregateGroupingExpressions = + groupingExpressions ++ namedDistinctColumnExpressions val partialAggregateResult = groupingAttributes ++ - Seq(distinctColumnAttribute) ++ + distinctColumnAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) if (usesTungstenAggregate) { TungstenAggregate( @@ -208,28 +205,28 @@ object Utils { partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = groupingAttributes ++ - Seq(distinctColumnAttribute) ++ + distinctColumnAttributes ++ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } @@ -244,14 +241,16 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + val distinctColumnAttributeLookup = + distinctColumnExpressions.zip(distinctColumnAttributes).toMap val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction] + val rewrittenAggregateFunction = aggregateFunction + .transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, @@ -270,7 +269,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = resultExpressions, child = partialMergeAggregate) } else { @@ -281,7 +280,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = resultExpressions, child = partialMergeAggregate) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index fc873c04f88f0..893e800a61438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -152,8 +152,8 @@ class WindowSpec private[sql]( case Sum(child) => WindowExpression( UnresolvedWindowFunction("sum", child :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), + case Count(children) => WindowExpression( + UnresolvedWindowFunction("count", children), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case First(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF first_value From 0ddfe7868948e302858a2b03b50762eaefbeb53e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 29 Nov 2015 19:02:15 -0800 Subject: [PATCH 1048/2089] [SPARK-12039] [SQL] Ignore HiveSparkSubmitSuite's "SPARK-9757 Persist Parquet relation with decimal column". https://issues.apache.org/jira/browse/SPARK-12039 Since it is pretty flaky in hadoop 1 tests, we can disable it while we are investigating the cause. Author: Yin Huai Closes #10035 from yhuai/SPARK-12039-ignore. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 24a3afee148c5..92962193311d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -101,7 +101,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-9757 Persist Parquet relation with decimal column") { + ignore("SPARK-9757 Persist Parquet relation with decimal column") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SPARK_9757.getClass.getName.stripSuffix("$"), From e0749442051d6e29dae4f4cdcb2937c0b015f98f Mon Sep 17 00:00:00 2001 From: toddwan Date: Mon, 30 Nov 2015 09:26:29 +0000 Subject: [PATCH 1049/2089] [SPARK-11859][MESOS] SparkContext accepts invalid Master URLs in the form zk://host:port for a multi-master Mesos cluster using ZooKeeper * According to below doc and validation logic in [SparkSubmit.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L231), master URL for a mesos cluster should always start with `mesos://` http://spark.apache.org/docs/latest/running-on-mesos.html `The Master URLs for Mesos are in the form mesos://host:5050 for a single-master Mesos cluster, or mesos://zk://host:2181 for a multi-master Mesos cluster using ZooKeeper.` * However, [SparkContext.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/SparkContext.scala#L2749) fails the validation and can receive master URL in the form `zk://host:port` * For the master URLs in the form `zk:host:port`, the valid form should be `mesos://zk://host:port` * This PR restrict the validation in `SparkContext.scala`, and now only mesos master URLs prefixed with `mesos://` can be accepted. * This PR also updated corresponding unit test. Author: toddwan Closes #9886 from toddwan/S11859. --- .../scala/org/apache/spark/SparkContext.scala | 16 ++++++++++------ .../SparkContextSchedulerCreationSuite.scala | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b030d3c71dc20..8a62b71c3fa68 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2708,15 +2708,14 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case mesosUrl @ MESOS_REGEX(_) => + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) - val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) + new CoarseMesosSchedulerBackend(scheduler, sc, mesosUrl, sc.env.securityManager) } else { - new MesosSchedulerBackend(scheduler, sc, url) + new MesosSchedulerBackend(scheduler, sc, mesosUrl) } scheduler.initialize(backend) (backend, scheduler) @@ -2727,6 +2726,11 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) + case zkUrl if zkUrl.startsWith("zk://") => + logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + + "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") + createTaskScheduler(sc, "mesos://" + zkUrl) + case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") } @@ -2745,8 +2749,8 @@ private object SparkMasterRegex { val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Mesos cluster by mesos:// or mesos://zk:// url + val MESOS_REGEX = """mesos://(.*)""".r // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index e5a14a69ef05f..d18e0782c0392 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -175,6 +175,11 @@ class SparkContextSchedulerCreationSuite } test("mesos with zookeeper") { + testMesos("mesos://zk://localhost:1234,localhost:2345", + classOf[MesosSchedulerBackend], coarse = false) + } + + test("mesos with zookeeper and Master URL starting with zk://") { testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } From 953e8e6dcb32cd88005834e9c3720740e201826c Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 30 Nov 2015 09:30:58 +0000 Subject: [PATCH 1050/2089] [MINOR][BUILD] Changed the comment to reflect the plugin project is there to support SBT pom reader only. Author: Prashant Sharma Closes #10012 from ScrapCodes/minor-build-comment. --- project/project/SparkPluginBuild.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 471d00bd8223f..cbb88dc7dd1dd 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -19,9 +19,8 @@ import sbt._ import sbt.Keys._ /** - * This plugin project is there to define new scala style rules for spark. This is - * a plugin project so that this gets compiled first and is put on the classpath and - * becomes available for scalastyle sbt plugin. + * This plugin project is there because we use our custom fork of sbt-pom-reader plugin. This is + * a plugin project so that this gets compiled first and is available on the classpath for SBT build. */ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) From 26c3581f17f475fab2f3b5301b8f253ff2fa6438 Mon Sep 17 00:00:00 2001 From: Wieland Hoffmann Date: Mon, 30 Nov 2015 09:32:48 +0000 Subject: [PATCH 1051/2089] [DOC] Explicitly state that top maintains the order of elements Top is implemented in terms of takeOrdered, which already maintains the order, so top should, too. Author: Wieland Hoffmann Closes #10013 from mineo/top-order. --- .../main/scala/org/apache/spark/api/java/JavaRDDLike.scala | 4 ++-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 871be0b1f39ea..1e9d4f1803a81 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -556,7 +556,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD as defined by - * the specified Comparator[T]. + * the specified Comparator[T] and maintains the order. * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -567,7 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD using the - * natural ordering for T. + * natural ordering for T and maintains the order. * @param num k, the number of top elements to return * @return an array of top elements */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aeb5eeaad32c..8b3731d935788 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1327,7 +1327,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the top k (largest) elements from this RDD as defined by the specified - * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: + * implicit Ordering[T] and maintains the ordering. This does the opposite of + * [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) * // returns Array(12) From bf0e85a70a54a2d7fd6804b6bd00c63c20e2bb00 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 30 Nov 2015 10:11:27 +0000 Subject: [PATCH 1052/2089] [SPARK-12023][BUILD] Fix warnings while packaging spark with maven. this is a trivial fix, discussed [here](http://stackoverflow.com/questions/28500401/maven-assembly-plugin-warning-the-assembly-descriptor-contains-a-filesystem-roo/). Author: Prashant Sharma Closes #10014 from ScrapCodes/assembly-warning. --- assembly/src/main/assembly/assembly.xml | 8 ++++---- external/mqtt/src/main/assembly/assembly.xml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 711156337b7c3..009d4b92f406c 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -32,7 +32,7 @@ ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/org/apache/spark/ui/static + ui-resources/org/apache/spark/ui/static **/* @@ -41,7 +41,7 @@ ${project.parent.basedir}/sbin/ - /sbin + sbin **/* @@ -50,7 +50,7 @@ ${project.parent.basedir}/bin/ - /bin + bin **/* @@ -59,7 +59,7 @@ ${project.parent.basedir}/assembly/target/${spark.jar.dir} - / + ${spark.jar.basename} diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml index ecab5b360eb3e..c110b01b34e10 100644 --- a/external/mqtt/src/main/assembly/assembly.xml +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -24,7 +24,7 @@ ${project.build.directory}/scala-${scala.binary.version}/test-classes - / + From 2db4662fe2d72749c06ad5f11f641a388343f77c Mon Sep 17 00:00:00 2001 From: CK50 Date: Mon, 30 Nov 2015 20:08:49 +0800 Subject: [PATCH 1053/2089] [SPARK-11989][SQL] Only use commit in JDBC data source if the underlying database supports transactions Fixes [SPARK-11989](https://issues.apache.org/jira/browse/SPARK-11989) Author: CK50 Author: Christian Kurz Closes #9973 from CK50/branch-1.6_non-transactional. (cherry picked from commit a589736a1b237ef2f3bd59fbaeefe143ddcc8f4e) Signed-off-by: Reynold Xin --- .../datasources/jdbc/JdbcUtils.scala | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7375a5c09123f..252f1cfd5d9c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -21,6 +21,7 @@ import java.sql.{Connection, PreparedStatement} import java.util.Properties import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} @@ -125,8 +126,19 @@ object JdbcUtils extends Logging { dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false + val supportsTransactions = try { + conn.getMetaData().supportsDataManipulationTransactionsOnly() || + conn.getMetaData().supportsDataDefinitionAndDataManipulationTransactions() + } catch { + case NonFatal(e) => + logWarning("Exception while detecting transaction support", e) + true + } + try { - conn.setAutoCommit(false) // Everything in the same db transaction. + if (supportsTransactions) { + conn.setAutoCommit(false) // Everything in the same db transaction. + } val stmt = insertStatement(conn, table, rddSchema) try { var rowCount = 0 @@ -175,14 +187,18 @@ object JdbcUtils extends Logging { } finally { stmt.close() } - conn.commit() + if (supportsTransactions) { + conn.commit() + } committed = true } finally { if (!committed) { // The stage must fail. We got here through an exception path, so // let the exception through unless rollback() or close() want to // tell the user about another problem. - conn.rollback() + if (supportsTransactions) { + conn.rollback() + } conn.close() } else { // The stage must succeed. We cannot propagate any exception close() might throw. From 17275fa99c670537c52843df405279a52b5c9594 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 10:32:13 -0800 Subject: [PATCH 1054/2089] [SPARK-11700] [SQL] Remove thread local SQLContext in SparkPlan In 1.6, we introduce a public API to have a SQLContext for current thread, SparkPlan should use that. Author: Davies Liu Closes #9990 from davies/leak_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 10 +++++----- .../spark/sql/execution/QueryExecution.scala | 3 +-- .../org/apache/spark/sql/execution/SparkPlan.scala | 14 ++++---------- .../apache/spark/sql/MultiSQLContextsSuite.scala | 2 +- .../sql/execution/ExchangeCoordinatorSuite.scala | 2 +- .../sql/execution/RowFormatConvertersSuite.scala | 4 ++-- 6 files changed, 14 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c2ac5f6f11bf..8d2783952532a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,7 +26,6 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SparkException, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -45,9 +44,10 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -401,7 +401,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) @@ -417,7 +417,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -1334,7 +1334,7 @@ object SQLContext { activeContext.remove() } - private[sql] def getActiveContextOption(): Option[SQLContext] = { + private[sql] def getActive(): Option[SQLContext] = { Option(activeContext.get()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5da5aea17e25b..107570f9dbcc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -42,9 +42,8 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) - // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) sqlContext.planner.plan(optimizedPlan).next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 534a3bcb8364d..507641ff8263e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,21 +23,15 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType -object SparkPlan { - protected[sql] val currentContext = new ThreadLocal[SQLContext]() -} - /** * The base class for physical operators. */ @@ -49,7 +43,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SparkPlan.currentContext.get() + protected[spark] final val sqlContext = SQLContext.getActive().get protected def sparkContext = sqlContext.sparkContext @@ -69,7 +63,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propogates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) super.makeCopy(newArgs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 34c5c68fd1c18..162c0b56c6e11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -27,7 +27,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var sparkConf: SparkConf = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b96d50a70b85c..180050bdac00f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -30,7 +30,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalInstantiatedSQLContext: Option[SQLContext] = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 6876ab0f02b10..13d68a103a225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -94,7 +94,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) From 8df584b0200402d8b2ce0a8de24f7a760ced8655 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 11:54:18 -0800 Subject: [PATCH 1055/2089] [SPARK-11982] [SQL] improve performance of cartesian product This PR improve the performance of CartesianProduct by caching the result of right plan. After this patch, the query time of TPC-DS Q65 go down to 4 seconds from 28 minutes (420X faster). cc nongli Author: Davies Liu Closes #9969 from davies/improve_cartesian. --- .../unsafe/sort/UnsafeExternalSorter.java | 63 +++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 7 ++ .../execution/joins/CartesianProduct.scala | 76 +++++++++++++++++-- .../execution/metric/SQLMetricsSuite.scala | 2 +- 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 2e40312674737..5a97f4f11340c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import java.util.Queue; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -521,4 +522,66 @@ public long getKeyPrefix() { return upstream.getKeyPrefix(); } } + + /** + * Returns a iterator, which will return the rows in the order as inserted. + * + * It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getIterator() throws IOException { + if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + return inMemSorter.getIterator(); + } else { + LinkedList queue = new LinkedList<>(); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + queue.add(spillWriter.getReader(blockManager)); + } + if (inMemSorter != null) { + queue.add(inMemSorter.getIterator()); + } + return new ChainedIterator(queue); + } + } + + /** + * Chain multiple UnsafeSorterIterator together as single one. + */ + class ChainedIterator extends UnsafeSorterIterator { + + private final Queue iterators; + private UnsafeSorterIterator current; + + public ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.iterators = iterators; + this.current = iterators.remove(); + } + + @Override + public boolean hasNext() { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + return current.hasNext(); + } + + @Override + public void loadNext() throws IOException { + current.loadNext(); + } + + @Override + public Object getBaseObject() { return current.getBaseObject(); } + + @Override + public long getBaseOffset() { return current.getBaseOffset(); } + + @Override + public int getRecordLength() { return current.getRecordLength(); } + + @Override + public long getKeyPrefix() { return current.getKeyPrefix(); } + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index dce1f15a2963c..c91e88f31bf9b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -226,4 +226,11 @@ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); return new SortedIterator(pos / 2); } + + /** + * Returns an iterator over record pointers in original order (inserted). + */ + public SortedIterator getIterator() { + return new SortedIterator(pos / 2); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index f467519b802a7..fa2bc7672131c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,16 +17,75 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + + +/** + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ +private[spark] +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, + iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(partition.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), @@ -39,18 +98,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod val leftResults = left.execute().map { row => numLeftRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } val rightResults = right.execute().map { row => numRightRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } - leftResults.cartesian(rightResults).mapPartitionsInternal { iter => - val joinedRow = new JoinedRow + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) iter.map { r => numOutputRows += 1 - joinedRow(r._1, r._2) + joiner.join(r._1, r._2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index ebfa1eaf3e5bc..4f2cad19bfb6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -317,7 +317,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df, 1, Map( 1L -> ("CartesianProduct", Map( "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 12L, // right is read 6 times + "number of right rows" -> 4L, // right is read twice "number of output rows" -> 12L))) ) } From f2fbfa444f6e8d27953ec2d1c0b3abd603c963f9 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 30 Nov 2015 13:02:08 -0800 Subject: [PATCH 1056/2089] [MINOR][DOCS] fixed list display in ml-ensembles The list in ml-ensembles.md wasn't properly formatted and, as a result, was looking like this: ![old](http://i.imgur.com/2ZhELLR.png) This PR aims to make it look like this: ![new](http://i.imgur.com/0Xriwd2.png) Author: BenFradet Closes #10025 from BenFradet/ml-ensembles-doc. --- docs/ml-ensembles.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index f6c3c30d5334f..14fef76f260ff 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -20,6 +20,7 @@ Both use [MLlib decision trees](ml-decision-tree.html) as their base models. Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + * support for ML Pipelines * separation of classification vs. regression * use of DataFrame metadata to distinguish continuous and categorical features From 2c5dee0fb8e4d1734ea3a0f22e0b5bfd2f6dba46 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 13:41:52 -0800 Subject: [PATCH 1057/2089] Revert "[SPARK-11206] Support SQL UI on the history server" This reverts commit cc243a079b1c039d6e7f0b410d1654d94a090e14 / PR #9297 I'm reverting this because it broke SQLListenerMemoryLeakSuite in the master Maven builds. See #9991 for a discussion of why this broke the tests. --- .rat-excludes | 1 - .../org/apache/spark/JavaSparkListener.java | 3 - .../apache/spark/SparkFirehoseListener.java | 4 - .../scheduler/EventLoggingListener.scala | 4 - .../spark/scheduler/SparkListener.scala | 24 +-- .../spark/scheduler/SparkListenerBus.scala | 1 - .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 - .../org/apache/spark/sql/SQLContext.scala | 18 +-- .../spark/sql/execution/SQLExecution.scala | 24 ++- .../spark/sql/execution/SparkPlanInfo.scala | 46 ------ .../sql/execution/metric/SQLMetricInfo.scala | 30 ---- .../sql/execution/metric/SQLMetrics.scala | 56 +++---- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++------------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 43 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 - 21 files changed, 135 insertions(+), 327 deletions(-) delete mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 7262c960ed6bb..08fba6d351d6a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,5 +82,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister -org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 23bc9a2e81727..fa9acf0a15b88 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,7 +82,4 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } - @Override - public void onOtherEvent(SparkListenerEvent event) { } - } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index e6b24afd88ad4..1214d05ba6063 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,8 +118,4 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } - @Override - public void onOtherEvent(SparkListenerEvent event) { - onEvent(event); - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index eaa07acc5132e..000a021a528cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,10 +207,6 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } - override def onOtherEvent(event: SparkListenerEvent): Unit = { - logEvent(event, flushLogger = true) - } - /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 075a7f13172de..896f1743332f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,19 +22,15 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import com.fasterxml.jackson.annotation.JsonTypeInfo - -import org.apache.spark.{Logging, SparkConf, TaskEndReason} +import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} -import org.apache.spark.ui.SparkUI @DeveloperApi -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") -trait SparkListenerEvent +sealed trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -134,17 +130,6 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent -/** - * Interface for creating history listeners defined in other modules like SQL, which are used to - * rebuild the history UI. - */ -private[spark] trait SparkHistoryListenerFactory { - /** - * Create listeners used to rebuild the history UI. - */ - def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] -} - /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -238,11 +223,6 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } - - /** - * Called when other events like SQL-specific events are posted. - */ - def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 95722a07144ec..04afde33f5aad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,7 +61,6 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata - case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 8da6884a38535..4608bce202ec8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,13 +17,10 @@ package org.apache.spark.ui -import java.util.{Date, ServiceLoader} - -import scala.collection.JavaConverters._ +import java.util.Date import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -157,16 +154,7 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - val sparkUI = create( - None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) - - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, sparkUI) - listeners.foreach(listenerBus.addListener) - } - sparkUI + create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7f5d713ec6505..c9beeb25e05af 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,21 +19,19 @@ package org.apache.spark.util import java.util.{Properties, UUID} +import org.apache.spark.scheduler.cluster.ExecutorInfo + import scala.collection.JavaConverters._ import scala.collection.Map -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -56,8 +54,6 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats - private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) - /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -100,7 +96,6 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this - case _ => parse(mapper.writeValueAsString(event)) } } @@ -516,8 +511,6 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) - case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) - .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory deleted file mode 100644 index 507100be90967..0000000000000 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8d2783952532a..9cc65de19180a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1263,8 +1263,6 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() - @transient private val sqlListener = new AtomicReference[SQLListener]() - /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1309,10 +1307,6 @@ object SQLContext { Option(instantiatedContext.get()) } - private[sql] def clearSqlListener(): Unit = { - sqlListener.set(null) - } - /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1361,13 +1355,9 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - if (sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - sqlListener.get() + val listener = new SQLListener(sc.conf) + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + listener } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 34971986261c2..1422e15549c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,8 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, - SparkListenerSQLExecutionEnd} +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -46,14 +45,25 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + sqlContext.listener.onExecutionStart( + executionId, + callSite.shortForm, + callSite.longForm, + queryExecution.toString, + SparkPlanGraph(queryExecution.executedPlan), + System.currentTimeMillis()) try { body } finally { - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. + // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new + // SQL event types to SparkListener since it's a public API, we cannot guarantee that. + // + // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. + // + // The worst case is onExecutionEnd may happen before onJobStart when the listener thread + // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. + sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala deleted file mode 100644 index 486ce34064e43..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.metric.SQLMetricInfo -import org.apache.spark.util.Utils - -/** - * :: DeveloperApi :: - * Stores information about a SQL SparkPlan. - */ -@DeveloperApi -class SparkPlanInfo( - val nodeName: String, - val simpleString: String, - val children: Seq[SparkPlanInfo], - val metrics: Seq[SQLMetricInfo]) - -private[sql] object SparkPlanInfo { - - def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, - Utils.getFormattedClassName(metric.param)) - } - val children = plan.children.map(fromSparkPlan) - - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala deleted file mode 100644 index 2708219ad3485..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.metric - -import org.apache.spark.annotation.DeveloperApi - -/** - * :: DeveloperApi :: - * Stores information about a SQL Metric. - */ -@DeveloperApi -class SQLMetricInfo( - val name: String, - val accumulatorId: Long, - val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 6c0f6f8a52dc5..1c253e3942e95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,39 +104,21 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } -private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) - -private object StaticsLongSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) - private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - param: LongSQLMetricParam): LongSQLMetric = { + stringValue: Seq[Long] => String, + initialValue: Long): LongSQLMetric = { + val param = new LongSQLMetricParam(stringValue, initialValue) val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, LongSQLMetricParam) + createLongMetric(sc, name, _.sum.toString, 0L) } /** @@ -144,25 +126,31 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { + val stringValue = (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) - } - - def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { - val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) - val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) - val metricParam = metricParamName match { - case `longSQLMetricParam` => LongSQLMetricParam - case `staticsSQLMetricParam` => StaticsLongSQLMetricParam - } - metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] + createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) + val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index c74ad40406992..e74d6fb396e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} + +import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index e19a1e3e5851f..5a072de400b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,34 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} -import org.apache.spark.ui.SparkUI - -@DeveloperApi -case class SparkListenerSQLExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - sparkPlanInfo: SparkPlanInfo, - time: Long) - extends SparkListenerEvent - -@DeveloperApi -case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) - extends SparkListenerEvent - -private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { - - override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { - List(new SQLHistoryListener(conf, sparkUI)) - } -} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -141,8 +118,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), - finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) } } @@ -164,7 +140,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulatorUpdates(), + taskEnd.taskMetrics, finishTask = true) } @@ -172,12 +148,15 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - protected def updateTaskAccumulatorValues( + private def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - accumulatorUpdates: Map[Long, Any], + metrics: TaskMetrics, finishTask: Boolean): Unit = { + if (metrics == null) { + return + } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -195,9 +174,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = accumulatorUpdates + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = accumulatorUpdates + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -206,7 +185,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, accumulatorUpdates) + attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) } } case None => @@ -214,40 +193,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) => - val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - val executionUIData = new SQLExecutionUIData( - executionId, - description, - details, - physicalPlanDescription, - physicalPlanGraph, - sqlPlanMetrics.toMap, - time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. - } + def onExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + physicalPlanGraph: SparkPlanGraph, + time: Long): Unit = { + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + + val executionUIData = new SQLExecutionUIData(executionId, description, details, + physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + } + + def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. } } - case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -312,38 +289,6 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } -private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) - extends SQLListener(conf) { - - private var sqlTabAttached = false - - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - // Do nothing - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { acc => - (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) - }.toMap, - finishTask = true) - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case _: SparkListenerSQLExecutionStart => - if (!sqlTabAttached) { - new SQLTab(this, sparkUI) - sqlTabAttached = true - } - super.onOtherEvent(event) - case _ => super.onOtherEvent(event) - } -} - /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 4f50b2ecdc8f8..9c27944d42fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.ui +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, "SQL") with Logging { + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI @@ -33,5 +35,13 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL$nextId" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 7af0ff09c5c6d..f1fce5478a3fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { + def apply(plan: SparkPlan): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - planInfo: SparkPlanInfo, + plan: SparkPlan, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) nodes += node - val childrenNodes = planInfo.children.map( + val childrenNodes = plan.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4f2cad19bfb6b..82867ab4967bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,7 +26,6 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -83,8 +82,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( - df.queryExecution.executedPlan)).nodes.filter { node => + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index f93d081d0c30e..c15aac775096c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,8 +82,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) - .nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -91,13 +90,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) val executionUIData = listener.executionIdToData(0) @@ -207,8 +206,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -221,20 +219,19 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -251,13 +248,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -274,8 +271,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -292,20 +288,19 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e7b376548787c..963d10eed62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,7 +42,6 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { - SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From a8ceec5e8c1572dd3d74783c06c78b7ca0b8a7ce Mon Sep 17 00:00:00 2001 From: Teng Qiu Date: Tue, 1 Dec 2015 07:27:32 +0900 Subject: [PATCH 1058/2089] [SPARK-12053][CORE] EventLoggingListener.getLogPath needs 4 parameters ```EventLoggingListener.getLogPath``` needs 4 input arguments: https://github.com/apache/spark/blob/v1.6.0-preview2/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L276-L280 the 3rd parameter should be appAttemptId, 4th parameter is codec... Author: Teng Qiu Closes #10044 from chutium/SPARK-12053. --- core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9952c97dbdffc..1355e1ad1b523 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -934,7 +934,7 @@ private[deploy] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) From e232720a65dfb9ae6135cbb7674e35eddd88d625 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 30 Nov 2015 14:56:51 -0800 Subject: [PATCH 1059/2089] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Original PR is reverted due to document build error. https://github.com/apache/spark/pull/9722 mengxr feynmanliang yinxusen Sorry for the troubling. Author: Yuhao Yang Closes #9974 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 31 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 97 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 0000000000000..cfefb5dfbde9e --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,31 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a17..6f35b30c3d4df 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %}
    - + \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec4..54e35fcbb15af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 0000000000000..3a5d3237c85f6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; +// $example on$ +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + // $example on$ + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } + // $example off$ +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 0000000000000..419ce3d87a6ac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From de64b65f7cf2ac58c1abc310ba547637fdbb8557 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 30 Nov 2015 15:01:08 -0800 Subject: [PATCH 1060/2089] [SPARK-11975][ML] Remove duplicate mllib example (DT/RF/GBT in Java/Python) Remove duplicate mllib example (DT/RF/GBT in Java/Python). Since we have tutorial code for DT/RF/GBT classification/regression in Scala/Java/Python and example applications for DT/RF/GBT in Scala, so we mark these as duplicated and remove them. mengxr Author: Yanbo Liang Closes #9954 from yanboliang/SPARK-11975. --- .../examples/mllib/JavaDecisionTree.java | 116 -------------- .../mllib/JavaGradientBoostedTreesRunner.java | 126 --------------- .../mllib/JavaRandomForestExample.java | 139 ----------------- .../main/python/mllib/decision_tree_runner.py | 144 ------------------ .../python/mllib/gradient_boosted_trees.py | 77 ---------- .../python/mllib/random_forest_example.py | 90 ----------- 6 files changed, 692 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java delete mode 100755 examples/src/main/python/mllib/decision_tree_runner.py delete mode 100644 examples/src/main/python/mllib/gradient_boosted_trees.py delete mode 100755 examples/src/main/python/mllib/random_forest_example.py diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java deleted file mode 100644 index 1f82e3f4cb18e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.HashMap; - -import scala.Tuple2; - -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -/** - * Classification and regression using decision trees. - */ -public final class JavaDecisionTree { - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - if (args.length == 1) { - datapath = args[0]; - } else if (args.length > 1) { - System.err.println("Usage: JavaDecisionTree "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - - // Set parameters. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; - - // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - - // Train a DecisionTree model for regression. - impurity = "variance"; - final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD regressorPredictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(regressionModel.predict(p.features()), p.label()); - } - }); - Double trainMSE = - regressorPredictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + regressionModel); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java deleted file mode 100644 index a1844d5d07ad4..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -/** - * Classification and regression using gradient-boosted decision trees. - */ -public final class JavaGradientBoostedTreesRunner { - - private static void usage() { - System.err.println("Usage: JavaGradientBoostedTreesRunner " + - " "); - System.exit(-1); - } - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - String algo = "Classification"; - if (args.length >= 1) { - datapath = args[0]; - } - if (args.length >= 2) { - algo = args[1]; - } - if (args.length > 2) { - usage(); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Set parameters. - // Note: All features are treated as continuous. - BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); - boostingStrategy.setNumIterations(10); - boostingStrategy.treeStrategy().setMaxDepth(5); - - if (algo.equals("Classification")) { - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - boostingStrategy.treeStrategy().setNumClasses(numClasses); - - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - } else if (algo.equals("Regression")) { - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainMSE = - predictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + model); - } else { - usage(); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java deleted file mode 100644 index 89a4e092a5af7..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import java.util.HashMap; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -public final class JavaRandomForestExample { - - /** - * Note: This example illustrates binary classification. - * For information on multiclass classification, please refer to the JavaDecisionTree.java - * example. - */ - private static void testClassification(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "gini"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); - System.out.println("Test Error: " + testErr); - System.out.println("Learned classification forest model:\n" + model.toDebugString()); - } - - private static void testRegression(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "variance"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); - System.out.println("Test Mean Squared Error: " + testMSE); - System.out.println("Learned regression forest model:\n" + model.toDebugString()); - } - - public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - // Load and parse the data file. - String datapath = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); - // Split the data into training and test sets (30% held out for testing) - JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); - JavaRDD trainingData = splits[0]; - JavaRDD testData = splits[1]; - - System.out.println("\nRunning example of classification using RandomForest\n"); - testClassification(trainingData, testData); - - System.out.println("\nRunning example of regression using RandomForest\n"); - testRegression(trainingData, testData); - sc.stop(); - } -} diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py deleted file mode 100755 index 513ed8fd51450..0000000000000 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ /dev/null @@ -1,144 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Decision tree classification and regression using MLlib. - -This example requires NumPy (http://www.numpy.org/). -""" -from __future__ import print_function - -import numpy -import os -import sys - -from operator import add - -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree -from pyspark.mllib.util import MLUtils - - -def getAccuracy(dtModel, data): - """ - Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + (x[0] == x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainCorrect / (0.0 + data.count()) - - -def getMSE(dtModel, data): - """ - Return mean squared error (MSE) of DecisionTreeModel on the given - RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainMSE / (0.0 + data.count()) - - -def reindexClassLabels(data): - """ - Re-index class labels in a dataset to the range {0,...,numClasses-1}. - If all labels in that range already appear at least once, - then the returned RDD is the same one (without a mapping). - Note: If a label simply does not appear in the data, - the index will not include it. - Be aware of this when reindexing subsampled data. - :param data: RDD of LabeledPoint where labels are integer values - denoting labels for a classification problem. - :return: Pair (reindexedData, origToNewLabels) where - reindexedData is an RDD of LabeledPoint with labels in - the range {0,...,numClasses-1}, and - origToNewLabels is a dictionary mapping original labels - to new labels. - """ - # classCounts: class --> # examples in class - classCounts = data.map(lambda x: x.label).countByValue() - numExamples = sum(classCounts.values()) - sortedClasses = sorted(classCounts.keys()) - numClasses = len(classCounts) - # origToNewLabels: class --> index in 0,...,numClasses-1 - if (numClasses < 2): - print("Dataset for classification should have at least 2 classes." - " The given dataset had only %d classes." % numClasses, file=sys.stderr) - exit(1) - origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - - print("numClasses = %d" % numClasses) - print("Per-class example fractions, counts:") - print("Class\tFrac\tCount") - for c in sortedClasses: - frac = classCounts[c] / (numExamples + 0.0) - print("%g\t%g\t%d" % (c, frac, classCounts[c])) - - if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): - return (data, origToNewLabels) - else: - reindexedData = \ - data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) - return (reindexedData, origToNewLabels) - - -def usage(): - print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) - exit(1) - - -if __name__ == "__main__": - if len(sys.argv) > 2: - usage() - sc = SparkContext(appName="PythonDT") - - # Load data. - dataPath = 'data/mllib/sample_libsvm_data.txt' - if len(sys.argv) == 2: - dataPath = sys.argv[1] - if not os.path.isfile(dataPath): - sc.stop() - usage() - points = MLUtils.loadLibSVMFile(sc, dataPath) - - # Re-index class labels if needed. - (reindexedData, origToNewLabels) = reindexClassLabels(points) - numClasses = len(origToNewLabels) - - # Train a classifier. - categoricalFeaturesInfo = {} # no categorical features - model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, - categoricalFeaturesInfo=categoricalFeaturesInfo) - # Print learned tree and stats. - print("Trained DecisionTree for classification:") - print(" Model numNodes: %d" % model.numNodes()) - print(" Model depth: %d" % model.depth()) - print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) - if model.numNodes() < 20: - print(model.toDebugString()) - else: - print(model) - - sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py deleted file mode 100644 index 781bd61c9d2b5..0000000000000 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Gradient boosted Trees classification and regression using MLlib. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import GradientBoostedTrees -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification ensemble model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ - / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression ensemble model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGradientBoostedTrees") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using GradientBoostedTrees\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using GradientBoostedTrees\n') - testRegression(trainingData, testData) - - sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py deleted file mode 100755 index 4cfdad868c66e..0000000000000 --- a/examples/src/main/python/mllib/random_forest_example.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Random Forest classification and regression using MLlib. - -Note: This example illustrates binary classification. - For information on multiclass classification, please refer to the decision_tree_runner.py - example. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import RandomForest -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainClassifier(trainingData, numClasses=2, - categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification forest model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ - .sum() / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression forest model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using RandomForest\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using RandomForest\n') - testRegression(trainingData, testData) - - sc.stop() From 55358889309cf2d856b72e72e0f3081dfdf61cfa Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 30 Nov 2015 15:38:44 -0800 Subject: [PATCH 1061/2089] [SPARK-11960][MLLIB][DOC] User guide for streaming tests CC jkbradley mengxr josepablocam Author: Feynman Liang Closes #10005 from feynmanliang/streaming-test-user-guide. --- docs/mllib-guide.md | 1 + docs/mllib-statistics.md | 25 +++++++++++++++++++ .../examples/mllib/StreamingTestExample.scala | 2 ++ 3 files changed, 28 insertions(+) diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15af..43772adcf26e1 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -34,6 +34,7 @@ We list major functionality from both below, with links to detailed guides. * [correlations](mllib-statistics.html#correlations) * [stratified sampling](mllib-statistics.html#stratified-sampling) * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [streaming significance testing](mllib-statistics.html#streaming-significance-testing) * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ade5b0768aefe..de209f68e19ca 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -521,6 +521,31 @@ print(testResult) # summary of the test including the p-value, test statistic, +### Streaming Significance Testing +MLlib provides online implementations of some tests to support use cases +like A/B testing. These tests may be performed on a Spark Streaming +`DStream[(Boolean,Double)]` where the first element of each tuple +indicates control group (`false`) or treatment group (`true`) and the +second element is the value of an observation. + +Streaming significance testing supports the following parameters: + +* `peacePeriod` - The number of initial data points from the stream to +ignore, used to mitigate novelty effects. +* `windowSize` - The number of past batches to perform hypothesis +testing over. Setting to `0` will perform cumulative processing using +all prior batches. + + +
    +
    +[`StreamingTest`](api/scala/index.html#org.apache.spark.mllib.stat.test.StreamingTest) +provides streaming hypothesis testing. + +{% include_example scala/org/apache/spark/examples/mllib/StreamingTestExample.scala %} +
    +
    + ## Random data generation diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index ab29f90254d34..b6677c6476639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -64,6 +64,7 @@ object StreamingTestExample { dir.toString }) + // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { case Array(label, value) => (label.toBoolean, value.toDouble) }) @@ -75,6 +76,7 @@ object StreamingTestExample { val out = streamingTest.registerStream(data) out.print() + // $example off$ // Stop processing if test becomes significant or we time out var timeoutCounter = numBatchesTimeout From ecc00ec3faf09b0e21bc4b468cb45e45d05cf482 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 15:42:10 -0800 Subject: [PATCH 1062/2089] fix Maven build --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 507641ff8263e..a78177751c9dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -43,7 +43,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().get + protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null) protected def sparkContext = sqlContext.sparkContext From edb26e7f4e1164645971c9a139eb29ddec8acc5d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 30 Nov 2015 16:31:59 -0800 Subject: [PATCH 1063/2089] [SPARK-12058][HOTFIX] Disable KinesisStreamTests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit KinesisStreamTests in test.py is broken because of #9403. See https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/46896/testReport/(root)/KinesisStreamTests/test_kinesis_stream/ Because Streaming Python didn’t work when merging https://github.com/apache/spark/pull/9403, the PR build didn’t report the Python test failure actually. This PR just disabled the test to unblock #10039 Author: Shixiong Zhu Closes #10047 from zsxwing/disable-python-kinesis-test. --- python/pyspark/streaming/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d380d697bc51c..a647e6bf39581 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1409,6 +1409,7 @@ def test_kinesis_stream_api(self): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") + @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write( From d3ca8cfac286ae19f8bedc736877ea9d0a0a072c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 16:37:27 -0800 Subject: [PATCH 1064/2089] [SPARK-12000] Fix API doc generation issues This pull request fixes multiple issues with API doc generation. - Modify the Jekyll plugin so that the entire doc build fails if API docs cannot be generated. This will make it easy to detect when the doc build breaks, since this will now trigger Jenkins failures. - Change how we handle the `-target` compiler option flag in order to fix `javadoc` generation. - Incorporate doc changes from thunterdb (in #10048). Closes #10048. Author: Josh Rosen Author: Timothy Hunter Closes #10049 from JoshRosen/fix-doc-build. --- docs/_plugins/copy_api_dirs.rb | 6 +++--- .../apache/spark/network/client/StreamCallback.java | 4 ++-- .../org/apache/spark/network/server/RpcHandler.java | 2 +- project/SparkBuild.scala | 11 ++++++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 01718d98dffe0..f2f3e2e653149 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -27,7 +27,7 @@ cd("..") puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl clean compile unidoc` + system("build/sbt -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") puts "Moving back into docs dir." cd("docs") @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - puts `make html` + system(make html) || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") @@ -131,7 +131,7 @@ # Build SparkR API docs puts "Moving to R directory and building roxygen docs." cd("R") - puts `./create-docs.sh` + system("./create-docs.sh") || raise("R doc generation failed") puts "Moving back into home dir." cd("../") diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 093fada320cc3..51d34cac6e636 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,8 +21,8 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link onComplete()} will be + * Callback for streaming data. Stream data will be offered to the {@link onData(String, ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link onComplete(String)} will be * called. *

    * The network library guarantees that a single thread will call these methods at a time, but diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 65109ddfe13b9..1a11f7b3820c6 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -55,7 +55,7 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * call "{@link receive(TransportClient, byte[], RpcResponseCallback)}" and log a warning if * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f575f0012d59e..63290d8a666e6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -160,7 +160,12 @@ object SparkBuild extends PomBuild { javacOptions in Compile ++= Seq( "-encoding", "UTF-8", - "-source", javacJVMVersion.value, + "-source", javacJVMVersion.value + ), + // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't + // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for + // additional discussion and explanation. + javacOptions in (Compile, compile) ++= Seq( "-target", javacJVMVersion.value ), @@ -547,9 +552,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. From e6dc89a33951e9197a77dbcacf022c27469ae41e Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 30 Nov 2015 17:18:44 -0800 Subject: [PATCH 1065/2089] [SPARK-12035] Add more debug information in include_example tag of Jekyll https://issues.apache.org/jira/browse/SPARK-12035 When we debuging lots of example code files, like in https://github.com/apache/spark/pull/10002, it's hard to know which file causes errors due to limited information in `include_example.rb`. With their filenames, we can locate bugs easily. Author: Xusen Yin Closes #10026 from yinxusen/SPARK-12035. --- docs/_plugins/include_example.rb | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 564c86680f68e..f7485826a762d 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -75,10 +75,10 @@ def select_lines(code) .select { |l, i| l.include? "$example off$" } .map { |l, i| i } - raise "Start indices amount is not equal to end indices amount, please check the code." \ + raise "Start indices amount is not equal to end indices amount, see #{@file}." \ unless startIndices.size == endIndices.size - raise "No code is selected by include_example, please check the code." \ + raise "No code is selected by include_example, see #{@file}." \ if startIndices.size == 0 # Select and join code blocks together, with a space line between each of two continuous @@ -86,8 +86,10 @@ def select_lines(code) lastIndex = -1 result = "" startIndices.zip(endIndices).each do |start, endline| - raise "Overlapping between two example code blocks are not allowed." if start <= lastIndex - raise "$example on$ should not be in the same line with $example off$." if start == endline + raise "Overlapping between two example code blocks are not allowed, see #{@file}." \ + if start <= lastIndex + raise "$example on$ should not be in the same line with $example off$, see #{@file}." \ + if start == endline lastIndex = endline range = Range.new(start + 1, endline - 1) result += trim_codeblock(lines[range]).join From 0a46e4377216a1f7de478f220c3b3042a77789e2 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 30 Nov 2015 17:19:26 -0800 Subject: [PATCH 1066/2089] [SPARK-12037][CORE] initialize heartbeatReceiverRef before calling startDriverHeartbeat https://issues.apache.org/jira/browse/SPARK-12037 a simple fix by changing the order of the statements Author: CodingCat Closes #10032 from CodingCat/SPARK-12037. --- .../main/scala/org/apache/spark/executor/Executor.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6154f06e3ac11..7b68dfe5ad06e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -109,6 +109,10 @@ private[spark] class Executor( // Executor for the heartbeat task. private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + // must be initialized before running startDriverHeartbeat() + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + startDriverHeartbeater() def launchTask( @@ -411,9 +415,6 @@ private[spark] class Executor( } } - private val heartbeatReceiverRef = - RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) - /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { // list of (task id, metrics) to send back to the driver From 9bf2120672ae0f620a217ccd96bef189ab75e0d6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 30 Nov 2015 17:22:05 -0800 Subject: [PATCH 1067/2089] [SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer. This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. Author: Marcelo Vanzin Closes #9987 from vanzin/SPARK-12007. --- .../mesos/MesosExternalShuffleService.scala | 3 +- .../network/netty/NettyBlockRpcServer.scala | 8 +- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 16 +- .../org/apache/spark/rpc/netty/Outbox.scala | 9 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 3 +- .../network/client/RpcResponseCallback.java | 4 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportResponseHandler.java | 16 +- .../network/protocol/AbstractMessage.java | 54 +++++++ ...Body.java => AbstractResponseMessage.java} | 16 +- .../network/protocol/ChunkFetchFailure.java | 2 +- .../network/protocol/ChunkFetchRequest.java | 2 +- .../network/protocol/ChunkFetchSuccess.java | 8 +- .../spark/network/protocol/Message.java | 11 +- .../network/protocol/MessageEncoder.java | 29 ++-- .../spark/network/protocol/OneWayMessage.java | 33 ++-- .../spark/network/protocol/RpcFailure.java | 2 +- .../spark/network/protocol/RpcRequest.java | 34 +++-- .../spark/network/protocol/RpcResponse.java | 39 +++-- .../spark/network/protocol/StreamFailure.java | 2 +- .../spark/network/protocol/StreamRequest.java | 2 +- .../network/protocol/StreamResponse.java | 19 +-- .../network/sasl/SaslClientBootstrap.java | 12 +- .../spark/network/sasl/SaslMessage.java | 31 ++-- .../spark/network/sasl/SaslRpcHandler.java | 26 +++- .../spark/network/server/MessageHandler.java | 2 +- .../spark/network/server/NoOpRpcHandler.java | 8 +- .../spark/network/server/RpcHandler.java | 8 +- .../server/TransportChannelHandler.java | 2 +- .../server/TransportRequestHandler.java | 15 +- .../apache/spark/network/util/JavaUtils.java | 48 ++++-- .../network/util/TransportFrameDecoder.java | 142 ++++++++++++------ .../network/ChunkFetchIntegrationSuite.java | 5 +- .../apache/spark/network/ProtocolSuite.java | 10 +- .../RequestTimeoutIntegrationSuite.java | 51 ++++--- .../spark/network/RpcIntegrationSuite.java | 26 ++-- .../org/apache/spark/network/StreamSuite.java | 5 +- .../TransportResponseHandlerSuite.java | 24 +-- .../spark/network/sasl/SparkSaslSuite.java | 43 ++++-- .../util/TransportFrameDecoderSuite.java | 23 ++- .../shuffle/ExternalShuffleBlockHandler.java | 9 +- .../shuffle/ExternalShuffleClient.java | 3 +- .../shuffle/OneForOneBlockFetcher.java | 7 +- .../mesos/MesosExternalShuffleClient.java | 5 +- .../protocol/BlockTransferMessage.java | 8 +- .../network/sasl/SaslIntegrationSuite.java | 18 ++- .../shuffle/BlockTransferMessagesSuite.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 21 +-- .../shuffle/OneForOneBlockFetcherSuite.java | 8 +- 50 files changed, 589 insertions(+), 307 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java rename network/common/src/main/java/org/apache/spark/network/protocol/{ResponseWithBody.java => AbstractResponseMessage.java} (63%) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 12337a940a414..8ffcfc0878a42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.mesos import java.net.SocketAddress +import java.nio.ByteBuffer import scala.collection.mutable @@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo } } connectedApps(address) = appId - callback.onSuccess(new Array[Byte](0)) + callback.onSuccess(ByteBuffer.allocate(0)) case _ => super.handleMessage(message, client, callback) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 76968249fb625..df8c21fb837ed 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -47,9 +47,9 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - messageBytes: Array[Byte], + rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { - val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { @@ -58,7 +58,7 @@ class NettyBlockRpcServer( openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel is serialized as bytes using our JavaSerializer. @@ -66,7 +66,7 @@ class NettyBlockRpcServer( serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) - responseContext.onSuccess(new Array[Byte](0)) + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b0694e3c6c8af..82c16e855b0c0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") result.success((): Unit) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c7d74fa1d9195..68c5f44145b0d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv( promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } - private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = javaSerializerInstance.serialize(content) - java.util.Arrays.copyOfRange( - buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + private[netty] def serialize(content: Any): ByteBuffer = { + javaSerializerInstance.serialize(content) } - private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + javaSerializerInstance.deserialize[T](bytes) } } } @@ -557,7 +555,7 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte], + message: ByteBuffer, callback: RpcResponseCallback): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postRemoteMessage(messageToDispatch, callback) @@ -565,12 +563,12 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte]): Unit = { + message: ByteBuffer): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 36fdd00bbc4c2..2316ebe347bb7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc.netty +import java.nio.ByteBuffer import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy @@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage { } -private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage +private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage with Logging { override def sendWith(client: TransportClient): Unit = { @@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb } private[netty] case class RpcOutboxMessage( - content: Array[Byte], + content: ByteBuffer, _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) + _onSuccess: (TransportClient, ByteBuffer) => Unit) extends OutboxMessage with RpcResponseCallback { private var client: TransportClient = _ @@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage( _onFailure(e) } - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { _onSuccess(client, response) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 323184cdd9b6e..ebd6f700710bd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import java.net.InetSocketAddress +import java.nio.ByteBuffer import io.netty.channel.Channel import org.mockito.Mockito._ @@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6ec960d795420..47e93f9846fa6 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,13 +17,15 @@ package org.apache.spark.network.client; +import java.nio.ByteBuffer; + /** * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ public interface RpcResponseCallback { /** Successful serialized result from server. */ - void onSuccess(byte[] response); + void onSuccess(ByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8a58e7b24585b..c49ca4d5ee925 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -20,6 +20,7 @@ import java.io.Closeable; import java.io.IOException; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -36,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; @@ -212,7 +214,7 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Callback to handle the RPC's reply. * @return The RPC's id. */ - public long sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -220,7 +222,7 @@ public long sendRpc(byte[] message, final RpcResponseCallback callback) { final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -249,12 +251,12 @@ public void operationComplete(ChannelFuture future) throws Exception { * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public byte[] sendRpcSync(byte[] message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); + public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { result.set(response); } @@ -279,8 +281,8 @@ public void onFailure(Throwable e) { * * @param message The message to send. */ - public void send(byte[] message) { - channel.writeAndFlush(new OneWayMessage(message)); + public void send(ByteBuffer message) { + channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 4c15045363b84..23a8dba593442 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -136,7 +136,7 @@ public void exceptionCaught(Throwable cause) { } @Override - public void handle(ResponseMessage message) { + public void handle(ResponseMessage message) throws Exception { String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; @@ -144,11 +144,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.body.release(); + resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body); - resp.body.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -166,10 +166,14 @@ public void handle(ResponseMessage message) { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.response.length); + resp.requestId, remoteAddress, resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.response); + try { + listener.onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java new file mode 100644 index 0000000000000..2924218c2f08b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for messages which optionally contain a body kept in a separate buffer. + */ +public abstract class AbstractMessage implements Message { + private final ManagedBuffer body; + private final boolean isBodyInFrame; + + protected AbstractMessage() { + this(null, false); + } + + protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + @Override + public ManagedBuffer body() { + return body; + } + + @Override + public boolean isBodyInFrame() { + return isBodyInFrame; + } + + protected boolean equals(AbstractMessage other) { + return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java similarity index 63% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java rename to network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java index 67be77e39f711..c362c92fc4f52 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java @@ -17,23 +17,15 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; /** - * Abstract class for response messages that contain a large data portion kept in a separate - * buffer. These messages are treated especially by MessageEncoder. + * Abstract class for response messages. */ -public abstract class ResponseWithBody implements ResponseMessage { - public final ManagedBuffer body; - public final boolean isBodyInFrame; +public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { - protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) { - this.body = body; - this.isBodyInFrame = isBodyInFrame; + protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { + super(body, isBodyInFrame); } public abstract ResponseMessage createFailureResponse(String error); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index f0363830b61ac..7b28a9a969486 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -23,7 +23,7 @@ /** * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ -public final class ChunkFetchFailure implements ResponseMessage { +public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 5a173af54f618..26d063feb5fe3 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -24,7 +24,7 @@ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements RequestMessage { +public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index e6a7e9a8b4145..94c2ac9b20e43 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,7 +30,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess extends ResponseWithBody { +public final class ChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { @@ -67,14 +67,14 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(streamChunkId, body); + return Objects.hashCode(streamChunkId, body()); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && body.equals(o.body); + return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; } @@ -83,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", body) + .add("buffer", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index 39afd03db60ee..66f5b8b3a59c8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,17 +19,25 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + /** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); + /** An optional body for the message. */ + ManagedBuffer body(); + + /** Whether to include the body of the message in the same frame as the message. */ + boolean isBodyInFrame(); + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9); + OneWayMessage(9), User(-1); private final byte id; @@ -57,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 6cce97c807dc0..abca22347b783 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder { * data to 'out', in order to enable zero-copy transfer. */ @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { Object body = null; long bodyLength = 0; boolean isBodyInFrame = false; - // Detect ResponseWithBody messages and get the data buffer out of them. - // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ResponseWithBody) { - ResponseWithBody resp = (ResponseWithBody) in; + // If the message has a body, take it out to enable zero-copy transfer for the payload. + if (in.body() != null) { try { - bodyLength = resp.body.size(); - body = resp.body.convertToNetty(); - isBodyInFrame = resp.isBodyInFrame; + bodyLength = in.body().size(); + body = in.body().convertToNetty(); + isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { - // Re-encode this message as a failure response. - String error = e.getMessage() != null ? e.getMessage() : "null"; - logger.error(String.format("Error processing %s for client %s", - resp, ctx.channel().remoteAddress()), e); - encode(ctx, resp.createFailureResponse(error), out); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } return; } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index 95a0270be3da9..efe0470f35875 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -17,21 +17,21 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A RPC that does not expect a reply, which is handled by a remote * {@link org.apache.spark.network.server.RpcHandler}. */ -public final class OneWayMessage implements RequestMessage { - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; +public final class OneWayMessage extends AbstractMessage implements RequestMessage { - public OneWayMessage(byte[] message) { - this.message = message; + public OneWayMessage(ManagedBuffer body) { + super(body, true); } @Override @@ -39,29 +39,34 @@ public OneWayMessage(byte[] message) { @Override public int encodedLength() { - return Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 4; } @Override public void encode(ByteBuf buf) { - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static OneWayMessage decode(ByteBuf buf) { - byte[] message = Encoders.ByteArrays.decode(buf); - return new OneWayMessage(message); + // See comment in encodedLength(). + buf.readInt(); + return new OneWayMessage(new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Arrays.hashCode(message); + return Objects.hashCode(body()); } @Override public boolean equals(Object other) { if (other instanceof OneWayMessage) { OneWayMessage o = (OneWayMessage) other; - return Arrays.equals(message, o.message); + return super.equals(o); } return false; } @@ -69,7 +74,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 2dfc7876ba328..a76624ef5dc96 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -21,7 +21,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ResponseMessage { +public final class RpcFailure extends AbstractMessage implements ResponseMessage { public final long requestId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 745039db742fa..96213794a8015 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -17,26 +17,25 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements RequestMessage { +public final class RpcRequest extends AbstractMessage implements RequestMessage { /** Used to link an RPC request with its response. */ public final long requestId; - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; - - public RpcRequest(long requestId, byte[] message) { + public RpcRequest(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.message = message; } @Override @@ -44,31 +43,36 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] message = Encoders.ByteArrays.decode(buf); - return new RpcRequest(requestId, message); + // See comment in encodedLength(). + buf.readInt(); + return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(message)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && Arrays.equals(message, o.message); + return requestId == o.requestId && super.equals(o); } return false; } @@ -77,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 1671cd444f039..bae866e14a1e1 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -17,49 +17,62 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ResponseMessage { +public final class RpcResponse extends AbstractResponseMessage { public final long requestId; - public final byte[] response; - public RpcResponse(long requestId, byte[] response) { + public RpcResponse(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.response = response; } @Override public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, response); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] response = Encoders.ByteArrays.decode(buf); - return new RpcResponse(requestId, response); + // See comment in encodedLength(). + buf.readInt(); + return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(response)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && Arrays.equals(response, o.response); + return requestId == o.requestId && super.equals(o); } return false; } @@ -68,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("response", response) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index e3dade2ebf905..26747ee55b4de 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -26,7 +26,7 @@ /** * Message indicating an error when transferring a stream. */ -public final class StreamFailure implements ResponseMessage { +public final class StreamFailure extends AbstractMessage implements ResponseMessage { public final String streamId; public final String error; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index 821e8f53884d7..35af5a84ba6bd 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -29,7 +29,7 @@ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before * the data can be streamed. */ -public final class StreamRequest implements RequestMessage { +public final class StreamRequest extends AbstractMessage implements RequestMessage { public final String streamId; public StreamRequest(String streamId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index ac5ab9a323a11..51b899930f721 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -30,15 +30,15 @@ * sender. The receiver is expected to set a temporary channel handler that will consume the * number of bytes this message says the stream has. */ -public final class StreamResponse extends ResponseWithBody { - public final String streamId; - public final long byteCount; +public final class StreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } @Override public Type type() { return Type.StreamResponse; } @@ -68,7 +68,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId); + return Objects.hashCode(byteCount, streamId, body()); } @Override @@ -85,6 +85,7 @@ public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) .add("byteCount", byteCount) + .add("body", body()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 69923769d44b4..68381037d6891 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -28,6 +30,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,11 +73,12 @@ public void doBootstrap(TransportClient client, Channel channel) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + buf.writeBytes(msg.body().nioByteBuffer()); - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); @@ -88,6 +92,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for SASL encryption.", client); } + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index cad76ab7aa54e..e52b526f09c77 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -18,38 +18,50 @@ package org.apache.spark.network.sasl; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; -import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.protocol.AbstractMessage; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -class SaslMessage implements Encodable { +class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ private static final byte TAG_BYTE = (byte) 0xEA; public final String appId; - public final byte[] payload; - public SaslMessage(String appId, byte[] payload) { + public SaslMessage(String appId, byte[] message) { + this(appId, Unpooled.wrappedBuffer(message)); + } + + public SaslMessage(String appId, ByteBuf message) { + super(new NettyManagedBuffer(message), true); this.appId = appId; - this.payload = payload; } + @Override + public Type type() { return Type.User; } + @Override public int encodedLength() { - return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 1 + Encoders.Strings.encodedLength(appId) + 4; } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); - Encoders.ByteArrays.encode(buf, payload); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static SaslMessage decode(ByteBuf buf) { @@ -59,7 +71,8 @@ public static SaslMessage decode(ByteBuf buf) { } String appId = Encoders.Strings.decode(buf); - byte[] payload = Encoders.ByteArrays.decode(buf); - return new SaslMessage(appId, payload); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 830db94b890c5..c215bd9d15045 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,8 +17,11 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } if (saslServer == null) { // First message in the handshake, setup the necessary state. @@ -86,8 +96,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback conf.saslServerAlwaysEncrypt()); } - byte[] response = saslServer.response(saslMessage.payload); - callback.onSuccess(response); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -109,7 +125,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index b80c15106ecbd..3843406b27403 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -26,7 +26,7 @@ */ public abstract class MessageHandler { /** Handles the receipt of a single message. */ - public abstract void handle(T message); + public abstract void handle(T message) throws Exception; /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 1502b7489e864..6ed61da5c7eff 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,5 +1,3 @@ -package org.apache.spark.network.server; - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,6 +15,10 @@ * limitations under the License. */ +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -29,7 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 1a11f7b3820c6..ee1c683699478 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +46,7 @@ public abstract class RpcHandler { */ public abstract void receive( TransportClient client, - byte[] message, + ByteBuffer message, RpcResponseCallback callback); /** @@ -62,7 +64,7 @@ public abstract void receive( * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. */ - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { receive(client, message, ONE_WAY_CALLBACK); } @@ -79,7 +81,7 @@ private static class OneWayRpcCallback implements RpcResponseCallback { private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.warn("Response provided for one-way RPC."); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 3164e00679035..09435bcbab35e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -99,7 +99,7 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) { + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); } else { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index db18ea77d1073..c864d7ce16bd3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; @@ -26,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -143,10 +146,10 @@ private void processStreamRequest(final StreamRequest req) { private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - respond(new RpcResponse(req.requestId, response)); + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } @Override @@ -157,14 +160,18 @@ public void onFailure(Throwable e) { } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } finally { + req.body().release(); } } private void processOneWayMessage(OneWayMessage req) { try { - rpcHandler.receive(reverseClient, req.message); + rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } finally { + req.body().release(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7d27439cfde7a..b3d8e0cd7cdcd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -132,7 +132,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -164,32 +164,32 @@ private static boolean isSymlink(File file) throws IOException { */ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); - + try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } - + long val = Long.parseLong(m.group(1)); String suffix = m.group(2); - + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); } - + // If suffix is valid use that, otherwise none was provided and use the default passed return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; - + throw new NumberFormatException(timeError + "\n" + e.getMessage()); } } - + /** * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. @@ -205,10 +205,10 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } - + /** * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is + * internal use. If no suffix is provided a direct conversion of the provided default is * attempted. */ private static long parseByteString(String str, ByteUnit unit) { @@ -217,7 +217,7 @@ private static long parseByteString(String str, ByteUnit unit) { try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - + if (m.matches()) { long val = Long.parseLong(m.group(1)); String suffix = m.group(2); @@ -228,14 +228,14 @@ private static long parseByteString(String str, ByteUnit unit) { } // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " + throw new NumberFormatException("Fractional values are not supported. Input was: " + fractionMatcher.group(1)); } else { - throw new NumberFormatException("Failed to parse byte string: " + str); + throw new NumberFormatException("Failed to parse byte string: " + str); } - + } catch (NumberFormatException e) { String timeError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + @@ -248,7 +248,7 @@ private static long parseByteString(String str, ByteUnit unit) { /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for * internal use. - * + * * If no suffix is provided, the passed number is assumed to be in bytes. */ public static long byteStringAsBytes(String str) { @@ -264,7 +264,7 @@ public static long byteStringAsBytes(String str) { public static long byteStringAsKb(String str) { return parseByteString(str, ByteUnit.KiB); } - + /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for * internal use. @@ -284,4 +284,20 @@ public static long byteStringAsMb(String str) { public static long byteStringAsGb(String str) { return parseByteString(str, ByteUnit.GiB); } + + /** + * Returns a byte array with the buffer's contents, trying to avoid copying the data if + * possible. + */ + public static byte[] bufferToArray(ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && + buffer.array().length == buffer.remaining()) { + return buffer.array(); + } else { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 5889562dd9705..a466c729154aa 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -17,9 +17,13 @@ package org.apache.spark.network.util; +import java.util.Iterator; +import java.util.LinkedList; + import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { public static final String HANDLER_NAME = "frameDecoder"; private static final int LENGTH_SIZE = 8; private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); - private CompositeByteBuf buffer; + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + // First, feed the interceptor, and if it's still, active, try again. + if (interceptor != null) { + ByteBuf first = buffers.getFirst(); + int available = first.readableBytes(); + if (feedInterceptor(first)) { + assert !first.isReadable() : "Interceptor still active but buffer has data."; + } - if (buffer == null) { - buffer = in.alloc().compositeBuffer(); - } - - buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); - - while (buffer.isReadable()) { - discardReadBytes(); - if (!feedInterceptor()) { + int read = available - first.readableBytes(); + if (read == available) { + buffers.removeFirst().release(); + } + totalSize -= read; + } else { + // Interceptor is not active, so try to decode one frame. ByteBuf frame = decodeNext(); if (frame == null) { break; } - ctx.fireChannelRead(frame); } } - - discardReadBytes(); } - private void discardReadBytes() { - // If the buffer's been retained by downstream code, then make a copy of the remaining - // bytes into a new buffer. Otherwise, just discard stale components. - if (buffer.refCnt() > 1) { - CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; + } - if (buffer.readableBytes() > 0) { - ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); - spillBuf.writeBytes(buffer); - newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); } + return nextFrameSize; + } - buffer.release(); - buffer = newBuffer; - } else { - buffer.discardReadComponents(); + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } } + + nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; } private ByteBuf decodeNext() throws Exception { - if (buffer.readableBytes() < LENGTH_SIZE) { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { return null; } - int frameLen = (int) buffer.readLong() - LENGTH_SIZE; - if (buffer.readableBytes() < frameLen) { - buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE); - return null; + // Reset size for next frame. + nextFrameSize = UNKNOWN_FRAME_SIZE; + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + // If the first buffer holds the entire frame, return it. + int remaining = (int) frameSize; + if (buffers.getFirst().readableBytes() >= remaining) { + return nextBufferForFrame(remaining); } - Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen); - Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen); + // Otherwise, create a composite buffer. + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); + while (remaining > 0) { + ByteBuf next = nextBufferForFrame(remaining); + remaining -= next.readableBytes(); + frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + } + assert remaining == 0; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame + * (by taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); + } - ByteBuf frame = buffer.readSlice(frameLen); - frame.retain(); return frame; } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (buffer != null) { - if (buffer.isReadable()) { - feedInterceptor(); - } - buffer.release(); + for (ByteBuf b : buffers) { + b.release(); } if (interceptor != null) { interceptor.channelInactive(); } + frameLenBuf.release(); super.channelInactive(ctx); } @@ -141,8 +199,8 @@ public void setInterceptor(Interceptor interceptor) { /** * @return Whether the interceptor is still active after processing the data. */ - private boolean feedInterceptor() throws Exception { - if (interceptor != null && !interceptor.handle(buffer)) { + private boolean feedInterceptor(ByteBuf buf) throws Exception { + if (interceptor != null && !interceptor.handle(buf)) { interceptor = null; } return interceptor != null; diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 50a324e293386..70c849d60e0a6 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -107,7 +107,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 1aa20900ffe74..6c8dd742f4b64 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -82,10 +82,10 @@ private void testClientToServer(Message msg) { @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new byte[0])); - testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); testClientToServer(new StreamRequest("abcde")); - testClientToServer(new OneWayMessage(new byte[100])); + testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } @Test @@ -94,8 +94,8 @@ public void responses() { testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new byte[0])); - testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 42955ef69235a..f9b5bf96d6215 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.junit.*; +import static org.junit.Assert.*; import java.io.IOException; import java.nio.ByteBuffer; @@ -84,13 +85,16 @@ public void tearDown() { @Test public void timeoutInactiveRequests() throws Exception { final Semaphore semaphore = new Semaphore(1); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -110,15 +114,15 @@ public StreamManager getStreamManager() { // First completes quickly (semaphore starts at 1). TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client.sendRpc(new byte[0], callback0); + client.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); - assert (callback0.success.length == response.length); + assertEquals(responseSize, callback0.successLength); } // Second times out after 2 seconds, with slack. Must be IOException. TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client.sendRpc(new byte[0], callback1); + client.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(4 * 1000); assert (callback1.failure != null); assert (callback1.failure instanceof IOException); @@ -131,13 +135,16 @@ public StreamManager getStreamManager() { @Test public void timeoutCleanlyClosesClient() throws Exception { final Semaphore semaphore = new Semaphore(0); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -158,7 +165,7 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client0.sendRpc(new byte[0], callback0); + client0.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); assert (callback0.failure instanceof IOException); assert (!client0.isActive()); @@ -170,10 +177,10 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client1.sendRpc(new byte[0], callback1); + client1.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(FOREVER); - assert (callback1.success.length == response.length); - assert (callback1.failure == null); + assertEquals(responseSize, callback1.successLength); + assertNull(callback1.failure); } } @@ -191,7 +198,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -218,9 +228,10 @@ public StreamManager getStreamManager() { synchronized (callback0) { // not complete yet, but should complete soon - assert (callback0.success == null && callback0.failure == null); + assertEquals(-1, callback0.successLength); + assertNull(callback0.failure); callback0.wait(2 * 1000); - assert (callback0.failure instanceof IOException); + assertTrue(callback0.failure instanceof IOException); } synchronized (callback1) { @@ -235,13 +246,13 @@ public StreamManager getStreamManager() { */ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - byte[] success; + int successLength = -1; Throwable failure; @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { synchronized(this) { - success = response; + successLength = response.remaining(); this.notifyAll(); } } @@ -258,7 +269,7 @@ public void onFailure(Throwable e) { public void onSuccess(int chunkIndex, ManagedBuffer buffer) { synchronized(this) { try { - success = buffer.nioByteBuffer().array(); + successLength = buffer.nioByteBuffer().remaining(); this.notifyAll(); } catch (IOException e) { // weird diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 88fa2258bb794..9e9be98c140b7 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -26,7 +27,6 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.base.Charsets; import com.google.common.collect.Sets; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -41,6 +41,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -55,11 +56,14 @@ public static void setUp() throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String msg = new String(message, Charsets.UTF_8); + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String msg = JavaUtils.bytesToString(message); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { - callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); } else if (parts[0].equals("return error")) { callback.onFailure(new RuntimeException("Returned: " + parts[1])); } else if (parts[0].equals("throw error")) { @@ -68,9 +72,8 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { - String msg = new String(message, Charsets.UTF_8); - oneWayMsgs.add(msg); + public void receive(TransportClient client, ByteBuffer message) { + oneWayMsgs.add(JavaUtils.bytesToString(message)); } @Override @@ -103,8 +106,9 @@ private RpcResult sendRPC(String ... commands) throws Exception { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(byte[] message) { - res.successMessages.add(new String(message, Charsets.UTF_8)); + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); sem.release(); } @@ -116,7 +120,7 @@ public void onFailure(Throwable e) { }; for (String command : commands) { - client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { @@ -173,7 +177,7 @@ public void sendOneWayMessage() throws Exception { final String message = "no reply"; TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.send(message.getBytes(Charsets.UTF_8)); + client.send(JavaUtils.stringToBytes(message)); assertEquals(0, client.getHandler().numOutstandingRequests()); // Make sure the message arrives. diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 538f3efe8d6f2..9c49556927f0b 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -116,7 +116,10 @@ public ManagedBuffer openStream(String streamId) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 30144f4a9fc7a..128f7cba74350 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; + import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -27,6 +29,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; @@ -42,7 +45,7 @@ public class TransportResponseHandlerSuite { @Test - public void handleSuccessfulFetch() { + public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); @@ -56,7 +59,7 @@ public void handleSuccessfulFetch() { } @Test - public void handleFailedFetch() { + public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); @@ -69,7 +72,7 @@ public void handleFailedFetch() { } @Test - public void clearAllOutstandingRequests() { + public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); @@ -88,23 +91,24 @@ public void clearAllOutstandingRequests() { } @Test - public void handleSuccessfulRPC() { + public void handleSuccessfulRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + // This response should be ignored. + handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); assertEquals(1, handler.numOutstandingRequests()); - byte[] arr = new byte[10]; - handler.handle(new RpcResponse(12345, arr)); - verify(callback, times(1)).onSuccess(eq(arr)); + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); + verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); assertEquals(0, handler.numOutstandingRequests()); } @Test - public void handleFailedRPC() { + public void handleFailedRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); @@ -119,7 +123,7 @@ public void handleFailedRPC() { } @Test - public void testActiveStreams() { + public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); TransportResponseHandler handler = new TransportResponseHandler(c); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index a6f180bc40c9a..751516b9d82a1 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -22,7 +22,7 @@ import java.io.File; import java.lang.reflect.Method; -import java.nio.charset.StandardCharsets; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -57,6 +57,7 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -123,39 +124,53 @@ public void testNonMatching() { } @Test - public void testSaslAuthentication() throws Exception { + public void testSaslAuthentication() throws Throwable { testBasicSasl(false); } @Test - public void testSaslEncryption() throws Exception { + public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Exception { + private void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) { - byte[] message = (byte[]) invocation.getArguments()[1]; + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", new String(message, StandardCharsets.UTF_8)); - cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8)); + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; } }) .when(rpcHandler) - .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { - byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); + ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { ctx.close(); // There should be 2 terminated events; one for the client, one for the server. - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + Throwable error = null; + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (deadline > System.nanoTime()) { + try { + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + error = null; + break; + } catch (Throwable t) { + error = t; + TimeUnit.MILLISECONDS.sleep(10); + } + } + if (error != null) { + throw error; + } } } @@ -325,8 +340,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { SaslTestCtx ctx = null; try { ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); } catch (Exception e) { assertFalse(e.getCause() instanceof TimeoutException); diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 19475c21ffce9..d4de4a941d480 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -118,6 +118,27 @@ public Void answer(InvocationOnMock in) { } } + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + 8); + buf.writeLong(frame.length + 8); + buf.writeBytes(frame); + + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); + } + } + @Test(expected = IllegalArgumentException.class) public void testNegativeFrameSize() throws Exception { testInvalidFrame(-1); @@ -183,7 +204,7 @@ private void testInvalidFrame(long size) throws Exception { try { decoder.channelRead(ctx, frame); } finally { - frame.release(); + release(frame); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 3ddf5c3c39189..f22187a01db02 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.annotations.VisibleForTesting; @@ -66,8 +67,8 @@ public ExternalShuffleBlockHandler( } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } @@ -85,13 +86,13 @@ protected void handleMessage( } long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(new byte[0]); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ef3a9dcc8711f..58ca87d9d3b13 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.base.Preconditions; @@ -139,7 +140,7 @@ public void registerWithShuffleServer( checkInit(); TransportClient client = clientFactory.createUnmanagedClient(host, port); try { - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } finally { client.close(); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e653f5cb147ee..1b2ddbf1ed917 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.nio.ByteBuffer; import java.util.Arrays; import org.slf4j.Logger; @@ -89,11 +90,11 @@ public void start() { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 7543b6be4f2a1..675820308bd4c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle.mesos; import java.io.IOException; +import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,11 +55,11 @@ public MesosExternalShuffleClient( public void registerDriverWithShuffleService(String host, int port) throws IOException { checkInit(); - byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); client.sendRpc(registerDriver, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.info("Successfully registered app " + appId + " with external shuffle service."); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index fcb52363e632c..7fbe3384b4d4f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle.protocol; +import java.nio.ByteBuffer; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -53,7 +55,7 @@ private Type(int id) { // NB: Java does not support static methods in interfaces, so we must put this in a static class. public static class Decoder { /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteArray(byte[] msg) { + public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { ByteBuf buf = Unpooled.wrappedBuffer(msg); byte type = buf.readByte(); switch (type) { @@ -68,12 +70,12 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { } /** Serializes the 'type' byte followed by the message itself. */ - public byte[] toByteArray() { + public ByteBuffer toByteBuffer() { // Allow room for encoded message, plus the type byte ByteBuf buf = Unpooled.buffer(encodedLength() + 1); buf.writeByte(type().id); encode(buf); assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.array(); + return buf.nioBuffer(); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 1c2fa4d0d462c..19c870aebb023 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; @@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -107,8 +109,8 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); - assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); } @Test @@ -136,7 +138,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -144,7 +146,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -222,13 +224,13 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); - StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; // Create a second client, authenticated with a different app ID, and try to read from @@ -275,7 +277,7 @@ public synchronized void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index d65de9ca550a3..86c8609e7070b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,7 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e61390cf57061..9379412155e88 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -60,12 +60,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } @SuppressWarnings("unchecked") @@ -77,17 +77,18 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(2, handle.numChunks); @SuppressWarnings("unchecked") @@ -104,7 +105,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -112,7 +113,7 @@ public void testBadMessages() { // pass } - byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -120,7 +121,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index b35a6d685dd02..2590b9ce4c1f1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -134,14 +134,14 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( - (byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } - }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. final AtomicInteger expectedChunkIndex = new AtomicInteger(0); From 96bf468c7860be317c20ccacf259910968d2dc83 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 30 Nov 2015 17:33:09 -0800 Subject: [PATCH 1068/2089] [SPARK-12049][CORE] User JVM shutdown hook can cause deadlock at shutdown Avoid potential deadlock with a user app's shutdown hook thread by more narrowly synchronizing access to 'hooks' Author: Sean Owen Closes #10042 from srowen/SPARK-12049. --- .../spark/util/ShutdownHookManager.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4012dca3ecdf8..620f226a23e15 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -206,7 +206,7 @@ private[spark] object ShutdownHookManager extends Logging { private [util] class SparkShutdownHookManager { private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false + @volatile private var shuttingDown = false /** * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not @@ -232,28 +232,27 @@ private [util] class SparkShutdownHookManager { } } - def runAll(): Unit = synchronized { + def runAll(): Unit = { shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) + var nextHook: SparkShutdownHook = null + while ({ nextHook = hooks.synchronized { hooks.poll() }; nextHook != null }) { + Try(Utils.logUncaughtExceptions(nextHook.run())) } } - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) + def add(priority: Int, hook: () => Unit): AnyRef = { + hooks.synchronized { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } } - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } + def remove(ref: AnyRef): Boolean = { + hooks.synchronized { hooks.remove(ref) } } } From f73379be2b0c286957b678a996cb56afc96015eb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 17:15:47 -0800 Subject: [PATCH 1069/2089] [HOTFIX][SPARK-12000] Add missing quotes in Jekyll API docs plugin. I accidentally omitted these as part of #10049. --- docs/_plugins/copy_api_dirs.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index f2f3e2e653149..174c202e37918 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - system(make html) || raise("Python doc generation failed") + system("make html") || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") From 9693b0d5a55bc1d2da96f04fe2c6de59a8dfcc1b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 30 Nov 2015 20:56:42 -0800 Subject: [PATCH 1070/2089] [SPARK-12018][SQL] Refactor common subexpression elimination code JIRA: https://issues.apache.org/jira/browse/SPARK-12018 The code of common subexpression elimination can be factored and simplified. Some unnecessary variables can be removed. Author: Liang-Chi Hsieh Closes #10009 from viirya/refactor-subexpr-eliminate. --- .../sql/catalyst/expressions/Expression.scala | 10 ++---- .../expressions/codegen/CodeGenerator.scala | 34 +++++-------------- .../codegen/GenerateUnsafeProjection.scala | 4 +-- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 169435a10ea2c..b55d3653a7158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -94,13 +94,9 @@ abstract class Expression extends TreeNode[Expression] { def gen(ctx: CodeGenContext): GeneratedExpressionCode = { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added - // as a function, `subExprState.fnName`. Just call that. - val code = - s""" - |/* $this */ - |${subExprState.fnName}(${ctx.INPUT_ROW}); - """.stripMargin.trim - GeneratedExpressionCode(code, subExprState.code.isNull, subExprState.code.value) + // as a function and called in advance. Just use it. + val code = s"/* $this */" + GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2f3d6aeb86c5b..440c7d2fc1156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -104,16 +104,13 @@ class CodeGenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // State used for subexpression elimination. - case class SubExprEliminationState( - isLoaded: String, - code: GeneratedExpressionCode, - fnName: String) + case class SubExprEliminationState(isNull: String, value: String) // Foreach expression that is participating in subexpression elimination, the state to use. val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] - // The collection of isLoaded variables that need to be reset on each row. - val subExprIsLoadedVariables = mutable.ArrayBuffer.empty[String] + // The collection of sub-exression result resetting methods that need to be called on each row. + val subExprResetVariables = mutable.ArrayBuffer.empty[String] final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" @@ -408,7 +405,6 @@ class CodeGenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { val expr = e.head - val isLoaded = freshName("isLoaded") val isNull = freshName("isNull") val value = freshName("value") val fnName = freshName("evalExpr") @@ -417,18 +413,12 @@ class CodeGenContext { val code = expr.gen(this) val fn = s""" - |private void $fnName(InternalRow ${INPUT_ROW}) { - | if (!$isLoaded) { - | ${code.code.trim} - | $isLoaded = true; - | $isNull = ${code.isNull}; - | $value = ${code.value}; - | } + |private void $fnName(InternalRow $INPUT_ROW) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; |} """.stripMargin - code.code = fn - code.isNull = isNull - code.value = value addNewFunction(fnName, fn) @@ -448,18 +438,12 @@ class CodeGenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - - // Maintain the loaded value and isNull as member variables. This is necessary if the codegen - // function is split across multiple functions. - // TODO: maintaining this as a local variable probably allows the compiler to do better - // optimizations. - addMutableState("boolean", isLoaded, s"$isLoaded = false;") addMutableState("boolean", isNull, s"$isNull = false;") addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subExprIsLoadedVariables += isLoaded - val state = SubExprEliminationState(isLoaded, code, fnName) + subExprResetVariables += s"$fnName($INPUT_ROW);" + val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7b6c9373ebe30..68005afb21d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,8 +287,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - // Reset the isLoaded flag for each row. - val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + // Reset the subexpression values for each row. + val subexprReset = ctx.subExprResetVariables.mkString("\n") val code = s""" From a0af0e351e45a8be47a6f65efd132eaa4a00c9e4 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 1 Dec 2015 09:26:58 +0000 Subject: [PATCH 1071/2089] [SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11898 syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization. Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help, 1. decrease the worker memory consumption by 45%. 2. decrease running time by 40%. This will also help extend the upper limit for Word2Vec. Author: Yuhao Yang Closes #9878 from hhbyyh/w2vBC. --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a47f27b0afb14..655ac0bb5545b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging { Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount @@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist() From c87531b765f8934a9a6c0f673617e0abfa5e5f0e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 07:42:37 -0800 Subject: [PATCH 1072/2089] [SPARK-11949][SQL] Set field nullable property for GroupingSets to get correct results for null values JIRA: https://issues.apache.org/jira/browse/SPARK-11949 The result of cube plan uses incorrect schema. The schema of cube result should set nullable property to true because the grouping expressions will have null values. Author: Liang-Chi Hsieh Closes #10038 from viirya/fix-cube. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94ffbbb2e5c65..b8f212fca7509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -223,6 +223,11 @@ class Analyzer( case other => Alias(other, other.toString)() } + // TODO: We need to use bitmasks to determine which grouping expressions need to be + // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need + // to change the nullability of a. + val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change // it so that the aggregation is computed on the unmodified value of its argument @@ -231,12 +236,13 @@ class Analyzer( // If not then its a grouping expression and we need to use the modified (with nulls from // Expand) value of the expression. case expr => expr.transformDown { - case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + case e => + groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(_.toAttribute) + val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b5c636d0de1d1..b1004bc5bc290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType +case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -86,6 +87,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, 2013, 78000.0) :: Row(null, null, 113000.0) :: Nil ) + + val df0 = sqlContext.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6))).toDF() + + val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg")) + assert(cube0.where("date IS NULL").count > 0) } test("rollup overlapping columns") { From 1401166576c7018c5f9c31e0a6703d5fb16ea339 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 1 Dec 2015 09:45:55 -0800 Subject: [PATCH 1073/2089] [SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize `JavaSerializerInstance.serialize` uses `ByteArrayOutputStream.toByteArray` to get the serialized data. `ByteArrayOutputStream.toByteArray` needs to copy the content in the internal array to a new array. However, since the array will be converted to `ByteBuffer` at once, we can avoid the memory copy. This PR added `ByteBufferOutputStream` to access the protected `buf` and convert it to a `ByteBuffer` directly. Author: Shixiong Zhu Closes #10051 from zsxwing/SPARK-12060. --- .../spark/serializer/JavaSerializer.scala | 7 ++--- .../spark/util/ByteBufferOutputStream.scala | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7d..ea718a0edbe71 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 0000000000000..92e45224db81c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} From 69dbe6b40df35d488d4ee343098ac70d00bbdafb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Dec 2015 10:21:31 -0800 Subject: [PATCH 1074/2089] [SPARK-12046][DOC] Fixes various ScalaDoc/JavaDoc issues This PR backports PR #10039 to master Author: Cheng Lian Closes #10063 from liancheng/spark-12046.doc-fix.master. --- .../spark/api/java/function/Function4.java | 2 +- .../spark/api/java/function/VoidFunction.java | 2 +- .../api/java/function/VoidFunction2.java | 2 +- .../apache/spark/api/java/JavaPairRDD.scala | 16 +++---- .../org/apache/spark/memory/package.scala | 14 +++--- .../org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +-- .../org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 5 ++- .../serializer/SerializationDebugger.scala | 13 +++--- .../scala/org/apache/spark/util/Vector.scala | 1 + .../util/collection/ExternalSorter.scala | 30 ++++++------- .../WritablePartitionedPairCollection.scala | 7 +-- .../streaming/kinesis/KinesisReceiver.scala | 23 +++++----- .../streaming/kinesis/KinesisUtils.scala | 13 +++--- .../mllib/optimization/GradientDescent.scala | 12 ++--- project/SparkBuild.scala | 2 + .../scala/org/apache/spark/sql/Column.scala | 11 ++--- .../spark/streaming/StreamingContext.scala | 11 ++--- .../streaming/dstream/FileInputDStream.scala | 19 ++++---- .../streaming/receiver/BlockGenerator.scala | 22 ++++----- .../scheduler/ReceiverSchedulingPolicy.scala | 45 ++++++++++--------- .../util/FileBasedWriteAheadLog.scala | 7 +-- .../spark/streaming/util/RecurringTimer.scala | 8 ++-- .../org/apache/spark/deploy/yarn/Client.scala | 10 ++--- 25 files changed, 152 insertions(+), 133 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java index fd727d64863d7..9c35a22ca9d0f 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function4.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -23,5 +23,5 @@ * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. */ public interface Function4 extends Serializable { - public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; + R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index 2a10435b7523a..f30d42ee57966 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -23,5 +23,5 @@ * A function with no return value. */ public interface VoidFunction extends Serializable { - public void call(T t) throws Exception; + void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java index 6c576ab678455..da9ae1c9c5cdc 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -23,5 +23,5 @@ * A two-argument function that takes arguments of type T1 and T2 with no return value. */ public interface VoidFunction2 extends Serializable { - public void call(T1 v1, T2 v2) throws Exception; + void call(T1 v1, T2 v2) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0b0c6e5bb8cc1..87deaf20e2b25 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -215,13 +215,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple @@ -247,13 +247,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala index 564e30d2ffd66..3d00cd9cb6377 100644 --- a/core/src/main/scala/org/apache/spark/memory/package.scala +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -21,13 +21,13 @@ package org.apache.spark * This package implements Spark's memory management system. This system consists of two main * components, a JVM-wide memory manager and a per-task manager: * - * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. - * This component implements the policies for dividing the available memory across tasks and for - * allocating memory between storage (memory used caching and data transfer) and execution (memory - * used by computations, such as shuffles, joins, sorts, and aggregations). - * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual tasks. - * Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide - * MemoryManager. + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution + * (memory used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual + * tasks. Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. * * Internally, each of these components have additional abstractions for memory bookkeeping: * diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 935c3babd8ea1..3a0ca1d813297 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -70,7 +70,7 @@ private[spark] class CoGroupPartition( * * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of * instantiating this directly. - + * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c6181902ace6d..44d195587a081 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -65,9 +65,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index a013c3f66a3a8..3ef506e1562bf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -86,7 +86,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def getPreferredLocations(partition: Partition): Seq[String] = { + override protected def getPreferredLocations(partition: Partition): Seq[String] = { val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] tracker.getPreferredLocationsForShuffle(dep, partition.index) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4fb32ba8cb188..2fcd5aa57d11b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -33,8 +33,9 @@ import org.apache.spark.util.Utils /** * A unit of execution. We have two kinds of Task's in Spark: - * - [[org.apache.spark.scheduler.ShuffleMapTask]] - * - [[org.apache.spark.scheduler.ResultTask]] + * + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] * * A Spark job consists of one or more stages. The very last stage in a job consists of multiple * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index a1b1e1631eafb..e2951d8a3e096 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -53,12 +53,13 @@ private[spark] object SerializationDebugger extends Logging { /** * Find the path leading to a not serializable object. This method is modeled after OpenJDK's * serialization mechanism, and handles the following cases: - * - primitives - * - arrays of primitives - * - arrays of non-primitive objects - * - Serializable objects - * - Externalizable objects - * - writeReplace + * + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index 2ed827eab46df..6b3fa8491904a 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -122,6 +122,7 @@ class Vector(val elements: Array[Double]) extends Serializable { override def toString: String = elements.mkString("(", ", ", ")") } +@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") object Vector { def apply(elements: Array[Double]): Vector = new Vector(elements) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2440139ac95e9..44b1d90667e65 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -67,24 +67,24 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedPairBuffer if we don't. - * Inside these buffers, we sort elements by partition ID and then possibly also by key. - * To avoid calling the partitioner multiple times with each key, we store the partition ID - * alongside each record. + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * - * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first - * by partition ID and possibly second by key or by hash code of the key, if we want to do - * aggregation. For each file, we track how many objects were in each partition in memory, so we - * don't have to write out the partition ID for every element. + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. * - * - When the user requests an iterator or file output, the spilled files are merged, along with - * any remaining in-memory data, using the same sort order defined above (unless both sorting - * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering - * from the ordering parameter, or read the keys with the same hash code and compare them with - * each other for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * - * - Users are expected to call stop() at the end to delete all the intermediate files. + * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( context: TaskContext, diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 38848e9018c6c..5232c2bd8d6f6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -23,9 +23,10 @@ import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that - * - Have an associated partition for each key-value pair. - * - Support a memory-efficient sorted iterator - * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + * + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. */ private[spark] trait WritablePartitionedPairCollection[K, V] { /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 97dbb918573a3..05080835fc4ad 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -46,17 +46,18 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * * The way this Receiver works is as follows: - * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple - * KinesisRecordProcessor - * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is - * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. - * - When the block generator defines a block, then the recorded sequence number ranges that were - * inserted into the block are recorded separately for being used later. - * - When the block is ready to be pushed, the block is pushed and the ranges are reported as - * metadata of the block. In addition, the ranges are used to find out the latest sequence - * number for each shard that can be checkpointed through the DynamoDB. - * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence - * number for it own shard. + * + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2849fd8a82102..2de6195716e5c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -226,12 +226,13 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets AWS credentials. - * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. - * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in - * [[org.apache.spark.SparkConf]]. + * + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name + * in [[org.apache.spark.SparkConf]]. * * @param ssc StreamingContext object * @param streamName Kinesis stream name diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 3b663b5defb03..37bb6f6097f67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -81,11 +81,13 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the convergence tolerance. Default 0.001 * convergenceTol is a condition which decides iteration termination. * The end of iteration is decided based on below logic. - * - If the norm of the new solution vector is >1, the diff of solution vectors - * is compared to relative tolerance which means normalizing by the norm of - * the new solution vector. - * - If the norm of the new solution vector is <=1, the diff of solution vectors - * is compared to absolute tolerance which is not normalizing. + * + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * * Must be between 0.0 and 1.0 inclusively. */ def setConvergenceTol(tolerance: Double): this.type = { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 63290d8a666e6..b1dcaedcba75e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -535,6 +535,8 @@ object Unidoc { .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b3cd9e1eff142..ad6af481fadc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -136,11 +136,12 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Extracts a value or values from a complex type. * The following types of extraction are supported: - * - Given an Array, an integer ordinal can be used to retrieve a single value. - * - Given a Map, a key of the correct type can be used to retrieve an individual value. - * - Given a Struct, a string fieldName can be used to extract that field. - * - Given an Array of Structs, a string fieldName can be used to extract filed - * of every struct in that array, and return an Array of fields + * + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields * * @group expr_ops * @since 1.4.0 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index aee172a4f549a..6fb8ad38abcec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -574,11 +574,12 @@ class StreamingContext private[streaming] ( * :: DeveloperApi :: * * Return the current state of the context. The context can be in three possible states - - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. - * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. - * Input DStreams, transformations and output operations cannot be created on the context. - * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + * + * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ @DeveloperApi def getState(): StreamingContextState = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 40208a64861fb..cb5b1f252e90c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * class remembers the information about the files selected in past batches for * a certain duration (say, "remember window") as shown in the figure below. * + * {{{ * |<----- remember window ----->| * ignore threshold --->| |<--- current batch time * |____.____.____.____.____.____| @@ -49,6 +50,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * ---------------------|----|----|----|----|----|----|-----------------------> Time * |____|____|____|____|____|____| * remembered batches + * }}} * * The trailing end of the window is the "ignore threshold" and all files whose mod times * are less than this threshold are assumed to have already been selected and are therefore @@ -59,14 +61,15 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * `isNewFile` for more details. * * This makes some assumptions from the underlying file system that the system is monitoring. - * - The clock of the file system is assumed to synchronized with the clock of the machine running - * the streaming app. - * - If a file is to be visible in the directory listings, it must be visible within a certain - * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be - * selected as the mod time will be less than the ignore threshold when it becomes visible. - * - Once a file is visible, the mod time cannot change. If it does due to appends, then the - * processing semantics are undefined. + * + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 421d60ae359f8..cc7c04bfc9f63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -84,13 +84,14 @@ private[streaming] class BlockGenerator( /** * The BlockGenerator can be in 5 possible states, in the order as follows. - * - Initialized: Nothing has been started - * - Active: start() has been called, and it is generating blocks on added data. - * - StoppedAddingData: stop() has been called, the adding of data has been stopped, - * but blocks are still being generated and pushed. - * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but - * they are still being pushed. - * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + * + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. */ private object GeneratorState extends Enumeration { type GeneratorState = Value @@ -125,9 +126,10 @@ private[streaming] class BlockGenerator( /** * Stop everything in the right order such that all the data added is pushed out correctly. - * - First, stop adding data to the current buffer. - * - Second, stop generating blocks. - * - Finally, wait for queue of to-be-pushed blocks to be drained. + * + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. */ def stop(): Unit = { // Set the state to stop adding data diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 234bc8660da8a..391a461f08125 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -27,28 +27,29 @@ import org.apache.spark.streaming.receiver.Receiver * A class that tries to schedule receivers with evenly distributed. There are two phases for * scheduling receivers. * - * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule - * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. - * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker should - * update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. - * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list - * that contains the scheduled locations. Then when a receiver is starting, it will send a - * register request and `ReceiverTracker.registerReceiver` will be called. In - * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should check - * if the location of this receiver is one of the scheduled locations, if not, the register will - * be rejected. - * - The second phase is local scheduling when a receiver is restarting. There are two cases of - * receiver restarting: - * - If a receiver is restarting because it's rejected due to the real location and the scheduled - * locations mismatching, in other words, it fails to start in one of the locations that - * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are - * still alive in the list of scheduled locations, then use them to launch the receiver job. - * - If a receiver is restarting without a scheduled locations list, or the executors in the list - * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should - * not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it should clear - * it. Then when this receiver is registering, we can know this is a local scheduling, and - * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching - * location is matching. + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker + * should update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * that contains the scheduled locations. Then when a receiver is starting, it will send a + * register request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should + * check if the location of this receiver is one of the scheduled locations, if not, the register + * will be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * locations mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that + * are still alive in the list of scheduled locations, then use them to launch the receiver + * job. + * - If a receiver is restarting without a scheduled locations list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` + * should not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it + * should clear it. Then when this receiver is registering, we can know this is a local + * scheduling, and `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if + * the launching location is matching. * * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, * otherwise do local scheduling. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index f5165f7c39122..a99b570835831 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -34,9 +34,10 @@ import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. - * - Writes records (bytebuffers) to periodically rotating log files. - * - Recovers the log files and the reads the recovered records upon failures. - * - Cleans up old log files. + * + * - Writes records (bytebuffers) to periodically rotating log files. + * - Recovers the log files and the reads the recovered records upon failures. + * - Cleans up old log files. * * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 0148cb51c6f09..bfb53614050a7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -72,10 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * - interruptTimer = true will interrupt the callback - * if it is in progress (not guaranteed to give correct time in this case). - * - interruptTimer = false guarantees that there will be at least one callback after `stop` has - * been called. + * + * @param interruptTimer True will interrupt the callback if it is in progress (not guaranteed to + * give correct time in this case). False guarantees that there will be at + * least one callback after `stop` has been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a77a3e2420e24..f0590d2d222ec 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1336,11 +1336,11 @@ object Client extends Logging { * * This method uses two configuration values: * - * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may - * only be valid in the gateway node. - * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may - * contain, for example, env variable references, which will be expanded by the NMs when - * starting containers. + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. * * If either config is not available, the input path is returned. */ From 8ddc55f1d582cccc3ca135510b2ea776e889e481 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:22:55 -0800 Subject: [PATCH 1075/2089] [SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute. Author: Wenchen Fan Closes #10059 from cloud-fan/bug. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 7 ++++--- .../org/apache/spark/sql/DatasetSuite.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 6 +++--- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index da4600133290f..c357f88a94dd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -70,7 +70,7 @@ class Dataset[T] private[sql]( * implicit so that we can use it when constructing new [[Dataset]] objects that have the same * object type (that will be possibly resolved to a different schema). */ - private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a10a89342fb5c..4bf0b256fcb4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (groupingAttributes.length > 1) { - Alias(CreateStruct(groupingAttributes), "key")() - } else { + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7d539180ded9e..a2c8d201563e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3 -> "abcxyz", 5 -> "hello") } + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupBy(s => Tuple1(s.length)).count() + + checkAnswer( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") @@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, count") { + val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() + val count = ds.groupBy($"_1").count() + + checkAnswer( + count, + (Row("a"), 2L), (Row("b"), 1L)) + } + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 6ea1fe4ccfd89..8f476dd0f99b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T : Encoder]( - ds: => Dataset[T], + protected def checkAnswer[T]( + ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } From 9df24624afedd993a39ab46c8211ae153aedef1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:24:53 -0800 Subject: [PATCH 1076/2089] [SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff and lost the required data type, which may lead to runtime error if the real type doesn't match the encoder's schema. For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type is `[a: int, b: long]`, then we will hit runtime error and say that we can't construct class `Data` with int and long, because we lost the information that `b` should be a string. Author: Wenchen Fan Closes #9840 from cloud-fan/err-msg. --- .../spark/sql/catalyst/ScalaReflection.scala | 93 +++++++-- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++ .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 9 + .../expressions/complexTypeCreator.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 12 ++ .../encoders/EncoderResolutionSuite.scala | 180 ++++++++++++++++++ .../spark/sql/DatasetAggregatorSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 21 +- 10 files changed, 335 insertions(+), 32 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d133ad3f0d89d..9b6b5b8bd1a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) + def constructorFor[T : TypeTag]: Expression = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + constructorFor(tpe, None, walkedTypePath) + } private def constructorFor( tpe: `Type`, - path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + path: Option[Expression], + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal( + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + def getPath: Expression = { + val dataType = schemaFor(tpe).dataType + if (path.isDefined) { + path.get + } else { + upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + } + } + + /** + * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * don't need to cast struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it, thus we only need to handle leaf type. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - WrapOption(constructorFor(optType, path)) + val className = getClassNameFromType(optType) + val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + WrapOption(constructorFor(optType, path, newTypePath)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection { primitiveMethod.map { method => Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val arrayData = Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection { arrayData :: Nil) case t if t <:< localTypeOf[Map[_, _]] => + // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => constructorFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => constructorFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val dataType = schemaFor(fieldType).dataType - + val clsName = getClassNameFromType(fieldType) + val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + constructorFor( + fieldType, + Some(addToPathOrdinal(i, dataType, newTypePath)), + newTypePath) } else { - constructorFor(fieldType, Some(addToPath(fieldName))) + constructorFor( + fieldType, + Some(addToPath(fieldName, dataType, newTypePath)), + newTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b8f212fca7509..765327c474e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: + ResolveUpCast :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. + */ +object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index f90fc3cc12189..29502a59915f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -53,7 +53,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private val numericPrecedence = + private[sql] val numericPrecedence = IndexedSeq( ByteType, ShortType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0c10a56c555f4..06ffe864552fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{StructField, ObjectType, StructType} @@ -235,12 +236,13 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) + val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), // we need an instance of the outer scope. This rule substitues those outer objects into // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` // registry. - copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => val outer = outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a2c6c39fd8ce2..cb60d5958d535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -914,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } } + +/** + * Cast the child expression to the target data type, but will throw error if the cast might + * truncate, e.g. long -> int, timestamp -> data. + */ +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) + extends UnaryExpression with Unevaluable { + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1854dfaa7db35..72cc89c8be915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** - * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this * StructType. */ def flatten: Seq[NamedExpression] = valExprs.zip(names).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0cd352d0fa928..ce45245b9f6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { case _ => false } + /** + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` + * can be casted into `other` safely without losing any precision or range. + */ + private[sql] def isTighterThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale + case dt: IntegralType => + isTighterThan(DecimalType.forType(dt)) + case _ => false + } + /** * The default size of a value of the DecimalType is 4096 bytes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala new file mode 100644 index 0000000000000..0289988342e78 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ + +case class StringLongClass(a: String, b: Long) + +case class StringIntClass(a: String, b: Int) + +case class ComplexClass(a: Long, b: StringLongClass) + +class EncoderResolutionSuite extends PlanTest { + test("real type doesn't match encoder schema but they are compatible: product") { + val encoder = ExpressionEncoder[StringLongClass] + val cls = classOf[StringLongClass] + + { + val attrs = Seq('a.string, 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + { + val attrs = Seq('a.int, 'b.long) + val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression + val expected = NewInstance( + cls, + toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + } + + test("real type doesn't match encoder schema but they are compatible: nested product") { + val encoder = ExpressionEncoder[ComplexClass] + val innerCls = classOf[StringLongClass] + val cls = classOf[ComplexClass] + + val structType = new StructType().add("a", IntegerType).add("b", LongType) + val attrs = Seq('a.int, 'b.struct(structType)) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + Seq( + 'a.int.cast(LongType), + If( + 'b.struct(structType).isNull, + Literal.create(null, ObjectType(innerCls)), + NewInstance( + innerCls, + Seq( + toExternalString( + GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), + GetStructField('b.struct(structType), 1, Some("b"))), + false, + ObjectType(innerCls)) + )), + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + test("real type doesn't match encoder schema but they are compatible: tupled encoder") { + val encoder = ExpressionEncoder.tuple( + ExpressionEncoder[StringLongClass], + ExpressionEncoder[Long]) + val cls = classOf[StringLongClass] + + val structType = new StructType().add("a", StringType).add("b", ByteType) + val attrs = Seq('a.struct(structType), 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + classOf[Tuple2[_, _]], + Seq( + NewInstance( + cls, + Seq( + toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), + GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + false, + ObjectType(cls)), + 'b.int.cast(LongType)), + false, + ObjectType(classOf[Tuple2[_, _]])) + compareExpressions(fromRowExpr, expected) + } + + private def toExternalString(e: Expression): Expression = { + Invoke(e, "toString", ObjectType(classOf[String]), Nil) + } + + test("throw exception if real type is not compatible with encoder schema") { + val msg1 = intercept[AnalysisException] { + ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + }.message + assert(msg1 == + s""" + |Cannot up cast `b` from bigint to int as it may truncate + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + + val msg2 = intercept[AnalysisException] { + val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) + ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + }.message + assert(msg2 == + s""" + |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + // test for leaf types + castSuccess[Int, Long] + castSuccess[java.sql.Date, java.sql.Timestamp] + castSuccess[Long, String] + castSuccess[Int, java.math.BigDecimal] + castSuccess[Long, java.math.BigDecimal] + + castFail[Long, Int] + castFail[java.sql.Timestamp, java.sql.Date] + castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] + castFail[java.math.BigDecimal, Int] + castFail[String, Long] + + + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { + to.resolve(from.schema.toAttributes, null) + } + } + + private def castFail[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { + intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 19dce5d1e2f37..c6d2bf07b2803 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( sum(_._2), - expr("sum(_2)").as[Int], + expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) } test("typed aggregation: complex case") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a2c8d201563e5..542e4d6c43b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -335,24 +335,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int]), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupBy(_._1).agg(sum("_2").as[Long]), + ("a", 30L), ("b", 3L), ("c", 1L)) } test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), - ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), - ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } test("typed aggregation: expr, expr, expr, expr") { @@ -360,11 +360,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( - sum("_2").as[Int], + sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*").as[Long], avg("_2").as[Double]), - ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } test("cogroup") { @@ -476,6 +476,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ((nullInt, "1"), (new java.lang.Integer(22), "2")), ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) } + + test("change encoder with compatible schema") { + val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] + assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) + } } From fd95eeaf491809c6bb0f83d46b37b5e2eebbcbca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:35:12 -0800 Subject: [PATCH 1077/2089] [SPARK-11954][SQL] Encoder for JavaBeans create java version of `constructorFor` and `extractorFor` in `JavaTypeInference` Author: Wenchen Fan This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #9937 from cloud-fan/pojo. --- .../scala/org/apache/spark/sql/Encoder.scala | 18 + .../sql/catalyst/JavaTypeInference.scala | 313 +++++++++++++++++- .../catalyst/encoders/ExpressionEncoder.scala | 21 +- .../sql/catalyst/expressions/objects.scala | 42 ++- .../spark/sql/catalyst/trees/TreeNode.scala | 27 +- .../sql/catalyst/util/ArrayBasedMapData.scala | 5 + .../sql/catalyst/util/GenericArrayData.scala | 3 + .../sql/catalyst/trees/TreeNodeSuite.scala | 25 ++ .../apache/spark/sql/JavaDatasetSuite.java | 174 +++++++++- 9 files changed, 608 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 03aa25eda807f..c40061ae0aafd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,24 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7d4cfbe6faecb..c8ee87e8819f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.catalyst -import java.beans.Introspector +import java.beans.{PropertyDescriptor, Introspector} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap} +import java.util.{Iterator => JIterator, Map => JMap, List => JList} import scala.language.existentials import com.google.common.reflect.TypeToken + import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.unsafe.types.UTF8String + /** * Type-inference utilities for POJOs and Java collections. @@ -33,13 +39,14 @@ object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val listType = TypeToken.of(classOf[JList[_]]) private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType /** - * Infers the corresponding SQL data type of a JavaClean class. + * Infers the corresponding SQL data type of a JavaBean class. * @param beanClass Java type * @return (SQL data type, nullable) */ @@ -58,6 +65,8 @@ object JavaTypeInference { (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -87,15 +96,14 @@ object JavaTypeInference { (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) - val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyType, valueType) = mapKeyValueType(typeToken) val (keyDataType, _) = inferDataType(keyType) val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) case _ => + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") val fields = properties.map { property => @@ -107,11 +115,294 @@ object JavaTypeInference { } } + private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors + .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + } + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSupertype.resolveType(iteratorReturnType) - val itemType = iteratorType.resolveType(nextReturnType) - itemType + val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSuperType.resolveType(iteratorReturnType) + iteratorType.resolveType(nextReturnType) + } + + private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) + val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) + keyType -> valueType + } + + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case _ => ObjectType(cls) + } + + /** + * Returns an expression that can be used to construct an object of java bean `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def constructorFor(beanClass: Class[_]): Expression = { + constructorFor(TypeToken.of(beanClass), None) + } + + private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + + typeToken.getRawType match { + case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + + case c if c == classOf[java.lang.Short] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Integer] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Long] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Double] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Byte] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Float] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Boolean] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case c if c == classOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case c if c.isArray => + val elementType = c.getComponentType + val primitiveMethod = elementType match { + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(typeToken.getComponentType, Some(p)), + getPath, + inferDataType(elementType)._1), + "array", + ObjectType(c)) + } + + case c if listType.isAssignableFrom(typeToken) => + val et = elementType(typeToken) + val array = + Invoke( + MapObjects( + p => constructorFor(et, Some(p)), + getPath, + inferDataType(et)._1), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + + case _ if mapType.isAssignableFrom(typeToken) => + val (keyType, valueType) = mapKeyValueType(typeToken) + val keyDataType = inferDataType(keyType)._1 + val valueDataType = inferDataType(valueType)._1 + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + val setters = properties.map { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + }.toMap + + val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val result = InitializeJavaBean(newInstance, setters) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(other)), + result + ) + } else { + result + } + } + } + + /** + * Returns expressions for extracting all the fields from the given type. + */ + def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) + extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + } + + private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + + def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { + val (dataType, nullable) = inferDataType(elementType) + if (ScalaReflection.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + typeToken.getRawType match { + case c if c == classOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case c if c == classOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case c if c == classOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + case c if c == classOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case c if c == classOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case c if c == classOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case c if c == classOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + + case _ if typeToken.isArray => + toCatalystArray(inputObject, typeToken.getComponentType) + + case _ if listType.isAssignableFrom(typeToken) => + toCatalystArray(inputObject, elementType(typeToken)) + + case _ if mapType.isAssignableFrom(typeToken) => + // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can + // not guarantee they have same iteration order(which is different from scala map). + // A possible solution is creating a new `MapObjects` that can iterate a map directly. + throw new UnsupportedOperationException("map type is not supported currently") + + case other => + val properties = getJavaBeanProperties(other) + if (properties.length > 0) { + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + } else { + throw new UnsupportedOperationException(s"no encoder found for ${other.getName}") + } + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 06ffe864552fd..3e8420ecb9ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** @@ -68,6 +67,22 @@ object ExpressionEncoder { ClassTag[T](cls)) } + // TODO: improve error message for java bean encoder. + def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass)._1 + assert(schema.isInstanceOf[StructType]) + + val toRowExpression = JavaTypeInference.extractorsFor(beanClass) + val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + + new ExpressionEncoder[T]( + schema.asInstanceOf[StructType], + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](beanClass)) + } + /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an * N-tuple. Note that these encoders should be unresolved so that information about @@ -216,7 +231,7 @@ case class ExpressionEncoder[T]( */ def assertUnresolved(): Unit = { (fromRowExpression +: toRowExpressions).foreach(_.foreach { - case a: AttributeReference => + case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 62d09f0f55105..e6ab9a31be59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData + * The following collection ObjectTypes are currently supported: + * Seq, Array, ArrayData, java.util.List * * @param function A function that returns an expression, given an attribute that can be used * to access the current value. This is does as a lambda function so that @@ -386,6 +387,8 @@ case class MapObjects( (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".get($i)", false) case ArrayType(t, _) => val (sqlType, primitiveElement) = t match { case m: MapType => (m, false) @@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override def dataType: DataType = ObjectType(tag.runtimeClass) } + +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) + extends Expression { + + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val instanceGen = beanInstance.gen(ctx) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.gen(ctx) + s""" + ${fieldGen.code} + ${instanceGen.value}.$setterMethod(${fieldGen.value}); + """ + } + + ev.isNull = instanceGen.isNull + ev.value = instanceGen.value + + s""" + ${instanceGen.code} + if (!${instanceGen.isNull}) { + ${initialize.mkString("\n")} + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 35f087baccdee..f1cea07976a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.trees +import scala.collection.Map + import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.types.{StructType, DataType} @@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null } + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.view.force // `mapValues` is lazy and we need to force it to materialize case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { Some(arg) } - case m: Map[_, _] => m + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 70b028d2b3f7c..d85b72ed83def 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -70,4 +70,9 @@ object ArrayBasedMapData { def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { keys.zip(values).toMap } + + def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = { + import scala.collection.JavaConverters._ + keys.zip(values).toMap.asJava + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 96588bb5dc1bc..2b8cdc1e23ab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: Seq[Any]) = this(seq.toArray) + def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 8fff39906b342..965bdb1515e55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } +case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { + override def children: Seq[Expression] = map.values.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite { val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) assert(expected === actual) } + + test("expressions inside a map") { + val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2))) + + { + val actual = expression.transform { + case Literal(i: Int, _) => Literal(i + 1) + } + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + + { + val actual = expression.withNewChildren(Seq(Literal(2), Literal(3))) + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 67a3190cb7d4f..ae47f4fe0e231 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -31,14 +31,15 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.catalyst.encoders.OuterScopes; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; @@ -506,4 +507,169 @@ public void testJavaEncoderErrorMessageForPrivateClass() { public void testKryoEncoderErrorMessageForPrivateClass() { Encoders.kryo(PrivateClassTest.class); } + + public class SimpleJavaBean implements Serializable { + private boolean a; + private int b; + private byte[] c; + private String[] d; + private List e; + private List f; + + public boolean isA() { + return a; + } + + public void setA(boolean a) { + this.a = a; + } + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public byte[] getC() { + return c; + } + + public void setC(byte[] c) { + this.c = c; + } + + public String[] getD() { + return d; + } + + public void setD(String[] d) { + this.d = d; + } + + public List getE() { + return e; + } + + public void setE(List e) { + this.e = e; + } + + public List getF() { + return f; + } + + public void setF(List f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (a != that.a) return false; + if (b != that.b) return false; + if (!Arrays.equals(c, that.c)) return false; + if (!Arrays.equals(d, that.d)) return false; + if (!e.equals(that.e)) return false; + return f.equals(that.f); + } + + @Override + public int hashCode() { + int result = (a ? 1 : 0); + result = 31 * result + b; + result = 31 * result + Arrays.hashCode(c); + result = 31 * result + Arrays.hashCode(d); + result = 31 * result + e.hashCode(); + result = 31 * result + f.hashCode(); + return result; + } + } + + public class NestedJavaBean implements Serializable { + private SimpleJavaBean a; + + public SimpleJavaBean getA() { + return a; + } + + public void setA(SimpleJavaBean a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NestedJavaBean that = (NestedJavaBean) o; + + return a.equals(that.a); + } + + @Override + public int hashCode() { + return a.hashCode(); + } + } + + @Test + public void testJavaBeanEncoder() { + OuterScopes.addOuterScope(this); + SimpleJavaBean obj1 = new SimpleJavaBean(); + obj1.setA(true); + obj1.setB(3); + obj1.setC(new byte[]{1, 2}); + obj1.setD(new String[]{"hello", null}); + obj1.setE(Arrays.asList("a", "b")); + obj1.setF(Arrays.asList(100L, null, 200L)); + SimpleJavaBean obj2 = new SimpleJavaBean(); + obj2.setA(false); + obj2.setB(30); + obj2.setC(new byte[]{3, 4}); + obj2.setD(new String[]{null, "world"}); + obj2.setE(Arrays.asList("x", "y")); + obj2.setF(Arrays.asList(300L, null, 400L)); + + List data = Arrays.asList(obj1, obj2); + Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds.collectAsList()); + + NestedJavaBean obj3 = new NestedJavaBean(); + obj3.setA(obj1); + + List data2 = Arrays.asList(obj3); + Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Assert.assertEquals(data2, ds2.collectAsList()); + + Row row1 = new GenericRow(new Object[]{ + true, + 3, + new byte[]{1, 2}, + new String[]{"hello", null}, + Arrays.asList("a", "b"), + Arrays.asList(100L, null, 200L)}); + Row row2 = new GenericRow(new Object[]{ + false, + 30, + new byte[]{3, 4}, + new String[]{null, "world"}, + Arrays.asList("x", "y"), + Arrays.asList(300L, null, 400L)}); + StructType schema = new StructType() + .add("a", BooleanType, false) + .add("b", IntegerType, false) + .add("c", BinaryType) + .add("d", createArrayType(StringType)) + .add("e", createArrayType(StringType)) + .add("f", createArrayType(LongType)); + Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + .as(Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds3.collectAsList()); + } } From 0a7bca2da04aefff16f2513ec27a92e69ceb77f6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 1 Dec 2015 10:38:59 -0800 Subject: [PATCH 1078/2089] [SPARK-11905][SQL] Support Persist/Cache and Unpersist in Dataset APIs Persist and Unpersist exist in both RDD and Dataframe APIs. I think they are still very critical in Dataset APIs. Not sure if my understanding is correct? If so, could you help me check if the implementation is acceptable? Please provide your opinions. marmbrus rxin cloud-fan Thank you very much! Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #9889 from gatorsmile/persistDS. --- .../org/apache/spark/sql/DataFrame.scala | 9 +++ .../scala/org/apache/spark/sql/Dataset.scala | 50 +++++++++++- .../org/apache/spark/sql/SQLContext.scala | 9 +++ .../spark/sql/execution/CacheManager.scala | 27 ++++--- .../apache/spark/sql/DatasetCacheSuite.scala | 80 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 5 +- 6 files changed, 162 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6197f10813a3b..eb8700369275e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1584,6 +1584,7 @@ class DataFrame private[sql]( def distinct(): DataFrame = dropDuplicates() /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ @@ -1593,12 +1594,17 @@ class DataFrame private[sql]( } /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ def cache(): this.type = persist() /** + * Persist this [[DataFrame]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.3.0 */ @@ -1608,6 +1614,8 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @group basic * @since 1.3.0 */ @@ -1617,6 +1625,7 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c357f88a94dd0..d6bb1d2ad8e50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** @@ -565,7 +566,7 @@ class Dataset[T] private[sql]( * combined. * * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) @@ -618,7 +619,6 @@ class Dataset[T] private[sql]( case _ => Alias(CreateStruct(rightOutput), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => @@ -697,11 +697,55 @@ class Dataset[T] private[sql]( */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def cache(): this.type = persist() + + /** + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * @group basic + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. + * @since 1.6.0 + */ + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @since 1.6.0 + */ + def unpersist(): this.type = unpersist(blocking = false) + /* ******************** * * Internal Functions * * ******************** */ - private[sql] def logicalPlan = queryExecution.analyzed + private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9cc65de19180a..4e26250868374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -338,6 +338,15 @@ class SQLContext private[sql]( cacheManager.lookupCachedData(table(tableName)).nonEmpty } + /** + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + private[sql] def isCached(qName: Queryable): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + /** * Caches the specified table in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 293fcfe96e677..50f6562815c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[Queryable]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Queryable, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Queryable]] from the cache */ + private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -109,9 +108,11 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** Tries to remove the data for the given [[Queryable]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Queryable, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Queryable]] */ + private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 0000000000000..3a283a4e1f610 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8f476dd0f99b6..bc22fb8b7bdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.Queryable abstract class QueryTest extends PlanTest { @@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Queryable]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached From 6a8cf80cc8ef435ec46138fa57325bda5d68f3ce Mon Sep 17 00:00:00 2001 From: woj-i Date: Tue, 1 Dec 2015 11:05:45 -0800 Subject: [PATCH 1079/2089] [SPARK-11821] Propagate Kerberos keytab for all environments andrewor14 the same PR as in branch 1.5 harishreedharan Author: woj-i Closes #9859 from woj-i/master. --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 4 ++++ docs/running-on-yarn.md | 4 ++-- docs/sql-programming-guide.md | 7 ++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 2e912b59afdb8..52d3ab34c1784 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -545,6 +545,10 @@ object SparkSubmit { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } + } + + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL) { if (args.principal != null) { require(args.keytab != null, "Keytab must be specified when principal is specified") if (!new File(args.keytab).exists()) { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 925a1e0ba6fcf..06413f83c3a71 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -358,14 +358,14 @@ If you need a reference to the proper location to put log files in the YARN so t The full path to the file that contains the keytab for the principal specified above. This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. + for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) spark.yarn.principal (none) - Principal to be used to login to KDC, while running on secure HDFS. + Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d7b205c2fa0df..7b1d97baa3823 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1614,7 +1614,8 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), + `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the @@ -2028,7 +2029,7 @@ Beeline will ask you for a username and password. In non-secure mode, simply ent your machine and a blank password. For secure mode, please follow the instructions given in the [beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients). -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may also use the beeline script that comes with Hive. @@ -2053,7 +2054,7 @@ To start the Spark SQL CLI, run the following in the Spark directory: ./bin/spark-sql -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. From 34e7093c1131162b3aa05b65a19a633a0b5b633e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Dec 2015 11:49:20 -0800 Subject: [PATCH 1080/2089] [SPARK-12065] Upgrade Tachyon from 0.8.1 to 0.8.2 This commit upgrades the Tachyon dependency from 0.8.1 to 0.8.2. Author: Josh Rosen Closes #10054 from JoshRosen/upgrade-to-tachyon-0.8.2. --- core/pom.xml | 2 +- make-distribution.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 37e3f168ab374..61744bb5c7bf5 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -270,7 +270,7 @@ org.tachyonproject tachyon-client - 0.8.1 + 0.8.2 org.apache.hadoop diff --git a/make-distribution.sh b/make-distribution.sh index 7b417fe7cf619..e64ceb802464c 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.8.1" +TACHYON_VERSION="0.8.2" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" From 2cef1cdfbb5393270ae83179b6a4e50c3cbf9e93 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 1 Dec 2015 12:59:53 -0800 Subject: [PATCH 1081/2089] [SPARK-12030] Fix Platform.copyMemory to handle overlapping regions. This bug was exposed as memory corruption in Timsort which uses copyMemory to copy large regions that can overlap. The prior implementation did not handle this case half the time and always copied forward, resulting in the data being corrupt. Author: Nong Li Closes #10068 from nongli/spark-12030. --- .../org/apache/spark/unsafe/Platform.java | 27 ++++++-- .../spark/unsafe/PlatformUtilSuite.java | 61 +++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1c16da982923b..0d6b215fe5aaf 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -107,12 +107,27 @@ public static void freeMemory(long address) { public static void copyMemory( Object src, long srcOffset, Object dst, long dstOffset, long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java new file mode 100644 index 0000000000000..693ec6ec58dbd --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import org.junit.Assert; +import org.junit.Test; + +public class PlatformUtilSuite { + + @Test + public void overlappingCopyMemory() { + byte[] data = new byte[3 * 1024 * 1024]; + int size = 2 * 1024 * 1024; + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + + Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, data, Platform.BYTE_ARRAY_OFFSET, size); + for (int i = 0; i < data.length; ++i) { + Assert.assertEquals((byte)i, data[i]); + } + + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET + 1, + data, + Platform.BYTE_ARRAY_OFFSET, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)(i + 1), data[i]); + } + + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET, + data, + Platform.BYTE_ARRAY_OFFSET + 1, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)i, data[i + 1]); + } + } +} From 60b541ee1b97c9e5e84aa2af2ce856f316ad22b3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Dec 2015 14:08:36 -0800 Subject: [PATCH 1082/2089] [SPARK-12004] Preserve the RDD partitioner through RDD checkpointing The solution is the save the RDD partitioner in a separate file in the RDD checkpoint directory. That is, `/_partitioner`. In most cases, whether the RDD partitioner was recovered or not, does not affect the correctness, only reduces performance. So this solution makes a best-effort attempt to save and recover the partitioner. If either fails, the checkpointing is not affected. This makes this patch safe and backward compatible. Author: Tathagata Das Closes #9983 from tdas/SPARK-12004. --- .../spark/rdd/ReliableCheckpointRDD.scala | 122 +++++++++++++++++- .../spark/rdd/ReliableRDDCheckpointData.scala | 21 +-- .../org/apache/spark/CheckpointSuite.scala | 61 ++++++++- 3 files changed, 173 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index a69be6a068bbf..fa71b8c26233d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -20,12 +20,12 @@ package org.apache.spark.rdd import java.io.IOException import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -33,8 +33,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( sc: SparkContext, - val checkpointPath: String) - extends CheckpointRDD[T](sc) { + val checkpointPath: String, + _partitioner: Option[Partitioner] = None + ) extends CheckpointRDD[T](sc) { @transient private val hadoopConf = sc.hadoopConfiguration @transient private val cpath = new Path(checkpointPath) @@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( /** * Return the path of the checkpoint directory this RDD reads data from. */ - override def getCheckpointFile: Option[String] = Some(checkpointPath) + override val getCheckpointFile: Option[String] = Some(checkpointPath) + + override val partitioner: Option[Partitioner] = { + _partitioner.orElse { + ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath) + } + } /** * Return partitions described by the files in the checkpoint directory. @@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging { "part-%05d".format(partitionIndex) } + private def checkpointPartitionerFileName(): String = { + "_partitioner" + } + + /** + * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD. + */ + def writeRDDToCheckpointDirectory[T: ClassTag]( + originalRDD: RDD[T], + checkpointDir: String, + blockSize: Int = -1): ReliableCheckpointRDD[T] = { + + val sc = originalRDD.sparkContext + + // Create the output path for the checkpoint + val checkpointDirPath = new Path(checkpointDir) + val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration) + if (!fs.mkdirs(checkpointDirPath)) { + throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + sc.runJob(originalRDD, + writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _) + + if (originalRDD.partitioner.nonEmpty) { + writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) + } + + val newRDD = new ReliableCheckpointRDD[T]( + sc, checkpointDirPath.toString, originalRDD.partitioner) + if (newRDD.partitions.length != originalRDD.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})") + } + newRDD + } + /** - * Write this partition's values to a checkpoint file. + * Write a RDD partition's data to a checkpoint file. */ - def writeCheckpointFile[T: ClassTag]( + def writePartitionToCheckpointFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { @@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging { } } + /** + * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort + * basis; any exception while writing the partitioner is caught, logged and ignored. + */ + private def writePartitionerToCheckpointDir( + sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = { + try { + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeObject(partitioner) + } { + serializeStream.close() + } + logDebug(s"Written partitioner to $partitionerFilePath") + } catch { + case NonFatal(e) => + logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath") + } + } + + + /** + * Read a partitioner from the given RDD checkpoint directory, if it exists. + * This is done on a best-effort basis; any exception while reading the partitioner is + * caught, logged and ignored. + */ + private def readCheckpointedPartitionerFile( + sc: SparkContext, + checkpointDirPath: String): Option[Partitioner] = { + try { + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(partitionerFilePath)) { + val fileInputStream = fs.open(partitionerFilePath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + val partitioner = Utils.tryWithSafeFinally[Partitioner] { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } + logDebug(s"Read partitioner from $partitionerFilePath") + Some(partitioner) + } else { + logDebug("No partitioner file") + None + } + } catch { + case NonFatal(e) => + logWarning(s"Error reading partitioner from $checkpointDirPath, " + + s"partitioner will not be recovered which may lead to performance loss", e) + None + } + } + /** * Read the content of the specified checkpoint file. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 91cad6662e4d2..cac6cbe780e91 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v * This is called immediately after the first action invoked on this RDD has completed. */ protected override def doCheckpoint(): CheckpointRDD[T] = { - - // Create the output path for the checkpoint - val path = new Path(cpDir) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $cpDir") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) - val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + - s"number of partitions from original RDD $rdd(${rdd.partitions.length})") - } + val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir) // Optionally clean our checkpoint files if the reference is out of scope if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { @@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v } logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") - newRDD } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index ab23326c6c25d..553d46285ac03 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,7 +21,8 @@ import java.io.File import scala.reflect.ClassTag -import org.apache.spark.CheckpointSuite._ +import org.apache.hadoop.fs.Path + import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite => // Test whether the checkpoint file has been created if (reliableCheckpoint) { - assert( - collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + assert(operatedRDD.getCheckpointFile.nonEmpty) + val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get) + assert(collectFunc(recoveredRDD) === result) + assert(recoveredRDD.partitioner === operatedRDD.partitioner) } // Test whether dependencies have been changed from its earlier parent RDD @@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite => } /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - protected def runTest(name: String)(body: Boolean => Unit): Unit = { + protected def runTest( + name: String, + skipLocalCheckpoint: Boolean = false + )(body: Boolean => Unit): Unit = { test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) + if (!skipLocalCheckpoint) { + test(name + " [local checkpoint]")(body(false)) + } } /** @@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(flatMappedRDD.collect() === result) } + runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => + + def testPartitionerCheckpointing( + partitioner: Partitioner, + corruptPartitionerFile: Boolean = false + ): Unit = { + val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner) + rddWithPartitioner.checkpoint() + rddWithPartitioner.count() + assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty, + "checkpointing was not successful") + + if (corruptPartitionerFile) { + // Overwrite the partitioner file with garbage data + val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get) + val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration) + val partitionerFile = fs.listStatus(checkpointDir) + .find(_.getPath.getName.contains("partitioner")) + .map(_.getPath) + require(partitionerFile.nonEmpty, "could not find the partitioner file for testing") + val output = fs.create(partitionerFile.get, true) + output.write(100) + output.close() + } + + val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get) + assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered") + + if (!corruptPartitionerFile) { + assert(newRDD.partitioner != None, "partitioner not recovered") + assert(newRDD.partitioner === rddWithPartitioner.partitioner, + "recovered partitioner does not match") + } else { + assert(newRDD.partitioner == None, "partitioner unexpectedly recovered") + } + } + + testPartitionerCheckpointing(partitioner) + + // Test that corrupted partitioner file does not prevent recovery of RDD + testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) + } + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => testRDD(_.map(x => x.toString), reliableCheckpoint) testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) From 328b757d5d4486ea3c2e246780792d7a57ee85e5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 1 Dec 2015 15:13:10 -0800 Subject: [PATCH 1083/2089] Revert "[SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize" This reverts commit 1401166576c7018c5f9c31e0a6703d5fb16ea339. --- .../spark/serializer/JavaSerializer.scala | 7 +++-- .../spark/util/ByteBufferOutputStream.scala | 31 ------------------- 2 files changed, 4 insertions(+), 34 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index ea718a0edbe71..b463a71d5bd7d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,7 +24,8 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -95,11 +96,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteBufferOutputStream() + val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - bos.toByteBuffer + ByteBuffer.wrap(bos.toByteArray) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala deleted file mode 100644 index 92e45224db81c..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.io.ByteArrayOutputStream -import java.nio.ByteBuffer - -/** - * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer - */ -private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { - - def toByteBuffer: ByteBuffer = { - return ByteBuffer.wrap(buf, 0, count) - } -} From e76431f886ae8061545b3216e8e2fb38c4db1f43 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 1 Dec 2015 15:21:53 -0800 Subject: [PATCH 1084/2089] [SPARK-11961][DOC] Add docs of ChiSqSelector https://issues.apache.org/jira/browse/SPARK-11961 Author: Xusen Yin Closes #9965 from yinxusen/SPARK-11961. --- docs/ml-features.md | 50 +++++++++++++ .../examples/ml/JavaChiSqSelectorExample.java | 71 +++++++++++++++++++ .../examples/ml/ChiSqSelectorExample.scala | 57 +++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index cd1838d6d2882..5f888775553f2 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1949,3 +1949,53 @@ output.select("features", "label").show() {% endhighlight %} + +## ChiSqSelector + +`ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with +categorical features. ChiSqSelector orders features based on a +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) +from the class, and then filters (selects) the top features which the class label depends on the +most. This is akin to yielding the features with the most predictive power. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `clicked`, which is used as +our target to be predicted: + +~~~ +id | features | clicked +---|-----------------------|--------- + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 +~~~ + +If we use `ChiSqSelector` with a `numTopFeatures = 1`, then according to our label `clicked` the +last column in our `features` chosen as the most useful feature: + +~~~ +id | features | clicked | selectedFeatures +---|-----------------------|---------|------------------ + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 | [1.0] + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 | [0.0] + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 | [0.1] +~~~ + +
    +
    + +Refer to the [ChiSqSelector Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ChiSqSelector) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala %} +
    + +
    + +Refer to the [ChiSqSelector Java docs](api/java/org/apache/spark/ml/feature/ChiSqSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java %} +
    +
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java new file mode 100644 index 0000000000000..ede05d6e20118 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.ChiSqSelector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaChiSqSelectorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + ChiSqSelector selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures"); + + DataFrame result = selector.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala new file mode 100644 index 0000000000000..a8d2bc4907e80 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ChiSqSelector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ChiSqSelectorExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ChiSqSelectorExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Seq( + (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + ) + + val df = sc.parallelize(data).toDF("id", "features", "clicked") + + val selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From f292018f8e57779debc04998456ec875f628133b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 1 Dec 2015 15:26:10 -0800 Subject: [PATCH 1085/2089] [SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint recovery issue Fixed a minor race condition in #10017 Closes #10017 Author: jerryshao Author: Shixiong Zhu Closes #10074 from zsxwing/review-pr10017. --- python/pyspark/streaming/tests.py | 49 +++++++++++++++++++++++++++++++ python/pyspark/streaming/util.py | 13 ++++---- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a647e6bf39581..d50c6b8d4a428 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1149,6 +1149,55 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_with_checkpoint(self): + """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + self.ssc.stop(False) + self.ssc = None + tmpdir = "checkpoint-test-%d" % random.randint(0, 10000) + + def setup(): + ssc = StreamingContext(self.sc, 0.5) + ssc.checkpoint(tmpdir) + stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams) + stream.transform(transformWithOffsetRanges).count().pprint() + return ssc + + try: + ssc1 = StreamingContext.getOrCreate(tmpdir, setup) + ssc1.start() + self.wait_for(offsetRanges, 1) + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + # To make sure some checkpoint is written + time.sleep(3) + ssc1.stop(False) + ssc1 = None + + # Restart again to make sure the checkpoint is recovered correctly + ssc2 = StreamingContext.getOrCreate(tmpdir, setup) + ssc2.start() + ssc2.awaitTermination(3) + ssc2.stop(stopSparkContext=False, stopGraceFully=True) + ssc2 = None + finally: + shutil.rmtree(tmpdir) + @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_message_handler(self): """Test Python direct Kafka RDD MessageHandler.""" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index c7f02bca2ae38..abbbf6eb9394f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,11 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) self.failure = None def rdd_wrapper(self, func): - self._rdd_wrapper = func + self.rdd_wrap_func = func return self def call(self, milliseconds, jrdds): @@ -59,7 +59,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None + rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) @@ -101,7 +101,8 @@ def dumps(self, id): self.failure = None try: func = self.gateway.gateway_property.pool[id] - return bytearray(self.serializer.dumps((func.func, func.deserializers))) + return bytearray(self.serializer.dumps(( + func.func, func.rdd_wrap_func, func.deserializers))) except: self.failure = traceback.format_exc() @@ -109,8 +110,8 @@ def loads(self, data): # Clear the failure self.failure = None try: - f, deserializers = self.serializer.loads(bytes(data)) - return TransformFunction(self.ctx, f, *deserializers) + f, wrap_func, deserializers = self.serializer.loads(bytes(data)) + return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func) except: self.failure = traceback.format_exc() From ef6790fdc3b70b9d6184ec2b3d926e4b0e4b15f6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 2 Dec 2015 07:29:45 +0800 Subject: [PATCH 1086/2089] [SPARK-12075][SQL] Speed up HiveComparisionTest by avoiding / speeding up TestHive.reset() When profiling HiveCompatibilitySuite, I noticed that most of the time seems to be spent in expensive `TestHive.reset()` calls. This patch speeds up suites based on HiveComparisionTest, such as HiveCompatibilitySuite, with the following changes: - Avoid `TestHive.reset()` whenever possible: - Use a simple set of heuristics to guess whether we need to call `reset()` in between tests. - As a safety-net, automatically re-run failed tests by calling `reset()` before the re-attempt. - Speed up the expensive parts of `TestHive.reset()`: loading the `src` and `srcpart` tables took roughly 600ms per test, so we now avoid this by using a simple heuristic which only loads those tables by tests that reference them. This is based on simple string matching over the test queries which errs on the side of loading in more situations than might be strictly necessary. After these changes, HiveCompatibilitySuite seems to run in about 10 minutes. This PR is a revival of #6663, an earlier experimental PR from June, where I played around with several possible speedups for this suite. Author: Josh Rosen Closes #10055 from JoshRosen/speculative-testhive-reset. --- .../apache/spark/sql/hive/test/TestHive.scala | 7 -- .../hive/execution/HiveComparisonTest.scala | 67 +++++++++++++++++-- .../hive/execution/HiveQueryFileTest.scala | 2 +- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6883d305cbead..2e2d201bf254d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -443,13 +443,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { defaultOverrides() runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") } catch { case e: Exception => logError("FATAL ERROR: Failed to reset TestDB state.", e) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index aa95ba94fa873..4455430aa727a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -209,7 +209,11 @@ abstract class HiveComparisonTest } val installHooksCommand = "(?i)SET.*hooks".r - def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { + def createQueryTest( + testCaseName: String, + sql: String, + reset: Boolean = true, + tryWithoutResettingFirst: Boolean = false) { // testCaseName must not contain ':', which is not allowed to appear in a filename of Windows assert(!testCaseName.contains(":")) @@ -240,9 +244,6 @@ abstract class HiveComparisonTest test(testCaseName) { logDebug(s"=== HIVE TEST: $testCaseName ===") - // Clear old output for this testcase. - outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val sqlWithoutComment = sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") val allQueries = @@ -269,11 +270,32 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } - try { + def doTest(reset: Boolean, isSpeculative: Boolean = false): Unit = { + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + if (reset) { TestHive.reset() } + // Many tests drop indexes on src and srcpart at the beginning, so we need to load those + // tables here. Since DROP INDEX DDL is just passed to Hive, it bypasses the analyzer and + // thus the tables referenced in those DDL commands cannot be extracted for use by our + // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES + // command expect these tables to exist. + val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + for (table <- Seq("src", "srcpart")) { + val hasMatchingQuery = queryList.exists { query => + val normalizedQuery = query.toLowerCase.stripSuffix(";") + normalizedQuery.endsWith(table) || + normalizedQuery.contains(s"from $table") || + normalizedQuery.contains(s"from default.$table") + } + if (hasShowTableCommand || hasMatchingQuery) { + TestHive.loadTestTable(table) + } + } + val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" @@ -430,12 +452,45 @@ abstract class HiveComparisonTest """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) - fail(errorMessage) + if (isSpeculative && !reset) { + fail("Failed on first run; retrying") + } else { + fail(errorMessage) + } } } // Touch passed file. new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } + + val canSpeculativelyTryWithoutReset: Boolean = { + val excludedSubstrings = Seq( + "into table", + "create table", + "drop index" + ) + !queryList.map(_.toLowerCase).exists { query => + excludedSubstrings.exists(s => query.contains(s)) + } + } + + try { + try { + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + doTest(reset = false, isSpeculative = true) + } else { + doTest(reset) + } + } catch { + case tf: org.scalatest.exceptions.TestFailedException => + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + logWarning("Test failed without reset(); retrying with reset()") + doTest(reset = true) + } else { + throw tf + } + } } catch { case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index f7b37dae0a5f3..f96c989c4614f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -59,7 +59,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) - createQueryTest(testCaseName, queriesString) + createQueryTest(testCaseName, queriesString, reset = true, tryWithoutResettingFirst = true) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. From 47a0abc343550c855e679de12983f43e6fcc0171 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 1 Dec 2015 15:30:21 -0800 Subject: [PATCH 1087/2089] [SPARK-11328][SQL] Improve error message when hitting this issue The issue is that the output commiter is not idempotent and retry attempts will fail because the output file already exists. It is not safe to clean up the file as this output committer is by design not retryable. Currently, the job fails with a confusing file exists error. This patch is a stop gap to tell the user to look at the top of the error log for the proper message. This is difficult to test locally as Spark is hardcoded not to retry. Manually verified by upping the retry attempts. Author: Nong Li Author: Nong Li Closes #10080 from nongli/spark-11328. --- .../datasources/WriterContainer.scala | 22 +++++++++++++++++-- .../DirectParquetOutputCommitter.scala | 3 ++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 1b59b19d9420d..ad55367258890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -124,6 +124,24 @@ private[sql] abstract class BaseWriterContainer( } } + protected def newOutputWriter(path: String): OutputWriter = { + try { + outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + } catch { + case e: org.apache.hadoop.fs.FileAlreadyExistsException => + if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { + // Spark-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry + // attempts, the task will fail because the output file is created from a prior attempt. + // This often means the most visible error to the user is misleading. Augment the error + // to tell the user to look for the actual error. + throw new SparkException("The output file already exists but this could be due to a " + + "failure from an earlier attempt. Look through the earlier logs or stage page for " + + "the first error.\n File exists error: " + e) + } + throw e + } + } + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) @@ -234,7 +252,7 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set("spark.sql.sources.output.path", outputPath) - val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) + val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) var writerClosed = false @@ -403,7 +421,7 @@ private[sql] class DynamicPartitionWriterContainer( val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) + val newWriter = super.newOutputWriter(path.toString) newWriter.initConverter(dataSchema) newWriter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 300e8677b312f..1a4e99ff10afb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -41,7 +41,8 @@ import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetO * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are * left empty). */ -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) +private[datasources] class DirectParquetOutputCommitter( + outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) From 5a8b5fdd6ffa58f015cdadf3f2c6df78e0a388ad Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 1 Dec 2015 15:32:57 -0800 Subject: [PATCH 1088/2089] [SPARK-11788][SQL] surround timestamp/date value with quotes in JDBC data source When query the Timestamp or Date column like the following val filtered = jdbcdf.where($"TIMESTAMP_COLUMN" >= beg && $"TIMESTAMP_COLUMN" < end) The generated SQL query is "TIMESTAMP_COLUMN >= 2015-01-01 00:00:00.0" It should have quote around the Timestamp/Date value such as "TIMESTAMP_COLUMN >= '2015-01-01 00:00:00.0'" Author: Huaxin Gao Closes #9872 from huaxingao/spark-11788. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 57a8a044a37cd..392d3ed58e3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal @@ -267,6 +267,8 @@ private[sql] class JDBCRDD( */ private def compileValue(value: Any): Any = value match { case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" case _ => value } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce2..8c24aa3151bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -484,4 +484,15 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + test("Test DataFrame.where for Date and Timestamp") { + // Regression test for bug SPARK-11788 + val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); + val date = java.sql.Date.valueOf("1995-01-01") + val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(rows(0).getAs[java.sql.Timestamp](2) + === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) + } } From 5872a9d89fe2720c2bcb1fc7494136947a72581c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 1 Dec 2015 16:24:04 -0800 Subject: [PATCH 1089/2089] [SPARK-11352][SQL] Escape */ in the generated comments. https://issues.apache.org/jira/browse/SPARK-11352 Author: Yin Huai Closes #10072 from yhuai/SPARK-11352. --- .../spark/sql/catalyst/expressions/Expression.scala | 10 ++++++++-- .../catalyst/expressions/codegen/CodegenFallback.scala | 2 +- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 9 +++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b55d3653a7158..4ee6542455a6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -95,7 +95,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. - val code = s"/* $this */" + val code = s"/* ${this.toCommentSafeString} */" GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") @@ -103,7 +103,7 @@ abstract class Expression extends TreeNode[Expression] { val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code.trim) + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) } } @@ -214,6 +214,12 @@ abstract class Expression extends TreeNode[Expression] { } override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. + */ + protected def toCommentSafeString: String = this.toString.replace("*/", "\\*\\/") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a31574c251af5..26fb143d1e45c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -33,7 +33,7 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") s""" - /* expression: ${this} */ + /* expression: ${this.toCommentSafeString} */ java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 002ed16dcfe7e..fe754240dcd67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -98,4 +98,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) assert(internalRow === internalRow2) } + + test("*/ in the data") { + // When */ appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape */ to \*\/. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)), + true, + InternalRow(UTF8String.fromString("*/"))) + } } From e96a70d5ab2e2b43a2df17a550fa9ed2ee0001c4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 1 Dec 2015 17:18:45 -0800 Subject: [PATCH 1090/2089] [SPARK-11596][SQL] In TreeNode's argString, if a TreeNode is not a child of the current TreeNode, we should only return the simpleString. In TreeNode's argString, if a TreeNode is not a child of the current TreeNode, we will only return the simpleString. I tested the [following case provided by Cristian](https://issues.apache.org/jira/browse/SPARK-11596?focusedCommentId=15019241&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15019241). ``` val c = (1 to 20).foldLeft[Option[DataFrame]] (None) { (curr, idx) => println(s"PROCESSING >>>>>>>>>>> $idx") val df = sqlContext.sparkContext.parallelize((0 to 10).zipWithIndex).toDF("A", "B") val union = curr.map(_.unionAll(df)).getOrElse(df) union.cache() Some(union) } c.get.explain(true) ``` Without the change, `c.get.explain(true)` took 100s. With the change, `c.get.explain(true)` took 26ms. https://issues.apache.org/jira/browse/SPARK-11596 Author: Yin Huai Closes #10079 from yhuai/SPARK-11596. --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f1cea07976a37..ad2bd78430a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -380,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"(${tn.simpleString})" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil From 1ce4adf55b535518c2e63917a827fac1f2df4e8e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Dec 2015 19:36:34 -0800 Subject: [PATCH 1091/2089] [SPARK-8414] Ensure context cleaner periodic cleanups Garbage collection triggers cleanups. If the driver JVM is huge and there is little memory pressure, we may never clean up shuffle files on executors. This is a problem for long-running applications (e.g. streaming). Author: Andrew Or Closes #10070 from andrewor14/periodic-gc. --- .../org/apache/spark/ContextCleaner.scala | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index d23c1533db758..bc732535fed87 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,12 +18,13 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.concurrent.{TimeUnit, ScheduledExecutorService} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Classes that represent cleaning tasks. @@ -66,6 +67,20 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + private val periodicGCService: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("context-cleaner-periodic-gc") + + /** + * How often to trigger a garbage collection in this JVM. + * + * This context cleaner triggers cleanups only when weak references are garbage collected. + * In long-running applications with large driver JVMs, where there is little memory pressure + * on the driver, this may happen very occasionally or not at all. Not cleaning at all may + * lead to executors running out of disk space after a while. + */ + private val periodicGCInterval = + sc.conf.getTimeAsSeconds("spark.cleaner.periodicGC.interval", "30min") + /** * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). @@ -104,6 +119,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() + periodicGCService.scheduleAtFixedRate(new Runnable { + override def run(): Unit = System.gc() + }, periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS) } /** @@ -119,6 +137,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.interrupt() } cleaningThread.join() + periodicGCService.shutdown() } /** Register a RDD for cleanup when it is garbage collected. */ From d96f8c997b9bb5c3d61f513d2c71d67ccf8e85d6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Dec 2015 19:51:12 -0800 Subject: [PATCH 1092/2089] [SPARK-12081] Make unified memory manager work with small heaps The existing `spark.memory.fraction` (default 0.75) gives the system 25% of the space to work with. For small heaps, this is not enough: e.g. default 1GB leaves only 250MB system memory. This is especially a problem in local mode, where the driver and executor are crammed in the same JVM. Members of the community have reported driver OOM's in such cases. **New proposal.** We now reserve 300MB before taking the 75%. For 1GB JVMs, this leaves `(1024 - 300) * 0.75 = 543MB` for execution and storage. This is proposal (1) listed in the [JIRA](https://issues.apache.org/jira/browse/SPARK-12081). Author: Andrew Or Closes #10081 from andrewor14/unified-memory-small-heaps. --- .../spark/memory/UnifiedMemoryManager.scala | 22 +++++++++++++++---- .../memory/UnifiedMemoryManagerSuite.scala | 20 +++++++++++++++++ docs/configuration.md | 4 ++-- docs/tuning.md | 2 +- 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 8be5b05419094..48b4e23433e43 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that * either side can borrow memory from the other. * - * The region shared between execution and storage is a fraction of the total heap space + * The region shared between execution and storage is a fraction of (the total heap space - 300MB) * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary * within this space is further determined by `spark.memory.storageFraction` (default 0.5). * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. @@ -48,7 +48,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} */ private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, - maxMemory: Long, + val maxMemory: Long, private val storageRegionSize: Long, numCores: Int) extends MemoryManager( @@ -130,6 +130,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( object UnifiedMemoryManager { + // Set aside a fixed amount of memory for non-storage, non-execution purposes. + // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve + // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then + // the memory used for execution and storage will be (1024 - 300) * 0.75 = 543MB by default. + private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { val maxMemory = getMaxMemory(conf) new UnifiedMemoryManager( @@ -144,8 +150,16 @@ object UnifiedMemoryManager { * Return the total amount of memory shared between execution and storage, in bytes. */ private def getMaxMemory(conf: SparkConf): Long = { - val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val reservedMemory = conf.getLong("spark.testing.reservedMemory", + if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) + val minSystemMemory = reservedMemory * 1.5 + if (systemMemory < minSystemMemory) { + throw new IllegalArgumentException(s"System memory $systemMemory must " + + s"be at least $minSystemMemory. Please use a larger heap size.") + } + val usableMemory = systemMemory - reservedMemory val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) - (systemMaxMemory * memoryFraction).toLong + (usableMemory * memoryFraction).toLong } } diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 8cebe81c3bfff..e97c898a44783 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -182,4 +182,24 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assertEnsureFreeSpaceCalled(ms, 850L) } + test("small heap") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + val expectedMaxMemory = ((systemMemory - reservedMemory) * memoryFraction).toLong + assert(mm.maxMemory === expectedMaxMemory) + + // Try using a system memory that's too small + val conf2 = conf.clone().set("spark.testing.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("larger heap size")) + } + } diff --git a/docs/configuration.md b/docs/configuration.md index 741d6b2b37a87..c39b4890851bc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -719,8 +719,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.fraction 0.75 - Fraction of the heap space used for execution and storage. The lower this is, the more - frequently spills and cached data eviction occur. The purpose of this config is to set + Fraction of (heap space - 300MB) used for execution and storage. The lower this is, the + more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation in the case of sparse, unusually large records. Leaving this at the default value is recommended. For more detail, see diff --git a/docs/tuning.md b/docs/tuning.md index a8fe7a4532798..e73ed69ffbbf8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -114,7 +114,7 @@ variety of workloads without requiring user expertise of how memory is divided i Although there are two relevant configurations, the typical user should not need to adjust them as the default values are applicable to most workloads: -* `spark.memory.fraction` expresses the size of `M` as a fraction of the total JVM heap space +* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) (default 0.75). The rest of the space (25%) is reserved for user data structures, internal metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually large records. From 96691feae0229fd693c29475620be2c4059dd080 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Dec 2015 20:17:12 -0800 Subject: [PATCH 1093/2089] [SPARK-12077][SQL] change the default plan for single distinct Use try to match the behavior for single distinct aggregation with Spark 1.5, but that's not scalable, we should be robust by default, have a flag to address performance regression for low cardinality aggregation. cc yhuai nongli Author: Davies Liu Closes #10075 from davies/agg_15. --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5ef3a48c56a87..58adf64e49869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -451,7 +451,7 @@ private[spark] object SQLConf { val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(true), + defaultValue = Some(false), isPublic = false, doc = "When true, if a query only has a single distinct column and it has " + "grouping expressions, we will use our planner rule to handle this distinct " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index dfec139985f73..a4626259b2823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -44,10 +44,10 @@ class PlannerSuite extends SharedSQLContext { fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - // For the new aggregation code path, there will be three aggregate operator for + // For the new aggregation code path, there will be four aggregate operator for // distinct aggregations. assert( - aggregations.size == 2 || aggregations.size == 3, + aggregations.size == 2 || aggregations.size == 4, s"The plan of query $query does not have partial aggregations.") } From 8a75a3049539eeef04c0db51736e97070c162b46 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Dec 2015 21:04:52 -0800 Subject: [PATCH 1094/2089] [SPARK-12087][STREAMING] Create new JobConf for every batch in saveAsHadoopFiles The JobConf object created in `DStream.saveAsHadoopFiles` is used concurrently in multiple places: * The JobConf is updated by `RDD.saveAsHadoopFile()` before the job is launched * The JobConf is serialized as part of the DStream checkpoints. These concurrent accesses (updating in one thread, while the another thread is serializing it) can lead to concurrentModidicationException in the underlying Java hashmap using in the internal Hadoop Configuration object. The solution is to create a new JobConf in every batch, that is updated by `RDD.saveAsHadoopFile()`, while the checkpointing serializes the original JobConf. Tests to be added in #9988 will fail reliably without this patch. Keeping this patch really small to make sure that it can be added to previous branches. Author: Tathagata Das Closes #10088 from tdas/SPARK-12087. --- .../apache/spark/streaming/dstream/PairDStreamFunctions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index fb691eed27e32..2762309134eb1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -730,7 +730,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, + new JobConf(serializableConf.value)) } self.foreachRDD(saveFunc) } From 0f37d1d7ed7f6e34f98f2a3c274918de29e7a1d7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 21:51:33 -0800 Subject: [PATCH 1095/2089] [SPARK-11949][SQL] Check bitmasks to set nullable property Following up #10038. We can use bitmasks to determine which grouping expressions need to be set as nullable. cc yhuai Author: Liang-Chi Hsieh Closes #10067 from viirya/fix-cube-following. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 765327c474e69..d3163dcd4db94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -224,10 +224,15 @@ class Analyzer( case other => Alias(other, other.toString)() } - // TODO: We need to use bitmasks to determine which grouping expressions need to be - // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need - // to change the nullability of a. - val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val nonNullBitmask = x.bitmasks.reduce(_ & _) + + val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => + if ((nonNullBitmask & 1 << idx) == 0) { + (a -> a.toAttribute.withNullability(true)) + } else { + (a -> a.toAttribute) + } + }.toMap val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change From 4375eb3f48fc7ae90caf6c21a0d3ab0b66bf4efa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Dec 2015 22:41:48 -0800 Subject: [PATCH 1096/2089] [SPARK-12090] [PYSPARK] consider shuffle in coalesce() Author: Davies Liu Closes #10090 from davies/fix_coalesce. --- python/pyspark/rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b4d59647b2bc..00bb9a62e904a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2015,7 +2015,7 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions) + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): From 128c29035b4e7383cc3a9a6c7a9ab6136205ac6c Mon Sep 17 00:00:00 2001 From: Jeroen Schot Date: Wed, 2 Dec 2015 09:40:07 +0000 Subject: [PATCH 1097/2089] [SPARK-3580][CORE] Add Consistent Method To Get Number of RDD Partitions Across Different Languages I have tried to address all the comments in pull request https://github.com/apache/spark/pull/2447. Note that the second commit (using the new method in all internal code of all components) is quite intrusive and could be omitted. Author: Jeroen Schot Closes #9767 from schot/master. --- .../org/apache/spark/api/java/JavaRDDLike.scala | 5 +++++ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +++++++- .../test/java/org/apache/spark/JavaAPISuite.java | 13 +++++++++++++ .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 1 + project/MimaExcludes.scala | 4 ++++ 5 files changed, 30 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1e9d4f1803a81..0e4d7dce0f2f5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -28,6 +28,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap @@ -62,6 +63,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** Set of partitions in this RDD. */ def partitions: JList[Partition] = rdd.partitions.toSeq.asJava + /** Return the number of partitions in this RDD. */ + @Since("1.6.0") + def getNumPartitions: Int = rdd.getNumPartitions + /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8b3731d935788..9fe9d83a705b2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -242,6 +242,12 @@ abstract class RDD[T: ClassTag]( } } + /** + * Returns the number of partitions of this RDD. + */ + @Since("1.6.0") + final def getNumPartitions: Int = partitions.length + /** * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4d4e9820500e7..11f1248c24d38 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -973,6 +973,19 @@ public Iterator call(Integer index, Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void getNumPartitions(){ + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); + JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), 2); + Assert.assertEquals(3, rdd1.getNumPartitions()); + Assert.assertEquals(2, rdd2.getNumPartitions()); + Assert.assertEquals(2, rdd3.getNumPartitions()); + } @Test public void repartition() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5f718ea9f7be1..46ed5c04f4338 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -34,6 +34,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 566bfe8efb7a4..d3a3c0ceb68c8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -155,6 +155,10 @@ object MimaExcludes { "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") + ) ++ Seq( + // SPARK-3580 Add getNumPartitions method to JavaRDD + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") ) case v if v.startsWith("1.5") => Seq( From a1542ce2f33ad365ff437d2d3014b9de2f6670e5 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 2 Dec 2015 09:36:12 -0800 Subject: [PATCH 1098/2089] [SPARK-12094][SQL] Prettier tree string for TreeNode When examining plans of complex queries with multiple joins, a pain point of mine is that, it's hard to immediately see the sibling node of a specific query plan node. This PR adds tree lines for the tree string of a `TreeNode`, so that the result can be visually more intuitive. Author: Cheng Lian Closes #10099 from liancheng/prettier-tree-string. --- .../spark/sql/catalyst/trees/TreeNode.scala | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ad2bd78430a68..dfea583e01465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -393,7 +393,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. @@ -419,12 +419,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - /** Appends the string represent of this node and its children to the given StringBuilder. */ - protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { - builder.append(" " * depth) + /** + * Appends the string represent of this node and its children to the given StringBuilder. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + protected def generateTreeString( + depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + builder.append(simpleString) builder.append("\n") - children.foreach(_.generateTreeString(depth + 1, builder)) + + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + builder } From 452690ba1cc3c667bdd9f3022c43c9a10267880b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 2 Dec 2015 13:44:01 -0800 Subject: [PATCH 1099/2089] [SPARK-12001] Allow partially-stopped StreamingContext to be completely stopped If `StreamingContext.stop()` is interrupted midway through the call, the context will be marked as stopped but certain state will have not been cleaned up. Because `state = STOPPED` will be set, subsequent `stop()` calls will be unable to finish stopping the context, preventing any new StreamingContexts from being created. This patch addresses this issue by only marking the context as `STOPPED` once the `stop()` has successfully completed which allows `stop()` to be called a second time in order to finish stopping the context in case the original `stop()` call was interrupted. I discovered this issue by examining logs from a failed Jenkins run in which this race condition occurred in `FailureSuite`, leaking an unstoppable context and causing all subsequent tests to fail. Author: Josh Rosen Closes #9982 from JoshRosen/SPARK-12001. --- .../spark/streaming/StreamingContext.scala | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 6fb8ad38abcec..cf843e3e8b8ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -699,28 +699,33 @@ class StreamingContext private[streaming] ( " AsynchronousListenerBus") } synchronized { - try { - state match { - case INITIALIZED => - logWarning("StreamingContext has not been started yet") - case STOPPED => - logWarning("StreamingContext has already been stopped") - case ACTIVE => - scheduler.stop(stopGracefully) - // Removing the streamingSource to de-register the metrics on stop() - env.metricsSystem.removeSource(streamingSource) - uiTab.foreach(_.detach()) - StreamingContext.setActiveContext(null) - waiter.notifyStop() - if (shutdownHookRef != null) { - shutdownHookRefToRemove = shutdownHookRef - shutdownHookRef = null - } - logInfo("StreamingContext stopped successfully") - } - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + state = STOPPED + case STOPPED => + logWarning("StreamingContext has already been stopped") + state = STOPPED + case ACTIVE => + // It's important that we don't set state = STOPPED until the very end of this case, + // since we need to ensure that we're still able to call `stop()` to recover from + // a partially-stopped StreamingContext which resulted from this `stop()` call being + // interrupted. See SPARK-12001 for more details. Because the body of this case can be + // executed twice in the case of a partial stop, all methods called here need to be + // idempotent. + scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) + uiTab.foreach(_.detach()) + StreamingContext.setActiveContext(null) + waiter.notifyStop() + if (shutdownHookRef != null) { + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null + } + logInfo("StreamingContext stopped successfully") + state = STOPPED } } if (shutdownHookRefToRemove != null) { From de07d06abecf3516c95d099b6c01a86e0c8cfd8c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 2 Dec 2015 14:15:54 -0800 Subject: [PATCH 1100/2089] [SPARK-10266][DOCUMENTATION, ML] Fixed @Since annotation for ml.tunning cc mengxr noel-smith I worked on this issues based on https://github.com/apache/spark/pull/8729. ehsanmok thank you for your contricution! Author: Yu ISHIKAWA Author: Ehsan M.Kermani Closes #9338 from yu-iskw/JIRA-10266. --- .../spark/ml/tuning/CrossValidator.scala | 34 ++++++++++++++----- .../spark/ml/tuning/ParamGridBuilder.scala | 14 ++++++-- .../ml/tuning/TrainValidationSplit.scala | 26 +++++++++++--- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 83a9048374267..5c09f1aaff80d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,18 +19,18 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.{JObject, DefaultFormats} import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JObject} -import org.apache.spark.ml.classification.OneVsRestParams -import org.apache.spark.ml.feature.RFormulaModel -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ +import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -58,26 +58,34 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { * :: Experimental :: * K-fold cross validation. */ +@Since("1.2.0") @Experimental -class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] +class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) + extends Estimator[CrossValidatorModel] with CrossValidatorParams with MLWritable with Logging { + @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS /** @group setParam */ + @Since("1.2.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.2.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.2.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -116,10 +124,12 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -128,6 +138,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { val copied = defaultCopy(extra).asInstanceOf[CrossValidator] if (copied.isDefined(estimator)) { @@ -308,26 +319,31 @@ object CrossValidator extends MLReadable[CrossValidator] { * @param avgMetrics Average cross-validation metrics for each paramMap in * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ +@Since("1.2.0") @Experimental class CrossValidatorModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) + @Since("1.4.0") override val uid: String, + @Since("1.2.0") val bestModel: Model[_], + @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + @Since("1.4.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidatorModel = { val copied = new CrossValidatorModel( uid, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index 98a8f0330ca45..b836d2a2340e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ /** * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ +@Since("1.2.0") @Experimental -class ParamGridBuilder { +class ParamGridBuilder @Since("1.2.0") { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") def baseOn(paramMap: ParamMap): this.type = { baseOn(paramMap.toSeq: _*) this @@ -43,6 +45,7 @@ class ParamGridBuilder { /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") @varargs def baseOn(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => @@ -54,6 +57,7 @@ class ParamGridBuilder { /** * Adds a param with multiple values (overwrites if the input param exists). */ + @Since("1.2.0") def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { paramGrid.put(param, values) this @@ -64,6 +68,7 @@ class ParamGridBuilder { /** * Adds a double param with multiple values. */ + @Since("1.2.0") def addGrid(param: DoubleParam, values: Array[Double]): this.type = { addGrid[Double](param, values) } @@ -71,6 +76,7 @@ class ParamGridBuilder { /** * Adds a int param with multiple values. */ + @Since("1.2.0") def addGrid(param: IntParam, values: Array[Int]): this.type = { addGrid[Int](param, values) } @@ -78,6 +84,7 @@ class ParamGridBuilder { /** * Adds a float param with multiple values. */ + @Since("1.2.0") def addGrid(param: FloatParam, values: Array[Float]): this.type = { addGrid[Float](param, values) } @@ -85,6 +92,7 @@ class ParamGridBuilder { /** * Adds a long param with multiple values. */ + @Since("1.2.0") def addGrid(param: LongParam, values: Array[Long]): this.type = { addGrid[Long](param, values) } @@ -92,6 +100,7 @@ class ParamGridBuilder { /** * Adds a boolean param with true and false. */ + @Since("1.2.0") def addGrid(param: BooleanParam): this.type = { addGrid[Boolean](param, Array(true, false)) } @@ -99,6 +108,7 @@ class ParamGridBuilder { /** * Builds and returns all combinations of parameters specified by the param grid. */ + @Since("1.2.0") def build(): Array[ParamMap] = { var paramMaps = Array(new ParamMap) paramGrid.foreach { case (param, values) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 73a14b8310157..adf06302047a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} @@ -51,24 +51,32 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { * and uses evaluation metric on the validation set to select the best model. * Similar to [[CrossValidator]], but only splits the set once. */ +@Since("1.5.0") @Experimental -class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] +class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[TrainValidationSplitModel] with TrainValidationSplitParams with Logging { + @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) /** @group setParam */ + @Since("1.5.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.5.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.5.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -108,10 +116,12 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.5.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -120,6 +130,7 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali } } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] if (copied.isDefined(estimator)) { @@ -140,26 +151,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali * @param bestModel Estimator determined best model. * @param validationMetrics Evaluated validation metrics. */ +@Since("1.5.0") @Experimental class TrainValidationSplitModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val validationMetrics: Array[Double]) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val bestModel: Model[_], + @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + @Since("1.5.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplitModel = { val copied = new TrainValidationSplitModel ( uid, From d0d7ec533062151269b300ed455cf150a69098c0 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Thu, 3 Dec 2015 08:48:49 +0800 Subject: [PATCH 1101/2089] [SPARK-12093][SQL] Fix the error of comment in DDLParser Author: Yadong Qi Closes #10096 from watermen/patch-1. --- .../apache/spark/sql/execution/datasources/DDLParser.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 6969b423d01b9..f22508b21090c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -66,15 +66,15 @@ class DDLParser(parseQuery: String => LogicalPlan) protected def start: Parser[LogicalPlan] = ddl /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * AS SELECT ... From 9bb695b7a82d837e2c7a724514ea6b203efb5364 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 2 Dec 2015 17:19:31 -0800 Subject: [PATCH 1102/2089] [SPARK-12000] do not specify arg types when reference a method in ScalaDoc This fixes SPARK-12000, verified on my local with JDK 7. It seems that `scaladoc` try to match method names and messed up with annotations. cc: JoshRosen jkbradley Author: Xiangrui Meng Closes #10114 from mengxr/SPARK-12000.2. --- .../org/apache/spark/mllib/clustering/BisectingKMeans.scala | 2 +- .../apache/spark/mllib/clustering/BisectingKMeansModel.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 29a7aa0bb63f2..82adfa6ffd596 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -214,7 +214,7 @@ class BisectingKMeans private ( } /** - * Java-friendly version of [[run(RDD[Vector])*]] + * Java-friendly version of [[run()]]. */ def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 5015f1540d920..f942e5613ffaf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -64,7 +64,7 @@ class BisectingKMeansModel @Since("1.6.0") ( } /** - * Java-friendly version of [[predict(RDD[Vector])*]] + * Java-friendly version of [[predict()]]. */ @Since("1.6.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = @@ -88,7 +88,7 @@ class BisectingKMeansModel @Since("1.6.0") ( } /** - * Java-friendly version of [[computeCost(RDD[Vector])*]]. + * Java-friendly version of [[computeCost()]]. */ @Since("1.6.0") def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) From ae402533738be06ac802914ed3e48f0d5fa54cbe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Dec 2015 11:12:02 +0800 Subject: [PATCH 1103/2089] [SPARK-12082][FLAKY-TEST] Increase timeouts in NettyBlockTransferSecuritySuite We should try increasing a timeout in NettyBlockTransferSecuritySuite in order to reduce that suite's flakiness in Jenkins. Author: Josh Rosen Closes #10113 from JoshRosen/SPARK-12082. --- .../spark/network/netty/NettyBlockTransferSecuritySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 3940527fb874e..98da94139f7f8 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -148,7 +148,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } }) - Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } From ec2b6c26c9b6bd59d29b5d7af2742aca7e6e0b07 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 3 Dec 2015 11:21:24 +0800 Subject: [PATCH 1104/2089] [SPARK-12109][SQL] Expressions's simpleString should delegate to its toString. https://issues.apache.org/jira/browse/SPARK-12109 The change of https://issues.apache.org/jira/browse/SPARK-11596 exposed the problem. In the sql plan viz, the filter shows ![image](https://cloud.githubusercontent.com/assets/2072857/11547075/1a285230-9906-11e5-8481-2bb451e35ef1.png) After changes in this PR, the viz is back to normal. ![image](https://cloud.githubusercontent.com/assets/2072857/11547080/2bc570f4-9906-11e5-8897-3b3bff173276.png) Author: Yin Huai Closes #10111 from yhuai/SPARK-12109. --- .../org/apache/spark/sql/catalyst/expressions/Expression.scala | 3 ++- .../spark/sql/catalyst/expressions/windowExpressions.scala | 3 --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4ee6542455a6c..614f0c075fd23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -207,12 +207,13 @@ abstract class Expression extends TreeNode[Expression] { }.toString } - private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } + override def simpleString: String = toString + override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 09ec0e333aa44..1680aa8252ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -71,9 +71,6 @@ case class WindowSpecDefinition( childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] - - override def toString: String = simpleString - override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index dfea583e01465..d838d845d20fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -380,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil From 5349851f368a1b5dab8a99c0d51c9638ce7aec56 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 3 Dec 2015 08:42:21 +0000 Subject: [PATCH 1105/2089] =?UTF-8?q?[SPARK-12088][SQL]=20check=20connecti?= =?UTF-8?q?on.isClosed=20before=20calling=20connection=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In Java Spec java.sql.Connection, it has boolean getAutoCommit() throws SQLException Throws: SQLException - if a database access error occurs or this method is called on a closed connection So if conn.getAutoCommit is called on a closed connection, a SQLException will be thrown. Even though the code catch the SQLException and program can continue, I think we should check conn.isClosed before calling conn.getAutoCommit to avoid the unnecessary SQLException. Author: Huaxin Gao Closes #10095 from huaxingao/spark-12088. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 392d3ed58e3ce..b9dd7f6b4099b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -498,7 +498,7 @@ private[sql] class JDBCRDD( } try { if (null != conn) { - if (!conn.getAutoCommit && !conn.isClosed) { + if (!conn.isClosed && !conn.getAutoCommit) { try { conn.commit() } catch { From 7470d9edbb0a45e714c96b5d55eff30724c0653a Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 3 Dec 2015 15:36:28 +0000 Subject: [PATCH 1106/2089] [DOCUMENTATION][MLLIB] typo in mllib doc \cc mengxr Author: Jeff Zhang Closes #10093 from zjffdu/mllib_typo. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 5f888775553f2..05c2c96c5ec5a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1232,7 +1232,7 @@ lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) * `withStd`: True by default. Scales the data to unit standard deviation. * `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. +`StandardScaler` is an `Estimator` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. From 95b3cf125b905b4ef8705c46f2ef255377b0a9dc Mon Sep 17 00:00:00 2001 From: microwishing Date: Thu, 3 Dec 2015 16:09:05 +0000 Subject: [PATCH 1107/2089] [DOCUMENTATION][KAFKA] fix typo in kafka/OffsetRange.scala this is to fix some typo in external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala Author: microwishing Closes #10121 from microwishing/master. --- .../scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala | 2 +- .../scala/org/apache/spark/streaming/kafka/OffsetRange.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 86394ea8a685e..45a6982b9afe5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -151,7 +151,7 @@ private[kafka] class KafkaTestUtils extends Logging { } } - /** Create a Kafka topic and wait until it propagated to the whole cluster */ + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { AdminUtils.createTopic(zkClient, topic, 1, 1) // wait until metadata is propagated diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 8a5f371494511..d9b856e4697a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition /** - * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). * {{{ From 43c575cb1766b32c74db17216194a8a74119b759 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 3 Dec 2015 09:22:21 -0800 Subject: [PATCH 1108/2089] [SPARK-12116][SPARKR][DOCS] document how to workaround function name conflicts with dplyr shivaram Author: felixcheung Closes #10119 from felixcheung/rdocdplyrmasked. --- docs/sparkr.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index cfb9b41350f45..01148786b79d7 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -384,5 +384,6 @@ The following functions are masked by the SparkR package: +Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. + You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) - From 8fa3e474a8ba180188361c0ad7e2704c3e2258d3 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 3 Dec 2015 10:33:06 -0800 Subject: [PATCH 1109/2089] [SPARK-11314][YARN] add service API and test service for Yarn Cluster schedulers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is purely the yarn/src/main and yarn/src/test bits of the YARN ATS integration: the extension model to load and run implementations of `SchedulerExtensionService` in the yarn cluster scheduler process —and to stop them afterwards. There's duplication between the two schedulers, yarn-client and yarn-cluster, at least in terms of setting everything up, because the common superclass, `YarnSchedulerBackend` is in spark-core, and the extension services need the YARN app/attempt IDs. If you look at how the the extension services are loaded, the case class `SchedulerExtensionServiceBinding` is used to pass in config info -currently just the spark context and the yarn IDs, of which one, the attemptID, will be null when running client-side. I'm passing in a case class to ensure that it would be possible in future to add extra arguments to the binding class, yet, as the method signature will not have changed, still be able to load existing services. There's no functional extension service here, just one for testing. The real tests come in the bigger pull requests. At the same time, there's no restriction of this extension service purely to the ATS history publisher. Anything else that wants to listen to the spark context and publish events could use this, and I'd also consider writing one for the YARN-913 registry service, so that the URLs of the web UI would be locatable through that (low priority; would make more sense if integrated with a REST client). There's no minicluster test. Given the test execution overhead of setting up minicluster tests, it'd probably be better to add an extension service into one of the existing tests. Author: Steve Loughran Closes #9182 from steveloughran/stevel/feature/SPARK-1537-service. --- .../spark/deploy/yarn/ApplicationMaster.scala | 8 + .../cluster/SchedulerExtensionService.scala | 154 ++++++++++++++++++ .../cluster/YarnClientSchedulerBackend.scala | 22 +-- .../cluster/YarnClusterSchedulerBackend.scala | 20 +-- .../cluster/YarnSchedulerBackend.scala | 70 +++++++- .../ExtensionServiceIntegrationSuite.scala | 71 ++++++++ .../cluster/SimpleExtensionService.scala | 34 ++++ .../cluster/StubApplicationAttemptId.scala | 48 ++++++ .../scheduler/cluster/StubApplicationId.scala | 42 +++++ 9 files changed, 431 insertions(+), 38 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala rename {core => yarn}/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala (81%) create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 50ae7ffeec4c5..13ef4dfd64165 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -117,6 +117,10 @@ private[spark] class ApplicationMaster( private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + def getAttemptId(): ApplicationAttemptId = { + client.getAttemptId() + } + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -662,6 +666,10 @@ object ApplicationMaster extends Logging { master.sparkContextStopped(sc) } + private[spark] def getAttemptId(): ApplicationAttemptId = { + master.getAttemptId + } + } /** diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala new file mode 100644 index 0000000000000..c064521845399 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.util.Utils + +/** + * An extension service that can be loaded into a Spark YARN scheduler. + * A Service that can be started and stopped. + * + * 1. For implementations to be loadable by `SchedulerExtensionServices`, + * they must provide an empty constructor. + * 2. The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ +trait SchedulerExtensionService { + + /** + * Start the extension service. This should be a no-op if + * called more than once. + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit + + /** + * Stop the service + * The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ + def stop(): Unit +} + +/** + * Binding information for a [[SchedulerExtensionService]]. + * + * The attempt ID will be set if the service is started within a YARN application master; + * there is then a different attempt ID for every time that AM is restarted. + * When the service binding is instantiated in client mode, there's no attempt ID, as it lacks + * this information. + * @param sparkContext current spark context + * @param applicationId YARN application ID + * @param attemptId YARN attemptID. This will always be unset in client mode, and always set in + * cluster mode. + */ +case class SchedulerExtensionServiceBinding( + sparkContext: SparkContext, + applicationId: ApplicationId, + attemptId: Option[ApplicationAttemptId] = None) + +/** + * Container for [[SchedulerExtensionService]] instances. + * + * Loads Extension Services from the configuration property + * `"spark.yarn.services"`, instantiates and starts them. + * When stopped, it stops all child entries. + * + * The order in which child extension services are started and stopped + * is undefined. + */ +private[spark] class SchedulerExtensionServices extends SchedulerExtensionService + with Logging { + private var serviceOption: Option[String] = None + private var services: List[SchedulerExtensionService] = Nil + private val started = new AtomicBoolean(false) + private var binding: SchedulerExtensionServiceBinding = _ + + /** + * Binding operation will load the named services and call bind on them too; the + * entire set of services are then ready for `init()` and `start()` calls. + * + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit = { + if (started.getAndSet(true)) { + logWarning("Ignoring re-entrant start operation") + return + } + require(binding.sparkContext != null, "Null context parameter") + require(binding.applicationId != null, "Null appId parameter") + this.binding = binding + val sparkContext = binding.sparkContext + val appId = binding.applicationId + val attemptId = binding.attemptId + logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId") + + serviceOption = sparkContext.getConf.getOption(SchedulerExtensionServices.SPARK_YARN_SERVICES) + services = serviceOption + .map { s => + s.split(",").map(_.trim()).filter(!_.isEmpty) + .map { sClass => + val instance = Utils.classForName(sClass) + .newInstance() + .asInstanceOf[SchedulerExtensionService] + // bind this service + instance.start(binding) + logInfo(s"Service $sClass started") + instance + }.toList + }.getOrElse(Nil) + } + + /** + * Get the list of services. + * + * @return a list of services; Nil until the service is started + */ + def getServices: List[SchedulerExtensionService] = services + + /** + * Stop the services; idempotent. + * + */ + override def stop(): Unit = { + if (started.getAndSet(false)) { + logInfo(s"Stopping $this") + services.foreach { s => + Utils.tryLogNonFatalError(s.stop()) + } + } + } + + override def toString(): String = s"""SchedulerExtensionServices + |(serviceOption=$serviceOption, + | services=$services, + | started=$started)""".stripMargin +} + +private[spark] object SchedulerExtensionServices { + + /** + * A list of comma separated services to instantiate in the scheduler + */ + val SPARK_YARN_SERVICES = "spark.yarn.services" +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 20771f655473c..0e27a2665e939 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} @@ -33,7 +33,6 @@ private[spark] class YarnClientSchedulerBackend( with Logging { private var client: Client = null - private var appId: ApplicationId = null private var monitorThread: MonitorThread = null /** @@ -54,13 +53,12 @@ private[spark] class YarnClientSchedulerBackend( val args = new ClientArguments(argsArrayBuf.toArray, conf) totalExpectedExecutors = args.numExecutors client = new Client(args, conf) - appId = client.submitApplication() + bindToYarn(client.submitApplication(), None) // SPARK-8687: Ensure all necessary properties have already been set before // we initialize our driver scheduler backend, which serves these properties // to the executors super.start() - waitForApplication() // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver @@ -116,8 +114,8 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def waitForApplication(): Unit = { - assert(client != null && appId != null, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking + assert(client != null && appId.isDefined, "Application has not been submitted yet!") + val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true) // blocking if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -125,7 +123,7 @@ private[spark] class YarnClientSchedulerBackend( "It might have been killed or unable to launch application master.") } if (state == YarnApplicationState.RUNNING) { - logInfo(s"Application $appId has started running.") + logInfo(s"Application ${appId.get} has started running.") } } @@ -141,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) logError(s"Yarn application has already exited with state $state!") allowInterrupt = false sc.stop() @@ -163,7 +161,7 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def asyncMonitorApplication(): MonitorThread = { - assert(client != null && appId != null, "Application has not been submitted yet!") + assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) @@ -193,10 +191,4 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } - override def applicationId(): String = { - Option(appId).map(_.toString).getOrElse { - logWarning("Application ID is not initialized yet.") - super.applicationId - } - } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 50b699f11b21c..ced597bed36d9 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -31,26 +31,12 @@ private[spark] class YarnClusterSchedulerBackend( extends YarnSchedulerBackend(scheduler, sc) { override def start() { + val attemptId = ApplicationMaster.getAttemptId + bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) } - override def applicationId(): String = - // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.id").getOrElse { - logError("Application ID is not set.") - super.applicationId - } - - override def applicationAttemptId(): Option[String] = - // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.attemptId").orElse { - logError("Application attempt ID is not set.") - super.applicationAttemptId - } - override def getDriverLogUrls: Option[Map[String, String]] = { var driverLogs: Option[Map[String, String]] = None try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala similarity index 81% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala rename to yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 80da37b09b590..e3dd87798f018 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,17 +17,17 @@ package org.apache.spark.scheduler.cluster -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{ThreadUtils, RpcUtils} - -import scala.util.control.NonFatal +import org.apache.spark.util.{RpcUtils, ThreadUtils} /** * Abstract Yarn scheduler backend that contains common logic @@ -51,6 +51,64 @@ private[spark] abstract class YarnSchedulerBackend( private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) + /** Application ID. */ + protected var appId: Option[ApplicationId] = None + + /** Attempt ID. This is unset for client-mode schedulers */ + private var attemptId: Option[ApplicationAttemptId] = None + + /** Scheduler extension services. */ + private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + + /** + * Bind to YARN. This *must* be done before calling [[start()]]. + * + * @param appId YARN application ID + * @param attemptId Optional YARN attempt ID + */ + protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = { + this.appId = Some(appId) + this.attemptId = attemptId + } + + override def start() { + require(appId.isDefined, "application ID unset") + val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId) + services.start(binding) + super.start() + } + + override def stop(): Unit = { + try { + super.stop() + } finally { + services.stop() + } + } + + /** + * Get the attempt ID for this run, if the cluster manager supports multiple + * attempts. Applications run in client mode will not have attempt IDs. + * + * @return The application attempt id, if available. + */ + override def applicationAttemptId(): Option[String] = { + attemptId.map(_.toString) + } + + /** + * Get an application ID associated with the job. + * This returns the string value of [[appId]] if set, otherwise + * the locally-generated ID from the superclass. + * @return The application ID + */ + override def applicationId(): String = { + appId.map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } + /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala new file mode 100644 index 0000000000000..b4d1b0a3d22a7 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{LocalSparkContext, Logging, SparkConf, SparkContext, SparkFunSuite} + +/** + * Test the integration with [[SchedulerExtensionServices]] + */ +class ExtensionServiceIntegrationSuite extends SparkFunSuite + with LocalSparkContext with BeforeAndAfter + with Logging { + + val applicationId = new StubApplicationId(0, 1111L) + val attemptId = new StubApplicationAttemptId(applicationId, 1) + + /* + * Setup phase creates the spark context + */ + before { + val sparkConf = new SparkConf() + sparkConf.set(SchedulerExtensionServices.SPARK_YARN_SERVICES, + classOf[SimpleExtensionService].getName()) + sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite") + sc = new SparkContext(sparkConf) + } + + test("Instantiate") { + val services = new SchedulerExtensionServices() + assertResult(Nil, "non-nil service list") { + services.getServices + } + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + services.stop() + } + + test("Contains SimpleExtensionService Service") { + val services = new SchedulerExtensionServices() + try { + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + val serviceList = services.getServices + assert(serviceList.nonEmpty, "empty service list") + val (service :: Nil) = serviceList + val simpleService = service.asInstanceOf[SimpleExtensionService] + assert(simpleService.started.get, "service not started") + services.stop() + assert(!simpleService.started.get, "service not stopped") + } finally { + services.stop() + } + } +} + + diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala new file mode 100644 index 0000000000000..9b8c98cda8da8 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +private[spark] class SimpleExtensionService extends SchedulerExtensionService { + + /** started flag; set in the `start()` call, stopped in `stop()`. */ + val started = new AtomicBoolean(false) + + override def start(binding: SchedulerExtensionServiceBinding): Unit = { + started.set(true) + } + + override def stop(): Unit = { + started.set(false) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala new file mode 100644 index 0000000000000..4b57b9509a655 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +/** + * A stub application ID; can be set in constructor and/or updated later. + * @param applicationId application ID + * @param attempt an attempt counter + */ +class StubApplicationAttemptId(var applicationId: ApplicationId, var attempt: Int) + extends ApplicationAttemptId { + + override def setApplicationId(appID: ApplicationId): Unit = { + applicationId = appID + } + + override def getAttemptId: Int = { + attempt + } + + override def setAttemptId(attemptId: Int): Unit = { + attempt = attemptId + } + + override def getApplicationId: ApplicationId = { + applicationId + } + + override def build(): Unit = { + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala new file mode 100644 index 0000000000000..bffa0e09befd2 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.ApplicationId + +/** + * Simple Testing Application Id; ID and cluster timestamp are set in constructor + * and cannot be updated. + * @param id app id + * @param clusterTimestamp timestamp + */ +private[spark] class StubApplicationId(id: Int, clusterTimestamp: Long) extends ApplicationId { + override def getId: Int = { + id + } + + override def getClusterTimestamp: Long = { + clusterTimestamp + } + + override def setId(id: Int): Unit = {} + + override def setClusterTimestamp(clusterTimestamp: Long): Unit = {} + + override def build(): Unit = {} +} From 7bc9e1db2c47387ee693bcbeb4a8a2cbe11909cf Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 Dec 2015 11:05:12 -0800 Subject: [PATCH 1110/2089] [SPARK-12059][CORE] Avoid assertion error when unexpected state transition met in Master Downgrade to warning log for unexpected state transition. andrewor14 please review, thanks a lot. Author: jerryshao Closes #10091 from jerryshao/SPARK-12059. --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1355e1ad1b523..04b20e0d6ab9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -257,8 +257,9 @@ private[deploy] class Master( exec.state = state if (state == ExecutorState.RUNNING) { - assert(oldState == ExecutorState.LAUNCHING, - s"executor $execId state transfer from $oldState to RUNNING is illegal") + if (oldState != ExecutorState.LAUNCHING) { + logWarning(s"Executor $execId state transfer from $oldState to RUNNING is unexpected") + } appInfo.resetRetryCount() } From 649be4fa4532dcd3001df8345f9f7e970a3fbc65 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 3 Dec 2015 11:06:25 -0800 Subject: [PATCH 1111/2089] [SPARK-12101][CORE] Fix thread pools that cannot cache tasks in Worker and AppClient `SynchronousQueue` cannot cache any task. This issue is similar to #9978. It's an easy fix. Just use the fixed `ThreadUtils.newDaemonCachedThreadPool`. Author: Shixiong Zhu Closes #10108 from zsxwing/fix-threadpool. --- .../org/apache/spark/deploy/client/AppClient.scala | 10 ++++------ .../org/apache/spark/deploy/worker/Worker.scala | 10 ++++------ .../apache/spark/deploy/yarn/YarnAllocator.scala | 14 ++++---------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index df6ba7d669ce9..1e2f469214b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -68,12 +68,10 @@ private[spark] class AppClient( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.length, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "appclient-register-master-threadpool", + masterRpcAddresses.length // Make sure we can register with all masters at the same time + ) // A scheduled executor for scheduling the registration actions private val registrationRetryThread = diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 418faf8fc967f..1afc1ff59f2f9 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -146,12 +146,10 @@ private[deploy] class Worker( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "worker-register-master-threadpool", + masterRpcAddresses.size // Make sure we can register with all masters at the same time + ) var coresUsed = 0 var memoryUsed = 0 diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 73cd9031f0250..4e044aa4788da 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -25,8 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.JavaConverters._ -import com.google.common.util.concurrent.ThreadFactoryBuilder - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient @@ -40,7 +38,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -117,13 +115,9 @@ private[yarn] class YarnAllocator( // Resource capability requested for each executors private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) - private val launcherPool = new ThreadPoolExecutor( - // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue - sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, - 1, TimeUnit.MINUTES, - new LinkedBlockingQueue[Runnable](), - new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) - launcherPool.allowCoreThreadTimeOut(true) + private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( + "ContainerLauncher", + sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25)) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) From 688e521c2833a00069272a6749153d721a0996f6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 3 Dec 2015 11:09:29 -0800 Subject: [PATCH 1112/2089] [SPARK-12108] Make event logs smaller **Problem.** Event logs in 1.6 were much bigger than 1.5. I ran page rank and the event log size in 1.6 was almost 5x that in 1.5. I did a bisect to find that the RDD callsite added in #9398 is largely responsible for this. **Solution.** This patch removes the long form of the callsite (which is not used!) from the event log. This reduces the size of the event log significantly. *Note on compatibility*: if this patch is to be merged into 1.6.0, then it won't break any compatibility. Otherwise, if it is merged into 1.6.1, then we might need to add more backward compatibility handling logic (currently does not exist yet). Author: Andrew Or Closes #10115 from andrewor14/smaller-event-logs. --- .../org/apache/spark/storage/RDDInfo.scala | 4 +-- .../spark/ui/scope/RDDOperationGraph.scala | 4 +-- .../org/apache/spark/util/JsonProtocol.scala | 17 ++------- .../apache/spark/util/JsonProtocolSuite.scala | 35 ++++++++----------- 4 files changed, 20 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 87c1b981e7e13..94e8559bd2e91 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -28,7 +28,7 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], - val callSite: CallSite = CallSite.empty, + val callSite: String = "", val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { @@ -58,6 +58,6 @@ private[spark] object RDDInfo { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite, rdd.scope) + rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 24274562657b3..e9c8a8e299cd7 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -39,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: CallSite) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -178,7 +178,7 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - val label = s"${node.name} [${node.id}]\n${node.callsite.shortForm}" + val label = s"${node.name} [${node.id}]\n${node.callsite}" s"""${node.id} [label="$label"]""" } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c9beeb25e05af..2d2bd90eb339e 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -398,7 +398,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ - ("Callsite" -> callsiteToJson(rddInfo.callSite)) ~ + ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -408,11 +408,6 @@ private[spark] object JsonProtocol { ("Disk Size" -> rddInfo.diskSize) } - def callsiteToJson(callsite: CallSite): JValue = { - ("Short Form" -> callsite.shortForm) ~ - ("Long Form" -> callsite.longForm) - } - def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ @@ -857,9 +852,7 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite") - .map(callsiteFromJson) - .getOrElse(CallSite.empty) + val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -880,12 +873,6 @@ private[spark] object JsonProtocol { rddInfo } - def callsiteFromJson(json: JValue): CallSite = { - val shortForm = (json \ "Short Form").extract[String] - val longForm = (json \ "Long Form").extract[String] - CallSite(shortForm, longForm) - } - def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3f94ef7041914..1939ce5c743b0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -111,7 +111,6 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) - testCallsite(CallSite("happy", "birthday")) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( @@ -343,13 +342,13 @@ class JsonProtocolSuite extends SparkFunSuite { // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 // "Callsite" was introduced in Spark 1.6.0 val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), - CallSite("short", "long"), Some(new RDDOperationScope("fable"))) + "callsite", Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) .removeField({ _._1 == "Callsite"}) val expectedRddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq.empty, CallSite.empty, scope = None) + 1, "one", 100, StorageLevel.NONE, Seq.empty, "", scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -397,11 +396,6 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } - private def testCallsite(callsite: CallSite): Unit = { - val newCallsite = JsonProtocol.callsiteFromJson(JsonProtocol.callsiteToJson(callsite)) - assert(callsite === newCallsite) - } - private def testStageInfo(info: StageInfo) { val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info)) assertEquals(info, newInfo) @@ -726,8 +720,7 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, - Seq(1, 4, 7), CallSite(a.toString, b.toString)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7), a.toString) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -870,7 +863,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", - | "Callsite": {"Short Form": "101", "Long Form": "201"}, + | "Callsite": "101", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1273,7 +1266,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", - | "Callsite": {"Short Form": "1", "Long Form": "200"}, + | "Callsite": "1", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1317,7 +1310,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", - | "Callsite": {"Short Form": "2", "Long Form": "400"}, + | "Callsite": "2", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1335,7 +1328,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", - | "Callsite": {"Short Form": "3", "Long Form": "401"}, + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1379,7 +1372,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", - | "Callsite": {"Short Form": "3", "Long Form": "600"}, + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1397,7 +1390,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", - | "Callsite": {"Short Form": "4", "Long Form": "601"}, + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1415,7 +1408,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", - | "Callsite": {"Short Form": "5", "Long Form": "602"}, + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1459,7 +1452,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", - | "Callsite": {"Short Form": "4", "Long Form": "800"}, + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1477,7 +1470,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", - | "Callsite": {"Short Form": "5", "Long Form": "801"}, + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1495,7 +1488,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", - | "Callsite": {"Short Form": "6", "Long Form": "802"}, + | "Callsite": "6", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1513,7 +1506,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", - | "Callsite": {"Short Form": "7", "Long Form": "803"}, + | "Callsite": "7", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, From d576e76bbaa818480d31d2b8fbbe4b15718307d9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 3 Dec 2015 11:37:34 -0800 Subject: [PATCH 1113/2089] [MINOR][ML] Use coefficients replace weights Use ```coefficients``` replace ```weights```, I wish they are the last two. mengxr Author: Yanbo Liang Closes #10065 from yanboliang/coefficients. --- python/pyspark/ml/classification.py | 2 +- python/pyspark/ml/regression.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4a2982e2047ff..5599b8f3ecd88 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -49,7 +49,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) - >>> model.weights + >>> model.coefficients DenseVector([5.5...]) >>> model.intercept -2.68... diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 944e648ec8801..a0bb8ceed8861 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -40,7 +40,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) From ad7cea6f776a39801d6bb5bb829d1800b175b2ab Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Thu, 3 Dec 2015 11:59:10 -0800 Subject: [PATCH 1114/2089] [SPARK-12107][EC2] Update spark-ec2 versions I haven't created a JIRA. If we absolutely need one I'll do it, but I'm fine with not getting mentioned in the release notes if that's the only purpose it'll serve. cc marmbrus - We should include this in 1.6-RC2 if there is one. I can open a second PR against branch-1.6 if necessary. Author: Nicholas Chammas Closes #10109 from nchammas/spark-ec2-versions. --- ec2/spark_ec2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 84a950c9f6529..19d5980560fef 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.5.0" +SPARK_EC2_VERSION = "1.6.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -72,7 +72,10 @@ "1.3.1", "1.4.0", "1.4.1", - "1.5.0" + "1.5.0", + "1.5.1", + "1.5.2", + "1.6.0", ]) SPARK_TACHYON_MAP = { @@ -87,7 +90,10 @@ "1.3.1": "0.5.0", "1.4.0": "0.6.4", "1.4.1": "0.6.4", - "1.5.0": "0.7.1" + "1.5.0": "0.7.1", + "1.5.1": "0.7.1", + "1.5.2": "0.7.1", + "1.6.0": "0.8.2", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION From a02d47277379e1e82d0ee41b2205434f9ffbc3e5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 3 Dec 2015 12:00:09 -0800 Subject: [PATCH 1115/2089] [FLAKY-TEST-FIX][STREAMING][TEST] Make sure StreamingContexts are shutdown after test Author: Tathagata Das Closes #10124 from tdas/InputStreamSuite-flaky-test. --- .../spark/streaming/InputStreamsSuite.scala | 122 +++++++++--------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 047e38ef90998..3a3176b91b1ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -206,28 +206,28 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val numTotalRecords = numThreads * numRecordsPerThread val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false - - // set up the network stream using the test receiver - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.receiverStream[Int](testReceiver) - val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] - val outputStream = new TestOutputStream(countStream, outputBuffer) def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Let the data from the receiver be received - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val startTime = System.currentTimeMillis() - while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && - System.currentTimeMillis() - startTime < 5000) { - Thread.sleep(100) - clock.advance(batchDuration.milliseconds) + + // set up the network stream using the test receiver + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val networkStream = ssc.receiverStream[Int](testReceiver) + val countStream = networkStream.count + + val outputStream = new TestOutputStream(countStream, outputBuffer) + outputStream.register() + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while ((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -239,30 +239,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = true") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = true) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - val inputIterator = input.toIterator - for (i <- 0 until input.size) { - // Enqueue more than 1 item per tick but they should dequeue one at a time - inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = true) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + val inputIterator = input.toIterator + for (i <- 0 until input.size) { + // Enqueue more than 1 item per tick but they should dequeue one at a time + inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -282,33 +282,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = false") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) - // Enqueue the first 3 items (one by one), they should be merged in the next batch - val inputIterator = input.toIterator - inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - - // Enqueue the remaining items (again one by one), merged in the final batch - inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = false) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + // Enqueue the first 3 items (one by one), they should be merged in the next batch + val inputIterator = input.toIterator + inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + + // Enqueue the remaining items (again one by one), merged in the final batch + inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + } // Verify whether data received was as expected logInfo("--------------------------------") From 2213441e5e0fba01e05826257604aa427cdf2598 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 3 Dec 2015 13:25:20 -0800 Subject: [PATCH 1116/2089] [SPARK-12019][SPARKR] Support character vector for sparkR.init(), check param and fix doc and add tests. Spark submit expects comma-separated list Author: felixcheung Closes #10034 from felixcheung/sparkrinitdoc. --- R/pkg/R/client.R | 10 ++++-- R/pkg/R/sparkR.R | 56 ++++++++++++++++++++++----------- R/pkg/R/utils.R | 5 +++ R/pkg/inst/tests/test_client.R | 9 ++++++ R/pkg/inst/tests/test_context.R | 20 ++++++++++++ 5 files changed, 79 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index c811d1dac3bd5..25e99390a9c89 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -44,12 +44,16 @@ determineSparkSubmitBin <- function() { } generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + jars <- paste0(jars, collapse = ",") if (jars != "") { - jars <- paste("--jars", jars) + # construct the jars argument with a space between --jars and comma-separated values + jars <- paste0("--jars ", jars) } - if (!identical(packages, "")) { - packages <- paste("--packages", packages) + packages <- paste0(packages, collapse = ",") + if (packages != "") { + # construct the packages argument with a space between --packages and comma-separated values + packages <- paste0("--packages ", packages) } combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 7ff3fa628b9ca..d2bfad553104f 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -86,13 +86,13 @@ sparkR.stop <- function() { #' and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. #' -#' @param master The Spark master URL. +#' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory -#' @param sparkEnvir Named list of environment variables to set on worker nodes. -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. -#' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkPackages Character string vector of packages from spark-packages.org +#' @param sparkEnvir Named list of environment variables to set on worker nodes +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors +#' @param sparkJars Character vector of jar files to pass to the worker nodes +#' @param sparkPackages Character vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -102,7 +102,9 @@ sparkR.stop <- function() { #' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("jarfile1.jar","jarfile2.jar")) +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1", +#' "com.databricks:spark-csv_2.10:1.3.0")) #'} sparkR.init <- function( @@ -120,15 +122,8 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - jars <- suppressWarnings(normalizePath(as.character(sparkJars))) - - # Classpath separator is ";" on Windows - # URI needs four /// as from http://stackoverflow.com/a/18522792 - if (.Platform$OS.type == "unix") { - uriSep <- "//" - } else { - uriSep <- "////" - } + jars <- processSparkJars(sparkJars) + packages <- processSparkPackages(sparkPackages) sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) @@ -145,7 +140,7 @@ sparkR.init <- function( sparkHome = sparkHome, jars = jars, sparkSubmitOpts = submitOps, - packages = sparkPackages) + packages = packages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -195,8 +190,14 @@ sparkR.init <- function( paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- lapply(nonEmptyJars, + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + uriSep <- "//" + } else { + uriSep <- "////" + } + localJarPaths <- lapply(jars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -366,3 +367,22 @@ getClientModeSparkSubmitOpts <- function(submitOps, sparkEnvirMap) { # --option must be before the application class "sparkr-shell" in submitOps paste0(paste0(envirToOps, collapse = ""), submitOps) } + +# Utility function that handles sparkJars argument, and normalize paths +processSparkJars <- function(jars) { + splittedJars <- splitString(jars) + if (length(splittedJars) > length(jars)) { + warning("sparkJars as a comma-separated string is deprecated, use character vector instead") + } + normalized <- suppressWarnings(normalizePath(splittedJars)) + normalized +} + +# Utility function that handles sparkPackages argument +processSparkPackages <- function(packages) { + splittedPackages <- splitString(packages) + if (length(splittedPackages) > length(packages)) { + warning("sparkPackages as a comma-separated string is deprecated, use character vector instead") + } + splittedPackages +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 45c77a86c9582..43105aaa38424 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -636,3 +636,8 @@ assignNewEnv <- function(data) { } env } + +# Utility function to split by ',' and whitespace, remove empty tokens +splitString <- function(input) { + Filter(nzchar, unlist(strsplit(input, ",|\\s"))) +} diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 8a20991f89af8..a0664f32f31c1 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -34,3 +34,12 @@ test_that("no package specified doesn't add packages flag", { test_that("multiple packages don't produce a warning", { expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) }) + +test_that("sparkJars sparkPackages as character vectors", { + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", + c("com.databricks:spark-avro_2.10:2.0.1", + "com.databricks:spark-csv_2.10:1.3.0")) + expect_match(args, "--jars one.jar,two.jar,three.jar") + expect_match(args, + "--packages com.databricks:spark-avro_2.10:2.0.1,com.databricks:spark-csv_2.10:1.3.0") +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index 80c1b89a4c627..1707e314beff5 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -92,3 +92,23 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli " --driver-memory 4g sparkr-shell2")) # nolint end }) + +test_that("sparkJars sparkPackages as comma-separated strings", { + expect_warning(processSparkJars(" a, b ")) + jars <- suppressWarnings(processSparkJars(" a, b ")) + expect_equal(jars, c("a", "b")) + + jars <- suppressWarnings(processSparkJars(" abc ,, def ")) + expect_equal(jars, c("abc", "def")) + + jars <- suppressWarnings(processSparkJars(c(" abc ,, def ", "", "xyz", " ", "a,b"))) + expect_equal(jars, c("abc", "def", "xyz", "a", "b")) + + p <- processSparkPackages(c("ghi", "lmn")) + expect_equal(p, c("ghi", "lmn")) + + # check normalizePath + f <- dir()[[1]] + expect_that(processSparkJars(f), not(gives_warning())) + expect_match(processSparkJars(f), f) +}) From f434f36d508eb4dcade70871611fc022ae0feb56 Mon Sep 17 00:00:00 2001 From: Anderson de Andrade Date: Thu, 3 Dec 2015 16:37:00 -0800 Subject: [PATCH 1117/2089] [SPARK-12056][CORE] Create a TaskAttemptContext only after calling setConf. TaskAttemptContext's constructor will clone the configuration instead of referencing it. Calling setConf after creating TaskAttemptContext makes any changes to the configuration made inside setConf unperceived by RecordReader instances. As an example, Titan's InputFormat will change conf when calling setConf. They wrap their InputFormat around Cassandra's ColumnFamilyInputFormat, and append Cassandra's configuration. This change fixes the following error when using Titan's CassandraInputFormat with Spark: *java.lang.RuntimeException: org.apache.thrift.protocol.TProtocolException: Required field 'keyspace' was not present! Struct: set_key space_args(keyspace:null)* There's a discussion of this error here: https://groups.google.com/forum/#!topic/aureliusgraphs/4zpwyrYbGAE Author: Anderson de Andrade Closes #10046 from adeandrade/newhadooprdd-fix. --- core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index d1960990da0fe..86f38ae836b2b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -138,14 +138,14 @@ class NewHadoopRDD[K, V]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) From b6e9963ee4bf0ffb62c8e9829a551bcdc31e12e3 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 3 Dec 2015 16:39:12 -0800 Subject: [PATCH 1118/2089] [SPARK-11206] Support SQL UI on the history server (resubmit) Resubmit #9297 and #9991 On the live web UI, there is a SQL tab which provides valuable information for the SQL query. But once the workload is finished, we won't see the SQL tab on the history server. It will be helpful if we support SQL UI on the history server so we can analyze it even after its execution. To support SQL UI on the history server: 1. I added an onOtherEvent method to the SparkListener trait and post all SQL related events to the same event bus. 2. Two SQL events SparkListenerSQLExecutionStart and SparkListenerSQLExecutionEnd are defined in the sql module. 3. The new SQL events are written to event log using Jackson. 4. A new trait SparkHistoryListenerFactory is added to allow the history server to feed events to the SQL history listener. The SQL implementation is loaded at runtime using java.util.ServiceLoader. Author: Carson Wang Closes #10061 from carsonwang/SqlHistoryUI. --- .rat-excludes | 1 + .../org/apache/spark/JavaSparkListener.java | 3 + .../apache/spark/SparkFirehoseListener.java | 4 + .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 24 ++- .../spark/scheduler/SparkListenerBus.scala | 1 + .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 + .../org/apache/spark/sql/SQLContext.scala | 19 ++- .../spark/sql/execution/SQLExecution.scala | 24 +-- .../spark/sql/execution/SparkPlanInfo.scala | 46 ++++++ .../sql/execution/metric/SQLMetricInfo.scala | 30 ++++ .../sql/execution/metric/SQLMetrics.scala | 56 ++++--- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++++++++------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 44 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 + 21 files changed, 329 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 08fba6d351d6a..7262c960ed6bb 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,4 +82,5 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b88..23bc9a2e81727 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba6063..e6b24afd88ad4 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528cf..eaa07acc5132e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f1..075a7f13172de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aad..95722a07144ec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4608bce202ec8..8da6884a38535 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -154,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2d2bd90eb339e..cb0f1bf79f3d5 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -506,6 +511,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 0000000000000..507100be90967 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4e26250868374..db286ea8700b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1245,6 +1245,7 @@ class SQLContext private[sql]( sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { SQLContext.clearInstantiatedContext() + SQLContext.clearSqlListener() } }) @@ -1272,6 +1273,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1316,6 +1319,10 @@ object SQLContext { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1364,9 +1371,13 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c94..34971986261c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 0000000000000..486ce34064e43 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 0000000000000..2708219ad3485 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e95..6c0f6f8a52dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,21 +104,39 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +144,25 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e74d6fb396e1c..c74ad40406992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.Node import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6a..e19a1e3e5851f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -118,7 +141,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) } } @@ -140,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics, + taskEnd.taskMetrics.accumulatorUpdates(), finishTask = true) } @@ -148,15 +172,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Map[Long, Any], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +195,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +206,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +214,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -289,6 +312,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc6..4f50b2ecdc8f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3fe..7af0ff09c5c6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) nodes += node - val childrenNodes = plan.children.map( + val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82867ab4967bb..4f2cad19bfb6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,6 +26,7 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -82,7 +83,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096c..12a4e1356fed0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,7 +82,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +91,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -206,7 +207,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +221,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +251,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +274,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +292,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -338,6 +343,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { + SQLContext.clearSqlListener() val sqlContext = new SQLContext(sc) import sqlContext.implicits._ // Run 100 successful executions and 100 failed executions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62ed..e7b376548787c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,6 +42,7 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From 5011f264fb53705c528250bd055acbc2eca2baaa Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 3 Dec 2015 21:11:10 -0800 Subject: [PATCH 1119/2089] [SPARK-12104][SPARKR] collect() does not handle multiple columns with same name. Author: Sun Rui Closes #10118 from sun-rui/SPARK-12104. --- R/pkg/R/DataFrame.R | 8 ++++---- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a82ded9c51fac..81b4e6b91d8a2 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -822,21 +822,21 @@ setMethod("collect", # Get a column of complex type returns a list. # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] - colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[colName]] <- col + df[[colIndex]] <- col } else { colType <- dtypes[[colIndex]][[2]] # Note that "binary" columns behave like complex types. if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") - df[[colName]] <- vec + df[[colIndex]] <- vec } else { - df[[colName]] <- col + df[[colIndex]] <- col } } } + names(df) <- names(x) df } }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92ec82096c6df..1e7cb54099703 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -530,6 +530,11 @@ test_that("collect() returns a data.frame", { expect_equal(names(rdf)[1], "age") expect_equal(nrow(rdf), 0) expect_equal(ncol(rdf), 2) + + # collect() correctly handles multiple columns with same name + df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + ldf <- collect(df) + expect_equal(names(ldf), c("name", "name")) }) test_that("limit() returns DataFrame with the correct number of rows", { @@ -1197,6 +1202,7 @@ test_that("join() and merge() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) expect_equal(count(joined), 12) + expect_equal(names(collect(joined)), c("age", "name", "name", "test")) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) From 4106d80fb6a16713a6cd2f15ab9d60f2527d9be5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 Dec 2015 01:42:29 -0800 Subject: [PATCH 1120/2089] [SPARK-12122][STREAMING] Prevent batches from being submitted twice after recovering StreamingContext from checkpoint Author: Tathagata Das Closes #10127 from tdas/SPARK-12122. --- .../org/apache/spark/streaming/scheduler/JobGenerator.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2de035d166e7b..8dfdc1f57b403 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -220,7 +220,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + val timesToReschedule = (pendingTimes ++ downTimes).filter { _ < restartTime } + .distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach { time => From 17e4e021ae7fdf5e4dd05a0473faa529e3e80dbb Mon Sep 17 00:00:00 2001 From: kaklakariada Date: Fri, 4 Dec 2015 14:43:16 +0000 Subject: [PATCH 1121/2089] Add links howto to setup IDEs for developing spark These links make it easier for new developers to work with Spark in their IDE. Author: kaklakariada Closes #10104 from kaklakariada/readme-developing-ide-gettting-started. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c0d6a946035a9..d5804d1a20b43 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) +and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). ## Interactive Scala Shell From 95296d9b1ad1d9e9396d7dfd0015ef27ce1cf341 Mon Sep 17 00:00:00 2001 From: Nong Date: Fri, 4 Dec 2015 10:01:20 -0800 Subject: [PATCH 1122/2089] [SPARK-12089] [SQL] Fix memory corrupt due to freeing a page being referenced When the spillable sort iterator was spilled, it was mistakenly keeping the last page in memory rather than the current page. This causes the current record to get corrupted. Author: Nong Closes #10142 from nongli/spark-12089. --- .../util/collection/unsafe/sort/UnsafeExternalSorter.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5a97f4f11340c..79d74b23ceaef 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -443,6 +443,7 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); while (inMemIterator.hasNext()) { @@ -458,9 +459,11 @@ public long spill() throws IOException { long released = 0L; synchronized (UnsafeExternalSorter.this) { - // release the pages except the one that is used + // release the pages except the one that is used. There can still be a caller that + // is accessing the current record. We free this page in that caller's next loadNext() + // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) { + if (!loaded || page.getBaseObject() != upstream.getBaseObject()) { released += page.size(); freePage(page); } else { From d0d82227785dcd6c49a986431c476c7838a9541c Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Fri, 4 Dec 2015 12:03:45 -0800 Subject: [PATCH 1123/2089] [SPARK-6990][BUILD] Add Java linting script; fix minor warnings This replaces https://github.com/apache/spark/pull/9696 Invoke Checkstyle and print any errors to the console, failing the step. Use Google's style rules modified according to https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide Some important checks are disabled (see TODOs in `checkstyle.xml`) due to multiple violations being present in the codebase. Suggest fixing those TODOs in a separate PR(s). More on Checkstyle can be found on the [official website](http://checkstyle.sourceforge.net/). Sample output (from [build 46345](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/46345/consoleFull)) (duplicated because I run the build twice with different profiles): > Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java:[217,7] (coding) MissingSwitchDefault: switch without "default" clause. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[198,10] (modifier) ModifierOrder: 'protected' modifier out of order with the JLS suggestions. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java:[217,7] (coding) MissingSwitchDefault: switch without "default" clause. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[198,10] (modifier) ModifierOrder: 'protected' modifier out of order with the JLS suggestions. > [error] running /home/jenkins/workspace/SparkPullRequestBuilder2/dev/lint-java ; received return code 1 Also fix some of the minor violations that didn't require sweeping changes. Apologies for the previous botched PRs - I finally figured out the issue. cr: JoshRosen, pwendell > I state that the contribution is my original work, and I license the work to the project under the project's open source license. Author: Dmitry Erastov Closes #9867 from dskrvk/master. --- checkstyle-suppressions.xml | 33 ++++ checkstyle.xml | 164 ++++++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 +- dev/lint-java | 30 ++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 + dev/sparktestsupport/__init__.py | 1 + .../examples/ml/JavaSimpleParamsExample.java | 2 +- .../spark/examples/mllib/JavaLDAExample.java | 3 +- ...ultiLabelClassificationMetricsExample.java | 12 +- ...ulticlassClassificationMetricsExample.java | 12 +- .../mllib/JavaRankingMetricsExample.java | 4 +- .../mllib/JavaRecommendationExample.java | 2 +- .../mllib/JavaRegressionMetricsExample.java | 3 +- .../streaming/JavaSqlNetworkWordCount.java | 4 +- .../ml/feature/JavaStringIndexerSuite.java | 6 +- .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../shuffle/ExternalShuffleBlockResolver.java | 2 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- pom.xml | 24 +++ .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/types/SQLUserDefinedType.java | 2 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../apache/spark/sql/hive/test/Complex.java | 86 ++++++--- .../spark/streaming/util/WriteAheadLog.java | 10 +- .../apache/spark/streaming/JavaAPISuite.java | 4 +- .../streaming/JavaTrackStateByKeySuite.java | 2 +- .../apache/spark/tags/ExtendedHiveTest.java | 1 + .../apache/spark/tags/ExtendedYarnTest.java | 1 + .../apache/spark/unsafe/types/UTF8String.java | 8 +- 31 files changed, 368 insertions(+), 70 deletions(-) create mode 100644 checkstyle-suppressions.xml create mode 100644 checkstyle.xml create mode 100755 dev/lint-java diff --git a/checkstyle-suppressions.xml b/checkstyle-suppressions.xml new file mode 100644 index 0000000000000..9242be3d0357a --- /dev/null +++ b/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 0000000000000..a493ee443c752 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,164 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c91e88f31bf9b..c16cbce9a0f6c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -175,7 +175,7 @@ private SortedIterator(int numRecords) { this.position = 0; } - public SortedIterator clone () { + public SortedIterator clone() { SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index d87a1d2a56d99..a5c583f9f2844 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -356,8 +356,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); diff --git a/dev/lint-java b/dev/lint-java new file mode 100755 index 0000000000000..fe8ab83d562d1 --- /dev/null +++ b/dev/lint-java @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 4f390ef1eaa32..7aecea25b2099 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -119,6 +119,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', diff --git a/dev/run-tests.py b/dev/run-tests.py index 9e1abb0697192..e7e10f1d8c725 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,6 +198,11 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) @@ -522,6 +527,8 @@ def main(): # style checks if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") for f in changed_files): + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 8ab6d9e37ca2f..0e8032d13341e 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -31,5 +31,6 @@ "BLOCK_SPARK_UNIT_TESTS": 18, "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, "BLOCK_TIMEOUT": 124 } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 94beeced3d479..ea83e8fef9eb9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,7 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double thresholds[] = {0.45, 0.55}; + double[] thresholds = {0.45, 0.55}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index fd53c81cc4974..de8e739ac9256 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -41,8 +41,9 @@ public static void main(String[] args) { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) + for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); + } return Vectors.dense(values); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java index b54e1ea3f2bcf..5ba01e0d08816 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -57,12 +57,12 @@ public static void main(String[] args) { // Stats by labels for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision - (metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics - .labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure - (metrics.labels()[i])); + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision( + metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure( + metrics.labels()[i])); } // Micro stats diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java index 21f628fb51b6e..5247c9c748618 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -74,12 +74,12 @@ public Tuple2 call(LabeledPoint p) { // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision - (metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics - .labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure - (metrics.labels()[i])); + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( + metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( + metrics.labels()[i])); } //Weighted stats diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index 7c4c97e74681f..47ab3fc358246 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -120,8 +120,8 @@ public List call(Rating[] docs) { } } ); - JavaRDD, List>> relevantDocs = userMoviesList.join - (userRecommendedList).values(); + JavaRDD, List>> relevantDocs = userMoviesList.join( + userRecommendedList).values(); // Instantiate the metrics object RankingMetrics metrics = RankingMetrics.of(relevantDocs); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java index 1065fde953b96..c179e7578cdfa 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -29,7 +29,7 @@ // $example off$ public class JavaRecommendationExample { - public static void main(String args[]) { + public static void main(String[] args) { // $example on$ SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); JavaSparkContext jsc = new JavaSparkContext(conf); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java index d2efc6bf97776..4e89dd0c37c52 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -43,8 +43,9 @@ public static void main(String[] args) { public LabeledPoint call(String line) { String[] parts = line.split(" "); double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) + for (int i = 1; i < parts.length - 1; i++) { v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + } return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb57..3515d7be45d37 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -112,8 +112,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 6b2c48ef1c342..b2df79ba74feb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -58,7 +58,7 @@ public void testStringIndexer() { createStructField("label", StringType, false) }); List data = Arrays.asList( - c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() @@ -67,12 +67,12 @@ public void testStringIndexer() { DataFrame output = indexer.fit(dataset).transform(dataset); Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, output.orderBy("id").select("id", "labelIndex").collect()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 3fea359a3b46c..225a216270b3b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -144,7 +144,7 @@ public Boolean call(Tuple2 tuple2) { } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0d4dd6afac769..e5cb68c8a4dbb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -419,7 +419,7 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); public final int major; public final int minor; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 19c870aebb023..f573d962fe361 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -61,7 +61,7 @@ public class SaslIntegrationSuite { // Use a long timeout to account for slow / overloaded build machines. In the normal case, // tests should finish way before the timeout expires. - private final static long TIMEOUT_MS = 10_000; + private static final long TIMEOUT_MS = 10_000; static TransportServer server; static TransportConf conf; diff --git a/pom.xml b/pom.xml index 234fd5dea1a6e..16e656d11961d 100644 --- a/pom.xml +++ b/pom.xml @@ -2256,6 +2256,30 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + true + false + ${basedir}/src/main/java + ${basedir}/src/test/java + checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3986d6e18f770..352002b3499a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,7 +51,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b36..1e4e5ede8cc11 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 2ed30c1f5a8d9..842dcb8c93dc2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -195,7 +195,7 @@ protected static final class NullIntIterator extends IntIterator { * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index e010112bb9327..4ef1f276d1bbb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -489,6 +489,7 @@ public void setFieldValue(_Fields field, Object value) { } break; + default: } } @@ -512,6 +513,7 @@ public Object getFieldValue(_Fields field) { case M_STRING_STRING: return getMStringString(); + default: } throw new IllegalStateException(); } @@ -535,75 +537,91 @@ public boolean isSet(_Fields field) { return isSetLintString(); case M_STRING_STRING: return isSetMStringString(); + default: } throw new IllegalStateException(); } @Override public boolean equals(Object that) { - if (that == null) + if (that == null) { return false; - if (that instanceof Complex) + } + if (that instanceof Complex) { return this.equals((Complex)that); + } return false; } public boolean equals(Complex that) { - if (that == null) + if (that == null) { return false; + } boolean this_present_aint = true; boolean that_present_aint = true; if (this_present_aint || that_present_aint) { - if (!(this_present_aint && that_present_aint)) + if (!(this_present_aint && that_present_aint)) { return false; - if (this.aint != that.aint) + } + if (this.aint != that.aint) { return false; + } } boolean this_present_aString = true && this.isSetAString(); boolean that_present_aString = true && that.isSetAString(); if (this_present_aString || that_present_aString) { - if (!(this_present_aString && that_present_aString)) + if (!(this_present_aString && that_present_aString)) { return false; - if (!this.aString.equals(that.aString)) + } + if (!this.aString.equals(that.aString)) { return false; + } } boolean this_present_lint = true && this.isSetLint(); boolean that_present_lint = true && that.isSetLint(); if (this_present_lint || that_present_lint) { - if (!(this_present_lint && that_present_lint)) + if (!(this_present_lint && that_present_lint)) { return false; - if (!this.lint.equals(that.lint)) + } + if (!this.lint.equals(that.lint)) { return false; + } } boolean this_present_lString = true && this.isSetLString(); boolean that_present_lString = true && that.isSetLString(); if (this_present_lString || that_present_lString) { - if (!(this_present_lString && that_present_lString)) + if (!(this_present_lString && that_present_lString)) { return false; - if (!this.lString.equals(that.lString)) + } + if (!this.lString.equals(that.lString)) { return false; + } } boolean this_present_lintString = true && this.isSetLintString(); boolean that_present_lintString = true && that.isSetLintString(); if (this_present_lintString || that_present_lintString) { - if (!(this_present_lintString && that_present_lintString)) + if (!(this_present_lintString && that_present_lintString)) { return false; - if (!this.lintString.equals(that.lintString)) + } + if (!this.lintString.equals(that.lintString)) { return false; + } } boolean this_present_mStringString = true && this.isSetMStringString(); boolean that_present_mStringString = true && that.isSetMStringString(); if (this_present_mStringString || that_present_mStringString) { - if (!(this_present_mStringString && that_present_mStringString)) + if (!(this_present_mStringString && that_present_mStringString)) { return false; - if (!this.mStringString.equals(that.mStringString)) + } + if (!this.mStringString.equals(that.mStringString)) { return false; + } } return true; @@ -615,33 +633,39 @@ public int hashCode() { boolean present_aint = true; builder.append(present_aint); - if (present_aint) + if (present_aint) { builder.append(aint); + } boolean present_aString = true && (isSetAString()); builder.append(present_aString); - if (present_aString) + if (present_aString) { builder.append(aString); + } boolean present_lint = true && (isSetLint()); builder.append(present_lint); - if (present_lint) + if (present_lint) { builder.append(lint); + } boolean present_lString = true && (isSetLString()); builder.append(present_lString); - if (present_lString) + if (present_lString) { builder.append(lString); + } boolean present_lintString = true && (isSetLintString()); builder.append(present_lintString); - if (present_lintString) + if (present_lintString) { builder.append(lintString); + } boolean present_mStringString = true && (isSetMStringString()); builder.append(present_mStringString); - if (present_mStringString) + if (present_mStringString) { builder.append(mStringString); + } return builder.toHashCode(); } @@ -737,7 +761,9 @@ public String toString() { sb.append("aint:"); sb.append(this.aint); first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("aString:"); if (this.aString == null) { sb.append("null"); @@ -745,7 +771,9 @@ public String toString() { sb.append(this.aString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lint:"); if (this.lint == null) { sb.append("null"); @@ -753,7 +781,9 @@ public String toString() { sb.append(this.lint); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lString:"); if (this.lString == null) { sb.append("null"); @@ -761,7 +791,9 @@ public String toString() { sb.append(this.lString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lintString:"); if (this.lintString == null) { sb.append("null"); @@ -769,7 +801,9 @@ public String toString() { sb.append(this.lintString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("mStringString:"); if (this.mStringString == null) { sb.append("null"); diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 3738fc1a235c2..2803cad8095dd 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -37,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 609bb4413b6b1..9722c60bba1c3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1332,12 +1332,12 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( + List> initial = Arrays.asList( new Tuple2<>("california", 1), new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( Arrays.asList(new Tuple2<>("california", 5), diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java index eac4cdd14a683..89d0bb7b617e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -95,7 +95,7 @@ public Double call(Optional one, State state) { JavaTrackStateDStream stateDstream2 = wordsDstream.trackStateByKey( - StateSpec. function(trackStateFunc2) + StateSpec.function(trackStateFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java index 1b0c416b0fe4e..83279e5e93c0e 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java index 2a631bfc88cf0..108300168e173 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4bd3fd7772079..5b61386808769 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -900,9 +900,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -965,7 +965,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; From 302d68de87dbaf1974accf49de26fc01fc0eb089 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 4 Dec 2015 12:08:42 -0800 Subject: [PATCH 1124/2089] [SPARK-12058][STREAMING][KINESIS][TESTS] fix Kinesis python tests Python tests require access to the `KinesisTestUtils` file. When this file exists under src/test, python can't access it, since it is not available in the assembly jar. However, if we move KinesisTestUtils to src/main, we need to add the KinesisProducerLibrary as a dependency. In order to avoid this, I moved KinesisTestUtils to src/main, and extended it with ExtendedKinesisTestUtils which is under src/test that adds support for the KPL. cc zsxwing tdas Author: Burak Yavuz Closes #10050 from brkyvz/kinesis-py. --- .../streaming/kinesis/KinesisTestUtils.scala | 88 +++++++++---------- .../kinesis/KPLBasedKinesisTestUtils.scala | 72 +++++++++++++++ .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../kinesis/KinesisStreamSuite.scala | 2 +- python/pyspark/streaming/tests.py | 1 - 5 files changed, 115 insertions(+), 50 deletions(-) rename extras/kinesis-asl/src/{test => main}/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala (82%) create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 82% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 7487aa1c12639..0ace453ee9280 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,13 +31,13 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ -import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} -import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging /** - * Shared utility methods for performing Kinesis tests that actually transfer data + * Shared utility methods for performing Kinesis tests that actually transfer data. + * + * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ private[kinesis] class KinesisTestUtils extends Logging { @@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging { @volatile private var _streamName: String = _ - private lazy val kinesisClient = { + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) client @@ -66,14 +66,12 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } - private lazy val kinesisProducer: KinesisProducer = { - val conf = new KinesisProducerConfiguration() - .setRecordMaxBufferedTime(1000) - .setMaxConnections(1) - .setRegion(regionName) - .setMetricsLevel("none") - - new KinesisProducer(conf) + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + throw new UnsupportedOperationException("Aggregation is not supported through this code path") + } } def streamName: String = { @@ -104,41 +102,8 @@ private[kinesis] class KinesisTestUtils extends Logging { */ def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") - val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() - - testData.foreach { num => - val str = num.toString - val data = ByteBuffer.wrap(str.getBytes()) - if (aggregate) { - val future = kinesisProducer.addUserRecord(streamName, str, data) - val kinesisCallBack = new FutureCallback[UserRecordResult]() { - override def onFailure(t: Throwable): Unit = {} // do nothing - - override def onSuccess(result: UserRecordResult): Unit = { - val shardId = result.getShardId - val seqNumber = result.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - - Futures.addCallback(future, kinesisCallBack) - kinesisProducer.flushSync() // make sure we send all data before returning the map - } else { - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(data) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - + val producer = getProducer(aggregate) + val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") shardIdToSeqNumbers.toMap } @@ -264,3 +229,32 @@ private[kinesis] object KinesisTestUtils { } } } + +/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ +private[kinesis] trait KinesisDataGenerator { + /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */ + def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] +} + +private[kinesis] class SimpleDataGenerator( + client: AmazonKinesisClient) extends KinesisDataGenerator { + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = client.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala new file mode 100644 index 0000000000000..fdb270eaad8c9 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} + +private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { + override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + new KPLDataGenerator(regionName) + } + } +} + +/** A wrapper for the KinesisProducer provided in the KPL. */ +private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator { + + private lazy val producer: KPLProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KPLProducer(conf) + } + + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val future = producer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + Futures.addCallback(future, kinesisCallBack) + } + producer.flushSync() + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 52c61dfb1c023..d85b4cda8ce98 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -40,7 +40,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index dee30444d8cc6..78cec021b78c1 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -63,7 +63,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun sc = new SparkContext(conf) runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d50c6b8d4a428..a2bfd79e1abcd 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1458,7 +1458,6 @@ def test_kinesis_stream_api(self): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") - @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write( From d64806b37373c5cc4fd158a9f5005743bd00bf28 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 4 Dec 2015 13:05:07 -0800 Subject: [PATCH 1125/2089] [SPARK-11314][BUILD][HOTFIX] Add exclusion for moved YARN classes. Author: Marcelo Vanzin Closes #10147 from vanzin/SPARK-11314. --- project/MimaExcludes.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d3a3c0ceb68c8..b4aa6adc3c620 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -159,7 +159,10 @@ object MimaExcludes { // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") - ) + ) ++ + // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a + // private class. + MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From b7204e1d41271d2e8443484371770936664350b1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 5 Dec 2015 08:15:30 +0800 Subject: [PATCH 1126/2089] [SPARK-12112][BUILD] Upgrade to SBT 0.13.9 We should upgrade to SBT 0.13.9, since this is a requirement in order to use SBT's new Maven-style resolution features (which will be done in a separate patch, because it's blocked by some binary compatibility issues in the POM reader plugin). I also upgraded Scalastyle to version 0.8.0, which was necessary in order to fix a Scala 2.10.5 compatibility issue (see https://github.com/scalastyle/scalastyle/issues/156). The newer Scalastyle is slightly stricter about whitespace surrounding tokens, so I fixed the new style violations. Author: Josh Rosen Closes #10112 from JoshRosen/upgrade-to-sbt-0.13.9. --- .../org/apache/spark/deploy/JsonProtocol.scala | 2 +- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 10 +++++----- .../spark/serializer/KryoSerializerSuite.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../streaming/clickstream/PageViewStream.scala | 6 ++++-- .../streaming/flume/sink/SparkSinkSuite.scala | 4 ++-- .../spark/graphx/lib/TriangleCountSuite.scala | 2 +- .../scala/org/apache/spark/ml/param/params.scala | 2 ++ .../mllib/stat/test/StreamingTestMethod.scala | 4 ++-- .../DecisionTreeClassifierSuite.scala | 4 ++-- .../ml/regression/DecisionTreeRegressorSuite.scala | 6 +++--- .../spark/mllib/tree/DecisionTreeSuite.scala | 14 +++++++------- pom.xml | 2 +- project/build.properties | 2 +- project/plugins.sbt | 7 +------ .../spark/sql/catalyst/expressions/CastSuite.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 8 ++++---- .../sql/execution/columnar/ColumnTypeSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 4 ++-- .../sql/execution/metric/SQLMetricsSuite.scala | 2 +- 20 files changed, 47 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index ccffb36652988..220b20bf7cbd1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -45,7 +45,7 @@ private[deploy] object JsonProtocol { ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ - ("user" -> obj.desc.user) ~ + ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 46ed5c04f4338..007a71f87cf10 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -101,21 +101,21 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[UnionRDD[_]]) } test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) } test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) intercept[IllegalArgumentException] { new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e428414cf6e85..f81fe3113106f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -144,10 +144,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("Bug: SPARK-10251") { @@ -174,10 +174,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("ranges") { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index bea7a47cb2855..2fcccb22dddf7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -51,8 +51,8 @@ object PageView extends Serializable { */ // scalastyle:on object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, "http://foo.com/contact" -> .1) val httpStatus = Map(200 -> .95, 404 -> .05) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 4ef238606f82e..723616817f6a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -86,8 +86,10 @@ object PageViewStream { .map("Unique active users: " + _) // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2 -> "Reynold Xin", 3 -> "Matei Zaharia").toSeq) + val userList = ssc.sparkContext.parallelize(Seq( + 1 -> "Patrick Wendell", + 2 -> "Reynold Xin", + 3 -> "Matei Zaharia")) metric match { case "pageCounts" => pageCounts.print() diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index d2654700ea729..941fde45cd7b7 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -36,11 +36,11 @@ import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory // Spark core main, which has too many dependencies to require here manually. // For this reason, we continue to use FunSuite and ignore the scalastyle checks // that fail if this is detected. -//scalastyle:off +// scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { -//scalastyle:on +// scalastyle:on val eventsPerBatch = 1000 val channelCapacity = 5000 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index c47552cf3a3bd..608e43cf3ff53 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -26,7 +26,7 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => - val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) + val rawEdges = sc.parallelize(Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d182b0a98896c..ee7e89edd8798 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -82,7 +82,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali def w(value: T): ParamPair[T] = this -> value /** Creates a param pair with the given value (for Scala). */ + // scalastyle:off def ->(value: T): ParamPair[T] = ParamPair(this, value) + // scalastyle:on /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ def jsonEncode(value: T): String = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala index a7eaed51b4d55..911b4b9237356 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -152,8 +152,8 @@ private[stat] object StudentTTest extends StreamingTestMethod with Logging { private[stat] object StreamingTestMethod { // Note: after new `StreamingTestMethod`s are implemented, please update this map. private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( - "welch"->WelchTTest, - "student"->StudentTTest) + "welch" -> WelchTTest, + "student" -> StudentTTest) def getTestMethodFromName(method: String): StreamingTestMethod = TEST_NAME_TO_OBJECT.get(method) match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 92b8f84144ab0..fda2711fed0fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -73,7 +73,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxDepth(2) .setMaxBins(100) .setSeed(1) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } @@ -214,7 +214,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxBins(2) .setMaxDepth(2) .setMinInstancesPerNode(2) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val numClasses = 2 compareAPIs(rdd, dt, categoricalFeatures, numClasses) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index e0d5afa7a7e97..6999a910c34a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -50,7 +50,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setMaxDepth(2) .setMaxBins(100) .setSeed(1) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } @@ -59,12 +59,12 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } test("copied model must have the same parent") { - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val model = new DecisionTreeRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 1a4299db4eab2..bf8fe1acac2fe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -64,7 +64,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -178,7 +178,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) @@ -237,7 +237,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) @@ -421,7 +421,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -455,7 +455,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -484,7 +484,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -788,7 +788,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, - maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2), numClasses = 2, minInstancesPerNode = 2) val rootNode = DecisionTree.train(rdd, strategy).topNode diff --git a/pom.xml b/pom.xml index 16e656d11961d..ae2ff8878b0a5 100644 --- a/pom.xml +++ b/pom.xml @@ -2235,7 +2235,7 @@ org.scalastyle scalastyle-maven-plugin - 0.7.0 + 0.8.0 false true diff --git a/project/build.properties b/project/build.properties index 064ec843da9ea..86ca8755820a4 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.7 +sbt.version=0.13.9 diff --git a/project/plugins.sbt b/project/plugins.sbt index c06687d8f197b..5e23224cf8aa5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -10,14 +10,9 @@ addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") -// For Sonatype publishing -//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) - -//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") - addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ab77a764483e8..a98e16c253214 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -734,7 +734,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val complex = Literal.create( Row( Seq("123", "true", "f"), - Map("a" ->"123", "b" -> "true", "c" -> "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index fdd745f48e973..bb3e2786978c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -862,9 +862,9 @@ private[sql] object ParquetRelation extends Logging { // The parquet compression short names val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, + "NONE" -> CompressionCodecName.UNCOMPRESSED, "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) + "SNAPPY" -> CompressionCodecName.SNAPPY, + "GZIP" -> CompressionCodecName.GZIP, + "LZO" -> CompressionCodecName.LZO) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 34dd96929e6c1..706ff1f998501 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -35,7 +35,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index ce3affba55c71..95642e93ae9f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -100,11 +100,11 @@ class RunLengthEncodingSuite extends SparkFunSuite { } test(s"$RunLengthEncoding with $typeName: simple case") { - skeleton(2, Seq(0 -> 2, 1 ->2)) + skeleton(2, Seq(0 -> 2, 1 -> 2)) } test(s"$RunLengthEncoding with $typeName: run length == 1") { - skeleton(2, Seq(0 -> 1, 1 ->1)) + skeleton(2, Seq(0 -> 1, 1 -> 1)) } test(s"$RunLengthEncoding with $typeName: single long run") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4f2cad19bfb6b..4339f7260dcb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -116,7 +116,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) val df = person.select('name) testSparkPlanMetrics(df, 1, Map( - 0L ->("Project", Map( + 0L -> ("Project", Map( "number of rows" -> 2L))) ) } From bbfc16ec9d690c2dfa20896bd6d33f9783b9c109 Mon Sep 17 00:00:00 2001 From: meiyoula <1039320815@qq.com> Date: Fri, 4 Dec 2015 16:50:40 -0800 Subject: [PATCH 1127/2089] [SPARK-12142][CORE]Reply false when container allocator is not ready and reset target Using Dynamic Allocation function, when a new AM is starting, and ExecutorAllocationManager send RequestExecutor message to AM. If the container allocator is not ready, the whole app will hang on Author: meiyoula <1039320815@qq.com> Closes #10138 from XuTingjun/patch-1. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 1 + .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6419218f47c85..34c32ce312964 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -370,6 +370,7 @@ private[spark] class ExecutorAllocationManager( } else { logWarning( s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") + numExecutorsTarget = oldNumExecutorsTarget 0 } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 13ef4dfd64165..1970f7d150feb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -600,11 +600,12 @@ private[spark] class ApplicationMaster( localityAwareTasks, hostToLocalTaskCount)) { resetAllocatorInterval() } + context.reply(true) case None => logWarning("Container allocator is not ready to request executors yet.") + context.reply(false) } - context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") From f30373f5ee60f9892c28771e34b208e4f1f675a6 Mon Sep 17 00:00:00 2001 From: rotems Date: Fri, 4 Dec 2015 16:58:31 -0800 Subject: [PATCH 1128/2089] [SPARK-12080][CORE] Kryo - Support multiple user registrators Author: rotems Closes #10078 from Botnaim/KryoMultipleCustomRegistrators. --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 6 ++++-- docs/configuration.md | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d5ba690ed04be..7b77f78ce6f1a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -70,7 +70,9 @@ class KryoSerializer(conf: SparkConf) private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val userRegistrators = conf.get("spark.kryo.registrator", "") + .split(',') + .filter(!_.isEmpty) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) @@ -119,7 +121,7 @@ class KryoSerializer(conf: SparkConf) classesToRegister .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. - userRegistrator + userRegistrators .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname diff --git a/docs/configuration.md b/docs/configuration.md index c39b4890851bc..fd61ddc244f44 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -647,10 +647,10 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This + If you use Kryo serialization, give a comma-separated list of classes that register your custom classes with Kryo. This property is useful if you need to register your classes in a custom way, e.g. to specify a custom field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends + set to classes that extend KryoRegistrator. See the tuning guide for more details. From 3af53e61fd604fe8000e1fdf656d60b79c842d1c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 4 Dec 2015 17:02:04 -0800 Subject: [PATCH 1129/2089] [SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly `ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct. This patch fixed all places that use `ByteBuffer.array` incorrectly. Author: Shixiong Zhu Closes #10083 from zsxwing/bytebuffer-array. --- .../network/netty/NettyBlockTransferService.scala | 12 +++--------- .../org/apache/spark/scheduler/DAGScheduler.scala | 6 ++++-- .../scala/org/apache/spark/scheduler/Task.scala | 4 ++-- .../spark/serializer/GenericAvroSerializer.scala | 5 ++++- .../apache/spark/serializer/KryoSerializer.scala | 4 ++-- .../spark/storage/TachyonBlockManager.scala | 2 +- .../main/scala/org/apache/spark/util/Utils.scala | 15 ++++++++++++++- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 5 +++-- .../apache/spark/scheduler/TaskContextSuite.scala | 3 ++- .../pythonconverters/AvroConverters.scala | 5 ++++- .../spark/streaming/flume/FlumeInputDStream.scala | 6 +++--- .../streaming/flume/FlumePollingStreamSuite.scala | 4 ++-- .../spark/streaming/flume/FlumeStreamSuite.scala | 4 ++-- .../streaming/kinesis/KinesisStreamSuite.scala | 3 ++- .../parquet/UnsafeRowParquetRecordReader.java | 7 ++++--- .../spark/sql/execution/SparkSqlSerializer.scala | 3 ++- .../columnar/InMemoryColumnarTableScan.scala | 5 ++++- .../parquet/CatalystRowConverter.scala | 8 ++++---- .../scheduler/ReceivedBlockTracker.scala | 5 +++-- .../streaming/util/BatchedWriteAheadLog.scala | 15 ++++++--------- .../util/FileBasedWriteAheadLogWriter.scala | 14 +++----------- .../spark/streaming/JavaWriteAheadLogSuite.java | 15 +++++++-------- 22 files changed, 81 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 82c16e855b0c0..40604a4da18d5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded // using our binary protocol. - val levelBytes = serializer.newInstance().serialize(level).array() + val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level)) // Convert or copy nio buffer into array in order to serialize it. - val nioBuffer = blockData.nioByteBuffer() - val array = if (nioBuffer.hasArray) { - nioBuffer.array() - } else { - val data = new Array[Byte](nioBuffer.remaining()) - nioBuffer.get(data) - data - } + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e01a9609b9a0d..5582720bbcff2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout @@ -997,9 +998,10 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2fcd5aa57d11b..5fe5ae8c45819 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -191,8 +191,8 @@ private[spark] object Task { // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) + val taskBytes = serializer.serialize(task) + Utils.writeByteBuffer(taskBytes, out) ByteBuffer.wrap(out.toByteArray) } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 62f8aae7f2126..8d6af9cae8927 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) * seen values so to limit the number of times that decompression has to be done. */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { - val bis = new ByteArrayInputStream(schemaBytes.array()) + val bis = new ByteArrayInputStream( + schemaBytes.array(), + schemaBytes.arrayOffset() + schemaBytes.position(), + schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) new Schema.Parser().parse(new String(bytes, "UTF-8")) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7b77f78ce6f1a..62d445f3d7bd9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -309,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -321,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 22878783fca67..d14fe4613528a 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) try { - os.write(bytes.array()) + Utils.writeByteBuffer(bytes, os) } catch { case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index af632349c9cae..9dbe66e7eefbd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -178,7 +178,20 @@ private[spark] object Utils extends Logging { /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] + */ + def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index a5c583f9f2844..8724a34988421 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -41,6 +41,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; @@ -430,7 +431,7 @@ public void randomizedStressTest() { } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); @@ -480,7 +481,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 450ab7b9fe92b..d83d0aee42254 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Matchers.any import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import org.apache.spark.metrics.source.JvmSource @@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 805184e740f06..cf12c98b4af6c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable { def unpackBytes(obj: Any): Array[Byte] = { val bytes: Array[Byte] = obj match { - case buf: java.nio.ByteBuffer => buf.array() + case buf: java.nio.ByteBuffer => + val arr = new Array[Byte](buf.remaining()) + buf.get(arr) + arr case arr: Array[Byte] => arr case other => throw new SparkException( s"Unknown BYTES type ${other.getClass.getName}") diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index c8780aa83bdbd..2b9116eb3c790 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -93,9 +93,9 @@ class SparkFlumeEvent() extends Externalizable { /* Serialize to bytes. */ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) + val body = event.getBody + out.writeInt(body.remaining()) + Utils.writeByteBuffer(body, out) val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 5fd2711f5f7df..bb951a6ef100d 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -24,11 +24,11 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} @@ -119,7 +119,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { case (key, value) => (key.toString, value.toString) }).map(_.asJava) - val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody)) utils.assertOutput(headers.asJava, bodies.asJava) } } finally { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f315e0a7ca23c..b29e591c07374 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} @@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) output should be (input) } } finally { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 78cec021b78c1..6fe24fe81165b 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ @@ -196,7 +197,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun testIfEnabled("custom message handling") { val awsCredentials = KinesisTestUtils.getAWSCredentials() - def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5 + def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, addFive, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index dade488ca281b..0cc4566c9cdde 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -332,12 +332,13 @@ private void decodeBinaryBatch(int col, int num) throws IOException { for (int n = 0; n < num; ++n) { if (columnReaders[col].next()) { ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); - int len = bytes.limit() - bytes.position(); + int len = bytes.remaining(); if (originalTypes[col] == OriginalType.UTF8) { - UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + UTF8String str = + UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len); rowWriters[n].write(col, str); } else { - rowWriters[n].write(col, bytes.array(), bytes.position(), len); + rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len); } rows[n].setNotNullAt(col); } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 8317f648ccb4e..45a8e03248267 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair @@ -76,7 +77,7 @@ private[sql] object SparkSqlSerializer { def serialize[T: ClassTag](o: T): Array[Byte] = acquireRelease { k => - k.serialize(o).array() + JavaUtils.bufferToArray(k.serialize(o)) } def deserialize[T: ClassTag](bytes: Array[Byte]): T = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ce701fb3a7f28..3c5a8cb2aa935 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -163,7 +164,9 @@ private[sql] case class InMemoryRelation( .flatMap(_.values)) batchStats += stats - CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats) + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) } def hasNext: Boolean = rowIterator.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 94298fae2d69b..8851bc23cd050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -327,8 +327,8 @@ private[parquet] class CatalystRowConverter( // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying // it. val buffer = value.toByteBuffer - val offset = buffer.position() - val numBytes = buffer.limit() - buffer.position() + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) } } @@ -644,8 +644,8 @@ private[parquet] object CatalystRowConverter { // copying it. val buffer = binary.toByteBuffer val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() + val start = buffer.arrayOffset() + buffer.position() + val end = buffer.arrayOffset() + buffer.limit() var unscaled = 0L var i = start diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 500dc70c98506..4dab64d696b3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} @@ -210,9 +211,9 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") writeAheadLog.readAll().asScala.foreach { byteBuffer => - logTrace("Recovering record " + byteBuffer) + logInfo("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( - byteBuffer.array, Thread.currentThread().getContextClassLoader) match { + JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 6e6ed8d819721..7158abc08894a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils /** @@ -197,17 +198,10 @@ private[util] object BatchedWriteAheadLog { */ case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) - /** Copies the byte array of a ByteBuffer. */ - private def getByteArray(buffer: ByteBuffer): Array[Byte] = { - val byteArray = new Array[Byte](buffer.remaining()) - buffer.get(byteArray) - byteArray - } - /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ def aggregate(records: Seq[Record]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( - records.map(record => getByteArray(record.data)).toArray)) + records.map(record => JavaUtils.bufferToArray(record.data)).toArray)) } /** @@ -216,10 +210,13 @@ private[util] object BatchedWriteAheadLog { * method therefore needs to be backwards compatible. */ def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + val prevPosition = buffer.position() try { - Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap) } catch { case _: ClassCastException => // users may restart a stream with batching enabled + // Restore `position` so that the user can read `buffer` later + buffer.position(prevPosition) Array(buffer) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index e146bec32a456..1185f30265f63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -24,6 +24,8 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.spark.util.Utils + /** * A writer for writing byte-buffers to a write ahead log file. */ @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: val lengthToWrite = data.remaining() val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } + Utils.writeByteBuffer(data, stream: OutputStream) flush() nextOffset = stream.getPos() segment diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 09b5f8ed03279..f02fa87f6194b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; @@ -27,6 +26,7 @@ import com.google.common.base.Function; import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; import org.apache.spark.streaming.util.WriteAheadLogUtils; @@ -112,20 +112,19 @@ public void testCustomWAL() { WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = - wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); + WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); + Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); - wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); + wal.write(JavaUtils.stringToBytes("data2"), 1235); + wal.write(JavaUtils.stringToBytes("data3"), 1236); + wal.write(JavaUtils.stringToBytes("data4"), 1237); wal.clean(1236, false); Iterator dataIterator = wal.readAll(); List readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); + readData.add(JavaUtils.bytesToString(dataIterator.next())); } Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } From ee94b70ce56661ea26c5aad17778ade32f3f1d3d Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 5 Dec 2015 15:27:31 +0000 Subject: [PATCH 1130/2089] [SPARK-12096][MLLIB] remove the old constraint in word2vec jira: https://issues.apache.org/jira/browse/SPARK-12096 word2vec now can handle much bigger vocabulary. The old constraint vocabSize.toLong * vectorSize < Ine.max / 8 should be removed. new constraint is vocabSize.toLong * vectorSize < max array length (usually a little less than Int.MaxValue) I tested with vocabsize over 18M and vectorsize = 100. srowen jkbradley Sorry to miss this in last PR. I was reminded today. Author: Yuhao Yang Closes #10103 from hhbyyh/w2vCapacity. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 655ac0bb5545b..be12d45286034 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -306,10 +306,10 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + if (vocabSize.toLong * vectorSize >= Int.MaxValue) { throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + - "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.") } val syn0Global = From e9c9ae22b96e08e5bb40a029e84d342efb1aec0c Mon Sep 17 00:00:00 2001 From: Antonio Murgia Date: Sat, 5 Dec 2015 15:42:02 +0000 Subject: [PATCH 1131/2089] [SPARK-11994][MLLIB] Word2VecModel load and save cause SparkException when model is bigger than spark.kryoserializer.buffer.max Author: Antonio Murgia Closes #9989 from tmnd1991/SPARK-11932. --- .../apache/spark/mllib/feature/Word2Vec.scala | 16 ++++++++++++---- .../spark/mllib/feature/Word2VecSuite.scala | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index be12d45286034..b693f3c8e4bd9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -604,13 +604,21 @@ object Word2VecModel extends Loader[Word2VecModel] { val vectorSize = model.values.head.size val numWords = model.size - val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ - ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + val metadata = compact(render( + ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + // We want to partition the model in partitions of size 32MB + val partitionSize = (1L << 25) + // We calculate the approximate size of the model + // We only calculate the array size, not considering + // the string size, the formula is: + // floatSize * numWords * vectorSize + val approxSize = 4L * numWords * vectorSize + val nPartitions = ((approxSize / partitionSize) + 1).toInt val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index a864eec460f2b..37d01e2876695 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("big model load / save") { + // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25 + val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + } From 7da674851928ed23eb651a3e2f8233e7a684ac41 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 5 Dec 2015 15:52:52 +0000 Subject: [PATCH 1132/2089] [SPARK-11988][ML][MLLIB] Update JPMML to 1.2.7 Update JPMML pmml-model to 1.2.7 Author: Sean Owen Closes #9972 from srowen/SPARK-11988. --- LICENSE | 3 +- mllib/pom.xml | 2 +- .../BinaryClassificationPMMLModelExport.scala | 32 +++++++------- .../GeneralizedLinearPMMLModelExport.scala | 26 +++++------ .../pmml/export/KMeansPMMLModelExport.scala | 44 +++++++++---------- .../mllib/pmml/export/PMMLModelExport.scala | 17 +++---- 6 files changed, 59 insertions(+), 65 deletions(-) diff --git a/LICENSE b/LICENSE index 0db2d14465bd3..a2f75b817ab37 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -237,7 +236,7 @@ The following components are provided under a BSD-style license. See project lin The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/mllib/pom.xml b/mllib/pom.xml index 70139121d8c78..df50aca1a3f76 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -109,7 +109,7 @@ org.jpmml pmml-model - 1.1.15 + 1.2.7 com.sun.xml.fastinfoset diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 622b53a252ac5..7abb1bf7ce967 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -45,7 +45,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( val fields = new SArray[FieldName](model.weights.size) val dataDictionary = new DataDictionary val miningSchema = new MiningSchema - val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1") var interceptNO = threshold if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { if (threshold <= 0) { @@ -56,35 +56,35 @@ private[mllib] class BinaryClassificationPMMLModelExport( interceptNO = -math.log(1 / threshold - 1) } } - val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionTableNO = new RegressionTable(interceptNO).setTargetCategory("0") val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.CLASSIFICATION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withNormalizationMethod(normalizationMethod) - .withRegressionTables(regressionTableYES, regressionTableNO) + .setFunctionName(MiningFunctionType.CLASSIFICATION) + .setMiningSchema(miningSchema) + .setModelName(description) + .setNormalizationMethod(normalizationMethod) + .addRegressionTables(regressionTableYES, regressionTableNO) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // add target field val targetField = FieldName.create("target") dataDictionary - .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + .addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 1874786af0002..4d951d2973a6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -45,31 +45,31 @@ private[mllib] class GeneralizedLinearPMMLModelExport( val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.REGRESSION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withRegressionTables(regressionTable) + .setFunctionName(MiningFunctionType.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName(description) + .addRegressionTables(regressionTable) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTable.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // for completeness add target field val targetField = FieldName.create("target") - dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index 069e7afc9fca0..b5b824bb9c9b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -42,42 +42,42 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val comparisonMeasure = new ComparisonMeasure() - .withKind(ComparisonMeasure.Kind.DISTANCE) - .withMeasure(new SquaredEuclidean()) + .setKind(ComparisonMeasure.Kind.DISTANCE) + .setMeasure(new SquaredEuclidean()) val clusteringModel = new ClusteringModel() - .withModelName("k-means") - .withMiningSchema(miningSchema) - .withComparisonMeasure(comparisonMeasure) - .withFunctionName(MiningFunctionType.CLUSTERING) - .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) - .withNumberOfClusters(model.clusterCenters.length) + .setModelName("k-means") + .setMiningSchema(miningSchema) + .setComparisonMeasure(comparisonMeasure) + .setFunctionName(MiningFunctionType.CLUSTERING) + .setModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .setNumberOfClusters(model.clusterCenters.length) for (i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - clusteringModel.withClusteringFields( - new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + clusteringModel.addClusteringFields( + new ClusteringField(fields(i)).setCompareFunction(CompareFunctionType.ABS_DIFF)) } - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) - for (i <- 0 until model.clusterCenters.length) { + for (i <- model.clusterCenters.indices) { val cluster = new Cluster() - .withName("cluster_" + i) - .withArray(new org.dmg.pmml.Array() - .withType(Array.Type.REAL) - .withN(clusterCenter.size) - .withValue(model.clusterCenters(i).toArray.mkString(" "))) + .setName("cluster_" + i) + .setArray(new org.dmg.pmml.Array() + .setType(Array.Type.REAL) + .setN(clusterCenter.size) + .setValue(model.clusterCenters(i).toArray.mkString(" "))) // we don't have the size of the single cluster but only the centroids (withValue) // .withSize(value) - clusteringModel.withClusters(cluster) + clusteringModel.addClusters(cluster) } pmml.setDataDictionary(dataDictionary) - pmml.withModels(clusteringModel) + pmml.addModels(clusteringModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 9267e6dbdb857..426bb818c9266 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -30,19 +30,14 @@ private[mllib] trait PMMLModelExport { * Holder of the exported model in PMML format */ @BeanProperty - val pmml: PMML = new PMML - - pmml.setVersion("4.2") - setHeader(pmml) - - private def setHeader(pmml: PMML): Unit = { + val pmml: PMML = { val version = getClass.getPackage.getImplementationVersion - val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) val header = new Header() - .withApplication(app) - .withTimestamp(timestamp) - pmml.setHeader(header) + .setApplication(app) + .setTimestamp(timestamp) + new PMML("4.2", header, null) } } From c8d0e160dadf3b23c5caa379ba9ad5547794eaa0 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sat, 5 Dec 2015 15:49:51 -0800 Subject: [PATCH 1133/2089] [SPARK-11774][SPARKR] Implement struct(), encode(), decode() functions in SparkR. Author: Sun Rui Closes #9804 from sun-rui/SPARK-11774. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/functions.R | 59 ++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 12 +++++++ R/pkg/inst/tests/test_sparkSQL.R | 37 ++++++++++++++++---- 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 43e5e0119e7fe..565a2b1a68b5f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -134,8 +134,10 @@ exportMethods("%in%", "datediff", "dayofmonth", "dayofyear", + "decode", "dense_rank", "desc", + "encode", "endsWith", "exp", "explode", @@ -225,6 +227,7 @@ exportMethods("%in%", "stddev", "stddev_pop", "stddev_samp", + "struct", "sqrt", "startsWith", "substr", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b30331c61c9a7..7432cb8e7ccf6 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -357,6 +357,40 @@ setMethod("dayofyear", column(jc) }) +#' decode +#' +#' Computes the first argument into a string from a binary using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname decode +#' @name decode +#' @family string_funcs +#' @export +#' @examples \dontrun{decode(df$c, "UTF-8")} +setMethod("decode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset) + column(jc) + }) + +#' encode +#' +#' Computes the first argument into a binary from a string using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname encode +#' @name encode +#' @family string_funcs +#' @export +#' @examples \dontrun{encode(df$c, "UTF-8")} +setMethod("encode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset) + column(jc) + }) + #' exp #' #' Computes the exponential of the given value. @@ -1039,6 +1073,31 @@ setMethod("stddev_samp", column(jc) }) +#' struct +#' +#' Creates a new struct column that composes multiple input columns. +#' +#' @rdname struct +#' @name struct +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' struct(df$c, df$d) +#' struct("col1", "col2") +#' } +setMethod("struct", + signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "Column") { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...)) + } + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 711ce38f9e104..4b5f786d39461 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -744,10 +744,18 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) +#' @rdname decode +#' @export +setGeneric("decode", function(x, charset) { standardGeneric("decode") }) + #' @rdname dense_rank #' @export setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) +#' @rdname encode +#' @export +setGeneric("encode", function(x, charset) { standardGeneric("encode") }) + #' @rdname explode #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) @@ -1001,6 +1009,10 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @export setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) +#' @rdname struct +#' @export +setGeneric("struct", function(x, ...) { standardGeneric("struct") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1e7cb54099703..2d26b92ac7275 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -27,6 +27,11 @@ checkStructField <- function(actual, expectedName, expectedType, expectedNullabl expect_equal(actual$nullable(), expectedNullable) } +markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -551,11 +556,6 @@ test_that("collect() and take() on a DataFrame return the same number of rows an }) test_that("collect() support Unicode characters", { - markUtf8 <- function(s) { - Encoding(s) <- "UTF-8" - s - } - lines <- c("{\"name\":\"안녕하세요\"}", "{\"name\":\"您好\", \"age\":30}", "{\"name\":\"ã“ã‚“ã«ã¡ã¯\", \"age\":19}", @@ -933,8 +933,33 @@ test_that("column functions", { # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) + + # Test struct() + df <- createDataFrame(sqlContext, + list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + schema = c("a", "b", "c")) + result <- collect(select(df, struct("a", "c"))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) + expect_equal(result, expected) + + result <- collect(select(df, struct(df$a, df$b))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) + expect_equal(result, expected) + + # Test encode(), decode() + bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) + df <- createDataFrame(sqlContext, + list(list(markUtf8("大åƒä¸–界"), "utf-8", bytes)), + schema = c("a", "b", "c")) + result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) + expect_equal(result[[1]][[1]], bytes) + expect_equal(result[[2]], markUtf8("大åƒä¸–界")) }) -# + test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", From 895b6c474735d7e0a38283f92292daa5c35875ee Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 5 Dec 2015 16:00:12 -0800 Subject: [PATCH 1134/2089] [SPARK-11715][SPARKR] Add R support corr for Column Aggregration Need to match existing method signature Author: felixcheung Closes #9680 from felixcheung/rcorr. --- R/pkg/R/functions.R | 15 +++++++++++++++ R/pkg/R/generics.R | 2 +- R/pkg/R/stats.R | 9 +++++---- R/pkg/inst/tests/test_sparkSQL.R | 2 +- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7432cb8e7ccf6..25231451df3d2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -259,6 +259,21 @@ setMethod("column", function(x) { col(x) }) +#' corr +#' +#' Computes the Pearson Correlation Coefficient for two Columns. +#' +#' @rdname corr +#' @name corr +#' @family math_funcs +#' @export +#' @examples \dontrun{corr(df$c, df$d)} +setMethod("corr", signature(x = "Column"), + function(x, col2) { + stopifnot(class(col2) == "Column") + jc <- callJStatic("org.apache.spark.sql.functions", "corr", x@jc, col2@jc) + column(jc) + }) #' cos #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4b5f786d39461..acfd4841e19af 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -411,7 +411,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @rdname statfunctions #' @export -setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) +setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname summary #' @export diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index f79329b115404..d17cce9c756e2 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -77,7 +77,7 @@ setMethod("cov", #' Calculates the correlation of two columns of a DataFrame. #' Currently only supports the Pearson Correlation Coefficient. #' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. -#' +#' #' @param x A SparkSQL DataFrame #' @param col1 the name of the first column #' @param col2 the name of the second column @@ -95,8 +95,9 @@ setMethod("cov", #' corr <- corr(df, "title", "gender", method = "pearson") #' } setMethod("corr", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2, method = "pearson") { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "corr", col1, col2, method) }) @@ -109,7 +110,7 @@ setMethod("corr", #' #' @param x A SparkSQL DataFrame. #' @param cols A vector column names to search frequent items in. -#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. #' Should be greater than 1e-4. Default support = 0.01. #' @return a local R data.frame with the frequent items in each column #' @@ -131,7 +132,7 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), #' sampleBy #' #' Returns a stratified sample without replacement based on the fraction given on each stratum. -#' +#' #' @param x A SparkSQL DataFrame #' @param col column that defines strata #' @param fractions A named list giving sampling fraction for each stratum. If a stratum is diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 2d26b92ac7275..a5a234a02d9f2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -892,7 +892,7 @@ test_that("column functions", { c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c14 <- cume_dist() + ntile(1) + c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() # Test if base::rank() is exposed From 6979edf4e1a93caafa8d286692097dd377d7616d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 5 Dec 2015 16:39:01 -0800 Subject: [PATCH 1135/2089] [SPARK-12115][SPARKR] Change numPartitions() to getNumPartitions() to be consistent with Scala/Python Change ```numPartitions()``` to ```getNumPartitions()``` to be consistent with Scala/Python. Note: If we can not catch up with 1.6 release, it will be breaking change for 1.7 that we also need to explain in release note. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10123 from yanboliang/spark-12115. --- R/pkg/R/RDD.R | 55 ++++++++++++++++++++++--------------- R/pkg/R/generics.R | 6 +++- R/pkg/R/pairRDD.R | 4 +-- R/pkg/inst/tests/test_rdd.R | 10 +++---- 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 47945c2825da9..00c40c38cabc9 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -306,17 +306,28 @@ setMethod("checkpoint", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' numPartitions(rdd) # 2L +#' getNumPartitions(rdd) # 2L #'} -#' @rdname numPartitions +#' @rdname getNumPartitions +#' @aliases getNumPartitions,RDD-method +#' @noRd +setMethod("getNumPartitions", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "getNumPartitions") + }) + +#' Gets the number of partitions of an RDD, the same as getNumPartitions. +#' But this function has been deprecated, please use getNumPartitions. +#' +#' @rdname getNumPartitions #' @aliases numPartitions,RDD-method #' @noRd setMethod("numPartitions", signature(x = "RDD"), function(x) { - jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "partitions") - callJMethod(partitions, "size") + .Deprecated("getNumPartitions") + getNumPartitions(x) }) #' Collect elements of an RDD @@ -443,7 +454,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, numPartitions(x))) + collect(reduceByKey(ones, `+`, getNumPartitions(x))) }) #' Apply a function to all elements @@ -759,7 +770,7 @@ setMethod("take", resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- numPartitions(x) + numPartitions <- getNumPartitions(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -823,7 +834,7 @@ setMethod("first", #' @noRd setMethod("distinct", signature(x = "RDD"), - function(x, numPartitions = SparkR:::numPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -993,8 +1004,8 @@ setMethod("keyBy", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) -#' numPartitions(rdd) # 4 -#' numPartitions(repartition(rdd, 2L)) # 2 +#' getNumPartitions(rdd) # 4 +#' getNumPartitions(repartition(rdd, 2L)) # 2 #'} #' @rdname repartition #' @aliases repartition,RDD @@ -1014,8 +1025,8 @@ setMethod("repartition", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) -#' numPartitions(rdd) # 3 -#' numPartitions(coalesce(rdd, 1L)) # 1 +#' getNumPartitions(rdd) # 3 +#' getNumPartitions(coalesce(rdd, 1L)) # 1 #'} #' @rdname coalesce #' @aliases coalesce,RDD @@ -1024,7 +1035,7 @@ setMethod("coalesce", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::numPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1112,7 +1123,7 @@ setMethod("saveAsTextFile", #' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1144,7 +1155,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- numPartitions(newRdd) + numPartitions <- getNumPartitions(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1368,7 +1379,7 @@ setMethod("setName", setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1409,7 +1420,7 @@ setMethod("zipWithUniqueId", setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) if (n > 1) { nums <- collect(lapplyPartition(x, function(part) { @@ -1521,8 +1532,8 @@ setMethod("unionRDD", setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- numPartitions(x) - n2 <- numPartitions(other) + n1 <- getNumPartitions(x) + n2 <- getNumPartitions(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1588,7 +1599,7 @@ setMethod("cartesian", #' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) @@ -1620,7 +1631,7 @@ setMethod("subtract", #' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1661,7 +1672,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, numPartitions) + nPart <- sapply(rrdds, getNumPartitions) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index acfd4841e19af..29dd11f41ff5e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -133,7 +133,11 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname numPartitions +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + +# @rdname getNumPartitions # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 991bea4d2022d..334c11d2f89a1 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -750,7 +750,7 @@ setMethod("cogroup", #' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -818,7 +818,7 @@ setMethod("sortByKey", #' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 71aed2bb9d6a8..7423b4f2bed1f 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -28,8 +28,8 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(numPartitions(rdd), 2) - expect_equal(numPartitions(intRdd), 2) + expect_equal(getNumPartitions(rdd), 2) + expect_equal(getNumPartitions(intRdd), 2) }) test_that("first on RDD", { @@ -304,18 +304,18 @@ test_that("repartition/coalesce on RDDs", { # repartition r1 <- repartition(rdd, 2) - expect_equal(numPartitions(r1), 2L) + expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartition(rdd, 6) - expect_equal(numPartitions(r2), 6L) + expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) - expect_equal(numPartitions(r3), 1L) + expect_equal(getNumPartitions(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) From b6e8e63a0dbe471187a146c96fdaddc6b8a8e55e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 5 Dec 2015 22:51:05 -0800 Subject: [PATCH 1136/2089] [SPARK-12044][SPARKR] Fix usage of isnan, isNaN 1, Add ```isNaN``` to ```Column``` for SparkR. ```Column``` should has three related variable functions: ```isNaN, isNull, isNotNull```. 2, Replace ```DataFrame.isNaN``` with ```DataFrame.isnan``` at SparkR side. Because ```DataFrame.isNaN``` has been deprecated and will be removed at Spark 2.0. 3, Add ```isnull``` to ```DataFrame``` for SparkR. ```DataFrame``` should has two related functions: ```isnan, isnull```. cc shivaram sun-rui felixcheung Author: Yanbo Liang Closes #10037 from yanboliang/spark-12044. --- R/pkg/R/column.R | 2 +- R/pkg/R/functions.R | 26 +++++++++++++++++++------- R/pkg/R/generics.R | 8 ++++++-- R/pkg/inst/tests/test_sparkSQL.R | 6 +++++- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 20de3907b7dd9..7bb8ef2595b59 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -56,7 +56,7 @@ operators <- list( "&" = "and", "|" = "or", #, "!" = "unary_$bang" "^" = "pow" ) -column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") createOperator <- function(op) { diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 25231451df3d2..09e4e04335a33 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -537,19 +537,31 @@ setMethod("initcap", column(jc) }) -#' isNaN +#' is.nan #' -#' Return true iff the column is NaN. +#' Return true if the column is NaN, alias for \link{isnan} #' -#' @rdname isNaN -#' @name isNaN +#' @rdname is.nan +#' @name is.nan #' @family normal_funcs #' @export -#' @examples \dontrun{isNaN(df$c)} -setMethod("isNaN", +#' @examples +#' \dontrun{ +#' is.nan(df$c) +#' isnan(df$c) +#' } +setMethod("is.nan", + signature(x = "Column"), + function(x) { + isnan(x) + }) + +#' @rdname is.nan +#' @name isnan +setMethod("isnan", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 29dd11f41ff5e..c383e6e78b8b4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -625,6 +625,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -808,9 +812,9 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname isNaN +#' @rdname is.nan #' @export -setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname kurtosis #' @export diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index a5a234a02d9f2..6ef03ae97635e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -883,7 +883,7 @@ test_that("column functions", { c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) c3 <- cosh(c) + count(c) + crc32(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) - c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) @@ -894,6 +894,10 @@ test_that("column functions", { c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() + c16 <- is.nan(c) + isnan(c) + isNaN(c) + + # Test if base::is.nan() is exposed + expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") From 04b6799932707f0a4aa4da0f2fc838bdb29794ce Mon Sep 17 00:00:00 2001 From: gcc Date: Sun, 6 Dec 2015 16:27:40 +0000 Subject: [PATCH 1137/2089] [SPARK-12048][SQL] Prevent to close JDBC resources twice Author: gcc Closes #10101 from rh99/master. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b9dd7f6b4099b..1c348ed62fc78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -511,6 +511,7 @@ private[sql] class JDBCRDD( } catch { case e: Exception => logWarning("Exception closing connection", e) } + closed = true } override def hasNext: Boolean = { From 49efd03bacad6060d99ed5e2fe53ba3df1d1317e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 Dec 2015 11:15:02 -0800 Subject: [PATCH 1138/2089] [SPARK-12138][SQL] Escape \u in the generated comments of codegen When \u appears in a comment block (i.e. in /**/), code gen will break. So, in Expression and CodegenFallback, we escape \u to \\u. yhuai Please review it. I did reproduce it and it works after the fix. Thanks! Author: gatorsmile Closes #10155 from gatorsmile/escapeU. --- .../spark/sql/catalyst/expressions/Expression.scala | 4 +++- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 614f0c075fd23..6d807c9ecf302 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -220,7 +220,9 @@ abstract class Expression extends TreeNode[Expression] { * Returns the string representation of this expression that is safe to be put in * code comments of generated code. */ - protected def toCommentSafeString: String = this.toString.replace("*/", "\\*\\/") + protected def toCommentSafeString: String = this.toString + .replace("*/", "\\*\\/") + .replace("\\u", "\\\\u") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index fe754240dcd67..cd2ef7dcd0cd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -107,4 +107,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { true, InternalRow(UTF8String.fromString("*/"))) } + + test("\\u in the data") { + // When \ u appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape \ u to \\u. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)), + true, + InternalRow(UTF8String.fromString("\\u"))) + } } From 80a824d36eec9d9a9f092ee1741453851218ec73 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 6 Dec 2015 17:35:01 -0800 Subject: [PATCH 1139/2089] [SPARK-12152][PROJECT-INFRA] Speed up Scalastyle checks by only invoking SBT once Currently, `dev/scalastyle` invokes SBT four times, but these invocations can be replaced with a single invocation, saving about one minute of build time. Author: Josh Rosen Closes #10151 from JoshRosen/speed-up-scalastyle. --- dev/scalastyle | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/dev/scalastyle b/dev/scalastyle index ad93f7e85b27c..8fd3604b9f451 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,14 +17,17 @@ # limitations under the License. # -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver scalastyle > scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt -# Check style with YARN built too -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 scalastyle >> scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 test:scalastyle >> scalastyle.txt - -ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') -rm scalastyle.txt +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pyarn \ + -Phive \ + -Phive-thriftserver \ + scalastyle test:scalastyle \ + | awk '{if($1~/error/)print}' \ +) if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" From 6fd9e70e3ed43836a0685507fff9949f921234f4 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 7 Dec 2015 00:21:55 -0800 Subject: [PATCH 1140/2089] [SPARK-12106][STREAMING][FLAKY-TEST] BatchedWAL test transiently flaky when Jenkins load is high We need to make sure that the last entry is indeed the last entry in the queue. Author: Burak Yavuz Closes #10110 from brkyvz/batch-wal-test-fix. --- .../streaming/util/BatchedWriteAheadLog.scala | 6 ++++-- .../spark/streaming/util/WriteAheadLogSuite.scala | 14 ++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 7158abc08894a..b2cd524f28b74 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -166,10 +166,12 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp var segment: WriteAheadLogRecordHandle = null if (buffer.length > 0) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // threads may not be able to add items in order by time + val sortedByTime = buffer.sortBy(_.time) // We take the latest record for the timestamp. Please refer to the class Javadoc for // detailed explanation - val time = buffer.last.time - segment = wrappedLog.write(aggregate(buffer), time) + val time = sortedByTime.last.time + segment = wrappedLog.write(aggregate(sortedByTime), time) } buffer.foreach(_.promise.success(segment)) } catch { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index eaa88ea3cd380..ef1e89df31305 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -480,7 +480,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( p } - test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { + test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") { val blockingWal = new BlockingWriteAheadLog(wal, walHandle) val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) @@ -500,8 +500,14 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // rest of the records will be batched while it takes time for 3 to get written writeAsync(batchedWal, event2, 5L) writeAsync(batchedWal, event3, 8L) - writeAsync(batchedWal, event4, 12L) - writeAsync(batchedWal, event5, 10L) + // we would like event 5 to be written before event 4 in order to test that they get + // sorted before being aggregated + writeAsync(batchedWal, event5, 12L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 3) + } + writeAsync(batchedWal, event4, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) assert(batchedWal.invokePrivate(queueLength()) === 4) @@ -517,7 +523,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) - verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L)) val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) assert(records.toSet === queuedEvents) } From 9cde7d5fa87e7ddfff0b9c1212920a1d9000539b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Dec 2015 10:34:18 -0800 Subject: [PATCH 1141/2089] [SPARK-12032] [SQL] Re-order inner joins to do join with conditions first Currently, the order of joins is exactly the same as SQL query, some conditions may not pushed down to the correct join, then those join will become cross product and is extremely slow. This patch try to re-order the inner joins (which are common in SQL query), pick the joins that have self-contain conditions first, delay those that does not have conditions. After this patch, the TPCDS query Q64/65 can run hundreds times faster. cc marmbrus nongli Author: Davies Liu Closes #10073 from davies/reorder_joins. --- .../sql/catalyst/optimizer/Optimizer.scala | 56 ++++++++++- .../sql/catalyst/planning/patterns.scala | 40 +++++++- .../catalyst/optimizer/JoinOrderSuite.scala | 95 +++++++++++++++++++ 3 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 06d14fcf8b9c2..f6088695a9276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,14 +18,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet + import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -44,6 +42,7 @@ object DefaultOptimizer extends Optimizer { // Operator push down SetOperationPushDown, SamplePushDown, + ReorderJoin, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -711,6 +710,53 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } } +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + .exists(_.references.subsetOf(refs)) + } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f4f11406d7c4..cd3f15cbe107b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * A pattern that matches any number of project or filter operations on top of another relational @@ -132,6 +131,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } } +/** + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ +object ExtractFiltersAndInnerJoins extends PredicateHelper { + + // flatten all inner joins, which are next to each other + def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { + case Join(left, right, Inner, cond) => + val (plans, conditions) = flattenJoin(left) + (plans ++ Seq(right), conditions ++ cond.toSeq) + + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq(plan), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + Some(flattenJoin(f)) + case j @ Join(_, _, Inner, _) => + Some(flattenJoin(j)) + case _ => None + } +} + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala new file mode 100644 index 0000000000000..9b1e16c727647 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class JoinOrderSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", Once, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + ReorderJoin, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, + ProjectCollapsing) :: Nil + + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + + test("extract filters and joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) { + assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + } + + testExtract(x, None) + testExtract(x.where("x.b".attr === 1), None) + testExtract(x.join(y), Some(Seq(x, y), Seq())) + testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq())) + testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z), + Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq())) + testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), + Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr))) + } + + test("reorder inner joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + x.join(y).join(z) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + .analyze + + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + } +} From 39d677c8f1ee7ebd7e142bec0415cf8f90ac84b6 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 7 Dec 2015 10:38:17 -0800 Subject: [PATCH 1142/2089] [SPARK-12034][SPARKR] Eliminate warnings in SparkR test cases. This PR: 1. Suppress all known warnings. 2. Cleanup test cases and fix some errors in test cases. 3. Fix errors in HiveContext related test cases. These test cases are actually not run previously due to a bug of creating TestHiveContext. 4. Support 'testthat' package version 0.11.0 which prefers that test cases be under 'tests/testthat' 5. Make sure the default Hadoop file system is local when running test cases. 6. Turn on warnings into errors. Author: Sun Rui Closes #10030 from sun-rui/SPARK-12034. --- R/pkg/inst/tests/{ => testthat}/jarTest.R | 0 .../tests/{ => testthat}/packageInAJarTest.R | 0 R/pkg/inst/tests/{ => testthat}/test_Serde.R | 0 .../tests/{ => testthat}/test_binaryFile.R | 0 .../{ => testthat}/test_binary_function.R | 0 .../tests/{ => testthat}/test_broadcast.R | 0 R/pkg/inst/tests/{ => testthat}/test_client.R | 0 .../inst/tests/{ => testthat}/test_context.R | 0 .../tests/{ => testthat}/test_includeJAR.R | 2 +- .../{ => testthat}/test_includePackage.R | 0 R/pkg/inst/tests/{ => testthat}/test_mllib.R | 14 ++-- .../{ => testthat}/test_parallelize_collect.R | 0 R/pkg/inst/tests/{ => testthat}/test_rdd.R | 0 .../inst/tests/{ => testthat}/test_shuffle.R | 0 .../inst/tests/{ => testthat}/test_sparkSQL.R | 68 +++++++++++-------- R/pkg/inst/tests/{ => testthat}/test_take.R | 0 .../inst/tests/{ => testthat}/test_textFile.R | 0 R/pkg/inst/tests/{ => testthat}/test_utils.R | 0 R/pkg/tests/run-all.R | 3 + R/run-tests.sh | 2 +- 20 files changed, 50 insertions(+), 39 deletions(-) rename R/pkg/inst/tests/{ => testthat}/jarTest.R (100%) rename R/pkg/inst/tests/{ => testthat}/packageInAJarTest.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_Serde.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_binaryFile.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_binary_function.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_broadcast.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_client.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_context.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_includeJAR.R (94%) rename R/pkg/inst/tests/{ => testthat}/test_includePackage.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_mllib.R (90%) rename R/pkg/inst/tests/{ => testthat}/test_parallelize_collect.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_rdd.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_shuffle.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_sparkSQL.R (97%) rename R/pkg/inst/tests/{ => testthat}/test_take.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_textFile.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_utils.R (100%) diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R similarity index 100% rename from R/pkg/inst/tests/jarTest.R rename to R/pkg/inst/tests/testthat/jarTest.R diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R similarity index 100% rename from R/pkg/inst/tests/packageInAJarTest.R rename to R/pkg/inst/tests/testthat/packageInAJarTest.R diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R similarity index 100% rename from R/pkg/inst/tests/test_Serde.R rename to R/pkg/inst/tests/testthat/test_Serde.R diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R similarity index 100% rename from R/pkg/inst/tests/test_binaryFile.R rename to R/pkg/inst/tests/testthat/test_binaryFile.R diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R similarity index 100% rename from R/pkg/inst/tests/test_binary_function.R rename to R/pkg/inst/tests/testthat/test_binary_function.R diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R similarity index 100% rename from R/pkg/inst/tests/test_broadcast.R rename to R/pkg/inst/tests/testthat/test_broadcast.R diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/testthat/test_client.R similarity index 100% rename from R/pkg/inst/tests/test_client.R rename to R/pkg/inst/tests/testthat/test_client.R diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/testthat/test_context.R similarity index 100% rename from R/pkg/inst/tests/test_context.R rename to R/pkg/inst/tests/testthat/test_context.R diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/testthat/test_includeJAR.R similarity index 94% rename from R/pkg/inst/tests/test_includeJAR.R rename to R/pkg/inst/tests/testthat/test_includeJAR.R index cc1faeabffe30..f89aa8e507fd5 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/testthat/test_includeJAR.R @@ -20,7 +20,7 @@ runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) - scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/testthat/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, args = c(jarPath, scriptPath), diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R similarity index 100% rename from R/pkg/inst/tests/test_includePackage.R rename to R/pkg/inst/tests/testthat/test_includePackage.R diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R similarity index 90% rename from R/pkg/inst/tests/test_mllib.R rename to R/pkg/inst/tests/testthat/test_mllib.R index e0667e5e22c18..08099dd96a87b 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -26,7 +26,7 @@ sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) test_that("glm and predict", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) test <- select(training, "Sepal_Length") model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) @@ -39,7 +39,7 @@ test_that("glm and predict", { }) test_that("glm should work with long formula", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -51,7 +51,7 @@ test_that("glm should work with long formula", { }) test_that("predictions match with native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) @@ -59,7 +59,7 @@ test_that("predictions match with native glm", { }) test_that("dot minus and intercept vs native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ . - Species + 0, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) @@ -67,7 +67,7 @@ test_that("dot minus and intercept vs native glm", { }) test_that("feature interaction vs native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) @@ -75,7 +75,7 @@ test_that("feature interaction vs native glm", { }) test_that("summary coefficients match with native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) coefs <- unlist(stats$coefficients) devianceResiduals <- unlist(stats$devianceResiduals) @@ -92,7 +92,7 @@ test_that("summary coefficients match with native glm", { }) test_that("summary coefficients match with native glm of family 'binomial'", { - df <- createDataFrame(sqlContext, iris) + df <- suppressWarnings(createDataFrame(sqlContext, iris)) training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R similarity index 100% rename from R/pkg/inst/tests/test_parallelize_collect.R rename to R/pkg/inst/tests/testthat/test_parallelize_collect.R diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R similarity index 100% rename from R/pkg/inst/tests/test_rdd.R rename to R/pkg/inst/tests/testthat/test_rdd.R diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R similarity index 100% rename from R/pkg/inst/tests/test_shuffle.R rename to R/pkg/inst/tests/testthat/test_shuffle.R diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R similarity index 97% rename from R/pkg/inst/tests/test_sparkSQL.R rename to R/pkg/inst/tests/testthat/test_sparkSQL.R index 6ef03ae97635e..39fc94aea5fb1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -133,38 +133,45 @@ test_that("create DataFrame from RDD", { expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- jsonFile(sqlContext, jsonPathNa) - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - insertInto(df, "people") - expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) - expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) - schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) - df2 <- createDataFrame(sqlContext, df.toRDD, schema) - df2AsDF <- as.DataFrame(sqlContext, df.toRDD, schema) + df <- read.df(sqlContext, jsonPathNa, "json", schema) + df2 <- createDataFrame(sqlContext, toRDD(df), schema) + df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) expect_equal(columns(df2), c("name", "age", "height")) expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) - expect_equal(collect(where(df2AsDF, df2$name == "Bob")), c("Bob", 16, 176.5)) + expect_equal(as.list(collect(where(df2, df2$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + expect_equal(as.list(collect(where(df2AsDF, df2AsDF$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) localDF <- data.frame(name=c("John", "Smith", "Sarah"), - age=c(19, 23, 18), - height=c(164.10, 181.4, 173.7)) + age=c(19L, 23L, 18L), + height=c(176.5, 181.4, 173.7)) df <- createDataFrame(sqlContext, localDF, schema) expect_is(df, "DataFrame") expect_equal(count(df), 3) expect_equal(columns(df), c("name", "age", "height")) expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) + expect_equal(as.list(collect(where(df, df$name == "John"))), + list(name = "John", age = 19L, height = 176.5)) + + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + df <- read.df(hiveCtx, jsonPathNa, "json", schema) + invisible(insertInto(df, "people")) + expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + c(16)) + expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + c(176.5)) }) test_that("convert NAs to null type in DataFrames", { @@ -250,7 +257,7 @@ test_that("create DataFrame from list or data.frame", { ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) - irisdf <- createDataFrame(sqlContext, iris) + irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) iris_collected <- collect(irisdf) expect_equivalent(iris_collected[,-5], iris[,-5]) expect_equal(iris_collected$Species, as.character(iris$Species)) @@ -463,7 +470,7 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) expect_is(unioned, "RDD") - expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(getSerializedMode(unioned), "byte") expect_equal(collect(unioned)[[2]]$name, "Andy") }) @@ -485,13 +492,13 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { unionByte <- unionRDD(rdd, dfRDD) expect_is(unionByte, "RDD") - expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(getSerializedMode(unionByte), "byte") expect_equal(collect(unionByte)[[1]], 1) expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) expect_is(unionString, "RDD") - expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(getSerializedMode(unionString), "byte") expect_equal(collect(unionString)[[1]], "Michael") expect_equal(collect(unionString)[[5]]$name, "Andy") }) @@ -504,7 +511,7 @@ test_that("objectFile() works with row serialization", { objectIn <- objectFile(sc, objectPath) expect_is(objectIn, "RDD") - expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -849,6 +856,7 @@ test_that("write.df() as parquet file", { }) test_that("test HiveContext", { + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) }, @@ -863,10 +871,10 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - saveAsTable(df, "json", "json", "append", path = jsonPath2) - df3 <- sql(hiveCtx, "select * from json") + invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") - expect_equal(count(df3), 6) + expect_equal(count(df3), 3) }) test_that("column operators", { @@ -1311,7 +1319,7 @@ test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") - expect_equal(SparkR:::getSerializedMode(testRDD), "string") + expect_equal(getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -1641,7 +1649,7 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) -irisDF <- createDataFrame(sqlContext, iris) +irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF), collect(irisDF)) @@ -1670,7 +1678,7 @@ test_that("attach() on a DataFrame", { }) test_that("with() on a DataFrame", { - df <- createDataFrame(sqlContext, iris) + df <- suppressWarnings(createDataFrame(sqlContext, iris)) expect_error(Sepal_Length) sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/testthat/test_take.R similarity index 100% rename from R/pkg/inst/tests/test_take.R rename to R/pkg/inst/tests/testthat/test_take.R diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R similarity index 100% rename from R/pkg/inst/tests/test_textFile.R rename to R/pkg/inst/tests/testthat/test_textFile.R diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R similarity index 100% rename from R/pkg/inst/tests/test_utils.R rename to R/pkg/inst/tests/testthat/test_utils.R diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 4f8a1ed2d83ef..1d04656ac2594 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -18,4 +18,7 @@ library(testthat) library(SparkR) +# Turn all warnings into errors +options("warn" = 2) + test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..e64a4ea94c584 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then From ef3f047c07ef0ac4a3a97e6bc11e1c28c6c8f9a0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Dec 2015 11:00:25 -0800 Subject: [PATCH 1143/2089] [SPARK-12132] [PYSPARK] raise KeyboardInterrupt inside SIGINT handler Currently, the current line is not cleared by Cltr-C After this patch ``` >>> asdfasdf^C Traceback (most recent call last): File "~/spark/python/pyspark/context.py", line 225, in signal_handler raise KeyboardInterrupt() KeyboardInterrupt ``` It's still worse than 1.5 (and before). Author: Davies Liu Closes #10134 from davies/fix_cltrc. --- python/pyspark/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 77710a13394c6..529d16b480399 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # create a signal handler which would be invoked on receiving SIGINT def signal_handler(signal, frame): self.cancelAllJobs() + raise KeyboardInterrupt() # see http://stackoverflow.com/questions/23206787/ if isinstance(threading.current_thread(), threading._MainThread): From 5d80d8c6a54b2113022eff31187e6d97521bd2cf Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Dec 2015 11:03:59 -0800 Subject: [PATCH 1144/2089] [SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004). While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected. Author: Tathagata Das Closes #9988 from tdas/SPARK-11932. --- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/dstream/TrackStateDStream.scala | 39 ++-- .../spark/streaming/rdd/TrackStateRDD.scala | 29 ++- .../spark/streaming/CheckpointSuite.scala | 189 +++++++++++++----- .../spark/streaming/TestSuiteBase.scala | 6 + .../streaming/TrackStateByKeySuite.scala | 77 +++++-- 6 files changed, 258 insertions(+), 84 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index fd0e8d5d690b6..d0046afdeb447 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -277,7 +277,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 0ada1111ce30a..ea6213420e7ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, V, S, E]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, validTime - ) + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + TrackStateRDD.createFromRDD[K, V, S, E]( + rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) + } else { + rdd + } + case None => + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime + ) } + // Compute the new state RDD with previous state RDD and partitioned data RDD - parent.getOrCompute(validTime).map { dataRDD => - val partitionedDataRDD = dataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds } + Some(new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 7050378d0feb0..30aafcf1460e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: private[streaming] object TrackStateRDD { - def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, T] = { + updateTime: Time): TrackStateRDD[K, V, S, E] = { val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } + + def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + rdd: RDD[(K, S, Long)], + partitioner: Partitioner, + updateTime: Time): TrackStateRDD[K, V, S, E] = { + + val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, (state, updateTime)) => + stateMap.put(key, state, updateTime) + } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index b1cbc7163bee3..cd28d3cf408d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -33,17 +33,149 @@ import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.TestUtils +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Milliseconds(500), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that number of batches to be run = number of inputs + val totalNumBatches = input.size + val batchDurationMillis = batchDuration.milliseconds + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + operatedStream.print() + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + val beforeRestartOutput = generateOutput[V](ssc, + Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest) + assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true) + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + + val restartedSsc = new StreamingContext(checkpointDir) + val afterRestartOutput = generateOutput[V](restartedSsc, + Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest) + assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) + } + + private def generateOutput[V: ClassTag]( + ssc: StreamingContext, + targetBatchTime: Time, + checkpointDir: String, + stopSparkContext: Boolean + ): Seq[Seq[V]] = { + try { + val batchDuration = ssc.graph.batchDuration + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val currentTime = clock.getTimeMillis() + + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) + clock.setTime(targetBatchTime.milliseconds) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = ssc.graph.getOutputStreams().filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(batchCounter.getLastCompletedBatchTime === targetBatchTime) + } + + eventually(timeout(10 seconds)) { + val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { + _.toString.contains(clock.getTimeMillis.toString) + } + // Checkpoint files are written twice for every batch interval. So assert that both + // are written to make sure that both of them have been written. + assert(checkpointFilesOfLatestTime.size === 2) + } + outputStream.output.map(_.flatten) + + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } + + private def assertOutput[V: ClassTag]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + beforeRestart: Boolean): Unit = { + val expectedPartialOutput = if (beforeRestart) { + expectedOutput.take(output.size) + } else { + expectedOutput.takeRight(output.size) + } + val setComparison = output.zip(expectedPartialOutput).forall { + case (o, e) => o.toSet === e.toSet + } + assert(setComparison, s"set comparison failed\n" + + s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" + + s"Generated output items: ${output.mkString("\n")}" + ) + } +} + /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { var ssc: StreamingContext = null @@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase { override def afterFunction() { super.afterFunction() - if (ssc != null) ssc.stop() + if (ssc != null) { ssc.stop() } Utils.deleteRecursively(new File(checkpointDir)) } @@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase { Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), - Seq(("", 2)), Seq() ), + Seq(("", 2)), + Seq() + ), 3 ) } @@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase { checkpointWriter.stop() } - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) - - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null - } - /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index a45c92d9c7bc8..be0f4636a6cb8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) { // All access to this state should be guarded by `BatchCounter.this.synchronized` private var numCompletedBatches = 0 private var numStartedBatches = 0 + private var lastCompletedBatchTime: Time = null private val listener = new StreamingListener { override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) { override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = BatchCounter.this.synchronized { numCompletedBatches += 1 + lastCompletedBatchTime = batchCompleted.batchInfo.batchTime BatchCounter.this.notifyAll() } } @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) { numStartedBatches } + def getLastCompletedBatchTime: Time = this.synchronized { + lastCompletedBatchTime + } + /** * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 58aef74c0040f..1fc320d31b18b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -25,31 +25,27 @@ import scala.reflect.ClassTag import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { +class TrackStateByKeySuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null - private var ssc: StreamingContext = null - private var checkpointDir: File = null - private val batchDuration = Seconds(1) + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) before { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) - } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } checkpointDir = Utils.createTempDir("checkpoint") - - ssc = new StreamingContext(sc, batchDuration) - ssc.checkpoint(checkpointDir.toString) } after { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } } override def beforeAll(): Unit = { @@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(dstreamImpl.stateClass === classOf[Double]) assert(dstreamImpl.emittedClass === classOf[Long]) } - + val ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types @@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef expectedCheckpointDuration: Duration, explicitCheckpointDuration: Option[Duration] = None ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + try { - ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) val dummyFunc = (value: Option[Int], state: State[Int]) => 0 val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) @@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef trackStateStream.checkpoint(d) } trackStateStream.register() + ssc.checkpoint(checkpointDir.toString) ssc.start() // should initialize all the checkpoint durations assert(trackStateStream.checkpointDuration === null) assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) } finally { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + ssc.stop(stopSparkContext = false) } } @@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) } + + test("trackStateByKey - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + trackStateStream.checkpoint(checkpointDuration) + trackStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], trackStateSpec: StateSpec[K, Int, S, T], @@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] @@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef stateSnapshotStream.register() val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds * numBatches) batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) (collectedOutputs, collectedStateSnapshots) } From 3f4efb5c23b029496b112760fa062ff070c20334 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 7 Dec 2015 12:01:09 -0800 Subject: [PATCH 1145/2089] [SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize Merged #10051 again since #10083 is resolved. This reverts commit 328b757d5d4486ea3c2e246780792d7a57ee85e5. Author: Shixiong Zhu Closes #10167 from zsxwing/merge-SPARK-12060. --- .../spark/serializer/JavaSerializer.scala | 7 ++--- .../spark/util/ByteBufferOutputStream.scala | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7d..ea718a0edbe71 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 0000000000000..92e45224db81c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} From 871e85d9c14c6b19068cc732951a8ae8db61b411 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 7 Dec 2015 13:16:47 -0800 Subject: [PATCH 1146/2089] [SPARK-11963][DOC] Add docs for QuantileDiscretizer https://issues.apache.org/jira/browse/SPARK-11963 Author: Xusen Yin Closes #9962 from yinxusen/SPARK-11963. --- docs/ml-features.md | 65 +++++++++++++++++ .../ml/JavaQuantileDiscretizerExample.java | 71 +++++++++++++++++++ .../ml/QuantileDiscretizerExample.scala | 49 +++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 05c2c96c5ec5a..b499d6ec3b90e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1705,6 +1705,71 @@ print(output.select("features", "clicked").first()) +## QuantileDiscretizer + +`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned +categorical features. +The bin ranges are chosen by taking a sample of the data and dividing it into roughly equal parts. +The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. +This attempts to find `numBuckets` partitions based on a sample of the given input data, but it may +find fewer depending on the data sample values. + +Note that the result may be different every time you run it, since the sample strategy behind it is +non-deterministic. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`: + +~~~ + id | hour +----|------ + 0 | 18.0 +----|------ + 1 | 19.0 +----|------ + 2 | 8.0 +----|------ + 3 | 5.0 +----|------ + 4 | 2.2 +~~~ + +`hour` is a continuous feature with `Double` type. We want to turn the continuous feature into +categorical one. Given `numBuckets = 3`, we should get the following DataFrame: + +~~~ + id | hour | result +----|------|------ + 0 | 18.0 | 2.0 +----|------|------ + 1 | 19.0 | 2.0 +----|------|------ + 2 | 8.0 | 1.0 +----|------|------ + 3 | 5.0 | 1.0 +----|------|------ + 4 | 2.2 | 0.0 +~~~ + +
    +
    + +Refer to the [QuantileDiscretizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.QuantileDiscretizer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala %} +
    + +
    + +Refer to the [QuantileDiscretizer Java docs](api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java %} +
    +
    + # Feature Selectors ## VectorSlicer diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java new file mode 100644 index 0000000000000..251ae79d9a108 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.QuantileDiscretizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaQuantileDiscretizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize( + Arrays.asList( + RowFactory.create(0, 18.0), + RowFactory.create(1, 19.0), + RowFactory.create(2, 8.0), + RowFactory.create(3, 5.0), + RowFactory.create(4, 2.2) + ) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("hour", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + QuantileDiscretizer discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3); + + DataFrame result = discretizer.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala new file mode 100644 index 0000000000000..8f29b7eaa6d26 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.QuantileDiscretizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object QuantileDiscretizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("QuantileDiscretizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) + val df = sc.parallelize(data).toDF("id", "hour") + + val discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3) + + val result = discretizer.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 84b809445f39b9030f272528bdaa39d1559cbc6e Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 7 Dec 2015 14:58:09 -0800 Subject: [PATCH 1147/2089] [SPARK-11884] Drop multiple columns in the DataFrame API See the thread Ben started: http://search-hadoop.com/m/q3RTtveEuhjsr7g/ This PR adds drop() method to DataFrame which accepts multiple column names Author: tedyu Closes #9862 from ted-yu/master. --- .../org/apache/spark/sql/DataFrame.scala | 24 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 7 ++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eb8700369275e..243a8c853f90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1261,16 +1261,24 @@ class DataFrame private[sql]( * @since 1.4.0 */ def drop(colName: String): DataFrame = { + drop(Seq(colName) : _*) + } + + /** + * Returns a new [[DataFrame]] with columns dropped. + * This is a no-op if schema doesn't contain column name(s). + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def drop(colNames: String*): DataFrame = { val resolver = sqlContext.analyzer.resolver - val shouldDrop = schema.exists(f => resolver(f.name, colName)) - if (shouldDrop) { - val colsAfterDrop = schema.filter { field => - val name = field.name - !resolver(name, colName) - }.map(f => Column(f.name)) - select(colsAfterDrop : _*) - } else { + val remainingCols = + schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) + if (remainingCols.size == this.schema.size) { this + } else { + this.select(remainingCols: _*) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 76e9648aa7533..605a6549dd686 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -378,6 +378,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("value")) } + test("drop columns using drop") { + val src = Seq((0, 2, 3)).toDF("a", "b", "c") + val df = src.drop("a", "b") + checkAnswer(df, Row(3)) + assert(df.schema.map(_.name) === Seq("c")) + } + test("drop unknown column (no-op)") { val df = testData.drop("random") checkAnswer( From 36282f78b888743066843727426c6d806231aa97 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 7 Dec 2015 15:01:00 -0800 Subject: [PATCH 1148/2089] [SPARK-12184][PYTHON] Make python api doc for pivot consistant with scala doc In SPARK-11946 the API for pivot was changed a bit and got updated doc, the doc changes were not made for the python api though. This PR updates the python doc to be consistent. Author: Andrew Ray Closes #10176 from aray/sql-pivot-python-doc. --- python/pyspark/sql/group.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 1911588309aff..9ca303a974cd4 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,16 +169,20 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and perform the specified aggregation. + """ + Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + There are two versions of pivot function: one that requires the caller to specify the list + of distinct values to pivot on, and one that does not. The latter is more concise but less + efficient, because Spark needs to first compute the list of distinct values internally. - :param pivot_col: Column to pivot - :param values: Optional list of values of pivot column that will be translated to columns in - the output DataFrame. If values are not provided the method will do an immediate call - to .distinct() on the pivot column. + :param pivot_col: Name of the column to pivot. + :param values: List of values that will be translated to columns in the output DataFrame. + // Compute the sum of earnings for each year by course with each course as a separate column >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + // Or without specifying column values (less efficient) >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ From 3e7e05f5ee763925ed60410d7de04cf36b723de1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 7 Dec 2015 16:37:09 -0800 Subject: [PATCH 1149/2089] [SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: mengxr yhuai Author: Joseph K. Bradley Closes #10161 from jkbradley/mllib-sqlcontext-fix. --- .../apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++---- .../classification/impl/GLMClassificationModel.scala | 4 ++-- .../spark/mllib/clustering/GaussianMixtureModel.scala | 4 ++-- .../org/apache/spark/mllib/clustering/KMeansModel.scala | 4 ++-- .../spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- .../mllib/recommendation/MatrixFactorizationModel.scala | 4 ++-- .../spark/mllib/regression/IsotonicRegression.scala | 4 ++-- .../spark/mllib/regression/impl/GLMRegressionModel.scala | 4 ++-- .../apache/spark/mllib/tree/model/DecisionTreeModel.scala | 4 ++-- .../spark/mllib/tree/model/treeEnsembleModels.scala | 4 ++-- 13 files changed, 29 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 54b03a9f90283..2aa6aec0b4347 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. - val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext) + val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) sqlContext.createDataFrame(indexedRowMatrix.rows) } @@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. - val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) + val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } @@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. - val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) sqlContext.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a956084ae06e8..aef9ef2cb052d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index fe09f6b75d28b..2910c027ae06d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 2115f7d99c182..74d13e4f77945 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a741584982725..91fa9b0d3590d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 7cd9b08fa8e0e..bb1804505948b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( @@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index d4d022afde051..eaa99cfe82e27 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b693f3c8e4bd9..23b1514e3080e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -587,7 +587,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -599,7 +599,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val vectorSize = model.values.head.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 46562eb2ad0f7..0dc40483dd0ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ec78ea24539b5..f235089873ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 317d3a5702636..02af281fb726b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 54c136aecf660..89c470d573431 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -201,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -242,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 90e032e3d9842..3f427f0be3af2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -408,7 +408,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -468,7 +468,7 @@ private[tree] object TreeEnsembleModel extends Logging { path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) From 78209b0ccaf3f22b5e2345dfb2b98edfdb746819 Mon Sep 17 00:00:00 2001 From: somideshmukh Date: Mon, 7 Dec 2015 23:26:34 -0800 Subject: [PATCH 1150/2089] [SPARK-11551][DOC][EXAMPLE] Replace example code in ml-features.md using include_example Made new patch contaning only markdown examples moved to exmaple/folder. Ony three java code were not shfted since they were contaning compliation error ,these classes are 1)StandardScale 2)NormalizerExample 3)VectorIndexer Author: Xusen Yin Author: somideshmukh Closes #10002 from somideshmukh/SomilBranch1.33. --- docs/ml-features.md | 1109 +---------------- .../examples/ml/JavaBinarizerExample.java | 68 + .../examples/ml/JavaBucketizerExample.java | 70 ++ .../spark/examples/ml/JavaDCTExample.java | 65 + .../ml/JavaElementwiseProductExample.java | 75 ++ .../examples/ml/JavaMinMaxScalerExample.java | 50 + .../spark/examples/ml/JavaNGramExample.java | 71 ++ .../examples/ml/JavaNormalizerExample.java | 52 + .../examples/ml/JavaOneHotEncoderExample.java | 77 ++ .../spark/examples/ml/JavaPCAExample.java | 71 ++ .../ml/JavaPolynomialExpansionExample.java | 71 ++ .../examples/ml/JavaRFormulaExample.java | 69 + .../ml/JavaStandardScalerExample.java | 53 + .../ml/JavaStopWordsRemoverExample.java | 65 + .../examples/ml/JavaStringIndexerExample.java | 66 + .../examples/ml/JavaTokenizerExample.java | 75 ++ .../ml/JavaVectorAssemblerExample.java | 67 + .../examples/ml/JavaVectorIndexerExample.java | 60 + .../examples/ml/JavaVectorSlicerExample.java | 73 ++ .../src/main/python/ml/binarizer_example.py | 43 + .../src/main/python/ml/bucketizer_example.py | 42 + .../python/ml/elementwise_product_example.py | 39 + examples/src/main/python/ml/n_gram_example.py | 42 + .../src/main/python/ml/normalizer_example.py | 41 + .../main/python/ml/onehot_encoder_example.py | 47 + examples/src/main/python/ml/pca_example.py | 42 + .../python/ml/polynomial_expansion_example.py | 43 + .../src/main/python/ml/rformula_example.py | 44 + .../main/python/ml/standard_scaler_example.py | 42 + .../python/ml/stopwords_remover_example.py | 40 + .../main/python/ml/string_indexer_example.py | 39 + .../src/main/python/ml/tokenizer_example.py | 44 + .../python/ml/vector_assembler_example.py | 42 + .../main/python/ml/vector_indexer_example.py | 39 + .../spark/examples/ml/BinarizerExample.scala | 48 + .../spark/examples/ml/BucketizerExample.scala | 51 + .../apache/spark/examples/ml/DCTExample.scala | 54 + .../ml/ElementWiseProductExample.scala | 53 + .../examples/ml/MinMaxScalerExample.scala | 49 + .../spark/examples/ml/NGramExample.scala | 47 + .../spark/examples/ml/NormalizerExample.scala | 50 + .../examples/ml/OneHotEncoderExample.scala | 58 + .../apache/spark/examples/ml/PCAExample.scala | 54 + .../ml/PolynomialExpansionExample.scala | 53 + .../spark/examples/ml/RFormulaExample.scala | 49 + .../examples/ml/StandardScalerExample.scala | 51 + .../examples/ml/StopWordsRemoverExample.scala | 48 + .../examples/ml/StringIndexerExample.scala | 49 + .../spark/examples/ml/TokenizerExample.scala | 54 + .../examples/ml/VectorAssemblerExample.scala | 49 + .../examples/ml/VectorIndexerExample.scala | 53 + .../examples/ml/VectorSlicerExample.scala | 58 + 52 files changed, 2806 insertions(+), 1058 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java create mode 100644 examples/src/main/python/ml/binarizer_example.py create mode 100644 examples/src/main/python/ml/bucketizer_example.py create mode 100644 examples/src/main/python/ml/elementwise_product_example.py create mode 100644 examples/src/main/python/ml/n_gram_example.py create mode 100644 examples/src/main/python/ml/normalizer_example.py create mode 100644 examples/src/main/python/ml/onehot_encoder_example.py create mode 100644 examples/src/main/python/ml/pca_example.py create mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py create mode 100644 examples/src/main/python/ml/rformula_example.py create mode 100644 examples/src/main/python/ml/standard_scaler_example.py create mode 100644 examples/src/main/python/ml/stopwords_remover_example.py create mode 100644 examples/src/main/python/ml/string_indexer_example.py create mode 100644 examples/src/main/python/ml/tokenizer_example.py create mode 100644 examples/src/main/python/ml/vector_assembler_example.py create mode 100644 examples/src/main/python/ml/vector_indexer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index b499d6ec3b90e..5105a948fec8e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,25 +170,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,44 +179,7 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    @@ -243,21 +188,7 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +237,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +245,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +253,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +270,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +278,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +286,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +303,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +311,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +319,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +333,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +341,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +349,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +363,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +371,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +379,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +403,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +411,7 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}}
    @@ -881,18 +466,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +474,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,16 +482,7 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    @@ -961,29 +496,7 @@ indexed.show() Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -991,45 +504,7 @@ encoded.select("id", "categoryVec").foreach(println) Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1037,24 +512,7 @@ DataFrame encoded = encoder.transform(indexed); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1078,23 +536,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1102,30 +544,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1133,17 +552,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1160,22 +569,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    @@ -1183,24 +577,7 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    @@ -1208,19 +585,7 @@ DataFrame lInfNormData = Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1244,23 +609,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    @@ -1268,25 +617,7 @@ val scaledData = scalerModel.transform(dataFrame) Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    @@ -1294,20 +625,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1337,21 +655,7 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) - -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %}
    @@ -1360,24 +664,7 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1401,23 +688,7 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    @@ -1425,38 +696,7 @@ val bucketedData = bucketizer.transform(dataFrame) Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    @@ -1464,19 +704,7 @@ DataFrame bucketedData = bucketizer.transform(dataFrame); Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1508,25 +736,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1534,41 +744,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1576,19 +752,8 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() - -{% endhighlight %} +{% include_example python/ml/elementwise_product_example.py %}
    - ## VectorAssembler @@ -1632,19 +797,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1652,36 +805,7 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    @@ -1689,19 +813,7 @@ System.out.println(output.select("features", "clicked").first()); Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler - -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %}
    @@ -1831,33 +943,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1865,41 +951,7 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); - -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - -DataFrame output = vectorSlicer.transform(dataset); - -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %}
    @@ -1936,21 +988,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -1958,38 +996,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -1997,21 +1004,7 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 0000000000000..9698cac504371 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 0000000000000..b06a23e76d604 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 0000000000000..35c0d534a45e9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 0000000000000..2898accec61b0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 0000000000000..138b3ab6aba44 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 0000000000000..8fd75ed8b5f4e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 0000000000000..6283a355e1fef --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 0000000000000..172a9cc6feb28 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 0000000000000..8282fab084f36 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 0000000000000..668f71e64056b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 0000000000000..1e1062b541ad9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 0000000000000..0cbdc97e8ae30 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 0000000000000..b6b201c6b68d2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 0000000000000..05d12c1e702f1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 0000000000000..617dc3f66e3bf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 0000000000000..7e230b5897c1e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 0000000000000..06b4bf6bf8ff6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 0000000000000..4d5cb04ff5e2b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 0000000000000..960ad208be12e --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 0000000000000..a12750aa9248a --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 0000000000000..c85cb0d89543c --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 0000000000000..f2d85f53e7219 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 0000000000000..833d93e976a7e --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 0000000000000..7529dfd09213a --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 0000000000000..8b66140a40a7a --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data,["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 0000000000000..030a6132a451a --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 0000000000000..b544a14700762 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 0000000000000..139acecbfb53f --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 0000000000000..01f94af8ca752 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 0000000000000..58a8cb5d56b73 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 0000000000000..ce9b225be5357 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 0000000000000..04f64839f188d --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 0000000000000..cc00d1454f2e0 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 0000000000000..e724aa587294b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 0000000000000..30c2776d39688 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 0000000000000..314c2c28a2a10 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala new file mode 100644 index 0000000000000..ac50bb7b2b155 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 0000000000000..dac3679a5bf7e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 0000000000000..8a85f71b56f3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 0000000000000..17571f0aad793 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 0000000000000..4512736943dd5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder().setInputCol("categoryIndex"). + setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 0000000000000..a18d4f33973d8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 0000000000000..b8e9e6952a5ea --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 0000000000000..286866edea502 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 0000000000000..646ce0f13ecf5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 0000000000000..655ffce08d3ab --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 0000000000000..1be8a5f33f7c0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 0000000000000..01e0d1388a2f4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 0000000000000..d527924419f81 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 0000000000000..14279d610fda8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 0000000000000..04f19829eff87 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 73896588dd3af6ba77c9692cd5120ee32448eb22 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 7 Dec 2015 23:34:16 -0800 Subject: [PATCH 1151/2089] Closes #10098 From 7d05a624510f7299b3dd07f87c203db1ff7caa3e Mon Sep 17 00:00:00 2001 From: Takahashi Hiroshi Date: Mon, 7 Dec 2015 23:46:55 -0800 Subject: [PATCH 1152/2089] [SPARK-10259][ML] Add @since annotation to ml.classification Add since annotation to ml.classification Author: Takahashi Hiroshi Closes #8534 from taishi-oss/issue10259. --- .../DecisionTreeClassifier.scala | 30 +++++++-- .../ml/classification/GBTClassifier.scala | 35 ++++++++-- .../classification/LogisticRegression.scala | 64 +++++++++++++++---- .../MultilayerPerceptronClassifier.scala | 23 +++++-- .../spark/ml/classification/NaiveBayes.scala | 19 ++++-- .../spark/ml/classification/OneVsRest.scala | 24 +++++-- .../RandomForestClassifier.scala | 34 ++++++++-- 7 files changed, 185 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index c478aea44ace8..8c4cec1326653 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest @@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeClassifier(override val uid: String) +final class DecisionTreeClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) // Override parameter setters from parent trait for Java API compatibility. + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { @@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.1") override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } @@ -104,12 +119,13 @@ object DecisionTreeClassifier { * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental final class DecisionTreeClassificationModel private[ml] ( - override val uid: String, - override val rootNode: Node, - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.4.0")override val uid: String, + @Since("1.4.0")override val rootNode: Node, + @Since("1.6.0")override val numFeatures: Int, + @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 74aef94bf7675..cda2bca58c50d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ +@Since("1.4.0") @Experimental -final class GBTClassifier(override val uid: String) +final class GBTClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this @@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = { logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") super.setSeed(value) @@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String) // Parameters from GBTParams: + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) // Parameters for GBTClassifier: @@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String) * (default = logistic) * @group param */ + @Since("1.4.0") val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", @@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String) setDefault(lossType -> "logistic") /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ + @Since("1.4.0") def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ @@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } + @Since("1.4.1") override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ + @Since("1.4.0") final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) } @@ -164,12 +185,13 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.6.0") @Experimental final class GBTClassificationModel private[ml]( - override val uid: String, + @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - override val numFeatures: Int) + @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { @@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml]( * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ + @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = this(uid, _trees, _treeWeights, -1) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml]( if (prediction > 0.0) 1.0 else 0.0 } + @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"GBTClassificationModel (uid=$uid) with $numTrees trees" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d320d64dd90d0..19cc323d5073f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -154,11 +154,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * Currently, this class only supports binary classification. It will support multiclass * in the future. */ +@Since("1.2.0") @Experimental -class LogisticRegression(override val uid: String) +class LogisticRegression @Since("1.2.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("logreg")) /** @@ -166,6 +169,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.2.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -176,6 +180,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -184,6 +189,7 @@ class LogisticRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.2.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -193,6 +199,7 @@ class LogisticRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) @@ -201,6 +208,7 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.4.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -213,11 +221,14 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold /** @@ -226,11 +237,14 @@ class LogisticRegression(override val uid: String) * Default is empty, so all instances have weight one. * @group setParam */ + @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(weightCol -> "") + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { @@ -384,11 +398,14 @@ class LogisticRegression(override val uid: String) model.setSummary(logRegSummary) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } +@Since("1.6.0") object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + @Since("1.6.0") override def load(path: String): LogisticRegression = super.load(path) } @@ -396,23 +413,28 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { * :: Experimental :: * Model produced by [[LogisticRegression]]. */ +@Since("1.4.0") @Experimental class LogisticRegressionModel private[ml] ( - override val uid: String, - val coefficients: Vector, - val intercept: Double) + @Since("1.4.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.3.0") val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds /** Margin (rawPrediction) for class label 1. For binary classification only. */ @@ -426,8 +448,10 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + @Since("1.6.0") override val numFeatures: Int = coefficients.size + @Since("1.3.0") override val numClasses: Int = 2 private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None @@ -436,6 +460,7 @@ class LogisticRegressionModel private[ml] ( * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. */ + @Since("1.5.0") def summary: LogisticRegressionTrainingSummary = trainingSummary match { case Some(summ) => summ case None => @@ -451,6 +476,7 @@ class LogisticRegressionModel private[ml] ( } /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") def hasSummary: Boolean = trainingSummary.isDefined /** @@ -493,6 +519,7 @@ class LogisticRegressionModel private[ml] ( Vectors.dense(-m, m) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) @@ -710,12 +737,13 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionTrainingSummary private[classification] ( - predictions: DataFrame, - probabilityCol: String, - labelCol: String, - featuresCol: String, - val objectiveHistory: Array[Double]) + @Since("1.5.0") predictions: DataFrame, + @Since("1.5.0") probabilityCol: String, + @Since("1.5.0") labelCol: String, + @Since("1.6.0") featuresCol: String, + @Since("1.5.0") val objectiveHistory: Array[Double]) extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { @@ -731,11 +759,13 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionSummary private[classification] ( - @transient override val predictions: DataFrame, - override val probabilityCol: String, - override val labelCol: String, - override val featuresCol: String) extends LogisticRegressionSummary { + @Since("1.5.0") @transient override val predictions: DataFrame, + @Since("1.5.0") override val probabilityCol: String, + @Since("1.5.0") override val labelCol: String, + @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { + private val sqlContext = predictions.sqlContext import sqlContext.implicits._ @@ -760,6 +790,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** @@ -768,6 +799,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** @@ -777,6 +809,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** @@ -785,6 +818,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") } @@ -797,6 +831,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") } @@ -809,6 +844,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index cd7462596dd9e..a691aa005ef54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} @@ -104,19 +104,23 @@ private object LabelConverter { * Each layer has sigmoid activation function, output layer has softmax. * Number of inputs has to be equal to the size of feature vectors. * Number of outputs has to be equal to the total number of labels. - * */ +@Since("1.5.0") @Experimental -class MultilayerPerceptronClassifier(override val uid: String) +class MultilayerPerceptronClassifier @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { + @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) /** @group setParam */ + @Since("1.5.0") def setLayers(value: Array[Int]): this.type = set(layers, value) /** @group setParam */ + @Since("1.5.0") def setBlockSize(value: Int): this.type = set(blockSize, value) /** @@ -124,6 +128,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @@ -132,14 +137,17 @@ class MultilayerPerceptronClassifier(override val uid: String) * Default is 1E-4. * @group setParam */ + @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** * Set the seed for weights initialization. * @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** @@ -173,14 +181,16 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model */ +@Since("1.5.0") @Experimental class MultilayerPerceptronClassificationModel private[ml] ( - override val uid: String, - val layers: Array[Int], - val weights: Vector) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val layers: Array[Int], + @Since("1.5.0") val weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { + @Since("1.6.0") override val numFeatures: Int = layers.head private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -200,6 +210,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index c512a2cb8bf3d..718f49d3aedcd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -72,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ +@Since("1.5.0") @Experimental -class NaiveBayes(override val uid: String) +class NaiveBayes @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) /** @@ -84,6 +87,7 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ + @Since("1.5.0") def setSmoothing(value: Double): this.type = set(smoothing, value) setDefault(smoothing -> 1.0) @@ -93,6 +97,7 @@ class NaiveBayes(override val uid: String) * Default is "multinomial" * @group setParam */ + @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) @@ -102,6 +107,7 @@ class NaiveBayes(override val uid: String) NaiveBayesModel.fromOld(oldModel, this) } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } @@ -119,11 +125,12 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ +@Since("1.5.0") @Experimental class NaiveBayesModel private[ml] ( - override val uid: String, - val pi: Vector, - val theta: Matrix) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val pi: Vector, + @Since("1.5.0") val theta: Matrix) extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { @@ -148,8 +155,10 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + @Since("1.6.0") override val numFeatures: Int = theta.numCols + @Since("1.5.0") override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { @@ -206,10 +215,12 @@ class NaiveBayesModel private[ml] ( } } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } + @Since("1.5.0") override def toString: String = { s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index debc164bf2432..08a51109d6c62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap} @@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ +@Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") labelMetadata: Metadata, + @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -134,6 +137,7 @@ final class OneVsRestModel private[ml] ( .drop(accColName) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) @@ -150,30 +154,39 @@ final class OneVsRestModel private[ml] ( * Each example is scored against all k models and the model with highest score * is picked to label the example. */ +@Since("1.4.0") @Experimental -final class OneVsRest(override val uid: String) +final class OneVsRest @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ + @Since("1.4.0") def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } + @Since("1.4.0") override def fit(dataset: DataFrame): OneVsRestModel = { // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) @@ -222,6 +235,7 @@ final class OneVsRest(override val uid: String) copyValues(model) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRest = { val copied = defaultCopy(extra).asInstanceOf[OneVsRest] if (isDefined(classifier)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bae329692a68d..d6d85ad2533a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} @@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._ * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class RandomForestClassifier(override val uid: String) +final class RandomForestClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) @@ -99,15 +114,19 @@ final class RandomForestClassifier(override val uid: String) new RandomForestClassificationModel(trees, numFeatures, numClasses) } + @Since("1.4.1") override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies } @@ -120,12 +139,13 @@ object RandomForestClassifier { * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ +@Since("1.4.0") @Experimental final class RandomForestClassificationModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.6.0") override val numFeatures: Int, + @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -141,11 +161,13 @@ final class RandomForestClassificationModel private[ml] ( numClasses: Int) = this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -186,11 +208,13 @@ final class RandomForestClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } From 4a39b5a1bee28cec792d509654f6236390cafdcb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 7 Dec 2015 23:50:57 -0800 Subject: [PATCH 1153/2089] [SPARK-11958][SPARK-11957][ML][DOC] SQLTransformer user guide and example code Add ```SQLTransformer``` user guide, example code and make Scala API doc more clear. Author: Yanbo Liang Closes #10006 from yanboliang/spark-11958. --- docs/ml-features.md | 59 +++++++++++++++++++ .../ml/JavaSQLTransformerExample.java | 59 +++++++++++++++++++ .../src/main/python/ml/sql_transformer.py | 40 +++++++++++++ .../examples/ml/SQLTransformerExample.scala | 45 ++++++++++++++ .../spark/ml/feature/SQLTransformer.scala | 11 +++- 5 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java create mode 100644 examples/src/main/python/ml/sql_transformer.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 5105a948fec8e..f85e0d56d2e40 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -756,6 +756,65 @@ for more details on the API. +## SQLTransformer + +`SQLTransformer` implements the transformations which are defined by SQL statement. +Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +where `"__THIS__"` represents the underlying table of the input dataset. +The select clause specifies the fields, constants, and expressions to display in +the output, it can be any select clause that Spark SQL supports. Users can also +use Spark SQL built-in function and UDFs to operate on these selected columns. +For example, `SQLTransformer` supports statements like: + +* `SELECT a, a + b AS a_b FROM __THIS__` +* `SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5` +* `SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b` + +**Examples** + +Assume that we have the following DataFrame with columns `id`, `v1` and `v2`: + +~~~~ + id | v1 | v2 +----|-----|----- + 0 | 1.0 | 3.0 + 2 | 2.0 | 5.0 +~~~~ + +This is the output of the `SQLTransformer` with statement `"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"`: + +~~~~ + id | v1 | v2 | v3 | v4 +----|-----|-----|-----|----- + 0 | 1.0 | 3.0 | 4.0 | 3.0 + 2 | 2.0 | 5.0 | 7.0 |10.0 +~~~~ + +
    +
    + +Refer to the [SQLTransformer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.SQLTransformer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/SQLTransformerExample.scala %} +
    + +
    + +Refer to the [SQLTransformer Java docs](api/java/org/apache/spark/ml/feature/SQLTransformer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java %} +
    + +
    + +Refer to the [SQLTransformer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.SQLTransformer) for more details on the API. + +{% include_example python/ml/sql_transformer.py %} +
    +
    + ## VectorAssembler `VectorAssembler` is a transformer that combines a given list of columns into a single vector diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java new file mode 100644 index 0000000000000..d55c70796a967 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.SQLTransformer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaSQLTransformerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 1.0, 3.0), + RowFactory.create(2, 2.0, 5.0) + )); + StructType schema = new StructType(new StructField [] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("v2", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + SQLTransformer sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"); + + sqlTrans.transform(df).show(); + // $example off$ + } +} diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py new file mode 100644 index 0000000000000..9575d728d8159 --- /dev/null +++ b/examples/src/main/python/ml/sql_transformer.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import SQLTransformer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="SQLTransformerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, 1.0, 3.0), + (2, 2.0, 5.0) + ], ["id", "v1", "v2"]) + sqlTrans = SQLTransformer( + statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + sqlTrans.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala new file mode 100644 index 0000000000000..014abd1fdbc63 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.SQLTransformer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + + +object SQLTransformerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SQLTransformerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + + sqlTrans.transform(df).show() + // $example off$ + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 3a735017ba836..c09f4d076c964 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -27,9 +27,16 @@ import org.apache.spark.sql.types.StructType /** * :: Experimental :: - * Implements the transforms which are defined by SQL statement. - * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * Implements the transformations which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' * where '__THIS__' represents the underlying table of the input dataset. + * The select clause specifies the fields, constants, and expressions to display in + * the output, it can be any select clause that Spark SQL supports. Users can also + * use Spark SQL built-in function and UDFs to operate on these selected columns. + * For example, [[SQLTransformer]] supports statements like: + * - SELECT a, a + b AS a_b FROM __THIS__ + * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 + * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b */ @Experimental @Since("1.6.0") From 48a9804b2ad89b3fb204c79f0dbadbcfea15d8dc Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 8 Dec 2015 11:02:35 +0000 Subject: [PATCH 1154/2089] =?UTF-8?q?[SPARK-12103][STREAMING][KAFKA][DOC]?= =?UTF-8?q?=20document=20that=20K=20means=20Key=20and=20V=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …means Value Author: cody koeninger Closes #10132 from koeninger/SPARK-12103. --- .../spark/streaming/kafka/KafkaUtils.scala | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index ad2fb8aa5f24c..fe572220528d5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -51,6 +51,7 @@ object KafkaUtils { * in its own thread * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( ssc: StreamingContext, @@ -74,6 +75,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel Storage level to use for storing the received objects + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, @@ -93,6 +99,7 @@ object KafkaUtils { * @param groupId The group id for this consumer * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -111,6 +118,7 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -135,6 +143,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread * @param storageLevel RDD storage level. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K, V, U <: Decoder[_], T <: Decoder[_]]( jssc: JavaStreamingContext, @@ -219,6 +232,11 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[ K: ClassTag, @@ -251,6 +269,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[ K: ClassTag, @@ -288,6 +312,15 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @param keyClass type of Kafka message key + * @param valueClass type of Kafka message value + * @param keyDecoderClass type of Kafka message key decoder + * @param valueDecoderClass type of Kafka message value decoder + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, @@ -321,6 +354,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, @@ -373,6 +412,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[ K: ClassTag, @@ -419,6 +464,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[ K: ClassTag, @@ -470,6 +520,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, @@ -529,6 +585,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, From 708129187a460aca30790281e9221c0cd5e271df Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 8 Dec 2015 11:05:06 +0000 Subject: [PATCH 1155/2089] [SPARK-12166][TEST] Unset hadoop related environment in testing Author: Jeff Zhang Closes #10172 from zjffdu/SPARK-12166. --- bin/spark-class | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bin/spark-class b/bin/spark-class index 87d06693af4fe..5d964ba96abd8 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -71,6 +71,12 @@ fi export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" +# For tests +if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR +fi + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. From 037b7e76a7f8b59e031873a768d81417dd180472 Mon Sep 17 00:00:00 2001 From: Nakul Jindal Date: Tue, 8 Dec 2015 11:08:27 +0000 Subject: [PATCH 1156/2089] [SPARK-11439][ML] Optimization of creating sparse feature without dense one Sparse feature generated in LinearDataGenerator does not create dense vectors as an intermediate any more. Author: Nakul Jindal Closes #9756 from nakul02/SPARK-11439_sparse_without_creating_dense_feature. --- .../mllib/util/LinearDataGenerator.scala | 44 ++-- .../evaluation/RegressionEvaluatorSuite.scala | 6 +- .../ml/regression/LinearRegressionSuite.scala | 214 ++++++++++-------- 3 files changed, 142 insertions(+), 122 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 6ff07eed6cfd2..094528e2ece06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -24,7 +24,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{BLAS, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -131,35 +131,27 @@ object LinearDataGenerator { eps: Double, sparsity: Double): Seq[LabeledPoint] = { require(0.0 <= sparsity && sparsity <= 1.0) - val rnd = new Random(seed) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble())) - - val sparseRnd = new Random(seed) - x.foreach { v => - var i = 0 - val len = v.length - while (i < len) { - if (sparseRnd.nextDouble() < sparsity) { - v(i) = 0.0 - } else { - v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) - } - i += 1 - } - } - val y = x.map { xi => - blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() - } + val rnd = new Random(seed) + def rndElement(i: Int) = {(rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)} - y.zip(x).map { p => - if (sparsity == 0.0) { + if (sparsity == 0.0) { + (0 until nPoints).map { _ => + val features = Vectors.dense(weights.indices.map { rndElement(_) }.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with DenseVector - LabeledPoint(p._1, Vectors.dense(p._2)) - } else { + LabeledPoint(label, features) + } + } else { + (0 until nPoints).map { _ => + val indices = weights.indices.filter { _ => rnd.nextDouble() <= sparsity} + val values = indices.map { rndElement(_) } + val features = Vectors.sparse(weights.length, indices.toArray, values.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with SparseVector - LabeledPoint(p._1, Vectors.dense(p._2).toSparse) + LabeledPoint(label, features) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 60886bf77d2f0..954d3bedc14bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -65,15 +65,15 @@ class RegressionEvaluatorSuite // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1013829 absTol 0.01) // r2 score evaluator.setMetricName("r2") - assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.9998387 absTol 0.01) // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2bdc0e184d734..2f3e703f4c252 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -42,6 +42,7 @@ class LinearRegressionSuite In `LinearRegressionSuite`, we will make sure that the model trained by SparkML is the same as the one trained by R's glmnet package. The following instruction describes how to reproduce the data in R. + In a spark-shell, use the following code: import org.apache.spark.mllib.util.LinearDataGenerator val data = @@ -184,15 +185,15 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.995908 - as.numeric.data.V3. 5.275131 + as.numeric.data.V2. 6.973403 + as.numeric.data.V3. 5.284370 */ - val coefficientsR = Vectors.dense(6.995908, 5.275131) + val coefficientsR = Vectors.dense(6.973403, 5.284370) - assert(model1.intercept ~== 0 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR relTol 1E-3) - assert(model2.intercept ~== 0 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR relTol 1E-3) + assert(model1.intercept ~== 0 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) + assert(model2.intercept ~== 0 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) /* Then again with the data with no intercept: @@ -235,14 +236,14 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.24300 - as.numeric.data.V2. 4.024821 - as.numeric.data.V3. 6.679841 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 */ - val interceptR1 = 6.24300 - val coefficientsR1 = Vectors.dense(4.024821, 6.679841) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -296,14 +297,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.299752 - as.numeric.data.V3. 4.772913 + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(6.299752, 4.772913) + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -312,14 +313,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.232193 - as.numeric.data.V3. 4.764229 + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(6.232193, 4.764229) + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -347,15 +348,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.269376 - as.numeric.data.V2. 3.736216 - as.numeric.data.V3. 5.712356) + (Intercept) 5.260103 + as.numeric.d1.V2. 3.725522 + as.numeric.d1.V3. 5.711203 */ - val interceptR1 = 5.269376 - val coefficientsR1 = Vectors.dense(3.736216, 5.712356) + val interceptR1 = 5.260103 + val coefficientsR1 = Vectors.dense(3.725522, 5.711203) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -363,15 +364,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.791109 - as.numeric.data.V2. 3.435466 - as.numeric.data.V3. 5.910406 + (Intercept) 5.790885 + as.numeric.d1.V2. 3.432373 + as.numeric.d1.V3. 5.919196 */ - val interceptR2 = 5.791109 - val coefficientsR2 = Vectors.dense(3.435466, 5.910406) + val interceptR2 = 5.790885 + val coefficientsR2 = Vectors.dense(3.432373, 5.919196) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -398,15 +399,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.data.V2. 5.522875 - as.numeric.data.V3. 4.214502 + (Intercept) . + as.numeric.d1.V2. 5.493430 + as.numeric.d1.V3. 4.223082 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.522875, 4.214502) + val coefficientsR1 = Vectors.dense(5.493430, 4.223082) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -415,14 +416,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.263704 - as.numeric.data.V3. 4.187419 + as.numeric.d1.V2. 5.244324 + as.numeric.d1.V3. 4.203106 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.263704, 4.187419) + val coefficientsR2 = Vectors.dense(5.244324, 4.203106) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -457,15 +458,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.324108 - as.numeric.data.V2. 3.168435 - as.numeric.data.V3. 5.200403 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 */ - val interceptR1 = 5.696056 - val coefficientsR1 = Vectors.dense(3.670489, 6.001122) + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 @@ -473,15 +474,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.114723 - as.numeric.data.V2. 3.409937 - as.numeric.data.V3. 6.146531 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 */ - val interceptR2 = 6.114723 - val coefficientsR2 = Vectors.dense(3.409937, 6.146531) + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -518,15 +519,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.dataM.V2. 5.673348 - as.numeric.dataM.V3. 4.322251 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.673348, 4.322251) + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, @@ -535,14 +536,15 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.477988 - as.numeric.data.V3. 4.297622 + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.477988, 4.297622) + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -592,21 +594,47 @@ class LinearRegressionSuite } /* - Use the following R code to generate model training results. - - predictions <- predict(fit, newx=features) - residuals <- label - predictions - > mean(residuals^2) # MSE - [1] 0.009720325 - > mean(abs(residuals)) # MAD - [1] 0.07863206 - > cor(predictions, label)^2# r^2 - [,1] - s0 0.9998749 + # Use the following R code to generate model training results. + + # path/part-00000 is the file generated by running LinearDataGenerator.generateLinearInput + # as described before the beforeAll() method. + d1 <- read.csv("path/part-00000", header=FALSE, stringsAsFactors=FALSE) + fit <- glm(V1 ~ V2 + V3, data = d1, family = "gaussian") + names(f1)[1] = c("V2") + names(f1)[2] = c("V3") + f1 <- data.frame(as.numeric(d1$V2), as.numeric(d1$V3)) + predictions <- predict(fit, newdata=f1) + l1 <- as.numeric(d1$V1) + + residuals <- l1 - predictions + > mean(residuals^2) # MSE + [1] 0.00985449 + > mean(abs(residuals)) # MAD + [1] 0.07961668 + > cor(predictions, l1)^2 # r^2 + [1] 0.9998737 + + > summary(fit) + + Call: + glm(formula = V1 ~ V2 + V3, family = "gaussian", data = d1) + + Deviance Residuals: + Min 1Q Median 3Q Max + -0.47082 -0.06797 0.00002 0.06725 0.34635 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** + V2 4.6982442 0.0011805 3980 <2e-16 *** + V3 7.1994344 0.0009044 7961 <2e-16 *** + --- + + .... */ - assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) - assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) - assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) + assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) + assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". This algorithm does not generate // objective history because it does not run through iterations. @@ -621,14 +649,14 @@ class LinearRegressionSuite // To clalify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) - val devianceResidualsR = Array(-0.35566, 0.34504) - val seCoefR = Array(0.0011756, 0.0009032, 0.0018489) - val tValsR = Array(3998, 7971, 3407) + val devianceResidualsR = Array(-0.47082, 0.34635) + val seCoefR = Array(0.0011805, 0.0009044, 0.0018600) + val tValsR = Array(3980, 7961, 3388) val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } From da2012a0e152aa078bdd19a5c7f91786a2dd7016 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 8 Dec 2015 19:18:59 +0800 Subject: [PATCH 1157/2089] [SPARK-11551][DOC][EXAMPLE] Revert PR #10002 This reverts PR #10002, commit 78209b0ccaf3f22b5e2345dfb2b98edfdb746819. The original PR wasn't tested on Jenkins before being merged. Author: Cheng Lian Closes #10200 from liancheng/revert-pr-10002. --- docs/ml-features.md | 1109 ++++++++++++++++- .../examples/ml/JavaBinarizerExample.java | 68 - .../examples/ml/JavaBucketizerExample.java | 70 -- .../spark/examples/ml/JavaDCTExample.java | 65 - .../ml/JavaElementwiseProductExample.java | 75 -- .../examples/ml/JavaMinMaxScalerExample.java | 50 - .../spark/examples/ml/JavaNGramExample.java | 71 -- .../examples/ml/JavaNormalizerExample.java | 52 - .../examples/ml/JavaOneHotEncoderExample.java | 77 -- .../spark/examples/ml/JavaPCAExample.java | 71 -- .../ml/JavaPolynomialExpansionExample.java | 71 -- .../examples/ml/JavaRFormulaExample.java | 69 - .../ml/JavaStandardScalerExample.java | 53 - .../ml/JavaStopWordsRemoverExample.java | 65 - .../examples/ml/JavaStringIndexerExample.java | 66 - .../examples/ml/JavaTokenizerExample.java | 75 -- .../ml/JavaVectorAssemblerExample.java | 67 - .../examples/ml/JavaVectorIndexerExample.java | 60 - .../examples/ml/JavaVectorSlicerExample.java | 73 -- .../src/main/python/ml/binarizer_example.py | 43 - .../src/main/python/ml/bucketizer_example.py | 42 - .../python/ml/elementwise_product_example.py | 39 - examples/src/main/python/ml/n_gram_example.py | 42 - .../src/main/python/ml/normalizer_example.py | 41 - .../main/python/ml/onehot_encoder_example.py | 47 - examples/src/main/python/ml/pca_example.py | 42 - .../python/ml/polynomial_expansion_example.py | 43 - .../src/main/python/ml/rformula_example.py | 44 - .../main/python/ml/standard_scaler_example.py | 42 - .../python/ml/stopwords_remover_example.py | 40 - .../main/python/ml/string_indexer_example.py | 39 - .../src/main/python/ml/tokenizer_example.py | 44 - .../python/ml/vector_assembler_example.py | 42 - .../main/python/ml/vector_indexer_example.py | 39 - .../spark/examples/ml/BinarizerExample.scala | 48 - .../spark/examples/ml/BucketizerExample.scala | 51 - .../apache/spark/examples/ml/DCTExample.scala | 54 - .../ml/ElementWiseProductExample.scala | 53 - .../examples/ml/MinMaxScalerExample.scala | 49 - .../spark/examples/ml/NGramExample.scala | 47 - .../spark/examples/ml/NormalizerExample.scala | 50 - .../examples/ml/OneHotEncoderExample.scala | 58 - .../apache/spark/examples/ml/PCAExample.scala | 54 - .../ml/PolynomialExpansionExample.scala | 53 - .../spark/examples/ml/RFormulaExample.scala | 49 - .../examples/ml/StandardScalerExample.scala | 51 - .../examples/ml/StopWordsRemoverExample.scala | 48 - .../examples/ml/StringIndexerExample.scala | 49 - .../spark/examples/ml/TokenizerExample.scala | 54 - .../examples/ml/VectorAssemblerExample.scala | 49 - .../examples/ml/VectorIndexerExample.scala | 53 - .../examples/ml/VectorSlicerExample.scala | 58 - 52 files changed, 1058 insertions(+), 2806 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java delete mode 100644 examples/src/main/python/ml/binarizer_example.py delete mode 100644 examples/src/main/python/ml/bucketizer_example.py delete mode 100644 examples/src/main/python/ml/elementwise_product_example.py delete mode 100644 examples/src/main/python/ml/n_gram_example.py delete mode 100644 examples/src/main/python/ml/normalizer_example.py delete mode 100644 examples/src/main/python/ml/onehot_encoder_example.py delete mode 100644 examples/src/main/python/ml/pca_example.py delete mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py delete mode 100644 examples/src/main/python/ml/rformula_example.py delete mode 100644 examples/src/main/python/ml/standard_scaler_example.py delete mode 100644 examples/src/main/python/ml/stopwords_remover_example.py delete mode 100644 examples/src/main/python/ml/string_indexer_example.py delete mode 100644 examples/src/main/python/ml/tokenizer_example.py delete mode 100644 examples/src/main/python/ml/vector_assembler_example.py delete mode 100644 examples/src/main/python/ml/vector_indexer_example.py delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index f85e0d56d2e40..01d6abeb5ba6a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,7 +170,25 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} + +val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") +)).toDF("label", "sentence") +val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) +{% endhighlight %}
    @@ -179,7 +197,44 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) +}); +DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); +Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); +DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); +for (Row r : wordsDataFrame.select("words", "label").take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); +} + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); +{% endhighlight %}
    @@ -188,7 +243,21 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% include_example python/ml/tokenizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Tokenizer, RegexTokenizer + +sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") +], ["label", "sentence"]) +tokenizer = Tokenizer(inputCol="sentence", outputCol="words") +wordsDataFrame = tokenizer.transform(sentenceDataFrame) +for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) +{% endhighlight %}
    @@ -237,7 +306,19 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StopWordsRemover + +val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") +val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) +)).toDF("id", "raw") + +remover.transform(dataSet).show() +{% endhighlight %}
    @@ -245,7 +326,34 @@ for more details on the API. Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) +)); +StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame dataset = jsql.createDataFrame(rdd, schema); + +remover.transform(dataset).show(); +{% endhighlight %}
    @@ -253,7 +361,17 @@ for more details on the API. Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% include_example python/ml/stopwords_remover_example.py %} +{% highlight python %} +from pyspark.ml.feature import StopWordsRemover + +sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) +], ["label", "raw"]) + +remover = StopWordsRemover(inputCol="raw", outputCol="filtered") +remover.transform(sentenceData).show(truncate=False) +{% endhighlight %}
    @@ -270,7 +388,19 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.NGram + +val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) +)).toDF("label", "words") + +val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") +val ngramDataFrame = ngram.transform(wordDataFrame) +ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) +{% endhighlight %}
    @@ -278,7 +408,38 @@ for more details on the API. Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); +NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); +DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); +for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); +} +{% endhighlight %}
    @@ -286,7 +447,19 @@ for more details on the API. Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% include_example python/ml/n_gram_example.py %} +{% highlight python %} +from pyspark.ml.feature import NGram + +wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) +], ["label", "words"]) +ngram = NGram(inputCol="words", outputCol="ngrams") +ngramDataFrame = ngram.transform(wordDataFrame) +for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) +{% endhighlight %}
    @@ -303,7 +476,26 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Binarizer +import org.apache.spark.sql.DataFrame + +val data = Array( + (0, 0.1), + (1, 0.8), + (2, 0.2) +) +val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + +val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + +val binarizedDataFrame = binarizer.transform(dataFrame) +val binarizedFeatures = binarizedDataFrame.select("binarized_feature") +binarizedFeatures.collect().foreach(println) +{% endhighlight %}
    @@ -311,7 +503,40 @@ for more details on the API. Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); +Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); +DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); +DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); +for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); +} +{% endhighlight %}
    @@ -319,7 +544,20 @@ for more details on the API. Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% include_example python/ml/binarizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Binarizer + +continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) +], ["label", "feature"]) +binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") +binarizedDataFrame = binarizer.transform(continuousDataFrame) +binarizedFeatures = binarizedDataFrame.select("binarized_feature") +for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) +{% endhighlight %}
    @@ -333,7 +571,25 @@ for more details on the API. Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) +val pcaDF = pca.transform(df) +val result = pcaDF.select("pcaFeatures") +result.show() +{% endhighlight %}
    @@ -341,7 +597,42 @@ for more details on the API. Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.PCA +import org.apache.spark.ml.feature.PCAModel +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); +DataFrame result = pca.transform(df).select("pcaFeatures"); +result.show(); +{% endhighlight %}
    @@ -349,7 +640,19 @@ for more details on the API. Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% include_example python/ml/pca_example.py %} +{% highlight python %} +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] +df = sqlContext.createDataFrame(data,["features"]) +pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") +model = pca.fit(df) +result = model.transform(df).select("pcaFeatures") +result.show(truncate=False) +{% endhighlight %}
    @@ -363,7 +666,23 @@ for more details on the API. Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) +val polyDF = polynomialExpansion.transform(df) +polyDF.select("polyFeatures").take(3).foreach(println) +{% endhighlight %}
    @@ -371,7 +690,43 @@ for more details on the API. Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DataFrame polyDF = polyExpansion.transform(df); +Row[] row = polyDF.select("polyFeatures").take(3); +for (Row r : row) { + System.out.println(r.get(0)); +} +{% endhighlight %}
    @@ -379,7 +734,20 @@ for more details on the API. Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% include_example python/ml/polynomial_expansion_example.py %} +{% highlight python %} +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors + +df = sqlContext.createDataFrame( + [(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) +px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") +polyDF = px.transform(df) +for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) +{% endhighlight %}
    @@ -403,7 +771,22 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors + +val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) +val dctDf = dct.transform(df) +dctDf.select("featuresDCT").show(3) +{% endhighlight %}
    @@ -411,7 +794,39 @@ for more details on the API. Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); +DataFrame dctDf = dct.transform(df); +dctDf.select("featuresDCT").show(3); +{% endhighlight %}
    @@ -466,7 +881,18 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %}
    @@ -474,7 +900,37 @@ for more details on the API. Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %}
    @@ -482,7 +938,16 @@ for more details on the API. Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% include_example python/ml/string_indexer_example.py %} +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %}
    @@ -496,7 +961,29 @@ for more details on the API. Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} + +val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") +)).toDF("id", "category") + +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) +val indexed = indexer.transform(df) + +val encoder = new OneHotEncoder().setInputCol("categoryIndex"). + setOutputCol("categoryVec") +val encoded = encoder.transform(indexed) +encoded.select("id", "categoryVec").foreach(println) +{% endhighlight %}
    @@ -504,7 +991,45 @@ for more details on the API. Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); +DataFrame indexed = indexer.transform(df); + +OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); +DataFrame encoded = encoder.transform(indexed); +{% endhighlight %}
    @@ -512,7 +1037,24 @@ for more details on the API. Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% highlight python %} +from pyspark.ml.feature import OneHotEncoder, StringIndexer + +df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") +], ["id", "category"]) + +stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +model = stringIndexer.fit(df) +indexed = model.transform(df) +encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") +encoded = encoder.transform(indexed) +{% endhighlight %}
    @@ -536,7 +1078,23 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.VectorIndexer + +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) +val indexerModel = indexer.fit(data) +val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet +println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + +// Create new column "indexed" with categorical values transformed to indices +val indexedData = indexerModel.transform(data) +{% endhighlight %}
    @@ -544,7 +1102,30 @@ for more details on the API. Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %} +{% highlight java %} +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); +VectorIndexerModel indexerModel = indexer.fit(data); +Map> categoryMaps = indexerModel.javaCategoryMaps(); +System.out.print("Chose " + categoryMaps.size() + "categorical features:"); +for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); +} +System.out.println(); + +// Create new column "indexed" with categorical values transformed to indices +DataFrame indexedData = indexerModel.transform(data); +{% endhighlight %}
    @@ -552,7 +1133,17 @@ for more details on the API. Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% include_example python/ml/vector_indexer_example.py %} +{% highlight python %} +from pyspark.ml.feature import VectorIndexer + +data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) +indexerModel = indexer.fit(data) + +# Create new column "indexed" with categorical values transformed to indices +indexedData = indexerModel.transform(data) +{% endhighlight %}
    @@ -569,7 +1160,22 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Normalizer + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") + +// Normalize each Vector using $L^1$ norm. +val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) +val l1NormData = normalizer.transform(dataFrame) + +// Normalize each Vector using $L^\infty$ norm. +val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) +{% endhighlight %}
    @@ -577,7 +1183,24 @@ for more details on the API. Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %} +{% highlight java %} +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + +// Normalize each Vector using $L^1$ norm. +Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); +DataFrame l1NormData = normalizer.transform(dataFrame); + +// Normalize each Vector using $L^\infty$ norm. +DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); +{% endhighlight %}
    @@ -585,7 +1208,19 @@ for more details on the API. Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% include_example python/ml/normalizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Normalizer + +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") + +# Normalize each Vector using $L^1$ norm. +normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) +l1NormData = normalizer.transform(dataFrame) + +# Normalize each Vector using $L^\infty$ norm. +lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) +{% endhighlight %}
    @@ -609,7 +1244,23 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StandardScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + +// Compute summary statistics by fitting the StandardScaler +val scalerModel = scaler.fit(dataFrame) + +// Normalize each feature to have unit standard deviation. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -617,7 +1268,25 @@ for more details on the API. Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %} +{% highlight java %} +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + +// Compute summary statistics by fitting the StandardScaler +StandardScalerModel scalerModel = scaler.fit(dataFrame); + +// Normalize each feature to have unit standard deviation. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %}
    @@ -625,7 +1294,20 @@ for more details on the API. Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% include_example python/ml/standard_scaler_example.py %} +{% highlight python %} +from pyspark.ml.feature import StandardScaler + +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + +# Compute summary statistics by fitting the StandardScaler +scalerModel = scaler.fit(dataFrame) + +# Normalize each feature to have unit standard deviation. +scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -655,7 +1337,21 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.MinMaxScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + +// Compute summary statistics and generate MinMaxScalerModel +val scalerModel = scaler.fit(dataFrame) + +// rescale each feature to range [min, max]. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -664,7 +1360,24 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %} +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + +// Compute summary statistics and generate MinMaxScalerModel +MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + +// rescale each feature to range [min, max]. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %}
    @@ -688,7 +1401,23 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Bucketizer +import org.apache.spark.sql.DataFrame + +val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + +val data = Array(-0.5, -0.3, 0.0, 0.2) +val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + +val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + +// Transform original data into its bucket index. +val bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %}
    @@ -696,7 +1425,38 @@ for more details on the API. Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame dataFrame = jsql.createDataFrame(data, schema); + +Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + +// Transform original data into its bucket index. +DataFrame bucketedData = bucketizer.transform(dataFrame); +{% endhighlight %}
    @@ -704,7 +1464,19 @@ for more details on the API. Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% include_example python/ml/bucketizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Bucketizer + +splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + +data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] +dataFrame = sqlContext.createDataFrame(data, ["features"]) + +bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + +# Transform original data into its bucket index. +bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %}
    @@ -736,7 +1508,25 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + +// Batch transform the vectors to create new column: +transformer.transform(dataFrame).show() + +{% endhighlight %}
    @@ -744,7 +1534,41 @@ for more details on the API. Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// Create some vector data; also works for sparse vectors +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) +)); +List fields = new ArrayList(2); +fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); +fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); +StructType schema = DataTypes.createStructType(fields); +DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); +// Batch transform the vectors to create new column: +transformer.transform(dataFrame).show(); + +{% endhighlight %}
    @@ -752,8 +1576,19 @@ for more details on the API. Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% include_example python/ml/elementwise_product_example.py %} +{% highlight python %} +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] +df = sqlContext.createDataFrame(data, ["vector"]) +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") +transformer.transform(df).show() + +{% endhighlight %}
    + ## SQLTransformer @@ -856,7 +1691,19 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %} +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.feature.VectorAssembler + +val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) +).toDF("id", "hour", "mobile", "userFeatures", "clicked") +val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") +val output = assembler.transform(dataset) +println(output.select("features", "clicked").first()) +{% endhighlight %}
    @@ -864,7 +1711,36 @@ for more details on the API. Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) +}); +Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); +JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + +DataFrame output = assembler.transform(dataset); +System.out.println(output.select("features", "clicked").first()); +{% endhighlight %}
    @@ -872,7 +1748,19 @@ for more details on the API. Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% include_example python/ml/vector_assembler_example.py %} +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler + +dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) +assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") +output = assembler.transform(dataset) +print(output.select("features", "clicked").first()) +{% endhighlight %}
    @@ -1002,7 +1890,33 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %} +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0) +) + +val defaultAttr = NumericAttribute.defaultAttr +val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) +val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + +val dataRDD = sc.parallelize(data).map(Row.apply) +val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) + +val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + +slicer.setIndices(1).setNames("f3") +// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + +val output = slicer.transform(dataset) +println(output.select("userFeatures", "features").first()) +{% endhighlight %}
    @@ -1010,7 +1924,41 @@ for more details on the API. Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") +}; +AttributeGroup group = new AttributeGroup("userFeatures", attrs); + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) +)); + +DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + +VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + +vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); +// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + +DataFrame output = vectorSlicer.transform(dataset); + +System.out.println(output.select("userFeatures", "features").first()); +{% endhighlight %}
    @@ -1047,7 +1995,21 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.RFormula + +val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) +)).toDF("id", "country", "hour", "clicked") +val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") +val output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %}
    @@ -1055,7 +2017,38 @@ for more details on the API. Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) +}); +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) +)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + +DataFrame output = formula.fit(dataset).transform(dataset); +output.select("features", "label").show(); +{% endhighlight %}
    @@ -1063,7 +2056,21 @@ for more details on the API. Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% include_example python/ml/rformula_example.py %} +{% highlight python %} +from pyspark.ml.feature import RFormula + +dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) +formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") +output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java deleted file mode 100644 index 9698cac504371..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaBinarizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) - }); - DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); - Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); - DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); - DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); - for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java deleted file mode 100644 index b06a23e76d604..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Bucketizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaBucketizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) - }); - DataFrame dataFrame = jsql.createDataFrame(data, schema); - - Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - - // Transform original data into its bucket index. - DataFrame bucketedData = bucketizer.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java deleted file mode 100644 index 35c0d534a45e9..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaDCTExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - DataFrame df = jsql.createDataFrame(data, schema); - DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); - DataFrame dctDf = dct.transform(df); - dctDf.select("featuresDCT").show(3); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java deleted file mode 100644 index 2898accec61b0..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaElementwiseProductExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - // Create some vector data; also works for sparse vectors - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) - )); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); - fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); - - StructType schema = DataTypes.createStructType(fields); - - DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); - - Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); - - ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); - - // Batch transform the vectors to create new column: - transformer.transform(dataFrame).show(); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java deleted file mode 100644 index 138b3ab6aba44..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaMinMaxScalerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - - // Compute summary statistics and generate MinMaxScalerModel - MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - - // rescale each feature to range [min, max]. - DataFrame scaledData = scalerModel.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java deleted file mode 100644 index 8fd75ed8b5f4e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaNGramExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField( - "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) - }); - - DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); - - NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); - - DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); - - for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java deleted file mode 100644 index 6283a355e1fef..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaNormalizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - // Normalize each Vector using $L^1$ norm. - Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); - - DataFrame l1NormData = normalizer.transform(dataFrame); - - // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java deleted file mode 100644 index 172a9cc6feb28..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaOneHotEncoderExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) - }); - - DataFrame df = sqlContext.createDataFrame(jrdd, schema); - - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - DataFrame indexed = indexer.transform(df); - - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - DataFrame encoded = encoder.transform(indexed); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java deleted file mode 100644 index 8282fab084f36..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.PCA; -import org.apache.spark.ml.feature.PCAModel; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaPCAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - DataFrame df = jsql.createDataFrame(data, schema); - - PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); - - DataFrame result = pca.transform(df).select("pcaFeatures"); - result.show(); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java deleted file mode 100644 index 668f71e64056b..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.PolynomialExpansion; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaPolynomialExpansionExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); - - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - DataFrame df = jsql.createDataFrame(data, schema); - DataFrame polyDF = polyExpansion.transform(df); - - Row[] row = polyDF.select("polyFeatures").take(3); - for (Row r : row) { - System.out.println(r.get(0)); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java deleted file mode 100644 index 1e1062b541ad9..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaRFormulaExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - StructType schema = createStructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) - }); - - JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) - )); - - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - DataFrame output = formula.fit(dataset).transform(dataset); - output.select("features", "label").show(); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java deleted file mode 100644 index 0cbdc97e8ae30..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaStandardScalerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - - // Compute summary statistics by fitting the StandardScaler - StandardScalerModel scalerModel = scaler.fit(dataFrame); - - // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java deleted file mode 100644 index b6b201c6b68d2..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaStopWordsRemoverExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - - JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField( - "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) - }); - - DataFrame dataset = jsql.createDataFrame(rdd, schema); - remover.transform(dataset).show(); - // $example off$ - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java deleted file mode 100644 index 05d12c1e702f1..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaStringIndexerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") - )); - StructType schema = new StructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("category", StringType, false) - }); - DataFrame df = sqlContext.createDataFrame(jrdd, schema); - StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); - DataFrame indexed = indexer.fit(df).transform(df); - indexed.show(); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java deleted file mode 100644 index 617dc3f66e3bf..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaTokenizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) - }); - - DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); - - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); - for (Row r : wordsDataFrame.select("words", "label"). take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); - } - - RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); - // $example off$ - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java deleted file mode 100644 index 7e230b5897c1e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaVectorAssemblerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - StructType schema = createStructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) - }); - Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); - JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - - VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - - DataFrame output = assembler.transform(dataset); - System.out.println(output.select("features", "clicked").first()); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java deleted file mode 100644 index 06b4bf6bf8ff6..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaVectorIndexerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); - VectorIndexerModel indexerModel = indexer.fit(data); - - Map> categoryMaps = indexerModel.javaCategoryMaps(); - System.out.print("Chose " + categoryMaps.size() + " categorical features:"); - - for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); - } - System.out.println(); - - // Create new column "indexed" with categorical values transformed to indices - DataFrame indexedData = indexerModel.transform(data); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java deleted file mode 100644 index 4d5cb04ff5e2b..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.attribute.Attribute; -import org.apache.spark.ml.attribute.AttributeGroup; -import org.apache.spark.ml.attribute.NumericAttribute; -import org.apache.spark.ml.feature.VectorSlicer; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -// $example off$ - -public class JavaVectorSlicerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") - }; - AttributeGroup group = new AttributeGroup("userFeatures", attrs); - - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) - )); - - DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - - VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - - vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - - DataFrame output = vectorSlicer.transform(dataset); - - System.out.println(output.select("userFeatures", "features").first()); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py deleted file mode 100644 index 960ad208be12e..0000000000000 --- a/examples/src/main/python/ml/binarizer_example.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Binarizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="BinarizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) - ], ["label", "feature"]) - binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") - binarizedDataFrame = binarizer.transform(continuousDataFrame) - binarizedFeatures = binarizedDataFrame.select("binarized_feature") - for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py deleted file mode 100644 index a12750aa9248a..0000000000000 --- a/examples/src/main/python/ml/bucketizer_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Bucketizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="BucketizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - - data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] - dataFrame = sqlContext.createDataFrame(data, ["features"]) - - bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - - # Transform original data into its bucket index. - bucketedData = bucketizer.transform(dataFrame) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py deleted file mode 100644 index c85cb0d89543c..0000000000000 --- a/examples/src/main/python/ml/elementwise_product_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="ElementwiseProductExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] - df = sqlContext.createDataFrame(data, ["vector"]) - transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") - transformer.transform(df).show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py deleted file mode 100644 index f2d85f53e7219..0000000000000 --- a/examples/src/main/python/ml/n_gram_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import NGram -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="NGramExample") - sqlContext = SQLContext(sc) - - # $example on$ - wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) - ], ["label", "words"]) - ngram = NGram(inputCol="words", outputCol="ngrams") - ngramDataFrame = ngram.transform(wordDataFrame) - for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py deleted file mode 100644 index 833d93e976a7e..0000000000000 --- a/examples/src/main/python/ml/normalizer_example.py +++ /dev/null @@ -1,41 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Normalizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="NormalizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Normalize each Vector using $L^1$ norm. - normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) - l1NormData = normalizer.transform(dataFrame) - - # Normalize each Vector using $L^\infty$ norm. - lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py deleted file mode 100644 index 7529dfd09213a..0000000000000 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="OneHotEncoderExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) - - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py deleted file mode 100644 index 8b66140a40a7a..0000000000000 --- a/examples/src/main/python/ml/pca_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="PCAExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] - df = sqlContext.createDataFrame(data,["features"]) - pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") - model = pca.fit(df) - result = model.transform(df).select("pcaFeatures") - result.show(truncate=False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py deleted file mode 100644 index 030a6132a451a..0000000000000 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="PolynomialExpansionExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) - px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") - polyDF = px.transform(df) - for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py deleted file mode 100644 index b544a14700762..0000000000000 --- a/examples/src/main/python/ml/rformula_example.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import RFormula -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="RFormulaExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) - formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") - output = formula.fit(dataset).transform(dataset) - output.select("features", "label").show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py deleted file mode 100644 index 139acecbfb53f..0000000000000 --- a/examples/src/main/python/ml/standard_scaler_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StandardScaler -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StandardScalerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - - # Compute summary statistics by fitting the StandardScaler - scalerModel = scaler.fit(dataFrame) - - # Normalize each feature to have unit standard deviation. - scaledData = scalerModel.transform(dataFrame) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py deleted file mode 100644 index 01f94af8ca752..0000000000000 --- a/examples/src/main/python/ml/stopwords_remover_example.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StopWordsRemover -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StopWordsRemoverExample") - sqlContext = SQLContext(sc) - - # $example on$ - sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) - ], ["label", "raw"]) - - remover = StopWordsRemover(inputCol="raw", outputCol="filtered") - remover.transform(sentenceData).show(truncate=False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py deleted file mode 100644 index 58a8cb5d56b73..0000000000000 --- a/examples/src/main/python/ml/string_indexer_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StringIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StringIndexerExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) - indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - indexed = indexer.fit(df).transform(df) - indexed.show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py deleted file mode 100644 index ce9b225be5357..0000000000000 --- a/examples/src/main/python/ml/tokenizer_example.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Tokenizer, RegexTokenizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="TokenizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") - ], ["label", "sentence"]) - tokenizer = Tokenizer(inputCol="sentence", outputCol="words") - wordsDataFrame = tokenizer.transform(sentenceDataFrame) - for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) - regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") - # alternatively, pattern="\\w+", gaps(False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py deleted file mode 100644 index 04f64839f188d..0000000000000 --- a/examples/src/main/python/ml/vector_assembler_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="VectorAssemblerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) - assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") - output = assembler.transform(dataset) - print(output.select("features", "clicked").first()) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py deleted file mode 100644 index cc00d1454f2e0..0000000000000 --- a/examples/src/main/python/ml/vector_indexer_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import VectorIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="VectorIndexerExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) - indexerModel = indexer.fit(data) - - # Create new column "indexed" with categorical values transformed to indices - indexedData = indexerModel.transform(data) - # $example off$ - - sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala deleted file mode 100644 index e724aa587294b..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Binarizer -// $example off$ -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} - -object BinarizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BinarizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - // $example on$ - val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) - val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - - val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - - val binarizedDataFrame = binarizer.transform(dataFrame) - val binarizedFeatures = binarizedDataFrame.select("binarized_feature") - binarizedFeatures.collect().foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala deleted file mode 100644 index 30c2776d39688..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Bucketizer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object BucketizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BucketizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - - val data = Array(-0.5, -0.3, 0.0, 0.2) - val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - - val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - - // Transform original data into its bucket index. - val bucketedData = bucketizer.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala deleted file mode 100644 index 314c2c28a2a10..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object DCTExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("DCTExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) - - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - - val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) - - val dctDf = dct.transform(df) - dctDf.select("featuresDCT").show(3) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala deleted file mode 100644 index ac50bb7b2b155..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object ElementwiseProductExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("ElementwiseProductExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Create some vector data; also works for sparse vectors - val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - - val transformingVector = Vectors.dense(0.0, 1.0, 2.0) - val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - - // Batch transform the vectors to create new column: - transformer.transform(dataFrame).show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala deleted file mode 100644 index dac3679a5bf7e..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.MinMaxScaler -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object MinMaxScalerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("MinMaxScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - - // Compute summary statistics and generate MinMaxScalerModel - val scalerModel = scaler.fit(dataFrame) - - // rescale each feature to range [min, max]. - val scaledData = scalerModel.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala deleted file mode 100644 index 8a85f71b56f3d..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.NGram -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object NGramExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NGramExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) - )).toDF("label", "words") - - val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") - val ngramDataFrame = ngram.transform(wordDataFrame) - ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala deleted file mode 100644 index 17571f0aad793..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Normalizer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object NormalizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NormalizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - // Normalize each Vector using $L^1$ norm. - val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) - - val l1NormData = normalizer.transform(dataFrame) - - // Normalize each Vector using $L^\infty$ norm. - val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala deleted file mode 100644 index 4512736943dd5..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object OneHotEncoderExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("OneHotEncoderExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") - val encoded = encoder.transform(indexed) - encoded.select("id", "categoryVec").foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala deleted file mode 100644 index a18d4f33973d8..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object PCAExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PCAExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) - ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) - val pcaDF = pca.transform(df) - val result = pcaDF.select("pcaFeatures") - result.show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala deleted file mode 100644 index b8e9e6952a5ea..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object PolynomialExpansionExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PolynomialExpansionExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) - ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) - val polyDF = polynomialExpansion.transform(df) - polyDF.select("polyFeatures").take(3).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala deleted file mode 100644 index 286866edea502..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.RFormula -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object RFormulaExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RFormulaExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) - )).toDF("id", "country", "hour", "clicked") - val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") - val output = formula.fit(dataset).transform(dataset) - output.select("features", "label").show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala deleted file mode 100644 index 646ce0f13ecf5..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StandardScaler -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StandardScalerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StandardScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - - // Compute summary statistics by fitting the StandardScaler. - val scalerModel = scaler.fit(dataFrame) - - // Normalize each feature to have unit standard deviation. - val scaledData = scalerModel.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala deleted file mode 100644 index 655ffce08d3ab..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StopWordsRemover -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StopWordsRemoverExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StopWordsRemoverExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") - - val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) - )).toDF("id", "raw") - - remover.transform(dataSet).show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala deleted file mode 100644 index 1be8a5f33f7c0..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StringIndexer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StringIndexerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StringIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) - ).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - - val indexed = indexer.fit(df).transform(df) - indexed.show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala deleted file mode 100644 index 01e0d1388a2f4..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object TokenizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("TokenizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") - )).toDF("label", "sentence") - - val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") - val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - - val tokenized = tokenizer.transform(sentenceDataFrame) - tokenized.select("words", "label").take(3).foreach(println) - val regexTokenized = regexTokenizer.transform(sentenceDataFrame) - regexTokenized.select("words", "label").take(3).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala deleted file mode 100644 index d527924419f81..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorAssemblerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorAssemblerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) - ).toDF("id", "hour", "mobile", "userFeatures", "clicked") - - val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") - - val output = assembler.transform(dataset) - println(output.select("features", "clicked").first()) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala deleted file mode 100644 index 14279d610fda8..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.VectorIndexer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorIndexerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) - - val indexerModel = indexer.fit(data) - - val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - - // Create new column "indexed" with categorical values transformed to indices - val indexedData = indexerModel.transform(data) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala deleted file mode 100644 index 04f19829eff87..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.Row -import org.apache.spark.sql.types.StructType -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorSlicerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorSlicerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) - - val defaultAttr = NumericAttribute.defaultAttr - val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) - val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - - val dataRDD = sc.parallelize(data) - val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) - - val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - - slicer.setIndices(Array(1)).setNames(Array("f3")) - // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - - val output = slicer.transform(dataset) - println(output.select("userFeatures", "features").first()) - // $example off$ - sc.stop() - } -} -// scalastyle:on println From e3735ce1602826f0a8e0ca9e08730923843449ee Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 8 Dec 2015 14:34:47 +0000 Subject: [PATCH 1158/2089] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Fix commons-collection group ID to commons-collections for version 3.x Patches earlier PR at https://github.com/apache/spark/pull/9731 Author: Sean Owen Closes #10198 from srowen/SPARK-11652.2. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ae2ff8878b0a5..5daca03f61436 100644 --- a/pom.xml +++ b/pom.xml @@ -478,7 +478,7 @@ ${commons.math3.version}
    - org.apache.commons + commons-collections commons-collections ${commons.collections.version} From 6cb06e8711fd6ac10c57faeb94bc323cae1cef27 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Tue, 8 Dec 2015 11:44:51 -0600 Subject: [PATCH 1159/2089] [SPARK-11155][WEB UI] Stage summary json should include stage duration The json endpoint for stages doesn't include information on the stage duration that is present in the UI. This looks like a simple oversight, they should be included. eg., the metrics should be included at api/v1/applications//stages. Metrics I've added are: submissionTime, firstTaskLaunchedTime and completionTime Author: Xin Ren Closes #10107 from keypointt/SPARK-11155. --- .../status/api/v1/AllStagesResource.scala | 14 ++++- .../org/apache/spark/status/api/v1/api.scala | 3 + .../complete_stage_list_json_expectation.json | 11 +++- .../failed_stage_list_json_expectation.json | 5 +- .../one_stage_attempt_json_expectation.json | 5 +- .../one_stage_json_expectation.json | 5 +- .../stage_list_json_expectation.json | 14 ++++- ...ist_with_accumulable_json_expectation.json | 5 +- ...age_with_accumulable_json_expectation.json | 5 +- .../api/v1/AllStagesResourceSuite.scala | 62 +++++++++++++++++++ project/MimaExcludes.scala | 4 +- 11 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 24a0b5220695c..31b4dd7c0f427 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,8 +17,8 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.{GET, PathParam, Produces, QueryParam} import javax.ws.rs.core.MediaType +import javax.ws.rs.{GET, Produces, QueryParam} import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} @@ -59,6 +59,15 @@ private[v1] object AllStagesResource { stageUiData: StageUIData, includeDetails: Boolean): StageData = { + val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + + val firstTaskLaunchedTime: Option[Date] = + if (taskLaunchTimes.nonEmpty) { + Some(new Date(taskLaunchTimes.min)) + } else { + None + } + val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) } else { @@ -92,6 +101,9 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + submissionTime = stageInfo.submissionTime.map(new Date(_)), + firstTaskLaunchedTime, + completionTime = stageInfo.completionTime.map(new Date(_)), inputBytes = stageUiData.inputBytes, inputRecords = stageUiData.inputRecords, outputBytes = stageUiData.outputBytes, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index baddfc50c1a40..5feb1dc2e5b74 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -120,6 +120,9 @@ class StageData private[spark]( val numFailedTasks: Int, val executorRunTime: Long, + val submissionTime: Option[Date], + val firstTaskLaunchedTime: Option[Date], + val completionTime: Option[Date], val inputBytes: Long, val inputRecords: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 31ac9beea8788..8f8067f86d57f 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -64,4 +73,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index bff6a4f69d077..08b692eda8028 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -20,4 +23,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 111cb8163eb3d..b07011d4f113f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index ef339f89afa45..2f71520549e1f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 056fac7088594..5b957ed549556 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -72,6 +81,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -86,4 +98,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 79ccacd309693..afa425f8c27bb 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -24,4 +27,4 @@ "name" : "my counter", "value" : "5050" } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 32d5731676ad5..12665a152c9ec 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -239,4 +242,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala new file mode 100644 index 0000000000000..88817dccf3497 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1 + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} +import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} + +class AllStagesResourceSuite extends SparkFunSuite { + + def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { + val tasks = new HashMap[Long, TaskUIData] + taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => + tasks(idx.toLong) = new TaskUIData( + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None, None) + } + + val stageUiData = new StageUIData() + stageUiData.taskData = tasks + val status = StageStatus.ACTIVE + val stageInfo = new StageInfo( + 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc", Seq.empty) + val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) + + stageData.firstTaskLaunchedTime + } + + test("firstTaskLaunchedTime when there are no tasks") { + val result = getFirstTaskLaunchTime(Seq()) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks but none launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks and some launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) + assert(result == Some(new Date(1449255596000L))) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b4aa6adc3c620..685cb419ca8a7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -132,7 +132,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") ) ++ Seq ( ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationInfo.this") + "org.apache.spark.status.api.v1.ApplicationInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.StageData.this") ) ++ Seq( // SPARK-11766 add toJson to Vector ProblemFilters.exclude[MissingMethodProblem]( From 75c60bf4ba91e45e76a6e27f054a1c550eb6ff94 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 8 Dec 2015 10:01:44 -0800 Subject: [PATCH 1160/2089] [SPARK-12074] Avoid memory copy involving ByteBuffer.wrap(ByteArrayOutputStream.toByteArray) SPARK-12060 fixed JavaSerializerInstance.serialize This PR applies the same technique on two other classes. zsxwing Author: tedyu Closes #10177 from tedyu/master. --- core/src/main/scala/org/apache/spark/scheduler/Task.scala | 7 +++---- .../main/scala/org/apache/spark/storage/BlockManager.scala | 4 ++-- .../org/apache/spark/util/ByteBufferOutputStream.scala | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5fe5ae8c45819..d4bc3a5c900f7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -27,8 +27,7 @@ import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** @@ -172,7 +171,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new ByteArrayOutputStream(4096) + val out = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -193,7 +192,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task) Utils.writeByteBuffer(taskBytes, out) - ByteBuffer.wrap(out.toByteArray) + out.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ab0007fb78993..ed05143877e20 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1202,9 +1202,9 @@ private[spark] class BlockManager( blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) + val byteStream = new ByteBufferOutputStream(4096) dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) + byteStream.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index 92e45224db81c..8527e3ae692e2 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -23,7 +23,9 @@ import java.nio.ByteBuffer /** * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer */ -private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { +private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) { + + def this() = this(32) def toByteBuffer: ByteBuffer = { return ByteBuffer.wrap(buf, 0, count) From 381f17b540d92507cc07adf18bce8bc7e5ca5407 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Dec 2015 10:13:40 -0800 Subject: [PATCH 1161/2089] [SPARK-12201][SQL] add type coercion rule for greatest/least checked with hive, greatest/least should cast their children to a tightest common type, i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error` Author: Wenchen Fan Closes #10196 from cloud-fan/type-coercion. --- .../catalyst/analysis/HiveTypeCoercion.scala | 14 +++++++++++ .../ExpressionTypeCheckingSuite.scala | 10 ++++++++ .../analysis/HiveTypeCoercionSuite.scala | 23 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 29502a59915f0..dbcbd6854b474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -594,6 +594,20 @@ object HiveTypeCoercion { case None => c } + case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case None => g + } + + case l @ Least(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case None => l + } + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ba1866efc84e1..915c585ec91fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, + 'decimalField.decimal(8, 0), 'arrayField.array(StringType), 'mapField.map(StringType, LongType)) @@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } + + test("check types for Greatest/Least") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('intField, 'stringField)), "should all have the same type") + assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") + assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d3fafaae89938..142915056f451 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + operator(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Nil)) + } + } + test("nanvl casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), From c0b13d5565c45ae2acbe8cfb17319c92b6a634e4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 Dec 2015 10:15:58 -0800 Subject: [PATCH 1162/2089] [SPARK-12195][SQL] Adding BigDecimal, Date and Timestamp into Encoder This PR is to add three more data types into Encoder, including `BigDecimal`, `Date` and `Timestamp`. marmbrus cloud-fan rxin Could you take a quick look at these three types? Not sure if it can be merged to 1.6. Thank you very much! Author: gatorsmile Closes #10188 from gatorsmile/dataTypesinEncoder. --- .../scala/org/apache/spark/sql/Encoder.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c40061ae0aafd..3ca5ade7f30f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,24 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + /** + * An encoder for nullable decimal type. + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ae47f4fe0e231..383a2d0badb53 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -18,6 +18,9 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; import java.util.*; import scala.Tuple2; @@ -385,6 +388,20 @@ public void testNestedTupleEncoder() { Assert.assertEquals(data3, ds3.collectAsList()); } + @Test + public void testPrimitiveEncoder() { + Encoder> encoder = + Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), + Encoders.FLOAT()); + List> data = + Arrays.asList(new Tuple5( + 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), + Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); + Dataset> ds = + context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + @Test public void testTypedAggregation() { Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); From 5d96a710a5ed543ec81e383620fc3b2a808b26a1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 Dec 2015 10:25:57 -0800 Subject: [PATCH 1163/2089] [SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs This PR contains the following updates: - Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`. - Replaced all the `queryExecution.analyzed` by the function call `logicalPlan` - A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`) - A few API descriptions are wrong. (e.g., `mapPartitions`) marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you! Author: gatorsmile Closes #10184 from gatorsmile/datasetClean. --- .../scala/org/apache/spark/sql/Dataset.scala | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d6bb1d2ad8e50..3bd18a14f9e8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -67,15 +67,21 @@ class Dataset[T] private[sql]( tEncoder: Encoder[T]) extends Queryable with Serializable { /** - * An unresolved version of the internal encoder for the type of this dataset. This one is marked - * implicit so that we can use it when constructing new [[Dataset]] objects that have the same - * object type (that will be possibly resolved to a different schema). + * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is + * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the + * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) + unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + + /** + * The encoder where the expressions used to construct an object from an input row have been + * bound to the ordinals of the given schema. + */ + private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) private implicit def classTag = resolvedTEncoder.clsTag @@ -89,7 +95,7 @@ class Dataset[T] private[sql]( override def schema: StructType = resolvedTEncoder.schema /** - * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format. * @since 1.6.0 */ override def printSchema(): Unit = toDF().printSchema() @@ -111,7 +117,7 @@ class Dataset[T] private[sql]( * ************* */ /** - * Returns a new `Dataset` where each record has been mapped on to the specified type. The + * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The * method used to map columns depend on the type of `U`: * - When `U` is a class, fields for the class will be mapped to columns of the same name * (case sensitivity is determined by `spark.sql.caseSensitive`) @@ -145,7 +151,7 @@ class Dataset[T] private[sql]( def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) /** - * Returns this Dataset. + * Returns this [[Dataset]]. * @since 1.6.0 */ // This is declared with parentheses to prevent the Scala compiler from treating @@ -153,15 +159,12 @@ class Dataset[T] private[sql]( def toDS(): Dataset[T] = this /** - * Converts this Dataset to an RDD. + * Converts this [[Dataset]] to an [[RDD]]. * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = resolvedTEncoder - val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => - val bound = tEnc.bind(input) - iter.map(bound.fromRow) + iter.map(boundTEncoder.fromRow) } } @@ -189,7 +192,7 @@ class Dataset[T] private[sql]( def show(numRows: Int): Unit = show(numRows, truncate = true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters * will be truncated, and all cells will be aligned right. * * @since 1.6.0 @@ -197,7 +200,7 @@ class Dataset[T] private[sql]( def show(): Unit = show(20) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[Dataset]] in a tabular form. * * @param truncate Whether truncate long strings. If true, strings more than 20 characters will * be truncated and all cells will be aligned right @@ -207,7 +210,7 @@ class Dataset[T] private[sql]( def show(truncate: Boolean): Unit = show(20, truncate) /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[Dataset]] in a tabular form. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -291,7 +294,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { @@ -307,7 +310,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { @@ -341,28 +344,28 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Runs `func` on each element of this Dataset. + * Runs `func` on each element of this [[Dataset]]. * @since 1.6.0 */ def foreach(func: T => Unit): Unit = rdd.foreach(func) /** * (Java-specific) - * Runs `func` on each element of this Dataset. + * Runs `func` on each element of this [[Dataset]]. * @since 1.6.0 */ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) - * Runs `func` on each partition of this Dataset. + * Runs `func` on each partition of this [[Dataset]]. * @since 1.6.0 */ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) /** * (Java-specific) - * Runs `func` on each partition of this Dataset. + * Runs `func` on each partition of this [[Dataset]]. * @since 1.6.0 */ def foreachPartition(func: ForeachPartitionFunction[T]): Unit = @@ -374,7 +377,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -382,7 +385,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -390,11 +393,11 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { - val inputPlan = queryExecution.analyzed + val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) @@ -429,18 +432,18 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ - def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = - groupBy(f.call(_))(encoder) + def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(func.call(_))(encoder) /* ****************** * * Typed Relational * * ****************** */ /** - * Selects a set of column based expressions. + * Returns a new [[DataFrame]] by selecting a set of column based expressions. * {{{ * df.select($"colA", $"colB" + 1) * }}} @@ -464,8 +467,8 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder.bind(queryExecution.analyzed.output), - queryExecution.analyzed.output).named :: Nil, + boundTEncoder, + logicalPlan.output).named :: Nil, logicalPlan)) } @@ -477,7 +480,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named) + columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) @@ -654,7 +657,7 @@ class Dataset[T] private[sql]( * Returns an array that contains all the elements in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @since 1.6.0 @@ -662,17 +665,14 @@ class Dataset[T] private[sql]( def collect(): Array[T] = { // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders // to convert the rows into objects of type T. - val tEnc = resolvedTEncoder - val input = queryExecution.analyzed.output - val bound = tEnc.bind(input) - queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow) + queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) } /** * Returns an array that contains all the elements in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @since 1.6.0 @@ -683,7 +683,7 @@ class Dataset[T] private[sql]( * Returns the first `num` elements of this [[Dataset]] as an array. * * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -692,7 +692,7 @@ class Dataset[T] private[sql]( * Returns the first `num` elements of this [[Dataset]] as an array. * * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) From 872a2ee281d84f40a786f765bf772cdb06e8c956 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Dec 2015 10:29:51 -0800 Subject: [PATCH 1164/2089] [SPARK-10393] use ML pipeline in LDA example jira: https://issues.apache.org/jira/browse/SPARK-10393 Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline. Author: Yuhao Yang Author: yuhaoyang Closes #8551 from hhbyyh/ldaExUpdate. --- .../spark/examples/mllib/LDAExample.scala | 153 +++++------------- 1 file changed, 40 insertions(+), 113 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 75b0f69cf91aa..70010b05e4345 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,19 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import java.text.BreakIterator - -import scala.collection.mutable - import scopt.OptionParser import org.apache.log4j.{Level, Logger} - -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with @@ -192,115 +189,45 @@ object LDAExample { vocabSize: Int, stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + // Get dataset of document texts // One document per line in each text file. If the input consists of many small files, // this can result in a large number of small partitions, which can degrade performance. // In this case, consider using coalesce() to create fewer, larger partitions. - val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) - - // Split text into words - val tokenizer = new SimpleTokenizer(sc, stopwordFile) - val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => - id -> tokenizer.getWords(text) - } - tokenized.cache() - - // Counts words: RDD[(word, wordCount)] - val wordCounts: RDD[(String, Long)] = tokenized - .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } - .reduceByKey(_ + _) - wordCounts.cache() - val fullVocabSize = wordCounts.count() - // Select vocab - // (vocab: Map[word -> id], total tokens after selecting vocab) - val (vocab: Map[String, Int], selectedTokenCount: Long) = { - val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocabSize) - } - (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) - } - - val documents = tokenized.map { case (id, tokens) => - // Filter tokens by vocabulary, and create word count vector representation of document. - val wc = new mutable.HashMap[Int, Int]() - tokens.foreach { term => - if (vocab.contains(term)) { - val termIndex = vocab(term) - wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 - } - } - val indices = wc.keys.toArray.sorted - val values = indices.map(i => wc(i).toDouble) - - val sb = Vectors.sparse(vocab.size, indices, values) - (id, sb) - } - - val vocabArray = new Array[String](vocab.size) - vocab.foreach { case (term, i) => vocabArray(i) = term } - - (documents, vocabArray, selectedTokenCount) - } -} - -/** - * Simple Tokenizer. - * - * TODO: Formalize the interface, and make this a public class in mllib.feature - */ -private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { - - private val stopwords: Set[String] = if (stopwordFile.isEmpty) { - Set.empty[String] - } else { - val stopwordText = sc.textFile(stopwordFile).collect() - stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet - } - - // Matches sequences of Unicode letters - private val allWordRegex = "^(\\p{L}*)$".r - - // Ignore words shorter than this length. - private val minWordLength = 3 - - def getWords(text: String): IndexedSeq[String] = { - - val words = new mutable.ArrayBuffer[String]() - - // Use Java BreakIterator to tokenize text into words. - val wb = BreakIterator.getWordInstance - wb.setText(text) - - // current,end index start,end of each word - var current = wb.first() - var end = wb.next() - while (end != BreakIterator.DONE) { - // Convert to lowercase - val word: String = text.substring(current, end).toLowerCase - // Remove short words and strings that aren't only letters - word match { - case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => - words += w - case _ => - } - - current = end - try { - end = wb.next() - } catch { - case e: Exception => - // Ignore remaining text in line. - // This is a known bug in BreakIterator (for some Java versions), - // which fails when it sees certain characters. - end = BreakIterator.DONE - } + val df = sc.textFile(paths.mkString(",")).toDF("docs") + val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) { + Array.empty[String] + } else { + val stopWordText = sc.textFile(stopwordFile).collect() + stopWordText.flatMap(_.stripMargin.split("\\s+")) } - words + val tokenizer = new RegexTokenizer() + .setInputCol("docs") + .setOutputCol("rawTokens") + val stopWordsRemover = new StopWordsRemover() + .setInputCol("rawTokens") + .setOutputCol("tokens") + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) + val countVectorizer = new CountVectorizer() + .setVocabSize(vocabSize) + .setInputCol("tokens") + .setOutputCol("features") + + val pipeline = new Pipeline() + .setStages(Array(tokenizer, stopWordsRemover, countVectorizer)) + + val model = pipeline.fit(df) + val documents = model.transform(df) + .select("features") + .map { case Row(features: Vector) => features } + .zipWithIndex() + .map(_.swap) + + (documents, + model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary + documents.map(_._2.numActives).sum().toLong) // total token count } - } // scalastyle:on println From 4bcb894948c1b7294d84e2bf58abb1d79e6759c6 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Tue, 8 Dec 2015 10:52:17 -0800 Subject: [PATCH 1165/2089] [SPARK-12205][SQL] Pivot fails Analysis when aggregate is UnresolvedFunction Delays application of ResolvePivot until all aggregates are resolved to prevent problems with UnresolvedFunction and adds unit test Author: Andrew Ray Closes #10202 from aray/sql-pivot-unresolved-function. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../scala/org/apache/spark/sql/DataFramePivotSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d3163dcd4db94..ca00a5e49f668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -259,7 +259,7 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved => p + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index fc53aba68ebb7..bc1a336ea4fd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -85,4 +85,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } + + test("pivot with UnresolvedFunction") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg("earnings" -> "sum"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } } From 5cb4695051e3dac847b1ea14d62e54dcf672c31c Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Dec 2015 11:46:26 -0800 Subject: [PATCH 1166/2089] [SPARK-11605][MLLIB] ML 1.6 QA: API: Java compatibility, docs jira: https://issues.apache.org/jira/browse/SPARK-11605 Check Java compatibility for MLlib for this release. fix: 1. `StreamingTest.registerStream` needs java friendly interface. 2. `GradientBoostedTreesModel.computeInitialPredictionAndError` and `GradientBoostedTreesModel.updatePredictionError` has java compatibility issue. Mark them as `developerAPI`. TBD: [updated] no fix for now per discussion. `org.apache.spark.mllib.classification.LogisticRegressionModel` `public scala.Option getThreshold();` has wrong return type for Java invocation. `SVMModel` has the similar issue. Yet adding a `scala.Option getThreshold()` would result in an overloading error due to the same function signature. And adding a new function with different name seems to be not necessary. cc jkbradley feynmanliang Author: Yuhao Yang Closes #10102 from hhbyyh/javaAPI. --- .../examples/mllib/StreamingTestExample.scala | 4 +- .../spark/mllib/stat/test/StreamingTest.scala | 50 +++++++++++++++---- .../mllib/tree/model/treeEnsembleModels.scala | 6 ++- .../spark/mllib/stat/JavaStatisticsSuite.java | 38 ++++++++++++-- .../spark/mllib/stat/StreamingTestSuite.scala | 25 +++++----- 5 files changed, 96 insertions(+), 27 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index b6677c6476639..49f5df39443e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.SparkConf -import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.util.Utils @@ -66,7 +66,7 @@ object StreamingTestExample { // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { - case Array(label, value) => (label.toBoolean, value.toDouble) + case Array(label, value) => BinarySample(label.toBoolean, value.toDouble) }) val streamingTest = new StreamingTest() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 75c6a51d09571..e990fe0768bc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,12 +17,30 @@ package org.apache.spark.mllib.stat.test +import scala.beans.BeanInfo + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.JavaDStream import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter +/** + * Class that represents the group and value of a sample. + * + * @param isExperiment if the sample is of the experiment group. + * @param value numeric value of the observation. + */ +@Since("1.6.0") +@BeanInfo +case class BinarySample @Since("1.6.0") ( + @Since("1.6.0") isExperiment: Boolean, + @Since("1.6.0") value: Double) { + override def toString: String = { + s"($isExperiment, $value)" + } +} + /** * :: Experimental :: * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The @@ -83,13 +101,13 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** * Register a [[DStream]] of values for significance testing. * - * @param data stream of (key,value) pairs where the key denotes group membership (true = - * experiment, false = control) and the value is the numerical metric to test for - * significance + * @param data stream of BinarySample(key,value) pairs where the key denotes group membership + * (true = experiment, false = control) and the value is the numerical metric to + * test for significance * @return stream of significance testing results */ @Since("1.6.0") - def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = { val dataAfterPeacePeriod = dropPeacePeriod(data) val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) val pairedSummaries = pairSummaries(summarizedData) @@ -97,9 +115,22 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { testMethod.doTest(pairedSummaries) } + /** + * Register a [[JavaDStream]] of values for significance testing. + * + * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes + * group (true = experiment, false = control) and the value is the numerical metric + * to test for significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = { + JavaDStream.fromDStream(registerStream(data.dstream)) + } + /** Drop all batches inside the peace period. */ private[stat] def dropPeacePeriod( - data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data: DStream[BinarySample]): DStream[BinarySample] = { data.transform { (rdd, time) => if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { rdd @@ -111,9 +142,10 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** Compute summary statistics over each key and the specified test window size. */ private[stat] def summarizeByKeyAndWindow( - data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = { + val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value)) if (this.windowSize == 0) { - data.updateStateByKey[StatCounter]( + categoryValuePair.updateStateByKey[StatCounter]( (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { val newSummary = oldSummary.getOrElse(new StatCounter()) newSummary.merge(newValues) @@ -121,7 +153,7 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { }) } else { val windowDuration = data.slideDuration * this.windowSize - data + categoryValuePair .groupByKeyAndWindow(windowDuration) .mapValues { values => val summary = new StatCounter() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 3f427f0be3af2..feabcee24fa2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Since +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -186,6 +186,7 @@ class GradientBoostedTreesModel @Since("1.2.0") ( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** + * :: DeveloperApi :: * Compute the initial predictions and errors for a dataset for the first * iteration of gradient boosting. * @param data: training data. @@ -196,6 +197,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to every sample. */ @Since("1.4.0") + @DeveloperApi def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -209,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { } /** + * :: DeveloperApi :: * Update a zipped predictionError RDD * (as obtained with computeInitialPredictionAndError) * @param data: training data. @@ -220,6 +223,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to each sample. */ @Since("1.4.0") + @DeveloperApi def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 4795809e47a46..66b2ceacb05f2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -18,34 +18,49 @@ package org.apache.spark.mllib.stat; import java.io.Serializable; - import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.BinarySample; import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; +import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; + private transient JavaStreamingContext ssc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaStatistics"); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("JavaStatistics") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + sc = new JavaSparkContext(conf); + ssc = new JavaStreamingContext(sc, new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After public void tearDown() { - sc.stop(); + ssc.stop(); + ssc = null; sc = null; } @@ -76,4 +91,21 @@ public void chiSqTest() { new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); ChiSqTestResult[] testResults = Statistics.chiSqTest(data); } + + @Test + public void streamingTest() { + List trainingBatch = Arrays.asList( + new BinarySample(true, 1.0), + new BinarySample(false, 2.0)); + JavaDStream training = + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + int numBatches = 2; + StreamingTest model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod("welch"); + model.registerStream(training); + attachTestOutputStream(training); + runStreams(ssc, numBatches, numBatches); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index d3e9ef4ff079c..3c657c8cfe743 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, + WelchTTest, BinarySample} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter @@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) @@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -157,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, - (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream)) val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) val outputCounts = outputBatches.flatten.map(_._2.count) @@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream)) val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) @@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { .setPeacePeriod(0) val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) - .map(batch => batch.filter(_._1)) // only keep one test group + .map(batch => batch.filter(_.isExperiment)) // only keep one test group // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) @@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { stdevA: Double, meanB: Double, stdevB: Double, - seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = { val rand = new XORShiftRandom(seed) val numTrues = pointsPerBatch / 2 val data = (0 until numBatches).map { i => - (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++ (pointsPerBatch / 2 until pointsPerBatch).map { idx => - (false, meanB + stdevB * rand.nextGaussian()) + BinarySample(false, meanB + stdevB * rand.nextGaussian()) } } From 06746b3005e5e9892d0314bee3bfdfaebc36d3d4 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Tue, 8 Dec 2015 12:45:34 -0800 Subject: [PATCH 1167/2089] [SPARK-12159][ML] Add user guide section for IndexToString transformer Documentation regarding the `IndexToString` label transformer with code snippets in Scala/Java/Python. Author: BenFradet Closes #10166 from BenFradet/SPARK-12159. --- docs/ml-features.md | 104 +++++++++++++++--- .../examples/ml/JavaIndexToStringExample.java | 75 +++++++++++++ .../main/python/ml/index_to_string_example.py | 45 ++++++++ .../examples/ml/IndexToStringExample.scala | 60 ++++++++++ 4 files changed, 268 insertions(+), 16 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java create mode 100644 examples/src/main/python/ml/index_to_string_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 01d6abeb5ba6a..e15c26836affc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -835,10 +835,10 @@ dctDf.select("featuresDCT").show(3); `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string -values. When downstream pipeline components such as `Estimator` or -`Transformer` make use of this string-indexed label, you must set the input -column of the component to this string-indexed column name. In many cases, +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, you can set the input column with `setInputCol`. **Examples** @@ -951,9 +951,78 @@ indexed.show() + +## IndexToString + +Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices +back to a column containing the original labels as strings. The common use case +is to produce indices from labels with `StringIndexer`, train a model with those +indices and retrieve the original labels from the column of predicted indices +with `IndexToString`. However, you are free to supply your own labels. + +**Examples** + +Building on the `StringIndexer` example, let's assume we have the following +DataFrame with columns `id` and `categoryIndex`: + +~~~~ + id | categoryIndex +----|--------------- + 0 | 0.0 + 1 | 2.0 + 2 | 1.0 + 3 | 0.0 + 4 | 0.0 + 5 | 1.0 +~~~~ + +Applying `IndexToString` with `categoryIndex` as the input column, +`originalCategory` as the output column, we are able to retrieve our original +labels (they will be inferred from the columns' metadata): + +~~~~ + id | categoryIndex | originalCategory +----|---------------|----------------- + 0 | 0.0 | a + 1 | 2.0 | b + 2 | 1.0 | c + 3 | 0.0 | a + 4 | 0.0 | a + 5 | 1.0 | c +~~~~ + +
    +
    + +Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %} + +
    + +
    + +Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %} + +
    + +
    + +Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString) +for more details on the API. + +{% include_example python/ml/index_to_string_example.py %} + +
    +
    + ## OneHotEncoder -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
    @@ -979,10 +1048,11 @@ val indexer = new StringIndexer() .fit(df) val indexed = indexer.transform(df) -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") +val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) +encoded.select("id", "categoryVec").show() {% endhighlight %}
    @@ -1015,7 +1085,7 @@ JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(5, "c") )); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("category", DataTypes.StringType, false, Metadata.empty()) }); DataFrame df = sqlContext.createDataFrame(jrdd, schema); @@ -1029,6 +1099,7 @@ OneHotEncoder encoder = new OneHotEncoder() .setInputCol("categoryIndex") .setOutputCol("categoryVec"); DataFrame encoded = encoder.transform(indexed); +encoded.select("id", "categoryVec").show(); {% endhighlight %}
    @@ -1054,6 +1125,7 @@ model = stringIndexer.fit(df) indexed = model.transform(df) encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") encoded = encoder.transform(indexed) +encoded.select("id", "categoryVec").show() {% endhighlight %} @@ -1582,7 +1654,7 @@ from pyspark.mllib.linalg import Vectors data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), inputCol="vector", outputCol="transformedVector") transformer.transform(df).show() @@ -1837,15 +1909,15 @@ for more details on the API. sub-array of the original features. It is useful for extracting features from a vector column. `VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column -whose values are selected via those indices. There are two types of indices, +whose values are selected via those indices. There are two types of indices, 1. Integer indices that represents the indices into the vector, `setIndices()`; - 2. String indices that represents the names of features into the vector, `setNames()`. + 2. String indices that represents the names of features into the vector, `setNames()`. *This requires the vector column to have an `AttributeGroup` since the implementation matches on the name field of an `Attribute`.* -Specification by integer and string are both acceptable. Moreover, you can use integer index and +Specification by integer and string are both acceptable. Moreover, you can use integer index and string name simultaneously. At least one feature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indices and names. Note that if names of features are selected, an exception will be threw out when encountering with empty input attributes. @@ -1858,9 +1930,9 @@ followed by the selected names (in the order given). Suppose that we have a DataFrame with the column `userFeatures`: ~~~ - userFeatures + userFeatures ------------------ - [0.0, 10.0, 0.5] + [0.0, 10.0, 0.5] ~~~ `userFeatures` is a vector column that contains three user features. Assuming that the first column @@ -1874,7 +1946,7 @@ column named `features`: [0.0, 10.0, 0.5] | [10.0, 0.5] ~~~ -Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. `["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. ~~~ diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java new file mode 100644 index 0000000000000..3ccd6993261e2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.IndexToString; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaIndexToStringExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + IndexToString converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory"); + DataFrame converted = converter.transform(indexed); + converted.select("id", "originalCategory").show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py new file mode 100644 index 0000000000000..fb0ba2950bbd6 --- /dev/null +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import IndexToString, StringIndexer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="IndexToStringExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + + converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory") + converted = converter.transform(indexed) + + converted.select("id", "originalCategory").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala new file mode 100644 index 0000000000000..52537e5bb568d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{StringIndexer, IndexToString} +// $example off$ + +object IndexToStringExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("IndexToStringExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory") + + val converted = converter.transform(indexed) + converted.select("id", "originalCategory").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 2ff17bcfb1818a1b1e357343cd507883dcfdd2ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 Dec 2015 13:13:56 -0800 Subject: [PATCH 1168/2089] [SPARK-3873][BUILD] Add style checker to enforce import ordering. The checker tries to follow as closely as possible the guidelines of the code style document, and makes some decisions where the guide is not clear. In particular: - wildcard imports come first when there are other imports in the same package - multi-import blocks come before single imports - lower-case names inside multi-import blocks come before others In some projects, such as graphx, there seems to be a convention to separate o.a.s imports from the project's own; to simplify the checker, I chose not to allow that, which is a strict interpretation of the code style guide, even though I think it makes sense. Since the checks are based on syntax only, some edge cases may generate spurious warnings; for example, when class names start with a lower case letter (and are thus treated as a package name by the checker). The checker is currently only generating warnings, and since there are many of those, the build output does get a little noisy. The idea is to fix the code (and the checker, as needed) little by little instead of having a huge change that touches everywhere. Author: Marcelo Vanzin Closes #6502 from vanzin/SPARK-3873. --- scalastyle-config.xml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 050c3f360476f..dab1ebddc666e 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -153,7 +153,7 @@ This file is divided into 3 sections: @VisibleForTesting @@ -203,6 +203,18 @@ This file is divided into 3 sections: + + + + java,scala,3rdParty,spark + javax?\..+ + scala\..+ + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + From 9494521695a1f1526aae76c0aea34a3bead96251 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 8 Dec 2015 14:34:15 -0800 Subject: [PATCH 1169/2089] [SPARK-12187] *MemoryPool classes should not be fully public This patch tightens them to `private[memory]`. Author: Andrew Or Closes #10182 from andrewor14/memory-visibility. --- .../scala/org/apache/spark/memory/ExecutionMemoryPool.scala | 2 +- core/src/main/scala/org/apache/spark/memory/MemoryPool.scala | 2 +- .../main/scala/org/apache/spark/memory/StorageMemoryPool.scala | 2 +- .../scala/org/apache/spark/memory/UnifiedMemoryManager.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 7825bae425877..9023e1ac012b7 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -39,7 +39,7 @@ import org.apache.spark.Logging * @param lock a [[MemoryManager]] instance to synchronize on * @param poolName a human-readable name for this pool, for use in log messages */ -class ExecutionMemoryPool( +private[memory] class ExecutionMemoryPool( lock: Object, poolName: String ) extends MemoryPool(lock) with Logging { diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala index bfeec47e3892e..1b9edf9c43bda 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -27,7 +27,7 @@ import javax.annotation.concurrent.GuardedBy * to `Object` to avoid programming errors, since this object should only be used for * synchronization purposes. */ -abstract class MemoryPool(lock: Object) { +private[memory] abstract class MemoryPool(lock: Object) { @GuardedBy("lock") private[this] var _poolSize: Long = 0 diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 6a322eabf81ed..fc4f0357e9f16 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} * * @param lock a [[MemoryManager]] instance to synchronize on */ -class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { +private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { @GuardedBy("lock") private[this] var _memoryUsed: Long = 0L diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 48b4e23433e43..0f1ea9ab39c07 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -49,7 +49,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, val maxMemory: Long, - private val storageRegionSize: Long, + storageRegionSize: Long, numCores: Int) extends MemoryManager( conf, From 39594894232e0b70c5ca8b0df137da0d61223fd5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Dec 2015 15:58:35 -0800 Subject: [PATCH 1170/2089] [SPARK-12069][SQL] Update documentation with Datasets Author: Michael Armbrust Closes #10060 from marmbrus/docs. --- docs/_layouts/global.html | 2 +- docs/index.md | 2 +- docs/sql-programming-guide.md | 268 +++++++++++------- .../scala/org/apache/spark/sql/Encoder.scala | 48 +++- .../scala/org/apache/spark/sql/Column.scala | 21 +- 5 files changed, 237 insertions(+), 104 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 1b09e2221e173..0b5b0cd48af64 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -71,7 +71,7 @@
  • Spark Programming Guide
  • Spark Streaming
  • -
  • DataFrames and SQL
  • +
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • diff --git a/docs/index.md b/docs/index.md index f1d9e012c6cf0..ae26f97c86c21 100644 --- a/docs/index.md +++ b/docs/index.md @@ -87,7 +87,7 @@ options for deployment: in all supported languages (Scala, Java, Python, R) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL, Datasets, and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7b1d97baa3823..9f87accd30f40 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Spark SQL and DataFrame Guide +displayTitle: Spark SQL, DataFrames and Datasets Guide title: Spark SQL and DataFrames --- @@ -9,18 +9,51 @@ title: Spark SQL and DataFrames # Overview -Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided +by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, +Spark SQL uses this extra information to perform extra optimizations. There are several ways to +interact with Spark SQL including SQL, the DataFrames API and the Datasets API. When computing a result +the same execution engine is used, independent of which API/language you are using to express the +computation. This unification means that developers can easily switch back and forth between the +various APIs based on which provides the most natural way to express a given transformation. -Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. +All of the examples on this page use sample data included in the Spark distribution and can be run in +the `spark-shell`, `pyspark` shell, or `sparkR` shell. -# DataFrames +## SQL -A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. +One use of Spark SQL is to execute SQL queries written using either a basic SQL syntax or HiveQL. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to +configure this feature, please refer to the [Hive Tables](#hive-tables) section. When running +SQL from within another programming language the results will be returned as a [DataFrame](#DataFrames). +You can also interact with the SQL interface using the [command-line](#running-the-spark-sql-cli) +or over [JDBC/ODBC](#running-the-thrift-jdbcodbc-server). -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). +## DataFrames -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R/Python, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of [sources](#data-sources) such +as: structured data files, tables in Hive, external databases, or existing RDDs. +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), +[Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), +[Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). + +## Datasets + +A Dataset is a new experimental interface added in Spark 1.6 that tries to provide the benefits of +RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL's +optimized execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then manipulated +using functional transformations (map, flatMap, filter, etc.). + +The unified Dataset API can be used both in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset) and +[Java](api/java/index.html?org/apache/spark/sql/Dataset.html). Python does not yet have support for +the Dataset API, but due to its dynamic nature many of the benefits are already available (i.e. you can +access the field of a row by name naturally `row.columnName`). Full python support will be added +in a future release. + +# Getting Started ## Starting Point: SQLContext @@ -29,7 +62,7 @@ All of the examples on this page use sample data included in the Spark distribut The entry point into all functionality in Spark SQL is the [`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} val sc: SparkContext // An existing SparkContext. @@ -45,7 +78,7 @@ import sqlContext.implicits._ The entry point into all functionality in Spark SQL is the [`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. @@ -58,7 +91,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); The entry point into all relational functionality in Spark is the [`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one -of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext @@ -70,7 +103,7 @@ sqlContext = SQLContext(sc)
    The entry point into all relational functionality in Spark is the -`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight r %} sqlContext <- sparkRSQL.init(sc) @@ -82,18 +115,18 @@ sqlContext <- sparkRSQL.init(sc) In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a superset of the functionality provided by the basic `SQLContext`. Additional features include the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the -ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an existing Hive setup, and all of the data sources available to a `SQLContext` are still available. `HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using `HiveContext` -is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up to feature parity with a `HiveContext`. The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, this is recommended for most use cases. @@ -215,7 +248,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$).
    @@ -270,7 +303,7 @@ df.groupBy("age").count().show(); For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). @@ -331,7 +364,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). @@ -385,7 +418,7 @@ showDF(count(groupBy(df, "age"))) For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). @@ -398,14 +431,14 @@ The `sql` function on a `SQLContext` enables applications to run SQL queries pro
    {% highlight scala %} -val sqlContext = ... // An existing SQLContext +val sqlContext = ... // An existing SQLContext val df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    {% highlight java %} -SQLContext sqlContext = ... // An existing SQLContext +SQLContext sqlContext = ... // An existing SQLContext DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    @@ -428,15 +461,54 @@ df <- sql(sqlContext, "SELECT * FROM table")
    +## Creating Datasets + +Datasets are similar to RDDs, however, instead of using Java Serialization or Kryo they use +a specialized [Encoder](api/scala/index.html#org.apache.spark.sql.Encoder) to serialize the objects +for processing or transmitting over the network. While both encoders and standard serialization are +responsible for turning an object into bytes, encoders are code generated dynamically and use a format +that allows Spark to perform many operations like filtering, sorting and hashing without deserializing +the bytes back into an object. + +
    +
    + +{% highlight scala %} +// Encoders for most common types are automatically provided by importing sqlContext.implicits._ +val ds = Seq(1, 2, 3).toDS() +ds.map(_ + 1).collect() // Returns: Array(2, 3, 4) + +// Encoders are also created for case classes. +case class Person(name: String, age: Long) +val ds = Seq(Person("Andy", 32)).toDS() + +// DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name. +val path = "examples/src/main/resources/people.json" +val people = sqlContext.read.json(path).as[Person] + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} + +
    +
    + ## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first -method uses reflection to infer the schema of an RDD that contains specific types of objects. This +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating DataFrames is through a programmatic interface that allows you to -construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows +construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection @@ -445,11 +517,11 @@ you to construct DataFrames when the columns and their types are not known until
    The Scala interface for Spark SQL supports automatically converting an RDD containing case classes -to a DataFrame. The case class -defines the schema of the table. The names of the arguments to the case class are read using +to a DataFrame. The case class +defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be -registered as a table. Tables can be used in subsequent SQL statements. +registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. @@ -486,9 +558,9 @@ teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(printl
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain -nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a +nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. {% highlight java %} @@ -559,9 +631,9 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
    -Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by looking at the first row. Since we currently only look at the first +and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we plan to more completely infer the schema by looking at more data, similar to the inference that is performed on JSON files. @@ -780,7 +852,7 @@ for name in names.collect(): Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a DataFrame as a table allows you to run SQL queries over its data. This section +Registering a DataFrame as a table allows you to run SQL queries over its data. This section describes the general methods for loading and saving data using the Spark Data Sources and then goes into specific options that are available for the built-in data sources. @@ -834,9 +906,9 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") ### Manually Specifying Options You can also manually specify the data source that will be used along with any extra options -that you would like to pass to the data source. Data sources are specified by their fully qualified +that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    @@ -923,8 +995,8 @@ df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users. ### Save Modes Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if -present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. @@ -960,7 +1032,7 @@ new data.
    Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
    @@ -968,14 +1040,14 @@ new data. ### Saving to Persistent Tables When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the -`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the -contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables will still exist even after your Spark program has restarted, as long as you maintain your connection -to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` method on a `SQLContext` with the name of the table. By default `saveAsTable` will create a "managed table", meaning that the location of the data will -be controlled by the metastore. Managed tables will also have their data deleted automatically +be controlled by the metastore. Managed tables will also have their data deleted automatically when a table is dropped. ## Parquet Files @@ -1003,7 +1075,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the previous // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.write.parquet("people.parquet") -// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.read.parquet("people.parquet") @@ -1025,7 +1097,7 @@ DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write().parquet("people.parquet"); -// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); @@ -1051,7 +1123,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write.parquet("people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.read.parquet("people.parquet") @@ -1075,7 +1147,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. saveAsParquetFile(schemaPeople, "people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile <- parquetFile(sqlContext, "people.parquet") @@ -1110,10 +1182,10 @@ SELECT * FROM parquetTable ### Partition Discovery -Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in -the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For example, we can store all our previously used +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1155,7 +1227,7 @@ root {% endhighlight %} -Notice that the data types of the partitioning columns are automatically inferred. Currently, +Notice that the data types of the partitioning columns are automatically inferred. Currently, numeric data types and string type are supported. Sometimes users may not want to automatically infer the data types of the partitioning columns. For these use cases, the automatic type inference can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to @@ -1164,13 +1236,13 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w ### Schema Merging -Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with -a simple schema, and gradually add more columns to the schema as needed. In this way, users may end -up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we -turned it off by default starting from 1.5.0. You may enable it by +turned it off by default starting from 1.5.0. You may enable it by 1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the examples below), or @@ -1284,10 +1356,10 @@ processing. 1. Hive considers all columns nullable, while nullability in Parquet is significant Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a -Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: 1. Fields that have the same name in both schema must have the same data type regardless of - nullability. The reconciled field should have the data type of the Parquet side, so that + nullability. The reconciled field should have the data type of the Parquet side, so that nullability is respected. 1. The reconciled schema contains exactly those fields defined in Hive metastore schema. @@ -1298,8 +1370,8 @@ Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation r #### Metadata Refreshing -Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table -conversion is enabled, metadata of those converted tables are also cached. If these tables are +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are updated by Hive or other external tools, you need to refresh them manually to ensure consistent metadata. @@ -1362,7 +1434,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1400,7 +1472,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    The output committer class used by Parquet. The specified class needs to be a subclass of - org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a subclass of org.apache.parquet.hadoop.ParquetOutputCommitter.

    @@ -1628,7 +1700,7 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a `HiveContext`. When not configured by the +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the hive-site.xml, the context automatically creates `metastore_db` in the current directory and creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts @@ -1738,10 +1810,10 @@ The following options can be used to configure the version of Hive that is used enabled. When this option is chosen, spark.sql.hive.metastore.version must be either 1.2.1 or not defined.

  • maven
  • - Use Hive jars of specified version downloaded from Maven repositories. This configuration + Use Hive jars of specified version downloaded from Maven repositories. This configuration is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive - and its dependencies, including the correct version of Hadoop. These jars only need to be +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1776,7 +1848,7 @@ The following options can be used to configure the version of Hive that is used ## JDBC To Other Databases -Spark SQL also includes a data source that can read data from other databases using JDBC. This +Spark SQL also includes a data source that can read data from other databases using JDBC. This functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). This is because the results are returned as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. @@ -1786,7 +1858,7 @@ provide a ClassTag. run queries using Spark SQL). To get started you will need to include the JDBC driver for you particular database on the -spark classpath. For example, to connect to postgres from the Spark Shell you would run the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: {% highlight bash %} @@ -1794,7 +1866,7 @@ SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. The following options are supported: @@ -1807,8 +1879,8 @@ the Data Sources API. The following options are supported: @@ -1816,7 +1888,7 @@ the Data Sources API. The following options are supported: @@ -1825,7 +1897,7 @@ the Data Sources API. The following options are supported: @@ -1947,7 +2019,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ ## Other Configuration Options -The following options can also be used to tune the performance of query execution. It is possible +The following options can also be used to tune the performance of query execution. It is possible that these options will be deprecated in future release as more optimizations are performed automatically.
    Property NameMeaning
    dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a + The JDBC table that should be read. Note that anything that is valid in a FROM clause of + a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses.
    driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded + The class name of the JDBC driver needed to connect to this URL. This class will be loaded on the master and workers before running an JDBC commands to allow the driver to register itself with the JDBC subsystem.
    partitionColumn, lowerBound, upperBound, numPartitions - These options must all be specified if any of them is specified. They describe how to + These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the @@ -1938,7 +2010,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ spark.sql.inMemoryColumnarStorage.batchSize 10000 - Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data.
    @@ -1957,7 +2029,7 @@ that these options will be deprecated in future release as more optimizations ar @@ -1995,8 +2067,8 @@ To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to -specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of -all available options. By default, the server listens on localhost:10000. You may override this +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this behaviour via either environment variables, i.e.: {% highlight bash %} @@ -2062,10 +2134,10 @@ options. ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC - connection owns a copy of their own SQL configuration and temporary function registry. Cached - tables are still shared though. If you prefer to run the Thrift server in the old single-session - mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add + - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + connection owns a copy of their own SQL configuration and temporary function registry. Cached + tables are still shared though. If you prefer to run the Thrift server in the old single-session + mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`: {% highlight bash %} @@ -2077,20 +2149,20 @@ options. ## Upgrading From Spark SQL 1.4 to 1.5 - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with - code generation for expression evaluation. These features can both be disabled by setting + code generation for expression evaluation. These features can both be disabled by setting `spark.sql.tungsten.enabled` to `false`. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if - your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum - precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. - Timestamps are now stored at a precision of 1us, rather than 1ns - - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe @@ -2183,38 +2255,38 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false") ## Upgrading from Spark SQL 1.0-1.2 to 1.3 In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the -available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other -releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked as unstable (i.e., DeveloperAPI or Experimental). #### Rename of SchemaRDD to DataFrame The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has -been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD directly, but instead provide most of the functionality that RDDs provide though their own -implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for -some use cases. It is still recommended that users update their code to use `DataFrame` instead. +some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. #### Unification of the Java and Scala APIs Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) -that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. #### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought -all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. Users should now write `import sqlContext.implicits._`. @@ -2222,7 +2294,7 @@ Additionally, the implicit conversions now only augment RDDs that are composed o case classes or tuples) with a method `toDF`, instead of applying automatically. When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import -`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: `import org.apache.spark.sql.functions._`. #### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 3ca5ade7f30f1..bb0fdc4c3d83b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -19,20 +19,60 @@ package org.apache.spark.sql import java.lang.reflect.Modifier +import scala.annotation.implicitNotFound import scala.reflect.{ClassTag, classTag} +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** + * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. + * == Scala == + * Encoders are generally created automatically through implicits from a `SQLContext`. + * + * {{{ + * import sqlContext.implicits._ + * + * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder) + * }}} + * + * == Java == + * Encoders are specified by calling static methods on [[Encoders]]. + * + * {{{ + * List data = Arrays.asList("abc", "abc", "xyz"); + * Dataset ds = context.createDataset(data, Encoders.STRING()); + * }}} + * + * Encoders can be composed into tuples: + * + * {{{ + * Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + * List> data2 = Arrays.asList(new scala.Tuple2(1, "a"); + * Dataset> ds2 = context.createDataset(data2, encoder2); + * }}} + * + * Or constructed from Java Beans: + * + * {{{ + * Encoders.bean(MyClass.class); + * }}} + * + * == Implementation == + * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard + * against concurrent access if they reuse internal buffers to improve performance. * * @since 1.6.0 */ +@Experimental +@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + + "(Int, String, etc) and Product types (case classes) are supported by importing " + + "sqlContext.implicits._ Support for serializing other types will be added in future " + + "releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ @@ -43,10 +83,12 @@ trait Encoder[T] extends Serializable { } /** - * Methods for creating encoders. + * :: Experimental :: + * Methods for creating an [[Encoder]]. * * @since 1.6.0 */ +@Experimental object Encoders { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad6af481fadc4..d641fcac1c8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -73,7 +73,26 @@ class TypedColumn[-T, U]( /** * :: Experimental :: - * A column in a [[DataFrame]]. + * A column that will be computed based on the data in a [[DataFrame]]. + * + * A new column is constructed based on the input columns present in a dataframe: + * + * {{{ + * df("columnName") // On a specific DataFrame. + * col("columnName") // A generic column no yet associcated with a DataFrame. + * col("columnName.field") // Extracting a struct field + * col("`a.column.with.dots`") // Escape `.` in column names. + * $"columnName" // Scala short hand for a named column. + * expr("a + 1") // A column that is constructed from a parsed SQL Expression. + * lit("1") // A column that produces a literal (constant) value. + * }}} + * + * [[Column]] objects can be composed to form complex expressions: + * + * {{{ + * $"a" + 1 + * $"a" === $"b" + * }}} * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators From 765c67f5f2e0b1367e37883f662d313661e3a0d9 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 8 Dec 2015 18:40:21 -0800 Subject: [PATCH 1171/2089] [SPARK-8517][ML][DOC] Reorganizes the spark.ml user guide This PR moves pieces of the spark.ml user guide to reflect suggestions in SPARK-8517. It does not introduce new content, as requested. screen shot 2015-12-08 at 11 36 00 am Author: Timothy Hunter Closes #10207 from thunterdb/spark-8517. --- docs/_data/menu-ml.yaml | 18 +- docs/ml-advanced.md | 13 + docs/ml-ann.md | 62 -- docs/ml-classification-regression.md | 775 ++++++++++++++++++++++ docs/ml-clustering.md | 5 + docs/ml-features.md | 4 +- docs/ml-intro.md | 941 +++++++++++++++++++++++++++ docs/mllib-guide.md | 15 +- 8 files changed, 1752 insertions(+), 81 deletions(-) create mode 100644 docs/ml-advanced.md delete mode 100644 docs/ml-ann.md create mode 100644 docs/ml-classification-regression.md create mode 100644 docs/ml-intro.md diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index dff3d33bf4ed1..fe37d0573e46b 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -1,10 +1,10 @@ -- text: Feature extraction, transformation, and selection +- text: "Overview: estimators, transformers and pipelines" + url: ml-intro.html +- text: Extracting, transforming and selecting features url: ml-features.html -- text: Decision trees for classification and regression - url: ml-decision-tree.html -- text: Ensembles - url: ml-ensembles.html -- text: Linear methods with elastic-net regularization - url: ml-linear-methods.html -- text: Multilayer perceptron classifier - url: ml-ann.html +- text: Classification and Regression + url: ml-classification-regression.html +- text: Clustering + url: ml-clustering.html +- text: Advanced topics + url: ml-advanced.html diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md new file mode 100644 index 0000000000000..b005633e56c11 --- /dev/null +++ b/docs/ml-advanced.md @@ -0,0 +1,13 @@ +--- +layout: global +title: Advanced topics - spark.ml +displayTitle: Advanced topics +--- + +# Optimization of linear methods + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. diff --git a/docs/ml-ann.md b/docs/ml-ann.md deleted file mode 100644 index 6e763e8f41568..0000000000000 --- a/docs/ml-ann.md +++ /dev/null @@ -1,62 +0,0 @@ ---- -layout: global -title: Multilayer perceptron classifier - ML -displayTitle: ML - Multilayer perceptron classifier ---- - - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). -MLPC consists of multiple layers of nodes. -Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs -by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. -It can be written in matrix form for MLPC with `$K+1$` layers as follows: -`\[ -\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) -\]` -Nodes in intermediate layers use sigmoid (logistic) function: -`\[ -\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} -\]` -Nodes in the output layer use softmax function: -`\[ -\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} -\]` -The number of nodes `$N$` in the output layer corresponds to the number of classes. - -MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. - -**Examples** - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} -
    - -
    -{% include_example python/ml/multilayer_perceptron_classification.py %} -
    - -
    diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md new file mode 100644 index 0000000000000..3663ffee32753 --- /dev/null +++ b/docs/ml-classification-regression.md @@ -0,0 +1,775 @@ +--- +layout: global +title: Classification and regression - spark.ml +displayTitle: Classification and regression in spark.ml +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + + +# Classification + +## Logistic regression + +Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. +For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + + > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + +**Example** + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} +
    + +
    +{% include_example python/ml/logistic_regression_with_elastic_net.py %} +
    + +
    + +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Dataframe` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} +
    + +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} +
    + + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + + +## Decision tree classifier + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} + +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} + +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% include_example python/ml/decision_tree_classification_example.py %} + +
    + +
    + +## Random forest classifier + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% include_example python/ml/random_forest_classifier_example.py %} +
    +
    + +## Gradient-boosted tree classifier + +Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% include_example python/ml/gradient_boosted_tree_classifier_example.py %} +
    +
    + +## Multilayer perceptron classifier + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} +
    + +
    +{% include_example python/ml/multilayer_perceptron_classification.py %} +
    + +
    + + +## One-vs-Rest classifier (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." + +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +**Example** + +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. + +{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %} +
    +
    + + +# Regression + +## Linear regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. + +**Example** + +The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} +
    + +
    + +{% include_example python/ml/linear_regression_with_elastic_net.py %} +
    + +
    + + +## Decision tree regression + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). + +{% include_example python/ml/decision_tree_regression_example.py %} +
    + +
    + + +## Random forest regression + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% include_example python/ml/random_forest_regressor_example.py %} +
    +
    + +## Gradient-boosted tree regression + +Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% include_example python/ml/gradient_boosted_tree_regressor_example.py %} +
    +
    + + +## Survival regression + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    + + + +# Decision trees + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). +The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). + +## Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +### Input Columns + +
    10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when - performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
    + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + +# Tree Ensembles + +The Pipelines API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [MLlib decision trees](ml-decision-tree.html) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + +* support for ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +## Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +MLlib supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + + +## Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +MLlib supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +Note that `GBTClassifier` currently only supports binary labels. + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    + +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index cfefb5dfbde9e..697777714b05b 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -6,6 +6,11 @@ displayTitle: ML - Clustering In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, diff --git a/docs/ml-features.md b/docs/ml-features.md index e15c26836affc..55e401221917e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction, Transformation, and Selection - SparkML -displayTitle: ML - Features +title: Extracting, transforming and selecting features +displayTitle: Extracting, transforming and selecting features --- This section covers algorithms for working with features, roughly divided into these groups: diff --git a/docs/ml-intro.md b/docs/ml-intro.md new file mode 100644 index 0000000000000..d95a66ba23566 --- /dev/null +++ b/docs/ml-intro.md @@ -0,0 +1,941 @@ +--- +layout: global +title: "Overview: estimators, transformers and pipelines - spark.ml" +displayTitle: "Overview: estimators, transformers and pipelines" +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. + +**Table of contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + + +# Main concepts in Pipelines + +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. + +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. + +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. + +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. + +* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. + +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. + +## DataFrame + +Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. + +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. + +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. + +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." + +## Pipeline components + +### Transformers + +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. +For example: + +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. + +### Estimators + +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. + +### Properties of pipeline components + +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. + +Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). + +## Pipeline + +In machine learning, it is common to run a sequence of algorithms to process and learn from data. +E.g., a simple text document processing workflow might include several stages: + +* Split each document's text into words. +* Convert each document's words into a numerical feature vector. +* Learn a prediction model using the feature vectors and labels. + +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. + +### How it works + +A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. + +We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. + +

    + Spark ML Pipeline Example +

    + +Above, the top row represents a `Pipeline` with three stages. +The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. +Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. + +A `Pipeline` is an `Estimator`. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage. + +

    + Spark ML PipelineModel Example +

    + +In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. +Each stage's `transform()` method updates the dataset and passes it to the next stage. + +`Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. + +### Details + +*DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. + +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. + +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + +## Parameters + +Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. + +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. + +There are two main ways to pass parameters to an algorithm: + +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. +2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. + +Parameters belong to specific instances of `Estimator`s and `Transformer`s. +For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. +This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. + +# Code examples + +This section gives code examples illustrating the functionality discussed above. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. + +## Example: Estimator, Transformer, and Param + +This example covers the concepts of `Estimator`, `Transformer`, and `Param`. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.sql.Row + +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") + +// Create a LogisticRegression instance. This instance is an Estimator. +val lr = new LogisticRegression() +// Print out the parameters, documentation, and any default values. +println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + +// We may set parameters using setter methods. +lr.setMaxIter(10) + .setRegParam(0.01) + +// Learn a LogisticRegression model. This uses the parameters stored in lr. +val model1 = lr.fit(training) +// Since model1 is a Model (i.e., a Transformer produced by an Estimator), +// we can view the parameters it used during fit(). +// This prints the parameter (name: value) pairs, where names are unique IDs for this +// LogisticRegression instance. +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + +// We may alternatively specify parameters using a ParamMap, +// which supports several methods for specifying parameters. +val paramMap = ParamMap(lr.maxIter -> 20) + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + +// One can also combine ParamMaps. +val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name +val paramMapCombined = paramMap ++ paramMap2 + +// Now learn a new model using the paramMapCombined parameters. +// paramMapCombined overrides all parameters set earlier via lr.set* methods. +val model2 = lr.fit(training, paramMapCombined) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + +// Prepare test data. +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") + +// Make predictions on test data using the Transformer.transform() method. +// LogisticRegression.transform will only use the 'features' column. +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +model2.transform(test) + .select("features", "label", "myProbability", "prediction") + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Prepare training data. +// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans +// into DataFrames, where it uses the bean metadata to infer the schema. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); + +// Create a LogisticRegression instance. This instance is an Estimator. +LogisticRegression lr = new LogisticRegression(); +// Print out the parameters, documentation, and any default values. +System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + +// We may set parameters using setter methods. +lr.setMaxIter(10) + .setRegParam(0.01); + +// Learn a LogisticRegression model. This uses the parameters stored in lr. +LogisticRegressionModel model1 = lr.fit(training); +// Since model1 is a Model (i.e., a Transformer produced by an Estimator), +// we can view the parameters it used during fit(). +// This prints the parameter (name: value) pairs, where names are unique IDs for this +// LogisticRegression instance. +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + +// We may alternatively specify parameters using a ParamMap. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + +// One can also combine ParamMaps. +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + +// Now learn a new model using the paramMapCombined parameters. +// paramMapCombined overrides all parameters set earlier via lr.set* methods. +LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + +// Prepare test documents. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); + +// Make predictions on test documents using the Transformer.transform() method. +// LogisticRegression.transform will only use the 'features' column. +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +DataFrame results = model2.transform(test); +for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params + +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training, paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +{% endhighlight %} +
    + +
    + +## Example: Pipeline + +This example follows the simple text document `Pipeline` illustrated in the figures above. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row + +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") +val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) +val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + +// Fit the pipeline to training documents. +val model = pipeline.fit(training) + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") + +// Make predictions on test documents. +model.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + +// Prepare training documents, which are labeled. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); +HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); +LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + +// Fit the pipeline to training documents. +PipelineModel model = pipeline.fit(training); + +// Prepare test documents, which are unlabeled. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop") +), Document.class); + +// Make predictions on test documents. +DataFrame predictions = model.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row + +# Prepare training documents from a list of (id, text, label) tuples. +LabeledDocument = Row("id", "text", "label") +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) + +# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. +tokenizer = Tokenizer(inputCol="text", outputCol="words") +hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") +lr = LogisticRegression(maxIter=10, regParam=0.01) +pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + +# Fit the pipeline to training documents. +model = pipeline.fit(training) + +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) + +# Make predictions on test documents and print columns of interest. +prediction = model.transform(test) +selected = prediction.select("id", "text", "prediction") +for row in selected.collect(): + print(row) + +{% endhighlight %} +
    + +
    + +## Example: model selection via cross-validation + +An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. +`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. + +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). +`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. +`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +method in each of these evaluators. + +The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. +`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +The following example demonstrates using `CrossValidator` to select from a grid of parameters. +To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. + +Note that cross-validation over a grid of parameters is expensive. +E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. +In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). +In other words, using `CrossValidator` can be very expensive. +However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") +val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") +val lr = new LogisticRegression() + .setMaxIter(10) +val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, +// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. +val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice + +// Run cross-validation, and choose the best set of parameters. +val cvModel = cv.fit(training) + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") + +// Make predictions on test documents. cvModel uses the best model found (lrModel). +cvModel.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + + +// Prepare training documents, which are labeled. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0), + new LabeledDocument(4L, "b spark who", 1.0), + new LabeledDocument(5L, "g d a y", 0.0), + new LabeledDocument(6L, "spark fly", 1.0), + new LabeledDocument(7L, "was mapreduce", 0.0), + new LabeledDocument(8L, "e spark program", 1.0), + new LabeledDocument(9L, "a e c l", 0.0), + new LabeledDocument(10L, "spark compile", 1.0), + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); +HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); +LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, +// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) + .addGrid(lr.regParam(), new double[]{0.1, 0.01}) + .build(); + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice + +// Run cross-validation, and choose the best set of parameters. +CrossValidatorModel cvModel = cv.fit(training); + +// Prepare test documents, which are unlabeled. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop") +), Document.class); + +// Make predictions on test documents. cvModel uses the best model found (lrModel). +DataFrame predictions = cvModel.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    + +## Example: model selection via train validation split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} + +// Prepare training and test data. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + +val lr = new LinearRegression() + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) + +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() + +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.sql.DataFrame; + +DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; + +LinearRegression lr = new LinearRegression(); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); + +{% endhighlight %} +
    + +
    \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 43772adcf26e1..3bc2b780601c2 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -66,15 +66,14 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major -concepts. It also contains sections on using algorithms within the Pipelines API, for example: - -* [Feature extraction, transformation, and selection](ml-features.html) +* [Overview: estimators, transformers and pipelines](ml-intro.html) +* [Extracting, transforming and selecting features](ml-features.html) +* [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) -* [Decision trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) +* [Advanced topics](ml-advanced.html) + +Some techniques are not available yet in spark.ml, most notably dimensionality reduction +Users can seemlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. # Dependencies From a0046e379bee0852c39ece4ea719cde70d350b0e Mon Sep 17 00:00:00 2001 From: Dominik Dahlem Date: Tue, 8 Dec 2015 18:54:10 -0800 Subject: [PATCH 1172/2089] [SPARK-11343][ML] Documentation of float and double prediction/label columns in RegressionEvaluator felixcheung , mengxr Just added a message to require() Author: Dominik Dahlem Closes #9598 from dahlem/ddahlem_regression_evaluator_double_predictions_message_04112015. --- .../apache/spark/ml/evaluation/RegressionEvaluator.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index daaa174a086e0..b6b25ecd01b3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -73,10 +73,15 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema + val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType - require(predictionType == FloatType || predictionType == DoubleType) + require(predictionType == FloatType || predictionType == DoubleType, + s"Prediction column $predictionColName must be of type float or double, " + + s" but not $predictionType") + val labelColName = $(labelCol) val labelType = schema($(labelCol)).dataType - require(labelType == FloatType || labelType == DoubleType) + require(labelType == FloatType || labelType == DoubleType, + s"Label column $labelColName must be of type float or double, but not $labelType") val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) From 3934562d34bbe08d91c54b4bbee27870e93d7571 Mon Sep 17 00:00:00 2001 From: Fei Wang Date: Tue, 8 Dec 2015 21:32:31 -0800 Subject: [PATCH 1173/2089] [SPARK-12222] [CORE] Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception Jira: https://issues.apache.org/jira/browse/SPARK-12222 Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception: ``` com.esotericsoftware.kryo.KryoException: Buffer underflow. at com.esotericsoftware.kryo.io.Input.require(Input.java:156) at com.esotericsoftware.kryo.io.Input.skip(Input.java:131) at com.esotericsoftware.kryo.io.Input.skip(Input.java:264) ``` This is caused by a bug of kryo's `Input.skip(long count)`(https://github.com/EsotericSoftware/kryo/issues/119) and we call this method in `KryoInputDataInputBridge`. Instead of upgrade kryo's version, this pr bypass the kryo's `Input.skip(long count)` by directly call another `skip` method in kryo's Input.java(https://github.com/EsotericSoftware/kryo/blob/kryo-2.21/src/com/esotericsoftware/kryo/io/Input.java#L124), i.e. write the bug-fixed version of `Input.skip(long count)` in KryoInputDataInputBridge's `skipBytes` method. more detail link to https://github.com/apache/spark/pull/9748#issuecomment-162860246 Author: Fei Wang Closes #10213 from scwf/patch-1. --- .../spark/serializer/KryoSerializer.scala | 10 ++++++- .../serializer/KryoSerializerSuite.scala | 28 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 62d445f3d7bd9..cb2ac5ea167ec 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -400,7 +400,15 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readUTF(): String = input.readString() // readString in kryo does utf8 override def readInt(): Int = input.readInt() override def readUnsignedShort(): Int = input.readShortUnsigned() - override def skipBytes(n: Int): Int = input.skip(n.toLong).toInt + override def skipBytes(n: Int): Int = { + var remaining: Long = n + while (remaining > 0) { + val skip = Math.min(Integer.MAX_VALUE, remaining).asInstanceOf[Int] + input.skip(skip) + remaining -= skip + } + n + } override def readFully(b: Array[Byte]): Unit = input.read(b) override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len) override def readLine(): String = throw new UnsupportedOperationException("readLine") diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index f81fe3113106f..9fcc22b608c65 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileOutputStream, FileInputStream} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} + +import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.Utils import org.apache.spark.storage.BlockManagerId class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -350,6 +354,28 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { + val dir = Utils.createTempDir() + val tmpfile = dir.toString + "/RoaringBitmap" + val outStream = new FileOutputStream(tmpfile) + val output = new KryoOutput(outStream) + val bitmap = new RoaringBitmap + bitmap.add(1) + bitmap.add(3) + bitmap.add(5) + bitmap.serialize(new KryoOutputDataOutputBridge(output)) + output.flush() + output.close() + + val inStream = new FileInputStream(tmpfile) + val input = new KryoInput(inStream) + val ret = new RoaringBitmap + ret.deserialize(new KryoInputDataInputBridge(input)) + input.close() + assert(ret == bitmap) + Utils.deleteRecursively(dir) + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) From f6883bb7afa7d5df480e1c2b3db6cb77198550be Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 9 Dec 2015 15:15:30 +0800 Subject: [PATCH 1174/2089] [SPARK-11676][SQL] Parquet filter tests all pass if filters are not really pushed down Currently Parquet predicate tests all pass even if filters are not pushed down or this is disabled. In this PR, For checking evaluating filters, Simply it makes the expression from `expression.Filter` and then try to create filters just like Spark does. For checking the results, this manually accesses to the child rdd (of `expression.Filter`) and produces the results which should be filtered properly, and then compares it to expected values. Now, if filters are not pushed down or this is disabled, this throws exceptions. Author: hyukjinkwon Closes #9659 from HyukjinKwon/SPARK-11676. --- .../parquet/ParquetFilterSuite.scala | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index cc5aae03d5516..daf41bc292cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -50,27 +50,33 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten.reduceLeftOption(_ && _) - assert(maybeAnalyzedPredicate.isDefined) - - val selectedFilters = maybeAnalyzedPredicate.flatMap(DataSourceStrategy.translateFilter) - assert(selectedFilters.nonEmpty) - - selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach { f => - // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(f.getClass === filterClass) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[ParquetRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(relation: ParquetRelation, _)) => + maybeRelation = Some(relation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } + checker(stripSparkFilter(query), expected) } - checker(query, expected) } } @@ -104,6 +110,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -347,19 +368,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) val df = sqlContext.read.parquet(path).filter("a = 2") - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - assert(childRDD.count == 1) + assert(stripSparkFilter(df).count == 1) } } } From a113216865fd45ea39ae8f104e784af2cf667dcf Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 9 Dec 2015 15:09:40 +0000 Subject: [PATCH 1175/2089] [SPARK-12031][CORE][BUG] Integer overflow when do sampling Author: uncleGen Closes #10023 from uncleGen/1.6-bugfix. --- .../src/main/scala/org/apache/spark/Partitioner.scala | 4 ++-- .../org/apache/spark/util/random/SamplingUtils.scala | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e4df7af81a6d2..ef9a2dab1c106 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -253,7 +253,7 @@ private[spark] object RangePartitioner { */ def sketch[K : ClassTag]( rdd: RDD[K], - sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => @@ -262,7 +262,7 @@ private[spark] object RangePartitioner { iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) }.collect() - val numItems = sketched.map(_._2.toLong).sum + val numItems = sketched.map(_._2).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae62778..f98932a470165 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,17 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } From 6e1c55eac4849669e119ce0d51f6d051830deb9f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 9 Dec 2015 23:30:42 +0800 Subject: [PATCH 1176/2089] [SPARK-12012][SQL] Show more comprehensive PhysicalRDD metadata when visualizing SQL query plan This PR adds a `private[sql]` method `metadata` to `SparkPlan`, which can be used to describe detail information about a physical plan during visualization. Specifically, this PR uses this method to provide details of `PhysicalRDD`s translated from a data source relation. For example, a `ParquetRelation` converted from Hive metastore table `default.psrc` is now shown as the following screenshot: ![image](https://cloud.githubusercontent.com/assets/230655/11526657/e10cb7e6-9916-11e5-9afa-f108932ec890.png) And here is the screenshot for a regular `ParquetRelation` (not converted from Hive metastore table) loaded from a really long path: ![output](https://cloud.githubusercontent.com/assets/230655/11680582/37c66460-9e94-11e5-8f50-842db5309d5a.png) Author: Cheng Lian Closes #10004 from liancheng/spark-12012.physical-rdd-metadata. --- python/pyspark/sql/dataframe.py | 2 +- .../spark/sql/execution/ExistingRDD.scala | 19 +++++--- .../spark/sql/execution/SparkPlan.scala | 5 +++ .../spark/sql/execution/SparkPlanInfo.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../datasources/DataSourceStrategy.scala | 22 +++++++-- .../datasources/parquet/ParquetRelation.scala | 10 +++++ .../sql/execution/ui/SparkPlanGraph.scala | 45 +++++++++++-------- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 ++- 11 files changed, 87 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 746bb55e14f22..78ab475eb466b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -213,7 +213,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan PhysicalRDD[age#0,name#1] + Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 623348f6768a4..b8a43025882e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -97,22 +97,31 @@ private[sql] case class LogicalRDD( private[sql] case class PhysicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], - extraInformation: String, + override val nodeName: String, + override val metadata: Map[String, String] = Map.empty, override val outputsUnsafeRows: Boolean = false) extends LeafNode { protected override def doExecute(): RDD[InternalRow] = rdd - override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" + } } private[sql] object PhysicalRDD { + // Metadata keys + val INPUT_PATHS = "InputPaths" + val PUSHED_FILTERS = "PushedFilters" + def createFromDataSource( output: Seq[Attribute], rdd: RDD[InternalRow], relation: BaseRelation, - extraInformation: String = ""): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString + extraInformation, - relation.isInstanceOf[HadoopFsRelation]) + metadata: Map[String, String] = Map.empty): PhysicalRDD = { + // All HadoopFsRelations output UnsafeRows + val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] + PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a78177751c9dc..ec98f81041343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,6 +67,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } + /** + * Return all metadata that describes more details of this SparkPlan. + */ + private[sql] def metadata: Map[String, String] = Map.empty + /** * Return all metrics containing metrics of this SparkPlan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 486ce34064e43..4f750ad13ab84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -30,6 +30,7 @@ class SparkPlanInfo( val nodeName: String, val simpleString: String, val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], val metrics: Seq[SQLMetricInfo]) private[sql] object SparkPlanInfo { @@ -41,6 +42,6 @@ private[sql] object SparkPlanInfo { } val children = plan.children.map(fromSparkPlan) - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, plan.metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f67c951bc0663..25e98c0bdd431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => apply(child) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 544d5eccec037..8a15a51d825ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} @@ -315,7 +318,20 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) - val pushedFiltersString = pushedFilters.mkString(" PushedFilter: [", ",", "] ") + val metadata: Map[String, String] = { + val pairs = ArrayBuffer.empty[(String, String)] + + if (pushedFilters.nonEmpty) { + pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) + } + + relation.relation match { + case r: HadoopFsRelation => pairs += INPUT_PATHS -> r.paths.mkString(", ") + case _ => + } + + pairs.toMap + } if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && @@ -334,7 +350,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, pushedFiltersString) + relation.relation, metadata) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. @@ -344,7 +360,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, pushedFiltersString) + relation.relation, metadata) execution.Project( projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bb3e2786978c5..1af2a394f399a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -146,6 +146,12 @@ private[sql] class ParquetRelation( meta } + override def toString: String = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map { tableName => + s"${getClass.getSimpleName}: $tableName" + }.getOrElse(super.toString) + } + override def equals(other: Any): Boolean = other match { case that: ParquetRelation => val schemaEquality = if (shouldMergeSchemas) { @@ -521,6 +527,10 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + // If a ParquetRelation is converted from a Hive metastore table, this option is set to the + // original Hive table name. + private[sql] val METASTORE_TABLE_NAME = "metastoreTableName" + /** * If parquet's block size (row group size) setting is larger than the min split size, * we use parquet's block size setting as the min split size. Otherwise, we will create diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 7af0ff09c5c6d..3a6eff9399825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -66,7 +66,9 @@ private[sql] object SparkPlanGraph { SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + nodes += node val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) @@ -85,26 +87,33 @@ private[sql] object SparkPlanGraph { * @param metrics metrics that this SparkPlan node will track */ private[ui] case class SparkPlanGraphNode( - id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { + id: Long, + name: String, + desc: String, + metadata: Map[String, String], + metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { - val values = { - for (metric <- metrics; - value <- metricsValue.get(metric.accumulatorId)) yield { - metric.name + ": " + value - } + val builder = new mutable.StringBuilder(name) + + val values = for { + metric <- metrics + value <- metricsValue.get(metric.accumulatorId) + } yield { + metric.name + ": " + value } - val label = if (values.isEmpty) { - name - } else { - // If there are metrics, display all metrics in a separate line. We should use an escaped - // "\n" here to follow the dot syntax. - // - // Note: whitespace between two "\n"s is to create an empty line between the name of - // SparkPlan and metrics. If removing it, it won't display the empty line in UI. - name + "\\n \\n" + values.mkString("\\n") - } - s""" $id [label="$label"];""" + + if (values.nonEmpty) { + // If there are metrics, display each entry in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + builder ++= "\\n \\n" + builder ++= values.mkString("\\n") + } + + s""" $id [label="${builder.toString()}"];""" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 9ace25dc7d21b..fc8ce6901dfca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -422,7 +422,7 @@ abstract class HadoopFsRelation private[sql]( parameters: Map[String, String]) extends BaseRelation with FileRelation with Logging { - override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") + override def toString: String = getClass.getSimpleName def this() = this(None, Map.empty[String, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a4626259b2823..2fb439f50117a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -169,7 +169,7 @@ class PlannerSuite extends SharedSQLContext { withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan - assert(exp.toString.contains("PushedFilter: [EqualTo(key,15)]")) + assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9a981d02ad67c..08b291e088238 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -411,7 +411,12 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, + ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + metastoreRelation.tableName, + Some(metastoreRelation.databaseName) + ).unquotedString + ) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) From 22b9a8740d51289434553d19b6b1ac34aecdc09a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 9 Dec 2015 16:45:13 +0000 Subject: [PATCH 1177/2089] [SPARK-10299][ML] word2vec should allow users to specify the window size Currently word2vec has the window hard coded at 5, some users may want different sizes (for example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size . Author: Holden Karau Author: Holden Karau Closes #8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size. --- .../apache/spark/ml/feature/Word2Vec.scala | 15 +++++++ .../apache/spark/mllib/feature/Word2Vec.scala | 11 ++++- .../spark/ml/feature/Word2VecSuite.scala | 43 +++++++++++++++++-- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index a8d61b6dea00b..f105a983a34f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getVectorSize: Int = $(vectorSize) + /** + * The window size (context words from [-window, window]) default 5. + * @group expertParam + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window])") + setDefault(windowSize -> 5) + + /** @group expertGetParam */ + def getWindowSize: Int = $(windowSize) + /** * Number of partitions for sentences of words. * Default: 1 @@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) + /** @group expertSetParam */ + def setWindowSize(value: Int): this.type = set(windowSize, value) + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setNumPartitions($(numPartitions)) .setSeed($(seed)) .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 23b1514e3080e..1f400e1430eba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + this.window = window + this + } + /** * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). @@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging { private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ - private val window = 5 + private var window = 5 private var trainWordsCount = 0 private var vocabSize = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a773244cd735e..d561bbbb25529 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - val sqlContext = new SQLContext(sc) + + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul expectedSimilarity.zip(similarity).map { case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) } + } + + test("window size") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + val (synonyms, similarity) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } test("Word2Vec read/write") { From 6900f0173790ad2fa4c79a426bd2dec2d149daa2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 9 Dec 2015 09:50:43 -0800 Subject: [PATCH 1178/2089] [SPARK-10582][YARN][CORE] Fix AM failure situation for dynamic allocation Because of AM failure, the target executor number between driver and AM will be different, which will lead to unexpected behavior in dynamic allocation. So when AM is re-registered with driver, state in `ExecutorAllocationManager` and `CoarseGrainedSchedulerBacked` should be reset. This issue is originally addressed in #8737 , here re-opened again. Thanks a lot KaiXinXiaoLei for finding this issue. andrewor14 and vanzin would you please help to review this, thanks a lot. Author: jerryshao Closes #9963 from jerryshao/SPARK-10582. --- .../spark/ExecutorAllocationManager.scala | 18 +++- .../CoarseGrainedSchedulerBackend.scala | 19 +++++ .../ExecutorAllocationManagerSuite.scala | 84 +++++++++++++++++++ .../cluster/YarnSchedulerBackend.scala | 23 +++++ 4 files changed, 142 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 34c32ce312964..6176e258989db 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -89,6 +89,8 @@ private[spark] class ExecutorAllocationManager( private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Integer.MAX_VALUE) + private val initialNumExecutors = conf.getInt("spark.dynamicAllocation.initialExecutors", + minNumExecutors) // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -121,8 +123,7 @@ private[spark] class ExecutorAllocationManager( // The desired number of executors at this moment in time. If all our executors were to die, this // is the number of executors we would immediately want from the cluster manager. - private var numExecutorsTarget = - conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) + private var numExecutorsTarget = initialNumExecutors // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -240,6 +241,19 @@ private[spark] class ExecutorAllocationManager( executor.awaitTermination(10, TimeUnit.SECONDS) } + /** + * Reset the allocation manager to the initial state. Currently this will only be called in + * yarn-client mode when AM re-registers after a failure. + */ + def reset(): Unit = synchronized { + initializing = true + numExecutorsTarget = initialNumExecutors + numExecutorsToAdd = 1 + + executorsPendingToRemove.clear() + removeTimes.clear() + } + /** * The maximum number of executors we would need under the current load to satisfy all running * and pending tasks, rounded up. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 505c161141c88..7efe16749e59d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -341,6 +341,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only + * be called in the yarn-client mode when AM re-registers after a failure, also dynamic + * allocation is enabled. + * */ + protected def reset(): Unit = synchronized { + if (Utils.isDynamicAllocationEnabled(conf)) { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + + // Remove all the lingering executors that should be removed but not yet. The reason might be + // because (1) disconnected event is not yet received; (2) executors die silently. + executorDataMap.toMap.foreach { case (eid, _) => + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + } + } + } + override def reviveOffers() { driverEndpoint.send(ReviveOffers) } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 116f027a0f987..fedfbd547b91b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -805,6 +805,90 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) } + test("reset the state of allocation manager") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + + // Allocation manager is reset when adding executor requests are sent without reporting back + // executor added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 4) + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 5) + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + + // Allocation manager is reset when executors are added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + + // Allocation manager is reset when executors are pending to remove + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + removeExecutor(manager, "first") + removeExecutor(manager, "second") + assert(executorsPendingToRemove(manager) === Set("first", "second")) + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorsPendingToRemove(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index e3dd87798f018..1431bceb256a7 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -60,6 +60,9 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + // Flag to specify whether this schedulerBackend should be reset. + private var shouldResetOnAmRegister = false + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -155,6 +158,16 @@ private[spark] abstract class YarnSchedulerBackend( new YarnDriverEndpoint(rpcEnv, properties) } + /** + * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed + * and re-registered itself to driver after a failure. The stale state in driver should be + * cleaned. + */ + override protected def reset(): Unit = { + super.reset() + sc.executorAllocationManager.foreach(_.reset()) + } + /** * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. * This endpoint communicates with the executors and queries the AM for an executor's exit @@ -218,6 +231,8 @@ private[spark] abstract class YarnSchedulerBackend( case None => logWarning("Attempted to check for an executor loss reason" + " before the AM has registered!") + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) } } @@ -225,6 +240,13 @@ private[spark] abstract class YarnSchedulerBackend( case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") amEndpoint = Option(am) + if (!shouldResetOnAmRegister) { + shouldResetOnAmRegister = true + } else { + // AM is already registered before, this potentially means that AM failed and + // a new one registered after the failure. This will only happen in yarn-client mode. + reset() + } case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) @@ -270,6 +292,7 @@ private[spark] abstract class YarnSchedulerBackend( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (amEndpoint.exists(_.address == remoteAddress)) { logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + amEndpoint = None } } From 442a7715a590ba2ea2446c73b1f914a16ae0ed4b Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 9 Dec 2015 10:25:38 -0800 Subject: [PATCH 1179/2089] [SPARK-12241][YARN] Improve failure reporting in Yarn client obtainTokenForHBase() This lines up the HBase token logic with that done for Hive in SPARK-11265: reflection with only CFNE being swallowed. There is a test, one which doesn't try to put HBase on the yarn/test class and really do the reflection (the way the hive introspection does). If people do want that then it could be added with careful POM work +also: cut an incorrect comment from the Hive test case before copying it, and a couple of imports that may have been related to the hive test in the past. Author: Steve Loughran Closes #10227 from steveloughran/stevel/patches/SPARK-12241-obtainTokenForHBase. --- .../org/apache/spark/deploy/yarn/Client.scala | 32 ++---------- .../deploy/yarn/YarnSparkHadoopUtil.scala | 51 ++++++++++++++++++- .../yarn/YarnSparkHadoopUtilSuite.scala | 12 ++++- 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f0590d2d222ec..7742ec92eb4e8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1369,40 +1369,16 @@ object Client extends Logging { } /** - * Obtain security token for HBase. + * Obtain a security token for HBase. */ def obtainTokenForHBase( sparkConf: SparkConf, conf: Configuration, credentials: Credentials): Unit = { if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val confCreate = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). - getMethod("create", classOf[Configuration]) - val obtainToken = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). - getMethod("obtainToken", classOf[Configuration]) - - logDebug("Attempting to fetch HBase security token.") - - val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] - if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { - val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] - credentials.addToken(token.getService, token) - logInfo("Added HBase security token to credentials.") - } - } catch { - case e: java.lang.NoSuchMethodException => - logInfo("HBase Method not found: " + e) - case e: java.lang.ClassNotFoundException => - logDebug("HBase Class not found: " + e) - case e: java.lang.NoClassDefFoundError => - logDebug("HBase Class not found: " + e) - case e: Exception => - logError("Exception when obtaining HBase security token: " + e) + YarnSparkHadoopUtil.get.obtainTokenForHBase(conf).foreach { token => + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index a290ebeec9001..36a2d61429887 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.security.token.Token +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -216,6 +216,55 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { None } } + + /** + * Obtain a security token for HBase. + * + * Requirements + * + * 1. `"hbase.security.authentication" == "kerberos"` + * 2. The HBase classes `HBaseConfiguration` and `TokenUtil` could be loaded + * and invoked. + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if the requirements were met, `None` if not. + */ + def obtainTokenForHBase(conf: Configuration): Option[Token[TokenIdentifier]] = { + try { + obtainTokenForHBaseInner(conf) + } catch { + case e: ClassNotFoundException => + logInfo(s"HBase class not found $e") + logDebug("HBase class not found", e) + None + } + } + + /** + * Obtain a security token for HBase if `"hbase.security.authentication" == "kerberos"` + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if one was needed + */ + def obtainTokenForHBaseInner(conf: Configuration): Option[Token[TokenIdentifier]] = { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + logDebug("Attempting to fetch HBase security token.") + Some(obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]]) + } else { + None + } + } + } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index a70e66d39a64e..3fafc91a166aa 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -27,7 +27,6 @@ import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -259,7 +258,6 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") }) - // expect exception trapping code to unwind this hive-side exception assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf) }) @@ -276,6 +274,16 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging inner } + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + val util = new YarnSparkHadoopUtil + intercept[ClassNotFoundException] { + util.obtainTokenForHBaseInner(hadoopConf) + } + util.obtainTokenForHBase(hadoopConf) should be (None) + } + // This test needs to live here because it depends on isYarnMode returning true, which can only // happen in the YARN module. test("security manager token generation") { From aec5ea000ebb8921f42f006b694ef26f5df67d83 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Dec 2015 11:39:59 -0800 Subject: [PATCH 1180/2089] [SPARK-12165][SPARK-12189] Fix bugs in eviction of storage memory by execution This patch fixes a bug in the eviction of storage memory by execution. ## The bug: In general, execution should be able to evict storage memory when the total storage memory usage is greater than `maxMemory * spark.memory.storageFraction`. Due to a bug, however, Spark might wind up evicting no storage memory in certain cases where the storage memory usage was between `maxMemory * spark.memory.storageFraction` and `maxMemory`. For example, here is a regression test which illustrates the bug: ```scala val maxMemory = 1000L val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) // Since we used the default storage fraction (0.5), we should be able to allocate 500 bytes // of storage memory which are immune to eviction by execution memory pressure. // Acquire enough storage memory to exceed the storage region size assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) // At this point, storage is using 250 more bytes of memory than it is guaranteed, so execution // should be able to reclaim up to 250 bytes of storage memory. // Therefore, execution should now be able to require up to 500 bytes of memory: assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) // <--- fails by only returning 250L assert(mm.storageMemoryUsed === 500L) assert(mm.executionMemoryUsed === 500L) assertEvictBlocksToFreeSpaceCalled(ms, 250L) ``` The problem relates to the control flow / interaction between `StorageMemoryPool.shrinkPoolToReclaimSpace()` and `MemoryStore.ensureFreeSpace()`. While trying to allocate the 500 bytes of execution memory, the `UnifiedMemoryManager` discovers that it will need to reclaim 250 bytes of memory from storage, so it calls `StorageMemoryPool.shrinkPoolToReclaimSpace(250L)`. This method, in turn, calls `MemoryStore.ensureFreeSpace(250L)`. However, `ensureFreeSpace()` first checks whether the requested space is less than `maxStorageMemory - storageMemoryUsed`, which will be true if there is any free execution memory because it turns out that `MemoryStore.maxStorageMemory = (maxMemory - onHeapExecutionMemoryPool.memoryUsed)` when the `UnifiedMemoryManager` is used. The control flow here is somewhat confusing (it grew to be messy / confusing over time / as a result of the merging / refactoring of several components). In the pre-Spark 1.6 code, `ensureFreeSpace` was called directly by the `MemoryStore` itself, whereas in 1.6 it's involved in a confusing control flow where `MemoryStore` calls `MemoryManager.acquireStorageMemory`, which then calls back into `MemoryStore.ensureFreeSpace`, which, in turn, calls `MemoryManager.freeStorageMemory`. ## The solution: The solution implemented in this patch is to remove the confusing circular control flow between `MemoryManager` and `MemoryStore`, making the storage memory acquisition process much more linear / straightforward. The key changes: - Remove a layer of inheritance which made the memory manager code harder to understand (53841174760a24a0df3eb1562af1f33dbe340eb9). - Move some bounds checks earlier in the call chain (13ba7ada77f87ef1ec362aec35c89a924e6987cb). - Refactor `ensureFreeSpace()` so that the part which evicts blocks can be called independently from the part which checks whether there is enough free space to avoid eviction (7c68ca09cb1b12f157400866983f753ac863380e). - Realize that this lets us remove a layer of overloads from `ensureFreeSpace` (eec4f6c87423d5e482b710e098486b3bbc4daf06). - Realize that `ensureFreeSpace()` can simply be replaced with an `evictBlocksToFreeSpace()` method which is called [after we've already figured out](https://github.com/apache/spark/blob/2dc842aea82c8895125d46a00aa43dfb0d121de9/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala#L88) how much memory needs to be reclaimed via eviction; (2dc842aea82c8895125d46a00aa43dfb0d121de9). Along the way, I fixed some problems with the mocks in `MemoryManagerSuite`: the old mocks would [unconditionally](https://github.com/apache/spark/blob/80a824d36eec9d9a9f092ee1741453851218ec73/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala#L84) report that a block had been evicted even if there was enough space in the storage pool such that eviction would be avoided. I also fixed a problem where `StorageMemoryPool._memoryUsed` might become negative due to freed memory being double-counted when excution evicts storage. The problem was that `StorageMemoryPoolshrinkPoolToFreeSpace` would [decrement `_memoryUsed`](https://github.com/apache/spark/commit/7c68ca09cb1b12f157400866983f753ac863380e#diff-935c68a9803be144ed7bafdd2f756a0fL133) even though `StorageMemoryPool.freeMemory` had already decremented it as each evicted block was freed. See SPARK-12189 for details. Author: Josh Rosen Author: Andrew Or Closes #10170 from JoshRosen/SPARK-12165. --- .../apache/spark/memory/MemoryManager.scala | 11 +- .../spark/memory/StaticMemoryManager.scala | 37 ++++- .../spark/memory/StorageMemoryPool.scala | 37 +++-- .../spark/memory/UnifiedMemoryManager.scala | 8 +- .../apache/spark/storage/MemoryStore.scala | 76 ++-------- .../spark/memory/MemoryManagerSuite.scala | 137 +++++++++--------- .../memory/StaticMemoryManagerSuite.scala | 52 ++++--- .../memory/UnifiedMemoryManagerSuite.scala | 76 +++++++--- 8 files changed, 230 insertions(+), 204 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ceb8ea434e1be..ae9e1ac0e246b 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -77,9 +77,7 @@ private[spark] abstract class MemoryManager( def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) - } + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -109,12 +107,7 @@ private[spark] abstract class MemoryManager( def acquireExecutionMemory( numBytes: Long, taskAttemptId: Long, - memoryMode: MemoryMode): Long = synchronized { - memoryMode match { - case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) - case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) - } - } + memoryMode: MemoryMode): Long /** * Release numBytes of execution memory belonging to the given task. diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 12a094306861f..3554b558f2123 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -49,19 +49,50 @@ private[spark] class StaticMemoryManager( } // Max number of bytes worth of blocks to evict when unrolling - private val maxMemoryToEvictForUnroll: Long = { + private val maxUnrollMemory: Long = { (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + false + } else { + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } + } + override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory - val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) - val numBytesToFree = math.min(numBytes, maxNumBytesToFree) + val freeMemory = storageMemoryPool.memoryFree + // When unrolling, we will use all of the existing free memory, and, if necessary, + // some extra space freed from evicting cached blocks. We must place a cap on the + // amount of memory to be evicted by unrolling, however, otherwise unrolling one + // big block can blow away the entire cache. + val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) + // Keep it within the range 0 <= X <= maxNumBytesToFree + val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } + + private[memory] + override def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + } + } } diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index fc4f0357e9f16..70af83b5ee092 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -65,7 +65,8 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { - acquireMemory(blockId, numBytes, numBytes, evictedBlocks) + val numBytesToFree = math.max(0, numBytes - memoryFree) + acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } /** @@ -73,7 +74,7 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w * * @param blockId the ID of the block we are acquiring storage memory for * @param numBytesToAcquire the size of this block - * @param numBytesToFree the size of space to be freed through evicting blocks + * @param numBytesToFree the amount of space to be freed through evicting blocks * @return whether all N bytes were successfully granted. */ def acquireMemory( @@ -84,16 +85,18 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w assert(numBytesToAcquire >= 0) assert(numBytesToFree >= 0) assert(memoryUsed <= poolSize) - memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + if (numBytesToFree > 0) { + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } } // NOTE: If the memory store evicts blocks, then those evictions will synchronously call - // back into this StorageMemoryPool in order to free. Therefore, these variables should have - // been updated. + // back into this StorageMemoryPool in order to free memory. Therefore, these variables + // should have been updated. val enoughMemory = numBytesToAcquire <= memoryFree if (enoughMemory) { _memoryUsed += numBytesToAcquire @@ -121,18 +124,20 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w */ def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { // First, shrink the pool by reclaiming free memory: - val spaceFreedByReleasingUnusedMemory = Math.min(spaceToFree, memoryFree) + val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) decrementPoolSize(spaceFreedByReleasingUnusedMemory) - if (spaceFreedByReleasingUnusedMemory == spaceToFree) { - spaceFreedByReleasingUnusedMemory - } else { + val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory + if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - memoryStore.ensureFreeSpace(spaceToFree - spaceFreedByReleasingUnusedMemory, evictedBlocks) + memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, evictedBlocks) val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum - _memoryUsed -= spaceFreedByEviction + // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do + // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. decrementPoolSize(spaceFreedByEviction) spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } else { + spaceFreedByReleasingUnusedMemory } } } diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0f1ea9ab39c07..0b9f6a9dc0525 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -100,7 +100,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( case MemoryMode.OFF_HEAP => // For now, we only support on-heap caching of data, so we do not need to interact with // the storage pool when allocating off-heap memory. This will change in the future, though. - super.acquireExecutionMemory(numBytes, taskAttemptId, memoryMode) + offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) } } @@ -110,6 +110,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + return false + } if (numBytes > storageMemoryPool.memoryFree) { // There is not enough free memory in the storage pool, so try to borrow free memory from // the execution pool. diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 4dbac388e098b..bdab8c2332fae 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -406,85 +406,41 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } /** - * Try to free up a given amount of space by evicting existing blocks. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(None, space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a block by evicting existing ones. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - blockId: BlockId, - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(Some(blockId), space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a particular block, but can fail if - * either the block is bigger than our memory or it would require replacing another block - * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * @param blockId the ID of the block we are freeing space for, if any - * @param space the size of this block - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private def ensureFreeSpace( + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def evictBlocksToFreeSpace( blockId: Option[BlockId], space: Long, droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + assert(space > 0) memoryManager.synchronized { - val freeMemory = maxMemory - memoryUsed + var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L - - logInfo(s"Ensuring $space bytes of free space " + - blockId.map { id => s"for block $id" }.getOrElse("") + - s"(free: $freeMemory, max: $maxMemory)") - - // Fail fast if the block simply won't fit - if (space > maxMemory) { - logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") + - s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)") - return false - } - - // No need to evict anything if there is already enough free space - if (freeMemory >= space) { - return true - } - // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (freeMemory + selectedMemory < space && iterator.hasNext) { + while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { selectedBlocks += blockId - selectedMemory += pair.getValue.size + freedMemory += pair.getValue.size } } } - if (freeMemory + selectedMemory >= space) { + if (freedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index f55d435fa33a6..555b640cb4244 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -24,9 +24,10 @@ import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} import org.mockito.Matchers.{any, anyLong} -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite @@ -36,105 +37,105 @@ import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, StorageLevel /** * Helper trait for sharing code among [[MemoryManager]] tests. */ -private[memory] trait MemoryManagerSuite extends SparkFunSuite { +private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAfterEach { - import MemoryManagerSuite.DEFAULT_ENSURE_FREE_SPACE_CALLED + protected val evictedBlocks = new mutable.ArrayBuffer[(BlockId, BlockStatus)] + + import MemoryManagerSuite.DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED // Note: Mockito's verify mechanism does not provide a way to reset method call counts // without also resetting stubbed methods. Since our test code relies on the latter, - // we need to use our own variable to track invocations of `ensureFreeSpace`. + // we need to use our own variable to track invocations of `evictBlocksToFreeSpace`. /** - * The amount of free space requested in the last call to [[MemoryStore.ensureFreeSpace]] + * The amount of space requested in the last call to [[MemoryStore.evictBlocksToFreeSpace]]. * - * This set whenever [[MemoryStore.ensureFreeSpace]] is called, and cleared when the test - * code makes explicit assertions on this variable through [[assertEnsureFreeSpaceCalled]]. + * This set whenever [[MemoryStore.evictBlocksToFreeSpace]] is called, and cleared when the test + * code makes explicit assertions on this variable through + * [[assertEvictBlocksToFreeSpaceCalled]]. */ - private val ensureFreeSpaceCalled = new AtomicLong(DEFAULT_ENSURE_FREE_SPACE_CALLED) + private val evictBlocksToFreeSpaceCalled = new AtomicLong(0) + + override def beforeEach(): Unit = { + super.beforeEach() + evictedBlocks.clear() + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) + } /** - * Make a mocked [[MemoryStore]] whose [[MemoryStore.ensureFreeSpace]] method is stubbed. + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is stubbed. * - * This allows our test code to release storage memory when [[MemoryStore.ensureFreeSpace]] - * is called without relying on [[org.apache.spark.storage.BlockManager]] and all of its - * dependencies. + * This allows our test code to release storage memory when these methods are called + * without relying on [[org.apache.spark.storage.BlockManager]] and all of its dependencies. */ protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { - val ms = mock(classOf[MemoryStore]) - when(ms.ensureFreeSpace(anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 0)) - when(ms.ensureFreeSpace(any(), anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 1)) + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) + .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) mm.setMemoryStore(ms) ms } /** - * Make an [[Answer]] that stubs [[MemoryStore.ensureFreeSpace]] with the right arguments. - */ - private def ensureFreeSpaceAnswer(mm: MemoryManager, numBytesPos: Int): Answer[Boolean] = { + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. + */ + private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Boolean] = { new Answer[Boolean] { override def answer(invocation: InvocationOnMock): Boolean = { val args = invocation.getArguments - require(args.size > numBytesPos, s"bad test: expected >$numBytesPos arguments " + - s"in ensureFreeSpace, found ${args.size}") - require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + - s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") - val numBytes = args(numBytesPos).asInstanceOf[Long] - val success = mockEnsureFreeSpace(mm, numBytes) - if (success) { + val numBytesToFree = args(1).asInstanceOf[Long] + assert(numBytesToFree > 0) + require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "bad test: evictBlocksToFreeSpace() variable was not reset") + evictBlocksToFreeSpaceCalled.set(numBytesToFree) + if (numBytesToFree <= mm.storageMemoryUsed) { + // We can evict enough blocks to fulfill the request for space + mm.releaseStorageMemory(numBytesToFree) args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( - (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytes, 0L, 0L))) + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + // We need to add this call so that that the suite-level `evictedBlocks` is updated when + // execution evicts storage; in that case, args.last will not be equal to evictedBlocks + // because it will be a temporary buffer created inside of the MemoryManager rather than + // being passed in by the test code. + if (!(evictedBlocks eq args.last)) { + evictedBlocks.append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + } + true + } else { + // No blocks were evicted because eviction would not free enough space. + false } - success - } - } - } - - /** - * Simulate the part of [[MemoryStore.ensureFreeSpace]] that releases storage memory. - * - * This is a significant simplification of the real method, which actually drops existing - * blocks based on the size of each block. Instead, here we simply release as many bytes - * as needed to ensure the requested amount of free space. This allows us to set up the - * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in - * many other dependencies. - * - * Every call to this method will set a global variable, [[ensureFreeSpaceCalled]], that - * records the number of bytes this is called with. This variable is expected to be cleared - * by the test code later through [[assertEnsureFreeSpaceCalled]]. - */ - private def mockEnsureFreeSpace(mm: MemoryManager, numBytes: Long): Boolean = mm.synchronized { - require(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "bad test: ensure free space variable was not reset") - // Record the number of bytes we freed this call - ensureFreeSpaceCalled.set(numBytes) - if (numBytes <= mm.maxStorageMemory) { - def freeMemory = mm.maxStorageMemory - mm.storageMemoryUsed - val spaceToRelease = numBytes - freeMemory - if (spaceToRelease > 0) { - mm.releaseStorageMemory(spaceToRelease) } - freeMemory >= numBytes - } else { - // We attempted to free more bytes than our max allowable memory - false } } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters. */ - protected def assertEnsureFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { - assert(ensureFreeSpaceCalled.get() === numBytes, - s"expected ensure free space to be called with $numBytes") - ensureFreeSpaceCalled.set(DEFAULT_ENSURE_FREE_SPACE_CALLED) + protected def assertEvictBlocksToFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === numBytes, + s"expected evictBlocksToFreeSpace() to be called with $numBytes") + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is NOT called. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is NOT called. */ - protected def assertEnsureFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { - assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "ensure free space should not have been called!") + protected def assertEvictBlocksToFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "evictBlocksToFreeSpace() should not have been called!") + assert(evictedBlocks.isEmpty) } /** @@ -291,5 +292,5 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { } private object MemoryManagerSuite { - private val DEFAULT_ENSURE_FREE_SPACE_CALLED = -1L + private val DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED = -1L } diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 54cb28c389c2f..6700b94f0f57f 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.mockito.Mockito.when import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.storage.{MemoryStore, TestBlockId} class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] /** * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. @@ -85,33 +82,38 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxStorageMem + 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -133,7 +135,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 50L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released @@ -151,21 +153,25 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val dummyBlock = TestBlockId("lonely water") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + when(ms.currentUnrollMemory).thenReturn(100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 100L) mm.releaseUnrollMemory(40L) assert(mm.storageMemoryUsed === 60L) when(ms.currentUnrollMemory).thenReturn(60L) - assert(mm.acquireUnrollMemory(dummyBlock, 500L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. - assertEnsureFreeSpaceCalled(ms, 340L) - assert(mm.storageMemoryUsed === 560L) - when(ms.currentUnrollMemory).thenReturn(560L) - assert(!mm.acquireUnrollMemory(dummyBlock, 800L, evictedBlocks)) - assert(mm.storageMemoryUsed === 560L) - // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed - assertEnsureFreeSpaceCalled(ms, 0L) + // Since we already occupy 60 bytes, we will try to evict only 400 - 60 = 340 bytes. + assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 + assert(mm.storageMemoryUsed === 1000L) + evictedBlocks.clear() + assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 400 - 300 + assert(mm.storageMemoryUsed === 900L) // 100 bytes were evicted // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e97c898a44783..71221deeb4c28 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.storage.{MemoryStore, TestBlockId} class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { private val dummyBlock = TestBlockId("--") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] private val storageFraction: Double = 0.5 @@ -78,33 +75,40 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val (mm, ms) = makeThings(maxMemory) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxMemory + 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -117,25 +121,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val (mm, ms) = makeThings(maxMemory) // Acquire enough storage memory to exceed the storage region assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) // Execution needs to request 250 bytes to evict storage memory assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 300L) - assertEnsureFreeSpaceCalled(ms, 50L) - assert(mm.executionMemoryUsed === 300L) + assert(mm.storageMemoryUsed === 700L) + assertEvictBlocksToFreeSpaceCalled(ms, 50L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() mm.releaseAllStorageMemory() require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 400L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 400L) assert(mm.executionMemoryUsed === 300L) // Execution cannot evict storage because the latter is within the storage fraction, @@ -143,7 +149,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("execution memory requests smaller than free memory should evict storage (SPARK-12165)") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + // Acquire enough storage memory to exceed the storage region size + assert(mm.acquireStorageMemory(dummyBlock, 700L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 700L) + // SPARK-12165: previously, MemoryStore would not evict anything because it would + // mistakenly think that the 300 bytes of free space was still available even after + // using it to expand the execution pool. Consequently, no storage memory was released + // and the following call granted only 300 bytes to execution. + assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) + assertEvictBlocksToFreeSpaceCalled(ms, 200L) + assert(mm.storageMemoryUsed === 500L) + assert(mm.executionMemoryUsed === 500L) + assert(evictedBlocks.nonEmpty) } test("storage does not evict execution") { @@ -154,32 +180,34 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 250L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should still not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) // since there were 800 bytes free assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 850L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) } test("small heap") { From 1eb7c22ce72a1b82ed194a51bbcf0da9c771605a Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 9 Dec 2015 19:47:38 +0000 Subject: [PATCH 1181/2089] [SPARK-11824][WEBUI] WebUI does not render descriptions with 'bad' HTML, throws console error Don't warn when description isn't valid HTML since it may properly be like "SELECT ... where foo <= 1" The tests for this code indicate that it's normal to handle strings like this that don't contain HTML as a string rather than markup. Hence logging every such instance as a warning is too noisy since it's not a problem. this is an issue for stages whose name contain SQL like the above CC tdas as author of this bit of code Author: Sean Owen Closes #10159 from srowen/SPARK-11824. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 1e8194f57888e..81a6f07ec836a 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -448,7 +448,6 @@ private[spark] object UIUtils extends Logging { new RuleTransformer(rule).transform(xml) } catch { case NonFatal(e) => - logWarning(s"Invalid job description: $desc ", e) {desc} } } From 051c6a066f7b5fcc7472412144c15b50a5319bd5 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 9 Dec 2015 12:00:48 -0800 Subject: [PATCH 1182/2089] [SPARK-11551][DOC] Replace example code in ml-features.md using include_example PR on behalf of somideshmukh, thanks! Author: Xusen Yin Author: somideshmukh Closes #10219 from yinxusen/SPARK-11551. --- docs/ml-features.md | 1112 +---------------- .../examples/ml/JavaBinarizerExample.java | 68 + .../examples/ml/JavaBucketizerExample.java | 71 ++ .../spark/examples/ml/JavaDCTExample.java | 65 + .../ml/JavaElementwiseProductExample.java | 75 ++ .../examples/ml/JavaMinMaxScalerExample.java | 51 + .../spark/examples/ml/JavaNGramExample.java | 71 ++ .../examples/ml/JavaNormalizerExample.java | 54 + .../examples/ml/JavaOneHotEncoderExample.java | 78 ++ .../spark/examples/ml/JavaPCAExample.java | 71 ++ .../ml/JavaPolynomialExpansionExample.java | 71 ++ .../examples/ml/JavaRFormulaExample.java | 69 + .../ml/JavaStandardScalerExample.java | 54 + .../ml/JavaStopWordsRemoverExample.java | 65 + .../examples/ml/JavaStringIndexerExample.java | 66 + .../examples/ml/JavaTokenizerExample.java | 75 ++ .../ml/JavaVectorAssemblerExample.java | 67 + .../examples/ml/JavaVectorIndexerExample.java | 61 + .../examples/ml/JavaVectorSlicerExample.java | 73 ++ .../src/main/python/ml/binarizer_example.py | 43 + .../src/main/python/ml/bucketizer_example.py | 43 + .../python/ml/elementwise_product_example.py | 39 + examples/src/main/python/ml/n_gram_example.py | 42 + .../src/main/python/ml/normalizer_example.py | 43 + .../main/python/ml/onehot_encoder_example.py | 48 + examples/src/main/python/ml/pca_example.py | 42 + .../python/ml/polynomial_expansion_example.py | 43 + .../src/main/python/ml/rformula_example.py | 44 + .../main/python/ml/standard_scaler_example.py | 43 + .../python/ml/stopwords_remover_example.py | 40 + .../main/python/ml/string_indexer_example.py | 39 + .../src/main/python/ml/tokenizer_example.py | 44 + .../python/ml/vector_assembler_example.py | 42 + .../main/python/ml/vector_indexer_example.py | 40 + .../spark/examples/ml/BinarizerExample.scala | 48 + .../spark/examples/ml/BucketizerExample.scala | 52 + .../apache/spark/examples/ml/DCTExample.scala | 54 + .../ml/ElementWiseProductExample.scala | 52 + .../examples/ml/MinMaxScalerExample.scala | 50 + .../spark/examples/ml/NGramExample.scala | 47 + .../spark/examples/ml/NormalizerExample.scala | 52 + .../examples/ml/OneHotEncoderExample.scala | 58 + .../apache/spark/examples/ml/PCAExample.scala | 53 + .../ml/PolynomialExpansionExample.scala | 51 + .../spark/examples/ml/RFormulaExample.scala | 49 + .../examples/ml/StandardScalerExample.scala | 52 + .../examples/ml/StopWordsRemoverExample.scala | 48 + .../examples/ml/StringIndexerExample.scala | 48 + .../spark/examples/ml/TokenizerExample.scala | 54 + .../examples/ml/VectorAssemblerExample.scala | 49 + .../examples/ml/VectorIndexerExample.scala | 54 + .../examples/ml/VectorSlicerExample.scala | 58 + 52 files changed, 2820 insertions(+), 1061 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java create mode 100644 examples/src/main/python/ml/binarizer_example.py create mode 100644 examples/src/main/python/ml/bucketizer_example.py create mode 100644 examples/src/main/python/ml/elementwise_product_example.py create mode 100644 examples/src/main/python/ml/n_gram_example.py create mode 100644 examples/src/main/python/ml/normalizer_example.py create mode 100644 examples/src/main/python/ml/onehot_encoder_example.py create mode 100644 examples/src/main/python/ml/pca_example.py create mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py create mode 100644 examples/src/main/python/ml/rformula_example.py create mode 100644 examples/src/main/python/ml/standard_scaler_example.py create mode 100644 examples/src/main/python/ml/stopwords_remover_example.py create mode 100644 examples/src/main/python/ml/string_indexer_example.py create mode 100644 examples/src/main/python/ml/tokenizer_example.py create mode 100644 examples/src/main/python/ml/vector_assembler_example.py create mode 100644 examples/src/main/python/ml/vector_indexer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 55e401221917e..7ad7c4eb7ea65 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,25 +170,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,44 +179,7 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    @@ -243,21 +188,7 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +237,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +245,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +253,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +270,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +278,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +286,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +303,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +311,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +319,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +333,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +341,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +349,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +363,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +371,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +379,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +403,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +411,7 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}}
    @@ -881,18 +466,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +474,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,16 +482,7 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    @@ -1030,30 +565,7 @@ for more details on the API. Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -1061,46 +573,7 @@ encoded.select("id", "categoryVec").show() Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -encoded.select("id", "categoryVec").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1108,25 +581,7 @@ encoded.select("id", "categoryVec").show(); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").show() -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1150,23 +605,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1174,30 +613,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1205,17 +621,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1232,22 +638,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    @@ -1255,24 +646,7 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    @@ -1280,19 +654,7 @@ DataFrame lInfNormData = Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1316,23 +678,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    @@ -1340,25 +686,7 @@ val scaledData = scalerModel.transform(dataFrame) Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    @@ -1366,20 +694,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1409,21 +724,7 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) - -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %}
    @@ -1432,24 +733,7 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1473,23 +757,7 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    @@ -1497,38 +765,7 @@ val bucketedData = bucketizer.transform(dataFrame) Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    @@ -1536,19 +773,7 @@ DataFrame bucketedData = bucketizer.transform(dataFrame); Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1580,25 +805,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1606,41 +813,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1648,19 +821,8 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() - -{% endhighlight %} +{% include_example python/ml/elementwise_product_example.py %}
    - ## SQLTransformer @@ -1763,19 +925,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1783,36 +933,7 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    @@ -1820,19 +941,7 @@ System.out.println(output.select("features", "clicked").first()); Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler - -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %}
    @@ -1962,33 +1071,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1996,41 +1079,7 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); - -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - -DataFrame output = vectorSlicer.transform(dataset); - -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %}
    @@ -2067,21 +1116,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -2089,38 +1124,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -2128,21 +1132,7 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 0000000000000..9698cac504371 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 0000000000000..8ad369cc93e8a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + bucketedData.show(); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 0000000000000..35c0d534a45e9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 0000000000000..2898accec61b0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 0000000000000..2d50ba7faa1a1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 0000000000000..8fd75ed8b5f4e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 0000000000000..ed3f6163c0558 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + l1NormData.show(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 0000000000000..bc509607084b1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + encoded.select("id", "categoryVec").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 0000000000000..8282fab084f36 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 0000000000000..668f71e64056b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 0000000000000..1e1062b541ad9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 0000000000000..da4756643f3c4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 0000000000000..b6b201c6b68d2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 0000000000000..05d12c1e702f1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 0000000000000..617dc3f66e3bf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 0000000000000..7e230b5897c1e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 0000000000000..545758e31d972 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + indexedData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 0000000000000..4d5cb04ff5e2b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 0000000000000..317cfa638a5a9 --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 0000000000000..4304255f350db --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 0000000000000..c85cb0d89543c --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 0000000000000..f2d85f53e7219 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 0000000000000..d490221474c24 --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + lInfNormData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 0000000000000..0f94c26638d35 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 0000000000000..a17181f1b8a51 --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data, ["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 0000000000000..3d4fafd1a42e9 --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext\ + .createDataFrame([(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 0000000000000..b544a14700762 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 0000000000000..ae7aa85005bcd --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 0000000000000..01f94af8ca752 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 0000000000000..58a8cb5d56b73 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 0000000000000..ce9b225be5357 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 0000000000000..04f64839f188d --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 0000000000000..146f41c1dd903 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + indexedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 0000000000000..e724aa587294b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 0000000000000..7c75e3d72b47b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 0000000000000..314c2c28a2a10 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala new file mode 100644 index 0000000000000..872de51dc75df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 0000000000000..fb7f28c9886bb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 0000000000000..8a85f71b56f3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 0000000000000..1990b55e8c5e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + lInfNormData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 0000000000000..66602e2118506 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 0000000000000..4c806f71a32c3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 0000000000000..39fb79af35766 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 0000000000000..286866edea502 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 0000000000000..e0a41e383a7ea --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 0000000000000..655ffce08d3ab --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 0000000000000..9fa494cd2473b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 0000000000000..01e0d1388a2f4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 0000000000000..d527924419f81 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 0000000000000..685891c164e70 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + indexedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 0000000000000..04f19829eff87 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 7a8e587dc04c2fabc875d1754eae7f85b4fba6ba Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 9 Dec 2015 17:16:01 -0800 Subject: [PATCH 1183/2089] [SPARK-12211][DOC][GRAPHX] Fix version number in graphx doc for migration from 1.1 Migration from 1.1 section added to the GraphX doc in 1.2.0 (see https://spark.apache.org/docs/1.2.0/graphx-programming-guide.html#migrating-from-spark-11) uses \{{site.SPARK_VERSION}} as the version where changes were introduced, it should be just 1.2. Author: Andrew Ray Closes #10206 from aray/graphx-doc-1.1-migration. --- docs/graphx-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 6a512ab234bb2..9dea9b5904d2d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -70,7 +70,7 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato ## Migrating from Spark 1.1 -GraphX in Spark {{site.SPARK_VERSION}} contains a few user facing API changes: +GraphX in Spark 1.2 contains a few user facing API changes: 1. To improve performance we have introduced a new version of [`mapReduceTriplets`][Graph.mapReduceTriplets] called From 8770bd1213f9b1051dabde9c5424ae7b32143a44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 9 Dec 2015 17:24:04 -0800 Subject: [PATCH 1184/2089] [SPARK-12165][ADDENDUM] Fix outdated comments on unroll test JoshRosen Author: Andrew Or Closes #10229 from andrewor14/unroll-test-comments. --- .../spark/memory/StaticMemoryManagerSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 6700b94f0f57f..272253bc94e91 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -163,15 +163,20 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to evict only 400 - 60 = 340 bytes. + // As of this point, cache memory is 800 bytes and current unroll memory is 60 bytes. + // Requesting 240 more bytes of unroll memory will leave our total unroll memory at + // 300 bytes, still under the 400-byte limit. Therefore, all 240 bytes are granted. assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) - assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 860 + 240 - 1000 when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 assert(mm.storageMemoryUsed === 1000L) evictedBlocks.clear() + // We already have 300 bytes of unroll memory, so requesting 150 more will leave us + // above the 400-byte limit. Since there is not enough free memory, this request will + // fail even after evicting as much as we can (400 - 300 = 100 bytes). assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) - assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 400 - 300 - assert(mm.storageMemoryUsed === 900L) // 100 bytes were evicted + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.storageMemoryUsed === 900L) // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) From ac8cdf1cdc148bd21290ecf4d4f9874f8c87cc14 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 9 Dec 2015 18:09:36 -0800 Subject: [PATCH 1185/2089] [SPARK-11678][SQL][DOCS] Document basePath in the programming guide. This PR adds document for `basePath`, which is a new parameter used by `HadoopFsRelation`. The compiled doc is shown below. ![image](https://cloud.githubusercontent.com/assets/2072857/11673132/1ba01192-9dcb-11e5-98d9-ac0b4e92e98c.png) JIRA: https://issues.apache.org/jira/browse/SPARK-11678 Author: Yin Huai Closes #10211 from yhuai/basePathDoc. --- docs/sql-programming-guide.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9f87accd30f40..3f9a831eddc88 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1233,6 +1233,13 @@ infer the data types of the partitioning columns. For these use cases, the autom can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type inference is disabled, string type will be used for the partitioning columns. +Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths +by default. For the above example, if users pass `path/to/table/gender=male` to either +`SQLContext.read.parquet` or `SQLContext.read.load`, `gender` will not be considered as a +partitioning column. If users need to specify the base path that partition discovery +should start with, they can set `basePath` in the data source options. For example, +when `path/to/table/gender=male` is the path of the data and +users set `basePath` to `path/to/table/`, `gender` will be a partitioning column. ### Schema Merging From 2166c2a75083c2262e071a652dd52b1a33348b6e Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Wed, 9 Dec 2015 18:37:35 -0800 Subject: [PATCH 1186/2089] [SPARK-11796] Fix httpclient and httpcore depedency issues related to docker-client This commit fixes dependency issues which prevented the Docker-based JDBC integration tests from running in the Maven build. Author: Mark Grover Closes #9876 from markgrover/master_docker. --- docker-integration-tests/pom.xml | 22 ++++++++++++++++++++++ pom.xml | 28 ++++++++++++++++++++++++++++ sql/core/pom.xml | 2 -- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index dee0c4aa37ae8..39d3f344615e1 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -71,6 +71,18 @@
    + + org.apache.httpcomponents + httpclient + 4.5 + test + + + org.apache.httpcomponents + httpcore + 4.4.1 + test + com.google.guava @@ -109,6 +121,16 @@ ${project.version} test + + mysql + mysql-connector-java + test + + + org.postgresql + postgresql + test + prob=$prob, prediction=$prediction") - } - -{% endhighlight %} - - -
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// Fit the pipeline to training documents. -PipelineModel model = pipeline.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. -DataFrame predictions = model.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} -
    - -
    -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row - -# Prepare training documents from a list of (id, text, label) tuples. -LabeledDocument = Row("id", "text", "label") -training = sqlContext.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) - -# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. -tokenizer = Tokenizer(inputCol="text", outputCol="words") -hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") -lr = LogisticRegression(maxIter=10, regParam=0.01) -pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) - -# Fit the pipeline to training documents. -model = pipeline.fit(training) - -# Prepare test documents, which are unlabeled (id, text) tuples. -test = sqlContext.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")], ["id", "text"]) - -# Make predictions on test documents and print columns of interest. -prediction = model.transform(test) -selected = prediction.select("id", "text", "prediction") -for row in selected.collect(): - print(row) - -{% endhighlight %} -
    - - - -## Example: model selection via cross-validation - -An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. -`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. - -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). -`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. -`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. - -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` -method in each of these evaluators. - -The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. -`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -The following example demonstrates using `CrossValidator` to select from a grid of parameters. -To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. - -Note that cross-validation over a grid of parameters is expensive. -E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. -In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). -In other words, using `CrossValidator` can be very expensive. -However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training data from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0), - (4L, "b spark who", 1.0), - (5L, "g d a y", 0.0), - (6L, "spark fly", 1.0), - (7L, "was mapreduce", 0.0), - (8L, "e spark program", 1.0), - (9L, "a e c l", 0.0), - (10L, "spark compile", 1.0), - (11L, "hadoop software", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -val cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2) // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -val cvModel = cv.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} -
    - -
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2); // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = cv.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -DataFrame predictions = cvModel.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} -
    - -
    - -## Example: model selection via train validation split -In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. -`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in - case of `CrossValidator`. It is therefore less expensive, - but will not produce as reliable results when the training dataset is not sufficiently large. - -`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, -and an `Evaluator`. -It begins by splitting the dataset into two parts using `trainRatio` parameter -which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), -`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. -Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. -For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `ParamMap` which produces the best evaluation metric is selected as the best option. -`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} - -// Prepare training and test data. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") -val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - -val lr = new LinearRegression() - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - // 80% of the data will be used for training and the remaining 20% for validation. - .setTrainRatio(0.8) - -// Run train validation split, and choose the best set of parameters. -val model = trainValidationSplit.fit(training) - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show() - -{% endhighlight %} -
    - -
    -{% highlight java %} -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.DataFrame; - -DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); -DataFrame training = splits[0]; -DataFrame test = splits[1]; - -LinearRegression lr = new LinearRegression(); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation - -// Run train validation split, and choose the best set of parameters. -TrainValidationSplitModel model = trainValidationSplit.fit(training); - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show(); - -{% endhighlight %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 0c13d7d0c82b3..a8754835cab95 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -1,148 +1,8 @@ --- layout: global -title: Linear Methods - ML -displayTitle: ML - Linear Methods +title: Linear methods - spark.ml +displayTitle: Linear methods - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In MLlib, we implement popular linear methods such as logistic -regression and linear least squares with $L_1$ or $L_2$ regularization. -Refer to [the linear methods in mllib](mllib-linear-methods.html) for -details. In `spark.ml`, we also include Pipelines API for [Elastic -net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid -of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization -and variable selection via the elastic -net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). -Mathematically, it is defined as a convex combination of the $L_1$ and -the $L_2$ regularization terms: -`\[ -\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 -\]` -By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ -regularization as special cases. For example, if a [linear -regression](https://en.wikipedia.org/wiki/Linear_regression) model is -trained with the elastic net parameter $\alpha$ set to $1$, it is -equivalent to a -[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. -On the other hand, if $\alpha$ is set to $0$, the trained model reduces -to a [ridge -regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. -We implement Pipelines API for both linear regression and logistic -regression with elastic net regularization. - -## Example: Logistic Regression - -The following example shows how to train a logistic regression model -with elastic net regularization. `elasticNetParam` corresponds to -$\alpha$ and `regParam` corresponds to $\lambda$. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} -
    - -
    -{% include_example python/ml/logistic_regression_with_elastic_net.py %} -
    - -
    - -The `spark.ml` implementation of logistic regression also supports -extracting a summary of the model over the training set. Note that the -predictions and metrics which are stored as `Dataframe` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence -only available on the driver. - -
    - -
    - -[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) -provides a summary for a -[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} -
    - -
    -[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) -provides a summary for a -[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} -
    - - -
    -Logistic regression model summary is not yet supported in Python. -
    - -
    - -## Example: Linear Regression - -The interface for working with linear regression models and model -summaries is similar to the logistic regression case. The following -example demonstrates training an elastic net regularized linear -regression model and extracting model summary statistics. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} -
    - -
    - -{% include_example python/ml/linear_regression_with_elastic_net.py %} -
    - -
    - -# Optimization - -The optimization algorithm underlying the implementation is called -[Orthant-Wise Limited-memory -QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 -regularization and elastic net. - + > This section has been moved into the + [classification and regression section](ml-classification-regression.html). diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md index ab275213b9a84..856ceb2f4e7f6 100644 --- a/docs/ml-survival-regression.md +++ b/docs/ml-survival-regression.md @@ -1,96 +1,8 @@ --- layout: global -title: Survival Regression - ML -displayTitle: ML - Survival Regression +title: Survival Regression - spark.ml +displayTitle: Survival Regression - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) -model which is a parametric survival regression model for censored data. -It describes a model for the log of survival time, so it's often called -log-linear model for survival analysis. Different from -[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model -designed for the same purpose, the AFT model is more easily to parallelize -because each instance contribute to the objective function independently. - -Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of -subjects i = 1, ..., n, with possible right-censoring, -the likelihood function under the AFT model is given as: -`\[ -L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} -\]` -Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. -Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function -assumes the form: -`\[ -\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] -\]` -Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, -and $f_{0}(\epsilon_{i})$ is corresponding density function. - -The most commonly used AFT model is based on the Weibull distribution of the survival time. -The Weibull distribution for lifetime corresponding to extreme value distribution for -log of the lifetime, and the $S_{0}(\epsilon)$ function is: -`\[ -S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) -\]` -the $f_{0}(\epsilon_{i})$ function is: -`\[ -f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) -\]` -The log-likelihood function for AFT model with Weibull distribution of lifetime is: -`\[ -\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] -\]` -Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, -the loss function we use to optimize is $-\iota(\beta,\sigma)$. -The gradient functions for $\beta$ and $\log\sigma$ respectively are: -`\[ -\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} -\]` -`\[ -\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] -\]` - -The AFT model can be formulated as a convex optimization problem, -i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ -that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. -The optimization algorithm underlying the implementation is L-BFGS. -The implementation matches the result from R's survival function -[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) - -## Example: - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} -
    - -
    -{% include_example python/ml/aft_survival_regression.py %} -
    - -
    \ No newline at end of file + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#survival-regression). diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 0210950b89906..aaf8bd465c9ab 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -1,10 +1,10 @@ --- layout: global -title: Classification and Regression - MLlib -displayTitle: MLlib - Classification and Regression +title: Classification and Regression - spark.mllib +displayTitle: Classification and Regression - spark.mllib --- -MLlib supports various methods for +The `spark.mllib` package supports various methods for [binary classification](http://en.wikipedia.org/wiki/Binary_classification), [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification), and diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 8fbced6c87d9f..48d64cd402b11 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -1,7 +1,7 @@ --- layout: global -title: Clustering - MLlib -displayTitle: MLlib - Clustering +title: Clustering - spark.mllib +displayTitle: Clustering - spark.mllib --- [Clustering](https://en.wikipedia.org/wiki/Cluster_analysis) is an unsupervised learning problem whereby we aim to group subsets @@ -10,19 +10,19 @@ often used for exploratory analysis and/or as a component of a hierarchical [supervised learning](https://en.wikipedia.org/wiki/Supervised_learning) pipeline (in which distinct classifiers or regression models are trained for each cluster). -MLlib supports the following models: +The `spark.mllib` package supports the following models: * Table of contents {:toc} ## K-means -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +[K-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the most commonly used clustering algorithms that clusters the data points into a -predefined number of clusters. The MLlib implementation includes a parallelized +predefined number of clusters. The `spark.mllib` implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in `spark.mllib` has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -171,7 +171,7 @@ sameModel = KMeansModel.load(sc, "myModelPath") A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, -each with its own probability. The MLlib implementation uses the +each with its own probability. The `spark.mllib` implementation uses the [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) algorithm to induce the maximum-likelihood model given a set of samples. The implementation has the following parameters: @@ -308,13 +308,13 @@ graph given pairwise similarties as edge properties, described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. -MLlib includes an implementation of PIC using GraphX as its backend. +`spark.mllib` includes an implementation of PIC using GraphX as its backend. It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. The similarities must be nonnegative. PIC assumes that the similarity measure is symmetric. A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. If a pair is missing from input, their similarity is treated as zero. -MLlib's PIC implementation takes the following (hyper-)parameters: +`spark.mllib`'s PIC implementation takes the following (hyper-)parameters: * `k`: number of clusters * `maxIterations`: maximum number of power iterations @@ -323,7 +323,7 @@ MLlib's PIC implementation takes the following (hyper-)parameters: **Examples** -In the following, we show code snippets to demonstrate how to use PIC in MLlib. +In the following, we show code snippets to demonstrate how to use PIC in `spark.mllib`.
    @@ -493,7 +493,7 @@ checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. -All of MLlib's LDA models support: +All of `spark.mllib`'s LDA models support: * `describeTopics`: Returns topics as arrays of most important terms and term weights @@ -721,7 +721,7 @@ sameModel = LDAModel.load(sc, "myModelPath") ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, +updating them as new data arrive. `spark.mllib` provides support for streaming k-means clustering, with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7cd1b894e7cb5..1ebb4654aef12 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - MLlib -displayTitle: MLlib - Collaborative Filtering +title: Collaborative Filtering - spark.mllib +displayTitle: Collaborative Filtering - spark.mllib --- * Table of contents @@ -11,12 +11,12 @@ displayTitle: MLlib - Collaborative Filtering [Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) is commonly used for recommender systems. These techniques aim to fill in the -missing entries of a user-item association matrix. MLlib currently supports +missing entries of a user-item association matrix. `spark.mllib` currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -MLlib uses the [alternating least squares +`spark.mllib` uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) -algorithm to learn these latent factors. The implementation in MLlib has the +algorithm to learn these latent factors. The implementation in `spark.mllib` has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). @@ -34,7 +34,7 @@ The standard approach to matrix factorization based collaborative filtering trea the entries in the user-item matrix as *explicit* preferences given by the user to the item. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, -clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with such data is taken +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). Essentially instead of trying to model the matrix of ratings directly, this approach treats the data @@ -119,4 +119,4 @@ a dependency. ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for -[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). +[personalized movie recommendation with `spark.mllib`](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 3c0c0479674df..363dc7c13b306 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global title: Data Types - MLlib -displayTitle: MLlib - Data Types +displayTitle: Data Types - MLlib --- * Table of contents diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 77ce34e91af3c..a8612b6c84fe9 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Trees - MLlib -displayTitle: MLlib - Decision Trees +title: Decision Trees - spark.mllib +displayTitle: Decision Trees - spark.mllib --- * Table of contents @@ -15,7 +15,7 @@ feature scaling, and are able to capture non-linearities and feature interaction algorithms such as random forests and boosting are among the top performers for classification and regression tasks. -MLlib supports decision trees for binary and multiclass classification and for regression, +`spark.mllib` supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions of instances. diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index ac3526908a9f4..11d8e0bd1d23d 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -1,7 +1,7 @@ --- layout: global -title: Dimensionality Reduction - MLlib -displayTitle: MLlib - Dimensionality Reduction +title: Dimensionality Reduction - spark.mllib +displayTitle: Dimensionality Reduction - spark.mllib --- * Table of contents @@ -11,7 +11,7 @@ displayTitle: MLlib - Dimensionality Reduction of reducing the number of variables under consideration. It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -MLlib provides support for dimensionality reduction on the RowMatrix class. +`spark.mllib` provides support for dimensionality reduction on the RowMatrix class. ## Singular value decomposition (SVD) @@ -57,7 +57,7 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver. ### SVD Example -MLlib provides SVD functionality to row-oriented matrices, provided in the +`spark.mllib` provides SVD functionality to row-oriented matrices, provided in the RowMatrix class.
    @@ -141,7 +141,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. +`spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 50450e05d2abb..2416b6fa0aeb3 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - MLlib -displayTitle: MLlib - Ensembles +title: Ensembles - spark.mllib +displayTitle: Ensembles - spark.mllib --- * Table of contents @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -33,9 +33,9 @@ Like decision trees, random forests handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports random forests for binary and multiclass classification and for regression, +`spark.mllib` supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -MLlib implements random forests using the existing [decision tree](mllib-decision-tree.html) +`spark.mllib` implements random forests using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. ### Basic algorithm @@ -155,9 +155,9 @@ Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports GBTs for binary classification and for regression, +`spark.mllib` supports GBTs for binary classification and for regression, using both continuous and categorical features. -MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. +`spark.mllib` implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. *Note*: GBTs do not yet support multiclass classification. For multiclass problems, please use [decision trees](mllib-decision-tree.html) or [Random Forests](mllib-ensembles.html#Random-Forest). @@ -171,7 +171,7 @@ The specific mechanism for re-labeling instances is defined by a loss function ( #### Losses -The table below lists the losses currently supported by GBTs in MLlib. +The table below lists the losses currently supported by GBTs in `spark.mllib`. Note that each loss is applicable to one of classification or regression, not both. Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 6924037b941f3..774826c2703f8 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -1,20 +1,20 @@ --- layout: global -title: Evaluation Metrics - MLlib -displayTitle: MLlib - Evaluation Metrics +title: Evaluation Metrics - spark.mllib +displayTitle: Evaluation Metrics - spark.mllib --- * Table of contents {:toc} -Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +`spark.mllib` comes with a number of machine learning algorithms that can be used to learn from and make predictions on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance -of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +of the model on some criteria, which depends on the application and its requirements. `spark.mllib` also provides a suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those -metrics that are currently available in Spark's MLlib are detailed in this section. +metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 5bee170c61fe9..7796bac697562 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction and Transformation - MLlib -displayTitle: MLlib - Feature Extraction and Transformation +title: Feature Extraction and Transformation - spark.mllib +displayTitle: Feature Extraction and Transformation - spark.mllib --- * Table of contents @@ -31,7 +31,7 @@ The TF-IDF measure is simply the product of TF and IDF: TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). \]` There are several variants on the definition of term frequency and document frequency. -In MLlib, we separate TF and IDF to make them flexible. +In `spark.mllib`, we separate TF and IDF to make them flexible. Our implementation of term frequency utilizes the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). @@ -44,7 +44,7 @@ To reduce the chance of collision, we can increase the target feature dimension, the number of buckets of the hash table. The default feature dimension is `$2^{20} = 1,048,576$`. -**Note:** MLlib doesn't provide tools for text segmentation. +**Note:** `spark.mllib` doesn't provide tools for text segmentation. We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and [scalanlp/chalk](https://github.com/scalanlp/chalk). @@ -86,7 +86,7 @@ val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} -MLlib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. @@ -134,7 +134,7 @@ idf = IDF().fit(tf) tfidf = idf.transform(tf) {% endhighlight %} -MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index fe42896a05d8e..2c8a8f236163f 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -1,7 +1,7 @@ --- layout: global -title: Frequent Pattern Mining - MLlib -displayTitle: MLlib - Frequent Pattern Mining +title: Frequent Pattern Mining - spark.mllib +displayTitle: Frequent Pattern Mining - spark.mllib --- Mining frequent items, itemsets, subsequences, or other substructures is usually among the @@ -9,7 +9,7 @@ first steps to analyze a large-scale dataset, which has been an active research data mining for years. We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) for more information. -MLlib provides a parallel implementation of FP-growth, +`spark.mllib` provides a parallel implementation of FP-growth, a popular algorithm to mining frequent itemsets. ## FP-growth @@ -22,13 +22,13 @@ Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) al the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. -In MLlib, we implemented a parallel version of FP-growth called PFP, +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffices of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. -MLlib's FP-growth implementation takes the following (hyper-)parameters: +`spark.mllib`'s FP-growth implementation takes the following (hyper-)parameters: * `minSupport`: the minimum support for an itemset to be identified as frequent. For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. @@ -126,7 +126,7 @@ PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. -MLlib's PrefixSpan implementation takes the following parameters: +`spark.mllib`'s PrefixSpan implementation takes the following parameters: * `minSupport`: the minimum support required to be considered a frequent sequential pattern. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 3bc2b780601c2..7fef6b5c61f99 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -66,7 +66,7 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -* [Overview: estimators, transformers and pipelines](ml-intro.html) +* [Overview: estimators, transformers and pipelines](ml-guide.html) * [Extracting, transforming and selecting features](ml-features.html) * [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 85f9226b43416..8ede4407d5843 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Isotonic regression - MLlib -displayTitle: MLlib - Regression +title: Isotonic regression - spark.mllib +displayTitle: Regression - spark.mllib --- ## Isotonic regression @@ -23,7 +23,7 @@ Essentially isotonic regression is a [monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) best fitting the original data points. -MLlib supports a +`spark.mllib` supports a [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 132f8c354aa9c..20b35612cab95 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear Methods - MLlib -displayTitle: MLlib - Linear Methods +title: Linear Methods - spark.mllib +displayTitle: Linear Methods - spark.mllib --- * Table of contents @@ -41,7 +41,7 @@ the objective function is of the form Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several of MLlib's classification and regression algorithms fall into this category, +Several of `spark.mllib`'s classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: @@ -55,7 +55,7 @@ training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions The following table summarizes the loss functions and their gradients or sub-gradients for the -methods MLlib supports: +methods `spark.mllib` supports: @@ -83,7 +83,7 @@ methods MLlib supports: The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to encourage simple models and avoid overfitting. We support the following -regularizers in MLlib: +regularizers in `spark.mllib`:
    @@ -115,7 +115,10 @@ especially when the number of training examples is small. ### Optimization -Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. +Under the hood, linear methods use convex optimization methods to optimize the objective functions. +`spark.mllib` uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). +Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. +Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. ## Classification @@ -126,16 +129,16 @@ The most common classification type is categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). -MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +`spark.mllib` supports two linear methods for classification: linear Support Vector Machines (SVMs) and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. -For both methods, MLlib supports L1 and L2 regularized variants. +For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either $+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -207,7 +210,7 @@ val sameModel = SVMModel.load(sc, "myModelPath") The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -293,7 +296,7 @@ public class SVMClassifier { The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -375,7 +378,7 @@ Binary logistic regression can be generalized into train and predict multiclass classification problems. For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately regressed against the pivot outcome. -In MLlib, the first class $0$ is chosen as the "pivot" class. +In `spark.mllib`, the first class $0$ is chosen as the "pivot" class. See Section 4.4 of [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for references. @@ -726,7 +729,7 @@ a dependency. ###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +updating the parameters of the model as new data arrives. `spark.mllib` currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -852,7 +855,7 @@ will get better! # Implementation (developer) -Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent +Behind the scene, `spark.mllib` implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the optimization section). All provided algorithms take as input a regularization parameter (`regParam`) along with various parameters associated with stochastic diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 774b85d1f773a..73e4fddf67fc0 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -1,7 +1,7 @@ --- layout: global -title: Old Migration Guides - MLlib -displayTitle: MLlib - Old Migration Guides +title: Old Migration Guides - spark.mllib +displayTitle: Old Migration Guides - spark.mllib description: MLlib migration guides from before Spark SPARK_VERSION_SHORT --- diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 60ac6c7e5bb1a..d0d594af6a4ad 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -1,7 +1,7 @@ --- layout: global -title: Naive Bayes - MLlib -displayTitle: MLlib - Naive Bayes +title: Naive Bayes - spark.mllib +displayTitle: Naive Bayes - spark.mllib --- [Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple @@ -12,7 +12,7 @@ distribution of each feature given label, and then it applies Bayes' theorem to compute the conditional probability distribution of label given an observation and use it for prediction. -MLlib supports [multinomial naive +`spark.mllib` supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index ad7bcd9bfd407..f90b66f8e2c44 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -1,7 +1,7 @@ --- layout: global -title: Optimization - MLlib -displayTitle: MLlib - Optimization +title: Optimization - spark.mllib +displayTitle: Optimization - spark.mllib --- * Table of contents @@ -87,7 +87,7 @@ in the `$t$`-th iteration, with the input parameter `$s=$ stepSize`. Note that s step-size for SGD methods can often be delicate in practice and is a topic of active research. **Gradients.** -A table of (sub)gradients of the machine learning methods implemented in MLlib, is available in +A table of (sub)gradients of the machine learning methods implemented in `spark.mllib`, is available in the classification and regression section. @@ -140,7 +140,7 @@ other first-order optimization. ### Choosing an Optimization Method -[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in MLlib support both SGD and L-BFGS. +[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in `spark.mllib` support both SGD and L-BFGS. Different optimization methods can have different convergence guarantees depending on the properties of the objective function, and we cannot cover the literature here. In general, when L-BFGS is available, we recommend using it instead of SGD since L-BFGS tends to converge faster (in fewer iterations). diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index 615287125c032..b532ad907dfc5 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -1,21 +1,21 @@ --- layout: global -title: PMML model export - MLlib -displayTitle: MLlib - PMML model export +title: PMML model export - spark.mllib +displayTitle: PMML model export - spark.mllib --- * Table of contents {:toc} -## MLlib supported models +## `spark.mllib` supported models -MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). +`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). -The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. +The table below outlines the `spark.mllib` models that can be exported to PMML and their equivalent PMML model.
    - + diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de209f68e19ca..652d215fa8653 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -1,7 +1,7 @@ --- layout: global -title: Basic Statistics - MLlib -displayTitle: MLlib - Basic Statistics +title: Basic Statistics - spark.mllib +displayTitle: Basic Statistics - spark.mllib --- * Table of contents @@ -112,7 +112,7 @@ print(summary.numNonzeros()) ## Correlations -Calculating the correlation between two series of data is a common operation in Statistics. In MLlib +Calculating the correlation between two series of data is a common operation in Statistics. In `spark.mllib` we provide the flexibility to calculate pairwise correlations among many series. The supported correlation methods are currently Pearson's and Spearman's correlation. @@ -209,7 +209,7 @@ print(Statistics.corr(data, method="pearson")) ## Stratified sampling -Unlike the other statistics functions, which reside in MLlib, stratified sampling methods, +Unlike the other statistics functions, which reside in `spark.mllib`, stratified sampling methods, `sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified sampling, the keys can be thought of as a label and the value as a specific attribute. For example the key can be man or woman, or document ids, and the respective values can be the list of ages @@ -294,12 +294,12 @@ approxSample = data.sampleByKey(False, fractions); ## Hypothesis testing Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically -significant, whether this result occurred by chance or not. MLlib currently supports Pearson's +significant, whether this result occurred by chance or not. `spark.mllib` currently supports Pearson's chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. -MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared +`spark.mllib` also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared independence tests.
    @@ -438,7 +438,7 @@ for i, result in enumerate(featureTestResults):
    -Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +Additionally, `spark.mllib` provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test for equality of probability distributions. By providing the name of a theoretical distribution (currently solely supported for the normal distribution) and its parameters, or a function to calculate the cumulative distribution according to a given theoretical distribution, the user can @@ -522,7 +522,7 @@ print(testResult) # summary of the test including the p-value, test statistic, ### Streaming Significance Testing -MLlib provides online implementations of some tests to support use cases +`spark.mllib` provides online implementations of some tests to support use cases like A/B testing. These tests may be performed on a Spark Streaming `DStream[(Boolean,Double)]` where the first element of each tuple indicates control group (`false`) or treatment group (`true`) and the @@ -550,7 +550,7 @@ provides streaming hypothesis testing. ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. -MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +`spark.mllib` supports generating random RDDs with i.i.d. values drawn from a given distribution: uniform, standard normal, or Poisson.
    From 4a46b8859d3314b5b45a67cdc5c81fecb6e9e78c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 10 Dec 2015 13:26:30 -0800 Subject: [PATCH 1199/2089] [SPARK-11563][CORE][REPL] Use RpcEnv to transfer REPL-generated classes. This avoids bringing up yet another HTTP server on the driver, and instead reuses the file server already managed by the driver's RpcEnv. As a bonus, the repl now inherits the security features of the network library. There's also a small change to create the directory for storing classes under the root temp dir for the application (instead of directly under java.io.tmpdir). Author: Marcelo Vanzin Closes #9923 from vanzin/SPARK-11563. --- .../org/apache/spark/HttpFileServer.scala | 5 ++ .../scala/org/apache/spark/HttpServer.scala | 26 ++++++---- .../scala/org/apache/spark/SparkContext.scala | 6 +++ .../org/apache/spark/executor/Executor.scala | 6 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 18 +++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 5 ++ .../spark/rpc/netty/NettyStreamManager.scala | 22 ++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 25 +++++++++- docs/configuration.md | 8 ---- docs/security.md | 8 ---- .../org/apache/spark/repl/SparkILoop.scala | 17 ++++--- .../org/apache/spark/repl/SparkIMain.scala | 28 ++--------- .../scala/org/apache/spark/repl/Main.scala | 23 ++++----- .../spark/repl/ExecutorClassLoader.scala | 36 +++++++------- .../spark/repl/ExecutorClassLoaderSuite.scala | 48 +++++++++++++++---- 15 files changed, 183 insertions(+), 98 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7cf7bc0dc6810..77d8ec9bb1607 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -71,6 +71,11 @@ private[spark] class HttpFileServer( serverUri + "/jars/" + file.getName } + def addDirectory(path: String, resourceBase: String): String = { + httpServer.addDirectory(path, resourceBase) + serverUri + path + } + def addFileToDir(file: File, dir: File) : String = { // Check whether the file is a directory. If it is, throw a more meaningful exception. // If we don't catch this, Guava throws a very confusing error message: diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 8de3a6c04df34..faa3ef3d7561d 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -23,10 +23,9 @@ import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} - import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} +import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils @@ -52,6 +51,11 @@ private[spark] class HttpServer( private var server: Server = null private var port: Int = requestedPort + private val servlets = { + val handler = new ServletContextHandler() + handler.setContextPath("/") + handler + } def start() { if (server != null) { @@ -65,6 +69,14 @@ private[spark] class HttpServer( } } + def addDirectory(contextPath: String, resourceBase: String): Unit = { + val holder = new ServletHolder() + holder.setInitParameter("resourceBase", resourceBase) + holder.setInitParameter("pathInfoOnly", "true") + holder.setServlet(new DefaultServlet()) + servlets.addServlet(holder, contextPath.stripSuffix("/") + "/*") + } + /** * Actually start the HTTP server on the given port. * @@ -85,21 +97,17 @@ private[spark] class HttpServer( val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + addDirectory("/", resourceBase.getAbsolutePath) if (securityManager.isAuthenticationEnabled()) { logDebug("HttpServer is using security") val sh = setupSecurityHandler(securityManager) // make sure we go through security handler to get resources - sh.setHandler(handlerList) + sh.setHandler(servlets) server.setHandler(sh) } else { logDebug("HttpServer is not using security") - server.setHandler(handlerList) + server.setHandler(servlets) } server.start() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8a62b71c3fa68..194ecc0a0434e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -457,6 +457,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) + // If running the REPL, register the repl's output dir with the file server. + _conf.getOption("spark.repl.class.outputDir").foreach { path => + val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path)) + _conf.set("spark.repl.class.uri", replUri) + } + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) _statusTracker = new SparkStatusTracker(this) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 7b68dfe5ad06e..552b644d13aaf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -364,9 +364,9 @@ private[spark] class Executor( val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], - classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, _userClassPathFirst) + val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv], + classOf[String], classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3d7d281b0dd66..64a4a8bf7c5eb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -179,6 +179,24 @@ private[spark] trait RpcEnvFileServer { */ def addJar(file: File): String + /** + * Adds a local directory to be served via this file server. + * + * @param baseUri Leading URI path (files can be retrieved by appending their relative + * path to this base URI). This cannot be "files" nor "jars". + * @param path Path to the local directory. + * @return URI for the root of the directory in the file server. + */ + def addDirectory(baseUri: String, path: File): String + + /** Validates and normalizes the base URI for directories. */ + protected def validateDirectoryUri(baseUri: String): String = { + val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/") + require(fixedBaseUri != "/files" && fixedBaseUri != "/jars", + "Directory URI cannot be /files nor /jars.") + fixedBaseUri + } + } private[spark] case class RpcEnvConfig( diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 94dbec593c315..9d098154f7190 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -273,6 +273,11 @@ private[akka] class AkkaFileServer( getFileServer().addJar(file) } + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) + } + def shutdown(): Unit = { if (httpFileServer != null) { httpFileServer.stop() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index a2768b4252dcb..ecd96972455d0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -25,12 +25,22 @@ import org.apache.spark.rpc.RpcEnvFileServer /** * StreamManager implementation for serving files from a NettyRpcEnv. + * + * Three kinds of resources can be registered in this manager, all backed by actual files: + * + * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]]. + * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]]. + * - arbitrary directories; all files under the directory become available through the manager, + * respecting the directory's hierarchy. + * + * Only streaming (openStream) is supported. */ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) extends StreamManager with RpcEnvFileServer { private val files = new ConcurrentHashMap[String, File]() private val jars = new ConcurrentHashMap[String, File]() + private val dirs = new ConcurrentHashMap[String, File]() override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { throw new UnsupportedOperationException() @@ -41,7 +51,10 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) val file = ftype match { case "files" => files.get(fname) case "jars" => jars.get(fname) - case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + case other => + val dir = dirs.get(ftype) + require(dir != null, s"Invalid stream URI: $ftype not found.") + new File(dir, fname) } require(file != null && file.isFile(), s"File not found: $streamId") @@ -60,4 +73,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" } + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null, + s"URI '$fixedBaseUri' already registered.") + s"${rpcEnv.address.toSparkURL}$fixedBaseUri" + } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6cc958a5f6bc8..a61d0479aacdb 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -734,9 +734,28 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) + val dir1 = new File(tempDir, "dir1") + assert(dir1.mkdir()) + val subFile1 = new File(dir1, "file1") + Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + + val dir2 = new File(tempDir, "dir2") + assert(dir2.mkdir()) + val subFile2 = new File(dir2, "file2") + Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + val fileUri = env.fileServer.addFile(file) val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) + val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) + val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) + + // Try registering directories with invalid names. + Seq("/files", "/jars").foreach { uri => + intercept[IllegalArgumentException] { + env.fileServer.addDirectory(uri, dir1) + } + } val destDir = Utils.createTempDir() val sm = new SecurityManager(conf) @@ -745,7 +764,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val files = Seq( (file, fileUri), (empty, emptyUri), - (jar, jarUri)) + (jar, jarUri), + (subFile1, dir1Uri + "/file1"), + (subFile2, dir2Uri + "/file2")) files.foreach { case (f, uri) => val destFile = new File(destDir, f.getName()) Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) @@ -753,7 +774,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } // Try to download files that do not exist. - Seq("files", "jars").foreach { root => + Seq("files", "jars", "dir1").foreach { root => intercept[Exception] { val uri = env.address.toSparkURL + s"/$root/doesNotExist" Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) diff --git a/docs/configuration.md b/docs/configuration.md index fd61ddc244f44..873a2d0b303cd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1053,14 +1053,6 @@ Apart from these, the following properties are also available, and may be useful to port + maxRetries.
    - - - - - diff --git a/docs/security.md b/docs/security.md index e1af221d446b0..0bfc791c5744e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -169,14 +169,6 @@ configure those ports. - - - - - - - - diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 304b1e8cdbed5..22749c4609345 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -253,7 +253,7 @@ class SparkILoop( case xs => xs find (_.name == cmd) } } - private var fallbackMode = false + private var fallbackMode = false private def toggleFallbackMode() { val old = fallbackMode @@ -261,9 +261,9 @@ class SparkILoop( System.setProperty("spark.repl.fallback", fallbackMode.toString) echo(s""" |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would + | If you have defined classes in the repl, it would |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` + |into issues it would be good to restart the repl and turn on `:fallback` |mode as first command. """.stripMargin) } @@ -350,7 +350,7 @@ class SparkILoop( shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. + |disable/enable advanced repl changes, these fix some issues but may introduce others. |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) @@ -1009,8 +1009,13 @@ class SparkILoop( val conf = new SparkConf() .setMaster(getMaster()) .setJars(jars) - .set("spark.repl.class.uri", intp.classServerUri) .setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1025,7 +1030,7 @@ class SparkILoop( val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") } catch { diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 829b12269fd2b..7fcb423575d39 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -37,7 +37,7 @@ import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.spark.annotation.DeveloperApi @@ -96,10 +96,9 @@ import org.apache.spark.annotation.DeveloperApi private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - private lazy val outputDir = { - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - Utils.createTempDir(rootDir) + private[repl] val outputDir = { + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + Utils.createTempDir(root = rootDir, namePrefix = "repl") } if (SPARK_DEBUG_REPL) { echo("Output directory: " + outputDir) @@ -114,8 +113,6 @@ import org.apache.spark.annotation.DeveloperApi private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - private val classServerPort = conf.getInt("spark.replClassServer.port", 0) - private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings private var printResults = true // whether to print result lines private var totalSilence = false // whether to print anything @@ -124,22 +121,6 @@ import org.apache.spark.annotation.DeveloperApi private var bindExceptions = true // whether to bind the lastException variable private var _executionWrapper = "" // code to be wrapped around all lines - - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - if (SPARK_DEBUG_REPL) { - echo("Class server started, URI = " + classServer.uri) - } - - /** - * URI of the class server used to feed REPL compiled classes. - * - * @return The string representing the class server uri - */ - @DeveloperApi - def classServerUri = classServer.uri - /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -994,7 +975,6 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi def close() { reporter.flush() - classServer.stop() } /** diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 455a6b9a93aad..44650f25f7a18 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.SQLContext object Main extends Logging { val conf = new SparkConf() - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - val outputDir = Utils.createTempDir(rootDir) + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator)), true) // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed - lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. @@ -45,7 +47,6 @@ object Main extends Logging { } def main(args: Array[String]) { - val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", @@ -57,11 +58,7 @@ object Main extends Logging { if (!hasErrors) { if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() interp.process(settings) // Repl starts and goes in loop of R.E.P.L - classServer.stop() Option(sparkContext).map(_.stop) } } @@ -82,9 +79,13 @@ object Main extends Logging { val conf = new SparkConf() .setMaster(getMaster) .setJars(jars) - .set("spark.repl.class.uri", classServer.uri) .setIfMissing("spark.app.name", "Spark shell") - logInfo("Spark class server started at " + classServer.uri) + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a8859fcd4584b..da8f0aa1e3360 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -19,6 +19,7 @@ package org.apache.spark.repl import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels import scala.util.control.NonFatal @@ -38,7 +39,11 @@ import org.apache.spark.util.ParentClassLoader * This class loader delegates getting/finding resources to parent loader, * which makes sense until REPL never provide resource dynamically. */ -class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { val uri = new URI(classUri) val directory = uri.getPath @@ -48,13 +53,12 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 - // Hadoop FileSystem object for our URI, if it isn't using HTTP - var fileSystem: FileSystem = { - if (Set("http", "https", "ftp").contains(uri.getScheme)) { - null - } else { - FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) - } + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) } override def getResource(name: String): URL = { @@ -90,6 +94,11 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + Channels.newInputStream(channel) + } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) @@ -126,7 +135,8 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } - private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) if (fileSystem.exists(path)) { fileSystem.open(path) @@ -139,13 +149,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val pathInDirectory = name.replace('.', '/') + ".class" var inputStream: InputStream = null try { - inputStream = { - if (fileSystem != null) { - getClassFileInputStreamFromFileSystem(pathInDirectory) - } else { - getClassFileInputStreamFromHttpServer(pathInDirectory) - } - } + inputStream = fetchFn(pathInDirectory) val bytes = readAndTransformClass(name, inputStream) Some(defineClass(name, bytes, 0, bytes.length)) } catch { diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c1211f7596b9c..1360f09e7fa1f 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -18,24 +18,29 @@ package org.apache.spark.repl import java.io.File -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} +import java.nio.channels.{FileChannel, ReadableByteChannel} import java.nio.charset.StandardCharsets +import java.nio.file.{Paths, StandardOpenOption} import java.util -import com.google.common.io.Files - import scala.concurrent.duration._ import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps +import com.google.common.io.Files import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.Matchers.anyString import org.mockito.Mockito._ import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils class ExecutorClassLoaderSuite @@ -78,7 +83,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") @@ -86,7 +91,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -94,7 +99,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -102,7 +107,7 @@ class ExecutorClassLoaderSuite test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() } @@ -110,7 +115,7 @@ class ExecutorClassLoaderSuite test("resource from parent") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val resourceName: String = parentResourceNames.head val is = classLoader.getResourceAsStream(resourceName) assert(is != null, s"Resource $resourceName not found") @@ -120,7 +125,7 @@ class ExecutorClassLoaderSuite test("resources from parent") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val resourceName: String = parentResourceNames.head val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) assert(resources.hasMoreElements, s"Resource $resourceName not found") @@ -142,7 +147,7 @@ class ExecutorClassLoaderSuite SparkEnv.set(mockEnv) // Create an ExecutorClassLoader that's configured to load classes from the HTTP server val parentLoader = new URLClassLoader(Array.empty, null) - val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + val classLoader = new ExecutorClassLoader(conf, null, classServer.uri, parentLoader, false) classLoader.httpUrlConnectionTimeoutMillis = 500 // Check that this class loader can actually load classes that exist val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() @@ -177,4 +182,27 @@ class ExecutorClassLoaderSuite failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) } + test("fetch classes using Spark's RpcEnv") { + val env = mock[SparkEnv] + val rpcEnv = mock[RpcEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() { + override def answer(invocation: InvocationOnMock): ReadableByteChannel = { + val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) + val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/")) + FileChannel.open(path, StandardOpenOption.READ) + } + }) + + val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", + getClass().getClassLoader(), false) + + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + } + } + } From 6a6c1fc5c807ba4e8aba3e260537aa527ff5d46a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 10 Dec 2015 14:21:15 -0800 Subject: [PATCH 1200/2089] [SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark Adding ability to define an initial state RDD for use with updateStateByKey PySpark. Added unit test and changed stateful_network_wordcount example to use initial RDD. Author: Bryan Cutler Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713. --- .../streaming/stateful_network_wordcount.py | 5 ++++- python/pyspark/streaming/dstream.py | 13 ++++++++++-- python/pyspark/streaming/tests.py | 20 +++++++++++++++++++ .../streaming/api/python/PythonDStream.scala | 14 +++++++++++-- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 16ef646b7c42e..f8bbc659c2ea7 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -44,13 +44,16 @@ ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") + # RDD with initial state (key, value) pairs + initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) + def updateFunc(new_values, last_sum): return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) running_counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ - .updateStateByKey(updateFunc) + .updateStateByKey(updateFunc, initialRDD=initialStateRDD) running_counts.pprint() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index acec850f02c2d..f61137cb88c47 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -568,7 +568,7 @@ def invReduceFunc(t, a, b): self._ssc._jduration(slideDuration)) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +579,9 @@ def updateStateByKey(self, updateFunc, numPartitions=None): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +593,13 @@ def reduceFunc(t, a, b): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a2bfd79e1abcd..4949cd68e3212 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,26 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_update_state_by_key_initial_rdd(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + initial = [('k', [0, 1])] + initial = self.sc.parallelize(initial, 1) + + input = [[('k', i)] for i in range(2, 5)] + + def func(dstream): + return dstream.updateStateByKey(updater, initialRDD=initial) + + expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + def test_failed_func(self): # Test failure in # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 994309ddd0a3e..056248ccc7bcd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction, + initialRDD: Option[RDD[Array[Byte]]]) extends PythonDStream(parent, reduceFunc) { + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None) + + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction, + initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd)) + super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -274,7 +284,7 @@ private[python] class PythonStateDStream( val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - func(lastState, rdd, validTime) + func(lastState.orElse(initialRDD), rdd, validTime) } else { lastState } From 23a9e62bad9669e9ff5dc4bd714f58d12f9be0b5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 10 Dec 2015 15:29:04 -0800 Subject: [PATCH 1201/2089] [SPARK-12251] Document and improve off-heap memory configurations This patch adds documentation for Spark configurations that affect off-heap memory and makes some naming and validation improvements for those configs. - Change `spark.memory.offHeapSize` to `spark.memory.offHeap.size`. This is fine because this configuration has not shipped in any Spark release yet (it's new in Spark 1.6). - Deprecated `spark.unsafe.offHeap` in favor of a new `spark.memory.offHeap.enabled` configuration. The motivation behind this change is to gather all memory-related configurations under the same prefix. - Add a check which prevents users from setting `spark.memory.offHeap.enabled=true` when `spark.memory.offHeap.size == 0`. After SPARK-11389 (#9344), which was committed in Spark 1.6, Spark enforces a hard limit on the amount of off-heap memory that it will allocate to tasks. As a result, enabling off-heap execution memory without setting `spark.memory.offHeap.size` will lead to immediate OOMs. The new configuration validation makes this scenario easier to diagnose, helping to avoid user confusion. - Document these configurations on the configuration page. Author: Josh Rosen Closes #10237 from JoshRosen/SPARK-12251. --- .../scala/org/apache/spark/SparkConf.scala | 4 +++- .../apache/spark/memory/MemoryManager.scala | 10 +++++++-- .../spark/memory/TaskMemoryManagerSuite.java | 21 +++++++++++++++---- .../sort/PackedRecordPointerSuite.java | 6 ++++-- .../sort/ShuffleInMemorySorterSuite.java | 4 ++-- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 ++-- .../sort/UnsafeExternalSorterSuite.java | 2 +- .../sort/UnsafeInMemorySorterSuite.java | 4 ++-- .../memory/StaticMemoryManagerSuite.scala | 2 +- .../memory/UnifiedMemoryManagerSuite.scala | 2 +- docs/configuration.md | 16 ++++++++++++++ .../sql/execution/joins/HashedRelation.scala | 6 +++++- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 +- .../UnsafeKVExternalSorterSuite.scala | 2 +- 15 files changed, 65 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 19633a3ce6a02..d3384fb297732 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -597,7 +597,9 @@ private[spark] object SparkConf extends Logging { "spark.streaming.fileStream.minRememberDuration" -> Seq( AlternateConfig("spark.streaming.minRememberDuration", "1.5")), "spark.yarn.max.executor.failures" -> Seq( - AlternateConfig("spark.yarn.max.worker.failures", "1.5")) + AlternateConfig("spark.yarn.max.worker.failures", "1.5")), + "spark.memory.offHeap.enabled" -> Seq( + AlternateConfig("spark.unsafe.offHeap", "1.6")) ) /** diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ae9e1ac0e246b..e707e27d96b50 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -50,7 +50,7 @@ private[spark] abstract class MemoryManager( storageMemoryPool.incrementPoolSize(storageMemory) onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) - offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeapSize", 0)) + offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeap.size", 0)) /** * Total available memory for storage, in bytes. This amount can vary over time, depending on @@ -182,7 +182,13 @@ private[spark] abstract class MemoryManager( * sun.misc.Unsafe. */ final val tungstenMemoryMode: MemoryMode = { - if (conf.getBoolean("spark.unsafe.offHeap", false)) MemoryMode.OFF_HEAP else MemoryMode.ON_HEAP + if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { + require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } } /** diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 711eed0193bc0..776a2997cf91f 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -29,7 +29,7 @@ public class TaskMemoryManagerSuite { public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), + new SparkConf().set("spark.memory.offHeap.enabled", "false"), Long.MAX_VALUE, Long.MAX_VALUE, 1), @@ -41,8 +41,10 @@ public void leakedPageMemoryIsDetected() { @Test public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); // In off-heap mode, an offset is an absolute address that may require more than 51 bits to // encode. This test exercises that corner-case: @@ -55,7 +57,7 @@ public void encodePageNumberAndOffsetOffHeap() { @Test public void encodePageNumberAndOffsetOnHeap() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); @@ -104,4 +106,15 @@ public void cooperativeSpilling() { assert(manager.cleanUpAllAllocatedMemory() == 0); } + @Test + public void offHeapConfigurationBackwardsCompatibility() { + // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which + // was deprecated in Spark 1.6 and replaced by `spark.memory.offHeap.enabled` (see SPARK-12251). + final SparkConf conf = new SparkConf() + .set("spark.unsafe.offHeap", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); + assert(manager.tungstenMemoryMode == MemoryMode.OFF_HEAP); + } + } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 9a43f1f3a9235..fe5abc5c23049 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -35,7 +35,7 @@ public class PackedRecordPointerSuite { @Test public void heap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); @@ -54,7 +54,9 @@ public void heap() throws IOException { @Test public void offHeap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true"); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "10000"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index faa5a863ee630..0328e63e45439 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -34,7 +34,7 @@ public class ShuffleInMemorySorterSuite { final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); @@ -64,7 +64,7 @@ public void testBasicSorting() throws Exception { "Lychee", "Mango" }; - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index bc85918c59aab..5fe64bde3604a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -108,7 +108,7 @@ public void setUp() throws IOException { spillFilesCreated.clear(); conf = new SparkConf() .set("spark.buffer.pageSize", "1m") - .set("spark.unsafe.offHeap", "false"); + .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 8724a34988421..702ba5469b8b4 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -85,8 +85,8 @@ public void setup() { memoryManager = new TestMemoryManager( new SparkConf() - .set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()) - .set("spark.memory.offHeapSize", "256mb")); + .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeap.size", "256mb")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a1c9f6fab8e65..e0ee281e98b71 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -58,7 +58,7 @@ public class UnsafeExternalSorterSuite { final LinkedList spillFilesCreated = new LinkedList(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index a203a09648ac0..93efd033eb940 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -46,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, @@ -71,7 +71,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { "Mango" }; final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 272253bc94e91..68cf26fc3ed5d 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -47,7 +47,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { conf.clone .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString), + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxStorageMemory = 0, numCores = 1) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 71221deeb4c28..e21a028b7faec 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -42,7 +42,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) .set("spark.memory.storageFraction", storageFraction.toString) UnifiedMemoryManager(conf, numCores = 1) } diff --git a/docs/configuration.md b/docs/configuration.md index 873a2d0b303cd..55cf4b2dac5f5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -738,6 +738,22 @@ Apart from these, the following properties are also available, and may be useful this description. + + + + + + + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index aebfea5832402..8c7099ab5a34d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -334,7 +334,11 @@ private[joins] final class UnsafeHashedRelation( // so that tests compile: val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0) + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7ceaee38d131b..5a8406789ab81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -61,7 +61,7 @@ class UnsafeFixedWidthAggregationMapSuite } test(name) { - val conf = new SparkConf().set("spark.unsafe.offHeap", "false") + val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 7b80963ec8708..29027a664b4b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { pageSize: Long, spill: Boolean): Unit = { val memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")) + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, From 5030923ea8bb94ac8fa8e432de9fc7089aa93986 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 10 Dec 2015 15:30:08 -0800 Subject: [PATCH 1202/2089] [SPARK-12155][SPARK-12253] Fix executor OOM in unified memory management **Problem.** In unified memory management, acquiring execution memory may lead to eviction of storage memory. However, the space freed from evicting cached blocks is distributed among all active tasks. Thus, an incorrect upper bound on the execution memory per task can cause the acquisition to fail, leading to OOM's and premature spills. **Example.** Suppose total memory is 1000B, cached blocks occupy 900B, `spark.memory.storageFraction` is 0.4, and there are two active tasks. In this case, the cap on task execution memory is 100B / 2 = 50B. If task A tries to acquire 200B, it will evict 100B of storage but can only acquire 50B because of the incorrect cap. For another example, see this [regression test](https://github.com/andrewor14/spark/blob/fix-oom/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala#L233) that I stole from JoshRosen. **Solution.** Fix the cap on task execution memory. It should take into account the space that could have been freed by storage in addition to the current amount of memory available to execution. In the example above, the correct cap should have been 600B / 2 = 300B. This patch also guards against the race condition (SPARK-12253): (1) Existing tasks collectively occupy all execution memory (2) New task comes in and blocks while existing tasks spill (3) After tasks finish spilling, another task jumps in and puts in a large block, stealing the freed memory (4) New task still cannot acquire memory and goes back to sleep Author: Andrew Or Closes #10240 from andrewor14/fix-oom. --- .../spark/memory/ExecutionMemoryPool.scala | 57 +++++++++++++------ .../spark/memory/UnifiedMemoryManager.scala | 57 ++++++++++++++----- .../org/apache/spark/scheduler/Task.scala | 6 ++ .../memory/UnifiedMemoryManagerSuite.scala | 25 ++++++++ 4 files changed, 114 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 9023e1ac012b7..dbb0ad8d5c673 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -70,11 +70,28 @@ private[memory] class ExecutionMemoryPool( * active tasks) before it is forced to spill. This can happen if the number of tasks increase * but an older task had a lot of memory already. * + * @param numBytes number of bytes to acquire + * @param taskAttemptId the task attempt acquiring memory + * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in + * one parameter (Long) that represents the desired amount of memory by + * which this pool should be expanded. + * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool + * at this given moment. This is not a field because the max pool + * size is variable in certain cases. For instance, in unified + * memory management, the execution pool can be expanded by evicting + * cached blocks, thereby shrinking the storage pool. + * * @return the number of bytes granted to the task. */ - def acquireMemory(numBytes: Long, taskAttemptId: Long): Long = lock.synchronized { + private[memory] def acquireMemory( + numBytes: Long, + taskAttemptId: Long, + maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit, + computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized { assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + // TODO: clean up this clunky method signature + // Add this task to the taskMemory map just so we can keep an accurate count of the number // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` if (!memoryForTask.contains(taskAttemptId)) { @@ -91,25 +108,31 @@ private[memory] class ExecutionMemoryPool( val numActiveTasks = memoryForTask.keys.size val curMem = memoryForTask(taskAttemptId) - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = - math.min(numBytes, math.max(0, (poolSize / numActiveTasks) - curMem)) + // In every iteration of this loop, we should first try to reclaim any borrowed execution + // space from storage. This is necessary because of the potential race condition where new + // storage blocks may steal the free execution memory that this task was waiting for. + maybeGrowPool(numBytes - memoryFree) + + // Maximum size the pool would have after potentially growing the pool. + // This is used to compute the upper bound of how much memory each task can occupy. This + // must take into account potential free memory as well as the amount this pool currently + // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management, + // we did not take into account space that could have been freed by evicting cached blocks. + val maxPoolSize = computeMaxPoolSize() + val maxMemoryPerTask = maxPoolSize / numActiveTasks + val minMemoryPerTask = poolSize / (2 * numActiveTasks) + + // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks + val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem)) // Only give it as much memory as is free, which might be none if it reached 1 / numTasks val toGrant = math.min(maxToGrant, memoryFree) - if (curMem < poolSize / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if (memoryFree >= math.min(maxToGrant, poolSize / (2 * numActiveTasks) - curMem)) { - memoryForTask(taskAttemptId) += toGrant - return toGrant - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") - lock.wait() - } + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) { + logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() } else { memoryForTask(taskAttemptId) += toGrant return toGrant diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0b9f6a9dc0525..829f054dba0e9 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -81,22 +81,51 @@ private[spark] class UnifiedMemoryManager private[memory] ( assert(numBytes >= 0) memoryMode match { case MemoryMode.ON_HEAP => - if (numBytes > onHeapExecutionMemoryPool.memoryFree) { - val extraMemoryNeeded = numBytes - onHeapExecutionMemoryPool.memoryFree - // There is not enough free memory in the execution pool, so try to reclaim memory from - // storage. We can reclaim any free memory from the storage pool. If the storage pool - // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim - // the memory that storage has borrowed from execution. - val memoryReclaimableFromStorage = - math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) - if (memoryReclaimableFromStorage > 0) { - // Only reclaim as much space as is necessary and available: - val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( - math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) - onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + + /** + * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool. + * + * When acquiring memory for a task, the execution pool may need to make multiple + * attempts. Each attempt must be able to evict storage in case another task jumps in + * and caches a large block between the attempts. This is called once per attempt. + */ + def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = { + if (extraMemoryNeeded > 0) { + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = + math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + } } } - onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + + /** + * The size the execution pool would have after evicting storage memory. + * + * The execution memory pool divides this quantity among the active tasks evenly to cap + * the execution memory allocation for each task. It is important to keep this greater + * than the execution pool size, which doesn't take into account potential memory that + * could be freed by evicting storage. Otherwise we may hit SPARK-12155. + * + * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness + * in execution memory allocation across tasks, Otherwise, a task may occupy more than + * its fair share of execution memory, mistakenly thinking that other tasks can acquire + * the portion of storage memory that cannot be evicted. + */ + def computeMaxExecutionPoolSize(): Long = { + maxMemory - math.min(storageMemoryUsed, storageRegionSize) + } + + onHeapExecutionMemoryPool.acquireMemory( + numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize) + case MemoryMode.OFF_HEAP => // For now, we only support on-heap caching of data, so we do not need to interact with // the storage pool when allocating off-heap memory. This will change in the future, though. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d4bc3a5c900f7..9f27eed626be3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -92,6 +92,12 @@ private[spark] abstract class Task[T]( Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } } } finally { TaskContext.unset() diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e21a028b7faec..6cc48597d38f9 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -230,4 +230,29 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(exception.getMessage.contains("larger heap size")) } + test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + val ms = makeMemoryStore(mm) + assert(mm.maxMemory === 1000) + // Have two tasks each acquire some execution memory so that the memory pool registers that + // there are two active tasks: + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assert(mm.acquireExecutionMemory(100L, 1, MemoryMode.ON_HEAP) === 100L) + // Fill up all of the remaining memory with storage. + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 800) + assert(mm.executionMemoryUsed === 200) + // A task should still be able to allocate 100 bytes execution memory by evicting blocks + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.executionMemoryUsed === 300) + assert(mm.storageMemoryUsed === 700) + assert(evictedBlocks.nonEmpty) + } + } From 24d3357d66e14388faf8709b368edca70ea96432 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 10 Dec 2015 15:31:46 -0800 Subject: [PATCH 1203/2089] [STREAMING][DOC][MINOR] Update the description of direct Kafka stream doc With the merge of [SPARK-8337](https://issues.apache.org/jira/browse/SPARK-8337), now the Python API has the same functionalities compared to Scala/Java, so here changing the description to make it more precise. zsxwing tdas , please review, thanks a lot. Author: jerryshao Closes #10246 from jerryshao/direct-kafka-doc-update. --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index b00351b2fbcc0..5be73c42560f5 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -74,7 +74,7 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. This approach has the following advantages over the receiver-based approach (i.e. Approach 1). From b1b4ee7f3541d92c8bc2b0b4fdadf46cfdb09504 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 10 Dec 2015 17:22:18 -0800 Subject: [PATCH 1204/2089] [SPARK-12258][SQL] passing null into ScalaUDF Check nullability and passing them into ScalaUDF. Closes #10249 Author: Davies Liu Closes #10259 from davies/udf_null. --- .../apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 7 +++++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 03b89221ef2d3..5deb2f81d1738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,8 +1029,11 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zip(evals).map { - case (converter, eval) => s"$converter.apply(${eval.value})" + val funcArguments = converterTerms.zipWithIndex.map { + case (converter, i) => + val eval = evals(i) + val dt = children(i).dataType + s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" }.mkString(",") val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 605a6549dd686..8887dc68a50e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1138,14 +1138,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-11725: correctly handle null inputs for ScalaUDF") { - val df = Seq( + val df = sparkContext.parallelize(Seq( new java.lang.Integer(22) -> "John", - null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") + // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) null else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else i * 2 } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) From 518ab5101073ee35d62e33c8f7281a1e6342101e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 11 Dec 2015 02:35:53 -0500 Subject: [PATCH 1205/2089] [SPARK-10991][ML] logistic regression training summary handle empty prediction col LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 ) Author: Holden Karau Author: Holden Karau Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col. --- .../classification/LogisticRegression.scala | 20 +++++++++++++++++-- .../LogisticRegressionSuite.scala | 11 ++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 19cc323d5073f..486043e8d9741 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] ( new NullPointerException()) } + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + private[classification] def setSummary( summary: LogisticRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a9a6ff8a783d5..1087afb0cdf79 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -99,6 +99,17 @@ class LogisticRegressionSuite assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default From c119a34d1e9e599e302acfda92e5de681086a19f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 11 Dec 2015 11:15:53 -0800 Subject: [PATCH 1206/2089] [SPARK-12258] [SQL] passing null into ScalaUDF (follow-up) This is a follow-up PR for #10259 Author: Davies Liu Closes #10266 from davies/null_udf2. --- .../sql/catalyst/expressions/ScalaUDF.scala | 31 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 8 +++-- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5deb2f81d1738..85faa19bbf5ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,24 +1029,27 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zipWithIndex.map { - case (converter, i) => - val eval = evals(i) - val dt = children(i).dataType - s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" - }.mkString(",") - val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + - s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + - s".apply($funcTerm.apply($funcArguments));" + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip - evalCode + s""" - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - Boolean ${ev.isNull}; + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + s""" + $evalCode + ${converters.mkString("\n")} $callFunc - ${ev.value} = $resultTerm; - ${ev.isNull} = $resultTerm == null; + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8887dc68a50e7..5353fefaf4b84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) -10 else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else null } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) + + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) From 0fb9825556dbbcc98d7eafe9ddea8676301e09bb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Dec 2015 11:47:35 -0800 Subject: [PATCH 1207/2089] [SPARK-12146][SPARKR] SparkR jsonFile should support multiple input files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ```jsonFile``` should support multiple input files, such as: ```R jsonFile(sqlContext, c(“path1â€, “path2â€)) # character vector as arguments jsonFile(sqlContext, “path1,path2â€) ``` * Meanwhile, ```jsonFile``` has been deprecated by Spark SQL and will be removed at Spark 2.0. So we mark ```jsonFile``` deprecated and use ```read.json``` at SparkR side. * Replace all ```jsonFile``` with ```read.json``` at test_sparkSQL.R, but still keep jsonFile test case. * If this PR is accepted, we should also make almost the same change for ```parquetFile```. cc felixcheung sun-rui shivaram Author: Yanbo Liang Closes #10145 from yanboliang/spark-12146. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 102 +++++++++--------- R/pkg/R/SQLContext.R | 29 ++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 120 ++++++++++++---------- examples/src/main/r/dataframe.R | 2 +- 5 files changed, 138 insertions(+), 116 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba64bc59edee5..cab39d68c3f52 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -267,6 +267,7 @@ export("as.DataFrame", "createExternalTable", "dropTempTable", "jsonFile", + "read.json", "loadDF", "parquetFile", "read.df", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c4a2585e291..975b058c0aaf1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -24,14 +24,14 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame #' @description DataFrames can be created using functions like \link{createDataFrame}, -#' \link{jsonFile}, \link{table} etc. +#' \link{read.json}, \link{table} etc. #' @family DataFrame functions #' @rdname DataFrame #' @docType class #' #' @slot env An R environment that stores bookkeeping states of the DataFrame #' @slot sdf A Java object reference to the backing Scala DataFrame -#' @seealso \link{createDataFrame}, \link{jsonFile}, \link{table} +#' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} #' @export #' @examples @@ -77,7 +77,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -102,7 +102,7 @@ setMethod("printSchema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -126,7 +126,7 @@ setMethod("schema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -157,7 +157,7 @@ setMethod("explain", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -182,7 +182,7 @@ setMethod("isLocal", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -207,7 +207,7 @@ setMethod("showDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -234,7 +234,7 @@ setMethod("show", "DataFrame", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -261,7 +261,7 @@ setMethod("dtypes", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' columns(df) #' colnames(df) #'} @@ -376,7 +376,7 @@ setMethod("coltypes", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' coltypes(df) <- c("character", "integer") #' coltypes(df) <- c(NA, "numeric") #'} @@ -423,7 +423,7 @@ setMethod("coltypes<-", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "json_df") #' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} @@ -476,7 +476,7 @@ setMethod("insertInto", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -504,7 +504,7 @@ setMethod("cache", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -532,7 +532,7 @@ setMethod("persist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -560,7 +560,7 @@ setMethod("unpersist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -585,7 +585,7 @@ setMethod("repartition", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newRDD <- toJSON(df) #'} setMethod("toJSON", @@ -613,7 +613,7 @@ setMethod("toJSON", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -637,7 +637,7 @@ setMethod("saveAsParquetFile", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -672,7 +672,7 @@ setMethod("unique", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} @@ -711,7 +711,7 @@ setMethod("sample_frac", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' count(df) #' } setMethod("count", @@ -741,7 +741,7 @@ setMethod("nrow", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' ncol(df) #' } setMethod("ncol", @@ -762,7 +762,7 @@ setMethod("ncol", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dim(df) #' } setMethod("dim", @@ -786,7 +786,7 @@ setMethod("dim", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -858,7 +858,7 @@ setMethod("collect", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -879,7 +879,7 @@ setMethod("limit", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -908,7 +908,7 @@ setMethod("take", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' head(df) #' } setMethod("head", @@ -931,7 +931,7 @@ setMethod("head", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' first(df) #' } setMethod("first", @@ -952,7 +952,7 @@ setMethod("first", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' rdd <- toRDD(df) #'} setMethod("toRDD", @@ -1298,7 +1298,7 @@ setMethod("select", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -1327,7 +1327,7 @@ setMethod("selectExpr", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -1352,7 +1352,7 @@ setMethod("withColumn", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) @@ -1402,7 +1402,7 @@ setMethod("transform", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1427,7 +1427,7 @@ setMethod("withColumnRenamed", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1471,7 +1471,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) @@ -1547,7 +1547,7 @@ setMethod("orderBy", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1591,8 +1591,8 @@ setMethod("where", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1648,8 +1648,8 @@ setMethod("join", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' merge(df1, df2) # Performs a Cartesian #' merge(df1, df2, by = "col1") # Performs an inner join based on expression #' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) @@ -1781,8 +1781,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1824,8 +1824,8 @@ setMethod("rbind", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1851,8 +1851,8 @@ setMethod("intersect", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1892,7 +1892,7 @@ setMethod("except", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") #' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } @@ -1957,7 +1957,7 @@ setMethod("saveDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1998,7 +1998,7 @@ setMethod("saveAsTable", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -2054,7 +2054,7 @@ setMethod("summary", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' dropna(df) #' } setMethod("dropna", @@ -2108,7 +2108,7 @@ setMethod("na.omit", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' fillna(df, 1) #' fillna(df, list("age" = 20, "name" = "unknown")) #' } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f678c70a7a77c..9243d70e66f75 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -208,24 +208,33 @@ setMethod("toDF", signature(x = "RDD"), #' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.json +#' @name read.json #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) #' df <- jsonFile(sqlContext, path) #' } - -jsonFile <- function(sqlContext, path) { +read.json <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- suppressWarnings(normalizePath(path)) - # Convert a string vector of paths to a string containing comma separated paths - path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlContext, "jsonFile", path) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } +#' @rdname read.json +#' @name jsonFile +#' @export +jsonFile <- function(sqlContext, path) { + .Deprecated("read.json") + read.json(sqlContext, path) +} + #' JSON RDD #' @@ -299,7 +308,7 @@ parquetFile <- function(sqlContext, ...) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } @@ -323,7 +332,7 @@ sql <- function(sqlContext, sqlQuery) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- table(sqlContext, "table") #' } @@ -396,7 +405,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' cacheTable(sqlContext, "table") #' } @@ -418,7 +427,7 @@ cacheTable <- function(sqlContext, tableName) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' uncacheTable(sqlContext, "table") #' } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2051784427be3..ed9b2c9d4d16c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -330,7 +330,7 @@ writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { # ArrayType - df <- jsonFile(sqlContext, complexTypeJsonPath) + df <- read.json(sqlContext, complexTypeJsonPath) ldf <- collect(df) expect_equal(nrow(ldf), 3) @@ -357,7 +357,7 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) # StructType - df <- jsonFile(sqlContext, mapTypeJsonPath) + df <- read.json(sqlContext, mapTypeJsonPath) expect_equal(dtypes(df), list(c("info", "struct"), c("name", "string"))) ldf <- collect(df) @@ -371,10 +371,22 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) }) -test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) +test_that("read.json()/jsonFile() on a local file returns a DataFrame", { + df <- read.json(sqlContext, jsonPath) expect_is(df, "DataFrame") expect_equal(count(df), 3) + # read.json()/jsonFile() works with multiple input paths + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") + write.df(df, jsonPath2, "json", mode="overwrite") + jsonDF1 <- read.json(sqlContext, c(jsonPath, jsonPath2)) + expect_is(jsonDF1, "DataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath, jsonPath2))) + expect_is(jsonDF2, "DataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) }) test_that("jsonRDD() on a RDD with json string", { @@ -391,7 +403,7 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") cacheTable(sqlContext, "table1") uncacheTable(sqlContext, "table1") @@ -400,7 +412,7 @@ test_that("test cache, uncache and clearCache", { }) test_that("test tableNames and tables", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) @@ -409,7 +421,7 @@ test_that("test tableNames and tables", { }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "DataFrame") @@ -445,7 +457,7 @@ test_that("insertInto() on a registered table", { }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") expect_is(tabledf, "DataFrame") @@ -458,14 +470,14 @@ test_that("table() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) @@ -487,7 +499,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) @@ -505,7 +517,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) @@ -516,7 +528,7 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row @@ -528,7 +540,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(names(rdf)[1], "age") @@ -550,14 +562,14 @@ test_that("collect() returns a data.frame", { }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfLimited <- limit(df, 2) expect_is(dfLimited, "DataFrame") expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(nrow(collect(df)), nrow(take(df, 10))) expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) @@ -584,7 +596,7 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -601,7 +613,7 @@ test_that("multiple pipeline transformations result in an RDD with the correct v }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -620,7 +632,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testSchema <- schema(df) expect_equal(length(testSchema$fields()), 2) expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") @@ -641,7 +653,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("names() colnames() set the column names", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) names(df) <- c("col1", "col2") expect_equal(colnames(df)[2], "col2") @@ -661,7 +673,7 @@ test_that("names() colnames() set the column names", { }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testHead <- head(df) expect_equal(nrow(testHead), 3) expect_equal(ncol(testHead), 2) @@ -694,7 +706,7 @@ test_that("distinct() and unique on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) uniques <- distinct(df) expect_is(uniques, "DataFrame") expect_equal(count(uniques), 3) @@ -705,7 +717,7 @@ test_that("distinct() and unique on DataFrames", { }) test_that("sample on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "DataFrame") @@ -721,7 +733,7 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + df <- select(read.json(sqlContext, jsonPath), "name", "age") expect_is(df$name, "Column") expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") @@ -747,7 +759,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df1 <- select(df, "name") expect_equal(columns(df1), c("name")) expect_equal(count(df1), 3) @@ -770,8 +782,8 @@ test_that("select with column", { }) test_that("subsetting", { - # jsonFile returns columns in random order - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + # read.json returns columns in random order + df <- select(read.json(sqlContext, jsonPath), "name", "age") filtered <- df[df$age > 20,] expect_equal(count(filtered), 1) expect_equal(columns(filtered), c("name", "age")) @@ -808,7 +820,7 @@ test_that("subsetting", { }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -819,12 +831,12 @@ test_that("selectExpr() on a DataFrame", { }) test_that("expr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) }) test_that("column calculation", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -915,7 +927,7 @@ test_that("column functions", { expect_equal(class(rank())[[1]], "Column") expect_equal(rank(1:3), as.numeric(c(1:3))) - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) expect_equal(collect(df2)[[2, 1]], TRUE) expect_equal(collect(df2)[[2, 2]], FALSE) @@ -983,7 +995,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -1004,7 +1016,7 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -1100,7 +1112,7 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { }) test_that("group by, agg functions", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -1145,7 +1157,7 @@ test_that("group by, agg functions", { "{\"name\":\"ID2\", \"value\": \"-3\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - gd2 <- groupBy(jsonFile(sqlContext, jsonPath2), "name") + gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") df6 <- agg(gd2, value = "sum") df6_local <- collect(df6) expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) @@ -1162,7 +1174,7 @@ test_that("group by, agg functions", { "{\"name\":\"Justin\", \"age\":1}") jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines3, jsonPath3) - df8 <- jsonFile(sqlContext, jsonPath3) + df8 <- read.json(sqlContext, jsonPath3) gd3 <- groupBy(df8, "name") gd3_local <- collect(sum(gd3)) expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) @@ -1181,7 +1193,7 @@ test_that("group by, agg functions", { }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) sorted <- arrange(df, df$age) expect_equal(collect(sorted)[1,2], "Michael") @@ -1207,7 +1219,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) filtered <- filter(df, "age > 20") expect_equal(count(filtered), 1) expect_equal(collect(filtered)$name, "Andy") @@ -1230,7 +1242,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() and merge() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -1238,7 +1250,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlContext, jsonPath2) + df2 <- read.json(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -1313,14 +1325,14 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines3, jsonPath3) - df3 <- jsonFile(sqlContext, jsonPath3) + df3 <- read.json(sqlContext, jsonPath3) expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") expect_equal(getSerializedMode(testRDD), "string") @@ -1328,7 +1340,7 @@ test_that("toJSON() returns an RDD of the correct values", { }) test_that("showDF()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) s <- capture.output(showDF(df)) expected <- paste("+----+-------+\n", "| age| name|\n", @@ -1341,12 +1353,12 @@ test_that("showDF()", { }) test_that("isLocal()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", @@ -1383,7 +1395,7 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1395,7 +1407,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate(), transform(), rename() and names()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1425,7 +1437,7 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("write.df() on DataFrame and works with read.parquet", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- read.parquet(sqlContext, parquetPath) expect_is(parquetDF, "DataFrame") @@ -1433,7 +1445,7 @@ test_that("write.df() on DataFrame and works with read.parquet", { }) test_that("read.parquet()/parquetFile() works with multiple input paths", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") @@ -1452,7 +1464,7 @@ test_that("read.parquet()/parquetFile() works with multiple input paths", { }) test_that("describe() and summarize() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") @@ -1470,7 +1482,7 @@ test_that("describe() and summarize() on a DataFrame", { }) test_that("dropna() and na.omit() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) + df <- read.json(sqlContext, jsonPathNa) rows <- collect(df) # drop with columns @@ -1556,7 +1568,7 @@ test_that("dropna() and na.omit() on a DataFrame", { }) test_that("fillna() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) + df <- read.json(sqlContext, jsonPathNa) rows <- collect(df) # fill with value @@ -1665,7 +1677,7 @@ test_that("Method as.data.frame as a synonym for collect()", { }) test_that("attach() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_error(age) attach(df) expect_is(age, "DataFrame") @@ -1713,7 +1725,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") - df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 53b817144f6ac..62f60e57eebe6 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -35,7 +35,7 @@ printSchema(df) # Create a DataFrame from a JSON file path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") -peopleDF <- jsonFile(sqlContext, path) +peopleDF <- read.json(sqlContext, path) printSchema(peopleDF) # Register this DataFrame as a table. From aa305dcaf5b4148aba9e669e081d0b9235f50857 Mon Sep 17 00:00:00 2001 From: anabranch Date: Fri, 11 Dec 2015 12:55:56 -0800 Subject: [PATCH 1208/2089] [SPARK-11964][DOCS][ML] Add in Pipeline Import/Export Documentation Adding in Pipeline Import and Export Documentation. Author: anabranch Author: Bill Chambers Closes #10179 from anabranch/master. --- docs/ml-guide.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5c96c2b7d5cc9..44a316a07dfef 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -192,6 +192,10 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +## Saving and Loading Pipelines + +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. + # Code examples This section gives code examples illustrating the functionality discussed above. @@ -455,6 +459,15 @@ val pipeline = new Pipeline() // Fit the pipeline to training documents. val model = pipeline.fit(training) +// now we can optionally save the fitted pipeline to disk +model.save("/tmp/spark-logistic-regression-model") + +// we can also save this unfit pipeline to disk +pipeline.save("/tmp/unfit-lr-model") + +// and load it back in during production +val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") + // Prepare test documents, which are unlabeled (id, text) tuples. val test = sqlContext.createDataFrame(Seq( (4L, "spark i j k"), From 713e6959d21d24382ef99bbd7e9da751a7ed388c Mon Sep 17 00:00:00 2001 From: proflin Date: Fri, 11 Dec 2015 13:50:36 -0800 Subject: [PATCH 1209/2089] [SPARK-12273][STREAMING] Make Spark Streaming web UI list Receivers in order Currently the Streaming web UI does NOT list Receivers in order; however, it seems more convenient for the users if Receivers are listed in order. ![spark-12273](https://cloud.githubusercontent.com/assets/15843379/11736602/0bb7f7a8-a00b-11e5-8e86-96ba9297fb12.png) Author: proflin Closes #10264 from proflin/Spark-12273. --- .../scala/org/apache/spark/streaming/ui/StreamingPage.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4588b2163cd44..88a4483e8068f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -392,8 +392,9 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { - val content = listener.receivedEventRateWithBatchTime.map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { + case (streamId, eventRates) => + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off From 1b8220387e6903564f765fabb54be0420c3e99d7 Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Fri, 11 Dec 2015 14:21:33 -0800 Subject: [PATCH 1210/2089] [SPARK-11497][MLLIB][PYTHON] PySpark RowMatrix Constructor Has Type Erasure Issue As noted in PR #9441, implementing `tallSkinnyQR` uncovered a bug with our PySpark `RowMatrix` constructor. As discussed on the dev list [here](http://apache-spark-developers-list.1001551.n3.nabble.com/K-Means-And-Class-Tags-td10038.html), there appears to be an issue with type erasure with RDDs coming from Java, and by extension from PySpark. Although we are attempting to construct a `RowMatrix` from an `RDD[Vector]` in [PythonMLlibAPI](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala#L1115), the `Vector` type is erased, resulting in an `RDD[Object]`. Thus, when calling Scala's `tallSkinnyQR` from PySpark, we get a Java `ClassCastException` in which an `Object` cannot be cast to a Spark `Vector`. As noted in the aforementioned dev list thread, this issue was also encountered with `DecisionTrees`, and the fix involved an explicit `retag` of the RDD with a `Vector` type. `IndexedRowMatrix` and `CoordinateMatrix` do not appear to have this issue likely due to their related helper functions in `PythonMLlibAPI` creating the RDDs explicitly from DataFrames with pattern matching, thus preserving the types. This PR currently contains that retagging fix applied to the `createRowMatrix` helper function in `PythonMLlibAPI`. This PR blocks #9441, so once this is merged, the other can be rebased. cc holdenk Author: Mike Dusenberry Closes #9458 from dusenberrymw/SPARK-11497_PySpark_RowMatrix_Constructor_Has_Type_Erasure_Issue. --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2aa6aec0b4347..8d546e3d6099b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1143,7 +1143,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Wrapper around RowMatrix constructor. */ def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { - new RowMatrix(rows.rdd, numRows, numCols) + new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols) } /** From aea676ca2d07c72b1a752e9308c961118e5bfc3c Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 11 Dec 2015 15:43:00 -0800 Subject: [PATCH 1211/2089] [SPARK-12217][ML] Document invalid handling for StringIndexer Added a paragraph regarding StringIndexer#setHandleInvalid to the ml-features documentation. I wonder if I should also add a snippet to the code example, input welcome. Author: BenFradet Closes #10257 from BenFradet/SPARK-12217. --- docs/ml-features.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 6494fed0a01e5..8b00cc652dc7a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -459,6 +459,42 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. +Additionaly, there are two strategies regarding how `StringIndexer` will handle +unseen labels when you have fit a `StringIndexer` on one dataset and then use it +to transform another: + +- throw an exception (which is the default) +- skip the row containing the unseen label entirely + +**Examples** + +Let's go back to our previous example but this time reuse our previously defined +`StringIndexer` on the following dataset: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | d +~~~~ + +If you've not set how `StringIndexer` handles unseen labels or set it to +"error", an exception will be thrown. +However, if you had called `setHandleInvalid("skip")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 +~~~~ + +Notice that the row containing "d" does not appear. +
    From a0ff6d16ef4bcc1b6ff7282e82a9b345d8449454 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Dec 2015 18:02:24 -0800 Subject: [PATCH 1212/2089] [SPARK-11978][ML] Move dataset_example.py to examples/ml and rename to dataframe_example.py Since ```Dataset``` has a new meaning in Spark 1.6, we should rename it to avoid confusion. #9873 finished the work of Scala example, here we focus on the Python one. Move dataset_example.py to ```examples/ml``` and rename to ```dataframe_example.py```. BTW, fix minor missing issues of #9873. cc mengxr Author: Yanbo Liang Closes #9957 from yanboliang/SPARK-11978. --- .../dataframe_example.py} | 56 +++++++++++-------- .../spark/examples/ml/DataFrameExample.scala | 8 +-- 2 files changed, 38 insertions(+), 26 deletions(-) rename examples/src/main/python/{mllib/dataset_example.py => ml/dataframe_example.py} (53%) diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/ml/dataframe_example.py similarity index 53% rename from examples/src/main/python/mllib/dataset_example.py rename to examples/src/main/python/ml/dataframe_example.py index e23ecc0c5d302..d2644ca335654 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -16,8 +16,8 @@ # """ -An example of how to use DataFrame as a dataset for ML. Run with:: - bin/spark-submit examples/src/main/python/mllib/dataset_example.py +An example of how to use DataFrame for ML. Run with:: + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -28,36 +28,48 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.mllib.util import MLUtils from pyspark.mllib.stat import Statistics - -def summarize(dataset): - print("schema: %s" % dataset.schema().json()) - labels = dataset.map(lambda r: r.label) - print("label average: %f" % labels.mean()) - features = dataset.map(lambda r: r.features) - summary = Statistics.colStats(features) - print("features average: %r" % summary.mean()) - if __name__ == "__main__": if len(sys.argv) > 2: - print("Usage: dataset_example.py ", file=sys.stderr) + print("Usage: dataframe_example.py ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="DatasetExample") + sc = SparkContext(appName="DataFrameExample") sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" - points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() - summarize(dataset0) + + # Load input data + print("Loading LIBSVM file with UDT from " + input + ".") + df = sqlContext.read.format("libsvm").load(input).cache() + print("Schema from LIBSVM:") + df.printSchema() + print("Loaded training data as a DataFrame with " + + str(df.count()) + " records.") + + # Show statistical summary of labels. + labelSummary = df.describe("label") + labelSummary.show() + + # Convert features column to an RDD of vectors. + features = df.select("features").map(lambda r: r.features) + summary = Statistics.colStats(features) + print("Selected features column with average values:\n" + + str(summary.mean())) + + # Save the records in a parquet file. tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print("Save dataset as a Parquet file to %s." % tempdir) - dataset0.saveAsParquetFile(tempdir) - print("Load it back and summarize it again.") - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() - summarize(dataset1) + print("Saving to " + tempdir + " as Parquet file.") + df.write.parquet(tempdir) + + # Load the records back. + print("Loading Parquet file with UDT from " + tempdir) + newDF = sqlContext.read.parquet(tempdir) + print("Schema from Parquet:") + newDF.printSchema() shutil.rmtree(tempdir) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 424f00158c2f2..0a477abae5679 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -44,10 +44,10 @@ object DataFrameExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using DataFrame as a Dataset for ML.") + val parser = new OptionParser[Params]("DataFrameExample") { + head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataset") + .text(s"input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -88,7 +88,7 @@ object DataFrameExample { // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() - val outputDir = new File(tmpDir, "dataset").toString + val outputDir = new File(tmpDir, "dataframe").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) From 1e799d617a28cd0eaa8f22d103ea8248c4655ae5 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 11 Dec 2015 19:07:48 -0800 Subject: [PATCH 1213/2089] [SPARK-12298][SQL] Fix infinite loop in DataFrame.sortWithinPartitions Modifies the String overload to call the Column overload and ensures this is called in a test. Author: Ankur Dave Closes #10271 from ankurdave/SPARK-12298. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index da180a2ba09d9..497bd48266770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -609,7 +609,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { - sortWithinPartitions(sortCol, sortCols : _*) + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5353fefaf4b84..c0bbf73ab1188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1090,8 +1090,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } // Distribute into one partition and order by. This partition should contain all the values. - val df6 = data.repartition(1, $"a").sortWithinPartitions($"b".asc) - // Walk each partition and verify that it is sorted descending and not globally sorted. + val df6 = data.repartition(1, $"a").sortWithinPartitions("b") + // Walk each partition and verify that it is sorted ascending and not globally sorted. df6.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true From 1e3526c2d3de723225024fedd45753b556e18fc6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 11 Dec 2015 20:55:16 -0800 Subject: [PATCH 1214/2089] [SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit test cases The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value. This could cause SparkR unit tests failed. For example, I hit it in another PR: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull Author: gatorsmile Closes #10160 from gatorsmile/sampleR. --- R/pkg/R/DataFrame.R | 17 +++++++++++------ R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 975b058c0aaf1..764597d1e32b4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -662,6 +662,7 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value #' #' @family DataFrame functions #' @rdname sample @@ -677,13 +678,17 @@ setMethod("unique", #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", - # TODO : Figure out how to send integer as java.lang.Long to JVM so - # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { + function(x, withReplacement, fraction, seed) { if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + if (!missing(seed)) { + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + } else { + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + } dataFrame(sdf) }) @@ -692,8 +697,8 @@ setMethod("sample", setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { - sample(x, withReplacement, fraction) + function(x, withReplacement, fraction, seed) { + sample(x, withReplacement, fraction, seed) }) #' nrow diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ed9b2c9d4d16c..071fd310fd58a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -724,6 +724,10 @@ test_that("sample on a DataFrame", { sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) + count1 <- count(sample(df, FALSE, 0.1, 0)) + count2 <- count(sample(df, FALSE, 0.1, 0)) + expect_equal(count1, count2) + # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) From 03138b67d3ef7f5278ea9f8b9c75f0e357ef79d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Sat, 12 Dec 2015 08:51:52 +0000 Subject: [PATCH 1215/2089] [SPARK-11193] Use Java ConcurrentHashMap instead of SynchronizedMap trait in order to avoid ClassCastException due to KryoSerializer in KinesisReceiver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #10203 from jbonofre/SPARK-11193. --- .../streaming/kinesis/KinesisReceiver.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 05080835fc4ad..80edda59e1719 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -124,8 +125,7 @@ private[kinesis] class KinesisReceiver[T]( private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] /** Sequence number ranges of data added to each generated block */ - private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] - with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + private val blockIdToSeqNumRanges = new ConcurrentHashMap[StreamBlockId, SequenceNumberRanges] /** * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. @@ -135,8 +135,8 @@ private[kinesis] class KinesisReceiver[T]( /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ - private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String] - with mutable.SynchronizedMap[String, String] + private val shardIdToLatestStoredSeqNum = new ConcurrentHashMap[String, String] + /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). @@ -222,7 +222,7 @@ private[kinesis] class KinesisReceiver[T]( /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { - shardIdToLatestStoredSeqNum.get(shardId) + Option(shardIdToLatestStoredSeqNum.get(shardId)) } /** @@ -257,7 +257,7 @@ private[kinesis] class KinesisReceiver[T]( * for next block. Internally, this is synchronized with `rememberAddedRange()`. */ private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { - blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray) + blockIdToSeqNumRanges.put(blockId, SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)) seqNumRangesInCurrentBlock.clear() logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") } @@ -265,7 +265,7 @@ private[kinesis] class KinesisReceiver[T]( /** Store the block along with its associated ranges */ private def storeBlockWithRanges( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { - val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) + val rangesToReportOption = Option(blockIdToSeqNumRanges.remove(blockId)) if (rangesToReportOption.isEmpty) { stop("Error while storing block into Spark, could not find sequence number ranges " + s"for block $blockId") @@ -294,7 +294,7 @@ private[kinesis] class KinesisReceiver[T]( // Note that we are doing this sequentially because the array of sequence number ranges // is assumed to be rangesToReport.ranges.foreach { range => - shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber + shardIdToLatestStoredSeqNum.put(range.shardId, range.toSeqNumber) } } From 98b212d36b34ab490c391ea2adf5b141e4fb9289 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sat, 12 Dec 2015 17:47:01 -0800 Subject: [PATCH 1216/2089] [SPARK-12199][DOC] Follow-up: Refine example code in ml-features.md https://issues.apache.org/jira/browse/SPARK-12199 Follow-up PR of SPARK-11551. Fix some errors in ml-features.md mengxr Author: Xusen Yin Closes #10193 from yinxusen/SPARK-12199. --- docs/ml-features.md | 22 +++++++++---------- .../examples/ml/JavaBinarizerExample.java | 2 +- .../python/ml/polynomial_expansion_example.py | 6 ++--- ....scala => ElementwiseProductExample.scala} | 0 4 files changed, 15 insertions(+), 15 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/ml/{ElementWiseProductExample.scala => ElementwiseProductExample.scala} (100%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 8b00cc652dc7a..158f3f201899c 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2Vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. @@ -411,7 +411,7 @@ for more details on the API. Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}
    @@ -669,7 +669,7 @@ for more details on the API. The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
    -
    +
    Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. @@ -677,7 +677,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    -
    +
    Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. @@ -685,7 +685,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    -
    +
    Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. @@ -709,7 +709,7 @@ Note that if the standard deviation of a feature is zero, it will return default The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
    -
    +
    Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. @@ -717,7 +717,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    -
    +
    Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. @@ -725,7 +725,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    -
    +
    Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. @@ -788,7 +788,7 @@ More details can be found in the API docs for [Bucketizer](api/scala/index.html# The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
    -
    +
    Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. @@ -796,7 +796,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    -
    +
    Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. @@ -804,7 +804,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    -
    +
    Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java index 9698cac504371..1eda1f694fc27 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -59,7 +59,7 @@ public static void main(String[] args) { DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); + Double binarized_value = r.getDouble(0); System.out.println(binarized_value); } // $example off$ diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py index 3d4fafd1a42e9..89f5cbe8f2f41 100644 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -30,9 +30,9 @@ # $example on$ df = sqlContext\ - .createDataFrame([(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], + .createDataFrame([(Vectors.dense([-2.0, 2.3]),), + (Vectors.dense([0.0, 0.0]),), + (Vectors.dense([0.6, -1.1]),)], ["features"]) px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") polyDF = px.transform(df) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala From 8af2f8c61ae4a59d129fb3530d0f6e9317f4bff8 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 12 Dec 2015 21:58:55 -0800 Subject: [PATCH 1217/2089] [SPARK-12267][CORE] Store the remote RpcEnv address to send the correct disconnetion message Author: Shixiong Zhu Closes #10261 from zsxwing/SPARK-12267. --- .../spark/deploy/master/ApplicationInfo.scala | 1 + .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 21 ++++++++++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 42 +++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ac553b71115df..7e2cf956c7253 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -66,6 +66,7 @@ private[spark] class ApplicationInfo( nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] executorLimit = Integer.MAX_VALUE + appUIUrlAtHistoryServer = None } private def newExecutorId(useID: Option[Int] = None): Int = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1afc1ff59f2f9..f41efb097b4be 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -690,7 +690,7 @@ private[deploy] object Worker extends Logging { val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.masters, args.workDir) + args.memory, args.masters, args.workDir, conf = conf) rpcEnv.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68c5f44145b0d..f82fd4eb5756d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -553,6 +553,9 @@ private[netty] class NettyRpcHandler( // A variable to track whether we should dispatch the RemoteProcessConnected message. private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() + // A variable to track the remote RpcEnv addresses of all clients + private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() + override def receive( client: TransportClient, message: ByteBuffer, @@ -580,6 +583,12 @@ private[netty] class NettyRpcHandler( // Create a new message with the socket address of the client as the sender. RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) } else { + // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for + // the listening address + val remoteEnvAddress = requestMessage.senderAddress + if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } requestMessage } } @@ -591,6 +600,12 @@ private[netty] class NettyRpcHandler( if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessConnectionError for the remote RpcEnv listening address + val remoteEnvAddress = remoteAddresses.get(clientAddr) + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -606,6 +621,12 @@ private[netty] class NettyRpcHandler( val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + val remoteEnvAddress = remoteAddresses.remove(clientAddr) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessDisconnected for the remote RpcEnv listening address + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index a61d0479aacdb..6d153eb04e04f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -545,6 +545,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("network events between non-client-mode RpcEnvs") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events-non-client") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) + assert(events.contains(("onDisconnected", remoteAddress))) + } + } + test("sendWithReply: unserializable error") { env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { override val rpcEnv = env From 2aecda284e22ec608992b6221e2f5ffbd51fcd24 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 13 Dec 2015 22:06:39 -0800 Subject: [PATCH 1218/2089] [SPARK-12281][CORE] Fix a race condition when reporting ExecutorState in the shutdown hook 1. Make sure workers and masters exit so that no worker or master will still be running when triggering the shutdown hook. 2. Set ExecutorState to FAILED if it's still RUNNING when executing the shutdown hook. This should fix the potential exceptions when exiting a local cluster ``` java.lang.AssertionError: assertion failed: executor 4 state transfer from RUNNING to RUNNING is illegal at scala.Predef$.assert(Predef.scala:179) at org.apache.spark.deploy.master.Master$$anonfun$receive$1.applyOrElse(Master.scala:260) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:116) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:204) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:100) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) java.lang.IllegalStateException: Shutdown hooks cannot be modified during shutdown. at org.apache.spark.util.SparkShutdownHookManager.add(ShutdownHookManager.scala:246) at org.apache.spark.util.ShutdownHookManager$.addShutdownHook(ShutdownHookManager.scala:191) at org.apache.spark.util.ShutdownHookManager$.addShutdownHook(ShutdownHookManager.scala:180) at org.apache.spark.deploy.worker.ExecutorRunner.start(ExecutorRunner.scala:73) at org.apache.spark.deploy.worker.Worker$$anonfun$receive$1.applyOrElse(Worker.scala:474) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:116) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:204) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:100) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` Author: Shixiong Zhu Closes #10269 from zsxwing/executor-state. --- .../scala/org/apache/spark/deploy/LocalSparkCluster.scala | 2 ++ .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 ++--- .../org/apache/spark/deploy/worker/ExecutorRunner.scala | 5 +++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 83ccaadfe7447..5bb62d37d6374 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -75,6 +75,8 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected workerRpcEnvs.foreach(_.shutdown()) masterRpcEnvs.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.awaitTermination()) + masterRpcEnvs.foreach(_.awaitTermination()) masterRpcEnvs.clear() workerRpcEnvs.clear() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 04b20e0d6ab9c..1355e1ad1b523 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -257,9 +257,8 @@ private[deploy] class Master( exec.state = state if (state == ExecutorState.RUNNING) { - if (oldState != ExecutorState.LAUNCHING) { - logWarning(s"Executor $execId state transfer from $oldState to RUNNING is unexpected") - } + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") appInfo.resetRetryCount() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 25a17473e4b53..9a42487bb37aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -71,6 +71,11 @@ private[deploy] class ExecutorRunner( workerThread.start() // Shutdown hook that kills actors on shutdown. shutdownHook = ShutdownHookManager.addShutdownHook { () => + // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will + // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`. + if (state == ExecutorState.RUNNING) { + state = ExecutorState.FAILED + } killProcess(Some("Worker shutting down")) } } From 834e71489bf560302f9d743dff669df1134e9b74 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 13 Dec 2015 22:57:01 -0800 Subject: [PATCH 1219/2089] [SPARK-12213][SQL] use multiple partitions for single distinct query Currently, we could generate different plans for query with single distinct (depends on spark.sql.specializeSingleDistinctAggPlanning), one works better on low cardinality columns, the other works better for high cardinality column (default one). This PR change to generate a single plan (three aggregations and two exchanges), which work better in both cases, then we could safely remove the flag `spark.sql.specializeSingleDistinctAggPlanning` (introduced in 1.6). For a query like `SELECT COUNT(DISTINCT a) FROM table` will be ``` AGG-4 (count distinct) Shuffle to a single reducer Partial-AGG-3 (count distinct, no grouping) Partial-AGG-2 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on a) ``` This PR also includes large refactor for aggregation (reduce 500+ lines of code) cc yhuai nongli marmbrus Author: Davies Liu Closes #10228 from davies/single_distinct. --- .../spark/sql/catalyst/CatalystConf.scala | 7 - .../DistinctAggregationRewriter.scala | 11 +- .../scala/org/apache/spark/sql/SQLConf.scala | 15 - .../aggregate/AggregationIterator.scala | 417 ++++++----------- .../aggregate/SortBasedAggregate.scala | 29 +- .../SortBasedAggregationIterator.scala | 47 +- .../aggregate/TungstenAggregate.scala | 25 +- .../TungstenAggregationIterator.scala | 439 +++--------------- .../spark/sql/execution/aggregate/utils.scala | 280 +++++------ .../execution/AggregationQuerySuite.scala | 142 +++--- 10 files changed, 422 insertions(+), 990 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 7c2b8a9407884..2c7c58e66b855 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean - - protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } - - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { - throw new UnsupportedOperationException - } } /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9c78f6d4cc71b..4e7d1341028ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { - // When the flag is set to specialize single distinct agg planning, - // we will rely on our Aggregation strategy to handle queries with a single - // distinct column. - distinctAggGroups.size > 1 - } else { - distinctAggGroups.size >= 1 - } - if (shouldRewrite) { + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 58adf64e49869..3d819262859f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -449,18 +449,6 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) - val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = - booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(false), - isPublic = false, - doc = "When true, if a query only has a single distinct column and it has " + - "grouping expressions, we will use our planner rule to handle this distinct " + - "column (other cases are handled by DistinctAggregationRewriter). " + - "When false, we will always use DistinctAggregationRewriter to plan " + - "aggregation queries with DISTINCT keyword. This is an internal flag that is " + - "used to benchmark the performance impact of using DistinctAggregationRewriter to " + - "plan aggregation queries with a single distinct column.") - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = - getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 008478a6a0e17..0c74df0aa5fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import scala.collection.mutable.ArrayBuffer - /** - * The base class of [[SortBasedAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of @@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer * is used to generate result. */ abstract class AggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends Iterator[InternalRow] with Logging { + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// // Initializing functions. /////////////////////////////////////////////////////////////////////////// - // An Seq of all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - protected val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - require( - allAggregateExpressions.map(_.mode).distinct.length <= 2, - s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") - /** - * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: - * - Partial-only: all AggregateExpressions have the mode of Partial; - * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); - * - Final-only: all AggregateExpressions have the mode of Final; - * - Final-Complete: some AggregateExpressions have the mode of Final and - * others have the mode of Complete; - * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions - * with mode Complete in completeAggregateExpressions; and - * - Grouping-only: there is no AggregateExpression. - */ - protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ + { + val modes = aggregateExpressions.map(_.mode).distinct.toSet + require(modes.size <= 2, + s"$aggregateExpressions are not supported because they have more than 2 distinct modes.") + require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)), + s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") + } // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction] = { + protected def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, valueAttributes) + BindReferences.bindReference(func, inputAttributes) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. @@ -117,15 +111,18 @@ abstract class AggregationIterator( functions } + protected val aggregateFunctions: Array[AggregateFunction] = + initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are imperative aggregate functions. // ImperativeAggregateFunctionPositions will be [1, 2]. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { case agg: DeclarativeAggregate => case _ => positions += i } @@ -134,17 +131,9 @@ abstract class AggregationIterator( positions.toArray } - // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - - // All imperative aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - // The projection used to initialize buffer values for all expression-based aggregates. - private[this] val expressionAggInitialProjection = { - val initExpressions = allAggregateFunctions.flatMap { + protected[this] val expressionAggInitialProjection = { + val initExpressions = aggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.initialValues // For the positions corresponding to imperative aggregate functions, we'll use special // no-op expressions which are ignored during projection code-generation. @@ -154,248 +143,112 @@ abstract class AggregationIterator( } // All imperative AggregateFunctions. - private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = allImperativeAggregateFunctionPositions - .map(allAggregateFunctions) + .map(aggregateFunctions) .map(_.asInstanceOf[ImperativeAggregate]) - /////////////////////////////////////////////////////////////////////////// - // Methods and fields used by sub-classes. - /////////////////////////////////////////////////////////////////////////// - // Initializing functions used to process a row. - protected val processRow: (MutableRow, InternalRow) => Unit = { - val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { - // If initialInputBufferOffset, the input value does not contain - // grouping keys. - // This part is pretty hacky. - allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq - } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - } - // val inputAggregationBufferSchema = - // groupingKeyAttributes ++ - // allAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferSchema ++ inputAggregationBufferSchema)() - - (currentBuffer: MutableRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - aggregationBufferSchema ++ - groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalExpressionAggMergeProjection = - newMutableProjection(mergeExpressions, mergeInputSchema)() - - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalExpressionAggMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 + protected def generateProcessRow( + expressions: Seq[AggregateExpression], + functions: Seq[AggregateFunction], + inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + val joinedRow = new JoinedRow + if (expressions.nonEmpty) { + val mergeExpressions = functions.zipWithIndex.flatMap { + case (ae: DeclarativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = - completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val updateFunctions = functions.zipWithIndex.collect { + case (ae: ImperativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => + (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + case PartialMerge | Final => + (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } + } + // This projection is used to merge buffer values for all expression-based aggregates. + val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) + val updateProjection = + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all expression-based aggregate functions. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < updateFunctions.length) { + updateFunctions(i)(currentBuffer, row) + i += 1 } - + } + } else { // Grouping only. - case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + (currentBuffer: MutableRow, row: InternalRow) => {} } } - // Initializing the function used to generate the output row. - protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { - val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) - val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) - } else { - safeOutputRow - } - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not - // support generic getter), we create a mutable projection to output the - // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes) - val resultProjection = - newMutableProjection( - groupingKeyAttributes ++ bufferSchema, - groupingKeyAttributes ++ bufferSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) - // rowToBeEvaluated(currentGroupingKey, currentBuffer) - } + protected val processRow: (MutableRow, InternalRow) => Unit = + generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) - // Final-only, Complete-only and Final-Complete: every output row contains values representing - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val bufferSchemata = - allAggregateFunctions.flatMap(_.aggBufferAttributes) - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - // TODO: Use unsafe row. - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - newMutableProjection( - resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() - resultProjection.target(mutableOutput) + protected val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpressions, inputAttributes) + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + // Initializing the function used to generate the output row. + protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val joinedRow = new JoinedRow + val modes = aggregateExpressions.map(_.mode).distinct + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + if (modes.contains(Final) || modes.contains(Complete)) { + val evalExpressions = aggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction => NoOp + } + val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + expressionAggEvalProjection.target(aggregateResult) + + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 } - + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + val resultProjection = UnsafeProjection.create( + groupingAttributes ++ bufferAttributes, + groupingAttributes ++ bufferAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + } + } else { // Grouping-only: we only output values of grouping expressions. - case (None, None) => - val resultProjection = - newMutableProjection(resultExpressions, groupingKeyAttributes)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(currentGroupingKey) - } - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } } } + protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + generateResultProjection() + /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) @@ -405,10 +258,4 @@ abstract class AggregationIterator( i += 1 } } - - /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ - protected def newBuffer: MutableRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ee982453c3287..c5470a6989de7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) @@ -42,10 +40,8 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = false - + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -76,31 +72,24 @@ case class SortBasedAggregate( if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. - Iterator[InternalRow]() + Iterator[UnsafeRow]() } else { - val groupingKeyProjection = - UnsafeProjection.create(groupingExpressions, child.output) - val outputIter = new SortBasedAggregationIterator( - groupingKeyProjection, - groupingExpressions.map(_.toAttribute), + groupingExpressions, child.output, iter, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows, numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. numOutputRows += 1 - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter } @@ -109,7 +98,7 @@ case class SortBasedAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index fe5c3195f867b..ac920aa8bc7f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric /** * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been - * sorted by values of [[groupingKeyAttributes]]. + * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( - groupingKeyProjection: InternalRow => InternalRow, - groupingKeyAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean, numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends AggregationIterator( - groupingKeyAttributes, + groupingExpressions, valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + newMutableProjection) { + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: MutableRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) @@ -76,10 +73,10 @@ class SortBasedAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The partition key of the current partition. - private[this] var currentGroupingKey: InternalRow = _ + private[this] var currentGroupingKey: UnsafeRow = _ // The partition key of next partition. - private[this] var nextGroupingKey: InternalRow = _ + private[this] var nextGroupingKey: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -94,7 +91,7 @@ class SortBasedAggregationIterator( if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) val inputRow = inputIterator.next() - nextGroupingKey = groupingKeyProjection(inputRow).copy() + nextGroupingKey = groupingProjection(inputRow).copy() firstRowInNextGroup = inputRow.copy() numInputRows += 1 sortedInputHasNewGroup = true @@ -120,7 +117,7 @@ class SortBasedAggregationIterator( while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. val currentRow = inputIterator.next() - val groupingKey = groupingKeyProjection(currentRow) + val groupingKey = groupingProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. @@ -146,7 +143,7 @@ class SortBasedAggregationIterator( override final def hasNext: Boolean = sortedInputHasNewGroup - override final def next(): InternalRow = { + override final def next(): UnsafeRow = { if (hasNext) { // Process the current group. processCurrentSortedGroup() @@ -162,8 +159,8 @@ class SortBasedAggregationIterator( } } - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 920de615e1d86..b8849c827048a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { private[this] val aggregateBufferAttributes = { - (nonCompleteAggregateExpressions ++ completeAggregateExpressions) - .flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } - require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), @@ -53,9 +50,7 @@ case class TungstenAggregate( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,10 +89,8 @@ case class TungstenAggregate( val aggregationIterator = new TungstenAggregationIterator( groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, @@ -119,7 +112,7 @@ case class TungstenAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => @@ -135,9 +128,7 @@ case class TungstenAggregate( } object TungstenAggregate { - def supportsAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 04391443920ac..582fdbe547061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. @@ -63,15 +61,11 @@ import org.apache.spark.sql.types.StructType * * @param groupingExpressions * expressions for grouping keys - * @param nonCompleteAggregateExpressions + * @param aggregateExpressions * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. - * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * @param aggregateAttributes the attributes of the aggregateExpressions' * outputs when they are stored in the final aggregation buffer. - * @param completeAggregateExpressions - * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. - * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs - * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -83,10 +77,8 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), @@ -97,378 +89,62 @@ class TungstenAggregationIterator( numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) - extends Iterator[UnsafeRow] with Logging { + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// - // A Seq containing all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression] = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - // Check to make sure we do not have more than three modes in our AggregateExpressions. - // If we have, users are hitting a bug and we throw an IllegalStateException. - if (allAggregateExpressions.map(_.mode).distinct.length > 2) { - throw new IllegalStateException( - s"$allAggregateExpressions should have no more than 2 kinds of modes.") - } - // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - // - // The modes of AggregateExpressions. Right now, we can handle the following mode: - // - Partial-only: - // All AggregateExpressions have the mode of Partial. - // For this case, aggregationMode is (Some(Partial), None). - // - PartialMerge-only: - // All AggregateExpressions have the mode of PartialMerge). - // For this case, aggregationMode is (Some(PartialMerge), None). - // - Final-only: - // All AggregateExpressions have the mode of Final. - // For this case, aggregationMode is (Some(Final), None). - // - Final-Complete: - // Some AggregateExpressions have the mode of Final and - // others have the mode of Complete. For this case, - // aggregationMode is (Some(Final), Some(Complete)). - // - Complete-only: - // nonCompleteAggregateExpressions is empty and we have AggregateExpressions - // with mode Complete in completeAggregateExpressions. For this case, - // aggregationMode is (None, Some(Complete)). - // - Grouping-only: - // There is no AggregateExpression. For this case, AggregationMode is (None,None). - // - private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption - } - - // Initialize all AggregateFunctions by binding references, if necessary, - // and setting inputBufferOffset and mutableBufferOffset. - private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) - var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length - // We need to use this mode instead of func.mode in order to handle aggregation mode switching - // when switching to sort-based aggregation: - val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 - val funcWithBoundReferences = mode match { - case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => - // We need to create BoundReferences if the function is not an - // expression-based aggregate function (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, originalInputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - val updatedFunc = func match { - case function: ImperativeAggregate => - function.withNewInputAggBufferOffset(inputBufferOffset) - case function => function - } - inputBufferOffset += func.aggBufferSchema.length - updatedFunc - } - val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { - case function: ImperativeAggregate => - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - function.withNewMutableAggBufferOffset(mutableBufferOffset) - case function => function - } - mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length - functions(i) = funcWithUpdatedAggBufferOffset - i += 1 - } - functions - } - - private[this] var allAggregateFunctions: Array[AggregateFunction] = - initializeAllAggregateFunctions(initialInputBufferOffset) - - // Positions of those imperative aggregate functions in allAggregateFunctions. - // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are imperative aggregate functions. Then - // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be - // updated when falling back to sort-based aggregation because the positions of the aggregate - // functions do not change in that case. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { - case agg: DeclarativeAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - /////////////////////////////////////////////////////////////////////////// // Part 2: Methods and fields used by setting aggregation buffer values, // processing input rows from inputIter, and generating output // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values for all expression-based aggregates. - // Note that this projection does not need to be updated when switching to sort-based aggregation - // because the schema of empty aggregation buffers does not change in that case. - private[this] val expressionAggInitialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.initialValues - // For the positions corresponding to imperative aggregate functions, we'll use special - // no-op expressions which are ignored during projection code-generation. - case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) - } - newMutableProjection(initExpressions, Nil)() - } - // Creates a new aggregation buffer and initializes buffer values. - // This function should be only called at most three times (when we create the hash map, - // when we switch to sort-based aggregation, and when we create the re-used buffer for - // sort-based aggregation). + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) .apply(new GenericMutableRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values - allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } - // Creates a function used to process a row based on the given inputAttributes. - private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { - - val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val joinedRow = new JoinedRow() - - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffer values in row to - // currentBuffer. - finalMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") - } - } - // Creates a function used to generate output rows. - private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { - - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - unsafeRowJoiner.join(currentGroupingKey, currentBuffer) - } - - // Final-only, Complete-only and Final-Complete: a output row is generated based on - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val joinedRow = new JoinedRow() - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() - // These are the attributes of the row produced by `expressionAggEvalProjection` - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) - - val allImperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } - - // Grouping-only: a output row is generated from values of grouping expressions. - case (None, None) => - val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(currentGroupingKey) - } - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() } } - // An UnsafeProjection used to extract grouping keys from the input rows. - private[this] val groupProjection = - UnsafeProjection.create(groupingExpressions, originalInputAttributes) - - // A function used to process a input row. Its first argument is the aggregation buffer - // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, InternalRow) => Unit = - generateProcessRow(originalInputAttributes) - - // A function used to generate output rows based on the grouping keys (first argument) - // and the corresponding aggregation buffer (second argument). - private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = - generateResultProjection() - // An aggregation buffer containing initial buffer values. It is used to // initialize other aggregation buffers. private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() @@ -482,7 +158,7 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity @@ -499,7 +175,7 @@ class TungstenAggregationIterator( if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. - val groupingKey = groupProjection.apply(null) + val groupingKey = groupingProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = inputIter.next() @@ -511,7 +187,7 @@ class TungstenAggregationIterator( while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) + val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null if (i < fallbackStartsAt) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -565,25 +241,18 @@ class TungstenAggregationIterator( private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. - val newAggregationMode = aggregationMode match { - case (Some(Partial), None) => (Some(PartialMerge), None) - case (None, Some(Complete)) => (Some(Final), None) - case (Some(Final), Some(Complete)) => (Some(Final), None) + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _) => + agg.copy(mode = Final) case other => other } - aggregationMode = newAggregationMode - - allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) - - // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use inputAggBufferAttributes. - val newInputAttributes: Seq[Attribute] = - allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - - // Set up new processRow and generateOutput. - processRow = generateProcessRow(newInputAttributes) - generateOutput = generateResultProjection() + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) // Step 5: Get the sorted iterator from the externalSorter. sortedKVIterator = externalSorter.sortedIterator() @@ -632,6 +301,9 @@ class TungstenAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { // First, we need to copy nextGroupingKey to currentGroupingKey. @@ -640,7 +312,7 @@ class TungstenAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -655,16 +327,15 @@ class TungstenAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey.equals(groupingKey)) { - processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) hasNext = sortedKVIterator.next() } else { // We find a new group. findNextPartition = true // copyFrom will fail when - nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() - firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() - + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) } } // We have not seen a new group. It means that there is no new row in the input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 76b938cdb694e..83379ae90f703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -42,16 +42,45 @@ object Utils { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = Nil, - nonCompleteAggregateAttributes = Nil, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = resultExpressions, child = child ) :: Nil } + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -59,9 +88,6 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. @@ -73,29 +99,14 @@ object Utils { groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -105,29 +116,14 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val finalAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - } else { - SortBasedAggregate( + val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) - } finalAggregate :: Nil } @@ -140,99 +136,99 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = - groupingExpressions ++ namedDistinctColumnExpressions - val partialAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) } // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { - val partialMergeAggregateExpressions = - functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] } - // 3. Create an Aggregate Operator for the final aggregation. + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -241,49 +237,27 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val distinctColumnAttributeLookup = - distinctColumnExpressions.zip(distinctColumnAttributes).toMap - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction - .transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) - - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) }.unzip - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) } finalAndCompleteAggregate :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 064c0004b801e..5550198c02fbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -552,80 +551,73 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - Seq(true, false).foreach { specializeSingleDistinctAgg => - val conf = - (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, - specializeSingleDistinctAgg.toString) - withSQLConf(conf) { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), - | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | count(value1), - | count(*), - | count(1), - | count(DISTINCT value1), - | key - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) - } - } + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("single distinct multiple columns set") { From ed87f6d3b48a85391628c29c43d318c26e2c6de7 Mon Sep 17 00:00:00 2001 From: yucai Date: Sun, 13 Dec 2015 23:08:21 -0800 Subject: [PATCH 1220/2089] [SPARK-12275][SQL] No plan for BroadcastHint in some condition When SparkStrategies.BasicOperators's "case BroadcastHint(child) => apply(child)" is hit, it only recursively invokes BasicOperators.apply with this "child". It makes many strategies have no change to process this plan, which probably leads to "No plan" issue, so we use planLater to go through all strategies. https://issues.apache.org/jira/browse/SPARK-12275 Author: yucai Closes #10265 from yucai/broadcast_hint. --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameJoinSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 25e98c0bdd431..688555cf136e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -364,7 +364,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil - case BroadcastHint(child) => apply(child) + case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56ad71ea4f487..c70397f9853ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -120,5 +120,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // planner should not crash without a join broadcast(df1).queryExecution.executedPlan + + // SPARK-12275: no physical plan for BroadcastHint in some condition + withTempPath { path => + df1.write.parquet(path.getCanonicalPath) + val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + assert(df1.join(broadcast(pf1)).count() === 4) + } } } From e25f1fe42747be71c6b6e6357ca214f9544e3a46 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 14 Dec 2015 13:50:30 +0000 Subject: [PATCH 1221/2089] [MINOR][DOC] Fix broken word2vec link Follow-up of [SPARK-12199](https://issues.apache.org/jira/browse/SPARK-12199) and #10193 where a broken link has been left as is. Author: BenFradet Closes #10282 from BenFradet/SPARK-12199. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 158f3f201899c..677e4bfb916e8 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. From b51a4cdff3a7e640a8a66f7a9c17021f3056fd34 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Dec 2015 09:59:42 -0800 Subject: [PATCH 1222/2089] [SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark JIRA: https://issues.apache.org/jira/browse/SPARK-12016 We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper when loading it in pyspark. Author: Liang-Chi Hsieh Closes #10100 from viirya/fix-load-py-wordvecmodel. --- .../mllib/api/python/PythonMLLibAPI.scala | 33 ---------- .../api/python/Word2VecModelWrapper.scala | 62 +++++++++++++++++++ python/pyspark/mllib/feature.py | 6 +- 3 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8d546e3d6099b..29160a10e16b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -680,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable { } } - private[python] class Word2VecModelWrapper(model: Word2VecModel) { - def transform(word: String): Vector = { - model.transform(word) - } - - /** - * Transforms an RDD of words to its vector representation - * @param rdd an RDD of words - * @return an RDD of vector representations of words - */ - def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { - rdd.rdd.map(model.transform) - } - - def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) - } - - def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) - val similarity = Vectors.dense(result.map(_._2)) - val words = result.map(_._1) - List(words, similarity).map(_.asInstanceOf[Object]).asJava - } - - def getVectors: JMap[String, JList[Float]] = { - model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava - } - - def save(sc: SparkContext, path: String): Unit = model.save(sc, path) - } - /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala new file mode 100644 index 0000000000000..0f55980481dcb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.feature.Word2VecModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Wrapper around Word2VecModel to provide helper methods in Python + */ +private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): JList[Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + List(words, similarity).map(_.asInstanceOf[Object]).asJava + } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7b077b058c3fd..7254679ebb533 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -504,7 +504,8 @@ def load(cls, sc, path): """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) - return Word2VecModel(jmodel) + model = sc._jvm.Word2VecModelWrapper(jmodel) + return Word2VecModel(model) @ignore_unicode_prefix @@ -546,6 +547,9 @@ class Word2Vec(object): >>> sameModel = Word2VecModel.load(sc, path) >>> model.transform("a") == sameModel.transform("a") True + >>> syms = sameModel.findSynonyms("a", 2) + >>> [s[0] for s in syms] + [u'b', u'c'] >>> from shutil import rmtree >>> try: ... rmtree(path) From fb3778de685881df66bf0222b520f94dca99e8c8 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 14 Dec 2015 16:13:55 -0800 Subject: [PATCH 1223/2089] [SPARK-12327] Disable commented code lintr temporarily cc yhuai felixcheung shaneknapp Author: Shivaram Venkataraman Closes #10300 from shivaram/comment-lintr-disable. --- R/pkg/.lintr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e6..39c872663ad44 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") From 9ea1a8efca7869618c192546ef4e6d94c17689b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Dec 2015 16:48:11 -0800 Subject: [PATCH 1224/2089] [SPARK-12274][SQL] WrapOption should not have type constraint for child I think it was a mistake, and we have not catched it so far until https://github.com/apache/spark/pull/10260 which begin to check if the `fromRowExpression` is resolved. Author: Wenchen Fan Closes #10263 from cloud-fan/encoder. --- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index b2facfda24446..96bc4fe67a985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -296,15 +296,12 @@ case class UnwrapOption( * (in the case of reference types) equality with null. * @param child The expression to evaluate and wrap. */ -case class WrapOption(child: Expression) - extends UnaryExpression with ExpectsInputTypes { +case class WrapOption(child: Expression) extends UnaryExpression { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") From d13ff82cba10a1ce9889c4b416f1e43c717e3f10 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 14 Dec 2015 18:33:45 -0800 Subject: [PATCH 1225/2089] [SPARK-12188][SQL][FOLLOW-UP] Code refactoring and comment correction in Dataset APIs marmbrus This PR is to address your comment. Thanks for your review! Author: gatorsmile Closes #10214 from gatorsmile/followup12188. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3bd18a14f9e8f..dc69822e92908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -79,7 +79,7 @@ class Dataset[T] private[sql]( /** * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of the given schema. + * bound to the ordinals of this [[Dataset]]'s output schema. */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) From 606f99b942a25dc3d86138b00026462ffe58cded Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 14 Dec 2015 19:42:16 -0800 Subject: [PATCH 1226/2089] [SPARK-12288] [SQL] Support UnsafeRow in Coalesce/Except/Intersect. Support UnsafeRow for the Coalesce/Except/Intersect. Could you review if my code changes are ok? davies Thank you! Author: gatorsmile Closes #10285 from gatorsmile/unsafeSupportCIE. --- .../spark/sql/execution/basicOperators.scala | 12 ++++++- .../execution/RowFormatConvertersSuite.scala | 35 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a42aea0b96d43..b3e4688557ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -137,7 +137,7 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = @@ -250,7 +250,9 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { child.execute().coalesce(numPartitions, shuffle = false) } + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -263,6 +265,10 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -275,6 +281,10 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 13d68a103a225..2328899bb2f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -58,6 +58,41 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { assert(!preparedPlan.outputsUnsafeRows) } + test("coalesce can process unsafe rows") { + val plan = Coalesce(1, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except can process unsafe rows") { + val plan = Except(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except requires all of its input rows' formats to agree") { + val plan = Except(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect can process unsafe rows") { + val plan = Intersect(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect requires all of its input rows' formats to agree") { + val plan = Intersect(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + test("execute() fails an assertion if inputs rows are of different formats") { val e = intercept[AssertionError] { Union(Seq(outputsSafe, outputsUnsafe)).execute() From c59df8c51609a0d6561ae1868e7970b516fb1811 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 15 Dec 2015 11:38:57 +0000 Subject: [PATCH 1227/2089] [SPARK-12332][TRIVIAL][TEST] Fix minor typo in ResetSystemProperties Fix a minor typo (unbalanced bracket) in ResetSystemProperties. Author: Holden Karau Closes #10303 from holdenk/SPARK-12332-trivial-typo-in-ResetSystemProperties-comment. --- .../scala/org/apache/spark/util/ResetSystemProperties.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index c58db5e606f7c..60fb7abb66d32 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -45,7 +45,7 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // we need SerializationUtils.clone instead of `new Properties(System.getProperties())` because // the later way of creating a copy does not copy the properties but it initializes a new // Properties object with the given properties as defaults. They are not recognized at all // by standard Scala wrapper over Java Properties then. From bc1ff9f4a41401599d3a87fb3c23a2078228a29b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 15 Dec 2015 09:41:40 -0800 Subject: [PATCH 1228/2089] [STREAMING][MINOR] Fix typo in function name of StateImpl cc\ tdas zsxwing , please review. Thanks a lot. Author: jerryshao Closes #10305 from jerryshao/fix-typo-state-impl. --- streaming/src/main/scala/org/apache/spark/streaming/State.scala | 2 +- .../scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 2 +- .../scala/org/apache/spark/streaming/MapWithStateSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index b47bdda2c2137..42424d67d8838 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -206,7 +206,7 @@ private[streaming] class StateImpl[S] extends State[S] { * Update the internal data and flags in `this` to the given state that is going to be timed out. * This method allows `this` object to be reused across many state records. */ - def wrapTiminoutState(newState: S): Unit = { + def wrapTimingOutState(newState: S): Unit = { this.state = newState defined = true timingOut = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index ed95171f73ee1..fdf61674a37f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -67,7 +67,7 @@ private[streaming] object MapWithStateRDDRecord { // data returned if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - wrappedState.wrapTiminoutState(state) + wrappedState.wrapTimingOutState(state) val returned = mappingFunction(batchTime, key, None, wrappedState) mappedData ++= returned newStateMap.remove(key) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 4b08085e09b1f..6b21433f1781b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -125,7 +125,7 @@ class MapWithStateSuite extends SparkFunSuite state.remove() testState(None, shouldBeRemoved = true) - state.wrapTiminoutState(3) + state.wrapTimingOutState(3) testState(Some(3), shouldBeTimingOut = true) } From b24c12d7338b47b637698e7458ba90f34cba28c0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 15 Dec 2015 16:29:39 -0800 Subject: [PATCH 1229/2089] [MINOR][ML] Rename weights to coefficients for examples/DeveloperApiExample Rename ```weights``` to ```coefficients``` for examples/DeveloperApiExample. cc mengxr jkbradley Author: Yanbo Liang Closes #10280 from yanboliang/spark-coefficients. --- .../examples/ml/JavaDeveloperApiExample.java | 22 +++++++++---------- .../examples/ml/DeveloperApiExample.scala | 16 +++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 0b4c0d9ba9f8b..b9dd3ad957714 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) throws Exception { } if (sumPredictions != 0.0) { throw new Exception("MyJavaLogisticRegression predicted something other than 0," + - " even though all weights are 0!"); + " even though all coefficients are 0!"); } jsc.stop(); @@ -149,12 +149,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. int numFeatures = oldDataset.take(1).get(0).features().size(); - Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this); } @Override @@ -173,12 +173,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) { class MyJavaLogisticRegressionModel extends ClassificationModel { - private Vector weights_; - public Vector weights() { return weights_; } + private Vector coefficients_; + public Vector coefficients() { return coefficients_; } - public MyJavaLogisticRegressionModel(String uid, Vector weights) { + public MyJavaLogisticRegressionModel(String uid, Vector coefficients) { this.uid_ = uid; - this.weights_ = weights; + this.coefficients_ = coefficients; } private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); @@ -208,7 +208,7 @@ public String uid() { * modifier. */ public Vector predictRaw(Vector features) { - double margin = BLAS.dot(features, weights_); + double margin = BLAS.dot(features, coefficients_); // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). return Vectors.dense(-margin, margin); @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Number of features the model was trained on. */ - public int numFeatures() { return weights_.size(); } + public int numFeatures() { return coefficients_.size(); } /** * Create a copy of the model. @@ -235,7 +235,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra) .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3758edc56198a..c1f63c6a1dce3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -75,7 +75,7 @@ object DeveloperApiExample { prediction }.sum assert(sumPredictions == 0.0, - "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + "MyLogisticRegression predicted something other than 0, even though all coefficients are 0!") sc.stop() } @@ -124,12 +124,12 @@ private class MyLogisticRegression(override val uid: String) // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. val numFeatures = oldDataset.take(1)(0).features.size - val weights = Vectors.zeros(numFeatures) // Learning would happen here. + val coefficients = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(uid, weights).setParent(this) + new MyLogisticRegressionModel(uid, coefficients).setParent(this) } override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) @@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String) */ private class MyLogisticRegressionModel( override val uid: String, - val weights: Vector) + val coefficients: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -163,7 +163,7 @@ private class MyLogisticRegressionModel( * confidence for that label. */ override protected def predictRaw(features: Vector): Vector = { - val margin = BLAS.dot(features, weights) + val margin = BLAS.dot(features, coefficients) // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). Vectors.dense(-margin, margin) @@ -173,7 +173,7 @@ private class MyLogisticRegressionModel( override val numClasses: Int = 2 /** Number of features the model was trained on. */ - override val numFeatures: Int = weights.size + override val numFeatures: Int = coefficients.size /** * Create a copy of the model. @@ -182,7 +182,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) + copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent) } } // scalastyle:on println From 86ea64dd146757c8f997d05fb5bb44f6aa58515c Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 15 Dec 2015 16:55:58 -0800 Subject: [PATCH 1230/2089] [SPARK-12271][SQL] Improve error message when Dataset.as[ ] has incompatible schemas. Author: Nong Li Closes #10260 from nongli/spark-11271. --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/encoders/ExpressionEncoder.scala | 1 + .../spark/sql/catalyst/expressions/objects.scala | 12 +++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 +++++++++- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9013fd050b5f9..ecff8605706de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -184,7 +184,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath - WrapOption(constructorFor(optType, path, newTypePath)) + WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 3e8420ecb9ccf..363178b0e21a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -251,6 +251,7 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) + SimpleAnalyzer.checkAnalysis(analyzedPlan) val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 96bc4fe67a985..10ec75eca37f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -23,11 +23,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -295,13 +293,17 @@ case class UnwrapOption( * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. * @param child The expression to evaluate and wrap. + * @param optType The type of this option. */ -case class WrapOption(child: Expression) extends UnaryExpression { +case class WrapOption(child: Expression, optType: DataType) + extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true + override def inputTypes: Seq[AbstractDataType] = optType :: Nil + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 542e4d6c43b9f..8f8db318261db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -481,10 +481,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) } -} + test("verify mismatching field names fail with a good error") { + val ds = Seq(ClassData("a", 1)).toDS() + val e = intercept[AnalysisException] { + ds.as[ClassData2].collect() + } + assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) + } +} case class ClassData(a: String, b: Int) +case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) /** From 28112657ea5919451291c21b4b8e1eb3db0ec8d4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 17:02:14 -0800 Subject: [PATCH 1231/2089] [SPARK-12236][SQL] JDBC filter tests all pass if filters are not really pushed down https://issues.apache.org/jira/browse/SPARK-12236 Currently JDBC filters are not tested properly. All the tests pass even if the filters are not pushed down due to Spark-side filtering. In this PR, Firstly, I corrected the tests to properly check the pushed down filters by removing Spark-side filtering. Also, `!=` was being tested which is actually not pushed down. So I removed them. Lastly, I moved the `stripSparkFilter()` function to `SQLTestUtils` as this functions would be shared for all tests for pushed down filters. This function would be also shared with ORC datasource as the filters for that are also not being tested properly. Author: hyukjinkwon Closes #10221 from HyukjinKwon/SPARK-12236. --- .../datasources/parquet/ParquetFilterSuite.scala | 15 --------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 10 ++++------ .../org/apache/spark/sql/test/SQLTestUtils.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index daf41bc292cc9..6178e37d2a585 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -110,21 +110,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - /** - * Strip Spark-side filtering in order to check if a datasource filters rows correctly. - */ - protected def stripSparkFilter(df: DataFrame): DataFrame = { - val schema = df.schema - val childRDD = df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - .map(row => Row.fromSeq(row.toSeq(schema))) - - sqlContext.createDataFrame(childRDD, schema) - } - test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8c24aa3151bc1..a360947152996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -176,12 +176,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) } test("SELECT * WHERE (quoted strings)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9214569f18e93..e87da1527c4d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -179,6 +179,21 @@ private[sql] trait SQLTestUtils try f finally sqlContext.sql(s"USE default") } + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + /** * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier * way to construct [[DataFrame]] directly out of local data without relying on implicits. From 31b391019ff6eb5a483f4b3e62fd082de7ff8416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Tue, 15 Dec 2015 18:06:30 -0800 Subject: [PATCH 1232/2089] [SPARK-12105] [SQL] add convenient show functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #10130 from jbonofre/SPARK-12105. --- .../org/apache/spark/sql/DataFrame.scala | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd48266770..b69d4411425d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,17 +160,24 @@ class DataFrame private[sql]( } } + /** + * Compose the string representing rows for output + */ + def showString(): String = { + showString(20) + } + /** * Compose the string representing rows for output - * @param _numRows Number of rows to show + * @param numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { - val numRows = _numRows.max(0) + def showString(numRows: Int, truncate: Boolean = true): String = { + val _numRows = numRows.max(0) val sb = new StringBuilder - val takeResult = take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val takeResult = take(_numRows + 1) + val hasMoreData = takeResult.length > _numRows + val data = takeResult.take(_numRows) val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets @@ -224,10 +231,10 @@ class DataFrame private[sql]( sb.append(sep) - // For Data that has more than "numRows" records + // For Data that has more than "_numRows" records if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") + val rowsString = if (_numRows == 1) "row" else "rows" + sb.append(s"only showing top $_numRows $rowsString\n") } sb.toString() From 840bd2e008da5b22bfa73c587ea2c57666fffc60 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Dec 2015 18:11:53 -0800 Subject: [PATCH 1233/2089] [HOTFIX] Compile error from commit 31b3910 --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b69d4411425d4..33b03be1138be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -234,7 +234,7 @@ class DataFrame private[sql]( // For Data that has more than "_numRows" records if (hasMoreData) { val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top $_numRows $rowsString\n") + sb.append(s"only showing top ${_numRows} $rowsString\n") } sb.toString() From f725b2ec1ab0d89e35b5e2d3ddeddb79fec85f6d Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 15 Dec 2015 18:15:10 -0800 Subject: [PATCH 1234/2089] [SPARK-12056][CORE] Part 2 Create a TaskAttemptContext only after calling setConf This is continuation of SPARK-12056 where change is applied to SqlNewHadoopRDD.scala andrewor14 FYI Author: tedyu Closes #10164 from tedyu/master. --- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 56cb63d9eff2a..eea780cbaa7e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -148,14 +148,14 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private[this] var reader: RecordReader[Void, V] = null /** From 369127f03257e7081d2aa1fc445e773b26f0d5e3 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 15 Dec 2015 18:16:22 -0800 Subject: [PATCH 1235/2089] [SPARK-12130] Replace shuffleManagerClass with shortShuffleMgrNames in ExternalShuffleBlockResolver Replace shuffleManagerClassName with shortShuffleMgrName is to reduce time of string's comparison. and put sort's comparison on the front. cc JoshRosen andrewor14 Author: Lianhui Wang Closes #10131 from lianhuiwang/spark-12130. --- .../scala/org/apache/spark/shuffle/ShuffleManager.scala | 4 ++++ .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 ++ .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 2 ++ .../main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../network/shuffle/ExternalShuffleBlockResolver.java | 7 +++---- .../apache/spark/network/sasl/SaslIntegrationSuite.java | 3 +-- .../network/shuffle/ExternalShuffleBlockResolverSuite.java | 6 +++--- .../network/shuffle/ExternalShuffleIntegrationSuite.java | 4 ++-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 978366d1a1d1b..a3444bf4daa3b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -28,6 +28,10 @@ import org.apache.spark.{TaskContext, ShuffleDependency} * boolean isDriver as parameters. */ private[spark] trait ShuffleManager { + + /** Return short name for the ShuffleManager */ + val shortName: String + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index d2e2fc4c110a7..4f30da0878ee1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -34,6 +34,8 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) + override val shortName: String = "hash" + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 66b6bbc61fe8e..9b1a279528428 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -79,6 +79,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + override val shortName: String = "sort" + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed05143877e20..540e1ec003a2b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -200,7 +200,7 @@ private[spark] class BlockManager( val shuffleConfig = new ExecutorShuffleInfo( diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, - shuffleManager.getClass.getName) + shuffleManager.shortName) val MAX_ATTEMPTS = 3 val SLEEP_TIME_SECS = 5 diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e5cb68c8a4dbb..fe933ed650caf 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -183,11 +183,10 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) - || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { + if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else if ("hash".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); } else { throw new UnsupportedOperationException( "Unsupported shuffle manager: " + executor.shuffleManager); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index f573d962fe361..0ea631ea14d70 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -221,8 +221,7 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { // Register an executor so that the next steps work. ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( - new String[] { System.getProperty("java.io.tmpdir") }, 1, - "org.apache.spark.shuffle.sort.SortShuffleManager"); + new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index a9958232a1d28..60a1b8b0451fe 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -83,7 +83,7 @@ public void testBadRequests() throws IOException { // Nonexistent shuffle block resolver.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); try { resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); fail("Should have failed"); @@ -96,7 +96,7 @@ public void testBadRequests() throws IOException { public void testSortShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); @@ -115,7 +115,7 @@ public void testSortShuffleBlocks() throws IOException { public void testHashShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + dataContext.createExecutorInfo("hash")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 2095f41d79c16..5e706bf401693 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -49,8 +49,8 @@ public class ExternalShuffleIntegrationSuite { static String APP_ID = "app-id"; - static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + static String SORT_MANAGER = "sort"; + static String HASH_MANAGER = "hash"; // Executor 0 is sort-based static TestShuffleDataContext dataContext0; From c2de99a7c3a52b0da96517c7056d2733ef45495f Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 15 Dec 2015 18:20:00 -0800 Subject: [PATCH 1236/2089] [SPARK-12351][MESOS] Add documentation about submitting Spark with mesos cluster mode. Adding more documentation about submitting jobs with mesos cluster mode. Author: Timothy Chen Closes #10086 from tnachen/mesos_supervise_docs. --- docs/running-on-mesos.md | 26 +++++++++++++++++++++----- docs/submitting-applications.md | 15 ++++++++++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index a197d0e373027..3193e17853483 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -150,14 +150,30 @@ it does not need to be redundantly passed in as a system property. Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client can find the results of the driver from the Mesos Web UI. -To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, -passing in the Mesos master url (e.g: mesos://host:5050). +To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url -to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL +to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. -Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves. +For example: +{% highlight bash %} +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 +{% endhighlight %} + + +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves, as the Spark driver doesn't automatically upload local jars. # Mesos Run Modes diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index ac2a14eb56fea..acbb0f298fe47 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -115,6 +115,18 @@ export HADOOP_CONF_DIR=XXX --master spark://207.184.161.138:7077 \ examples/src/main/python/pi.py \ 1000 + +# Run on a Mesos cluster in cluster deploy mode with supervise +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 + {% endhighlight %} # Master URLs @@ -132,9 +144,10 @@ The master URL passed to Spark can be in one of the following formats:
    ") + } else { + if (!forceAdd) { + stackTrace.remove() + } + } +} + +function expandAllThreadStackTrace(toggleButton) { + $('.accordion-heading').each(function() { + //get thread ID + if (!$(this).hasClass("hidden")) { + var trId = $(this).attr('id').match(/thread_([0-9]+)_tr/m)[1] + toggleThreadStackTrace(trId, true) + } + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden') + } +} + +function collapseAllThreadStackTrace(toggleButton) { + $('.accordion-body').each(function() { + $(this).remove() + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden'); + } +} + + +// inOrOut - true: over, false: out +function onMouseOverAndOut(threadId) { + $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); +} + +function onSearchStringChange() { + var searchString = $('#search').val().toLowerCase(); + //remove the stacktrace + collapseAllThreadStackTrace(false) + if (searchString.length == 0) { + $('tr').each(function() { + $(this).removeClass('hidden') + }) + } else { + $('tr').each(function(){ + if($(this).attr('id') && $(this).attr('id').match(/thread_[0-9]+_tr/) ) { + var children = $(this).children() + var found = false + for (i = 0; i < children.length; i++) { + if (children.eq(i).text().toLowerCase().indexOf(searchString) >= 0) { + found = true + } + } + if (found) { + $(this).removeClass('hidden') + } else { + $(this).addClass('hidden') + } + } + }); + } +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index c628a0c706553..b54e33a96fa23 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -221,10 +221,8 @@ a.expandbutton { cursor: pointer; } -.executor-thread { - background: #E6E6E6; -} - -.non-executor-thread { - background: #FAFAFA; +.threaddump-td-mouseover { + background-color: #49535a !important; + color: white; + cursor:pointer; } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index b0a2cb4aa4d4b..58575d154ce5c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,44 +60,49 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } } }.map { thread => - val threadName = thread.threadName - val className = "accordion-heading " + { - if (threadName.contains("Executor task launch")) { - "executor-thread" - } else { - "non-executor-thread" - } - } -
    - -
    + + + + + + } + +
    +

    Updated at {UIUtils.formatDate(time)}

    + { + // scalastyle:off +

    + Expand All +

    +

    +
    +
    +
    +
    + Search:
    +
    +

    + // scalastyle:on } - -
    -

    Updated at {UIUtils.formatDate(time)}

    - { - // scalastyle:off -

    - Expand All -

    -

    - // scalastyle:on - } -
    {dumpRows}
    -
    +
    MLlib modelPMML model
    `spark.mllib` modelPMML model
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. -
    spark.rpc.numRetries 3Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager instead.
    ExecutorDriver(random)Class file serverspark.replClassServer.portJetty-based. Only used in Spark shells.
    Executor / Driver Executor / Driver
    spark.memory.offHeap.enabledtrue + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. +
    spark.memory.offHeap.size0 + The absolute amount of memory which can be used for off-heap allocation. + This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This must be set to a positive value when spark.memory.offHeap.enabled=true. +
    spark.memory.useLegacyMode false
    mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... + To submit with --deploy-mode cluster, the HOST:PORT should be configured to connect to the MesosClusterDispatcher.
    yarn Connect to a YARN cluster in - client or cluster mode depending on the value of --deploy-mode. + client or cluster mode depending on the value of --deploy-mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
    yarn-client Equivalent to yarn with --deploy-mode client, From a63d9edcfb8a714a17492517927aa114dea8fea0 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 15 Dec 2015 18:21:00 -0800 Subject: [PATCH 1237/2089] [SPARK-9516][UI] Improvement of Thread Dump Page https://issues.apache.org/jira/browse/SPARK-9516 - [x] new look of Thread Dump Page - [x] click column title to sort - [x] grep - [x] search as you type squito JoshRosen It's ready for the review now Author: CodingCat Closes #7910 from CodingCat/SPARK-9516. --- .../org/apache/spark/ui/static/sorttable.js | 6 +- .../org/apache/spark/ui/static/table.js | 72 ++++++++++++++++++ .../org/apache/spark/ui/static/webui.css | 10 +-- .../ui/exec/ExecutorThreadDumpPage.scala | 73 ++++++++++--------- 4 files changed, 118 insertions(+), 43 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index a73d9a5cbc215..ff241470f32df 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -169,7 +169,7 @@ sorttable = { for (var i=0; i
    " +
    +            stackTraceText +  "
    {threadId}{thread.threadName}{thread.threadState}
    + + + + + + {dumpRows} +
    Thread IDThread NameThread State
    +
    }.getOrElse(Text("Error fetching thread dump")) UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) } From 765a488494dac0ed38d2b81742c06467b79d96b2 Mon Sep 17 00:00:00 2001 From: "Richard W. Eggert II" Date: Tue, 15 Dec 2015 18:22:58 -0800 Subject: [PATCH 1238/2089] [SPARK-9026][SPARK-4514] Modifications to JobWaiter, FutureAction, and AsyncRDDActions to support non-blocking operation These changes rework the implementations of `SimpleFutureAction`, `ComplexFutureAction`, `JobWaiter`, and `AsyncRDDActions` such that asynchronous callbacks on the generated `Futures` NEVER block waiting for a job to complete. A small amount of mutex synchronization is necessary to protect the internal fields that manage cancellation, but these locks are only held very briefly and in practice should almost never cause any blocking to occur. The existing blocking APIs of these classes are retained, but they simply delegate to the underlying non-blocking API and `Await` the results with indefinite timeouts. Associated JIRA ticket: https://issues.apache.org/jira/browse/SPARK-9026 Also fixes: https://issues.apache.org/jira/browse/SPARK-4514 This pull request contains all my own original work, which I release to the Spark project under its open source license. Author: Richard W. Eggert II Closes #9264 from reggert/fix-futureaction. --- .../scala/org/apache/spark/FutureAction.scala | 164 +++++++----------- .../apache/spark/rdd/AsyncRDDActions.scala | 48 ++--- .../apache/spark/scheduler/DAGScheduler.scala | 8 +- .../apache/spark/scheduler/JobWaiter.scala | 48 ++--- .../test/scala/org/apache/spark/Smuggle.scala | 82 +++++++++ .../org/apache/spark/StatusTrackerSuite.scala | 26 +++ .../spark/rdd/AsyncRDDActionsSuite.scala | 33 +++- 7 files changed, 251 insertions(+), 158 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/Smuggle.scala diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 48792a958130c..2a8220ff40090 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -20,13 +20,15 @@ package org.apache.spark import java.util.Collections import java.util.concurrent.TimeUnit +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobWaiter -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.{Failure, Try} /** * A future for the result of an action to support cancellation. This is an extension of the @@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] { * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ +@DeveloperApi class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { @@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { - if (!atMost.isFinite()) { - awaitResult() - } else jobWaiter.synchronized { - val finishTime = System.currentTimeMillis() + atMost.toMillis - while (!isCompleted) { - val time = System.currentTimeMillis() - if (time >= finishTime) { - throw new TimeoutException - } else { - jobWaiter.wait(finishTime - time) - } - } - } + jobWaiter.completionFuture.ready(atMost) this } @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { - ready(atMost)(permit) - awaitResult() match { - case scala.util.Success(res) => res - case scala.util.Failure(e) => throw e - } + jobWaiter.completionFuture.ready(atMost) + assert(value.isDefined, "Future has not completed properly") + value.get.get } override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { - executor.execute(new Runnable { - override def run() { - func(awaitResult()) - } - }) + jobWaiter.completionFuture onComplete {_ => func(value.get)} } override def isCompleted: Boolean = jobWaiter.jobFinished override def isCancelled: Boolean = _cancelled - override def value: Option[Try[T]] = { - if (jobWaiter.jobFinished) { - Some(awaitResult()) - } else { - None - } - } - - private def awaitResult(): Try[T] = { - jobWaiter.awaitResult() match { - case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception) => scala.util.Failure(e) - } - } + override def value: Option[Try[T]] = + jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } +/** + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ +@DeveloperApi +trait JobSubmitter { + /** + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): FutureAction[R] +} + + /** * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the - * action thread if it is being blocked by a job. + * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending + * jobs. */ -class ComplexFutureAction[T] extends FutureAction[T] { +@DeveloperApi +class ComplexFutureAction[T](run : JobSubmitter => Future[T]) + extends FutureAction[T] { self => - // Pointer to the thread that is executing the action. It is set when the action is run. - @volatile private var thread: Thread = _ + @volatile private var _cancelled = false - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - - @volatile private var jobs: Seq[Int] = Nil + @volatile private var subActions: List[FutureAction[_]] = Nil // A promise used to signal the future. - private val p = promise[T]() + private val p = Promise[T]().tryCompleteWith(run(jobSubmitter)) - override def cancel(): Unit = this.synchronized { + override def cancel(): Unit = synchronized { _cancelled = true - if (thread != null) { - thread.interrupt() - } - } - - /** - * Executes some action enclosed in the closure. To properly enable cancellation, the closure - * should use runJob implementation in this promise. See takeAsync for example. - */ - def run(func: => T)(implicit executor: ExecutionContext): this.type = { - scala.concurrent.future { - thread = Thread.currentThread - try { - p.success(func) - } catch { - case e: Exception => p.failure(e) - } finally { - // This lock guarantees when calling `thread.interrupt()` in `cancel`, - // thread won't be set to null. - ComplexFutureAction.this.synchronized { - thread = null - } - } - } - this + p.tryFailure(new SparkException("Action has been cancelled")) + subActions.foreach(_.cancel()) } - /** - * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ - def runJob[T, U, R]( + private def jobSubmitter = new JobSubmitter { + def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { - // If the action hasn't been cancelled yet, submit the job. The check and the submitJob - // command need to be in an atomic block. - val job = this.synchronized { + resultFunc: => R): FutureAction[R] = self.synchronized { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. if (!isCancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + val job = rdd.context.submitJob( + rdd, + processPartition, + partitions, + resultHandler, + resultFunc) + subActions = job :: subActions + job } else { throw new SparkException("Action has been cancelled") } } - - this.jobs = jobs ++ job.jobIds - - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. This is not in a synchronized block because - // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. - try { - Await.ready(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw new SparkException("Action has been cancelled") - } } override def isCancelled: Boolean = _cancelled @@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds: Seq[Int] = jobs + def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) } + private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { @@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S Await.ready(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) - case Failure(exception) => + case scala.util.Failure(exception) => if (isCancelled) { throw new CancellationException("Job cancelled").initCause(exception) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d5e853613b05b..14f541f937b4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,13 +19,12 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.util.ThreadUtils - import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.util.ThreadUtils /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { - val f = new ComplexFutureAction[Seq[T]] val callSite = self.context.getCallSite - - f.run { - // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which - // is a cached thread pool. - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - self.context.setCallSite(callSite) - while (results.size < num && partsScanned < totalParts) { + val localProperties = self.context.getLocalProperties + // Cached thread pool to handle aggregation of subtasks. + implicit val executionContext = AsyncRDDActions.futureExecutionContext + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + + /* + Recursively triggers jobs to scan partitions until either the requested + number of elements are retrieved, or the partitions to scan are exhausted. + This implementation is non-blocking, asynchronously handling the + results of each job and triggering the next job using callbacks on futures. + */ + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + if (results.size >= num || partsScanned >= totalParts) { + Future.successful(results.toSeq) + } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 @@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val buf = new Array[Array[T]](p.size) - f.runJob(self, + self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) + val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + job.flatMap {_ => + buf.foreach(results ++= _.take(num - results.size)) + continue(partsScanned + numPartsToTry) + } } - results.toSeq - }(AsyncRDDActions.futureExecutionContext) - f + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5582720bbcff2..8d0e0c8624a55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.mutable.{HashMap, HashSet, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -610,11 +611,12 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + Await.ready(waiter.completionFuture, atMost = Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a0..4326135186a73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,10 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -28,17 +32,15 @@ private[spark] class JobWaiter[T]( resultHandler: (Int, T) => Unit) extends JobListener { - private var finishedTasks = 0 - - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 - - def jobFinished: Boolean = _jobFinished - + private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + private val jobPromise: Promise[Unit] = + if (totalTasks == 0) Promise.successful(()) else Promise() + + def jobFinished: Boolean = jobPromise.isCompleted + + def completionFuture: Future[Unit] = jobPromise.future /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -49,29 +51,17 @@ private[spark] class JobWaiter[T]( dagScheduler.cancelJob(jobId) } - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + override def taskSucceeded(index: Int, result: Any): Unit = { + // resultHandler call must be synchronized in case resultHandler itself is not thread safe. + synchronized { + resultHandler(index, result.asInstanceOf[T]) } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + if (finishedTasks.incrementAndGet() == totalTasks) { + jobPromise.success(()) } } - override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } + override def jobFailed(exception: Exception): Unit = + jobPromise.failure(exception) - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() - } - return jobResult - } } diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala new file mode 100644 index 0000000000000..01694a6e6f741 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.UUID +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable + +/** + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ +class Smuggle[T] private(val key: Symbol) extends Serializable { + def smuggledObject: T = Smuggle.get(key) +} + + +object Smuggle { + /** + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ + def apply[T](smuggledObject: T): Smuggle[T] = { + val key = Symbol(UUID.randomUUID().toString) + lock.writeLock().lock() + try { + smuggledObjects += key -> smuggledObject + } finally { + lock.writeLock().unlock() + } + new Smuggle(key) + } + + private val lock = new ReentrantReadWriteLock + private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any] + + private def get[T](key: Symbol) : T = { + lock.readLock().lock() + try { + smuggledObjects(key).asInstanceOf[T] + } finally { + lock.readLock().unlock() + } + } + + /** + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ + implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 46516e8d25298..5483f2b8434aa 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont Set(firstJobId, secondJobId)) } } + + test("getJobIdsForGroup() with takeAsync()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) + } + } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index ec99f2a1bad66..de015ebd5d237 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import scala.concurrent.{Await, TimeoutException} +import scala.concurrent._ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim Await.result(f, Duration(20, "milliseconds")) } } + + private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { + val executionContextInvoked = Promise[Unit] + val fakeExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = { + executionContextInvoked.success(()) + } + override def reportFailure(t: Throwable): Unit = () + } + val starter = Smuggle(new Semaphore(0)) + starter.drainPermits() + val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr} + val f = action(rdd) + f.onComplete(_ => ())(fakeExecutionContext) + // Here we verify that registering the callback didn't cause a thread to be consumed. + assert(!executionContextInvoked.isCompleted) + // Now allow the executors to proceed with task processing. + starter.release(rdd.partitions.length) + // Waiting for the result verifies that the tasks were successfully processed. + Await.result(executionContextInvoked.future, atMost = 15.seconds) + } + + test("SimpleFutureAction callback must not consume a thread while waiting") { + testAsyncAction(_.countAsync()) + } + + test("ComplexFutureAction callback must not consume a thread while waiting") { + testAsyncAction((_.takeAsync(100))) + } } From 63ccdef81329e785807f37b4e918a9247fc70e3c Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 15 Dec 2015 18:24:23 -0800 Subject: [PATCH 1239/2089] [SPARK-10123][DEPLOY] Support specifying deploy mode from configuration Please help to review, thanks a lot. Author: jerryshao Closes #10195 from jerryshao/SPARK-10123. --- .../spark/deploy/SparkSubmitArguments.scala | 5 ++- .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++++++ docs/configuration.md | 15 +++++-- .../apache/spark/launcher/SparkLauncher.java | 3 ++ .../launcher/SparkSubmitCommandBuilder.java | 7 ++-- 5 files changed, 64 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 18a1c52ae53fb..915ef81b4eae3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -176,7 +176,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull - deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull + deployMode = Option(deployMode) + .orElse(sparkProperties.get("spark.submit.deployMode")) + .orElse(env.get("DEPLOY_MODE")) + .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d494b0caab85f..2626f5a16dfb8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -136,6 +136,47 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("specify deploy mode through configuration") { + val clArgs = Seq( + "--master", "yarn", + "--conf", "spark.submit.deployMode=client", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + + appArgs.deployMode should be ("client") + sysProps("spark.submit.deployMode") should be ("client") + + // Both cmd line and configuration are specified, cmdline option takes the priority + val clArgs1 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--conf", "spark.submit.deployMode=client", + "-class", "org.SomeClass", + "thejar.jar" + ) + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + + appArgs1.deployMode should be ("cluster") + sysProps1("spark.submit.deployMode") should be ("cluster") + + // Neither cmdline nor configuration are specified, client mode is the default choice + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + appArgs2.deployMode should be (null) + + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + appArgs2.deployMode should be ("client") + sysProps2("spark.submit.deployMode") should be ("client") + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/docs/configuration.md b/docs/configuration.md index 55cf4b2dac5f5..38d3d059f9d31 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,7 +48,7 @@ The following format is accepted: 1y (years) -Properties that specify a byte size should be configured with a unit of size. +Properties that specify a byte size should be configured with a unit of size. The following format is accepted: 1b (bytes) @@ -192,6 +192,15 @@ of the most common options to set are: allowed master URL's. + + spark.submit.deployMode + (none) + + The deploy mode of Spark driver program, either "client" or "cluster", + Which means to launch driver program locally ("client") + or remotely ("cluster") on one of the nodes inside the cluster. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -1095,7 +1104,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. @@ -1559,7 +1568,7 @@ Apart from these, the following properties are also available, and may be useful spark.streaming.stopGracefullyOnShutdown false - If true, Spark shuts down the StreamingContext gracefully on JVM + If true, Spark shuts down the StreamingContext gracefully on JVM shutdown rather than immediately. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index dd1c93af6ca4c..20e6003a00c19 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -40,6 +40,9 @@ public class SparkLauncher { /** The Spark master. */ public static final String SPARK_MASTER = "spark.master"; + /** The Spark deploy mode. */ + public static final String DEPLOY_MODE = "spark.submit.deployMode"; + /** Configuration key for the driver memory. */ public static final String DRIVER_MEMORY = "spark.driver.memory"; /** Configuration key for the driver class path. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 312df0b269f32..a95f0f17517d1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -294,10 +294,11 @@ private void constructEnvVarArgs( private boolean isClientMode(Map userProps) { String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); - // Default master is "local[*]", so assume client mode in that case. + String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE)); + // Default master is "local[*]", so assume client mode in that case return userMaster == null || - "client".equals(deployMode) || - (!userMaster.equals("yarn-cluster") && deployMode == null); + "client".equals(userDeployMode) || + (!userMaster.equals("yarn-cluster") && userDeployMode == null); } /** From 8a215d2338c6286253e20122640592f9d69896c8 Mon Sep 17 00:00:00 2001 From: Naveen Date: Tue, 15 Dec 2015 18:25:22 -0800 Subject: [PATCH 1240/2089] [SPARK-9886][CORE] Fix to use ShutdownHookManager in ExternalBlockStore.scala Author: Naveen Closes #10313 from naveenminchu/branch-fix-SPARK-9886. --- .../spark/storage/ExternalBlockStore.scala | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index db965d54bafd6..94883a54a74e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -177,15 +177,6 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - externalBlockManager.map(_.shutdown()) - } - }) - } - // Create concrete block manager and fall back to Tachyon by default for backward compatibility. private def createBlkManager(): Option[ExternalBlockManager] = { val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) @@ -196,7 +187,10 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) - addShutdownHook(); + ShutdownHookManager.addShutdownHook { () => + logDebug("Shutdown hook called") + externalBlockManager.map(_.shutdown()) + } Some(instance) } catch { case NonFatal(t) => From c5b6b398d5e368626e589feede80355fb74c2bd8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 15 Dec 2015 18:28:16 -0800 Subject: [PATCH 1241/2089] [SPARK-12062][CORE] Change Master to asyc rebuild UI when application completes This change builds the event history of completed apps asynchronously so the RPC thread will not be blocked and allow new workers to register/remove if the event log history is very large and takes a long time to rebuild. Author: Bryan Cutler Closes #10284 from BryanCutler/async-MasterUI-SPARK-12062. --- .../apache/spark/deploy/master/Master.scala | 79 ++++++++++++------- .../spark/deploy/master/MasterMessages.scala | 2 + 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1355e1ad1b523..fc42bf06e40a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,9 +21,11 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date -import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps import scala.util.Random @@ -56,6 +58,10 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val rebuildUIThread = + ThreadUtils.newDaemonSingleThreadExecutor("master-rebuild-ui-thread") + private val rebuildUIContext = ExecutionContext.fromExecutor(rebuildUIThread) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -78,7 +84,8 @@ private[deploy] class Master( private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 - private val appIdToUI = new HashMap[String, SparkUI] + // Using ConcurrentHashMap so that master-rebuild-ui-thread can add a UI after asyncRebuildUI + private val appIdToUI = new ConcurrentHashMap[String, SparkUI] private val drivers = new HashSet[DriverInfo] private val completedDrivers = new ArrayBuffer[DriverInfo] @@ -191,6 +198,7 @@ private[deploy] class Master( checkForWorkerTimeOutTask.cancel(true) } forwardMessageThread.shutdownNow() + rebuildUIThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -367,6 +375,10 @@ private[deploy] class Master( case CheckForWorkerTimeOut => { timeOutDeadWorkers() } + + case AttachCompletedRebuildUI(appId) => + // An asyncRebuildSparkUI has completed, so need to attach to master webUi + Option(appIdToUI.get(appId)).foreach { ui => webUi.attachSparkUI(ui) } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -809,7 +821,7 @@ private[deploy] class Master( if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { - appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) } + Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) }) completedApps.trimStart(toRemove) @@ -818,7 +830,7 @@ private[deploy] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - rebuildSparkUI(app) + asyncRebuildSparkUI(app) for (exec <- app.executors.values) { killExecutor(exec) @@ -923,49 +935,57 @@ private[deploy] class Master( * Return the UI if successful, else None */ private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { + val futureUI = asyncRebuildSparkUI(app) + Await.result(futureUI, Duration.Inf) + } + + /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ + private[master] def asyncRebuildSparkUI(app: ApplicationInfo): Future[Option[SparkUI]] = { val appName = app.desc.name val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" - try { - val eventLogDir = app.desc.eventLogDir - .getOrElse { - // Event logging is not enabled for this application - app.appUIUrlAtHistoryServer = Some(notFoundBasePath) - return None - } - + val eventLogDir = app.desc.eventLogDir + .getOrElse { + // Event logging is disabled for this application + app.appUIUrlAtHistoryServer = Some(notFoundBasePath) + return Future.successful(None) + } + val futureUI = Future { val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + - EventLoggingListener.IN_PROGRESS)) + EventLoggingListener.IN_PROGRESS)) - if (inProgressExists) { + val eventLogFile = if (inProgressExists) { // Event logging is enabled for this application, but the application is still in progress logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") - } - - val (eventLogFile, status) = if (inProgressExists) { - (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)") + eventLogFilePrefix + EventLoggingListener.IN_PROGRESS } else { - (eventLogFilePrefix, " (completed)") + eventLogFilePrefix } val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) - val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { - replayBus.replay(logInput, eventLogFile, maybeTruncated) + replayBus.replay(logInput, eventLogFile, inProgressExists) } finally { logInput.close() } - appIdToUI(app.id) = ui - webUi.attachSparkUI(ui) + + Some(ui) + }(rebuildUIContext) + + futureUI.onSuccess { case Some(ui) => + appIdToUI.put(app.id, ui) + self.send(AttachCompletedRebuildUI(app.id)) // Application UI is successfully rebuilt, so link the Master UI to it + // NOTE - app.appUIUrlAtHistoryServer is volatile app.appUIUrlAtHistoryServer = Some(ui.basePath) - Some(ui) - } catch { + }(ThreadUtils.sameThread) + + futureUI.onFailure { case fnf: FileNotFoundException => // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" @@ -974,7 +994,7 @@ private[deploy] class Master( msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title") - None + case e: Exception => // Relay exception message to application UI page val title = s"Application history load error (${app.id})" @@ -984,8 +1004,9 @@ private[deploy] class Master( msg = URLEncoder.encode(msg, "UTF-8") app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title") - None - } + }(ThreadUtils.sameThread) + + futureUI } /** Generate a new app ID given a app's submission date */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index a952cee36eb44..a055d097674ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -39,4 +39,6 @@ private[master] object MasterMessages { case object BoundPortsRequest case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) + + case class AttachCompletedRebuildUI(appId: String) } From a89e8b6122ee5a1517fbcf405b1686619db56696 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Dec 2015 18:29:19 -0800 Subject: [PATCH 1242/2089] [SPARK-10477][SQL] using DSL in ColumnPruningSuite to improve readability Author: Wenchen Fan Closes #8645 from cloud-fan/test. --- .../spark/sql/catalyst/dsl/package.scala | 7 ++-- .../optimizer/ColumnPruningSuite.scala | 41 +++++++++++-------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index af594c25c54cb..e50971173c499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -275,13 +275,14 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + alias: Option[String] = None, + outputNames: Seq[String] = Nil): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, + outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 4a1e7ceaf394b..9bf61ae091786 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning for Generate when Generate.join = false") { val input = LocalRelation('a.int, 'b.array(StringType)) - val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val query = input.generate(Explode('b), join = false).analyze + val optimized = Optimize.execute(query) - val correctAnswer = - Generate(Explode('b), false, false, None, 's.string :: Nil, - Project('b.attr :: Nil, input)).analyze + val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze comparePlans(optimized, correctAnswer) } @@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) val query = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - Project(Seq('a, 'c), - input))).analyze + input + .select('a, 'c) + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze comparePlans(optimized, correctAnswer) } @@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('b.array(StringType)) val query = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = true, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), false, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = false, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze comparePlans(optimized, correctAnswer) } From ca0690b5ef10b14ce57a0c30d5308eb02f163f39 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Tue, 15 Dec 2015 18:30:59 -0800 Subject: [PATCH 1243/2089] [SPARK-4117][YARN] Spark on Yarn handle AM being told command from RM Spark on Yarn handle AM being told command from RM When RM throws ApplicationAttemptNotFoundException for allocate invocation, making the ApplicationMaster to finish immediately without any retries. Author: Devaraj K Closes #10129 from devaraj-kavali/SPARK-4117. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1970f7d150feb..fc742df73d731 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -376,7 +376,14 @@ private[spark] class ApplicationMaster( case i: InterruptedException => case e: Throwable => { failureCount += 1 - if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + // this exception was introduced in hadoop 2.4 and this code would not compile + // with earlier versions if we refer it directly. + if ("org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException" == + e.getClass().getName()) { + logError("Exception from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, + e.getMessage) + } else if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + s"$failureCount time(s) from Reporter thread.") From d52bf47e13e0186590437f71040100d2f6f11da9 Mon Sep 17 00:00:00 2001 From: proflin Date: Tue, 15 Dec 2015 20:22:56 -0800 Subject: [PATCH 1244/2089] =?UTF-8?q?[SPARK-12304][STREAMING]=20Make=20Spa?= =?UTF-8?q?rk=20Streaming=20web=20UI=20display=20more=20fri=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …endly Receiver graphs Currently, the Spark Streaming web UI uses the same maxY when displays 'Input Rate Times& Histograms' and 'Per-Receiver Times& Histograms'. This may lead to somewhat un-friendly graphs: once we have tens of Receivers or more, every 'Per-Receiver Times' line almost hits the ground. This issue proposes to calculate a new maxY against the original one, which is shared among all the `Per-Receiver Times& Histograms' graphs. Before: ![before-5](https://cloud.githubusercontent.com/assets/15843379/11761362/d790c356-a0fa-11e5-860e-4b834603de1d.png) After: ![after-5](https://cloud.githubusercontent.com/assets/15843379/11761361/cfabf692-a0fa-11e5-97d0-4ad124aaca2a.png) Author: proflin Closes #10318 from proflin/SPARK-12304. --- .../org/apache/spark/streaming/ui/StreamingPage.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 88a4483e8068f..b3692c3ea302b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -392,9 +392,15 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { + val maxYCalculated = listener.receivedEventRateWithBatchTime.values + .flatMap { case streamAndRates => streamAndRates.map { case (_, eventRate) => eventRate } } + .reduceOption[Double](math.max) + .map(_.ceil.toLong) + .getOrElse(0L) + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxYCalculated) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off From 0f6936b5f1c9b0be1c33b98ffb62a72ae0c3e2a8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:22:49 -0800 Subject: [PATCH 1245/2089] [SPARK-12249][SQL] JDBC non-equality comparison operator not pushed down. https://issues.apache.org/jira/browse/SPARK-12249 Currently `!=` operator is not pushed down correctly. I simply added a case for this. Author: hyukjinkwon Closes #10233 from HyukjinKwon/SPARK-12249. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1c348ed62fc78..c18a2d2cc0768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -281,6 +281,7 @@ private[sql] class JDBCRDD( */ private def compileFilter(f: Filter): String = f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" case LessThan(attr, value) => s"$attr < ${compileValue(value)}" case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a360947152996..aca1443057343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -177,9 +177,11 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("SELECT * WHERE (simple predicates)") { assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) } test("SELECT * WHERE (quoted strings)") { From 7f443a6879fa33ca8adb682bd85df2d56fb5fcda Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:25:08 -0800 Subject: [PATCH 1246/2089] [SPARK-12314][SQL] isnull operator not pushed down for JDBC datasource. https://issues.apache.org/jira/browse/SPARK-12314 `IsNull` filter is not being pushed down for JDBC datasource. It looks it is SQL standard according to [SQL-92](http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt), SQL:1999, [SQL:2003](http://www.wiscorp.com/sql_2003_standard.zip) and [SQL:201x](http://www.wiscorp.com/sql20nn.zip) and I believe most databases support this. In this PR, I simply added the case for `IsNull` filter to produce a proper filter string. Author: hyukjinkwon This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10286 from HyukjinKwon/SPARK-12314. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c18a2d2cc0768..3271b46be18fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -286,6 +286,7 @@ private[sql] class JDBCRDD( case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index aca1443057343..0305667ff66ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -182,6 +182,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) } test("SELECT * WHERE (quoted strings)") { From 2aad2d372469aaf2773876cae98ef002fef03aa3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:30:35 -0800 Subject: [PATCH 1247/2089] [SPARK-12315][SQL] isnotnull operator not pushed down for JDBC datasource. https://issues.apache.org/jira/browse/SPARK-12315 `IsNotNull` filter is not being pushed down for JDBC datasource. It looks it is SQL standard according to [SQL-92](http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt), SQL:1999, [SQL:2003](http://www.wiscorp.com/sql_2003_standard.zip) and [SQL:201x](http://www.wiscorp.com/sql20nn.zip) and I believe most databases support this. In this PR, I simply added the case for `IsNotNull` filter to produce a proper filter string. Author: hyukjinkwon This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10287 from HyukjinKwon/SPARK-12315. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3271b46be18fb..2d38562e0901a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -287,6 +287,7 @@ private[sql] class JDBCRDD( case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0305667ff66ee..d6aeb523ea8d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -183,6 +183,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) + assert(stripSparkFilter( + sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size === 0) } test("SELECT * WHERE (quoted strings)") { From 554d840a9ade79722c96972257435a05e2aa9d88 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Dec 2015 22:32:51 -0800 Subject: [PATCH 1248/2089] Style fix for the previous 3 JDBC filter push down commits. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d6aeb523ea8d6..2b91f62c2fa22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -176,15 +176,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) - assert(stripSparkFilter( - sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) } test("SELECT * WHERE (quoted strings)") { From 18ea11c3a84e5eafd81fa0fe7c09224e79c4e93f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Dec 2015 00:57:07 -0800 Subject: [PATCH 1249/2089] Revert "[HOTFIX] Compile error from commit 31b3910" This reverts commit 840bd2e008da5b22bfa73c587ea2c57666fffc60. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 33b03be1138be..b69d4411425d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -234,7 +234,7 @@ class DataFrame private[sql]( // For Data that has more than "_numRows" records if (hasMoreData) { val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top ${_numRows} $rowsString\n") + sb.append(s"only showing top $_numRows $rowsString\n") } sb.toString() From 1a3d0cd9f013aee1f03b1c632c91ae0951bccbb0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Dec 2015 00:57:34 -0800 Subject: [PATCH 1250/2089] Revert "[SPARK-12105] [SQL] add convenient show functions" This reverts commit 31b391019ff6eb5a483f4b3e62fd082de7ff8416. --- .../org/apache/spark/sql/DataFrame.scala | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b69d4411425d4..497bd48266770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,24 +160,17 @@ class DataFrame private[sql]( } } - /** - * Compose the string representing rows for output - */ - def showString(): String = { - showString(20) - } - /** * Compose the string representing rows for output - * @param numRows Number of rows to show + * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - def showString(numRows: Int, truncate: Boolean = true): String = { - val _numRows = numRows.max(0) + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) val sb = new StringBuilder - val takeResult = take(_numRows + 1) - val hasMoreData = takeResult.length > _numRows - val data = takeResult.take(_numRows) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets @@ -231,10 +224,10 @@ class DataFrame private[sql]( sb.append(sep) - // For Data that has more than "_numRows" records + // For Data that has more than "numRows" records if (hasMoreData) { - val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top $_numRows $rowsString\n") + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") } sb.toString() From a6325fc401f68d9fa30cc947c44acc9d64ebda7b Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Wed, 16 Dec 2015 10:12:33 -0800 Subject: [PATCH 1251/2089] [SPARK-12324][MLLIB][DOC] Fixes the sidebar in the ML documentation This fixes the sidebar, using a pure CSS mechanism to hide it when the browser's viewport is too narrow. Credit goes to the original author Titan-C (mentioned in the NOTICE). Note that I am not a CSS expert, so I can only address comments up to some extent. Default view: screen shot 2015-12-14 at 12 46 39 pm When collapsed manually by the user: screen shot 2015-12-14 at 12 54 02 pm Disappears when column is too narrow: screen shot 2015-12-14 at 12 47 22 pm Can still be opened by the user if necessary: screen shot 2015-12-14 at 12 51 15 pm Author: Timothy Hunter Closes #10297 from thunterdb/12324. --- NOTICE | 9 ++- docs/_layouts/global.html | 35 +++++++--- docs/css/main.css | 137 ++++++++++++++++++++++++++++++++------ docs/js/main.js | 2 +- 4 files changed, 149 insertions(+), 34 deletions(-) diff --git a/NOTICE b/NOTICE index 7f7769f73047f..571f8c2fff7ff 100644 --- a/NOTICE +++ b/NOTICE @@ -606,4 +606,11 @@ Vis.js uses and redistributes the following third-party libraries: - keycharm https://github.com/AlexDM0/keycharm - The MIT License \ No newline at end of file + The MIT License + +=============================================================================== + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. +=============================================================================== diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 0b5b0cd48af64..3089474c13385 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -1,3 +1,4 @@ + @@ -127,20 +128,32 @@
    {% if page.url contains "/ml" %} - {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} - {% endif %} - + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    + {% else %} +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} - {{ content }} + {{ content }} -
    +
    + {% endif %} +
    diff --git a/docs/css/main.css b/docs/css/main.css index 356b324d6303b..175e8004fca0e 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -40,17 +40,14 @@ } body .container-wrapper { - position: absolute; - width: 100%; - display: flex; -} - -body #content { + background-color: #FFF; + color: #1D1F22; + max-width: 1024px; + margin-top: 10px; + margin-left: auto; + margin-right: auto; + border-radius: 15px; position: relative; - - line-height: 1.6; /* Inspired by Github's wiki style */ - background-color: white; - padding-left: 15px; } .title { @@ -101,6 +98,24 @@ a:hover code { max-width: 914px; } +.content { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 15px; +} + +.content-with-sidebar { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 30px; +} + .dropdown-menu { /* Remove the default 2px top margin which causes a small gap between the hover trigger area and the popup menu */ @@ -171,24 +186,104 @@ a.anchorjs-link:hover { text-decoration: none; } * The left navigation bar. */ .left-menu-wrapper { - position: absolute; - height: 100%; - - width: 256px; - margin-top: -20px; - padding-top: 20px; + margin-left: 0px; + margin-right: 0px; background-color: #F0F8FC; + border-top-width: 0px; + border-left-width: 0px; + border-bottom-width: 0px; + margin-top: 0px; + width: 210px; + float: left; + position: absolute; } .left-menu { - position: fixed; - max-width: 350px; - - padding-right: 10px; - width: 256px; + padding: 0px; + width: 199px; } .left-menu h3 { margin-left: 10px; line-height: 30px; +} + +/** + * The collapsing button for the navigation bar. + */ +.nav-trigger { + position: fixed; + clip: rect(0, 0, 0, 0); +} + +.nav-trigger + label:after { + content: '»'; +} + +label { + z-index: 10; +} + +label[for="nav-trigger"] { + position: fixed; + margin-left: 0px; + padding-top: 100px; + padding-left: 5px; + width: 10px; + height: 80%; + cursor: pointer; + background-size: contain; + background-color: #D4F0FF; +} + +label[for="nav-trigger"]:hover { + background-color: #BEE9FF; +} + +.nav-trigger:checked + label { + margin-left: 200px; +} + +.nav-trigger:checked + label:after { + content: '«'; +} + +.nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 200px; +} + +.nav-trigger + label, div.content-with-sidebar { + transition: left 0.4s; +} + +/** + * Rules to collapse the menu automatically when the screen becomes too thin. + */ + +@media all and (max-width: 780px) { + + div.content-with-sidebar { + margin-left: 200px; + } + .nav-trigger + label:after { + content: '«'; + } + label[for="nav-trigger"] { + margin-left: 200px; + } + + .nav-trigger:checked + label { + margin-left: 0px; + } + .nav-trigger:checked + label:after { + content: '»'; + } + .nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 0px; + } + + div.container-index { + margin-left: -215px; + } + } \ No newline at end of file diff --git a/docs/js/main.js b/docs/js/main.js index f5d66b16f7b21..2329eb8327dd5 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -83,7 +83,7 @@ $(function() { // Display anchor links when hovering over headers. For documentation of the // configuration options, see the AnchorJS documentation. anchors.options = { - placement: 'left' + placement: 'right' }; anchors.add(); From 54c512ba906edfc25b8081ad67498e99d884452b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 10:22:48 -0800 Subject: [PATCH 1252/2089] [SPARK-8745] [SQL] remove GenerateProjection cc rxin Author: Davies Liu Closes #10316 from davies/remove_generate_projection. --- .../codegen/GenerateProjection.scala | 238 ------------------ .../codegen/GenerateSafeProjection.scala | 4 + .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/ExpressionEvalHelper.scala | 39 +-- .../expressions/MathFunctionsSuite.scala | 4 +- .../CodegenExpressionCachingSuite.scala | 18 -- .../sql/execution/local/ExpandNode.scala | 4 +- .../spark/sql/execution/local/LocalNode.scala | 18 -- 8 files changed, 11 insertions(+), 319 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala deleted file mode 100644 index f229f2000d8e1..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * Java can not access Projection (in package object) - */ -abstract class BaseProjection extends Projection {} - -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - -/** - * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] - * object is custom generated based on the output types of the [[Expression]] to avoid boxing of - * primitive values. - */ -object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) - - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) - - // Make Mutablility optional... - protected def create(expressions: Seq[Expression]): Projection = { - val ctx = newCodeGenContext() - val columns = expressions.zipWithIndex.map { - case (e, i) => - s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n") - - val initColumns = expressions.zipWithIndex.map { - case (e, i) => - val eval = e.gen(ctx) - s""" - { - // column$i - ${eval.code} - nullBits[$i] = ${eval.isNull}; - if (!${eval.isNull}) { - c$i = ${eval.value}; - } - } - """ - }.mkString("\n") - - val getCases = (0 until expressions.size).map { i => - s"case $i: return c$i;" - }.mkString("\n") - - val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n") - - val specificAccessorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: return c$i;") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val getter = "get" + ctx.primitiveTypeName(jt) - s""" - public $jt $getter(int i) { - if (isNullAt(i)) { - return ${ctx.defaultValue(jt)}; - } - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i - + " in $getter"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val specificMutatorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: { c$i = value; return; }") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val setter = "set" + ctx.primitiveTypeName(jt) - s""" - public void $setter(int i, $jt value) { - nullBits[i] = false; - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i + - " in $setter}"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = s"c$i" - val nonNull = e.dataType match { - case BooleanType => s"$col ? 0 : 1" - case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType | TimestampType => s"$col ^ ($col >>> 32)" - case FloatType => s"Float.floatToIntBits($col)" - case DoubleType => - s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" - case BinaryType => s"java.util.Arrays.hashCode($col)" - case _ => s"$col.hashCode()" - } - s"isNullAt($i) ? 0 : ($nonNull)" - } - - val hashUpdates: String = hashValues.map( v => - s""" - result *= 37; result += $v;""" - ).mkString("\n") - - val columnChecks = expressions.zipWithIndex.map { case (e, i) => - s""" - if (nullBits[$i] != row.nullBits[$i] || - (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { - return false; - } - """ - }.mkString("\n") - - val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n") - - val code = s""" - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); - } - - class SpecificProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} - - public SpecificProjection($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} - } - - public java.lang.Object apply(java.lang.Object r) { - // GenerateProjection does not work with UnsafeRows. - assert(!(r instanceof ${classOf[UnsafeRow].getName})); - return new SpecificRow((InternalRow) r); - } - - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { - - $columns - - public SpecificRow(InternalRow ${ctx.INPUT_ROW}) { - $initColumns - } - - public int numFields() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } - - public java.lang.Object genericGet(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases - } - return null; - } - public void update(int i, java.lang.Object value) { - if (value == null) { - setNullAt(i); - return; - } - nullBits[i] = false; - switch (i) { - $updateCases - } - } - $specificAccessorFunctions - $specificMutatorFunctions - - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - - public boolean equals(java.lang.Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; - } - return super.equals(other); - } - - public InternalRow copy() { - java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); - } - } - } - """ - - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + - CodeFormatter.format(code)) - - compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b7926bda3de19..13634b69457a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProjection extends Projection {} /** * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cd2ef7dcd0cd3..0c42e2fc7c5e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -38,7 +38,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val futures = (1 to 20).map { _ => future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) - GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 465f7d08aa142..f869a96edb1ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,7 +43,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } @@ -120,42 +119,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def checkEvaluationWithGeneratedProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { - - val plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) - - val actual = plan(inputRow) - val expectedRow = InternalRow(expected) - - // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our - // interpreted version. - if (actual.hashCode() != expectedRow.hashCode()) { - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |Expressions: $expression - |Code: $evaluated - """.stripMargin) - } - - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail("Incorrect Evaluation in codegen mode: " + - s"$expression, actual: $actual, expected: $expectedRow$input") - } - if (actual.copy() != expectedRow) { - fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") - } - } - protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, @@ -202,7 +165,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 88ed9fdd6465f..4ad65db0977c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 2d3f98dbbd3d1..c9616cdb26c20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -34,12 +34,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance.apply(null).getBoolean(0) === false) } - test("GenerateProjection should initialize expressions") { - val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateProjection.generate(Seq(expr)) - assert(instance.apply(null).getBoolean(0) === false) - } - test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr))() @@ -64,18 +58,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection should not share expression instances") { - val expr1 = MutableExpression() - val instance1 = GenerateProjection.generate(Seq(expr1)) - assert(instance1.apply(null).getBoolean(0) === false) - - val expr2 = MutableExpression() - expr2.mutableState = true - val instance2 = GenerateProjection.generate(Seq(expr2)) - assert(instance1.apply(null).getBoolean(0) === false) - assert(instance2.apply(null).getBoolean(0) === true) - } - test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 2aff156d18b54..85111bd6d1c98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} +import org.apache.spark.sql.catalyst.expressions._ case class ExpandNode( conf: SQLConf, @@ -36,7 +36,7 @@ case class ExpandNode( override def open(): Unit = { child.open() - groups = projections.map(ee => newProjection(ee, child.output)).toArray + groups = projections.map(ee => newMutableProjection(ee, child.output)()).toArray idx = groups.length } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index d3381eac91d43..6a882c9234df4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -103,24 +103,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin result } - protected def newProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema") - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { From 2eb5af5f0d3c424dc617bb1a18dd0210ea9ba0bc Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 16 Dec 2015 10:32:32 -0800 Subject: [PATCH 1253/2089] [SPARK-12318][SPARKR] Save mode in SparkR should be error by default shivaram Please help review. Author: Jeff Zhang Closes #10290 from zjffdu/SPARK-12318. --- R/pkg/R/DataFrame.R | 10 +++++----- docs/sparkr.md | 9 ++++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 764597d1e32b4..380a13fe2b7c6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1886,7 +1886,7 @@ setMethod("except", #' @param df A SparkSQL DataFrame #' @param path A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' #' @family DataFrame functions #' @rdname write.df @@ -1903,7 +1903,7 @@ setMethod("except", #' } setMethod("write.df", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1928,7 +1928,7 @@ setMethod("write.df", #' @export setMethod("saveDF", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ write.df(df, path, source, mode, ...) }) @@ -1951,7 +1951,7 @@ setMethod("saveDF", #' @param df A SparkSQL DataFrame #' @param tableName A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' #' @family DataFrame functions #' @rdname saveAsTable @@ -1968,7 +1968,7 @@ setMethod("saveDF", setMethod("saveAsTable", signature(df = "DataFrame", tableName = "character", source = "character", mode = "character"), - function(df, tableName, source = NULL, mode="append", ...){ + function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", diff --git a/docs/sparkr.md b/docs/sparkr.md index 01148786b79d7..9ddd2eda3fe8b 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -148,7 +148,7 @@ printSchema(people)
    The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example -to a Parquet file using `write.df` +to a Parquet file using `write.df` (Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API)
    {% highlight r %} @@ -387,3 +387,10 @@ The following functions are masked by the SparkR package: Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + + +# Migration Guide + +## Upgrading From SparkR 1.6 to 1.7 + + - Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API. From 22f6cd86fc2e2d6f6ad2c3aae416732c46ebf1b1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 10:34:30 -0800 Subject: [PATCH 1254/2089] [SPARK-12310][SPARKR] Add write.json and write.parquet for SparkR Add ```write.json``` and ```write.parquet``` for SparkR, and deprecated ```saveAsParquetFile```. Author: Yanbo Liang Closes #10281 from yanboliang/spark-12310. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/DataFrame.R | 51 +++++++++-- R/pkg/R/generics.R | 16 +++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 104 ++++++++++++---------- 4 files changed, 119 insertions(+), 56 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index cab39d68c3f52..ccc01fe169601 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -92,7 +92,9 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", - "write.df") + "write.df", + "write.json", + "write.parquet") exportClasses("Column") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 380a13fe2b7c6..0cfa12b997d69 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -596,17 +596,44 @@ setMethod("toJSON", RDD(jrdd, serializedMode = "string") }) -#' saveAsParquetFile +#' write.json +#' +#' Save the contents of a DataFrame as a JSON file (one object per line). Files written out +#' with this method can be read back in as a DataFrame using read.json(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.json +#' @name write.json +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' write.json(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.json", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "json", path)) + }) + +#' write.parquet #' #' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out -#' with this method can be read back in as a DataFrame using parquetFile(). +#' with this method can be read back in as a DataFrame using read.parquet(). #' #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' #' @family DataFrame functions -#' @rdname saveAsParquetFile -#' @name saveAsParquetFile +#' @rdname write.parquet +#' @name write.parquet #' @export #' @examples #'\dontrun{ @@ -614,12 +641,24 @@ setMethod("toJSON", #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- read.json(sqlContext, path) -#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#' write.parquet(df, "/tmp/sparkr-tmp1/") +#' saveAsParquetFile(df, "/tmp/sparkr-tmp2/") #'} +setMethod("write.parquet", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "parquet", path)) + }) + +#' @rdname write.parquet +#' @name saveAsParquetFile +#' @export setMethod("saveAsParquetFile", signature(x = "DataFrame", path = "character"), function(x, path) { - invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + .Deprecated("write.parquet") + write.parquet(x, path) }) #' Distinct diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c383e6e78b8b4..62be2ddc8f522 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -519,10 +519,6 @@ setGeneric("sample_frac", #' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) -#' @rdname saveAsParquetFile -#' @export -setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) - #' @rdname saveAsTable #' @export setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { @@ -541,6 +537,18 @@ setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @export setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) +#' @rdname write.json +#' @export +setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) + +#' @rdname write.parquet +#' @export +setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) + +#' @rdname write.parquet +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + #' @rdname schema #' @export setGeneric("schema", function(x) { standardGeneric("schema") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 071fd310fd58a..135c7576e5291 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -371,22 +371,49 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) }) -test_that("read.json()/jsonFile() on a local file returns a DataFrame", { +test_that("read/write json files", { + # Test read.df + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json df <- read.json(sqlContext, jsonPath) expect_is(df, "DataFrame") expect_equal(count(df), 3) - # read.json()/jsonFile() works with multiple input paths + + # Test write.df jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") write.df(df, jsonPath2, "json", mode="overwrite") - jsonDF1 <- read.json(sqlContext, c(jsonPath, jsonPath2)) + + # Test write.json + jsonPath3 <- tempfile(pattern="jsonPath3", fileext=".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(sqlContext, c(jsonPath2, jsonPath3)) expect_is(jsonDF1, "DataFrame") expect_equal(count(jsonDF1), 6) # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath, jsonPath2))) + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath2, jsonPath3))) expect_is(jsonDF2, "DataFrame") expect_equal(count(jsonDF2), 6) unlink(jsonPath2) + unlink(jsonPath3) }) test_that("jsonRDD() on a RDD with json string", { @@ -454,6 +481,9 @@ test_that("insertInto() on a registered table", { expect_equal(count(sql(sqlContext, "select * from table1")), 2) expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") + + unlink(jsonPath2) + unlink(parquetPath2) }) test_that("table() returns a new DataFrame", { @@ -848,33 +878,6 @@ test_that("column calculation", { expect_equal(count(df2), 3) }) -test_that("read.df() from json file", { - df <- read.df(sqlContext, jsonPath, "json") - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - - # Check if we can apply a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_is(df1, "DataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Run the same with loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_is(df2, "DataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) -}) - -test_that("write.df() as parquet file", { - df <- read.df(sqlContext, jsonPath, "json") - write.df(df, parquetPath, "parquet", mode="overwrite") - df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) -}) - test_that("test HiveContext", { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ @@ -895,6 +898,8 @@ test_that("test HiveContext", { df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") expect_equal(count(df3), 3) + + unlink(jsonPath2) }) test_that("column operators", { @@ -1333,6 +1338,9 @@ test_that("join() and merge() on a DataFrame", { expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) + + unlink(jsonPath2) + unlink(jsonPath3) }) test_that("toJSON() returns an RDD of the correct values", { @@ -1396,6 +1404,8 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { # Test base::intersect is working expect_equal(length(intersect(1:20, 3:23)), 18) + + unlink(jsonPath2) }) test_that("withColumn() and withColumnRenamed()", { @@ -1440,31 +1450,35 @@ test_that("mutate(), transform(), rename() and names()", { detach(airquality) }) -test_that("write.df() on DataFrame and works with read.parquet", { - df <- read.json(sqlContext, jsonPath) +test_that("read/write Parquet files", { + df <- read.df(sqlContext, jsonPath, "json") + # Test write.df and read.df write.df(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- read.parquet(sqlContext, parquetPath) - expect_is(parquetDF, "DataFrame") - expect_equal(count(df), count(parquetDF)) -}) + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) -test_that("read.parquet()/parquetFile() works with multiple input paths", { - df <- read.json(sqlContext, jsonPath) - write.df(df, parquetPath, "parquet", mode="overwrite") + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- read.parquet(sqlContext, c(parquetPath, parquetPath2)) + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(sqlContext, c(parquetPath2, parquetPath3)) expect_is(parquetDF, "DataFrame") expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath, parquetPath2)) + parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath2, parquetPath3)) expect_is(parquetDF2, "DataFrame") expect_equal(count(parquetDF2), count(df) * 2) # Test if varargs works with variables saveMode <- "overwrite" mergeSchema <- "true" - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) }) test_that("describe() and summarize() on a DataFrame", { From 26d70bd2b42617ff731b6e9e6d77933b38597ebe Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Dec 2015 10:43:45 -0800 Subject: [PATCH 1255/2089] [SPARK-12215][ML][DOC] User guide section for KMeans in spark.ml cc jkbradley Author: Yu ISHIKAWA Closes #10244 from yu-iskw/SPARK-12215. --- docs/ml-clustering.md | 71 +++++++++++++++++++ .../spark/examples/ml/JavaKMeansExample.java | 8 ++- .../spark/examples/ml/KMeansExample.scala | 49 ++++++------- 3 files changed, 100 insertions(+), 28 deletions(-) diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index a59f7e3005a3e..440c455cd077c 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -11,6 +11,77 @@ In this section, we introduce the pipeline API for [clustering in mllib](mllib-c * This will become a table of contents (this text will be scraped). {:toc} +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a +predefined number of clusters. The MLlib implementation includes a parallelized +variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method +called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). + +`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model. + +### Input Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    predictionColInt"prediction"Predicted cluster center
    + +### Example + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details. + +{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %} +
    + +
    + + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index 47665ff2b1f3c..96481d882a5d5 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -23,6 +23,9 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +// $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; import org.apache.spark.mllib.linalg.Vector; @@ -30,11 +33,10 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +// $example off$ /** @@ -74,6 +76,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); + // $example on$ // Loads data JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; @@ -91,6 +94,7 @@ public static void main(String[] args) { for (Vector center: centers) { System.out.println(center); } + // $example off$ jsc.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index 5ce38462d1181..af90652b55a16 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -17,57 +17,54 @@ package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -import org.apache.spark.ml.clustering.KMeans -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example demonstrating a k-means clustering. * Run with * {{{ - * bin/run-example ml.KMeansExample + * bin/run-example ml.KMeansExample * }}} */ object KMeansExample { - final val FEATURES_COL = "features" - def main(args: Array[String]): Unit = { - if (args.length != 2) { - // scalastyle:off println - System.err.println("Usage: ml.KMeansExample ") - // scalastyle:on println - System.exit(1) - } - val input = args(0) - val k = args(1).toInt - // Creates a Spark context and a SQL context val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) + // $example on$ + // Crates a DataFrame + val dataset: DataFrame = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(0.0, 0.0, 0.0)), + (2, Vectors.dense(0.1, 0.1, 0.1)), + (3, Vectors.dense(0.2, 0.2, 0.2)), + (4, Vectors.dense(9.0, 9.0, 9.0)), + (5, Vectors.dense(9.1, 9.1, 9.1)), + (6, Vectors.dense(9.2, 9.2, 9.2)) + )).toDF("id", "features") // Trains a k-means model val kmeans = new KMeans() - .setK(k) - .setFeaturesCol(FEATURES_COL) + .setK(2) + .setFeaturesCol("features") + .setPredictionCol("prediction") val model = kmeans.fit(dataset) // Shows the result - // scalastyle:off println println("Final Centers: ") model.clusterCenters.foreach(println) - // scalastyle:on println + // $example off$ sc.stop() } } +// scalastyle:on println From ad8c1f0b840284d05da737fb2cc5ebf8848f4490 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Wed, 16 Dec 2015 10:54:15 -0800 Subject: [PATCH 1256/2089] [SPARK-12345][MESOS] Filter SPARK_HOME when submitting Spark jobs with Mesos cluster mode. SPARK_HOME is now causing problem with Mesos cluster mode since spark-submit script has been changed recently to take precendence when running spark-class scripts to look in SPARK_HOME if it's defined. We should skip passing SPARK_HOME from the Spark client in cluster mode with Mesos, since Mesos shouldn't use this configuration but should use spark.executor.home instead. Author: Timothy Chen Closes #10332 from tnachen/scheduler_ui. --- .../apache/spark/deploy/rest/mesos/MesosRestServer.scala | 7 ++++++- .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 868cc35d06ef3..24510db2bd0ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -94,7 +94,12 @@ private[mesos] class MesosSubmitRequestServlet( val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // We don't want to pass down SPARK_HOME when launching Spark apps + // with Mesos cluster mode since it's populated by default on the client and it will + // cause spark-submit script to look for files in SPARK_HOME instead. + // We only need the ability to specify where to find spark-submit script + // which user can user spark.executor.home or spark.home configurations. + val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 721861fbbc517..573355ba58132 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper - * methods and Mesos scheduler will use. + * methods the Mesos scheduler will use. */ private[mesos] trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered From 7b6dc29d0ebbfb3bb941130f8542120b6bc3e234 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Dec 2015 10:55:42 -0800 Subject: [PATCH 1257/2089] [SPARK-6518][MLLIB][EXAMPLE][DOC] Add example code and user guide for bisecting k-means This PR includes only an example code in order to finish it quickly. I'll send another PR for the docs soon. Author: Yu ISHIKAWA Closes #9952 from yu-iskw/SPARK-6518. --- docs/mllib-clustering.md | 35 ++++++++++ docs/mllib-guide.md | 1 + .../mllib/JavaBisectingKMeansExample.java | 69 +++++++++++++++++++ .../mllib/BisectingKMeansExample.scala | 60 ++++++++++++++++ 4 files changed, 165 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 48d64cd402b11..93cd0c1c61ae9 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -718,6 +718,41 @@ sameModel = LDAModel.load(sc, "myModelPath")
    +## Bisecting k-means + +Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering. + +Bisecting k-means is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering). +Hierarchical clustering is one of the most commonly used method of cluster analysis which seeks to build a hierarchy of clusters. +Strategies for hierarchical clustering generally fall into two types: + +- Agglomerative: This is a "bottom up" approach: each observation starts in its own cluster, and pairs of clusters are merged as one moves up the hierarchy. +- Divisive: This is a "top down" approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. + +Bisecting k-means algorithm is a kind of divisive algorithms. +The implementation in MLlib has the following parameters: + +* *k*: the desired number of leaf clusters (default: 4). The actual number could be smaller if there are no divisible leaf clusters. +* *maxIterations*: the max number of k-means iterations to split clusters (default: 20) +* *minDivisibleClusterSize*: the minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster (default: 1) +* *seed*: a random seed (default: hash value of the class name) + +**Examples** + +
    +
    +Refer to the [`BisectingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeans) and [`BisectingKMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeansModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala %} +
    + +
    +Refer to the [`BisectingKMeans` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeans.html) and [`BisectingKMeansModel` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeansModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java %} +
    +
    + ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 7fef6b5c61f99..680ed4861d9f4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -49,6 +49,7 @@ We list major functionality from both below, with links to detailed guides. * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [bisecting k-means](mllib-clustering.html#bisecting-kmeans) * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java new file mode 100644 index 0000000000000..0001500f4fa5a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +// $example on$ +import com.google.common.collect.Lists; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.BisectingKMeans; +import org.apache.spark.mllib.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaBisectingKMeansExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + ArrayList localData = Lists.newArrayList( + Vectors.dense(0.1, 0.1), Vectors.dense(0.3, 0.3), + Vectors.dense(10.1, 10.1), Vectors.dense(10.3, 10.3), + Vectors.dense(20.1, 20.1), Vectors.dense(20.3, 20.3), + Vectors.dense(30.1, 30.1), Vectors.dense(30.3, 30.3) + ); + JavaRDD data = sc.parallelize(localData, 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4); + BisectingKMeansModel model = bkm.run(data); + + System.out.println("Compute Cost: " + model.computeCost(data)); + for (Vector center: model.clusterCenters()) { + System.out.println(""); + } + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala new file mode 100644 index 0000000000000..3a596cccb87d3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +// scalastyle:off println +// $example on$ +import org.apache.spark.mllib.clustering.BisectingKMeans +import org.apache.spark.mllib.linalg.{Vector, Vectors} +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example demonstrating a bisecting k-means clustering in spark.mllib. + * + * Run with + * {{{ + * bin/run-example mllib.BisectingKMeansExample + * }}} + */ +object BisectingKMeansExample { + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("mllib.BisectingKMeansExample") + val sc = new SparkContext(sparkConf) + + // $example on$ + // Loads and parses data + def parse(line: String): Vector = Vectors.dense(line.split(" ").map(_.toDouble)) + val data = sc.textFile("data/mllib/kmeans_data.txt").map(parse).cache() + + // Clustering the data into 6 clusters by BisectingKMeans. + val bkm = new BisectingKMeans().setK(6) + val model = bkm.run(data) + + // Show the compute cost and the cluster centers + println(s"Compute Cost: ${model.computeCost(data)}") + model.clusterCenters.zipWithIndex.foreach { case (center, idx) => + println(s"Cluster Center ${idx}: ${center}") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println From 860dc7f2f8dd01f2562ba83b7af27ba29d91cb62 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 11:05:37 -0800 Subject: [PATCH 1258/2089] [SPARK-9694][ML] Add random seed Param to Scala CrossValidator Add random seed Param to Scala CrossValidator Author: Yanbo Liang Closes #9108 from yanboliang/spark-9694. --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 11 ++++++++--- .../scala/org/apache/spark/mllib/util/MLUtils.scala | 8 ++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5c09f1aaff80d..40f8857fc5866 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams { +private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema @@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 414ea99cfd8c8..4c9151f0cb4fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -265,6 +265,14 @@ object MLUtils { */ @Since("1.0.0") def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + kFold(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFold()]] taking a Long seed. + */ + @Since("2.0.0") + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, From d252b2d544a75f6c5523be3492494955050acf50 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 11:07:54 -0800 Subject: [PATCH 1259/2089] [SPARK-12309][ML] Use sqlContext from MLlibTestSparkContext for spark.ml test suites Use ```sqlContext``` from ```MLlibTestSparkContext``` rather than creating new one for spark.ml test suites. I have checked thoroughly and found there are four test cases need to update. cc mengxr jkbradley Author: Yanbo Liang Closes #10279 from yanboliang/spark-12309. --- .../scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala | 4 +--- .../scala/org/apache/spark/ml/feature/NormalizerSuite.scala | 3 +-- .../scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala | 4 +--- mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala | 2 +- .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 3 +-- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 09183fe65b722..035bfc07b684d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -21,13 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.dense(1, 0, Long.MinValue), Vectors.dense(2, 0, 0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index de3d438ce83be..468833901995a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -61,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - val sqlContext = new SQLContext(sc) dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 74706a23e0936..8acc3369c489c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -54,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } test("Test vector slicer") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 460849c79f04f..4e2d0e93bd412 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = new SQLContext(data.sparkContext) + val sqlContext = SQLContext.getOrCreate(data.sparkContext) import sqlContext.implicits._ val df = data.toDF() val numFeatures = data.first().features.size diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index dd6366050c020..d281084f913c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType class CrossValidatorSuite @@ -39,7 +39,6 @@ class CrossValidatorSuite override def beforeAll(): Unit = { super.beforeAll() - val sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } From 6a880afa831348b413ba95b98ff089377b950666 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 16 Dec 2015 11:29:47 -0800 Subject: [PATCH 1260/2089] [SPARK-12361][PYSPARK][TESTS] Should set PYSPARK_DRIVER_PYTHON before Python tests Although this patch still doesn't solve the issue why the return code is 0 (see JIRA description), it resolves the issue of python version mismatch. Author: Jeff Zhang Closes #10322 from zjffdu/SPARK-12361. --- python/run-tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/run-tests.py b/python/run-tests.py index f5857f8c62214..ee73eb1506ca4 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -56,7 +56,8 @@ def print_red(text): def run_individual_python_test(test_name, pyspark_python): env = dict(os.environ) - env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python), + 'PYSPARK_DRIVER_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: From 8148cc7a5c9f52c82c2eb7652d9aeba85e72d406 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 16 Dec 2015 11:53:04 -0800 Subject: [PATCH 1261/2089] [SPARK-11608][MLLIB][DOC] Added migration guide for MLlib 1.6 No known breaking changes, but some deprecations and changes of behavior. CC: mengxr Author: Joseph K. Bradley Closes #10235 from jkbradley/mllib-guide-update-1.6. --- docs/mllib-guide.md | 38 ++++++++++++++++++++-------------- docs/mllib-migration-guides.md | 19 +++++++++++++++++ 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 680ed4861d9f4..7ef91a178ccd1 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -74,7 +74,7 @@ We list major functionality from both below, with links to detailed guides. * [Advanced topics](ml-advanced.html) Some techniques are not available yet in spark.ml, most notably dimensionality reduction -Users can seemlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. +Users can seamlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. # Dependencies @@ -101,24 +101,32 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. -## From 1.4 to 1.5 +## From 1.5 to 1.6 -In the `spark.mllib` package, there are no break API changes but several behavior changes: +There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are +deprecations and changes of behavior. -* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): - `RegressionMetrics.explainedVariance` returns the average regression sum of squares. -* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become - sorted. -* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default - convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. +Deprecations: -In the `spark.ml` package, there exists one break API change and one behavior change: +* [SPARK-11358](https://issues.apache.org/jira/browse/SPARK-11358): + In `spark.mllib.clustering.KMeans`, the `runs` parameter has been deprecated. +* [SPARK-10592](https://issues.apache.org/jira/browse/SPARK-10592): + In `spark.ml.classification.LogisticRegressionModel` and + `spark.ml.regression.LinearRegressionModel`, the `weights` field has been deprecated in favor of + the new name `coefficients`. This helps disambiguate from instance (row) "weights" given to + algorithms. -* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed - from `Params.setDefault` due to a - [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). -* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is - added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. +Changes of behavior: + +* [SPARK-7770](https://issues.apache.org/jira/browse/SPARK-7770): + `spark.mllib.tree.GradientBoostedTrees`: `validationTol` has changed semantics in 1.6. + Previously, it was a threshold for absolute change in error. Now, it resembles the behavior of + `GradientDescent`'s `convergenceTol`: For large errors, it uses relative error (relative to the + previous error); for small errors (`< 0.01`), it uses absolute error. +* [SPARK-11069](https://issues.apache.org/jira/browse/SPARK-11069): + `spark.ml.feature.RegexTokenizer`: Previously, it did not convert strings to lowercase before + tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the + behavior of the simpler `Tokenizer` transformer. ## Previous Spark versions diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 73e4fddf67fc0..f3daef2dbadbe 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,25 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.4 to 1.5 + +In the `spark.mllib` package, there are no breaking API changes but several behavior changes: + +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. + +In the `spark.ml` package, there exists one breaking API change and one behavior change: + +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. + ## From 1.3 to 1.4 In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: From 1a8b2a17db7ab7a213d553079b83274aeebba86f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 12:59:22 -0800 Subject: [PATCH 1262/2089] [SPARK-12364][ML][SPARKR] Add ML example for SparkR We have DataFrame example for SparkR, we also need to add ML example under ```examples/src/main/r```. cc mengxr jkbradley shivaram Author: Yanbo Liang Closes #10324 from yanboliang/spark-12364. --- examples/src/main/r/ml.R | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 examples/src/main/r/ml.R diff --git a/examples/src/main/r/ml.R b/examples/src/main/r/ml.R new file mode 100644 index 0000000000000..a0c903939cbbb --- /dev/null +++ b/examples/src/main/r/ml.R @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/sparkR examples/src/main/r/ml.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkContext and SQLContext +sc <- sparkR.init(appName="SparkR-ML-example") +sqlContext <- sparkRSQL.init(sc) + +# Train GLM of family 'gaussian' +training1 <- suppressWarnings(createDataFrame(sqlContext, iris)) +test1 <- training1 +model1 <- glm(Sepal_Length ~ Sepal_Width + Species, training1, family = "gaussian") + +# Model summary +summary(model1) + +# Prediction +predictions1 <- predict(model1, test1) +head(select(predictions1, "Sepal_Length", "prediction")) + +# Train GLM of family 'binomial' +training2 <- filter(training1, training1$Species != "setosa") +test2 <- training2 +model2 <- glm(Species ~ Sepal_Length + Sepal_Width, data = training2, family = "binomial") + +# Model summary +summary(model2) + +# Prediction (Currently the output of prediction for binomial GLM is the indexed label, +# we need to transform back to the original string label later) +predictions2 <- predict(model2, test2) +head(select(predictions2, "Species", "prediction")) + +# Stop the SparkContext now +sparkR.stop() From a783a8ed49814a09fde653433a3d6de398ddf888 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 16 Dec 2015 13:18:56 -0800 Subject: [PATCH 1263/2089] [SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder Author: Wenchen Fan Closes #10293 from cloud-fan/err-msg. --- .../spark/sql/catalyst/dsl/package.scala | 3 +- .../catalyst/encoders/ExpressionEncoder.scala | 36 ++++++++++- .../expressions/complexTypeExtractors.scala | 10 ++-- .../encoders/EncoderResolutionSuite.scala | 60 ++++++++++++++++--- .../expressions/ComplexTypeSuite.scala | 2 +- 5 files changed, 93 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e50971173c499..8102c93c6f107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -227,9 +227,10 @@ package object dsl { AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ - def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) def struct(structType: StructType): AttributeReference = AttributeReference(s, structType, nullable = true)() + def struct(attrs: AttributeReference*): AttributeReference = + struct(StructType.fromAttributes(attrs)) } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 363178b0e21a2..7a4401cf5810e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -244,9 +244,41 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) + def fail(st: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + + " - Target schema: " + this.schema.simpleString) + } + + var maxOrdinal = -1 + fromRowExpression.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { + fail(StructType.fromAttributes(schema), maxOrdinal) + } + val unbound = fromRowExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) + case b: BoundReference => schema(b.ordinal) + } + + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + unbound.foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } } val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 10ce10aaf6da2..58f6a7ec8a5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,14 +104,14 @@ object ExtractValue { case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${name.getOrElse(field.name)}" + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal, field.dataType) + input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 0289988342e78..815a03f7c1a89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -64,22 +64,21 @@ class EncoderResolutionSuite extends PlanTest { val innerCls = classOf[StringLongClass] val cls = classOf[ComplexClass] - val structType = new StructType().add("a", IntegerType).add("b", LongType) - val attrs = Seq('a.int, 'b.struct(structType)) + val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, Seq( 'a.int.cast(LongType), If( - 'b.struct(structType).isNull, + 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), NewInstance( innerCls, Seq( toExternalString( - GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), - GetStructField('b.struct(structType), 1, Some("b"))), + GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), + GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), false, ObjectType(innerCls)) )), @@ -94,8 +93,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[Long]) val cls = classOf[StringLongClass] - val structType = new StructType().add("a", StringType).add("b", ByteType) - val attrs = Seq('a.struct(structType), 'b.int) + val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( classOf[Tuple2[_, _]], @@ -103,8 +101,8 @@ class EncoderResolutionSuite extends PlanTest { NewInstance( cls, Seq( - toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), - GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), false, ObjectType(cls)), 'b.int.cast(LongType)), @@ -113,6 +111,50 @@ class EncoderResolutionSuite extends PlanTest { compareExpressions(fromRowExpr, expected) } + test("the real number of fields doesn't match encoder schema: tuple encoder") { + val encoder = ExpressionEncoder[(String, Long)] + + { + val attrs = Seq('a.string, 'b.long, 'c.int) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + + { + val attrs = Seq('a.string) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + } + + test("the real number of fields doesn't match encoder schema: nested tuple encoder") { + val encoder = ExpressionEncoder[(String, (Long, String))] + + { + val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + + { + val attrs = Seq('a.string, 'b.struct('x.long)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + } + private def toExternalString(e: Expression): Expression = { Invoke(e, "toString", ObjectType(classOf[String]), Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 62fd47234b33b..9f1b19253e7c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { "b", create_row(Map("a" -> "b"))) checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), 1, create_row(create_row(1))) } From edf65cd961b913ef54104770630a50fd4b120b4b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 16 Dec 2015 13:22:34 -0800 Subject: [PATCH 1264/2089] [SPARK-12164][SQL] Decode the encoded values and then display Based on the suggestions from marmbrus cloud-fan in https://github.com/apache/spark/pull/10165 , this PR is to print the decoded values(user objects) in `Dataset.show` ```scala implicit val kryoEncoder = Encoders.kryo[KryoClassData] val ds = Seq(KryoClassData("a", 1), KryoClassData("b", 2), KryoClassData("c", 3)).toDS() ds.show(20, false); ``` The current output is like ``` +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |value | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |[1, 0, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 75, 114, 121, 111, 67, 108, 97, 115, 115, 68, 97, 116, -31, 1, 1, -126, 97, 2]| |[1, 0, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 75, 114, 121, 111, 67, 108, 97, 115, 115, 68, 97, 116, -31, 1, 1, -126, 98, 4]| |[1, 0, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 75, 114, 121, 111, 67, 108, 97, 115, 115, 68, 97, 116, -31, 1, 1, -126, 99, 6]| +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` After the fix, it will be like the below if and only if the users override the `toString` function in the class `KryoClassData` ```scala override def toString: String = s"KryoClassData($a, $b)" ``` ``` +-------------------+ |value | +-------------------+ |KryoClassData(a, 1)| |KryoClassData(b, 2)| |KryoClassData(c, 3)| +-------------------+ ``` If users do not override the `toString` function, the results will be like ``` +---------------------------------------+ |value | +---------------------------------------+ |org.apache.spark.sql.KryoClassData68ef| |org.apache.spark.sql.KryoClassData6915| |org.apache.spark.sql.KryoClassData693b| +---------------------------------------+ ``` Question: Should we add another optional parameter in the function `show`? It will decide if the function `show` will display the hex values or the object values? Author: gatorsmile Closes #10215 from gatorsmile/showDecodedValue. --- .../org/apache/spark/sql/DataFrame.scala | 50 +------------- .../scala/org/apache/spark/sql/Dataset.scala | 37 ++++++++++- .../spark/sql/execution/Queryable.scala | 65 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 15 +++++ .../org/apache/spark/sql/DatasetSuite.scala | 14 ++++ 5 files changed, 133 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd48266770..6250e952169d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -165,13 +165,11 @@ class DataFrame private[sql]( * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) - val sb = new StringBuilder val takeResult = take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." @@ -179,6 +177,7 @@ class DataFrame private[sql]( row.toSeq.map { cell => val str = cell match { case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") case array: Array[_] => array.mkString("[", ", ", "]") case seq: Seq[_] => seq.mkString("[", ", ", "]") case _ => cell.toString @@ -187,50 +186,7 @@ class DataFrame private[sql]( }: Seq[String] } - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - - sb.append(sep) - - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - - sb.append(sep) - - // For Data that has more than "numRows" records - if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") - } - - sb.toString() + formatString ( rows, numRows, hasMoreData, truncate ) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dc69822e92908..79b4244ac0cd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -225,7 +225,42 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate) + // scalastyle:off println + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + // scalastyle:on println + + /** + * Compose the string representing rows for output + * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right + */ + override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) + + // For array values, replace Seq and Array with square brackets + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map { + case r: Row => r + case tuple: Product => Row.fromTuple(tuple) + case o => Row(o) + } map { row => + row.toSeq.map { cell => + val str = cell match { + case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => cell.toString + } + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str + }: Seq[String] + }) + + formatString ( rows, numRows, hasMoreData, truncate ) + } /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index f2f5997d1b7c6..b397d42612cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import scala.util.control.NonFatal +import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType @@ -42,4 +43,68 @@ private[sql] trait Queryable { def explain(extended: Boolean): Unit def explain(): Unit + + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String + + /** + * Format the string representing rows for output + * @param rows The rows to show + * @param numRows Number of rows to show + * @param hasMoreData Whether some rows are not shown due to the limit + * @param truncate Whether truncate long strings and align cells right + * + */ + private[sql] def formatString ( + rows: Seq[Seq[String]], + numRows: Int, + hasMoreData : Boolean, + truncate: Boolean = true): String = { + val sb = new StringBuilder + val numCols = schema.fieldNames.length + + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c0bbf73ab1188..0644bdaaa35ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -585,6 +585,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary") { + val df = Seq( + ("12".getBytes, "ABC.".getBytes), + ("34".getBytes, "12346".getBytes) + ).toDF() + val expectedAnswer = """+-------+----------------+ + || _1| _2| + |+-------+----------------+ + ||[31 32]| [41 42 43 2E]| + ||[33 34]|[31 32 33 34 36]| + |+-------+----------------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8f8db318261db..f1b6b98dc160c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -426,6 +426,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } + test("showString: Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + val expectedAnswer = """+-----------+ + || value| + |+-----------+ + ||KryoData(1)| + ||KryoData(2)| + |+-----------+ + |""".stripMargin + assert(ds.showString(10) === expectedAnswer) + } + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() From 9657ee87888422c5596987fe760b49117a0ea4e2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 16 Dec 2015 13:24:49 -0800 Subject: [PATCH 1265/2089] [SPARK-11677][SQL] ORC filter tests all pass if filters are actually not pushed down. Currently ORC filters are not tested properly. All the tests pass even if the filters are not pushed down or disabled. In this PR, I add some logics for this. Since ORC does not filter record by record fully, this checks the count of the result and if it contains the expected values. Author: hyukjinkwon Closes #9687 from HyukjinKwon/SPARK-11677. --- .../spark/sql/hive/orc/OrcQuerySuite.scala | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 7efeab528c1dd..2156806d21f96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -350,28 +350,47 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { import testImplicits._ - val path = dir.getCanonicalPath - sqlContext.range(10).coalesce(1).write.orc(path) + + // For field "a", the first column has odds integers. This is to check the filtered count + // when `isNull` is performed. For Field "b", `isNotNull` of ORC file filters rows + // only when all the values are null (maybe this works differently when the data + // or query is complicated). So, simply here a column only having `null` is added. + val data = (0 until 10).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + val nullValue: Option[String] = None + (maybeInt, nullValue) + } + createDataFrame(data).toDF("a", "b").write.orc(path) val df = sqlContext.read.orc(path) - def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { - checkAnswer(df.where(pred), answer.map(Row(_))) + def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { + val sourceDf = stripSparkFilter(df.where(pred)) + val data = sourceDf.collect().toSet + val expectedData = answer.toSet + + // When a filter is pushed to ORC, ORC can apply it to rows. So, we can check + // the number of rows returned from the ORC to make sure our filter pushdown work. + // A tricky part is, ORC does not process filter rows fully but return some possible + // results. So, this checks if the number of result is less than the original count + // of data, and then checks if it contains the expected data. + val isOrcFiltered = sourceDf.count < 10 && expectedData.subsetOf(data) + assert(isOrcFiltered) } - checkPredicate('id === 5, Seq(5L)) - checkPredicate('id <=> 5, Seq(5L)) - checkPredicate('id < 5, 0L to 4L) - checkPredicate('id <= 5, 0L to 5L) - checkPredicate('id > 5, 6L to 9L) - checkPredicate('id >= 5, 5L to 9L) - checkPredicate('id.isNull, Seq.empty[Long]) - checkPredicate('id.isNotNull, 0L to 9L) - checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) - checkPredicate('id > 0 && 'id < 3, 1L to 2L) - checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) - checkPredicate(!('id > 3), 0L to 3L) - checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + checkPredicate('a === 5, List(5).map(Row(_, null))) + checkPredicate('a <=> 5, List(5).map(Row(_, null))) + checkPredicate('a < 5, List(1, 3).map(Row(_, null))) + checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate('a > 5, List(7, 9).map(Row(_, null))) + checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate('a.isNull, List(null).map(Row(_, null))) + checkPredicate('b.isNotNull, List()) + checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) + checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) + checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) } } } From 3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb Mon Sep 17 00:00:00 2001 From: Martin Menestret Date: Wed, 16 Dec 2015 14:05:35 -0800 Subject: [PATCH 1266/2089] [SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed Extend CrossValidator with HasSeed in PySpark. This PR replaces [https://github.com/apache/spark/pull/7997] CC: yanboliang thunterdb mmenestret Would one of you mind taking a look? Thanks! Author: Joseph K. Bradley Author: Martin MENESTRET Closes #10268 from jkbradley/pyspark-cv-seed. --- python/pyspark/ml/tuning.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 705ee53685752..08f8db57f4400 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -19,8 +19,9 @@ import numpy as np from pyspark import since -from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model +from pyspark.ml.param import Params, Param +from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -89,7 +90,7 @@ def build(self): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class CrossValidator(Estimator, HasSeed): """ K-fold cross validation. @@ -129,9 +130,11 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated @@ -151,9 +154,11 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs @@ -225,9 +230,10 @@ def _fit(self, dataset): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h From 27b98e99d21a0cc34955337f82a71a18f9220ab2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 15:48:11 -0800 Subject: [PATCH 1267/2089] [SPARK-12380] [PYSPARK] use SQLContext.getOrCreate in mllib MLlib should use SQLContext.getOrCreate() instead of creating new SQLContext. Author: Davies Liu Closes #10338 from davies/create_context. --- python/pyspark/mllib/common.py | 6 +++--- python/pyspark/mllib/evaluation.py | 10 +++++----- python/pyspark/mllib/feature.py | 4 +--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index a439a488de5cc..9fda1b1682f57 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -102,7 +102,7 @@ def _java2py(sc, r, encoding="bytes"): return RDD(jrdd, sc) if clsName == 'DataFrame': - return DataFrame(r, SQLContext(sc)) + return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) @@ -125,7 +125,7 @@ def callJavaFunc(sc, func, *args): def callMLlibFunc(name, *args): """ Call API in PythonMLLibAPI """ - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() api = getattr(sc._jvm.PythonMLLibAPI(), name) return callJavaFunc(sc, api, *args) @@ -135,7 +135,7 @@ class JavaModelWrapper(object): Wrapper for the model in JVM """ def __init__(self, java_model): - self._sc = SparkContext._active_spark_context + self._sc = SparkContext.getOrCreate() self._java_model = java_model def __del__(self): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 8c87ee9df2132..22e68ea5b4511 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -44,7 +44,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): def __init__(self, scoreAndLabels): sc = scoreAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ StructField("score", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -103,7 +103,7 @@ class RegressionMetrics(JavaModelWrapper): def __init__(self, predictionAndObservations): sc = predictionAndObservations.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("observation", DoubleType(), nullable=False)])) @@ -197,7 +197,7 @@ class MulticlassMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -338,7 +338,7 @@ class RankingMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_model = callMLlibFunc("newRankingMetrics", df._jdf) @@ -424,7 +424,7 @@ class MultilabelMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7254679ebb533..acd7ec57d69da 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext, since +from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -100,8 +100,6 @@ def transform(self, vector): :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ - sc = SparkContext._active_spark_context - assert sc is not None, "SparkContext should be initialized first" if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: From 861549acdbc11920cde51fc57752a8bc241064e5 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 16 Dec 2015 16:13:48 -0800 Subject: [PATCH 1268/2089] [MINOR] Add missing interpolation in NettyRPCEnv ``` Exception in thread "main" org.apache.spark.rpc.RpcTimeoutException: Cannot receive any reply in ${timeout.duration}. This timeout is controlled by spark.rpc.askTimeout at org.apache.spark.rpc.RpcTimeout.org$apache$spark$rpc$RpcTimeout$$createRpcTimeoutException(RpcTimeout.scala:48) at org.apache.spark.rpc.RpcTimeout$$anonfun$addMessageIfTimeout$1.applyOrElse(RpcTimeout.scala:63) at org.apache.spark.rpc.RpcTimeout$$anonfun$addMessageIfTimeout$1.applyOrElse(RpcTimeout.scala:59) at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:33) ``` Author: Andrew Or Closes #10334 from andrewor14/rpc-typo. --- .../src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f82fd4eb5756d..de3db6ba624f8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -232,7 +232,7 @@ private[netty] class NettyRpcEnv( val timeoutCancelable = timeoutScheduler.schedule(new Runnable { override def run(): Unit = { promise.tryFailure( - new TimeoutException("Cannot receive any reply in ${timeout.duration}")) + new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) promise.future.onComplete { v => From ce5fd4008e890ef8ebc2d3cb703a666783ad6c02 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 16 Dec 2015 17:05:57 -0800 Subject: [PATCH 1269/2089] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1217 (requested by ankurdave, srowen) Closes #4650 (requested by andrewor14) Closes #5307 (requested by vanzin) Closes #5664 (requested by andrewor14) Closes #5713 (requested by marmbrus) Closes #5722 (requested by andrewor14) Closes #6685 (requested by srowen) Closes #7074 (requested by srowen) Closes #7119 (requested by andrewor14) Closes #7997 (requested by jkbradley) Closes #8292 (requested by srowen) Closes #8975 (requested by andrewor14, vanzin) Closes #8980 (requested by andrewor14, davies) From 38d9795a4fa07086d65ff705ce86648345618736 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 Dec 2015 19:01:05 -0800 Subject: [PATCH 1270/2089] [SPARK-10248][CORE] track exceptions in dagscheduler event loop in tests `DAGSchedulerEventLoop` normally only logs errors (so it can continue to process more events, from other jobs). However, this is not desirable in the tests -- the tests should be able to easily detect any exception, and also shouldn't silently succeed if there is an exception. This was suggested by mateiz on https://github.com/apache/spark/pull/7699. It may have already turned up an issue in "zero split job". Author: Imran Rashid Closes #8466 from squito/SPARK-10248. --- .../apache/spark/scheduler/DAGScheduler.scala | 5 ++-- .../spark/scheduler/DAGSchedulerSuite.scala | 28 +++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8d0e0c8624a55..b128ed50cad52 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -805,7 +805,8 @@ class DAGScheduler( private[scheduler] def cleanUpAfterSchedulerStop() { for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") + val error = + new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down") job.listener.jobFailed(error) // Tell the listeners that all of the running stages have ended. Don't bother // cancelling the stages because if the DAG scheduler is stopped, the entire application @@ -1295,7 +1296,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case other => + case _: ExecutorLostFailure | TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 653d41fc053c9..2869f0fde4c53 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -45,6 +45,13 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) case NonFatal(e) => onError(e) } } + + override def onError(e: Throwable): Unit = { + logError("Error in DAGSchedulerEventLoop: ", e) + dagScheduler.stop() + throw e + } + } /** @@ -300,13 +307,18 @@ class DAGSchedulerSuite test("zero split job") { var numResults = 0 + var failureReason: Option[Exception] = None val fakeListener = new JobListener() { - override def taskSucceeded(partition: Int, value: Any) = numResults += 1 - override def jobFailed(exception: Exception) = throw exception + override def taskSucceeded(partition: Int, value: Any): Unit = numResults += 1 + override def jobFailed(exception: Exception): Unit = { + failureReason = Some(exception) + } } val jobId = submit(new MyRDD(sc, 0, Nil), Array(), listener = fakeListener) assert(numResults === 0) cancel(jobId) + assert(failureReason.isDefined) + assert(failureReason.get.getMessage() === "Job 0 cancelled ") } test("run trivial job") { @@ -1675,6 +1687,18 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("catch errors in event loop") { + // this is a test of our testing framework -- make sure errors in event loop don't get ignored + + // just run some bad event that will throw an exception -- we'll give a null TaskEndReason + val rdd1 = new MyRDD(sc, 1, Nil) + submit(rdd1, Array(0)) + intercept[Exception] { + complete(taskSets(0), Seq( + (null, makeMapStatus("hostA", 1)))) + } + } + test("simple map stage submission") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) From f590178d7a06221a93286757c68b23919bee9f03 Mon Sep 17 00:00:00 2001 From: tedyu Date: Wed, 16 Dec 2015 19:02:12 -0800 Subject: [PATCH 1271/2089] [SPARK-12365][CORE] Use ShutdownHookManager where Runtime.getRuntime.addShutdownHook() is called SPARK-9886 fixed ExternalBlockStore.scala This PR fixes the remaining references to Runtime.getRuntime.addShutdownHook() Author: tedyu Closes #10325 from ted-yu/master. --- .../spark/deploy/ExternalShuffleService.scala | 18 +++++--------- .../deploy/mesos/MesosClusterDispatcher.scala | 13 ++++------ .../spark/util/ShutdownHookManager.scala | 4 ++++ scalastyle-config.xml | 12 ++++++++++ .../hive/thriftserver/SparkSQLCLIDriver.scala | 24 +++++++++---------- 5 files changed, 38 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index e8a1e35c3fc48..7fc96e4f764b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -28,7 +28,7 @@ import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -118,19 +118,13 @@ object ExternalShuffleService extends Logging { server = newShuffleService(sparkConf, securityManager) server.start() - installShutdownHook() + ShutdownHookManager.addShutdownHook { () => + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } // keep running until the process is terminated barrier.await() } - - private def installShutdownHook(): Unit = { - Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { - override def run() { - logInfo("Shutting down shuffle service.") - server.stop() - barrier.countDown() - } - }) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 5d4e5b899dfdc..389eff5e0645b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{ShutdownHookManager, SignalLogger} import org.apache.spark.{Logging, SecurityManager, SparkConf} /* @@ -103,14 +103,11 @@ private[mesos] object MesosClusterDispatcher extends Logging { } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() - val shutdownHook = new Thread() { - override def run() { - logInfo("Shutdown hook is shutting down dispatcher") - dispatcher.stop() - dispatcher.awaitShutdown() - } + ShutdownHookManager.addShutdownHook { () => + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() } - Runtime.getRuntime.addShutdownHook(shutdownHook) dispatcher.awaitShutdown() } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 620f226a23e15..1a0f3b477ba3f 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -162,7 +162,9 @@ private[spark] object ShutdownHookManager extends Logging { val hook = new Thread { override def run() {} } + // scalastyle:off runtimeaddshutdownhook Runtime.getRuntime.addShutdownHook(hook) + // scalastyle:on runtimeaddshutdownhook Runtime.getRuntime.removeShutdownHook(hook) } catch { case ise: IllegalStateException => return true @@ -228,7 +230,9 @@ private [util] class SparkShutdownHookManager { .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) case Failure(_) => + // scalastyle:off runtimeaddshutdownhook Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + // scalastyle:on runtimeaddshutdownhook } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index dab1ebddc666e..6925e18737b75 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -157,6 +157,18 @@ This file is divided into 3 sections: ]]> + + Runtime\.getRuntime\.addShutdownHook + + + Class\.forName - try { - h.flush() - } catch { - case e: IOException => - logWarning("WARNING: Failed to write command history file: " + e.getMessage) - } - case _ => - } + ShutdownHookManager.addShutdownHook { () => + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => } - })) + } // TODO: missing /* From fdb38227564c1af40cbfb97df420b23eb04c002b Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Wed, 16 Dec 2015 19:04:33 -0800 Subject: [PATCH 1272/2089] [SPARK-12186][WEB UI] Send the complete request URI including the query string when redirecting. Author: Rohit Agarwal Closes #10180 from mindprince/SPARK-12186. --- .../scala/org/apache/spark/deploy/history/HistoryServer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d4f327cc588fe..f31fef0eccc3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -103,7 +103,9 @@ class HistoryServer( // Note we don't use the UI retrieved from the cache; the cache loader above will register // the app's UI, and all we need to do is redirect the user to the same URI that was // requested, and the proper data should be served at that point. - res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) + // Also, make sure that the redirect url contains the query string present in the request. + val requestURI = req.getRequestURI + Option(req.getQueryString).map("?" + _).getOrElse("") + res.sendRedirect(res.encodeRedirectURL(requestURI)) } // SPARK-5983 ensure TRACE is not supported From d1508dd9b765489913bc948575a69ebab82f217b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 16 Dec 2015 19:47:49 -0800 Subject: [PATCH 1273/2089] [SPARK-12386][CORE] Fix NPE when spark.executor.port is set. Author: Marcelo Vanzin Closes #10339 from vanzin/SPARK-12386. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 84230e32a4462..52acde1b414eb 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -256,7 +256,12 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { - val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 + val actorSystemPort = + if (port == 0 || rpcEnv.address == null) { + port + } else { + rpcEnv.address.port + 1 + } // Create a ActorSystem for legacy codes AkkaUtils.createActorSystem( actorSystemName + "ActorSystem", From 97678edeaaafc19ea18d044233a952d2e2e89fbc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 16 Dec 2015 20:01:47 -0800 Subject: [PATCH 1274/2089] [SPARK-12390] Clean up unused serializer parameter in BlockManager No change in functionality is intended. This only changes internal API. Author: Andrew Or Closes #10343 from andrewor14/clean-bm-serializer. --- .../apache/spark/storage/BlockManager.scala | 29 +++++++------------ .../org/apache/spark/storage/DiskStore.scala | 10 ------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 540e1ec003a2b..6074fc58d70db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1190,20 +1190,16 @@ private[spark] class BlockManager( def dataSerializeStream( blockId: BlockId, outputStream: OutputStream, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): Unit = { + values: Iterator[Any]): Unit = { val byteStream = new BufferedOutputStream(outputStream) - val ser = serializer.newInstance() + val ser = defaultSerializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a byte buffer. */ - def dataSerialize( - blockId: BlockId, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { + def dataSerialize(blockId: BlockId, values: Iterator[Any]): ByteBuffer = { val byteStream = new ByteBufferOutputStream(4096) - dataSerializeStream(blockId, byteStream, values, serializer) + dataSerializeStream(blockId, byteStream, values) byteStream.toByteBuffer } @@ -1211,24 +1207,21 @@ private[spark] class BlockManager( * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserialize( - blockId: BlockId, - bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + def dataDeserialize(blockId: BlockId, bytes: ByteBuffer): Iterator[Any] = { bytes.rewind() - dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true)) } /** * Deserializes a InputStream into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserializeStream( - blockId: BlockId, - inputStream: InputStream, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + def dataDeserializeStream(blockId: BlockId, inputStream: InputStream): Iterator[Any] = { val stream = new BufferedInputStream(inputStream) - serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator + defaultSerializer + .newInstance() + .deserializeStream(wrapForCompression(blockId, stream)) + .asIterator } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index c008b9dc16327..6c4477184d5b4 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -144,16 +144,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } - /** - * A version of getValues that allows a custom serializer. This is used as part of the - * shuffle short-circuit code. - */ - def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - // TODO: Should bypass getBytes and use a stream based implementation, so that - // we won't use a lot of memory during e.g. external sort merge. - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) - } - override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) if (file.exists()) { From 437583f692e30b8dc03b339a34e92595d7b992ba Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Wed, 16 Dec 2015 22:10:24 -0800 Subject: [PATCH 1275/2089] [SPARK-11904][PYSPARK] reduceByKeyAndWindow does not require checkpointing when invFunc is None when invFunc is None, `reduceByKeyAndWindow(func, None, winsize, slidesize)` is equivalent to reduceByKey(func).window(winsize, slidesize).reduceByKey(winsize, slidesize) and no checkpoint is necessary. The corresponding Scala code does exactly that, but Python code always creates a windowed stream with obligatory checkpointing. The patch fixes this. I do not know how to unit-test this. Author: David Tolpin Closes #9888 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 45 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index f61137cb88c47..b994a53bf2b85 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -542,31 +542,32 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None reduced = self.reduceByKey(func, numPartitions) - def reduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - r = a.union(b).reduceByKey(func, numPartitions) if a else b - if filterFunc: - r = r.filter(filterFunc) - return r - - def invReduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) - if kv[1] is not None else kv[0]) - - jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invFunc: + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream( + reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) else: - jinvReduceFunc = None - if slideDuration is None: - slideDuration = self._slideDuration - dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), - jreduceFunc, jinvReduceFunc, - self._ssc._jduration(windowDuration), - self._ssc._jduration(slideDuration)) - return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + return reduced.window(windowDuration, slideDuration).reduceByKey(func, numPartitions) def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ From 9d66c4216ad830812848c657bbcd8cd50949e199 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 16 Dec 2015 23:18:53 -0800 Subject: [PATCH 1276/2089] [SPARK-12057][SQL] Prevent failure on corrupt JSON records This PR makes JSON parser and schema inference handle more cases where we have unparsed records. It is based on #10043. The last commit fixes the failed test and updates the logic of schema inference. Regarding the schema inference change, if we have something like ``` {"f1":1} [1,2,3] ``` originally, we will get a DF without any column. After this change, we will get a DF with columns `f1` and `_corrupt_record`. Basically, for the second row, `[1,2,3]` will be the value of `_corrupt_record`. When merge this PR, please make sure that the author is simplyianm. JIRA: https://issues.apache.org/jira/browse/SPARK-12057 Closes #10043 Author: Ian Macalinao Author: Yin Huai Closes #10288 from yhuai/handleCorruptJson. --- .../datasources/json/InferSchema.scala | 37 +++++++++++++++++-- .../datasources/json/JacksonParser.scala | 19 ++++++---- .../datasources/json/JsonSuite.scala | 37 +++++++++++++++++++ .../datasources/json/TestJsonData.scala | 9 ++++- 4 files changed, 90 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 922fd5b21167b..59ba4ae2cba0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -61,7 +61,10 @@ private[json] object InferSchema { StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + }.treeAggregate[DataType]( + StructType(Seq()))( + compatibleRootType(columnNameOfCorruptRecords), + compatibleRootType(columnNameOfCorruptRecords)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -170,12 +173,38 @@ private[json] object InferSchema { case other => Some(other) } + private def withCorruptField( + struct: StructType, + columnNameOfCorruptRecords: String): StructType = { + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + struct.add(columnNameOfCorruptRecords, StringType, nullable = true) + } else { + // Otherwise, just return this struct. + struct + } + } + /** * Remove top-level ArrayType wrappers and merge the remaining schemas */ - private def compatibleRootType: (DataType, DataType) => DataType = { - case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2) + private def compatibleRootType( + columnNameOfCorruptRecords: String): (DataType, DataType) => DataType = { + // Since we support array of json objects at the top level, + // we need to check the element type and find the root level data type. + case (ArrayType(ty1, _), ty2) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + case (struct: StructType, NullType) => struct + case (NullType, struct: StructType) => struct + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, columnNameOfCorruptRecords) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, columnNameOfCorruptRecords) + // If we get anything else, we call compatibleType. + // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index bfa1405041058..55a1c24e9e000 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils +private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) + object JacksonParser { def parse( @@ -110,7 +112,7 @@ object JacksonParser { lowerCaseValue.equals("-inf")) { value.toFloat } else { - sys.error(s"Cannot parse $value as FloatType.") + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => @@ -127,7 +129,7 @@ object JacksonParser { lowerCaseValue.equals("-inf")) { value.toDouble } else { - sys.error(s"Cannot parse $value as DoubleType.") + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => @@ -174,7 +176,11 @@ object JacksonParser { convertField(factory, parser, udt.sqlType) case (token, dataType) => - sys.error(s"Failed to parse a value for data type $dataType (current token: $token).") + // We cannot parse this token based on the given data type. So, we throw a + // SparkSQLJsonProcessingException and this exception will be caught by + // parseJson method. + throw new SparkSQLJsonProcessingException( + s"Failed to parse a value for data type $dataType (current token: $token).") } } @@ -267,15 +273,14 @@ object JacksonParser { array.toArray[InternalRow](schema) } case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of " + - "the file (or each string in the RDD) is a valid JSON object or " + - "an array of JSON objects.") + failedRecord(record) } } } catch { case _: JsonProcessingException => failedRecord(record) + case _: SparkSQLJsonProcessingException => + failedRecord(record) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index ba7718c864637..baa258ad26152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1427,4 +1427,41 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } } + + test("SPARK-12057 additional corrupt records do not throw exceptions") { + // Test if we can query corrupt records. + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("dummy", StringType, true) :: Nil) + + { + // We need to make sure we can infer the schema. + val jsonDF = sqlContext.read.json(additionalCorruptRecords) + assert(jsonDF.schema === schema) + } + + { + val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) + jsonDF.registerTempTable("jsonTable") + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT dummy, _unparsed + |FROM jsonTable + """.stripMargin), + Row("test", null) :: + Row(null, """[1,2,3]""") :: + Row(null, """":"test", "a":1}""") :: + Row(null, """42""") :: + Row(null, """ ","ian":"test"}""") :: Nil + ) + } + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 713d1da1cb515..cb61f7eeca0de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -188,6 +188,14 @@ private[json] trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def additionalCorruptRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"dummy":"test"}""" :: + """[1,2,3]""" :: + """":"test", "a":1}""" :: + """42""" :: + """ ","ian":"test"}""" :: Nil) + def emptyRecords: RDD[String] = sqlContext.sparkContext.parallelize( """{""" :: @@ -197,7 +205,6 @@ private[json] trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) From 5a514b61bbfb609c505d8d65f2483068a56f1f70 Mon Sep 17 00:00:00 2001 From: echo2mei <534384876@qq.com> Date: Thu, 17 Dec 2015 07:59:17 -0800 Subject: [PATCH 1277/2089] Once driver register successfully, stop it to connect to master. This commit is to resolve SPARK-12396. Author: echo2mei <534384876@qq.com> Closes #10354 from echoTomei/master. --- .../main/scala/org/apache/spark/deploy/client/AppClient.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 1e2f469214b84..3cf7464a15615 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -130,6 +130,7 @@ private[spark] class AppClient( if (registered.get) { registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() + registrationRetryTimer.cancel(true) } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { From cd3d937b0cc89cb5e4098f7d9f5db2712e3de71e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 08:01:27 -0800 Subject: [PATCH 1278/2089] Revert "Once driver register successfully, stop it to connect to master." This reverts commit 5a514b61bbfb609c505d8d65f2483068a56f1f70. --- .../main/scala/org/apache/spark/deploy/client/AppClient.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 3cf7464a15615..1e2f469214b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -130,7 +130,6 @@ private[spark] class AppClient( if (registered.get) { registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() - registrationRetryTimer.cancel(true) } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { From a170d34a1b309fecc76d1370063e0c4f44dc2142 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 08:04:11 -0800 Subject: [PATCH 1279/2089] [SPARK-12395] [SQL] fix resulting columns of outer join For API DataFrame.join(right, usingColumns, joinType), if the joinType is right_outer or full_outer, the resulting join columns could be wrong (will be null). The order of columns had been changed to match that with MySQL and PostgreSQL [1]. This PR also fix the nullability of output for outer join. [1] http://www.postgresql.org/docs/9.2/static/queries-table-expressions.html Author: Davies Liu Closes #10353 from davies/fix_join. --- .../org/apache/spark/sql/DataFrame.scala | 25 +++++++++++++++---- .../apache/spark/sql/DataFrameJoinSuite.scala | 20 ++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6250e952169d7..d74131231499d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} @@ -455,10 +455,8 @@ class DataFrame private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join] - // Project only one of the join columns. - val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) val condition = usingColumns.map { col => catalyst.expressions.EqualTo( withPlan(joined.left).resolve(col), @@ -467,9 +465,26 @@ class DataFrame private[sql]( catalyst.expressions.And(cond, eqTo) } + // Project only one of the join columns. + val joinedCols = JoinType(joinType) match { + case Inner | LeftOuter | LeftSemi => + usingColumns.map(col => withPlan(joined.left).resolve(col)) + case RightOuter => + usingColumns.map(col => withPlan(joined.right).resolve(col)) + case FullOuter => + usingColumns.map { col => + val leftCol = withPlan(joined.left).resolve(col) + val rightCol = withPlan(joined.right).resolve(col) + Alias(Coalesce(Seq(leftCol, rightCol)), col)() + } + } + // The nullability of output of joined could be different than original column, + // so we can only compare them by exprId + val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil) + val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId)) withPlan { Project( - joined.output.filterNot(joinedCols.contains(_)), + resultCols, Join( joined.left, joined.right, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c70397f9853ae..39a65413bd592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -43,16 +43,28 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } test("join - join using multiple columns and specifying join type") { - val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") - val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "inner"), + Row(1, "1", 2, 3) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "left"), - Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "right"), - Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "outer"), + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "left_semi"), + Row(1, "1", 2) :: Nil) } test("join - join using self join") { From 6e0771665b3c9330fc0a5b2c7740a796b4cd712e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Dec 2015 09:19:46 -0800 Subject: [PATCH 1280/2089] [SQL] Update SQLContext.read.text doc Since we rename the column name from ```text``` to ```value``` for DataFrame load by ```SQLContext.read.text```, we need to update doc. Author: Yanbo Liang Closes #10349 from yanboliang/text-value. --- python/pyspark/sql/readwriter.py | 2 +- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 2 +- .../spark/sql/execution/datasources/text/DefaultSource.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2e75f0c8a1827..a3d7eca04b616 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -207,7 +207,7 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) def text(self, paths): - """Loads a text file and returns a [[DataFrame]] with a single string column named "text". + """Loads a text file and returns a [[DataFrame]] with a single string column named "value". Each line in the text file is a new row in the resulting DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3ed1e55adec6d..c1a8f19313a7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -339,7 +339,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } /** - * Loads a text file and returns a [[DataFrame]] with a single string column named "text". + * Loads a text file and returns a [[DataFrame]] with a single string column named "value". * Each line in the text file is a new row in the resulting DataFrame. For example: * {{{ * // Scala: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index fbd387bc2ef47..4a1cbe4c38fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -76,7 +76,7 @@ private[sql] class TextRelation( (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - /** Data schema is always a single column, named "text". */ + /** Data schema is always a single column, named "value". */ override def dataSchema: StructType = new StructType().add("value", StringType) /** This is an internal data source that outputs internal row format. */ From 86e405f357711ae93935853a912bc13985c259db Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 17 Dec 2015 09:55:37 -0800 Subject: [PATCH 1281/2089] [SPARK-12220][CORE] Make Utils.fetchFile support files that contain special characters This PR encodes and decodes the file name to fix the issue. Author: Shixiong Zhu Closes #10208 from zsxwing/uri. --- .../org/apache/spark/HttpFileServer.scala | 6 ++--- .../spark/rpc/netty/NettyStreamManager.scala | 5 ++-- .../scala/org/apache/spark/util/Utils.scala | 26 ++++++++++++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +++ .../org/apache/spark/util/UtilsSuite.scala | 11 ++++++++ 5 files changed, 46 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 77d8ec9bb1607..46f9f9e9af7da 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -63,12 +63,12 @@ private[spark] class HttpFileServer( def addFile(file: File) : String = { addFileToDir(file, fileDir) - serverUri + "/files/" + file.getName + serverUri + "/files/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addJar(file: File) : String = { addFileToDir(file, jarDir) - serverUri + "/jars/" + file.getName + serverUri + "/jars/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addDirectory(path: String, resourceBase: String): String = { @@ -85,7 +85,7 @@ private[spark] class HttpFileServer( throw new IllegalArgumentException(s"$file cannot be a directory.") } Files.copy(file, new File(dir, file.getName)) - dir + "/" + file.getName + dir + "/" + Utils.encodeFileNameToURIRawPath(file.getName) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index ecd96972455d0..394cde4fa076d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc.RpcEnvFileServer +import org.apache.spark.util.Utils /** * StreamManager implementation for serving files from a NettyRpcEnv. @@ -64,13 +65,13 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) override def addFile(file: File): String = { require(files.putIfAbsent(file.getName(), file) == null, s"File ${file.getName()} already registered.") - s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addJar(file: File): String = { require(jars.putIfAbsent(file.getName(), file) == null, s"JAR ${file.getName()} already registered.") - s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addDirectory(baseUri: String, path: File): String = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9dbe66e7eefbd..fce89dfccfe23 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -331,6 +331,30 @@ private[spark] object Utils extends Logging { } /** + * A file name may contain some invalid URI characters, such as " ". This method will convert the + * file name to a raw path accepted by `java.net.URI(String)`. + * + * Note: the file name must not contain "/" or "\" + */ + def encodeFileNameToURIRawPath(fileName: String): String = { + require(!fileName.contains("/") && !fileName.contains("\\")) + // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as + // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. + // We should remove it after we get the raw path. + new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) + } + + /** + * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/", + * return the name before the last "/". + */ + def decodeFileNameInURI(uri: URI): String = { + val rawPath = uri.getRawPath + val rawFileName = rawPath.split("/").last + new URI("file:///" + rawFileName).getPath.substring(1) + } + + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -351,7 +375,7 @@ private[spark] object Utils extends Logging { hadoopConf: Configuration, timestamp: Long, useCache: Boolean) { - val fileName = url.split("/").last + val fileName = decodeFileNameInURI(new URI(url)) val targetFile = new File(targetDir, fileName) val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) if (useCache && fetchCacheEnabled) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6d153eb04e04f..49e3e0191c38a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -771,6 +771,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val fileWithSpecialChars = new File(tempDir, "file name") + Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) val empty = new File(tempDir, "empty") Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") @@ -787,6 +789,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) val fileUri = env.fileServer.addFile(file) + val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) @@ -805,6 +808,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val files = Seq( (file, fileUri), + (fileWithSpecialChars, fileWithSpecialCharsUri), (empty, emptyUri), (jar, jarUri), (subFile1, dir1Uri + "/file1"), diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 68b0da76bc134..fdb51d440eff6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -734,4 +734,15 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "0")) === true) } + test("encodeFileNameToURIRawPath") { + assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") + assert(Utils.encodeFileNameToURIRawPath("abc xyz") === "abc%20xyz") + assert(Utils.encodeFileNameToURIRawPath("abc:xyz") === "abc:xyz") + } + + test("decodeFileNameInURI") { + assert(Utils.decodeFileNameInURI(new URI("files:///abc/xyz")) === "xyz") + assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") + assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") + } } From 8184568810e8a2e7d5371db2c6a0366ef4841f70 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 18 Dec 2015 03:19:31 +0900 Subject: [PATCH 1282/2089] [SPARK-12345][MESOS] Properly filter out SPARK_HOME in the Mesos REST server Fix problem with #10332, this one should fix Cluster mode on Mesos Author: Iulian Dragos Closes #10359 from dragos/issue/fix-spark-12345-one-more-time. --- .../org/apache/spark/deploy/rest/mesos/MesosRestServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 24510db2bd0ba..c0b93596508f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -99,7 +99,7 @@ private[mesos] class MesosSubmitRequestServlet( // cause spark-submit script to look for files in SPARK_HOME instead. // We only need the ability to specify where to find spark-submit script // which user can user spark.executor.home or spark.home configurations. - val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) + val environmentVariables = request.environmentVariables.filterKeys(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description From 540b5aeadc84d1a5d61bda4414abd6bf35dc7ff9 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 17 Dec 2015 13:23:48 -0800 Subject: [PATCH 1283/2089] [SPARK-12410][STREAMING] Fix places that use '.' and '|' directly in split String.split accepts a regular expression, so we should escape "." and "|". Author: Shixiong Zhu Closes #10361 from zsxwing/reg-bug. --- .../main/scala/org/apache/spark/examples/ml/MovieLensALS.scala | 2 +- .../apache/spark/streaming/util/FileBasedWriteAheadLog.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 3ae53e57dbdb8..02ed746954f23 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -50,7 +50,7 @@ object MovieLensALS { def parseMovie(str: String): Movie = { val fields = str.split("::") assert(fields.size == 3) - Movie(fields(0).toInt, fields(1), fields(2).split("|")) + Movie(fields(0).toInt, fields(1), fields(2).split("\\|")) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index a99b570835831..b946e0d8e9271 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -253,7 +253,7 @@ private[streaming] object FileBasedWriteAheadLog { def getCallerName(): Option[String] = { val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) - stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split("\\.").lastOption) } /** Convert a sequence of files to a sequence of sorted LogInfo objects */ From e096a652b92fc64a7b3457cd0766ab324bcc980b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Dec 2015 14:16:49 -0800 Subject: [PATCH 1284/2089] [SPARK-12397][SQL] Improve error messages for data sources when they are not found Point users to spark-packages.org to find them. Author: Reynold Xin Closes #10351 from rxin/SPARK-12397. --- .../datasources/ResolvedDataSource.scala | 50 ++++++++++++------- .../sql/sources/ResolvedDataSourceSuite.scala | 17 +++++++ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 86a306b8f941d..e02ee6cd6b907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -57,24 +57,38 @@ object ResolvedDataSource extends Logging { val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider.", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") + // the provider format did not match any given registered aliases + case Nil => + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + if (provider == "avro" || provider == "com.databricks.spark.avro") { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please use Spark package " + + "http://spark-packages.org/package/databricks/spark-avro", + error) + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "http://spark-packages.org", + error) + } + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 27d1cd92fca1a..cb6e5179b31ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -57,4 +57,21 @@ class ResolvedDataSourceSuite extends SparkFunSuite { ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } + + test("error message for unknown data sources") { + val error1 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("avro") + } + assert(error1.getMessage.contains("spark-packages")) + + val error2 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("com.databricks.spark.avro") + } + assert(error2.getMessage.contains("spark-packages")) + + val error3 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("asfdwefasdfasdf") + } + assert(error3.getMessage.contains("spark-packages")) + } } From ed6ebda5c898bad76194fe3a090bef5a14f861c2 Mon Sep 17 00:00:00 2001 From: Evan Chen Date: Thu, 17 Dec 2015 14:22:30 -0800 Subject: [PATCH 1285/2089] [SPARK-12376][TESTS] Spark Streaming Java8APISuite fails in assertOrderInvariantEquals method org.apache.spark.streaming.Java8APISuite.java is failing due to trying to sort immutable list in assertOrderInvariantEquals method. Author: Evan Chen Closes #10336 from evanyc15/SPARK-12376-StreamingJavaAPISuite. --- .../org/apache/spark/streaming/Java8APISuite.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 89e0c7fdf7eec..e8a0dfc0f0a5f 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -439,9 +439,14 @@ public void testPairFlatMap() { */ public static > void assertOrderInvariantEquals( List> expected, List> actual) { - expected.forEach((List list) -> Collections.sort(list)); - actual.forEach((List list) -> Collections.sort(list)); - Assert.assertEquals(expected, actual); + expected.forEach(list -> Collections.sort(list)); + List> sortedActual = new ArrayList<>(); + actual.forEach(list -> { + List sortedList = new ArrayList<>(list); + Collections.sort(sortedList); + sortedActual.add(sortedList); + }); + Assert.assertEquals(expected, sortedActual); } @Test From 658f66e6208a52367e3b43a6fee9c90f33fb6226 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 17 Dec 2015 15:16:35 -0800 Subject: [PATCH 1286/2089] [SPARK-8641][SQL] Native Spark Window functions This PR removes Hive windows functions from Spark and replaces them with (native) Spark ones. The PR is on par with Hive in terms of features. This has the following advantages: * Better memory management. * The ability to use spark UDAFs in Window functions. cc rxin / yhuai Author: Herman van Hovell Closes #9819 from hvanhovell/SPARK-8641-2. --- .../sql/catalyst/analysis/Analyzer.scala | 81 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 39 +- .../catalyst/analysis/FunctionRegistry.scala | 12 +- .../expressions/aggregate/interfaces.scala | 7 +- .../expressions/windowExpressions.scala | 318 +++++++-- .../analysis/AnalysisErrorSuite.scala | 31 +- .../apache/spark/sql/execution/Window.scala | 649 ++++++++++-------- .../spark/sql/expressions/WindowSpec.scala | 54 +- .../org/apache/spark/sql/functions.scala | 16 +- .../spark/sql/DataFrameWindowSuite.scala} | 200 +++--- .../HiveWindowFunctionQuerySuite.scala | 10 +- .../apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 22 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 224 ------ .../sql/hive/execution/WindowQuerySuite.scala | 230 +++++++ 15 files changed, 1148 insertions(+), 746 deletions(-) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala => core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala} (57%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca00a5e49f668..64dd83a915711 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -77,6 +77,8 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveWindowOrder :: + ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -127,14 +129,12 @@ class Analyzer( // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { - case plan => plan.transformExpressions { + case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = s"Window specification $windowName is not defined in the WINDOW clause." val windowSpecDefinition = - windowDefinitions - .get(windowName) - .getOrElse(failAnalysis(errorMessage)) + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) WindowExpression(c, windowSpecDefinition) } } @@ -577,6 +577,10 @@ class Analyzer( AggregateExpression(max, Complete, isDistinct = false) case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. @@ -597,11 +601,17 @@ class Analyzer( } def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.foreach(_.foreach { - case agg: AggregateExpression => return true - case _ => - }) - false + // Collect all Windowed Aggregate Expressions. + val windowedAggExprs = exprs.flatMap { expr => + expr.collect { + case WindowExpression(ae: AggregateExpression, _) => ae + } + }.toSet + + // Find the first Aggregate Expression that is not Windowed. + exprs.exists(_.collectFirst { + case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae + }.isDefined) } } @@ -875,26 +885,37 @@ class Analyzer( // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. + val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). case wf : WindowFunction => - val newChildren = wf.children.map(extractExpr(_)) + val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => - val newPartitionSpec = partitionSpec.map(extractExpr(_)) + val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + // Extract Windowed AggregateExpression + case we @ WindowExpression( + AggregateExpression(function, mode, isDistinct), + spec: WindowSpecDefinition) => + val newChildren = function.children.map(extractExpr) + val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val newAgg = AggregateExpression(newFunction, mode, isDistinct) + seenWindowAggregates += newAgg + WindowExpression(newAgg, spec) + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). - case agg: AggregateExpression => + case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute @@ -1102,6 +1123,42 @@ class Analyzer( } } } + + /** + * Check and add proper window frames for all window functions. + */ + object ResolveWindowFrame extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, + WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + if wf.frame != UnspecifiedFrame && wf.frame != f => + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, + s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if wf.frame != UnspecifiedFrame => + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + we.copy(windowSpec = s.copy(frameSpecification = frame)) + } + } + } + + /** + * Check and add order to [[AggregateWindowFunction]]s. + */ + object ResolveWindowOrder extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"WindowFunction $wf requires window to be ordered") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b2c93d63d673..440f679913802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -70,15 +70,32 @@ trait CheckAnalysis { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") + case w @ WindowExpression(AggregateExpression(_, _, true), _) => + failAnalysis(s"Distinct window functions are not supported: $w") + + case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, + SpecifiedWindowFrame(frame, + FrameBoundary(l), + FrameBoundary(h)))) + if order.isEmpty || frame != RowFrame || l != h => + failAnalysis("An offset window function can only be evaluated in an ordered " + + s"row-based window frame with a single offset: $w") + + case w @ WindowExpression(e, s) => + // Only allow window functions with an aggregate expression or an offset window + // function. + e match { + case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + case _ => + failAnalysis(s"Expression '$e' not supported within a window function.") + } + // Make sure the window specification is valid. + s.validate match { + case Some(m) => + failAnalysis(s"Window specification $s is not valid because $m") + case None => w + } - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") } operator match { @@ -204,10 +221,12 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => // The rule above is used to check Aggregate operator. failAnalysis( - s"""nondeterministic expressions are only allowed in Project or Filter, found: + s"""nondeterministic expressions are only allowed in + |Project, Filter, Aggregate or Window, found: | ${o.expressions.map(_.prettyString).mkString(",")} |in operator ${operator.simpleString} """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f9c04d7ec0b0c..12c24cc768225 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -283,7 +283,17 @@ object FunctionRegistry { expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), - expression[MonotonicallyIncreasingID]("monotonically_increasing_id") + expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + + // window functions + expression[Lead]("lead"), + expression[Lag]("lag"), + expression[RowNumber]("row_number"), + expression[CumeDist]("cume_dist"), + expression[NTile]("ntile"), + expression[Rank]("rank"), + expression[DenseRank]("dense_rank"), + expression[PercentRank]("percent_rank") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 3b441de34a49f..b6d2ddc5b1364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -144,9 +144,6 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu */ def defaultResult: Option[Literal] = None - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - /** * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, @@ -187,7 +184,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction { +abstract class ImperativeAggregate extends AggregateFunction with CodegenFallback { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 1680aa8252ecb..06252ac4fc616 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{DataType, NumericType} +import org.apache.spark.sql.catalyst.expressions.aggregate.{NoOp, DeclarativeAggregate} +import org.apache.spark.sql.types._ /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -117,6 +118,19 @@ sealed trait FrameBoundary { def notFollows(other: FrameBoundary): Boolean } +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundary { + def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} + /** UNBOUNDED PRECEDING boundary. */ case object UnboundedPreceding extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { @@ -243,85 +257,281 @@ object SpecifiedWindowFrame { } } +case class UnresolvedWindowExpression( + child: Expression, + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} + +case class WindowExpression( + windowFunction: Expression, + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + + override def dataType: DataType = windowFunction.dataType + override def foldable: Boolean = windowFunction.foldable + override def nullable: Boolean = windowFunction.nullable + + override def toString: String = s"$windowFunction $windowSpec" +} + /** - * Every window function needs to maintain a output buffer for its output. - * It should expect that for a n-row window frame, it will be called n times - * to retrieve value corresponding with these n rows. + * A window function is a function that can only be evaluated in the context of a window operator. */ trait WindowFunction extends Expression { - def init(): Unit + /** Frame in which the window operator must be executed. */ + def frame: WindowFrame = UnspecifiedFrame +} + +/** + * An offset window function is a window function that returns the value of the input column offset + * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with + * offset -2, will get the value of x 2 rows back in the partition. + */ +abstract class OffsetWindowFunction + extends Expression with WindowFunction with Unevaluable with ImplicitCastInputTypes { + /** + * Input expression to evaluate against a row which a number of rows below or above (depending on + * the value and sign of the offset) the current row. + */ + val input: Expression + + /** + * Default result value for the function when the input expression returns NULL. The default will + * evaluated against the current row instead of the offset row. + */ + val default: Expression - def reset(): Unit + /** + * (Foldable) expression that contains the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val offset: Expression - def prepareInputParameters(input: InternalRow): AnyRef + /** + * Direction (above = 1/below = -1) of the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val direction: SortDirection - def update(input: AnyRef): Unit + override def children: Seq[Expression] = Seq(input, offset, default) - def batchUpdate(inputs: Array[AnyRef]): Unit + /* + * The result of an OffsetWindowFunction is dependent on the frame in which the + * OffsetWindowFunction is executed, the input expression and the default expression. Even when + * both the input and the default expression are foldable, the result is still not foldable due to + * the frame. + */ + override def foldable: Boolean = input.foldable && (default == null || default.foldable) - def evaluate(): Unit + override def nullable: Boolean = input.nullable && (default == null || default.nullable) - def get(index: Int): Any + override lazy val frame = { + // This will be triggered by the Analyzer. + val offsetValue = offset.eval() match { + case o: Int => o + case x => throw new AnalysisException( + s"Offset expression must be a foldable integer expression: $x") + } + val boundary = direction match { + case Ascending => ValueFollowing(offsetValue) + case Descending => ValuePreceding(offsetValue) + } + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } + + override def dataType: DataType = input.dataType - def newInstance(): WindowFunction + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) + + override def toString: String = s"$prettyName($input, $offset, $default)" } -case class UnresolvedWindowFunction( - name: String, - children: Seq[Expression]) - extends Expression with WindowFunction with Unevaluable { +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) - override def init(): Unit = throw new UnresolvedException(this, "init") - override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: InternalRow): AnyRef = - throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") - override def batchUpdate(inputs: Array[AnyRef]): Unit = - throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = throw new UnresolvedException(this, "get") + def this(input: Expression) = this(input, Literal(1)) - override def toString: String = s"'$name(${children.mkString(",")})" + def this() = this(Literal(null)) - override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") + override val direction = Ascending } -case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + + def this(input: Expression) = this(input, Literal(1)) + + def this() = this(Literal(null)) + + override val direction = Descending } -case class WindowExpression( - windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { +abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { + self: Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def nullable: Boolean = false + override def supportsPartial: Boolean = false + override lazy val mergeExpressions = + throw new UnsupportedOperationException("Window Functions do not support merging.") +} - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def inputTypes: Seq[AbstractDataType] = Nil + protected val zero = Literal(0) + protected val one = Literal(1) + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = zero :: Nil + override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil +} - override def dataType: DataType = windowFunction.dataType - override def foldable: Boolean = windowFunction.foldable - override def nullable: Boolean = windowFunction.nullable +/** + * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. + */ +trait SizeBasedWindowFunction extends AggregateWindowFunction { + protected def n: AttributeReference = SizeBasedWindowFunction.n +} - override def toString: String = s"$windowFunction $windowSpec" +object SizeBasedWindowFunction { + val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() +} + +case class RowNumber() extends RowNumberLike { + override val evaluateExpression = rowNumber +} + +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { + override def dataType: DataType = DoubleType + // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must + // return the same value for equal values in the partition. + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) +} + +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { + def this() = this(Literal(1)) + + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant + // for each partition. + buckets.eval() match { + case b: Int if b > 0 => // Ok + case x => throw new AnalysisException( + "Buckets expression must be a foldable positive integer expression: $x") + } + + private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() + private val bucketThreshold = + AttributeReference("bucketThreshold", IntegerType, nullable = false)() + private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() + private val bucketsWithPadding = + AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() + private def bucketOverflow(e: Expression) = + If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + + override val aggBufferAttributes = Seq( + rowNumber, + bucket, + bucketThreshold, + bucketSize, + bucketsWithPadding + ) + + override val initialValues = Seq( + zero, + zero, + zero, + Cast(Divide(n, buckets), IntegerType), + Cast(Remainder(n, buckets), IntegerType) + ) + + override val updateExpressions = Seq( + Add(rowNumber, one), + Add(bucket, bucketOverflow(one)), + Add(bucketThreshold, bucketOverflow( + Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + NoOp, + NoOp + ) + + override val evaluateExpression = bucket } /** - * Extractor for making working with frame boundaries easier. + * A RankLike function is a WindowFunction that changes its value based on a change in the value of + * the order of the window in which is processed. For instance, when the value of 'x' changes in a + * window ordered by 'x' the rank function also changes. The size of the change of the rank function + * is (typically) not dependent on the size of the change in 'x'. */ -object FrameBoundaryExtractor { - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None +abstract class RankLike extends AggregateWindowFunction { + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + /** Store the values of the window 'order' expressions. */ + protected val orderAttrs = children.map{ expr => + AttributeReference(expr.prettyString, expr.dataType)() } + + /** Predicate that detects if the order attributes have changed. */ + protected val orderEquals = children.zip(orderAttrs) + .map(EqualNullSafe.tupled) + .reduceOption(And) + .getOrElse(Literal(true)) + + protected val orderInit = children.map(e => Literal.create(null, e.dataType)) + protected val rank = AttributeReference("rank", IntegerType, nullable = false)() + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + protected val zero = Literal(0) + protected val one = Literal(1) + protected val increaseRowNumber = Add(rowNumber, one) + + /** + * Different RankLike implementations use different source expressions to update their rank value. + * Rank for instance uses the number of rows seen, whereas DenseRank uses the number of changes. + */ + protected def rankSource: Expression = rowNumber + + /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ + protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + + override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs + override val initialValues = zero +: one +: orderInit + override val updateExpressions = increaseRank +: increaseRowNumber +: children + override val evaluateExpression: Expression = rank + + def withOrder(order: Seq[Expression]): RankLike +} + +case class Rank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): Rank = Rank(order) +} + +case class DenseRank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) + override protected def rankSource = Add(rank, one) + override val updateExpressions = increaseRank +: children + override val aggBufferAttributes = rank +: orderAttrs + override val initialValues = zero +: orderInit +} + +case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) + override def dataType: DataType = DoubleType + override val evaluateExpression = If(GreaterThan(n, one), + Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), + Literal(0.0d)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ee435578743fc..12079992b5b84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Count, Sum, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -133,17 +134,37 @@ class AnalysisErrorSuite extends AnalysisTest { "requires int type" :: "'null' is of date type" :: Nil) errorTest( - "unresolved window function", + "invalid window function", testRelation2.select( WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), + Literal(0), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) + "not supported within a window function" :: Nil) + + errorTest( + "distinct window function", + testRelation2.select( + WindowExpression( + AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "Distinct window functions are not supported" :: Nil) + + errorTest( + "offset window function", + testRelation2.select( + WindowExpression( + new Lead(UnresolvedAttribute("b")), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + "window frame" :: "must match the required frame" :: Nil) errorTest( "too many generators", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index b1280c32a6a43..9852b6e7beeba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.CompactBuffer /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -42,6 +45,8 @@ import org.apache.spark.util.collection.CompactBuffer * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame * and we add some rows to the frame. Examples are: * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. * * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame * boundary can be either Row or Range based: @@ -122,12 +127,10 @@ case class Window( // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() // Flip the sign of the offset when processing the order is descending - val boundOffset = - if (sortExpr.direction == Descending) { - -offset - } else { - offset - } + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output)() @@ -149,43 +152,102 @@ case class Window( } /** - * Create a frame processor. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame boundaries. - * @param functions to process in the frame. - * @param ordinal at which the processor starts writing to the output. - * @return a frame processor. + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. */ - private[this] def createFrameProcessor( - frame: WindowFrame, - functions: Array[WindowFunction], - ordinal: Int): WindowFunctionFrame = frame match { - // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => - val uBoundOrdering = createBoundOrdering(frameType, high) - new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) - - // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => - val lBoundOrdering = createBoundOrdering(frameType, low) - new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) - - // Moving Frame. - case SpecifiedWindowFrame(frameType, - FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => - val lBoundOrdering = createBoundOrdering(frameType, low) - val uBoundOrdering = createBoundOrdering(frameType, high) - new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) - - // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => - new UnboundedWindowFunctionFrame(ordinal, functions) - - // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es.append(e) + fns.append(fn) + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + functions, + child.output, + newMutableProjection, + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } } /** @@ -210,43 +272,16 @@ case class Window( } protected override def doExecute(): RDD[InternalRow] = { - // Prepare processing. - // Group the window expression by their processing frame. - val windowExprs = windowExpression.flatMap { - _.collect { - case e: WindowExpression => e - } - } - - // Create Frame processor factories and order the unbound window expressions by the frame they - // are processed in; this is the order in which their results will be written to window - // function result buffer. - val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) - val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = scala.collection.mutable.Buffer.empty[Expression] - framedWindowExprs.zipWithIndex.foreach { - case ((frame, unboundFrameExpressions), index) => - // Track the ordinal. - val ordinal = unboundExpressions.size - - // Track the unbound expressions - unboundExpressions ++= unboundFrameExpressions - - // Bind the expressions. - val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction, child.output) - }.toArray - - // Create the frame processor factory. - factories(index) = () => createFrameProcessor(frame, functions, ordinal) - } + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions) + val result = createResultProjection(expressions) val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. @@ -266,14 +301,15 @@ case class Window( fetchNextRow() // Manage the current partition. - var rows: CompactBuffer[InternalRow] = _ - val frames: Array[WindowFunctionFrame] = factories.map(_()) + val rows = ArrayBuffer.empty[InternalRow] + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. // Before we start to fetch new input rows, make a copy of nextGroup. val currentGroup = nextGroup.copy() - rows = new CompactBuffer + rows.clear() while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() fetchNextRow() @@ -297,7 +333,6 @@ case class Window( override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow - val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. if (rowIndex >= rowsSize && nextRowAvailable) { @@ -308,7 +343,7 @@ case class Window( // Get the results for the window frames. var i = 0 while (i < numFrames) { - frames(i).write(windowFunctionResult) + frames(i).write() i += 1 } @@ -355,140 +390,96 @@ private[execution] final case class RangeBoundOrdering( * A window function calculates the results of a number of window functions for a window frame. * Before use a frame must be prepared by passing it all the rows in the current partition. After * preparation the update method can be called to fill the output rows. - * - * TODO How to improve performance? A few thoughts: - * - Window functions are expensive due to its distribution and ordering requirements. - * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project - * Tungsten are on the way. - * - The window frame processing bit can be improved though. But before we start doing that we - * need to see how much of the time and resources are spent on partitioning and ordering, and - * how much time and resources are spent processing the partitions. There are a couple ways to - * improve on the current situation: - * - Reduce memory footprint by performing streaming calculations. This can only be done when - * there are no Unbound/Unbounded Following calculations present. - * - Use Tungsten style memory usage. - * - Use code generation in general, and use the approach to aggregation taken in the - * GeneratedAggregate class in specific. - * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. */ -private[execution] abstract class WindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) { - - // Make sure functions are initialized. - functions.foreach(_.init()) - - /** Number of columns the window function frame is managing */ - val numColumns = functions.length - - /** - * Create a fresh thread safe copy of the frame. - * - * @return the copied frame. - */ - def copy: WindowFunctionFrame - - /** - * Create new instances of the functions. - * - * @return an array containing copies of the current window functions. - */ - protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) - +private[execution] abstract class WindowFunctionFrame { /** * Prepare the frame for calculating the results for a partition. * * @param rows to calculate the frame results for. */ - def prepare(rows: CompactBuffer[InternalRow]): Unit + def prepare(rows: ArrayBuffer[InternalRow]): Unit /** - * Write the result for the current row to the given target row. - * - * @param target row to write the result for the current row to. + * Write the current results to the target row. */ - def write(target: GenericMutableRow): Unit + def write(): Unit +} - /** Reset the current window functions. */ - protected final def reset(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).reset() - i += 1 - } - } +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[execution] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + offset: Int) extends WindowFunctionFrame { - /** Prepare an input row for processing. */ - protected final def prepare(input: InternalRow): Array[AnyRef] = { - val prepared = new Array[AnyRef](numColumns) - var i = 0 - while (i < numColumns) { - prepared(i) = functions(i).prepareInputParameters(input) - i += 1 - } - prepared - } + /** Rows of the partition currently being processed. */ + private[this] var input: ArrayBuffer[InternalRow] = null - /** Evaluate a prepared buffer (iterator). */ - protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { - reset() - while (iterator.hasNext) { - val prepared = iterator.next() - var i = 0 - while (i < numColumns) { - functions(i).update(prepared(i)) - i += 1 - } - } - evaluate() - } + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 - /** Evaluate a prepared buffer (array). */ - protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], - fromIndex: Int, toIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - val function = functions(i) - function.reset() - var j = fromIndex - while (j < toIndex) { - function.update(prepared(j)(i)) - j += 1 - } - function.evaluate() - i += 1 - } - } + /** Index of the current output row. */ + private[this] var outputIndex = 0 - /** Update an array of window functions. */ - protected final def update(input: InternalRow): Unit = { - var i = 0 - while (i < numColumns) { - val aggregate = functions(i) - val preparedInput = aggregate.prepareInputParameters(input) - aggregate.update(preparedInput) - i += 1 + /** Row used when there is no valid input. */ + private[this] val emptyRow = new GenericInternalRow(inputSchema.size) + + /** Row used to combine the offset and the current row. */ + private[this] val join = new JoinedRow + + /** Create the projection. */ + private[this] val projection = { + // Collect the expressions and bind them. + val numInputAttributes = inputSchema.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => + val input = BindReferences.bindReference(e.input, inputSchema) + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // Without default value. + input + } else { + // With default value. + val default = BindReferences.bindReference(e.default, inputSchema).transform { + // Shift the input reference to its default version. + case BoundReference(o, dataType, nullable) => + BoundReference(o + numInputAttributes, dataType, nullable) + } + org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + } + case e => + BindReferences.bindReference(e, inputSchema) } + + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) } - /** Evaluate the window functions. */ - protected final def evaluate(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).evaluate() - i += 1 - } + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + input = rows + inputIndex = offset + outputIndex = 0 } - /** Fill a target row with the current window function results. */ - protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - target.update(ordinal + i, functions(i).get(rowIndex)) - i += 1 + override def write(): Unit = { + val size = input.size + if (inputIndex >= 0 && inputIndex < size) { + join(input(inputIndex), input(outputIndex)) + } else { + join(emptyRow, input(outputIndex)) } + projection(join) + inputIndex += 1 + outputIndex += 1 } } @@ -496,19 +487,19 @@ private[execution] abstract class WindowFunctionFrame( * The sliding window frame calculates frames with the following SQL form: * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class SlidingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], + target: MutableRow, + processor: AggregateProcessor, lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -518,30 +509,25 @@ private[execution] final class SlidingWindowFunctionFrame( * current output row. */ private[this] var inputLowIndex = 0 - /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new java.util.ArrayDeque[Array[AnyRef]] - /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputHighIndex = 0 inputLowIndex = 0 outputIndex = 0 - buffer.clear() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. while (inputHighIndex < input.size && - ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(prepare(input(inputHighIndex))) + ubound.compare(input, inputHighIndex, outputIndex) <= 0) { inputHighIndex += 1 bufferUpdated = true } @@ -549,25 +535,21 @@ private[execution] final class SlidingWindowFunctionFrame( // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. while (inputLowIndex < inputHighIndex && - lbound.compare(input, inputLowIndex, outputIndex) < 0) { - buffer.pop() + lbound.compare(input, inputLowIndex, outputIndex) < 0) { inputLowIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer.iterator()) - fill(target, outputIndex) + processor.initialize(input.size) + processor.update(input, inputLowIndex, inputHighIndex) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) } /** @@ -578,36 +560,25 @@ private[execution] final class SlidingWindowFunctionFrame( * Its results are the same for each and every row in the partition. This class can be seen as a * special case of a sliding window, but is optimized for the unbound case. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. */ private[execution] final class UnboundedWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { - - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + target: MutableRow, + processor: AggregateProcessor) extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() - outputIndex = 0 - val iterator = rows.iterator - while (iterator.hasNext) { - update(iterator.next()) - } - evaluate() + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + processor.initialize(rows.size) + processor.update(rows, 0, rows.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - fill(target, outputIndex) - outputIndex += 1 + override def write(): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) } - - /** Copy the frame. */ - override def copy: UnboundedWindowFunctionFrame = - new UnboundedWindowFunctionFrame(ordinal, copyFunctions) } /** @@ -620,58 +591,53 @@ private[execution] final class UnboundedWindowFunctionFrame( * is not the case when there is no lower bound, given the additive nature of most aggregates * streaming updates and partial evaluation suffice and no buffering is needed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class UnboundedPrecedingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + * output row. */ private[this] var inputIndex = 0 /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 + processor.initialize(input.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - update(input(inputIndex)) + processor.update(input(inputIndex)) inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluate() - fill(target, outputIndex) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: UnboundedPrecedingWindowFunctionFrame = - new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) } /** @@ -686,45 +652,34 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * buffer and must do full recalculation after each row. Reverse iteration would be possible, if * the communitativity of the used window functions can be guaranteed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ private[execution] final class UnboundedFollowingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { - - /** Buffer used for storing prepared input for the window functions. */ - private[this] var buffer: Array[Array[AnyRef]] = _ + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + * current output row. */ private[this] var inputIndex = 0 /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 - val size = input.size - buffer = Array.ofDim(size) - var i = 0 - while (i < size) { - buffer(i) = prepare(input(i)) - i += 1 - } - evaluatePrepared(buffer, 0, buffer.length) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Drop all rows from the buffer for which the input row value is smaller than @@ -736,15 +691,151 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer, inputIndex, buffer.length) - fill(target, outputIndex) + processor.initialize(input.size) + processor.update(input, inputIndex, input.size) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } +} + +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[execution] object AggregateProcessor { + def apply(functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction]) + if (trackPartitionSize) { + aggBufferAttributes += SizeBasedWindowFunction.n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") + } + + // Create the projections. + val initialProjection = newMutableProjection( + initialValues, + Seq(SizeBasedWindowFunction.n))() + val updateProjection = newMutableProjection( + updateExpressions, + aggBufferAttributes ++ inputAttributes)() + val evaluateProjection = newMutableProjection( + evaluateExpressions, + aggBufferAttributes)() + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProjection, + updateProjection, + evaluateProjection, + imperatives.toArray, + trackPartitionSize) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[execution] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } + } + + /** Bulk update the given buffer. */ + def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = { + var i = begin + while (i < end) { + update(input(i)) + i += 1 + } + } - /** Copy the frame. */ - override def copy: UnboundedFollowingWindowFunctionFrame = - new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 893e800a61438..9397fb84105a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -140,57 +140,7 @@ class WindowSpec private[sql]( * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val windowExpr = aggregate.expr match { - // First, we check if we get an aggregate function without the DISTINCT keyword. - // Right now, we do not support using a DISTINCT aggregate function as a - // window function. - case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct => - aggregateFunction match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(children) => WindowExpression( - UnresolvedWindowFunction("count", children), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"$x is not supported in a window operation.") - } - - case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct => - throw new UnsupportedOperationException( - s"Distinct aggregate function ${aggregateFunction} is not supported " + - s"in window operation.") - - case wf: WindowFunction => - WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - - case x => - throw new UnsupportedOperationException(s"$x is not supported in a window operation.") - } - - new Column(windowExpr) + val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) + new Column(WindowExpression(aggregate.expr, spec)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e79defbbbdeea..65733dcf83e76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -577,7 +577,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + def cume_dist(): Column = withExpr { new CumeDist } /** * @group window_funcs @@ -597,7 +597,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { new DenseRank } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -648,7 +648,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lag(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -700,7 +700,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lead(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -713,7 +713,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } + def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } /** * @group window_funcs @@ -735,7 +735,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { new PercentRank } /** * Window function: returns the rank of rows within a window partition. @@ -750,7 +750,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } + def rank(): Column = withExpr { new Rank } /** * @group window_funcs @@ -765,7 +765,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { RowNumber() } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala similarity index 57% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 2c98f1c3cc49c..b50d7604e0ec7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.types.{DataType, LongType, StructType} -class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { - import hiveContext.implicits._ - import hiveContext.sql +class DataFrameWindowSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -55,10 +54,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lead(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) } test("lag") { @@ -68,10 +64,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) } test("lead with default value") { @@ -81,10 +74,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - sql( - """SELECT - | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) } test("lag with default value") { @@ -94,10 +84,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) } test("rank functions in unspecific window") { @@ -112,78 +99,52 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { count("key").over(Window.partitionBy("value").orderBy("key")), sum("key").over(Window.partitionBy("value").orderBy("key")), ntile(2).over(Window.partitionBy("value").orderBy("key")), - rowNumber().over(Window.partitionBy("value").orderBy("key")), - denseRank().over(Window.partitionBy("value").orderBy("key")), + row_number().over(Window.partitionBy("value").orderBy("key")), + dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), - cumeDist().over(Window.partitionBy("value").orderBy("key")), - percentRank().over(Window.partitionBy("value").orderBy("key"))), - sql( - s"""SELECT - |key, - |max(key) over (partition by value order by key), - |min(key) over (partition by value order by key), - |avg(key) over (partition by value order by key), - |count(key) over (partition by value order by key), - |sum(key) over (partition by value order by key), - |ntile(2) over (partition by value order by key), - |row_number() over (partition by value order by key), - |dense_rank() over (partition by value order by key), - |rank() over (partition by value order by key), - |cume_dist() over (partition by value order by key), - |percent_rank() over (partition by value order by key) - |FROM window_table""".stripMargin).collect()) + cume_dist().over(Window.partitionBy("value").orderBy("key")), + percent_rank().over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) } - test("aggregation and range betweens") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + test("aggregation and range between") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), + Row(2.0d), Row(2.0d))) } - test("aggregation and rows betweens with unbounded") { + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) - | FROM window_table""".stripMargin).collect()) + last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), + Row(4, 4, 4, 4))) } - test("aggregation and range betweens with unbounded") { + test("aggregation and range between with unbounded") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -200,18 +161,12 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) - | FROM window_table""".stripMargin).collect()) + Seq(Row(3, null, 3.0d, 4.0d, 3.0d), + Row(5, false, 4.0d, 5.0d, 5.0d), + Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), + Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), + Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), + Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) } test("reverse sliding range frame") { @@ -254,6 +209,87 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { sum($"value").over(window.rangeBetween(1, Long.MaxValue))), Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("statistical functions") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.partitionBy($"key") + checkAnswer( + df.select( + $"key", + var_pop($"value").over(window), + var_samp($"value").over(window), + approxCountDistinct($"value").over(window)), + Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + } + + test("window function with aggregates") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.groupBy($"key") + .agg( + sum($"value"), + sum(sum($"value")).over(window) - sum($"value")), + Seq(Row("a", 6, 9), Row("b", 9, 6))) + } + + test("window function with udaf") { + val udaf = new UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) + } + val df = Seq( + ("a", 1, 1), + ("a", 1, 5), + ("a", 2, 10), + ("a", 2, -1), + ("b", 4, 7), + ("b", 3, 8), + ("b", 2, 4)) + .toDF("key", "a", "b") + val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) + checkAnswer( + df.select( + $"key", + $"a", + $"b", + udaf($"a", $"b").over(window)), + Seq( + Row("a", 1, 1, 6), + Row("a", 1, 5, 6), + Row("a", 2, 10, 24), + Row("a", 2, -1, 24), + Row("b", 4, 7, 60), + Row("b", 3, 8, 32), + Row("b", 2, 4, 8))) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 92bb9e6d73af1..98bbdf0653c2a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -454,6 +454,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) """.stripMargin, reset = false) + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 15. testExpressions", s""" |select p_mfgr,p_name, p_size, @@ -472,7 +475,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 16. testMultipleWindows", s""" |select p_mfgr,p_name, p_size, @@ -530,6 +533,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // when running this test suite under Java 7 and 8. // We change the original sql query a little bit for making the test suite passed // under different JDK + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 20. testSTATs", """ |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp @@ -547,7 +553,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |) t lateral view explode(uniq_size) d as uniq_data |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 21. testDISTs", """ |select p_mfgr,p_name, p_size, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5958777b0d064..0eeb62ca2cb3f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -476,7 +476,6 @@ class HiveContext private[hive]( catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUDFs :: - ResolveHiveWindowFunction :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 091caab921fe9..da41b659e3fce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -353,6 +353,14 @@ private[hive] object HiveQl extends Logging { } /** Extractor for matching Hive's AST Tokens. */ + private[hive] case class Token(name: String, children: Seq[ASTNode]) extends Node { + def getName(): String = name + def getChildren(): java.util.List[Node] = { + val col = new java.util.ArrayList[Node](children.size) + children.foreach(col.add(_)) + col + } + } object Token { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { @@ -360,6 +368,7 @@ private[hive] object HiveQl extends Logging { CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case t: Token => Some((t.name, t.children)) case _ => None } } @@ -1617,17 +1626,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) /* Window Functions */ - case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - // Safe to use Literal(1)? - val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) + case Token(name, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(Token(name, args)) nodesToWindowSpecification(spec) match { case reference: WindowSpecReference => UnresolvedWindowExpression(function, reference) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 2e8c026259efe..a1787fc92d6d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -260,230 +260,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr } } -/** - * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. - */ -private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { - private def shouldResolveFunction( - unresolvedWindowFunction: UnresolvedWindowFunction, - windowSpec: WindowSpecDefinition): Boolean = { - unresolvedWindowFunction.childrenResolved && windowSpec.childrenResolved - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: LogicalPlan if !p.childrenResolved => p - - // We are resolving WindowExpressions at here. When we get here, we have already - // replaced those WindowSpecReferences. - case p: LogicalPlan => - p transformExpressions { - // We will not start to resolve the function unless all arguments are resolved - // and all expressions in window spec are fixed. - case WindowExpression( - u @ UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) if shouldResolveFunction(u, windowSpec) => - // First, let's find the window function info. - val windowFunctionInfo: WindowFunctionInfo = - Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"Couldn't find window function $name")) - - // Get the class of this function. - // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use - // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getFunctionClass() - val newChildren = - // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit - // input parameters and requires implicit parameters, which - // are expressions in Order By clause. - if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { - if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") - } - windowSpec.orderSpec.map(_.child) - } else { - children - } - - // If the class is UDAF, we need to use UDAFBridge. - val isUDAFBridgeRequired = - if (classOf[UDAF].isAssignableFrom(functionClass)) { - true - } else { - false - } - - // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of - // HiveWindowFunction. - val windowFunction = - HiveWindowFunction( - new HiveFunctionWrapper(functionClass.getName), - windowFunctionInfo.isPivotResult, - isUDAFBridgeRequired, - newChildren) - - // Second, check if the specified window function can accept window definition. - windowSpec.frameSpecification match { - case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => - // This Hive window function does not support user-speficied window frame. - throw new AnalysisException( - s"Window function $name does not take a frame specification.") - case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && - windowFunctionInfo.isPivotResult => - // These two should not be true at the same time when a window frame is defined. - // If so, throw an exception. - throw new AnalysisException(s"Could not handle Hive window function $name because " + - s"it supports both a user specified window frame and pivot result.") - case _ => // OK - } - // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs - // a window frame specification to work. - val newWindowSpec = windowSpec.frameSpecification match { - case UnspecifiedFrame => - val newWindowFrame = - SpecifiedWindowFrame.defaultWindowFrame( - windowSpec.orderSpec.nonEmpty, - windowFunctionInfo.isSupportsWindow) - WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) - case _ => windowSpec - } - - // Finally, we create a WindowExpression with the resolved window function and - // specified window spec. - WindowExpression(windowFunction, newWindowSpec) - } - } -} - -/** - * A [[WindowFunction]] implementation wrapping Hive's window function. - * @param funcWrapper The wrapper for the Hive Window Function. - * @param pivotResult If it is true, the Hive function will return a list of values representing - * the values of the added columns. Otherwise, a single value is returned for - * current row. - * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's - * createFunction is UDAF, we need to use GenericUDAFBridge to wrap - * it as a GenericUDAFResolver2. - * @param children Input parameters. - */ -private[hive] case class HiveWindowFunction( - funcWrapper: HiveFunctionWrapper, - pivotResult: Boolean, - isUDAFBridgeRequired: Boolean, - children: Seq[Expression]) extends WindowFunction - with HiveInspectors with Unevaluable { - - // Hive window functions are based on GenericUDAFResolver2. - type UDFType = GenericUDAFResolver2 - - @transient - protected lazy val resolver: GenericUDAFResolver2 = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[GenericUDAFResolver2]() - } - - @transient - protected lazy val inputInspectors = children.map(toInspector).toArray - - // The GenericUDAFEvaluator used to evaluate the window function. - @transient - protected lazy val evaluator: GenericUDAFEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - // The object inspector of values returned from the Hive window function. - @transient - protected lazy val returnInspector = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - override val dataType: DataType = - if (!pivotResult) { - inspectorToDataType(returnInspector) - } else { - // If pivotResult is true, we should take the element type out as the data type of this - // function. - inspectorToDataType(returnInspector) match { - case ArrayType(dt, _) => dt - case _ => - sys.error( - s"error resolve the data type of window function ${funcWrapper.functionClassName}") - } - } - - override def nullable: Boolean = true - - @transient - lazy val inputProjection = new InterpretedProjection(children) - - @transient - private var hiveEvaluatorBuffer: AggregationBuffer = _ - // Output buffer. - private var outputBuffer: Any = _ - - @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - override def init(): Unit = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - // Reset the hiveEvaluatorBuffer and outputPosition - override def reset(): Unit = { - // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. - // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. - // However, RowNumberBuffer.init does not really reset this buffer. - hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer - evaluator.reset(hiveEvaluatorBuffer) - } - - override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap( - inputProjection(input), - inputInspectors, - new Array[AnyRef](children.length), - inputDataTypes) - } - - // Add input parameters for a single row. - override def update(input: AnyRef): Unit = { - evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) - } - - override def batchUpdate(inputs: Array[AnyRef]): Unit = { - var i = 0 - while (i < inputs.length) { - evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) - i += 1 - } - } - - override def evaluate(): Unit = { - outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) - } - - override def get(index: Int): Any = { - if (!pivotResult) { - // if pivotResult is false, we will get a single value for all rows in the frame. - outputBuffer - } else { - // if pivotResult is true, we will get a ArrayData having the same size with the size - // of the window frame. At here, we will return the result at the position of - // index in the output buffer. - outputBuffer.asInstanceOf[ArrayData].get(index, dataType) - } - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - override def newInstance(): WindowFunction = - new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) -} - /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[Generator]]. Note that the semantics of Generators do not allow diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala new file mode 100644 index 0000000000000..c05dbfd7608d9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.test.SQLTestUtils + +/** + * This suite contains a couple of Hive window tests which fail in the typical setup due to tiny + * numerical differences or due semantic differences between Hive and Spark. + */ +class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part + """.stripMargin) + } + + override def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + } + + test("windowing.q -- 15. testExpressions") { + // Moved because: + // - Spark uses a different default stddev (sample instead of pop) + // - Tiny numerical differences in stddev results. + // - Different StdDev behavior when n=1 (NaN instead of 0) + checkAnswer(sql(s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |percent_rank() over(distribute by p_mfgr sort by p_name) as pr, + |ntile(3) over(distribute by p_mfgr sort by p_name) as nt, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |avg(p_size) over(distribute by p_mfgr sort by p_name) as avg, + |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, + |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, + |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) + // scalastyle:on + } + + test("windowing.q -- 20. testSTATs") { + // Moved because: + // - Spark uses a different default stddev/variance (sample instead of pop) + // - Tiny numerical differences in aggregation results. + checkAnswer(sql(""" + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( + |select p_mfgr,p_name, p_size, + |stddev_pop(p_retailprice) over w1 as sdev, + |stddev_pop(p_retailprice) over w1 as sdev_pop, + |collect_set(p_size) over w1 as uniq_size, + |var_pop(p_retailprice) over w1 as var, + |corr(p_size, p_retailprice) over w1 as cor, + |covar_pop(p_size, p_retailprice) over w1 as covarp + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 2, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 6, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 34, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 2, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 34, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 2, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 6, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 28, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 34, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 2, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 6, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 28, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 34, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 42, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 6, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 28, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 34, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 42, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 6, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 28, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 42, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 2, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 14, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 40, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 2, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 14, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 25, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 40, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 2, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 14, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 18, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 25, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 40, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 2, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 18, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 25, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 40, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 2, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 18, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 25, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 14, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 17, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 19, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 1, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 14, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 17, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 19, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 1, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 14, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 17, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 19, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 45, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 1, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 14, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 19, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 45, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 1, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 19, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 45, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 10, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 27, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 39, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 7, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 10, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 27, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 39, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 7, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 10, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 12, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 27, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 39, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 7, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 12, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 27, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 39, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 7, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 12, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 27, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 2, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 6, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 31, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 2, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 6, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 31, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 46, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 2, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 6, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 23, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 31, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 46, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 2, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 6, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 23, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 46, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 2, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 23, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) + // scalastyle:on + } +} From f4346f612b6798517153a786f9172cf41618d34d Mon Sep 17 00:00:00 2001 From: jhu-chang Date: Thu, 17 Dec 2015 17:53:15 -0800 Subject: [PATCH 1287/2089] [SPARK-11749][STREAMING] Duplicate creating the RDD in file stream when recovering from checkpoint data Add a transient flag `DStream.restoredFromCheckpointData` to control the restore processing in DStream to avoid duplicate works: check this flag first in `DStream.restoreCheckpointData`, only when `false`, the restore process will be executed. Author: jhu-chang Closes #9765 from jhu-chang/SPARK-11749. --- .../spark/streaming/dstream/DStream.scala | 15 +++-- .../spark/streaming/CheckpointSuite.scala | 56 +++++++++++++++++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1a6edf9473d84..91a43e14a8b1b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] ( private[streaming] val mustCheckpoint = false private[streaming] var checkpointDuration: Duration = null private[streaming] val checkpointData = new DStreamCheckpointData(this) + @transient + private var restoredFromCheckpointData = false // Reference to whole DStream graph private[streaming] var graph: DStreamGraph = null @@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] ( * override the updateCheckpointData() method would also need to override this method. */ private[streaming] def restoreCheckpointData() { - // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data") - checkpointData.restore() - dependencies.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") + if (!restoredFromCheckpointData) { + // Create RDDs from the checkpoint data + logInfo("Restoring checkpoint data") + checkpointData.restore() + dependencies.foreach(_.restoreCheckpointData()) + restoredFromCheckpointData = true + logInfo("Restored checkpoint data") + } } @throws(classOf[IOException]) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index cd28d3cf408d5..f5f446f14a0da 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag @@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} -import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils} + +/** + * A input stream that records the times of restore() invoked + */ +private[streaming] +class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) { + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1))) + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + @transient + var restoredTimes = 0 + override def restore() { + restoredTimes += 1 + super.restore() + } + } +} /** * A trait of that can be mixed in to get methods for testing DStream operations under @@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) } - private def generateOutput[V: ClassTag]( + protected def generateOutput[V: ClassTag]( ssc: StreamingContext, targetBatchTime: Time, checkpointDir: String, @@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { } } + test("DStreamCheckpointData.restore invoking times") { + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val checkpointData = inputDStream.checkpointData + val mappedDStream = inputDStream.map(_ + 100) + val outputStream = new TestOutputStreamWithPartitions(mappedDStream) + outputStream.register() + // do two more times output + mappedDStream.foreachRDD(rdd => rdd.count()) + mappedDStream.foreachRDD(rdd => rdd.count()) + assert(checkpointData.restoredTimes === 0) + val batchDurationMillis = ssc.progressListener.batchDuration + generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true) + assert(checkpointData.restoredTimes === 0) + } + logInfo("*********** RESTARTING ************") + withStreamingContext(new StreamingContext(checkpointDir)) { ssc => + val checkpointData = + ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData + assert(checkpointData.restoredTimes === 1) + ssc.start() + ssc.stop() + assert(checkpointData.restoredTimes === 1) + } + } + // This tests whether spark can deserialize array object // refer to SPARK-5569 test("recovery from checkpoint contains array object") { From 0370abdfd636566cd8df954c6f9ea5a794d275ef Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 17 Dec 2015 18:18:12 -0800 Subject: [PATCH 1288/2089] [MINOR] Hide the error logs for 'SQLListenerMemoryLeakSuite' Hide the error logs for 'SQLListenerMemoryLeakSuite' to avoid noises. Most of changes are space changes. Author: Shixiong Zhu Closes #10363 from zsxwing/hide-log. --- .../sql/execution/ui/SQLListenerSuite.scala | 64 ++++++++++--------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 12a4e1356fed0..11a6ce91116f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -336,39 +336,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { class SQLListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly - .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) + val oldLogLevel = org.apache.log4j.Logger.getRootLogger().getLevel() try { - SQLContext.clearSqlListener() - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - // Run 100 successful executions and 100 failed executions. - // Each execution only has one job and one stage. - for (i <- 0 until 100) { - val df = Seq( - (1, 1), - (2, 2) - ).toDF() - df.collect() - try { - df.foreach(_ => throw new RuntimeException("Oops")) - } catch { - case e: SparkException => // This is expected for a failed job + org.apache.log4j.Logger.getRootLogger().setLevel(org.apache.log4j.Level.FATAL) + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + SQLContext.clearSqlListener() + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() } - sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) - // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) } finally { - sc.stop() + org.apache.log4j.Logger.getRootLogger().setLevel(oldLogLevel) } } } From 40e52a27c74259237dd1906c0e8b54d2ae645dfb Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 18 Dec 2015 00:49:56 -0800 Subject: [PATCH 1289/2089] [CORE][TESTS] minor fix of JavaSerializerSuite Not jira is created. The original test is passed because the class cast is lazy (only when the object's method is invoked). Author: Jeff Zhang Closes #10371 from zjffdu/minor_fix. --- .../apache/spark/serializer/JavaSerializerSuite.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 20f45670bc2ba..6a6ea42797fb6 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -23,13 +23,18 @@ class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - instance.deserialize[JavaSerializer](instance.serialize(serializer)) + val obj = instance.deserialize[JavaSerializer](instance.serialize(serializer)) + // enforce class cast + obj.getClass } test("Deserialize object containing a primitive Class as attribute") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + val obj = instance.deserialize[ContainsPrimitiveClass](instance.serialize( + new ContainsPrimitiveClass())) + // enforce class cast + obj.getClass } } From 2bebaa39d9da33bc93ef682959cd42c1968a6a3e Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Fri, 18 Dec 2015 20:18:00 +0900 Subject: [PATCH 1290/2089] [SPARK-12413] Fix Mesos ZK persistence I believe this fixes SPARK-12413. I'm currently running an integration test to verify. Author: Michael Gummelt Closes #10366 from mgummelt/fix-zk-mesos. --- .../apache/spark/deploy/rest/mesos/MesosRestServer.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index c0b93596508f1..87d0fa8b52fe1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -99,7 +99,11 @@ private[mesos] class MesosSubmitRequestServlet( // cause spark-submit script to look for files in SPARK_HOME instead. // We only need the ability to specify where to find spark-submit script // which user can user spark.executor.home or spark.home configurations. - val environmentVariables = request.environmentVariables.filterKeys(!_.equals("SPARK_HOME")) + // + // Do not use `filterKeys` here to avoid SI-6654, which breaks ZK persistence + val environmentVariables = request.environmentVariables.filter { case (k, _) => + k != "SPARK_HOME" + } val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description From ea59b0f3a6600f8046e5f3f55e89257614fb1f10 Mon Sep 17 00:00:00 2001 From: Jeff L Date: Fri, 18 Dec 2015 15:06:54 +0000 Subject: [PATCH 1291/2089] [SPARK-9057][STREAMING] Twitter example joining to static RDD of word sentiment values Example of joining a static RDD of word sentiments to a streaming RDD of Tweets in order to demo the usage of the transform() method. Author: Jeff L Closes #8431 from Agent007/SPARK-9057. --- data/streaming/AFINN-111.txt | 2477 +++++++++++++++++ .../JavaTwitterHashTagJoinSentiments.java | 180 ++ .../streaming/network_wordjoinsentiments.py | 77 + .../TwitterHashTagJoinSentiments.scala | 96 + 4 files changed, 2830 insertions(+) create mode 100644 data/streaming/AFINN-111.txt create mode 100644 examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java create mode 100644 examples/src/main/python/streaming/network_wordjoinsentiments.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala diff --git a/data/streaming/AFINN-111.txt b/data/streaming/AFINN-111.txt new file mode 100644 index 0000000000000..0f6fb8ebaa0bf --- /dev/null +++ b/data/streaming/AFINN-111.txt @@ -0,0 +1,2477 @@ +abandon -2 +abandoned -2 +abandons -2 +abducted -2 +abduction -2 +abductions -2 +abhor -3 +abhorred -3 +abhorrent -3 +abhors -3 +abilities 2 +ability 2 +aboard 1 +absentee -1 +absentees -1 +absolve 2 +absolved 2 +absolves 2 +absolving 2 +absorbed 1 +abuse -3 +abused -3 +abuses -3 +abusive -3 +accept 1 +accepted 1 +accepting 1 +accepts 1 +accident -2 +accidental -2 +accidentally -2 +accidents -2 +accomplish 2 +accomplished 2 +accomplishes 2 +accusation -2 +accusations -2 +accuse -2 +accused -2 +accuses -2 +accusing -2 +ache -2 +achievable 1 +aching -2 +acquit 2 +acquits 2 +acquitted 2 +acquitting 2 +acrimonious -3 +active 1 +adequate 1 +admire 3 +admired 3 +admires 3 +admiring 3 +admit -1 +admits -1 +admitted -1 +admonish -2 +admonished -2 +adopt 1 +adopts 1 +adorable 3 +adore 3 +adored 3 +adores 3 +advanced 1 +advantage 2 +advantages 2 +adventure 2 +adventures 2 +adventurous 2 +affected -1 +affection 3 +affectionate 3 +afflicted -1 +affronted -1 +afraid -2 +aggravate -2 +aggravated -2 +aggravates -2 +aggravating -2 +aggression -2 +aggressions -2 +aggressive -2 +aghast -2 +agog 2 +agonise -3 +agonised -3 +agonises -3 +agonising -3 +agonize -3 +agonized -3 +agonizes -3 +agonizing -3 +agree 1 +agreeable 2 +agreed 1 +agreement 1 +agrees 1 +alarm -2 +alarmed -2 +alarmist -2 +alarmists -2 +alas -1 +alert -1 +alienation -2 +alive 1 +allergic -2 +allow 1 +alone -2 +amaze 2 +amazed 2 +amazes 2 +amazing 4 +ambitious 2 +ambivalent -1 +amuse 3 +amused 3 +amusement 3 +amusements 3 +anger -3 +angers -3 +angry -3 +anguish -3 +anguished -3 +animosity -2 +annoy -2 +annoyance -2 +annoyed -2 +annoying -2 +annoys -2 +antagonistic -2 +anti -1 +anticipation 1 +anxiety -2 +anxious -2 +apathetic -3 +apathy -3 +apeshit -3 +apocalyptic -2 +apologise -1 +apologised -1 +apologises -1 +apologising -1 +apologize -1 +apologized -1 +apologizes -1 +apologizing -1 +apology -1 +appalled -2 +appalling -2 +appease 2 +appeased 2 +appeases 2 +appeasing 2 +applaud 2 +applauded 2 +applauding 2 +applauds 2 +applause 2 +appreciate 2 +appreciated 2 +appreciates 2 +appreciating 2 +appreciation 2 +apprehensive -2 +approval 2 +approved 2 +approves 2 +ardent 1 +arrest -2 +arrested -3 +arrests -2 +arrogant -2 +ashame -2 +ashamed -2 +ass -4 +assassination -3 +assassinations -3 +asset 2 +assets 2 +assfucking -4 +asshole -4 +astonished 2 +astound 3 +astounded 3 +astounding 3 +astoundingly 3 +astounds 3 +attack -1 +attacked -1 +attacking -1 +attacks -1 +attract 1 +attracted 1 +attracting 2 +attraction 2 +attractions 2 +attracts 1 +audacious 3 +authority 1 +avert -1 +averted -1 +averts -1 +avid 2 +avoid -1 +avoided -1 +avoids -1 +await -1 +awaited -1 +awaits -1 +award 3 +awarded 3 +awards 3 +awesome 4 +awful -3 +awkward -2 +axe -1 +axed -1 +backed 1 +backing 2 +backs 1 +bad -3 +badass -3 +badly -3 +bailout -2 +bamboozle -2 +bamboozled -2 +bamboozles -2 +ban -2 +banish -1 +bankrupt -3 +bankster -3 +banned -2 +bargain 2 +barrier -2 +bastard -5 +bastards -5 +battle -1 +battles -1 +beaten -2 +beatific 3 +beating -1 +beauties 3 +beautiful 3 +beautifully 3 +beautify 3 +belittle -2 +belittled -2 +beloved 3 +benefit 2 +benefits 2 +benefitted 2 +benefitting 2 +bereave -2 +bereaved -2 +bereaves -2 +bereaving -2 +best 3 +betray -3 +betrayal -3 +betrayed -3 +betraying -3 +betrays -3 +better 2 +bias -1 +biased -2 +big 1 +bitch -5 +bitches -5 +bitter -2 +bitterly -2 +bizarre -2 +blah -2 +blame -2 +blamed -2 +blames -2 +blaming -2 +bless 2 +blesses 2 +blessing 3 +blind -1 +bliss 3 +blissful 3 +blithe 2 +block -1 +blockbuster 3 +blocked -1 +blocking -1 +blocks -1 +bloody -3 +blurry -2 +boastful -2 +bold 2 +boldly 2 +bomb -1 +boost 1 +boosted 1 +boosting 1 +boosts 1 +bore -2 +bored -2 +boring -3 +bother -2 +bothered -2 +bothers -2 +bothersome -2 +boycott -2 +boycotted -2 +boycotting -2 +boycotts -2 +brainwashing -3 +brave 2 +breakthrough 3 +breathtaking 5 +bribe -3 +bright 1 +brightest 2 +brightness 1 +brilliant 4 +brisk 2 +broke -1 +broken -1 +brooding -2 +bullied -2 +bullshit -4 +bully -2 +bullying -2 +bummer -2 +buoyant 2 +burden -2 +burdened -2 +burdening -2 +burdens -2 +calm 2 +calmed 2 +calming 2 +calms 2 +can't stand -3 +cancel -1 +cancelled -1 +cancelling -1 +cancels -1 +cancer -1 +capable 1 +captivated 3 +care 2 +carefree 1 +careful 2 +carefully 2 +careless -2 +cares 2 +cashing in -2 +casualty -2 +catastrophe -3 +catastrophic -4 +cautious -1 +celebrate 3 +celebrated 3 +celebrates 3 +celebrating 3 +censor -2 +censored -2 +censors -2 +certain 1 +chagrin -2 +chagrined -2 +challenge -1 +chance 2 +chances 2 +chaos -2 +chaotic -2 +charged -3 +charges -2 +charm 3 +charming 3 +charmless -3 +chastise -3 +chastised -3 +chastises -3 +chastising -3 +cheat -3 +cheated -3 +cheater -3 +cheaters -3 +cheats -3 +cheer 2 +cheered 2 +cheerful 2 +cheering 2 +cheerless -2 +cheers 2 +cheery 3 +cherish 2 +cherished 2 +cherishes 2 +cherishing 2 +chic 2 +childish -2 +chilling -1 +choke -2 +choked -2 +chokes -2 +choking -2 +clarifies 2 +clarity 2 +clash -2 +classy 3 +clean 2 +cleaner 2 +clear 1 +cleared 1 +clearly 1 +clears 1 +clever 2 +clouded -1 +clueless -2 +cock -5 +cocksucker -5 +cocksuckers -5 +cocky -2 +coerced -2 +collapse -2 +collapsed -2 +collapses -2 +collapsing -2 +collide -1 +collides -1 +colliding -1 +collision -2 +collisions -2 +colluding -3 +combat -1 +combats -1 +comedy 1 +comfort 2 +comfortable 2 +comforting 2 +comforts 2 +commend 2 +commended 2 +commit 1 +commitment 2 +commits 1 +committed 1 +committing 1 +compassionate 2 +compelled 1 +competent 2 +competitive 2 +complacent -2 +complain -2 +complained -2 +complains -2 +comprehensive 2 +conciliate 2 +conciliated 2 +conciliates 2 +conciliating 2 +condemn -2 +condemnation -2 +condemned -2 +condemns -2 +confidence 2 +confident 2 +conflict -2 +conflicting -2 +conflictive -2 +conflicts -2 +confuse -2 +confused -2 +confusing -2 +congrats 2 +congratulate 2 +congratulation 2 +congratulations 2 +consent 2 +consents 2 +consolable 2 +conspiracy -3 +constrained -2 +contagion -2 +contagions -2 +contagious -1 +contempt -2 +contemptuous -2 +contemptuously -2 +contend -1 +contender -1 +contending -1 +contentious -2 +contestable -2 +controversial -2 +controversially -2 +convince 1 +convinced 1 +convinces 1 +convivial 2 +cool 1 +cool stuff 3 +cornered -2 +corpse -1 +costly -2 +courage 2 +courageous 2 +courteous 2 +courtesy 2 +cover-up -3 +coward -2 +cowardly -2 +coziness 2 +cramp -1 +crap -3 +crash -2 +crazier -2 +craziest -2 +crazy -2 +creative 2 +crestfallen -2 +cried -2 +cries -2 +crime -3 +criminal -3 +criminals -3 +crisis -3 +critic -2 +criticism -2 +criticize -2 +criticized -2 +criticizes -2 +criticizing -2 +critics -2 +cruel -3 +cruelty -3 +crush -1 +crushed -2 +crushes -1 +crushing -1 +cry -1 +crying -2 +cunt -5 +curious 1 +curse -1 +cut -1 +cute 2 +cuts -1 +cutting -1 +cynic -2 +cynical -2 +cynicism -2 +damage -3 +damages -3 +damn -4 +damned -4 +damnit -4 +danger -2 +daredevil 2 +daring 2 +darkest -2 +darkness -1 +dauntless 2 +dead -3 +deadlock -2 +deafening -1 +dear 2 +dearly 3 +death -2 +debonair 2 +debt -2 +deceit -3 +deceitful -3 +deceive -3 +deceived -3 +deceives -3 +deceiving -3 +deception -3 +decisive 1 +dedicated 2 +defeated -2 +defect -3 +defects -3 +defender 2 +defenders 2 +defenseless -2 +defer -1 +deferring -1 +defiant -1 +deficit -2 +degrade -2 +degraded -2 +degrades -2 +dehumanize -2 +dehumanized -2 +dehumanizes -2 +dehumanizing -2 +deject -2 +dejected -2 +dejecting -2 +dejects -2 +delay -1 +delayed -1 +delight 3 +delighted 3 +delighting 3 +delights 3 +demand -1 +demanded -1 +demanding -1 +demands -1 +demonstration -1 +demoralized -2 +denied -2 +denier -2 +deniers -2 +denies -2 +denounce -2 +denounces -2 +deny -2 +denying -2 +depressed -2 +depressing -2 +derail -2 +derailed -2 +derails -2 +deride -2 +derided -2 +derides -2 +deriding -2 +derision -2 +desirable 2 +desire 1 +desired 2 +desirous 2 +despair -3 +despairing -3 +despairs -3 +desperate -3 +desperately -3 +despondent -3 +destroy -3 +destroyed -3 +destroying -3 +destroys -3 +destruction -3 +destructive -3 +detached -1 +detain -2 +detained -2 +detention -2 +determined 2 +devastate -2 +devastated -2 +devastating -2 +devoted 3 +diamond 1 +dick -4 +dickhead -4 +die -3 +died -3 +difficult -1 +diffident -2 +dilemma -1 +dipshit -3 +dire -3 +direful -3 +dirt -2 +dirtier -2 +dirtiest -2 +dirty -2 +disabling -1 +disadvantage -2 +disadvantaged -2 +disappear -1 +disappeared -1 +disappears -1 +disappoint -2 +disappointed -2 +disappointing -2 +disappointment -2 +disappointments -2 +disappoints -2 +disaster -2 +disasters -2 +disastrous -3 +disbelieve -2 +discard -1 +discarded -1 +discarding -1 +discards -1 +disconsolate -2 +disconsolation -2 +discontented -2 +discord -2 +discounted -1 +discouraged -2 +discredited -2 +disdain -2 +disgrace -2 +disgraced -2 +disguise -1 +disguised -1 +disguises -1 +disguising -1 +disgust -3 +disgusted -3 +disgusting -3 +disheartened -2 +dishonest -2 +disillusioned -2 +disinclined -2 +disjointed -2 +dislike -2 +dismal -2 +dismayed -2 +disorder -2 +disorganized -2 +disoriented -2 +disparage -2 +disparaged -2 +disparages -2 +disparaging -2 +displeased -2 +dispute -2 +disputed -2 +disputes -2 +disputing -2 +disqualified -2 +disquiet -2 +disregard -2 +disregarded -2 +disregarding -2 +disregards -2 +disrespect -2 +disrespected -2 +disruption -2 +disruptions -2 +disruptive -2 +dissatisfied -2 +distort -2 +distorted -2 +distorting -2 +distorts -2 +distract -2 +distracted -2 +distraction -2 +distracts -2 +distress -2 +distressed -2 +distresses -2 +distressing -2 +distrust -3 +distrustful -3 +disturb -2 +disturbed -2 +disturbing -2 +disturbs -2 +dithering -2 +dizzy -1 +dodging -2 +dodgy -2 +does not work -3 +dolorous -2 +dont like -2 +doom -2 +doomed -2 +doubt -1 +doubted -1 +doubtful -1 +doubting -1 +doubts -1 +douche -3 +douchebag -3 +downcast -2 +downhearted -2 +downside -2 +drag -1 +dragged -1 +drags -1 +drained -2 +dread -2 +dreaded -2 +dreadful -3 +dreading -2 +dream 1 +dreams 1 +dreary -2 +droopy -2 +drop -1 +drown -2 +drowned -2 +drowns -2 +drunk -2 +dubious -2 +dud -2 +dull -2 +dumb -3 +dumbass -3 +dump -1 +dumped -2 +dumps -1 +dupe -2 +duped -2 +dysfunction -2 +eager 2 +earnest 2 +ease 2 +easy 1 +ecstatic 4 +eerie -2 +eery -2 +effective 2 +effectively 2 +elated 3 +elation 3 +elegant 2 +elegantly 2 +embarrass -2 +embarrassed -2 +embarrasses -2 +embarrassing -2 +embarrassment -2 +embittered -2 +embrace 1 +emergency -2 +empathetic 2 +emptiness -1 +empty -1 +enchanted 2 +encourage 2 +encouraged 2 +encouragement 2 +encourages 2 +endorse 2 +endorsed 2 +endorsement 2 +endorses 2 +enemies -2 +enemy -2 +energetic 2 +engage 1 +engages 1 +engrossed 1 +enjoy 2 +enjoying 2 +enjoys 2 +enlighten 2 +enlightened 2 +enlightening 2 +enlightens 2 +ennui -2 +enrage -2 +enraged -2 +enrages -2 +enraging -2 +enrapture 3 +enslave -2 +enslaved -2 +enslaves -2 +ensure 1 +ensuring 1 +enterprising 1 +entertaining 2 +enthral 3 +enthusiastic 3 +entitled 1 +entrusted 2 +envies -1 +envious -2 +envy -1 +envying -1 +erroneous -2 +error -2 +errors -2 +escape -1 +escapes -1 +escaping -1 +esteemed 2 +ethical 2 +euphoria 3 +euphoric 4 +eviction -1 +evil -3 +exaggerate -2 +exaggerated -2 +exaggerates -2 +exaggerating -2 +exasperated 2 +excellence 3 +excellent 3 +excite 3 +excited 3 +excitement 3 +exciting 3 +exclude -1 +excluded -2 +exclusion -1 +exclusive 2 +excuse -1 +exempt -1 +exhausted -2 +exhilarated 3 +exhilarates 3 +exhilarating 3 +exonerate 2 +exonerated 2 +exonerates 2 +exonerating 2 +expand 1 +expands 1 +expel -2 +expelled -2 +expelling -2 +expels -2 +exploit -2 +exploited -2 +exploiting -2 +exploits -2 +exploration 1 +explorations 1 +expose -1 +exposed -1 +exposes -1 +exposing -1 +extend 1 +extends 1 +exuberant 4 +exultant 3 +exultantly 3 +fabulous 4 +fad -2 +fag -3 +faggot -3 +faggots -3 +fail -2 +failed -2 +failing -2 +fails -2 +failure -2 +failures -2 +fainthearted -2 +fair 2 +faith 1 +faithful 3 +fake -3 +fakes -3 +faking -3 +fallen -2 +falling -1 +falsified -3 +falsify -3 +fame 1 +fan 3 +fantastic 4 +farce -1 +fascinate 3 +fascinated 3 +fascinates 3 +fascinating 3 +fascist -2 +fascists -2 +fatalities -3 +fatality -3 +fatigue -2 +fatigued -2 +fatigues -2 +fatiguing -2 +favor 2 +favored 2 +favorite 2 +favorited 2 +favorites 2 +favors 2 +fear -2 +fearful -2 +fearing -2 +fearless 2 +fearsome -2 +fed up -3 +feeble -2 +feeling 1 +felonies -3 +felony -3 +fervent 2 +fervid 2 +festive 2 +fiasco -3 +fidgety -2 +fight -1 +fine 2 +fire -2 +fired -2 +firing -2 +fit 1 +fitness 1 +flagship 2 +flees -1 +flop -2 +flops -2 +flu -2 +flustered -2 +focused 2 +fond 2 +fondness 2 +fool -2 +foolish -2 +fools -2 +forced -1 +foreclosure -2 +foreclosures -2 +forget -1 +forgetful -2 +forgive 1 +forgiving 1 +forgotten -1 +fortunate 2 +frantic -1 +fraud -4 +frauds -4 +fraudster -4 +fraudsters -4 +fraudulence -4 +fraudulent -4 +free 1 +freedom 2 +frenzy -3 +fresh 1 +friendly 2 +fright -2 +frightened -2 +frightening -3 +frikin -2 +frisky 2 +frowning -1 +frustrate -2 +frustrated -2 +frustrates -2 +frustrating -2 +frustration -2 +ftw 3 +fuck -4 +fucked -4 +fucker -4 +fuckers -4 +fuckface -4 +fuckhead -4 +fucking -4 +fucktard -4 +fud -3 +fuked -4 +fuking -4 +fulfill 2 +fulfilled 2 +fulfills 2 +fuming -2 +fun 4 +funeral -1 +funerals -1 +funky 2 +funnier 4 +funny 4 +furious -3 +futile 2 +gag -2 +gagged -2 +gain 2 +gained 2 +gaining 2 +gains 2 +gallant 3 +gallantly 3 +gallantry 3 +generous 2 +genial 3 +ghost -1 +giddy -2 +gift 2 +glad 3 +glamorous 3 +glamourous 3 +glee 3 +gleeful 3 +gloom -1 +gloomy -2 +glorious 2 +glory 2 +glum -2 +god 1 +goddamn -3 +godsend 4 +good 3 +goodness 3 +grace 1 +gracious 3 +grand 3 +grant 1 +granted 1 +granting 1 +grants 1 +grateful 3 +gratification 2 +grave -2 +gray -1 +great 3 +greater 3 +greatest 3 +greed -3 +greedy -2 +green wash -3 +green washing -3 +greenwash -3 +greenwasher -3 +greenwashers -3 +greenwashing -3 +greet 1 +greeted 1 +greeting 1 +greetings 2 +greets 1 +grey -1 +grief -2 +grieved -2 +gross -2 +growing 1 +growth 2 +guarantee 1 +guilt -3 +guilty -3 +gullibility -2 +gullible -2 +gun -1 +ha 2 +hacked -1 +haha 3 +hahaha 3 +hahahah 3 +hail 2 +hailed 2 +hapless -2 +haplessness -2 +happiness 3 +happy 3 +hard -1 +hardier 2 +hardship -2 +hardy 2 +harm -2 +harmed -2 +harmful -2 +harming -2 +harms -2 +harried -2 +harsh -2 +harsher -2 +harshest -2 +hate -3 +hated -3 +haters -3 +hates -3 +hating -3 +haunt -1 +haunted -2 +haunting 1 +haunts -1 +havoc -2 +healthy 2 +heartbreaking -3 +heartbroken -3 +heartfelt 3 +heaven 2 +heavenly 4 +heavyhearted -2 +hell -4 +help 2 +helpful 2 +helping 2 +helpless -2 +helps 2 +hero 2 +heroes 2 +heroic 3 +hesitant -2 +hesitate -2 +hid -1 +hide -1 +hides -1 +hiding -1 +highlight 2 +hilarious 2 +hindrance -2 +hoax -2 +homesick -2 +honest 2 +honor 2 +honored 2 +honoring 2 +honour 2 +honoured 2 +honouring 2 +hooligan -2 +hooliganism -2 +hooligans -2 +hope 2 +hopeful 2 +hopefully 2 +hopeless -2 +hopelessness -2 +hopes 2 +hoping 2 +horrendous -3 +horrible -3 +horrific -3 +horrified -3 +hostile -2 +huckster -2 +hug 2 +huge 1 +hugs 2 +humerous 3 +humiliated -3 +humiliation -3 +humor 2 +humorous 2 +humour 2 +humourous 2 +hunger -2 +hurrah 5 +hurt -2 +hurting -2 +hurts -2 +hypocritical -2 +hysteria -3 +hysterical -3 +hysterics -3 +idiot -3 +idiotic -3 +ignorance -2 +ignorant -2 +ignore -1 +ignored -2 +ignores -1 +ill -2 +illegal -3 +illiteracy -2 +illness -2 +illnesses -2 +imbecile -3 +immobilized -1 +immortal 2 +immune 1 +impatient -2 +imperfect -2 +importance 2 +important 2 +impose -1 +imposed -1 +imposes -1 +imposing -1 +impotent -2 +impress 3 +impressed 3 +impresses 3 +impressive 3 +imprisoned -2 +improve 2 +improved 2 +improvement 2 +improves 2 +improving 2 +inability -2 +inaction -2 +inadequate -2 +incapable -2 +incapacitated -2 +incensed -2 +incompetence -2 +incompetent -2 +inconsiderate -2 +inconvenience -2 +inconvenient -2 +increase 1 +increased 1 +indecisive -2 +indestructible 2 +indifference -2 +indifferent -2 +indignant -2 +indignation -2 +indoctrinate -2 +indoctrinated -2 +indoctrinates -2 +indoctrinating -2 +ineffective -2 +ineffectively -2 +infatuated 2 +infatuation 2 +infected -2 +inferior -2 +inflamed -2 +influential 2 +infringement -2 +infuriate -2 +infuriated -2 +infuriates -2 +infuriating -2 +inhibit -1 +injured -2 +injury -2 +injustice -2 +innovate 1 +innovates 1 +innovation 1 +innovative 2 +inquisition -2 +inquisitive 2 +insane -2 +insanity -2 +insecure -2 +insensitive -2 +insensitivity -2 +insignificant -2 +insipid -2 +inspiration 2 +inspirational 2 +inspire 2 +inspired 2 +inspires 2 +inspiring 3 +insult -2 +insulted -2 +insulting -2 +insults -2 +intact 2 +integrity 2 +intelligent 2 +intense 1 +interest 1 +interested 2 +interesting 2 +interests 1 +interrogated -2 +interrupt -2 +interrupted -2 +interrupting -2 +interruption -2 +interrupts -2 +intimidate -2 +intimidated -2 +intimidates -2 +intimidating -2 +intimidation -2 +intricate 2 +intrigues 1 +invincible 2 +invite 1 +inviting 1 +invulnerable 2 +irate -3 +ironic -1 +irony -1 +irrational -1 +irresistible 2 +irresolute -2 +irresponsible 2 +irreversible -1 +irritate -3 +irritated -3 +irritating -3 +isolated -1 +itchy -2 +jackass -4 +jackasses -4 +jailed -2 +jaunty 2 +jealous -2 +jeopardy -2 +jerk -3 +jesus 1 +jewel 1 +jewels 1 +jocular 2 +join 1 +joke 2 +jokes 2 +jolly 2 +jovial 2 +joy 3 +joyful 3 +joyfully 3 +joyless -2 +joyous 3 +jubilant 3 +jumpy -1 +justice 2 +justifiably 2 +justified 2 +keen 1 +kill -3 +killed -3 +killing -3 +kills -3 +kind 2 +kinder 2 +kiss 2 +kudos 3 +lack -2 +lackadaisical -2 +lag -1 +lagged -2 +lagging -2 +lags -2 +lame -2 +landmark 2 +laugh 1 +laughed 1 +laughing 1 +laughs 1 +laughting 1 +launched 1 +lawl 3 +lawsuit -2 +lawsuits -2 +lazy -1 +leak -1 +leaked -1 +leave -1 +legal 1 +legally 1 +lenient 1 +lethargic -2 +lethargy -2 +liar -3 +liars -3 +libelous -2 +lied -2 +lifesaver 4 +lighthearted 1 +like 2 +liked 2 +likes 2 +limitation -1 +limited -1 +limits -1 +litigation -1 +litigious -2 +lively 2 +livid -2 +lmao 4 +lmfao 4 +loathe -3 +loathed -3 +loathes -3 +loathing -3 +lobby -2 +lobbying -2 +lol 3 +lonely -2 +lonesome -2 +longing -1 +loom -1 +loomed -1 +looming -1 +looms -1 +loose -3 +looses -3 +loser -3 +losing -3 +loss -3 +lost -3 +lovable 3 +love 3 +loved 3 +lovelies 3 +lovely 3 +loving 2 +lowest -1 +loyal 3 +loyalty 3 +luck 3 +luckily 3 +lucky 3 +lugubrious -2 +lunatic -3 +lunatics -3 +lurk -1 +lurking -1 +lurks -1 +mad -3 +maddening -3 +made-up -1 +madly -3 +madness -3 +mandatory -1 +manipulated -1 +manipulating -1 +manipulation -1 +marvel 3 +marvelous 3 +marvels 3 +masterpiece 4 +masterpieces 4 +matter 1 +matters 1 +mature 2 +meaningful 2 +meaningless -2 +medal 3 +mediocrity -3 +meditative 1 +melancholy -2 +menace -2 +menaced -2 +mercy 2 +merry 3 +mess -2 +messed -2 +messing up -2 +methodical 2 +mindless -2 +miracle 4 +mirth 3 +mirthful 3 +mirthfully 3 +misbehave -2 +misbehaved -2 +misbehaves -2 +misbehaving -2 +mischief -1 +mischiefs -1 +miserable -3 +misery -2 +misgiving -2 +misinformation -2 +misinformed -2 +misinterpreted -2 +misleading -3 +misread -1 +misreporting -2 +misrepresentation -2 +miss -2 +missed -2 +missing -2 +mistake -2 +mistaken -2 +mistakes -2 +mistaking -2 +misunderstand -2 +misunderstanding -2 +misunderstands -2 +misunderstood -2 +moan -2 +moaned -2 +moaning -2 +moans -2 +mock -2 +mocked -2 +mocking -2 +mocks -2 +mongering -2 +monopolize -2 +monopolized -2 +monopolizes -2 +monopolizing -2 +moody -1 +mope -1 +moping -1 +moron -3 +motherfucker -5 +motherfucking -5 +motivate 1 +motivated 2 +motivating 2 +motivation 1 +mourn -2 +mourned -2 +mournful -2 +mourning -2 +mourns -2 +mumpish -2 +murder -2 +murderer -2 +murdering -3 +murderous -3 +murders -2 +myth -1 +n00b -2 +naive -2 +nasty -3 +natural 1 +naïve -2 +needy -2 +negative -2 +negativity -2 +neglect -2 +neglected -2 +neglecting -2 +neglects -2 +nerves -1 +nervous -2 +nervously -2 +nice 3 +nifty 2 +niggas -5 +nigger -5 +no -1 +no fun -3 +noble 2 +noisy -1 +nonsense -2 +noob -2 +nosey -2 +not good -2 +not working -3 +notorious -2 +novel 2 +numb -1 +nuts -3 +obliterate -2 +obliterated -2 +obnoxious -3 +obscene -2 +obsessed 2 +obsolete -2 +obstacle -2 +obstacles -2 +obstinate -2 +odd -2 +offend -2 +offended -2 +offender -2 +offending -2 +offends -2 +offline -1 +oks 2 +ominous 3 +once-in-a-lifetime 3 +opportunities 2 +opportunity 2 +oppressed -2 +oppressive -2 +optimism 2 +optimistic 2 +optionless -2 +outcry -2 +outmaneuvered -2 +outrage -3 +outraged -3 +outreach 2 +outstanding 5 +overjoyed 4 +overload -1 +overlooked -1 +overreact -2 +overreacted -2 +overreaction -2 +overreacts -2 +oversell -2 +overselling -2 +oversells -2 +oversimplification -2 +oversimplified -2 +oversimplifies -2 +oversimplify -2 +overstatement -2 +overstatements -2 +overweight -1 +oxymoron -1 +pain -2 +pained -2 +panic -3 +panicked -3 +panics -3 +paradise 3 +paradox -1 +pardon 2 +pardoned 2 +pardoning 2 +pardons 2 +parley -1 +passionate 2 +passive -1 +passively -1 +pathetic -2 +pay -1 +peace 2 +peaceful 2 +peacefully 2 +penalty -2 +pensive -1 +perfect 3 +perfected 2 +perfectly 3 +perfects 2 +peril -2 +perjury -3 +perpetrator -2 +perpetrators -2 +perplexed -2 +persecute -2 +persecuted -2 +persecutes -2 +persecuting -2 +perturbed -2 +pesky -2 +pessimism -2 +pessimistic -2 +petrified -2 +phobic -2 +picturesque 2 +pileup -1 +pique -2 +piqued -2 +piss -4 +pissed -4 +pissing -3 +piteous -2 +pitied -1 +pity -2 +playful 2 +pleasant 3 +please 1 +pleased 3 +pleasure 3 +poised -2 +poison -2 +poisoned -2 +poisons -2 +pollute -2 +polluted -2 +polluter -2 +polluters -2 +pollutes -2 +poor -2 +poorer -2 +poorest -2 +popular 3 +positive 2 +positively 2 +possessive -2 +postpone -1 +postponed -1 +postpones -1 +postponing -1 +poverty -1 +powerful 2 +powerless -2 +praise 3 +praised 3 +praises 3 +praising 3 +pray 1 +praying 1 +prays 1 +prblm -2 +prblms -2 +prepared 1 +pressure -1 +pressured -2 +pretend -1 +pretending -1 +pretends -1 +pretty 1 +prevent -1 +prevented -1 +preventing -1 +prevents -1 +prick -5 +prison -2 +prisoner -2 +prisoners -2 +privileged 2 +proactive 2 +problem -2 +problems -2 +profiteer -2 +progress 2 +prominent 2 +promise 1 +promised 1 +promises 1 +promote 1 +promoted 1 +promotes 1 +promoting 1 +propaganda -2 +prosecute -1 +prosecuted -2 +prosecutes -1 +prosecution -1 +prospect 1 +prospects 1 +prosperous 3 +protect 1 +protected 1 +protects 1 +protest -2 +protesters -2 +protesting -2 +protests -2 +proud 2 +proudly 2 +provoke -1 +provoked -1 +provokes -1 +provoking -1 +pseudoscience -3 +punish -2 +punished -2 +punishes -2 +punitive -2 +pushy -1 +puzzled -2 +quaking -2 +questionable -2 +questioned -1 +questioning -1 +racism -3 +racist -3 +racists -3 +rage -2 +rageful -2 +rainy -1 +rant -3 +ranter -3 +ranters -3 +rants -3 +rape -4 +rapist -4 +rapture 2 +raptured 2 +raptures 2 +rapturous 4 +rash -2 +ratified 2 +reach 1 +reached 1 +reaches 1 +reaching 1 +reassure 1 +reassured 1 +reassures 1 +reassuring 2 +rebellion -2 +recession -2 +reckless -2 +recommend 2 +recommended 2 +recommends 2 +redeemed 2 +refuse -2 +refused -2 +refusing -2 +regret -2 +regretful -2 +regrets -2 +regretted -2 +regretting -2 +reject -1 +rejected -1 +rejecting -1 +rejects -1 +rejoice 4 +rejoiced 4 +rejoices 4 +rejoicing 4 +relaxed 2 +relentless -1 +reliant 2 +relieve 1 +relieved 2 +relieves 1 +relieving 2 +relishing 2 +remarkable 2 +remorse -2 +repulse -1 +repulsed -2 +rescue 2 +rescued 2 +rescues 2 +resentful -2 +resign -1 +resigned -1 +resigning -1 +resigns -1 +resolute 2 +resolve 2 +resolved 2 +resolves 2 +resolving 2 +respected 2 +responsible 2 +responsive 2 +restful 2 +restless -2 +restore 1 +restored 1 +restores 1 +restoring 1 +restrict -2 +restricted -2 +restricting -2 +restriction -2 +restricts -2 +retained -1 +retard -2 +retarded -2 +retreat -1 +revenge -2 +revengeful -2 +revered 2 +revive 2 +revives 2 +reward 2 +rewarded 2 +rewarding 2 +rewards 2 +rich 2 +ridiculous -3 +rig -1 +rigged -1 +right direction 3 +rigorous 3 +rigorously 3 +riot -2 +riots -2 +risk -2 +risks -2 +rob -2 +robber -2 +robed -2 +robing -2 +robs -2 +robust 2 +rofl 4 +roflcopter 4 +roflmao 4 +romance 2 +rotfl 4 +rotflmfao 4 +rotflol 4 +ruin -2 +ruined -2 +ruining -2 +ruins -2 +sabotage -2 +sad -2 +sadden -2 +saddened -2 +sadly -2 +safe 1 +safely 1 +safety 1 +salient 1 +sappy -1 +sarcastic -2 +satisfied 2 +save 2 +saved 2 +scam -2 +scams -2 +scandal -3 +scandalous -3 +scandals -3 +scapegoat -2 +scapegoats -2 +scare -2 +scared -2 +scary -2 +sceptical -2 +scold -2 +scoop 3 +scorn -2 +scornful -2 +scream -2 +screamed -2 +screaming -2 +screams -2 +screwed -2 +screwed up -3 +scumbag -4 +secure 2 +secured 2 +secures 2 +sedition -2 +seditious -2 +seduced -1 +self-confident 2 +self-deluded -2 +selfish -3 +selfishness -3 +sentence -2 +sentenced -2 +sentences -2 +sentencing -2 +serene 2 +severe -2 +sexy 3 +shaky -2 +shame -2 +shamed -2 +shameful -2 +share 1 +shared 1 +shares 1 +shattered -2 +shit -4 +shithead -4 +shitty -3 +shock -2 +shocked -2 +shocking -2 +shocks -2 +shoot -1 +short-sighted -2 +short-sightedness -2 +shortage -2 +shortages -2 +shrew -4 +shy -1 +sick -2 +sigh -2 +significance 1 +significant 1 +silencing -1 +silly -1 +sincere 2 +sincerely 2 +sincerest 2 +sincerity 2 +sinful -3 +singleminded -2 +skeptic -2 +skeptical -2 +skepticism -2 +skeptics -2 +slam -2 +slash -2 +slashed -2 +slashes -2 +slashing -2 +slavery -3 +sleeplessness -2 +slick 2 +slicker 2 +slickest 2 +sluggish -2 +slut -5 +smart 1 +smarter 2 +smartest 2 +smear -2 +smile 2 +smiled 2 +smiles 2 +smiling 2 +smog -2 +sneaky -1 +snub -2 +snubbed -2 +snubbing -2 +snubs -2 +sobering 1 +solemn -1 +solid 2 +solidarity 2 +solution 1 +solutions 1 +solve 1 +solved 1 +solves 1 +solving 1 +somber -2 +some kind 0 +son-of-a-bitch -5 +soothe 3 +soothed 3 +soothing 3 +sophisticated 2 +sore -1 +sorrow -2 +sorrowful -2 +sorry -1 +spam -2 +spammer -3 +spammers -3 +spamming -2 +spark 1 +sparkle 3 +sparkles 3 +sparkling 3 +speculative -2 +spirit 1 +spirited 2 +spiritless -2 +spiteful -2 +splendid 3 +sprightly 2 +squelched -1 +stab -2 +stabbed -2 +stable 2 +stabs -2 +stall -2 +stalled -2 +stalling -2 +stamina 2 +stampede -2 +startled -2 +starve -2 +starved -2 +starves -2 +starving -2 +steadfast 2 +steal -2 +steals -2 +stereotype -2 +stereotyped -2 +stifled -1 +stimulate 1 +stimulated 1 +stimulates 1 +stimulating 2 +stingy -2 +stolen -2 +stop -1 +stopped -1 +stopping -1 +stops -1 +stout 2 +straight 1 +strange -1 +strangely -1 +strangled -2 +strength 2 +strengthen 2 +strengthened 2 +strengthening 2 +strengthens 2 +stressed -2 +stressor -2 +stressors -2 +stricken -2 +strike -1 +strikers -2 +strikes -1 +strong 2 +stronger 2 +strongest 2 +struck -1 +struggle -2 +struggled -2 +struggles -2 +struggling -2 +stubborn -2 +stuck -2 +stunned -2 +stunning 4 +stupid -2 +stupidly -2 +suave 2 +substantial 1 +substantially 1 +subversive -2 +success 2 +successful 3 +suck -3 +sucks -3 +suffer -2 +suffering -2 +suffers -2 +suicidal -2 +suicide -2 +suing -2 +sulking -2 +sulky -2 +sullen -2 +sunshine 2 +super 3 +superb 5 +superior 2 +support 2 +supported 2 +supporter 1 +supporters 1 +supporting 1 +supportive 2 +supports 2 +survived 2 +surviving 2 +survivor 2 +suspect -1 +suspected -1 +suspecting -1 +suspects -1 +suspend -1 +suspended -1 +suspicious -2 +swear -2 +swearing -2 +swears -2 +sweet 2 +swift 2 +swiftly 2 +swindle -3 +swindles -3 +swindling -3 +sympathetic 2 +sympathy 2 +tard -2 +tears -2 +tender 2 +tense -2 +tension -1 +terrible -3 +terribly -3 +terrific 4 +terrified -3 +terror -3 +terrorize -3 +terrorized -3 +terrorizes -3 +thank 2 +thankful 2 +thanks 2 +thorny -2 +thoughtful 2 +thoughtless -2 +threat -2 +threaten -2 +threatened -2 +threatening -2 +threatens -2 +threats -2 +thrilled 5 +thwart -2 +thwarted -2 +thwarting -2 +thwarts -2 +timid -2 +timorous -2 +tired -2 +tits -2 +tolerant 2 +toothless -2 +top 2 +tops 2 +torn -2 +torture -4 +tortured -4 +tortures -4 +torturing -4 +totalitarian -2 +totalitarianism -2 +tout -2 +touted -2 +touting -2 +touts -2 +tragedy -2 +tragic -2 +tranquil 2 +trap -1 +trapped -2 +trauma -3 +traumatic -3 +travesty -2 +treason -3 +treasonous -3 +treasure 2 +treasures 2 +trembling -2 +tremulous -2 +tricked -2 +trickery -2 +triumph 4 +triumphant 4 +trouble -2 +troubled -2 +troubles -2 +true 2 +trust 1 +trusted 2 +tumor -2 +twat -5 +ugly -3 +unacceptable -2 +unappreciated -2 +unapproved -2 +unaware -2 +unbelievable -1 +unbelieving -1 +unbiased 2 +uncertain -1 +unclear -1 +uncomfortable -2 +unconcerned -2 +unconfirmed -1 +unconvinced -1 +uncredited -1 +undecided -1 +underestimate -1 +underestimated -1 +underestimates -1 +underestimating -1 +undermine -2 +undermined -2 +undermines -2 +undermining -2 +undeserving -2 +undesirable -2 +uneasy -2 +unemployment -2 +unequal -1 +unequaled 2 +unethical -2 +unfair -2 +unfocused -2 +unfulfilled -2 +unhappy -2 +unhealthy -2 +unified 1 +unimpressed -2 +unintelligent -2 +united 1 +unjust -2 +unlovable -2 +unloved -2 +unmatched 1 +unmotivated -2 +unprofessional -2 +unresearched -2 +unsatisfied -2 +unsecured -2 +unsettled -1 +unsophisticated -2 +unstable -2 +unstoppable 2 +unsupported -2 +unsure -1 +untarnished 2 +unwanted -2 +unworthy -2 +upset -2 +upsets -2 +upsetting -2 +uptight -2 +urgent -1 +useful 2 +usefulness 2 +useless -2 +uselessness -2 +vague -2 +validate 1 +validated 1 +validates 1 +validating 1 +verdict -1 +verdicts -1 +vested 1 +vexation -2 +vexing -2 +vibrant 3 +vicious -2 +victim -3 +victimize -3 +victimized -3 +victimizes -3 +victimizing -3 +victims -3 +vigilant 3 +vile -3 +vindicate 2 +vindicated 2 +vindicates 2 +vindicating 2 +violate -2 +violated -2 +violates -2 +violating -2 +violence -3 +violent -3 +virtuous 2 +virulent -2 +vision 1 +visionary 3 +visioning 1 +visions 1 +vitality 3 +vitamin 1 +vitriolic -3 +vivacious 3 +vociferous -1 +vulnerability -2 +vulnerable -2 +walkout -2 +walkouts -2 +wanker -3 +want 1 +war -2 +warfare -2 +warm 1 +warmth 2 +warn -2 +warned -2 +warning -3 +warnings -3 +warns -2 +waste -1 +wasted -2 +wasting -2 +wavering -1 +weak -2 +weakness -2 +wealth 3 +wealthy 2 +weary -2 +weep -2 +weeping -2 +weird -2 +welcome 2 +welcomed 2 +welcomes 2 +whimsical 1 +whitewash -3 +whore -4 +wicked -2 +widowed -1 +willingness 2 +win 4 +winner 4 +winning 4 +wins 4 +winwin 3 +wish 1 +wishes 1 +wishing 1 +withdrawal -3 +woebegone -2 +woeful -3 +won 3 +wonderful 4 +woo 3 +woohoo 3 +wooo 4 +woow 4 +worn -1 +worried -3 +worry -3 +worrying -3 +worse -3 +worsen -3 +worsened -3 +worsening -3 +worsens -3 +worshiped 3 +worst -3 +worth 2 +worthless -2 +worthy 2 +wow 4 +wowow 4 +wowww 4 +wrathful -3 +wreck -2 +wrong -2 +wronged -2 +wtf -4 +yeah 1 +yearning 1 +yeees 2 +yes 1 +youthful 2 +yucky -2 +yummy 3 +zealot -2 +zealots -2 +zealous 2 \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java new file mode 100644 index 0000000000000..030ee30b93381 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import org.apache.commons.io.IOUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.twitter.TwitterUtils; +import scala.Tuple2; +import twitter4j.Status; + +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +/** + * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of + * the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) + */ +public class JavaTwitterHashTagJoinSentiments { + + public static void main(String[] args) throws IOException { + if (args.length < 4) { + System.err.println("Usage: JavaTwitterHashTagJoinSentiments " + + " []"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + String consumerKey = args[0]; + String consumerSecret = args[1]; + String accessToken = args[2]; + String accessTokenSecret = args[3]; + String[] filters = Arrays.copyOfRange(args, 4, args.length); + + // Set the system properties so that Twitter4j library used by Twitter stream + // can use them to generate OAuth credentials + System.setProperty("twitter4j.oauth.consumerKey", consumerKey); + System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret); + System.setProperty("twitter4j.oauth.accessToken", accessToken); + System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret); + + SparkConf sparkConf = new SparkConf().setAppName("JavaTwitterHashTagJoinSentiments"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); + JavaReceiverInputDStream stream = TwitterUtils.createStream(jssc, filters); + + JavaDStream words = stream.flatMap(new FlatMapFunction() { + @Override + public Iterable call(Status s) { + return Arrays.asList(s.getText().split(" ")); + } + }); + + JavaDStream hashTags = words.filter(new Function() { + @Override + public Boolean call(String word) throws Exception { + return word.startsWith("#"); + } + }); + + // Read in the word-sentiment list and create a static RDD from it + String wordSentimentFilePath = "data/streaming/AFINN-111.txt"; + final JavaPairRDD wordSentiments = jssc.sparkContext().textFile(wordSentimentFilePath) + .mapToPair(new PairFunction(){ + @Override + public Tuple2 call(String line) { + String[] columns = line.split("\t"); + return new Tuple2(columns[0], + Double.parseDouble(columns[1])); + } + }); + + JavaPairDStream hashTagCount = hashTags.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + // leave out the # character + return new Tuple2(s.substring(1), 1); + } + }); + + JavaPairDStream hashTagTotals = hashTagCount.reduceByKeyAndWindow( + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }, new Duration(10000)); + + // Determine the hash tags with the highest sentiment values by joining the streaming RDD + // with the static RDD inside the transform() method and then multiplying + // the frequency of the hash tag by its sentiment value + JavaPairDStream> joinedTuples = + hashTagTotals.transformToPair(new Function, + JavaPairRDD>>() { + @Override + public JavaPairRDD> call(JavaPairRDD topicCount) + throws Exception { + return wordSentiments.join(topicCount); + } + }); + + JavaPairDStream topicHappiness = joinedTuples.mapToPair( + new PairFunction>, String, Double>() { + @Override + public Tuple2 call(Tuple2> topicAndTuplePair) throws Exception { + Tuple2 happinessAndCount = topicAndTuplePair._2(); + return new Tuple2(topicAndTuplePair._1(), + happinessAndCount._1() * happinessAndCount._2()); + } + }); + + JavaPairDStream happinessTopicPairs = topicHappiness.mapToPair( + new PairFunction, Double, String>() { + @Override + public Tuple2 call(Tuple2 topicHappiness) + throws Exception { + return new Tuple2(topicHappiness._2(), + topicHappiness._1()); + } + }); + + JavaPairDStream happiest10 = happinessTopicPairs.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD happinessAndTopics) throws Exception { + return happinessAndTopics.sortByKey(false); + } + } + ); + + // Print hash tags with the most positive sentiment values + happiest10.foreachRDD(new Function, Void>() { + @Override + public Void call(JavaPairRDD happinessTopicPairs) throws Exception { + List> topList = happinessTopicPairs.take(10); + System.out.println( + String.format("\nHappiest topics in last 10 seconds (%s total):", + happinessTopicPairs.count())); + for (Tuple2 pair : topList) { + System.out.println( + String.format("%s (%s happiness)", pair._2(), pair._1())); + } + return null; + } + }); + + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py new file mode 100644 index 0000000000000..b85517dfdd913 --- /dev/null +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network + every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list + (http://neuro.imm.dtu.dk/wiki/AFINN) + + Usage: network_wordjoinsentiments.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordjoinsentiments.py \ + localhost 9999` +""" + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def print_happiest_words(rdd): + top_list = rdd.take(5) + print("Happiest topics in the last 5 seconds (%d total):" % rdd.count()) + for tuple in top_list: + print("%s (%d happiness)" % (tuple[1], tuple[0])) + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") + ssc = StreamingContext(sc, 5) + + # Read in the word-sentiment list and create a static RDD from it + word_sentiments_file_path = "data/streaming/AFINN-111.txt" + word_sentiments = ssc.sparkContext.textFile(word_sentiments_file_path) \ + .map(lambda line: tuple(line.split("\t"))) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + + word_counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a + b) + + # Determine the words with the highest sentiment values by joining the streaming RDD + # with the static RDD inside the transform() method and then multiplying + # the frequency of the words by its sentiment value + happiest_words = word_counts.transform(lambda rdd: word_sentiments.join(rdd)) \ + .map(lambda (word, tuple): (word, float(tuple[0]) * tuple[1])) \ + .map(lambda (word, happiness): (happiness, word)) \ + .transform(lambda rdd: rdd.sortByKey(False)) + + happiest_words.foreachRDD(print_happiest_words) + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala new file mode 100644 index 0000000000000..0328fa81ea897 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.twitter.TwitterUtils +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of + * the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) + */ +object TwitterHashTagJoinSentiments { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: TwitterHashTagJoinSentiments " + + " []") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(consumerKey, consumerSecret, accessToken, accessTokenSecret) = args.take(4) + val filters = args.takeRight(args.length - 4) + + // Set the system properties so that Twitter4j library used by Twitter stream + // can use them to generate OAuth credentials + System.setProperty("twitter4j.oauth.consumerKey", consumerKey) + System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret) + System.setProperty("twitter4j.oauth.accessToken", accessToken) + System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret) + + val sparkConf = new SparkConf().setAppName("TwitterHashTagJoinSentiments") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + val stream = TwitterUtils.createStream(ssc, None, filters) + + val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) + + // Read in the word-sentiment list and create a static RDD from it + val wordSentimentFilePath = "data/streaming/AFINN-111.txt" + val wordSentiments = ssc.sparkContext.textFile(wordSentimentFilePath).map { line => + val Array(word, happinessValue) = line.split("\t") + (word, happinessValue) + } cache() + + // Determine the hash tags with the highest sentiment values by joining the streaming RDD + // with the static RDD inside the transform() method and then multiplying + // the frequency of the hash tag by its sentiment value + val happiest60 = hashTags.map(hashTag => (hashTag.tail, 1)) + .reduceByKeyAndWindow(_ + _, Seconds(60)) + .transform{topicCount => wordSentiments.join(topicCount)} + .map{case (topic, tuple) => (topic, tuple._1 * tuple._2)} + .map{case (topic, happinessValue) => (happinessValue, topic)} + .transform(_.sortByKey(false)) + + val happiest10 = hashTags.map(hashTag => (hashTag.tail, 1)) + .reduceByKeyAndWindow(_ + _, Seconds(10)) + .transform{topicCount => wordSentiments.join(topicCount)} + .map{case (topic, tuple) => (topic, tuple._1 * tuple._2)} + .map{case (topic, happinessValue) => (happinessValue, topic)} + .transform(_.sortByKey(false)) + + // Print hash tags with the most positive sentiment values + happiest60.foreachRDD(rdd => { + val topList = rdd.take(10) + println("\nHappiest topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (happiness, tag) => println("%s (%s happiness)".format(tag, happiness))} + }) + + happiest10.foreachRDD(rdd => { + val topList = rdd.take(10) + println("\nHappiest topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (happiness, tag) => println("%s (%s happiness)".format(tag, happiness))} + }) + + ssc.start() + ssc.awaitTermination() + } +} +// scalastyle:on println From 2782818287a71925523c1320291db6cb25221e9f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 18 Dec 2015 09:49:08 -0800 Subject: [PATCH 1292/2089] [SPARK-12350][CORE] Don't log errors when requested stream is not found. If a client requests a non-existent stream, just send a failure message back, without logging any error on the server side (since it's not a server error). On the executor side, avoid error logs by translating any errors during transfer to a `ClassNotFoundException`, so that loading the class is retried on a the parent class loader. This can mask IO errors during transmission, but the most common cause is that the class is not served by the remote end. Author: Marcelo Vanzin Closes #10337 from vanzin/SPARK-12350. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 17 +++++++-------- .../spark/rpc/netty/NettyStreamManager.scala | 7 +++++-- .../spark/network/server/StreamManager.java | 1 + .../server/TransportRequestHandler.java | 7 ++++++- .../spark/repl/ExecutorClassLoader.scala | 21 +++++++++++++++++-- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index de3db6ba624f8..975ea1a1ab2a8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -363,15 +363,14 @@ private[netty] class NettyRpcEnv( } override def read(dst: ByteBuffer): Int = { - val result = if (error == null) { - Try(source.read(dst)) - } else { - Failure(error) - } - - result match { + Try(source.read(dst)) match { case Success(bytesRead) => bytesRead - case Failure(error) => throw error + case Failure(readErr) => + if (error != null) { + throw error + } else { + throw readErr + } } } @@ -397,7 +396,7 @@ private[netty] class NettyRpcEnv( } override def onFailure(streamId: String, cause: Throwable): Unit = { - logError(s"Error downloading stream $streamId.", cause) + logDebug(s"Error downloading stream $streamId.", cause) source.setError(cause) sink.close() } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index 394cde4fa076d..afcb023a99daa 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -58,8 +58,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) new File(dir, fname) } - require(file != null && file.isFile(), s"File not found: $streamId") - new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + if (file != null && file.isFile()) { + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } else { + null + } } override def addFile(file: File): String = { diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 3f0155957a140..07f161a29cfb8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -54,6 +54,7 @@ public abstract class StreamManager { * {@link #getChunk(long, int)} method. * * @param streamId id of a stream that has been previously registered with the StreamManager. + * @return A managed buffer for the stream, or null if the stream was not found. */ public ManagedBuffer openStream(String streamId) { throw new UnsupportedOperationException(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index c864d7ce16bd3..105f53883167a 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -141,7 +141,12 @@ private void processStreamRequest(final StreamRequest req) { return; } - respond(new StreamResponse(req.streamId, buf.size(), buf)); + if (buf != null) { + respond(new StreamResponse(req.streamId, buf.size(), buf)); + } else { + respond(new StreamFailure(req.streamId, String.format( + "Stream '%s' was not found.", req.streamId))); + } } private void processRpcRequest(final RpcRequest req) { diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index da8f0aa1e3360..de7b831adc736 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,7 +17,7 @@ package org.apache.spark.repl -import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.io.{FilterInputStream, ByteArrayOutputStream, InputStream, IOException} import java.net.{HttpURLConnection, URI, URL, URLEncoder} import java.nio.channels.Channels @@ -96,7 +96,24 @@ class ExecutorClassLoader( private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { val channel = env.rpcEnv.openChannel(s"$classUri/$path") - Channels.newInputStream(channel) + new FilterInputStream(Channels.newInputStream(channel)) { + + override def read(): Int = toClassNotFound(super.read()) + + override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) + + override def read(b: Array[Byte], offset: Int, len: Int) = + toClassNotFound(super.read(b, offset, len)) + + private def toClassNotFound(fn: => Int): Int = { + try { + fn + } catch { + case e: Exception => + throw new ClassNotFoundException(path, e) + } + } + } } private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { From ee444fe4b8c9f382524e1fa346c67ba6da8104d8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 18 Dec 2015 09:54:30 -0800 Subject: [PATCH 1293/2089] [SPARK-11619][SQL] cannot use UDTF in DataFrame.selectExpr Description of the problem from cloud-fan Actually this line: https://github.com/apache/spark/blob/branch-1.5/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala#L689 When we use `selectExpr`, we pass in `UnresolvedFunction` to `DataFrame.select` and fall in the last case. A workaround is to do special handling for UDTF like we did for `explode`(and `json_tuple` in 1.6), wrap it with `MultiAlias`. Another workaround is using `expr`, for example, `df.select(expr("explode(a)").as(Nil))`, I think `selectExpr` is no longer needed after we have the `expr` function.... Author: Dilip Biswal Closes #9981 from dilipbiswal/spark-11619. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++------ .../spark/sql/catalyst/analysis/unresolved.scala | 6 +++++- .../src/main/scala/org/apache/spark/sql/Column.scala | 12 +++++++----- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ .../org/apache/spark/sql/JsonFunctionsSuite.scala | 4 ++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 7 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 64dd83a915711..c396546b4c005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -149,12 +149,12 @@ class Analyzer( exprs.zipWithIndex.map { case (expr, i) => expr transform { - case u @ UnresolvedAlias(child) => child match { + case u @ UnresolvedAlias(child, optionalAliasName) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case other => Alias(other, s"_c$i")() + case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -287,7 +287,7 @@ class Analyzer( } } val newGroupByExprs = groupByExprs.map { - case UnresolvedAlias(e) => e + case UnresolvedAlias(e, _) => e case e => e } Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) @@ -352,19 +352,19 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => val newChildren = expandStarExpressions(args, child) UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => val newChildren = expandStarExpressions(args, child) Alias(child = f.copy(children = newChildren), name)() :: Nil - case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => + case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => + case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 4f89b462a6ce3..64cad6ee787d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -284,8 +284,12 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) /** * Holds the expression that has yet to be aliased. + * + * @param child The computation that is needs to be resolved during analysis. + * @param aliasName The name if specified to be asoosicated with the result of computing [[child]] + * */ -case class UnresolvedAlias(child: Expression) +case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 297ef2299cb36..5026c0d6d12b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression - import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.functions.lit +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.SqlParser._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ - private[sql] object Column { def apply(colName: String): Column = new Column(colName) @@ -130,8 +129,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { // Leave an unaliased generator with an empty list of names since the analyzer will generate // the correct defaults after the nested expression's type has been resolved. case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + case expr: Expression => Alias(expr, expr.prettyString)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 79b4244ac0cd1..d201d65238523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -450,7 +450,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def groupBy(cols: Column*): GroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = sqlContext.executePlan(withKey) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0644bdaaa35ce..4c3e12af7203d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -176,6 +176,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select("key").collect().toSeq) } + test("selectExpr with udtf") { + val df = Seq((Map("1" -> 1), 1)).toDF("a", "b") + checkAnswer( + df.selectExpr("explode(a)"), + Row("1", 1) :: Nil) + } + test("filterExpr") { val res = testData.collect().filter(_.getInt(0) > 90).toSeq checkAnswer(testData.filter("key > 90"), res) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 1f384edf321b0..1391c9d57ff7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -73,6 +73,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), expected) + + checkAnswer( + df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), + expected) } test("json_tuple filter and group") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index da41b659e3fce..0e89928cb636d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1107,7 +1107,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => From 4af647c77ded6a0d3087ceafb2e30e01d97e7a06 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 18 Dec 2015 10:09:17 -0800 Subject: [PATCH 1294/2089] [SPARK-12054] [SQL] Consider nullability of expression in codegen This could simplify the generated code for expressions that is not nullable. This PR fix lots of bugs about nullability. Author: Davies Liu Closes #10333 from davies/skip_nullable. --- .../catalyst/expressions/BoundAttribute.scala | 17 +++- .../spark/sql/catalyst/expressions/Cast.scala | 28 +++--- .../sql/catalyst/expressions/Expression.scala | 95 ++++++++++++------- .../aggregate/CentralMomentAgg.scala | 2 +- .../catalyst/expressions/aggregate/Corr.scala | 3 +- .../expressions/aggregate/Count.scala | 19 +++- .../catalyst/expressions/aggregate/Sum.scala | 24 +++-- .../expressions/codegen/CodegenFallback.scala | 27 ++++-- .../codegen/GenerateMutableProjection.scala | 65 ++++++++----- .../codegen/GenerateUnsafeProjection.scala | 21 ++-- .../expressions/complexTypeExtractors.scala | 19 ++-- .../expressions/datetimeExpressions.scala | 4 + .../expressions/decimalExpressions.scala | 1 + .../expressions/jsonExpressions.scala | 15 +-- .../expressions/mathExpressions.scala | 5 + .../spark/sql/catalyst/expressions/misc.scala | 1 + .../expressions/stringExpressions.scala | 1 + .../expressions/windowExpressions.scala | 4 +- .../plans/logical/basicOperators.scala | 9 +- .../sql/catalyst/expressions/CastSuite.scala | 10 +- .../expressions/ComplexTypeSuite.scala | 1 - .../org/apache/spark/sql/DataFrame.scala | 21 ++-- .../apache/spark/sql/execution/Window.scala | 9 +- .../apache/spark/sql/execution/commands.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../sql/execution/joins/HashOuterJoin.scala | 80 +--------------- .../execution/joins/SortMergeOuterJoin.scala | 2 +- 27 files changed, 261 insertions(+), 226 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ff1f28ddbbf35..7293d5d4472af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + if (nullable) { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + ev.isNull = "false" + s""" + $javaType ${ev.value} = $value; + """ + } } } @@ -92,7 +99,7 @@ object BindReferences extends Logging { sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } } else { - BoundReference(ordinal, a.dataType, a.nullable) + BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index cb60d5958d535..b18f49f3203fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -87,18 +87,22 @@ object Cast { private def resolvableNullability(from: Boolean, to: Boolean) = !from || to private def forceNullable(from: DataType, to: DataType) = (from, to) match { - case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true - case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null + case (NullType, _) => true + case (_, _) if from == to => false + + case (StringType, BinaryType) => false + case (StringType, _) => true + case (_, StringType) => false + + case (FloatType | DoubleType, TimestampType) => true + case (TimestampType, DateType) => false + case (_, DateType) => true + case (DateType, TimestampType) => false + case (DateType, _) => true + case (_, CalendarIntervalType) => true + + case (_, _: DecimalType) => true // overflow + case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6d807c9ecf302..6a9c12127d367 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -340,14 +340,21 @@ abstract class UnaryExpression extends Expression { ev: GeneratedExpressionCode, f: String => String): String = { val eval = child.gen(ctx) - val resultCode = f(eval.value) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $resultCode - } - """ + if (nullable) { + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${eval.isNull}) { + ${f(eval.value)} + } + """ + } else { + ev.isNull = "false" + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${f(eval.value)} + """ + } } } @@ -424,19 +431,30 @@ abstract class BinaryExpression extends Expression { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.value, eval2.value) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; + if (nullable) { + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } } - } - """ + """ + + } else { + ev.isNull = "false" + s""" + ${eval1.code} + ${eval2.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } @@ -548,20 +566,31 @@ abstract class TernaryExpression extends Expression { f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - s""" - ${evals(0).code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode + if (nullable) { + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } } } - } - """ + """ + } else { + ev.isNull = "false" + s""" + ${evals(0).code} + ${evals(1).code} + ${evals(2).code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index d07d4c338cdfe..30f602227b17d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 00d7436b710d2..d25f3335ffd93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +41,7 @@ case class Corr( override def children: Seq[Expression] = Seq(left, right) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 441f52ab5ca58..663c69e799fbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -31,7 +31,7 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) - private lazy val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -39,15 +39,24 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) - ) + override lazy val updateExpressions = { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } + } override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override lazy val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index cfb042e0aa782..08a67ea3df51d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - // TODO: Remove this line once we remove the NullType from inputTypes. - case NullType => IntegerType case _ => child.dataType } @@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate { /* sum = */ Literal.create(null, sumDataType) ) - override lazy val updateExpressions: Seq[Expression] = Seq( - /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) - ) + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + } else { + Seq( + /* sum = */ + Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + ) + } + } override lazy val mergeExpressions: Seq[Expression] = { - val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ - Coalesce(Seq(add, sum.left)) + Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) ) } - override lazy val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 26fb143d1e45c..80c5e41baa927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -32,14 +32,23 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this.toCommentSafeString} */ - java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + if (nullable) { + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } else { + ev.isNull = "false" + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 40189f0877764..a6ec242589fa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,38 +44,55 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; - """ + if (e.nullable) { + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } else { + val value = s"value_$i" + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$value = ${evaluationCode.value}; + """ + } } val updates = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ + if (e.nullable) { + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + s""" + if (this.isNull_$i) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } else { + s""" + if (this.isNull_$i) { + mutableRow.setNullAt($i); + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } } else { s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; """ } + } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 68005afb21d2e..c1defe12b0b91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -135,14 +135,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$rowWriter.write($index, ${input.value});" } - s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { + if (input.isNull == "false") { + s""" + ${input.code} ${writeField.trim} - } - """ + """ + } else { + s""" + ${input.code} + if (${input.isNull}) { + ${setNull.trim} + } else { + ${writeField.trim} + } + """ + } } s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 58f6a7ec8a5f5..c5ed173eeb9dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -115,13 +115,19 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { + if (nullable) { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + } else { + s""" ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ + """ + } }) } } @@ -139,7 +145,6 @@ case class GetArrayStructFields( containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable || containsNull || field.nullable override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03c39f8404e78..311540e33576e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -340,6 +340,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { Seq(TypeCollection(StringType, DateType, TimestampType), StringType) override def dataType: DataType = LongType + override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] @@ -455,6 +456,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) @@ -561,6 +563,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def nullSafeEval(start: Any, dayOfW: Any): Any = { val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) @@ -832,6 +835,7 @@ case class TruncDate(date: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def prettyName: String = "trunc" private lazy val truncLevel: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 78f6631e46474..c54bcdd774021 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -47,6 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" protected override def nullSafeEval(input: Any): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 4991b9cb54e5e..72b323587c63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{StringWriter, ByteArrayOutputStream} +import java.io.{ByteArrayOutputStream, StringWriter} + +import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import scala.util.parsing.combinator.RegexParsers - private[this] sealed trait PathInstruction private[this] object PathInstruction { private[expressions] case object Subscript extends PathInstruction @@ -108,15 +109,17 @@ private[this] object SharedFactory { case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - import SharedFactory._ + import com.fasterxml.jackson.core.JsonToken._ + import PathInstruction._ + import SharedFactory._ import WriteStyle._ - import com.fasterxml.jackson.core.JsonToken._ override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def prettyName: String = "get_json_object" @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 28f616fbb9ca5..9c1a3294def24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -75,6 +75,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) abstract class UnaryLogExpression(f: Double => Double, name: String) extends UnaryMathExpression(f, name) { + override def nullable: Boolean = true + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity protected val yAsymptote: Double = 0.0 @@ -194,6 +196,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { NumberConverter.convert( @@ -621,6 +624,8 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + override def nullable: Boolean = true + protected override def nullSafeEval(input1: Any, input2: Any): Any = { val dLeft = input1.asInstanceOf[Double] val dRight = input2.asInstanceOf[Double] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0f6d02f2e00c2..5baab4f7e8c51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -57,6 +57,7 @@ case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8770c4b76c2e5..50c8b9d59847e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -924,6 +924,7 @@ case class FormatNumber(x: Expression, d: Expression) override def left: Expression = x override def right: Expression = d override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) // Associated with the pattern, for the last d value, and we will update the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 06252ac4fc616..91f169e7eac4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -329,7 +329,7 @@ abstract class OffsetWindowFunction */ override def foldable: Boolean = input.foldable && (default == null || default.foldable) - override def nullable: Boolean = input.nullable && (default == null || default.nullable) + override def nullable: Boolean = default == null || default.nullable override lazy val frame = { // This will be triggered by the Analyzer. @@ -381,7 +381,7 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF self: Product => override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) override def dataType: DataType = IntegerType - override def nullable: Boolean = false + override def nullable: Boolean = true override def supportsPartial: Boolean = false override lazy val mergeExpressions = throw new UnsupportedOperationException("Window Functions do not support merging.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5665fd7e5f419..ec42b763f18ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -293,7 +293,14 @@ private[sql] object Expand { Literal.create(bitmask, IntegerType) }) } - Expand(projections, child.output :+ gid, child) + val output = child.output.map { attr => + if (groupByExprs.exists(_.semanticEquals(attr))) { + attr.withNullability(true) + } else { + attr + } + } + Expand(projections, output :+ gid, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a98e16c253214..c99a4ac9645ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from string") { assert(cast("abcdef", StringType).nullable === false) assert(cast("abcdef", BinaryType).nullable === false) - assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === true) assert(cast("abcdef", TimestampType).nullable === true) assert(cast("abcdef", LongType).nullable === true) assert(cast("abcdef", IntegerType).nullable === true) @@ -547,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Seq(null, true, false)) } @@ -606,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { @@ -713,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("a", BooleanType, nullable = true), StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, InternalRow(null, true, false)) } @@ -754,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructType(Seq( StructField("l", LongType, nullable = true))))))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Row( Seq(123, null, null), Map("a" -> null, "b" -> true, "c" -> false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9f1b19253e7c2..9c1688b261aa8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d74131231499d..965eaa9efec41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -25,22 +25,20 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -455,7 +453,8 @@ class DataFrame private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + .analyzed.asInstanceOf[Join] val condition = usingColumns.map { col => catalyst.expressions.EqualTo( @@ -473,15 +472,15 @@ class DataFrame private[sql]( usingColumns.map(col => withPlan(joined.right).resolve(col)) case FullOuter => usingColumns.map { col => - val leftCol = withPlan(joined.left).resolve(col) - val rightCol = withPlan(joined.right).resolve(col) + val leftCol = withPlan(joined.left).resolve(col).toAttribute.withNullability(true) + val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) Alias(Coalesce(Seq(leftCol, rightCol)), col)() } } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId - val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil) - val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId)) + val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references)) + val resultCols = joinedCols ++ joined.output.filterNot(joinRefs.contains(_)) withPlan { Project( resultCols, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 9852b6e7beeba..c941d673c7248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -440,16 +440,17 @@ private[execution] final class OffsetWindowFunctionFrame( /** Create the projection. */ private[this] val projection = { // Collect the expressions and bind them. - val numInputAttributes = inputSchema.size + val inputAttrs = inputSchema.map(_.withNullability(true)) + val numInputAttributes = inputAttrs.size val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { case e: OffsetWindowFunction => - val input = BindReferences.bindReference(e.input, inputSchema) + val input = BindReferences.bindReference(e.input, inputAttrs) if (e.default == null || e.default.foldable && e.default.eval() == null) { // Without default value. input } else { // With default value. - val default = BindReferences.bindReference(e.default, inputSchema).transform { + val default = BindReferences.bindReference(e.default, inputAttrs).transform { // Shift the input reference to its default version. case BoundReference(o, dataType, nullable) => BoundReference(o + numInputAttributes, dataType, nullable) @@ -457,7 +458,7 @@ private[execution] final class OffsetWindowFunctionFrame( org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) } case e => - BindReferences.bindReference(e, inputSchema) + BindReferences.bindReference(e, inputAttrs) } // Create the projection. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 24a79f289aa81..e2dc13d66c61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -232,7 +232,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = - Seq(AttributeReference("plan", StringType, nullable = false)()), + Seq(AttributeReference("plan", StringType, nullable = true)()), extended: Boolean = false) extends RunnableCommand { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e7deeff13dc4d..e759c011e75d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -42,7 +42,7 @@ case class DescribeCommand( new MetadataBuilder().putString("comment", "name of the column").build())(), AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = false, + AttributeReference("comment", StringType, nullable = true, new MetadataBuilder().putString("comment", "comment of the column").build())() ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index ed626fef56af7..c6e5868187518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -75,7 +75,7 @@ trait HashOuterJoin { UnsafeProjection.create(streamedKeys, streamedPlan.output) protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(self.schema) + UnsafeProjection.create(output, output) @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -151,82 +151,4 @@ trait HashOuterJoin { } ret.iterator } - - protected[this] def fullOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - rightIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[InternalRow] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if boundCondition(joinedRow.withRight(r)) => - numOutputRows += 1 - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - resultProjection(joinedRow) - - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if !rightMatchedSet.contains(idx) => - numOutputRows += 1 - resultProjection(joinedRow(leftNullRow, r)) - } - } else { - leftIter.iterator.map[InternalRow] { l => - numOutputRows += 1 - resultProjection(joinedRow(l, rightNullRow)) - } ++ rightIter.iterator.map[InternalRow] { r => - numOutputRows += 1 - resultProjection(joinedRow(leftNullRow, r)) - } - } - } - - // This is only used by FullOuter - protected[this] def buildHashTable( - iter: Iterator[InternalRow], - numIterRows: LongSQLMetric, - keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = { - val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]() - while (iter.hasNext) { - val currentRow = iter.next() - numIterRows += 1 - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[InternalRow]() - hashTable.put(rowKey.copy(), existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index efaa69c1d3227..7ce38ebdb3413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -114,7 +114,7 @@ case class SortMergeOuterJoin( (r: InternalRow) => true } } - val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema) + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case LeftOuter => From 41ee7c57abd9f52065fd7ffb71a8af229603371d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 18 Dec 2015 10:52:14 -0800 Subject: [PATCH 1295/2089] [SPARK-12218][SQL] Invalid splitting of nested AND expressions in Data Source filter API JIRA: https://issues.apache.org/jira/browse/SPARK-12218 When creating filters for Parquet/ORC, we should not push nested AND expressions partially. Author: Yin Huai Closes #10362 from yhuai/SPARK-12218. --- .../datasources/parquet/ParquetFilters.scala | 12 +++++++++- .../parquet/ParquetFilterSuite.scala | 19 ++++++++++++++++ .../spark/sql/hive/orc/OrcFilters.scala | 22 +++++++++---------- .../hive/orc/OrcHadoopFsRelationSuite.scala | 20 +++++++++++++++++ 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 07714329370a5..883013bf1bfc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -257,7 +257,17 @@ private[sql] object ParquetFilters { makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.And(lhs, rhs) => - (createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and) + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + lhsFilter <- createFilter(schema, lhs) + rhsFilter <- createFilter(schema, rhs) + } yield FilterApi.and(lhsFilter, rhsFilter) case sources.Or(lhs, rhs) => for { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 6178e37d2a585..045425f282ad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -362,4 +362,23 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-12218: 'Not' is included in Parquet filter pushdown") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + + checkAnswer( + sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 27193f54d3a91..ebfb1759b8b96 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -74,22 +74,20 @@ private[orc] object OrcFilters extends Logging { expression match { case And(left, right) => - val tryLeft = buildSearchArgument(left, newBuilder) - val tryRight = buildSearchArgument(right, newBuilder) - - val conjunction = for { - _ <- tryLeft - _ <- tryRight + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) lhs <- buildSearchArgument(left, builder.startAnd()) rhs <- buildSearchArgument(right, lhs) } yield rhs.end() - // For filter `left AND right`, we can still push down `left` even if `right` is not - // convertible, and vice versa. - conjunction - .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) - .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) - case Or(left, right) => for { _ <- buildSearchArgument(left, newBuilder) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 92043d66c914f..e8a61123d18b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ @@ -60,4 +61,23 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } + + test("SPARK-12218: 'Not' is included in ORC filter pushdown") { + import testImplicits._ + + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path) + + checkAnswer( + sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + + checkAnswer( + sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + } + } + } } From 6eba655259d2bcea27d0147b37d5d1e476e85422 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 18 Dec 2015 14:05:06 -0800 Subject: [PATCH 1296/2089] [SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes the object passed is not serializable. For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly makes `StaticInvoke`. ``` case class TimestampContainer(timestamp: java.sql.Timestamp) val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis)) val df = rdd.toDF val ds = df.as[TimestampContainer] val rdd2 = ds.rdd <----------------- invokes extractorsFor indirectory ``` I'll add test cases. Author: Kousuke Saruta Author: Michael Armbrust Closes #10357 from sarutak/SPARK-12404. --- .../sql/catalyst/JavaTypeInference.scala | 12 ++--- .../spark/sql/catalyst/ScalaReflection.scala | 16 +++--- .../sql/catalyst/encoders/RowEncoder.scala | 14 ++--- .../sql/catalyst/expressions/objects.scala | 8 ++- .../apache/spark/sql/JavaDatasetSuite.java | 52 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++++ 6 files changed, 88 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index c8ee87e8819f2..f566d1b3caebf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -194,7 +194,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaDate", getPath :: Nil, @@ -202,7 +202,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", getPath :: Nil, @@ -276,7 +276,7 @@ object JavaTypeInference { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", keyData :: valueData :: Nil) @@ -341,21 +341,21 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ecff8605706de..c1b1d5cd2dee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -223,7 +223,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", getPath :: Nil, @@ -231,7 +231,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", getPath :: Nil, @@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) @@ -548,28 +548,28 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d34ec9408ae1b..63bdf05ca7c28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -61,21 +61,21 @@ object RowEncoder { case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case _: DecimalType => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) @@ -172,14 +172,14 @@ object RowEncoder { case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", input :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", input :: Nil) @@ -197,7 +197,7 @@ object RowEncoder { "array", ObjectType(classOf[Array[_]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -210,7 +210,7 @@ object RowEncoder { val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 10ec75eca37f2..492cc9bf4146c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -42,16 +42,14 @@ import org.apache.spark.sql.types._ * of calling the function. */ case class StaticInvoke( - staticObject: Any, + staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, propagateNull: Boolean = true) extends Expression { - val objectName = staticObject match { - case c: Class[_] => c.getName - case other => other.getClass.getName.stripSuffix("$") - } + val objectName = staticObject.getName.stripSuffix("$") + override def nullable: Boolean = true override def children: Seq[Expression] = arguments diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 383a2d0badb53..0dbaeb81c7ec9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -39,6 +39,7 @@ import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; @@ -608,6 +609,44 @@ public int hashCode() { } } + public class SimpleJavaBean2 implements Serializable { + private Timestamp a; + private Date b; + private java.math.BigDecimal c; + + public Timestamp getA() { return a; } + + public void setA(Timestamp a) { this.a = a; } + + public Date getB() { return b; } + + public void setB(Date b) { this.b = b; } + + public java.math.BigDecimal getC() { return c; } + + public void setC(java.math.BigDecimal c) { this.c = c; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (!a.equals(that.a)) return false; + if (!b.equals(that.b)) return false; + return c.equals(that.c); + } + + @Override + public int hashCode() { + int result = a.hashCode(); + result = 31 * result + b.hashCode(); + result = 31 * result + c.hashCode(); + return result; + } + } + public class NestedJavaBean implements Serializable { private SimpleJavaBean a; @@ -689,4 +728,17 @@ public void testJavaBeanEncoder() { .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } + + @Test + public void testJavaBeanEncoder2() { + // This is a regression test of SPARK-12404 + OuterScopes.addOuterScope(this); + SimpleJavaBean2 obj = new SimpleJavaBean2(); + obj.setA(new Timestamp(0)); + obj.setB(new Date(0)); + obj.setC(java.math.BigDecimal.valueOf(1)); + Dataset ds = + context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + ds.collect(); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f1b6b98dc160c..de012a9a56454 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.sql.{Date, Timestamp} import scala.language.postfixOps @@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + + test("SPARK-12404: Datatype Helper Serializablity") { + val ds = sparkContext.parallelize(( + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() + + ds.collect() + } + test("collect, first, and take should use encoders for serialization") { val item = NonSerializableCaseClass("abcd") val ds = Seq(item).toDS() From 2377b707f25449f4557bf048bb384c743d9008e5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 18 Dec 2015 15:24:41 -0800 Subject: [PATCH 1297/2089] [SPARK-11985][STREAMING][KINESIS][DOCS] Update Kinesis docs - Provide example on `message handler` - Provide bit on KPL record de-aggregation - Fix typos Author: Burak Yavuz Closes #9970 from brkyvz/kinesis-docs. --- docs/streaming-kinesis-integration.md | 54 ++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 238a911a9199f..07194b0a6b758 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -23,7 +23,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m **Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.** -2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream as follows: +2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream of byte array as follows:
    @@ -36,7 +36,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the [Running the Example](#running-the-example) subsection for instructions on how to run the example.
    @@ -49,7 +49,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
    @@ -60,18 +60,47 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
    - - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + You may also provide a "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. This is currently only supported in Scala and Java. - - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis +
    +
    + + import org.apache.spark.streaming.Duration + import org.apache.spark.streaming.kinesis._ + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + val kinesisStream = KinesisUtils.createStream[T]( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2, + [message handler]) + +
    +
    + + import org.apache.spark.streaming.Duration; + import org.apache.spark.streaming.kinesis.*; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + + JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2, + [message handler], [class T]); + +
    +
    + + - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + + - `[Kinesis app name]`: The application name that will be used to checkpoint the Kinesis sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - If the table exists but has incorrect checkpoint information (for a different stream, or - old expired sequenced numbers), then there may be temporary errors. + old expired sequenced numbers), then there may be temporary errors. - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. @@ -83,6 +112,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + - `[message handler]`: A function that takes a Kinesis `Record` and outputs generic `T`. + In other versions of the API, you can also specify the AWS access key and secret key directly. 3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). @@ -99,7 +130,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m Spark Streaming Kinesis Architecture

    @@ -165,11 +196,16 @@ To run the example, This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. +#### Record De-aggregation + +When data is generated using the [Kinesis Producer Library (KPL)](http://docs.aws.amazon.com/kinesis/latest/dev/developing-producers-with-kpl.html), messages may be aggregated for cost savings. Spark Streaming will automatically +de-aggregate records during consumption. + #### Kinesis Checkpointing - Each Kinesis input DStream periodically stores the current position of the stream in the backing DynamoDB table. This allows the system to recover from failures and continue processing where the DStream left off. - Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. - If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable. -- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). +- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). - InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. From 60da0e11f6724d86df16795a7a1166879215d547 Mon Sep 17 00:00:00 2001 From: Grace Date: Fri, 18 Dec 2015 16:04:42 -0800 Subject: [PATCH 1298/2089] [SPARK-9552] Return "false" while nothing to kill in killExecutors In discussion (SPARK-9552), we proposed a force kill in `killExecutors`. But if there is nothing to kill, it will return back with true (acknowledgement). And then, it causes the certain executor(s) (which is not eligible to kill) adding to pendingToRemove list for further actions. In this patch, we'd like to change the return semantics. If there is nothing to kill, we will return "false". and therefore all those non-eligible executors won't be added to the pendingToRemove list. vanzin andrewor14 As the follow up of PR#7888, please let me know your comments. Author: Grace Author: Jie Huang Author: Andrew Or Closes #9796 from GraceH/emptyPendingToRemove. --- .../spark/ExecutorAllocationManager.scala | 4 +-- .../CoarseGrainedSchedulerBackend.scala | 8 +++-- .../StandaloneDynamicAllocationSuite.scala | 29 +++++++++++-------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6176e258989db..4926cafaed1b0 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -423,7 +423,8 @@ private[spark] class ExecutorAllocationManager( executorsPendingToRemove.add(executorId) true } else { - logWarning(s"Unable to reach the cluster manager to kill executor $executorId!") + logWarning(s"Unable to reach the cluster manager to kill executor $executorId," + + s"or no executor eligible to kill!") false } } @@ -524,7 +525,6 @@ private[spark] class ExecutorAllocationManager( private def onExecutorBusy(executorId: String): Unit = synchronized { logDebug(s"Clearing idle timer for $executorId because it is now running a task") removeTimes.remove(executorId) - executorsPendingToRemove.remove(executorId) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7efe16749e59d..2279e8cad7bcf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -471,7 +471,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. - * @return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. If list to kill is empty, it will return + * false. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { killExecutors(executorIds, replace = false, force = false) @@ -487,7 +488,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones * @param force whether to force kill busy executors - * @return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. If list to kill is empty, it will return + * false. */ final def killExecutors( executorIds: Seq[String], @@ -516,7 +518,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp numPendingExecutors += knownExecutors.size } - doKillExecutors(executorsToKill) + !executorsToKill.isEmpty && doKillExecutors(executorsToKill) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 2fa795f846667..314517d296049 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -365,7 +365,7 @@ class StandaloneDynamicAllocationSuite val executors = getExecutorIds(sc) assert(executors.size === 2) assert(sc.killExecutor(executors.head)) - assert(sc.killExecutor(executors.head)) + assert(!sc.killExecutor(executors.head)) val apps = getApplications() assert(apps.head.executors.size === 1) // The limit should not be lowered twice @@ -386,23 +386,28 @@ class StandaloneDynamicAllocationSuite // the driver refuses to kill executors it does not know about syncExecutors(sc) val executors = getExecutorIds(sc) + val executorIdsBefore = executors.toSet assert(executors.size === 2) - // kill executor 1, and replace it + // kill and replace an executor assert(sc.killAndReplaceExecutor(executors.head)) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.head.executors.size === 2) + val executorIdsAfter = getExecutorIds(sc).toSet + // make sure the executor was killed and replaced + assert(executorIdsBefore != executorIdsAfter) } - var apps = getApplications() - // kill executor 1 - assert(sc.killExecutor(executors.head)) - apps = getApplications() - assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === 2) - // kill executor 2 - assert(sc.killExecutor(executors(1))) - apps = getApplications() + // kill old executor (which is killedAndReplaced) should fail + assert(!sc.killExecutor(executors.head)) + + // refresh executors list + val newExecutors = getExecutorIds(sc) + syncExecutors(sc) + + // kill newly created executor and do not replace it + assert(sc.killExecutor(newExecutors(1))) + val apps = getApplications() assert(apps.head.executors.size === 1) assert(apps.head.getExecutorLimit === 1) } @@ -430,7 +435,7 @@ class StandaloneDynamicAllocationSuite val executorIdToTaskCount = taskScheduler invokePrivate getMap() executorIdToTaskCount(executors.head) = 1 // kill the busy executor without force; this should fail - assert(killExecutor(sc, executors.head, force = false)) + assert(!killExecutor(sc, executors.head, force = false)) apps = getApplications() assert(apps.head.executors.size === 2) From 0514e8d4b69615ba8918649e7e3c46b5713b6540 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 18 Dec 2015 16:05:18 -0800 Subject: [PATCH 1299/2089] [SPARK-12411][CORE] Decrease executor heartbeat timeout to match heartbeat interval Previously, the rpc timeout was the default network timeout, which is the same value the driver uses to determine dead executors. This means if there is a network issue, the executor is determined dead after one heartbeat attempt. There is a separate config for the heartbeat interval which is a better value to use for the heartbeat RPC. With this change, the executor will make multiple heartbeat attempts even with RPC issues. Author: Nong Li Closes #10365 from nongli/spark-12411. --- core/src/main/scala/org/apache/spark/executor/Executor.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 552b644d13aaf..9b14184364246 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} @@ -445,7 +446,8 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) if (response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") env.blockManager.reregister() From 007a32f90af1065bfa3ca4cdb194c40c06e87abf Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 18 Dec 2015 16:06:37 -0800 Subject: [PATCH 1300/2089] [SPARK-11097][CORE] Add channelActive callback to RpcHandler to monitor the new connections Added `channelActive` to `RpcHandler` so that `NettyRpcHandler` doesn't need `clients` any more. Author: Shixiong Zhu Closes #10301 from zsxwing/network-events. --- .../mesos/MesosExternalShuffleService.scala | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 17 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 148 ++++++++++-------- .../rpc/netty/NettyRpcHandlerSuite.scala | 6 +- .../client/TransportResponseHandler.java | 6 +- .../spark/network/sasl/SaslRpcHandler.java | 9 +- .../spark/network/server/MessageHandler.java | 7 +- .../spark/network/server/RpcHandler.java | 9 +- .../server/TransportChannelHandler.java | 21 ++- .../server/TransportRequestHandler.java | 9 +- .../spark/network/sasl/SparkSaslSuite.java | 6 +- 11 files changed, 148 insertions(+), 92 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 8ffcfc0878a42..4172d924c802d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -65,7 +65,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo /** * On connection termination, clean up shuffle files written by the associated application. */ - override def connectionTerminated(client: TransportClient): Unit = { + override def channelInactive(client: TransportClient): Unit = { val address = client.getSocketAddress if (connectedApps.contains(address)) { val appId = connectedApps(address) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 975ea1a1ab2a8..090a1b9f6e366 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -548,10 +548,6 @@ private[netty] class NettyRpcHandler( nettyEnv: NettyRpcEnv, streamManager: StreamManager) extends RpcHandler with Logging { - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() - // A variable to track the remote RpcEnv addresses of all clients private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() @@ -574,9 +570,6 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { - dispatcher.postToAll(RemoteProcessConnected(clientAddr)) - } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. @@ -613,10 +606,16 @@ private[netty] class NettyRpcHandler( } } - override def connectionTerminated(client: TransportClient): Unit = { + override def channelActive(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) + } + + override def channelInactive(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - clients.remove(client) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 49e3e0191c38a..7b3a17c17233a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -484,10 +484,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("network events") { + /** + * Setup an [[RpcEndpoint]] to collect all network events. + * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + */ + private def setupNetworkEndpoint( + _env: RpcEnv, + name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => @@ -507,83 +513,97 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) + (ref, events) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - // anotherEnv is connected in client mode, so the remote address may be unknown depending on - // the implementation. Account for that when doing checks. - if (remoteAddress != null) { - assert(events === List(("onConnected", remoteAddress))) - } else { - assert(events.size === 1) - assert(events(0)._1 === "onConnected") + test("network events in sever RpcEnv when another RpcEnv is in server mode") { + val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") + val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") + try { + val serverRefInServer2 = + serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name) + // Send a message to set up the connection + serverRefInServer2.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) } - } - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - // Account for anotherEnv not having an address due to running in client mode. - if (remoteAddress != null) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) - } else { - val eventNames = events.map(_._1) - assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || - eventNames === List("onConnected", "onDisconnected")) + serverEnv2.shutdown() + serverEnv2.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) + assert(events.contains(("onDisconnected", serverEnv2.address))) } + } finally { + serverEnv1.shutdown() + serverEnv2.shutdown() + serverEnv1.awaitTermination() + serverEnv2.awaitTermination() } } - test("network events between non-client-mode RpcEnvs") { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + test("network events in sever RpcEnv when another RpcEnv is in client mode") { + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + try { + val serverRefInClient = + clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - override def receive: PartialFunction[Any, Unit] = { - case "hello" => - case m => events += "receive" -> m + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) } - override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress - } + clientEnv.shutdown() + clientEnv.awaitTermination() - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) + assert(events.map(_._1).contains("onDisconnected")) } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress - } + test("network events in client RpcEnv when another RpcEnv is in server mode") { + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") + val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") + try { + val serverRefInClient = + clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - }) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events-non-client") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - } + serverEnv.shutdown() + serverEnv.awaitTermination() - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - assert(events.contains(("onDisconnected", remoteAddress))) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + assert(events.contains(("onDisconnected", serverEnv.address))) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ebd6f700710bd..d4aebe9fd915e 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -43,7 +43,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } @@ -55,10 +55,10 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.connectionTerminated(client) + nettyRpcHandler.channelInactive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 23a8dba593442..f0e2004d2de2e 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -116,7 +116,11 @@ private void failOutstandingRequests(Throwable cause) { } @Override - public void channelUnregistered() { + public void channelActive() { + } + + @Override + public void channelInactive() { if (numOutstandingRequests() > 0) { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c215bd9d15045..c41f5b6873f6c 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -135,9 +135,14 @@ public StreamManager getStreamManager() { } @Override - public void connectionTerminated(TransportClient client) { + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { try { - delegate.connectionTerminated(client); + delegate.channelInactive(client); } finally { if (saslServer != null) { saslServer.dispose(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index 3843406b27403..4a1f28e9ffb31 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -28,9 +28,12 @@ public abstract class MessageHandler { /** Handles the receipt of a single message. */ public abstract void handle(T message) throws Exception; + /** Invoked when the channel this MessageHandler is on is active. */ + public abstract void channelActive(); + /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); - /** Invoked when the channel this MessageHandler is on has been unregistered. */ - public abstract void channelUnregistered(); + /** Invoked when the channel this MessageHandler is on is inactive. */ + public abstract void channelInactive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index ee1c683699478..c6ed0f459ad71 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -69,10 +69,15 @@ public void receive(TransportClient client, ByteBuffer message) { } /** - * Invoked when the connection associated with the given client has been invalidated. + * Invoked when the channel associated with the given client is active. + */ + public void channelActive(TransportClient client) { } + + /** + * Invoked when the channel associated with the given client is inactive. * No further requests will come from this client. */ - public void connectionTerminated(TransportClient client) { } + public void channelInactive(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 09435bcbab35e..18a9b7887ec28 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -84,14 +84,29 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) throws Exception { try { - requestHandler.channelUnregistered(); + requestHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while registering channel", e); + } + try { + responseHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while registering channel", e); + } + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from request handler while unregistering channel", e); } try { - responseHandler.channelUnregistered(); + responseHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from response handler while unregistering channel", e); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 105f53883167a..296ced3db093f 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -83,7 +83,12 @@ public void exceptionCaught(Throwable cause) { } @Override - public void channelUnregistered() { + public void channelActive() { + rpcHandler.channelActive(reverseClient); + } + + @Override + public void channelInactive() { if (streamManager != null) { try { streamManager.connectionTerminated(channel); @@ -91,7 +96,7 @@ public void channelUnregistered() { logger.error("StreamManager connectionTerminated() callback failed.", e); } } - rpcHandler.connectionTerminated(reverseClient); + rpcHandler.channelInactive(reverseClient); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 751516b9d82a1..045773317a78b 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -160,7 +160,7 @@ public Void answer(InvocationOnMock invocation) { long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); while (deadline > System.nanoTime()) { try { - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class)); error = null; break; } catch (Throwable t) { @@ -362,8 +362,8 @@ public void testRpcHandlerDelegate() throws Exception { saslHandler.getStreamManager(); verify(handler).getStreamManager(); - saslHandler.connectionTerminated(null); - verify(handler).connectionTerminated(any(TransportClient.class)); + saslHandler.channelInactive(null); + verify(handler).channelInactive(any(TransportClient.class)); saslHandler.exceptionCaught(null, null); verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); From ba9332edd889730c906404041bc83b1643d80961 Mon Sep 17 00:00:00 2001 From: Luc Bourlier Date: Fri, 18 Dec 2015 16:21:01 -0800 Subject: [PATCH 1301/2089] [SPARK-12345][CORE] Do not send SPARK_HOME through Spark submit REST interface It is usually an invalid location on the remote machine executing the job. It is picked up by the Mesos support in cluster mode, and most of the time causes the job to fail. Fixes SPARK-12345 Author: Luc Bourlier Closes #10329 from skyluc/issue/SPARK_HOME. --- .../org/apache/spark/deploy/rest/RestSubmissionClient.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index f0dd667ea1b26..0744c64d5e944 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -428,8 +428,10 @@ private[spark] object RestSubmissionClient { * Filter non-spark environment variables from any environment. */ private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { - env.filter { case (k, _) => - (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_") + env.filterKeys { k => + // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || + k.startsWith("MESOS_") } } } From 14be5dece291c900bf97d87b850d906717645fd4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Dec 2015 16:22:33 -0800 Subject: [PATCH 1302/2089] Revert "[SPARK-12413] Fix Mesos ZK persistence" This reverts commit 2bebaa39d9da33bc93ef682959cd42c1968a6a3e. --- .../apache/spark/deploy/rest/mesos/MesosRestServer.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 87d0fa8b52fe1..c0b93596508f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -99,11 +99,7 @@ private[mesos] class MesosSubmitRequestServlet( // cause spark-submit script to look for files in SPARK_HOME instead. // We only need the ability to specify where to find spark-submit script // which user can user spark.executor.home or spark.home configurations. - // - // Do not use `filterKeys` here to avoid SI-6654, which breaks ZK persistence - val environmentVariables = request.environmentVariables.filter { case (k, _) => - k != "SPARK_HOME" - } + val environmentVariables = request.environmentVariables.filterKeys(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description From 8a9417bc4b227c57bdb0a5a38a225dbdf6e69f64 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Dec 2015 16:22:41 -0800 Subject: [PATCH 1303/2089] Revert "[SPARK-12345][MESOS] Properly filter out SPARK_HOME in the Mesos REST server" This reverts commit 8184568810e8a2e7d5371db2c6a0366ef4841f70. --- .../org/apache/spark/deploy/rest/mesos/MesosRestServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index c0b93596508f1..24510db2bd0ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -99,7 +99,7 @@ private[mesos] class MesosSubmitRequestServlet( // cause spark-submit script to look for files in SPARK_HOME instead. // We only need the ability to specify where to find spark-submit script // which user can user spark.executor.home or spark.home configurations. - val environmentVariables = request.environmentVariables.filterKeys(!_.equals("SPARK_HOME")) + val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description From a78a91f4d7239c14bd5d0b18cdc87d55594a8d8a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Dec 2015 16:22:51 -0800 Subject: [PATCH 1304/2089] Revert "[SPARK-12345][MESOS] Filter SPARK_HOME when submitting Spark jobs with Mesos cluster mode." This reverts commit ad8c1f0b840284d05da737fb2cc5ebf8848f4490. --- .../apache/spark/deploy/rest/mesos/MesosRestServer.scala | 7 +------ .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 24510db2bd0ba..868cc35d06ef3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -94,12 +94,7 @@ private[mesos] class MesosSubmitRequestServlet( val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") val appArgs = request.appArgs - // We don't want to pass down SPARK_HOME when launching Spark apps - // with Mesos cluster mode since it's populated by default on the client and it will - // cause spark-submit script to look for files in SPARK_HOME instead. - // We only need the ability to specify where to find spark-submit script - // which user can user spark.executor.home or spark.home configurations. - val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) + val environmentVariables = request.environmentVariables val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 573355ba58132..721861fbbc517 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper - * methods the Mesos scheduler will use. + * methods and Mesos scheduler will use. */ private[mesos] trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered From 499ac3e69a102f9b10a1d7e14382fa191516f7b5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 18 Dec 2015 20:06:05 -0800 Subject: [PATCH 1305/2089] [SPARK-12091] [PYSPARK] Deprecate the JAVA-specific deserialized storage levels The current default storage level of Python persist API is MEMORY_ONLY_SER. This is different from the default level MEMORY_ONLY in the official document and RDD APIs. davies Is this inconsistency intentional? Thanks! Updates: Since the data is always serialized on the Python side, the storage levels of JAVA-specific deserialization are not removed, such as MEMORY_ONLY. Updates: Based on the reviewers' feedback. In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, `DISK_ONLY_2` and `OFF_HEAP`. Author: gatorsmile Closes #10092 from gatorsmile/persistStorageLevel. --- docs/configuration.md | 7 ++++--- docs/programming-guide.md | 10 ++++++---- python/pyspark/rdd.py | 8 ++++---- python/pyspark/sql/dataframe.py | 6 +++--- python/pyspark/storagelevel.py | 31 +++++++++++++++++++---------- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/dstream.py | 4 ++-- python/pyspark/streaming/flume.py | 4 ++-- python/pyspark/streaming/kafka.py | 2 +- python/pyspark/streaming/mqtt.py | 2 +- 10 files changed, 45 insertions(+), 31 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 38d3d059f9d31..85e7d1202d2ab 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -687,9 +687,10 @@ Apart from these, the following properties are also available, and may be useful spark.rdd.compress false - Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some - extra CPU time. + Whether to compress serialized RDD partitions (e.g. for + StorageLevel.MEMORY_ONLY_SER in Java + and Scala or StorageLevel.MEMORY_ONLY in Python). + Can save substantial space at the cost of some extra CPU time. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f823b89a4b5e9..c5e2a1cd7b8aa 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1196,14 +1196,14 @@ storage levels is: partitions that don't fit on disk, and read them from there when they're needed. - MEMORY_ONLY_SER + MEMORY_ONLY_SER
    (Java and Scala) Store RDD as serialized Java objects (one byte array per partition). This is generally more space-efficient than deserialized objects, especially when using a fast serializer, but more CPU-intensive to read. - MEMORY_AND_DISK_SER + MEMORY_AND_DISK_SER
    (Java and Scala) Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of recomputing them on the fly each time they're needed. @@ -1230,7 +1230,9 @@ storage levels is: -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, so it does not matter whether you choose a serialized level.* +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, `DISK_ONLY_2` and `OFF_HEAP`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1243,7 +1245,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. (Java and Scala) * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 00bb9a62e904a..a019c05862549 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -220,18 +220,18 @@ def context(self): def cache(self): """ - Persist this RDD with the default storage level (C{MEMORY_ONLY_SER}). + Persist this RDD with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True - self.persist(StorageLevel.MEMORY_ONLY_SER) + self.persist(StorageLevel.MEMORY_ONLY) return self - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """ Set this RDD's storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + If no storage level is specified defaults to (C{MEMORY_ONLY}). >>> rdd = sc.parallelize(["b", "a", "c"]) >>> rdd.persist().is_cached diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 78ab475eb466b..24fc29199924a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -371,18 +371,18 @@ def foreachPartition(self, f): @since(1.3) def cache(self): - """ Persists with the default storage level (C{MEMORY_ONLY_SER}). + """ Persists with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True self._jdf.cache() return self @since(1.3) - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + If no storage level is specified defaults to (C{MEMORY_ONLY}). """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 676aa0f7144aa..d4f184a85d764 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -23,8 +23,10 @@ class StorageLevel(object): """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory - in a serialized format, and whether to replicate the RDD partitions on multiple nodes. - Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. + in a JAVA-specific serialized format, and whether to replicate the RDD partitions on multiple + nodes. Also contains static constants for some commonly used storage levels, MEMORY_ONLY. + Since the data is always serialized on the Python side, all the constants use the serialized + formats. """ def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication=1): @@ -49,12 +51,21 @@ def __str__(self): StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False) StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2) -StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True) -StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, True, 2) -StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False, False) -StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, False, 2) -StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, True) -StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2) -StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False) -StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2) +StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, False) +StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, False, 2) +StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False) +StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2) StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) + +""" +.. note:: The following four storage level constants are deprecated in 2.0, since the records \ +will always be serialized in Python. +""" +StorageLevel.MEMORY_ONLY_SER = StorageLevel.MEMORY_ONLY +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY`` instead.""" +StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel.MEMORY_ONLY_2 +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY_2`` instead.""" +StorageLevel.MEMORY_AND_DISK_SER = StorageLevel.MEMORY_AND_DISK +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK`` instead.""" +StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel.MEMORY_AND_DISK_2 +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK_2`` instead.""" diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 1388b6d044e04..3deed52be0be2 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -258,7 +258,7 @@ def checkpoint(self, directory): """ self._jssc.checkpoint(directory) - def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_2): """ Create an input from TCP source hostname:port. Data is received using a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index b994a53bf2b85..adc2651740007 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -208,10 +208,10 @@ def func(iterator): def cache(self): """ Persist the RDDs of this DStream with the default storage level - (C{MEMORY_ONLY_SER}). + (C{MEMORY_ONLY}). """ self.is_cached = True - self.persist(StorageLevel.MEMORY_ONLY_SER) + self.persist(StorageLevel.MEMORY_ONLY) return self def persist(self, storageLevel): diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index b3d1905365925..b1fff0a5c7d6b 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -40,7 +40,7 @@ class FlumeUtils(object): @staticmethod def createStream(ssc, hostname, port, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, enableDecompression=False, bodyDecoder=utf8_decoder): """ @@ -70,7 +70,7 @@ def createStream(ssc, hostname, port, @staticmethod def createPollingStream(ssc, addresses, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, maxBatchSize=1000, parallelism=5, bodyDecoder=utf8_decoder): diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index cdf97ec73aaf9..13f8f9578e62a 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -40,7 +40,7 @@ class KafkaUtils(object): @staticmethod def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ Create an input stream that pulls messages from a Kafka Broker. diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index 1ce4093196e63..3a515ea4996f4 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -28,7 +28,7 @@ class MQTTUtils(object): @staticmethod def createStream(ssc, brokerUrl, topic, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + storageLevel=StorageLevel.MEMORY_AND_DISK_2): """ Create an input stream that pulls messages from a Mqtt Broker. From a073a73a561e78c734119c8b764d37a4e5e70da4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 19 Dec 2015 00:34:30 -0800 Subject: [PATCH 1306/2089] [SQL] Fix mistake doc of join type for dataframe.join Fix mistake doc of join type for ```dataframe.join```. Author: Yanbo Liang Closes #10378 from yanboliang/leftsemi. --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 24fc29199924a..4b3791e1b8864 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -610,7 +610,7 @@ def join(self, other, on=None, how=None): If `on` is a string or a list of string indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an inner equi-join. :param how: str, default 'inner'. - One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] From f496031bd2d09691f9d494a08d990b5d0f14b2a0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Dec 2015 15:13:05 -0800 Subject: [PATCH 1307/2089] Bump master version to 2.0.0-SNAPSHOT. Author: Reynold Xin Closes #10387 from rxin/version-bump. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- .../main/scala/org/apache/spark/package.scala | 2 +- docker-integration-tests/pom.xml | 2 +- docs/_config.yml | 4 +- examples/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-assembly/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt-assembly/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/kinesis-asl-assembly/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 136 ++++++++++++++++++ repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tags/README.md | 1 + tags/pom.xml | 2 +- tools/pom.xml | 2 +- unsafe/pom.xml | 2 +- yarn/pom.xml | 2 +- 40 files changed, 176 insertions(+), 39 deletions(-) create mode 100644 tags/README.md diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 369714f7b99c2..465bc37788e5d 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.6.0 +Version: 2.0.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman diff --git a/assembly/pom.xml b/assembly/pom.xml index 4b60ee00ffbe5..c3ab92f993d07 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 672e9469aec92..d45224cc8078d 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 61744bb5c7bf5..34ecb19654f1a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 7515aad09db73..cc5e7ef3ae008 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.6.0-SNAPSHOT" + val SPARK_VERSION = "2.0.0-SNAPSHOT" } diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index 39d3f344615e1..78b638ecfa638 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/docs/_config.yml b/docs/_config.yml index 2c70b76be8b7a..dc25ff2c16c5e 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.6.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.6.0 +SPARK_VERSION: 2.0.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.0.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.5" MESOS_VERSION: 0.21.0 diff --git a/examples/pom.xml b/examples/pom.xml index f5ab2a7fdc098..d27a5096a6511 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index dceedcf23ed5b..b2c377fe4cc9b 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 75113ff753e7a..4b6485ee0a71a 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 57f83607365d6..a79656c6f7d96 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index a9ed39ef8c9a0..0c466b3c4ac37 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 79258c126e043..5180ab6dbafbd 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index 89713a28ca6a8..c4a1ae26ea699 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 59fba8b826b4f..b3ba72a0087ad 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 087270de90b3f..7b628b09ea6a5 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 02d6b81281576..a725988449075 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4ce90e75fd359..4dfe3b654df1a 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 61ba4787fbf90..601080c2e6fbd 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 519a920279c97..3c5722502e5c1 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 87a4f05a05961..b046a10a04d5b 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8cd66c5b2e826..388a0ef06a2b0 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 5739bfc16958f..135866cea2e74 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index df50aca1a3f76..42af2b8b3e411 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/network/common/pom.xml b/network/common/pom.xml index 9af6cc5e925f9..32c34c63a45c5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 70ba5cb1995bb..f9aa7e2dd1f43 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index e2360eff5cfe1..a19cbb04b18c6 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index c560e13641c6e..1f570727dc4cc 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index edae59d882668..a3cfcd20fe690 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -33,6 +33,142 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("2.0") => + // When 1.6 is officially released, update this exclusion list. + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("unsafe"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // SQL columnar is considered private. + excludePackage("org.apache.spark.sql.columnar"), + // The shuffle package is considered private. + excludePackage("org.apache.spark.shuffle"), + // The collections utlities are considered pricate. + excludePackage("org.apache.spark.util.collection") + ) ++ + MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + Seq( + // MiMa does not deal properly with sealed traits + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") + ) ++ Seq( + // SPARK-11530 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.SQLContext$SQLSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.detachSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.tlSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.defaultSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.currentSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.openSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.setSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.createSession") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.preferredNodeLocationData_="), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") + ) ++ Seq( + // SPARK-11485 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), + // SPARK-11541 mark various JDBC dialects as private + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") + ) ++ Seq ( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.StageData.this") + ) ++ Seq( + // SPARK-11766 add toJson to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toJson") + ) ++ Seq( + // SPARK-9065 Support message handler in Kafka Python API + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") + ) ++ Seq( + // SPARK-11996 Make the executor thread dump work again + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") + ) ++ Seq( + // SPARK-3580 Add getNumPartitions method to JavaRDD + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") + ) ++ + // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a + // private class. + MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") case v if v.startsWith("1.6") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/repl/pom.xml b/repl/pom.xml index 154c99d23c7f4..67f9866509337 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 61d6fc63554bb..cfa520b7b9db2 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 06841b0945624..6db7a8a2dc526 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index b5b2143292a69..435e565f63458 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index d96f3e2b9f62b..e9885f6682028 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 435e16db13ab4..39cbd0d00f951 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/tags/README.md b/tags/README.md new file mode 100644 index 0000000000000..01e5126945eb7 --- /dev/null +++ b/tags/README.md @@ -0,0 +1 @@ +This module includes annotations in Java that are used to annotate test suites. diff --git a/tags/pom.xml b/tags/pom.xml index ca93722e73345..9e4610dae7a65 100644 --- a/tags/pom.xml +++ b/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 1e64f280e5bed..30cbb6a5a59c7 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index a1c1111364ee8..21fef3415adce 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 989b820bec9ef..a8c122fd40a1f 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml From 6ad31e79bfe4fc1d03bd8eb4472b8bd448ee3daf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Dec 2015 15:30:31 -0800 Subject: [PATCH 1308/2089] HOTFIX: Disable Java style test. --- dev/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index e7e10f1d8c725..20d493ca8bff4 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -528,7 +528,7 @@ def main(): if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() if not changed_files or any(f.endswith(".java") for f in changed_files): - run_java_style_checks() + # run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): From 0c4d6ad87389286280209b3f84a7fdc4d4be1441 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Dec 2015 16:55:25 -0800 Subject: [PATCH 1309/2089] HOTFIX for the previous hot fix. --- dev/run-tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/run-tests.py b/dev/run-tests.py index 20d493ca8bff4..2d4e04c4684de 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -529,6 +529,7 @@ def main(): run_scala_style_checks() if not changed_files or any(f.endswith(".java") for f in changed_files): # run_java_style_checks() + pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): From 284e29a870bbb62f59988a5d88cd12f1b0b6f9d3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 19 Dec 2015 22:40:35 -0800 Subject: [PATCH 1310/2089] [SPARK-11808] Remove Bagel. Author: Reynold Xin Closes #10395 from rxin/SPARK-11808. --- assembly/pom.xml | 5 - bagel/pom.xml | 64 ---- .../scala/org/apache/spark/bagel/Bagel.scala | 318 ------------------ .../org/apache/spark/bagel/package-info.java | 21 -- .../org/apache/spark/bagel/package.scala | 23 -- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../spark/deploy/SparkSubmitUtilsSuite.scala | 2 +- dev/audit-release/audit_release.py | 2 +- docs/_layouts/global.html | 1 - docs/bagel-programming-guide.md | 159 --------- examples/pom.xml | 6 - .../launcher/AbstractCommandBuilder.java | 2 +- pom.xml | 3 +- project/SparkBuild.scala | 8 +- repl/pom.xml | 6 - 15 files changed, 9 insertions(+), 613 deletions(-) delete mode 100644 bagel/pom.xml delete mode 100644 bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala delete mode 100644 bagel/src/main/scala/org/apache/spark/bagel/package-info.java delete mode 100644 bagel/src/main/scala/org/apache/spark/bagel/package.scala delete mode 100644 docs/bagel-programming-guide.md diff --git a/assembly/pom.xml b/assembly/pom.xml index c3ab92f993d07..6c79f9189787d 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -44,11 +44,6 @@ spark-core_${scala.binary.version} ${project.version} - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - org.apache.spark spark-mllib_${scala.binary.version} diff --git a/bagel/pom.xml b/bagel/pom.xml deleted file mode 100644 index d45224cc8078d..0000000000000 --- a/bagel/pom.xml +++ /dev/null @@ -1,64 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 2.0.0-SNAPSHOT - ../pom.xml - - - org.apache.spark - spark-bagel_2.10 - - bagel - - jar - Spark Project Bagel - http://spark.apache.org/ - - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala deleted file mode 100644 index 8399033ac61ec..0000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ /dev/null @@ -1,318 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -object Bagel extends Logging { - val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK - - /** - * Runs a Bagel program. - * @param sc org.apache.spark.SparkContext to use for the program. - * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the - * Key will be the vertex id. - * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often - * this will be an empty array, i.e. sc.parallelize(Array[K, Message]()). - * @param combiner [[org.apache.spark.bagel.Combiner]] combines multiple individual messages to a - * given vertex into one message before sending (which often involves network - * I/O). - * @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices - * after each superstep and provides the result to each vertex in the next - * superstep. - * @param partitioner org.apache.spark.Partitioner partitions values by key - * @param numPartitions number of partitions across which to split the graph. - * Default is the default parallelism of the SparkContext - * @param storageLevel org.apache.spark.storage.StorageLevel to use for caching of - * intermediate RDDs in each superstep. Defaults to caching in memory. - * @param compute function that takes a Vertex, optional set of (possibly combined) messages to - * the Vertex, optional Aggregator and the current superstep, - * and returns a set of (Vertex, outgoing Messages) pairs - * @tparam K key - * @tparam V vertex type - * @tparam M message type - * @tparam C combiner - * @tparam A aggregator - * @return an RDD of (K, V) pairs representing the graph after completion of the program - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C: Manifest, A: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - aggregator: Option[Aggregator[V, A]], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL - )( - compute: (V, Option[C], Option[A], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism - - var superstep = 0 - var verts = vertices - var msgs = messages - var noActivity = false - var lastRDD: RDD[(K, (V, Array[M]))] = null - do { - logInfo("Starting superstep " + superstep + ".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKeyWithClassTag( - combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) - val grouped = combinedMsgs.groupWith(verts) - val superstep_ = superstep // Create a read-only copy of superstep for capture in closure - val (processed, numMsgs, numActiveVerts) = - comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) - if (lastRDD != null) { - lastRDD.unpersist(false) - } - lastRDD = processed - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - verts = processed.mapValues { case (vert, msgs) => vert } - msgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - superstep += 1 - - noActivity = numMsgs == 0 && numActiveVerts == 0 - } while (!noActivity) - - verts - } - - /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default - * storage level */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M])): RDD[(K, V)] = run(sc, vertices, messages, - combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default - * org.apache.spark.HashPartitioner and default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, - DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the - * default org.apache.spark.HashPartitioner - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * default org.apache.spark.HashPartitioner, - * [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * the default org.apache.spark.HashPartitioner - * and [[org.apache.spark.bagel.DefaultCombiner]] - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, Array[M], Nothing]( - sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, Array[M]](compute)) - } - - /** - * Aggregates the given vertices using the given aggregator, if it - * is specified. - */ - private def agg[K, V <: Vertex, A: Manifest]( - verts: RDD[(K, V)], - aggregator: Option[Aggregator[V, A]] - ): Option[A] = aggregator match { - case Some(a) => - Some(verts.map { - case (id, vert) => a.createAggregator(vert) - }.reduce(a.mergeAggregators(_, _))) - case None => None - } - - /** - * Processes the given vertex-message RDD using the compute - * function. Returns the processed RDD, the number of messages - * created, and the number of active vertices. - */ - private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( - sc: SparkContext, - grouped: RDD[(K, (Iterable[C], Iterable[V]))], - compute: (V, Option[C]) => (V, Array[M]), - storageLevel: StorageLevel - ): (RDD[(K, (V, Array[M]))], Int, Int) = { - var numMsgs = sc.accumulator(0) - var numActiveVerts = sc.accumulator(0) - val processed = grouped.mapValues(x => (x._1.iterator, x._2.iterator)) - .flatMapValues { - case (_, vs) if !vs.hasNext => None - case (c, vs) => { - val (newVert, newMsgs) = - compute(vs.next, - c.hasNext match { - case true => Some(c.next) - case false => None - } - ) - - numMsgs += newMsgs.size - if (newVert.active) { - numActiveVerts += 1 - } - - Some((newVert, newMsgs)) - } - }.persist(storageLevel) - - // Force evaluation of processed RDD for accurate performance measurements - processed.foreach(x => {}) - - (processed, numMsgs.value, numActiveVerts.value) - } - - /** - * Converts a compute function that doesn't take an aggregator to - * one that does, so it can be passed to Bagel.run. - */ - private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( - compute: (V, Option[C], Int) => (V, Array[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { - (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => - compute(vert, msgs, superstep) - } -} - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { - def createCombiner(msg: M): Array[M] = - Array(msg) - def mergeMsg(combiner: Array[M], msg: M): Array[M] = - combiner :+ msg - def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = - a ++ b -} - -/** - * Represents a Bagel vertex. - * - * Subclasses may store state along with each vertex and must - * inherit from java.io.Serializable or scala.Serializable. - */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Vertex { - def active: Boolean -} - -/** - * Represents a Bagel message to a target vertex. - * - * Subclasses may contain a payload to deliver to the target vertex - * and must inherit from java.io.Serializable or scala.Serializable. - */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Message[K] { - def targetId: K -} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/package-info.java b/bagel/src/main/scala/org/apache/spark/bagel/package-info.java deleted file mode 100644 index 81f26f276549f..0000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Bagel: An implementation of Pregel in Spark. THIS IS DEPRECATED - use Spark's GraphX library. - */ -package org.apache.spark.bagel; \ No newline at end of file diff --git a/bagel/src/main/scala/org/apache/spark/bagel/package.scala b/bagel/src/main/scala/org/apache/spark/bagel/package.scala deleted file mode 100644 index 2fb1934579781..0000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -/** - * Bagel: An implementation of Pregel in Spark. THIS IS DEPRECATED - use Spark's GraphX library. - */ -package object bagel diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 52d3ab34c1784..669b6b614e38c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -965,7 +965,7 @@ private[spark] object SparkSubmitUtils { // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and // other spark-streaming utility components. Underscore is there to differentiate between // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x - val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") components.foreach { comp => diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 63c346c1b8908..4b5039b668a46 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -171,7 +171,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("neglects Spark and Spark's dependencies") { - val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") val coordinates = diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 27d1dd784ce2e..972be30da1eb6 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -115,7 +115,7 @@ def ensure_path_not_present(path): # maven that links against them. This will catch issues with messed up # dependencies within those projects. modules = [ - "spark-core", "spark-bagel", "spark-mllib", "spark-streaming", "spark-repl", + "spark-core", "spark-mllib", "spark-streaming", "spark-repl", "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 3089474c13385..62d75eff71057 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -75,7 +75,6 @@
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • -
  • Bagel (Pregel on Spark)
  • SparkR (R on Spark)
  • diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md deleted file mode 100644 index 347ca4a7af989..0000000000000 --- a/docs/bagel-programming-guide.md +++ /dev/null @@ -1,159 +0,0 @@ ---- -layout: global -displayTitle: Bagel Programming Guide -title: Bagel ---- - -**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** - -Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. - -In the Pregel programming model, jobs run as a sequence of iterations called _supersteps_. In each superstep, each vertex in the graph runs a user-specified function that can update state associated with the vertex and send messages to other vertices for use in the *next* iteration. - -This guide shows the programming model and features of Bagel by walking through an example implementation of PageRank on Bagel. - -# Linking with Bagel - -To use Bagel in your program, add the following SBT or Maven dependency: - - groupId = org.apache.spark - artifactId = spark-bagel_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION}} - -# Programming Model - -Bagel operates on a graph represented as a [distributed dataset](programming-guide.html) of (K, V) pairs, where keys are vertex IDs and values are vertices plus their associated state. In each superstep, Bagel runs a user-specified compute function on each vertex that takes as input the current vertex state and a list of messages sent to that vertex during the previous superstep, and returns the new vertex state and a list of outgoing messages. - -For example, we can use Bagel to implement PageRank. Here, vertices represent pages, edges represent links between pages, and messages represent shares of PageRank sent to the pages that a particular page links to. - -We first extend the default `Vertex` class to store a `Double` -representing the current PageRank of the vertex, and similarly extend -the `Message` and `Edge` classes. Note that these need to be marked `@serializable` to allow Spark to transfer them across machines. We also import the Bagel types and implicit conversions. - -{% highlight scala %} -import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ - -@serializable class PREdge(val targetId: String) extends Edge - -@serializable class PRVertex( - val id: String, val rank: Double, val outEdges: Seq[Edge], - val active: Boolean) extends Vertex - -@serializable class PRMessage( - val targetId: String, val rankShare: Double) extends Message -{% endhighlight %} - -Next, we load a sample graph from a text file as a distributed dataset and package it into `PRVertex` objects. We also cache the distributed dataset because Bagel will use it multiple times and we'd like to avoid recomputing it. - -{% highlight scala %} -val input = sc.textFile("data/mllib/pagerank_data.txt") - -val numVerts = input.count() - -val verts = input.map(line => { - val fields = line.split('\t') - val (id, linksStr) = (fields(0), fields(1)) - val links = linksStr.split(',').map(new PREdge(_)) - (id, new PRVertex(id, 1.0 / numVerts, links, true)) -}).cache -{% endhighlight %} - -We run the Bagel job, passing in `verts`, an empty distributed dataset of messages, and a custom compute function that runs PageRank for 10 iterations. - -{% highlight scala %} -val emptyMsgs = sc.parallelize(List[(String, PRMessage)]()) - -def compute(self: PRVertex, msgs: Option[Seq[PRMessage]], superstep: Int) -: (PRVertex, Iterable[PRMessage]) = { - val msgSum = msgs.getOrElse(List()).map(_.rankShare).sum - val newRank = - if (msgSum != 0) - 0.15 / numVerts + 0.85 * msgSum - else - self.rank - val halt = superstep >= 10 - val msgsOut = - if (!halt) - self.outEdges.map(edge => - new PRMessage(edge.targetId, newRank / self.outEdges.size)) - else - List() - (new PRVertex(self.id, newRank, self.outEdges, !halt), msgsOut) -} -{% endhighlight %} - -val result = Bagel.run(sc, verts, emptyMsgs)()(compute) - -Finally, we print the results. - -{% highlight scala %} -println(result.map(v => "%s\t%s\n".format(v.id, v.rank)).collect.mkString) -{% endhighlight %} - -## Combiners - -Sending a message to another vertex generally involves expensive communication over the network. For certain algorithms, it's possible to reduce the amount of communication using _combiners_. For example, if the compute function receives integer messages and only uses their sum, it's possible for Bagel to combine multiple messages to the same vertex by summing them. - -For combiner support, Bagel can optionally take a set of combiner functions that convert messages to their combined form. - -_Example: PageRank with combiners_ - -## Aggregators - -Aggregators perform a reduce across all vertices after each superstep, and provide the result to each vertex in the next superstep. - -For aggregator support, Bagel can optionally take an aggregator function that reduces across each vertex. - -_Example_ - -## Operations - -Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details. - -### Actions - -{% highlight scala %} -/*** Full form ***/ - -Bagel.run(sc, vertices, messages, combiner, aggregator, partitioner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], aggregated: Option[A], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -/*** Abbreviated forms ***/ - -Bagel.run(sc, vertices, messages, combiner, partitioner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -Bagel.run(sc, vertices, messages, combiner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -Bagel.run(sc, vertices, messages, numSplits)(compute) -// where compute takes (vertex: V, messages: Option[Array[M]], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) -{% endhighlight %} - -### Types - -{% highlight scala %} -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -trait Vertex { - def active: Boolean -} - -trait Message[K] { - def targetId: K -} -{% endhighlight %} diff --git a/examples/pom.xml b/examples/pom.xml index d27a5096a6511..1a0d5e5854642 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -53,12 +53,6 @@ ${project.version} provided
    - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-hive_${scala.binary.version} diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 55fe156cf665f..68af14397ba81 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -146,7 +146,7 @@ List buildClassPath(String appClassPath) throws IOException { boolean isTesting = "1".equals(getenv("SPARK_TESTING")); if (prependClasses || isTesting) { String scala = getScalaVersion(); - List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", + List projects = Arrays.asList("core", "repl", "mllib", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { diff --git a/pom.xml b/pom.xml index 1f570727dc4cc..32918d6a74af9 100644 --- a/pom.xml +++ b/pom.xml @@ -88,7 +88,6 @@ tags core - bagel graphx mllib tools @@ -194,7 +193,7 @@ declared in the projects that build assemblies. For other projects the scope should remain as "compile", otherwise they are not available - during compilation if the dependency is transivite (e.g. "bagel/" depending on "core/" and + during compilation if the dependency is transivite (e.g. "graphx/" depending on "core/" and needing Hadoop classes in the classpath to compile). --> compile diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b1dcaedcba75e..c3d53f835f395 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,10 +34,10 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, + val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) @@ -352,7 +352,7 @@ object OldDeps { scalaVersion := "2.10.5", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-streaming", "spark-mllib", "spark-graphx", "spark-core").map(versionArtifact(_).get intransitive()) ) } @@ -556,7 +556,7 @@ object Unidoc { unidocProjectFilter in(ScalaUnidoc, unidoc) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn, testTags), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. diff --git a/repl/pom.xml b/repl/pom.xml index 67f9866509337..efc3dd452e329 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -50,12 +50,6 @@ test-jar test - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - runtime - org.apache.spark spark-mllib_${scala.binary.version} From ce1798b3af8de326bf955b51ed955a924b019b4e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 20 Dec 2015 09:08:23 +0000 Subject: [PATCH 1311/2089] [SPARK-10158][PYSPARK][MLLIB] ALS better error message when using Long IDs Added catch for casting Long to Int exception when PySpark ALS Ratings are serialized. It is easy to accidentally use Long IDs for user/product and before, it would fail with a somewhat cryptic "ClassCastException: java.lang.Long cannot be cast to java.lang.Integer." Now if this is done, a more descriptive error is shown, e.g. "PickleException: Ratings id 1205640308657491975 exceeds max integer value of 2147483647." Author: Bryan Cutler Closes #9361 from BryanCutler/als-pyspark-long-id-error-SPARK-10158. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 12 +++++++++++- python/pyspark/mllib/tests.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 29160a10e16b3..f6826ddbfabfe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1438,9 +1438,19 @@ private[spark] object SerDe extends Serializable { if (args.length != 3) { throw new PickleException("should be 3") } - new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], + new Rating(ratingsIdCheckLong(args(0)), ratingsIdCheckLong(args(1)), args(2).asInstanceOf[Double]) } + + private def ratingsIdCheckLong(obj: Object): Int = { + try { + obj.asInstanceOf[Int] + } catch { + case ex: ClassCastException => + throw new PickleException(s"Ratings id ${obj.toString} exceeds " + + s"max integer value of ${Int.MaxValue}", ex) + } + } } var initialized = false diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f8e8e0e0adbea..6ed03e35828ed 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -1539,6 +1540,22 @@ def test_load_vectors(self): shutil.rmtree(load_vectors_path) +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") From d0f695089e4627273133c5f49ef7a83c1840c8f5 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Dec 2015 10:21:22 +0000 Subject: [PATCH 1312/2089] [SPARK-12349][ML] Make spark.ml PCAModel load backwards compatible Only load explainedVariance in PCAModel if it was written with Spark > 1.6.x jkbradley is this kind of what you had in mind? Author: Sean Owen Closes #10327 from srowen/SPARK-12349. --- .../org/apache/spark/ml/feature/PCA.scala | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 53d33ea2b8f76..759be813eea6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -167,14 +167,37 @@ object PCAModel extends MLReadable[PCAModel] { private val className = classOf[PCAModel].getName + /** + * Loads a [[PCAModel]] from data located at the input path. Note that the model includes an + * `explainedVariance` member that is not recorded by Spark 1.6 and earlier. A model + * can be loaded from such older data but will have an empty vector for + * `explainedVariance`. + * + * @param path path to serialized model data + * @return a [[PCAModel]] + */ override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + // explainedVariance field is not present in Spark <= 1.6 + val versionRegex = "([0-9]+)\\.([0-9])+.*".r + val hasExplainedVariance = metadata.sparkVersion match { + case versionRegex(major, minor) => + (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) + case _ => false + } + val dataPath = new Path(path, "data").toString - val Row(pc: DenseMatrix, explainedVariance: DenseVector) = - sqlContext.read.parquet(dataPath) - .select("pc", "explainedVariance") - .head() - val model = new PCAModel(metadata.uid, pc, explainedVariance) + val model = if (hasExplainedVariance) { + val Row(pc: DenseMatrix, explainedVariance: DenseVector) = + sqlContext.read.parquet(dataPath) + .select("pc", "explainedVariance") + .head() + new PCAModel(metadata.uid, pc, explainedVariance) + } else { + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head() + new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) + } DefaultParamsReader.getAndSetParams(model, metadata) model } From 1920d72a1f7b9844323d06e8094818347f413df6 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 21 Dec 2015 08:53:46 -0800 Subject: [PATCH 1313/2089] [PYSPARK] Pyspark typo & Add missing abstractmethod annotation No jira is created since this is a trivial change. davies Please help review it Author: Jeff Zhang Closes #10143 from zjffdu/pyspark_typo. --- python/pyspark/ml/pipeline.py | 2 +- python/pyspark/ml/wrapper.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 4475451edb781..9f5f6ac8fa4e2 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -86,7 +86,7 @@ class Transformer(Params): @abstractmethod def _transform(self, dataset): """ - Transforms the input dataset with optional parameters. + Transforms the input dataset. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 4bcb4aaec89de..dd1d4b076eddd 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -15,7 +15,7 @@ # limitations under the License. # -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from pyspark import SparkContext from pyspark.sql import DataFrame @@ -110,6 +110,7 @@ class JavaEstimator(Estimator, JavaWrapper): __metaclass__ = ABCMeta + @abstractmethod def _create_model(self, java_model): """ Creates a model from the input Java model reference. From 474eb21a30f7ee898f76a625a5470c8245af1d22 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 21 Dec 2015 12:46:06 -0800 Subject: [PATCH 1314/2089] [SPARK-12398] Smart truncation of DataFrame / Dataset toString When a DataFrame or Dataset has a long schema, we should intelligently truncate to avoid flooding the screen with unreadable information. // Standard output [a: int, b: int] // Truncate many top level fields [a: int, b, string ... 10 more fields] // Truncate long inner structs [a: struct] Author: Dilip Biswal Closes #10373 from dilipbiswal/spark-12398. --- .../org/apache/spark/sql/types/DataType.scala | 3 ++ .../apache/spark/sql/types/StructType.scala | 17 ++++++++ .../spark/sql/execution/Queryable.scala | 15 ++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 39 +++++++++++++++++++ 4 files changed, 73 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4b54c31dcc27a..b0c43c4100d08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -66,6 +66,9 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type. */ def simpleString: String = typeName + /** Readable string representation for the type with truncation */ + private[sql] def simpleString(maxNumberFields: Int): String = simpleString + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9778df271ddd5..d56802276558a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -278,6 +278,23 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + private[sql] override def simpleString(maxNumberFields: Int): String = { + val builder = new StringBuilder + val fieldTypes = fields.take(maxNumberFields).map { + case f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" + } + builder.append("struct<") + builder.append(fieldTypes.mkString(", ")) + if (fields.length > 2) { + if (fields.length - fieldTypes.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (fields.length - 2) + " more fields") + } + } + builder.append(">").toString() + } + /** * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field * B from `that`, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index b397d42612cf0..3f391fd9a9ddb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -31,7 +31,20 @@ private[sql] trait Queryable { override def toString: String = { try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + val builder = new StringBuilder + val fields = schema.take(2).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("[") + builder.append(fields.mkString(", ")) + if (schema.length > 2) { + if (schema.length - fields.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (schema.length - 2) + " more fields") + } + } + builder.append("]").toString() } catch { case NonFatal(e) => s"Invalid tree; ${e.getMessage}:\n$queryExecution" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4c3e12af7203d..1a0f1b61cb3c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1177,4 +1177,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) } + + test("SPARK-12398 truncated toString") { + val df1 = Seq((1L, "row1")).toDF("id", "name") + assert(df1.toString() === "[id: bigint, name: string]") + + val df2 = Seq((1L, "c2", false)).toDF("c1", "c2", "c3") + assert(df2.toString === "[c1: bigint, c2: string ... 1 more field]") + + val df3 = Seq((1L, "c2", false, 10)).toDF("c1", "c2", "c3", "c4") + assert(df3.toString === "[c1: bigint, c2: string ... 2 more fields]") + + val df4 = Seq((1L, Tuple2(1L, "val"))).toDF("c1", "c2") + assert(df4.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>]") + + val df5 = Seq((1L, Tuple2(1L, "val"), 20.0)).toDF("c1", "c2", "c3") + assert(df5.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 1 more field]") + + val df6 = Seq((1L, Tuple2(1L, "val"), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert(df6.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 2 more fields]") + + val df7 = Seq((1L, Tuple3(1L, "val", 2), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df7.toString === + "[c1: bigint, c2: struct<_1: bigint, _2: string ... 1 more field> ... 2 more fields]") + + val df8 = Seq((1L, Tuple7(1L, "val", 2, 3, 4, 5, 6), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df8.toString === + "[c1: bigint, c2: struct<_1: bigint, _2: string ... 5 more fields> ... 2 more fields]") + + val df9 = + Seq((1L, Tuple4(1L, Tuple4(1L, 2L, 3L, 4L), 2L, 3L), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df9.toString === + "[c1: bigint, c2: struct<_1: bigint," + + " _2: struct<_1: bigint," + + " _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]") + + } } From 7634fe9511e1a8fb94979624b1b617b495b48ad3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 21 Dec 2015 12:47:07 -0800 Subject: [PATCH 1315/2089] [SPARK-12321][SQL] JSON format for TreeNode (use reflection) An alternative solution for https://github.com/apache/spark/pull/10295 , instead of implementing json format for all logical/physical plans and expressions, use reflection to implement it in `TreeNode`. Here I use pre-order traversal to flattern a plan tree to a plan list, and add an extra field `num-children` to each plan node, so that we can reconstruct the tree from the list. example json: logical plan tree: ``` [ { "class" : "org.apache.spark.sql.catalyst.plans.logical.Sort", "num-children" : 1, "order" : [ [ { "class" : "org.apache.spark.sql.catalyst.expressions.SortOrder", "num-children" : 1, "child" : 0, "direction" : "Ascending" }, { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "i", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 10, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] } ] ], "global" : false, "child" : 0 }, { "class" : "org.apache.spark.sql.catalyst.plans.logical.Project", "num-children" : 1, "projectList" : [ [ { "class" : "org.apache.spark.sql.catalyst.expressions.Alias", "num-children" : 1, "child" : 0, "name" : "i", "exprId" : { "id" : 10, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Add", "num-children" : 2, "left" : 0, "right" : 1 }, { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "a", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 0, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Literal", "num-children" : 0, "value" : "1", "dataType" : "integer" } ], [ { "class" : "org.apache.spark.sql.catalyst.expressions.Alias", "num-children" : 1, "child" : 0, "name" : "j", "exprId" : { "id" : 11, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Multiply", "num-children" : 2, "left" : 0, "right" : 1 }, { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "a", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 0, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] }, { "class" : "org.apache.spark.sql.catalyst.expressions.Literal", "num-children" : 0, "value" : "2", "dataType" : "integer" } ] ], "child" : 0 }, { "class" : "org.apache.spark.sql.catalyst.plans.logical.LocalRelation", "num-children" : 0, "output" : [ [ { "class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference", "num-children" : 0, "name" : "a", "dataType" : "integer", "nullable" : true, "metadata" : { }, "exprId" : { "id" : 0, "jvmId" : "cd1313c7-3f66-4ed7-a320-7d91e4633ac6" }, "qualifiers" : [ ] } ] ], "data" : [ ] } ] ``` Author: Wenchen Fan Closes #10311 from cloud-fan/toJson-reflection. --- .../spark/sql/catalyst/ScalaReflection.scala | 114 ++++---- .../expressions/aggregate/interfaces.scala | 1 - .../sql/catalyst/expressions/literals.scala | 41 +++ .../expressions/namedExpressions.scala | 4 + .../spark/sql/catalyst/plans/QueryPlan.scala | 2 + .../spark/sql/catalyst/trees/TreeNode.scala | 258 +++++++++++++++++- .../org/apache/spark/sql/types/DataType.scala | 6 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 102 ++++++- .../spark/sql/UserDefinedTypeSuite.scala | 5 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 + .../hive/execution/ScriptTransformation.scala | 2 +- 13 files changed, 472 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c1b1d5cd2dee0..cc9e6af1818f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -68,7 +68,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = tpe arrayClassFor(elementType) case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + val clazz = getClassFromType(tpe) ObjectType(clazz) } } @@ -321,29 +321,11 @@ object ScalaReflection extends ScalaReflection { keyData :: valueData :: Nil) case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) + val params = getConstructorParameters(t) - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val cls = getClassFromType(tpe) - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val dataType = schemaFor(fieldType).dataType val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath @@ -477,27 +459,9 @@ object ScalaReflection extends ScalaReflection { } case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val params = getConstructorParameters(t) - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath @@ -595,6 +559,21 @@ object ScalaReflection extends ScalaReflection { } } } + + /** + * Returns the parameter names and types for the primary constructor of this class. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + getConstructorParameters(t) + } + + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } /** @@ -668,26 +647,11 @@ trait ScalaReflection { Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( - s => s.isMethod && s.asMethod.isPrimaryConstructor) - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val params = getConstructorParameters(t) Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) }), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) @@ -740,4 +704,32 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } + + /** + * Returns the parameter names and types for the primary constructor of this type. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { + val formalTypeArgs = tpe.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = tpe + val constructorSymbol = tpe.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( + s => s.isMethod && s.asMethod.isPrimaryConstructor) + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + params.flatten.map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b6d2ddc5b1364..b616d6953baa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 68ec688c99f93..e3573b4947379 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.json4s.JsonAST._ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -55,6 +56,34 @@ object Literal { */ def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def fromJSON(json: JValue): Literal = { + val dataType = DataType.parseDataType(json \ "dataType") + json \ "value" match { + case JNull => Literal.create(null, dataType) + case JString(str) => + val value = dataType match { + case BooleanType => str.toBoolean + case ByteType => str.toByte + case ShortType => str.toShort + case IntegerType => str.toInt + case LongType => str.toLong + case FloatType => str.toFloat + case DoubleType => str.toDouble + case StringType => UTF8String.fromString(str) + case DateType => java.sql.Date.valueOf(str) + case TimestampType => java.sql.Timestamp.valueOf(str) + case CalendarIntervalType => CalendarInterval.fromString(str) + case t: DecimalType => + val d = Decimal(str) + assert(d.changePrecision(t.precision, t.scale)) + d + case _ => null + } + Literal.create(value, dataType) + case other => sys.error(s"$other is not a valid Literal json value") + } + } + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } @@ -123,6 +152,18 @@ case class Literal protected (value: Any, dataType: DataType) case _ => false } + override protected def jsonFields: List[JField] = { + // Turns all kinds of literal values to string in json field, as the type info is hard to + // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc. + val jsonValue = (value, dataType) match { + case (null, _) => JNull + case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString) + case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString) + case (other, _) => JString(other.toString) + } + ("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil + } + override def eval(input: InternalRow): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 26b6aca79971e..eefd9c7482553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -262,6 +262,10 @@ case class AttributeReference( } } + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: Nil + } + override def toString: String = s"$name#${exprId.id}$typeSuffix" // Since the expression id is not in the first constructor it is missing from the default diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b9db7838db08a..d2626440b9434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -88,6 +88,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray @@ -120,6 +121,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d838d845d20fd..c97dc2d8be7e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,9 +17,25 @@ package org.apache.spark.sql.catalyst.trees +import java.util.UUID import scala.collection.Map - +import scala.collection.mutable.Stack +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.ScalaReflection._ +import org.apache.spark.sql.catalyst.{TableIdentifier, ScalaReflectionLock} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{StructType, DataType} /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ @@ -463,4 +479,244 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } s"$nodeName(${args.mkString(",")})" } + + def toJSON: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) + + private def jsonValue: JValue = { + val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue] + + def collectJsonValue(tn: BaseType): Unit = { + val jsonFields = ("class" -> JString(tn.getClass.getName)) :: + ("num-children" -> JInt(tn.children.length)) :: tn.jsonFields + jsonValues += JObject(jsonFields) + tn.children.foreach(collectJsonValue) + } + + collectJsonValue(this) + jsonValues + } + + protected def jsonFields: List[JField] = { + val fieldNames = getConstructorParameters(getClass).map(_._1) + val fieldValues = productIterator.toSeq ++ otherCopyArgs + assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) + + fieldNames.zip(fieldValues).map { + // If the field value is a child, then use an int to encode it, represents the index of + // this child in all children. + case (name, value: TreeNode[_]) if containsChild(value) => + name -> JInt(children.indexOf(value)) + case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) => + name -> JArray( + value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList + ) + case (name, value) => name -> parseToJson(value) + }.toList + } + + private def parseToJson(obj: Any): JValue = obj match { + case b: Boolean => JBool(b) + case b: Byte => JInt(b.toInt) + case s: Short => JInt(s.toInt) + case i: Int => JInt(i) + case l: Long => JInt(l) + case f: Float => JDouble(f) + case d: Double => JDouble(d) + case b: BigInt => JInt(b) + case null => JNull + case s: String => JString(s) + case u: UUID => JString(u.toString) + case dt: DataType => dt.jsonValue + case m: Metadata => m.jsonValue + case s: StorageLevel => + ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ + ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) + case n: TreeNode[_] => n.jsonValue + case o: Option[_] => o.map(parseToJson) + case t: Seq[_] => JArray(t.map(parseToJson).toList) + case m: Map[_, _] => + val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) } + JObject(fields) + case r: RDD[_] => JNothing + // if it's a scala object, we can simply keep the full class path. + // TODO: currently if the class name ends with "$", we think it's a scala object, there is + // probably a better way to check it. + case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName + // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper + case p: Product => try { + val fieldNames = getConstructorParameters(p.getClass).map(_._1) + val fieldValues = p.productIterator.toSeq + assert(fieldNames.length == fieldValues.length) + ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { + case (name, value) => name -> parseToJson(value) + }.toList + } catch { + case _: RuntimeException => null + } + case _ => JNull + } +} + +object TreeNode { + def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = { + val jsonAST = parse(json) + assert(jsonAST.isInstanceOf[JArray]) + reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType] + } + + private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = { + assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject])) + val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*) + + def parseNextNode(): TreeNode[_] = { + val nextNode = jsonNodes.pop() + + val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s) + if (cls == classOf[Literal]) { + Literal.fromJSON(nextNode) + } else if (cls.getName.endsWith("$")) { + cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]] + } else { + val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt + + val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode()) + val fields = getConstructorParameters(cls) + + val parameters: Array[AnyRef] = fields.map { + case (fieldName, fieldType) => + parseFromJson(nextNode \ fieldName, fieldType, children, sc) + }.toArray + + val maybeCtor = cls.getConstructors.find { p => + val expectedTypes = p.getParameterTypes + expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall { + case (cls, tpe) => cls == getClassFromType(tpe) + } + } + if (maybeCtor.isEmpty) { + sys.error(s"No valid constructor for ${cls.getName}") + } else { + try { + maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]] + } catch { + case e: java.lang.IllegalArgumentException => + throw new RuntimeException( + s""" + |Failed to construct tree node: ${cls.getName} + |ctor: ${maybeCtor.get} + |types: ${parameters.map(_.getClass).mkString(", ")} + |args: ${parameters.mkString(", ")} + """.stripMargin, e) + } + } + } + } + + parseNextNode() + } + + import universe._ + + private def parseFromJson( + value: JValue, + expectedType: Type, + children: Seq[TreeNode[_]], + sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized { + if (value == JNull) return null + + expectedType match { + case t if t <:< definitions.BooleanTpe => + value.asInstanceOf[JBool].value: java.lang.Boolean + case t if t <:< definitions.ByteTpe => + value.asInstanceOf[JInt].num.toByte: java.lang.Byte + case t if t <:< definitions.ShortTpe => + value.asInstanceOf[JInt].num.toShort: java.lang.Short + case t if t <:< definitions.IntTpe => + value.asInstanceOf[JInt].num.toInt: java.lang.Integer + case t if t <:< definitions.LongTpe => + value.asInstanceOf[JInt].num.toLong: java.lang.Long + case t if t <:< definitions.FloatTpe => + value.asInstanceOf[JDouble].num.toFloat: java.lang.Float + case t if t <:< definitions.DoubleTpe => + value.asInstanceOf[JDouble].num: java.lang.Double + + case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num + case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s + case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) + case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value) + case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject]) + case t if t <:< localTypeOf[StorageLevel] => + val JBool(useDisk) = value \ "useDisk" + val JBool(useMemory) = value \ "useMemory" + val JBool(useOffHeap) = value \ "useOffHeap" + val JBool(deserialized) = value \ "deserialized" + val JInt(replication) = value \ "replication" + StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt) + case t if t <:< localTypeOf[TreeNode[_]] => value match { + case JInt(i) => children(i.toInt) + case arr: JArray => reconstruct(arr, sc) + case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.") + } + case t if t <:< localTypeOf[Option[_]] => + if (value == JNothing) { + None + } else { + val TypeRef(_, _, Seq(optType)) = t + Option(parseFromJson(value, optType, children, sc)) + } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val JArray(elements) = value + elements.map(parseFromJson(_, elementType, children, sc)).toSeq + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val JObject(fields) = value + fields.map { + case (name, value) => name -> parseFromJson(value, valueType, children, sc) + }.toMap + case t if t <:< localTypeOf[RDD[_]] => + new EmptyRDD[Any](sc) + case _ if isScalaObject(value) => + val JString(clsName) = value \ "object" + val cls = Utils.classForName(clsName) + cls.getField("MODULE$").get(cls) + case t if t <:< localTypeOf[Product] => + val fields = getConstructorParameters(t) + val clsName = getClassNameFromType(t) + parseToProduct(clsName, fields, value, children, sc) + // There maybe some cases that the parameter type signature is not Product but the value is, + // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here. + case _ if isScalaProduct(value) => + val JString(clsName) = value \ "product-class" + val fields = getConstructorParameters(Utils.classForName(clsName)) + parseToProduct(clsName, fields, value, children, sc) + case _ => sys.error(s"Do not support type $expectedType with json $value.") + } + } + + private def parseToProduct( + clsName: String, + fields: Seq[(String, Type)], + value: JValue, + children: Seq[TreeNode[_]], + sc: SparkContext): AnyRef = { + val parameters: Array[AnyRef] = fields.map { + case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc) + }.toArray + val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size) + ctor.newInstance(parameters: _*).asInstanceOf[AnyRef] + } + + private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match { + case JString(str) if str.endsWith("$") => true + case _ => false + } + + private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match { + case _: JString => true + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index b0c43c4100d08..f8d71c5f02372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,8 +107,8 @@ object DataType { def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) private val nonDecimalNameToType = { - Seq(NullType, DateType, TimestampType, BinaryType, - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, + DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap } @@ -130,7 +130,7 @@ object DataType { } // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private def parseDataType(json: JValue): DataType = json match { + private[sql] def parseDataType(json: JValue): DataType = json match { case JString(name) => nameToType(name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index b8a43025882e5..ea5a9afe03b00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,9 +74,7 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil - override protected final def otherCopyArgs: Seq[AnyRef] = { - sqlContext :: Nil - } + override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 3c5a8cb2aa935..4afa5f8ec1035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -61,9 +61,9 @@ private[sql] case class InMemoryRelation( storageLevel: StorageLevel, @transient child: SparkPlan, tableName: Option[String])( - @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, - @transient private var _statistics: Statistics = null, - private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) + @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private[sql] var _statistics: Statistics = null, + private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bc22fb8b7bdb4..9246f55020fc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -21,10 +21,15 @@ import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.Queryable +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.{LogicalRDD, Queryable} abstract class QueryTest extends PlanTest { @@ -123,6 +128,8 @@ abstract class QueryTest extends PlanTest { |""".stripMargin) } + checkJsonFormat(analyzedDF) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -177,6 +184,97 @@ abstract class QueryTest extends PlanTest { s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + planWithCaching) } + + private def checkJsonFormat(df: DataFrame): Unit = { + val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. + logicalPlan.transform { + case _: MapPartitions[_, _] => return + case _: MapGroups[_, _, _] => return + case _: AppendColumns[_, _] => return + case _: CoGroup[_, _, _, _] => return + case _: LogicalRelation => return + }.transformAllExpressions { + case a: ImperativeAggregate => return + } + + val jsonString = try { + logicalPlan.toJSON + } catch { + case e => + fail( + s""" + |Failed to parse logical plan to JSON: + |${logicalPlan.treeString} + """.stripMargin, e) + } + + // bypass hive tests before we fix all corner cases in hive module. + if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return + + // scala function is not serializable to JSON, use null to replace them so that we can compare + // the plans later. + val normalized1 = logicalPlan.transformAllExpressions { + case udf: ScalaUDF => udf.copy(function = null) + case gen: UserDefinedGenerator => gen.copy(function = null) + } + + // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains + // these non-serializable stuff, and use these original ones to replace the null-placeholders + // in the logical plans parsed from JSON. + var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } + var localRelations = logicalPlan.collect { case l: LocalRelation => l } + var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + + val jsonBackPlan = try { + TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + } catch { + case e => + fail( + s""" + |Failed to rebuild the logical plan from JSON: + |${logicalPlan.treeString} + | + |${logicalPlan.prettyJson} + """.stripMargin, e) + } + + val normalized2 = jsonBackPlan transformDown { + case l: LogicalRDD => + val origin = logicalRDDs.head + logicalRDDs = logicalRDDs.drop(1) + LogicalRDD(l.output, origin.rdd)(sqlContext) + case l: LocalRelation => + val origin = localRelations.head + localRelations = localRelations.drop(1) + l.copy(data = origin.data) + case l: InMemoryRelation => + val origin = inMemoryRelations.head + inMemoryRelations = inMemoryRelations.drop(1) + InMemoryRelation( + l.output, + l.useCompression, + l.batchSize, + l.storageLevel, + origin.child, + l.tableName)( + origin.cachedColumnBuffers, + l._statistics, + origin._batchStats) + } + + assert(logicalRDDs.isEmpty) + assert(localRelations.isEmpty) + assert(inMemoryRelations.isEmpty) + + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: the logical plan parsed from json does not match the original one === + |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } } object QueryTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f602f2fb89ca5..2a1117318ad14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -65,6 +65,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this + + override def equals(other: Any): Boolean = other match { + case _: MyDenseVectorUDT => true + case _ => false + } } class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 08b291e088238..f099e146d1e37 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -728,6 +728,8 @@ private[hive] case class MetastoreRelation Objects.hashCode(databaseName, tableName, alias, output) } + override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil + @transient val hiveQlTable: Table = { // We start by constructing an API table as Hive performs several important transformations // internally when converting an API table to a QL table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index b30117f0de997..d9b9ba4bfdfed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -58,7 +58,7 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { - override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) From 4883a5087d481d4de5d3beabbd709853de01399a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 21 Dec 2015 13:46:58 -0800 Subject: [PATCH 1316/2089] [SPARK-12374][SPARK-12150][SQL] Adding logical/physical operators for Range Based on the suggestions from marmbrus , added logical/physical operators for Range for improving the performance. Also added another API for resolving the JIRA Spark-12150. Could you take a look at my implementation, marmbrus ? If not good, I can rework it. : ) Thank you very much! Author: gatorsmile Closes #10335 from gatorsmile/rangeOperators. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../plans/logical/basicOperators.scala | 32 ++++++++++ .../org/apache/spark/sql/SQLContext.scala | 23 ++++--- .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/basicOperators.scala | 62 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++ .../execution/ExchangeCoordinatorSuite.scala | 1 + 7 files changed, 119 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 194ecc0a0434e..81a4d0a4d6e9a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -759,7 +759,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val numElements: BigInt = { val safeStart = BigInt(start) val safeEnd = BigInt(end) - if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { (safeEnd - safeStart) / step } else { // the remainder has the same sign with range, could add 1 more diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ec42b763f18ee..64ef4d799659f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -210,6 +210,38 @@ case class Sort( override def output: Seq[Attribute] = child.output } +/** Factory for constructing new `Range` nodes. */ +object Range { + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes + new Range(start, end, step, numSlices, output) + } +} + +case class Range( + start: Long, + end: Long, + step: Long, + numSlices: Int, + output: Seq[Attribute]) extends LeafNode { + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + + override def statistics: Statistics = { + val sizeInBytes = LongType.defaultSize * numElements + Statistics( sizeInBytes = sizeInBytes ) + } +} + case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index db286ea8700b6..eadf5cba6d9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ @@ -785,9 +785,20 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long): DataFrame = { - createDataFrame( - sparkContext.range(start, end).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long): DataFrame = { + range(start, end, step, numPartitions = sparkContext.defaultParallelism) } /** @@ -801,9 +812,7 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { - createDataFrame( - sparkContext.range(start, end, step, numPartitions).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + DataFrame(this, Range(start, end, step, numPartitions)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 688555cf136e8..183d9b65023b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -358,6 +358,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil + case r @ logical.Range(start, end, step, numSlices, output) => + execution.Range(start, step, numSlices, r.numElements, output) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => execution.Exchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b3e4688557ba0..21325beb1c8c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler import org.apache.spark.{HashPartitioner, SparkEnv} @@ -126,6 +127,67 @@ case class Sample( } } +case class Range( + start: Long, + step: Long, + numSlices: Int, + numElements: BigInt, + output: Seq[Attribute]) + extends LeafNode { + + override def outputsUnsafeRows: Boolean = true + + protected override def doExecute(): RDD[InternalRow] = { + sqlContext + .sparkContext + .parallelize(0 until numSlices, numSlices) + .mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize + val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) + + new Iterator[InternalRow] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + + unsafeRow.setLong(0, ret) + unsafeRow + } + } + }) + } +} + /** * Union two plans, without a distinct. This is UNION ALL in SQL. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1a0f1b61cb3c7..ad478b0511095 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -769,6 +769,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) + + // using the default slice number + val res12 = sqlContext.range(3, 15, 3).select("id") + assert(res12.count == 4) + assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } test("SPARK-8621: support empty string column name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 180050bdac00f..101cf50d807ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -260,6 +260,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set("spark.driver.allowMultipleContexts", "true") .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, targetNumPostShufflePartitions.toString) From 935f46630685306edbdec91f71710703317fe129 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 21 Dec 2015 14:02:40 -0800 Subject: [PATCH 1317/2089] [SPARK-12392][CORE] Optimize a location order of broadcast blocks by considering preferred local hosts When multiple workers exist in a host, we can bypass unnecessary remote access for broadcasts; block managers fetch broadcast blocks from the same host instead of remote hosts. Author: Takeshi YAMAMURO Closes #10346 from maropu/OptimizeBlockLocationOrder. --- .../apache/spark/storage/BlockManager.scala | 12 +++++++++++- .../spark/storage/BlockManagerSuite.scala | 19 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6074fc58d70db..b5b7804d54ce2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -578,9 +578,19 @@ private[spark] class BlockManager( doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } + /** + * Return a list of locations for the given block, prioritizing the local machine since + * multiple block managers can share the same host. + */ + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + val locs = Random.shuffle(master.getLocations(blockId)) + val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } + preferredLocs ++ otherLocs + } + private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { require(blockId != null, "BlockId is null") - val locations = Random.shuffle(master.getLocations(blockId)) + val locations = getLocations(blockId) var numFetchFailures = 0 for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 53991d8a1aede..bf49be3d4c4fd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -26,6 +26,7 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.mockito.Mockito.{mock, when} +import org.mockito.{Matchers => mc} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -66,7 +67,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, - name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + name: String = SparkContext.DRIVER_IDENTIFIER, + master: BlockManagerMaster = this.master): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, @@ -451,6 +453,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } + test("optimize a location order of blocks") { + val localHost = Utils.localHostName() + val otherHost = "otherHost" + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1) + val bmId2 = BlockManagerId("id2", localHost, 2) + val bmId3 = BlockManagerId("id3", otherHost, 3) + when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + test("SPARK-9591: getRemoteBytes from another location when Exception throw") { val origTimeoutOpt = conf.getOption("spark.network.timeout") try { From 1eb90bc9cac33780890567343dab75fc14f9110a Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 21 Dec 2015 14:04:23 -0800 Subject: [PATCH 1318/2089] [SPARK-5882][GRAPHX] Add a test for GraphLoader.edgeListFile Author: Takeshi YAMAMURO Closes #4674 from maropu/AddGraphLoaderSuite. --- .../spark/graphx/GraphLoaderSuite.scala | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala new file mode 100644 index 0000000000000..bff9f328d4907 --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import java.io.File +import java.io.FileOutputStream +import java.io.OutputStreamWriter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +class GraphLoaderSuite extends SparkFunSuite with LocalSparkContext { + + test("GraphLoader.edgeListFile") { + withSpark { sc => + val tmpDir = Utils.createTempDir() + val graphFile = new File(tmpDir.getAbsolutePath, "graph.txt") + val writer = new OutputStreamWriter(new FileOutputStream(graphFile)) + for (i <- (1 until 101)) writer.write(s"$i 0\n") + writer.close() + try { + val graph = GraphLoader.edgeListFile(sc, tmpDir.getAbsolutePath) + val neighborAttrSums = graph.aggregateMessages[Int]( + ctx => ctx.sendToDst(ctx.srcAttr), + _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, 100))) + } finally { + Utils.deleteRecursively(tmpDir) + } + } + } +} From fc6dbcc7038c2b030ef6a2dc8be5848499ccee1c Mon Sep 17 00:00:00 2001 From: pshearer Date: Mon, 21 Dec 2015 14:04:59 -0800 Subject: [PATCH 1319/2089] Doc typo: ltrim = trim from left end, not right Author: pshearer Closes #10414 from pshearer/patch-1. --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 90625949f747a..25594d79c2141 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1053,7 +1053,7 @@ def sha2(col, numBits): 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', 'reverse': 'Reverses the string column and returns it as a new string column.', - 'ltrim': 'Trim the spaces from right end for the specified string value.', + 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', } From b0849b8aeafa801bb0561f1f6e46dc1d56c37c19 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Mon, 21 Dec 2015 14:06:36 -0800 Subject: [PATCH 1320/2089] [SPARK-12339][SPARK-11206][WEBUI] Added a null check that was removed in Updates made in SPARK-11206 missed an edge case which cause's a NullPointerException when a task is killed. In some cases when a task ends in failure taskMetrics is initialized as null (see JobProgressListener.onTaskEnd()). To address this a null check was added. Before the changes in SPARK-11206 this null check was called at the start of the updateTaskAccumulatorValues() function. Author: Alex Bozarth Closes #10405 from ajbozarth/spark12339. --- .../spark/sql/execution/ui/SQLListener.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index e19a1e3e5851f..622e01c46e1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -160,12 +160,14 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulatorUpdates(), - finishTask = true) + if (taskEnd.taskMetrics != null) { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics.accumulatorUpdates(), + finishTask = true) + } } /** From a820ca19de1fb4daa01939a4b8bde8d874a7f3fc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 21 Dec 2015 14:07:48 -0800 Subject: [PATCH 1321/2089] [SPARK-2331] SparkContext.emptyRDD should return RDD[T] not EmptyRDD[T] Author: Reynold Xin Closes #10394 from rxin/SPARK-2331. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- project/MimaExcludes.scala | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 81a4d0a4d6e9a..c4541aa3766a8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1248,7 +1248,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: RDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a3cfcd20fe690..ad878c1892e99 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,9 @@ import com.typesafe.tools.mima.core.ProblemFilters._ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("2.0") => + Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD") + ) ++ // When 1.6 is officially released, update this exclusion list. Seq( MimaBuild.excludeSparkPackage("deploy"), From d655d37ddf59d7fb6db529324ac8044d53b2622a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 21 Dec 2015 14:09:04 -0800 Subject: [PATCH 1322/2089] [SPARK-12466] Fix harmless NPE in tests ``` [info] ReplayListenerSuite: [info] - Simple replay (58 milliseconds) java.lang.NullPointerException at org.apache.spark.deploy.master.Master$$anonfun$asyncRebuildSparkUI$1.applyOrElse(Master.scala:982) at org.apache.spark.deploy.master.Master$$anonfun$asyncRebuildSparkUI$1.applyOrElse(Master.scala:980) ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark-QA-Test/job/Spark-Master-SBT/4316/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/consoleFull This was introduced in #10284. It's harmless because the NPE is caused by a race that occurs mainly in `local-cluster` tests (but don't actually fail the tests). Tested locally to verify that the NPE is gone. Author: Andrew Or Closes #10417 from andrewor14/fix-harmless-npe. --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fc42bf06e40a2..5d97c63918856 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -979,7 +979,11 @@ private[deploy] class Master( futureUI.onSuccess { case Some(ui) => appIdToUI.put(app.id, ui) - self.send(AttachCompletedRebuildUI(app.id)) + // `self` can be null if we are already in the process of shutting down + // This happens frequently in tests where `local-cluster` is used + if (self != null) { + self.send(AttachCompletedRebuildUI(app.id)) + } // Application UI is successfully rebuilt, so link the Master UI to it // NOTE - app.appUIUrlAtHistoryServer is volatile app.appUIUrlAtHistoryServer = Some(ui.basePath) From 29cecd4a42f6969613e5b2a40f2724f99e7eec01 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 21 Dec 2015 14:21:43 -0800 Subject: [PATCH 1323/2089] [SPARK-12388] change default compression to lz4 According the benchmark [1], LZ4-java could be 80% (or 30%) faster than Snappy. After changing the compressor to LZ4, I saw 20% improvement on end-to-end time for a TPCDS query (Q4). [1] https://github.com/ning/jvm-compressor-benchmark/wiki cc rxin Author: Davies Liu Closes #10342 from davies/lz4. --- .rat-excludes | 1 + .../apache/spark/io/CompressionCodec.scala | 12 +- .../apache/spark/io/LZ4BlockInputStream.java | 263 ++++++++++++++++++ .../spark/io/CompressionCodecSuite.scala | 8 +- docs/configuration.md | 2 +- .../execution/ExchangeCoordinatorSuite.scala | 4 +- 6 files changed, 276 insertions(+), 14 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java diff --git a/.rat-excludes b/.rat-excludes index 7262c960ed6bb..3544c0fc3d910 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -84,3 +84,4 @@ gen-java.* org.apache.spark.sql.sources.DataSourceRegister org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet +LZ4BlockInputStream.java diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ca74eedf89be5..717804626f852 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,10 +17,10 @@ package org.apache.spark.io -import java.io.{IOException, InputStream, OutputStream} +import java.io._ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} +import net.jpountz.lz4.LZ4BlockOutputStream import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf @@ -49,7 +49,8 @@ private[spark] object CompressionCodec { private val configKey = "spark.io.compression.codec" private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { - codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + || codec.isInstanceOf[LZ4CompressionCodec]) } private val shortCompressionCodecNames = Map( @@ -92,12 +93,11 @@ private[spark] object CompressionCodec { } } - val FALLBACK_COMPRESSION_CODEC = "lzf" - val DEFAULT_COMPRESSION_CODEC = "snappy" + val FALLBACK_COMPRESSION_CODEC = "snappy" + val DEFAULT_COMPRESSION_CODEC = "lz4" val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq } - /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. diff --git a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java new file mode 100644 index 0000000000000..27b6f0d4a3885 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java @@ -0,0 +1,263 @@ +package org.apache.spark.io; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.Checksum; + +import net.jpountz.lz4.LZ4BlockOutputStream; +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; +import net.jpountz.util.SafeUtils; +import net.jpountz.xxhash.StreamingXXHash32; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; + +/** + * {@link InputStream} implementation to decode data written with + * {@link LZ4BlockOutputStream}. This class is not thread-safe and does not + * support {@link #mark(int)}/{@link #reset()}. + * @see LZ4BlockOutputStream + * + * This is based on net.jpountz.lz4.LZ4BlockInputStream + * + * changes: https://github.com/davies/lz4-java/commit/cc1fa940ac57cc66a0b937300f805d37e2bf8411 + * + * TODO: merge this into upstream + */ +public final class LZ4BlockInputStream extends FilterInputStream { + + // Copied from net.jpountz.lz4.LZ4BlockOutputStream + static final byte[] MAGIC = new byte[] { 'L', 'Z', '4', 'B', 'l', 'o', 'c', 'k' }; + static final int MAGIC_LENGTH = MAGIC.length; + + static final int HEADER_LENGTH = + MAGIC_LENGTH // magic bytes + + 1 // token + + 4 // compressed length + + 4 // decompressed length + + 4; // checksum + + static final int COMPRESSION_LEVEL_BASE = 10; + + static final int COMPRESSION_METHOD_RAW = 0x10; + static final int COMPRESSION_METHOD_LZ4 = 0x20; + + static final int DEFAULT_SEED = 0x9747b28c; + + private final LZ4FastDecompressor decompressor; + private final Checksum checksum; + private byte[] buffer; + private byte[] compressedBuffer; + private int originalLen; + private int o; + private boolean finished; + + /** + * Create a new {@link InputStream}. + * + * @param in the {@link InputStream} to poll + * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to + * use + * @param checksum the {@link Checksum} instance to use, must be + * equivalent to the instance which has been used to + * write the stream + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { + super(in); + this.decompressor = decompressor; + this.checksum = checksum; + this.buffer = new byte[0]; + this.compressedBuffer = new byte[HEADER_LENGTH]; + o = originalLen = 0; + finished = false; + } + + /** + * Create a new instance using {@link XXHash32} for checksuming. + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) + * @see StreamingXXHash32#asChecksum() + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { + this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); + } + + /** + * Create a new instance which uses the fastest {@link LZ4FastDecompressor} available. + * @see LZ4Factory#fastestInstance() + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor) + */ + public LZ4BlockInputStream(InputStream in) { + this(in, LZ4Factory.fastestInstance().fastDecompressor()); + } + + @Override + public int available() throws IOException { + refill(); + return originalLen - o; + } + + @Override + public int read() throws IOException { + refill(); + if (finished) { + return -1; + } + return buffer[o++] & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + SafeUtils.checkRange(b, off, len); + refill(); + if (finished) { + return -1; + } + len = Math.min(len, originalLen - o); + System.arraycopy(buffer, o, b, off, len); + o += len; + return len; + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + @Override + public long skip(long n) throws IOException { + refill(); + if (finished) { + return -1; + } + final int skipped = (int) Math.min(n, originalLen - o); + o += skipped; + return skipped; + } + + private void refill() throws IOException { + if (finished || o < originalLen) { + return; + } + try { + readFully(compressedBuffer, HEADER_LENGTH); + } catch (EOFException e) { + finished = true; + return; + } + for (int i = 0; i < MAGIC_LENGTH; ++i) { + if (compressedBuffer[i] != MAGIC[i]) { + throw new IOException("Stream is corrupted"); + } + } + final int token = compressedBuffer[MAGIC_LENGTH] & 0xFF; + final int compressionMethod = token & 0xF0; + final int compressionLevel = COMPRESSION_LEVEL_BASE + (token & 0x0F); + if (compressionMethod != COMPRESSION_METHOD_RAW && compressionMethod != COMPRESSION_METHOD_LZ4) + { + throw new IOException("Stream is corrupted"); + } + final int compressedLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 1); + originalLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 5); + final int check = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 9); + assert HEADER_LENGTH == MAGIC_LENGTH + 13; + if (originalLen > 1 << compressionLevel + || originalLen < 0 + || compressedLen < 0 + || (originalLen == 0 && compressedLen != 0) + || (originalLen != 0 && compressedLen == 0) + || (compressionMethod == COMPRESSION_METHOD_RAW && originalLen != compressedLen)) { + throw new IOException("Stream is corrupted"); + } + if (originalLen == 0 && compressedLen == 0) { + if (check != 0) { + throw new IOException("Stream is corrupted"); + } + refill(); + return; + } + if (buffer.length < originalLen) { + buffer = new byte[Math.max(originalLen, buffer.length * 3 / 2)]; + } + switch (compressionMethod) { + case COMPRESSION_METHOD_RAW: + readFully(buffer, originalLen); + break; + case COMPRESSION_METHOD_LZ4: + if (compressedBuffer.length < originalLen) { + compressedBuffer = new byte[Math.max(compressedLen, compressedBuffer.length * 3 / 2)]; + } + readFully(compressedBuffer, compressedLen); + try { + final int compressedLen2 = + decompressor.decompress(compressedBuffer, 0, buffer, 0, originalLen); + if (compressedLen != compressedLen2) { + throw new IOException("Stream is corrupted"); + } + } catch (LZ4Exception e) { + throw new IOException("Stream is corrupted", e); + } + break; + default: + throw new AssertionError(); + } + checksum.reset(); + checksum.update(buffer, 0, originalLen); + if ((int) checksum.getValue() != check) { + throw new IOException("Stream is corrupted"); + } + o = 0; + } + + private void readFully(byte[] b, int len) throws IOException { + int read = 0; + while (read < len) { + final int r = in.read(b, read, len - read); + if (r < 0) { + throw new EOFException("Stream ended prematurely"); + } + read += r; + } + assert len == read; + } + + @Override + public boolean markSupported() { + return false; + } + + @SuppressWarnings("sync-override") + @Override + public void mark(int readlimit) { + // unsupported + } + + @SuppressWarnings("sync-override") + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "(in=" + in + + ", decompressor=" + decompressor + ", checksum=" + checksum + ")"; + } + +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 1553ab60bddaa..9e9c2b0165e13 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -46,7 +46,7 @@ class CompressionCodecSuite extends SparkFunSuite { test("default compression codec") { val codec = CompressionCodec.createCodec(conf) - assert(codec.getClass === classOf[SnappyCompressionCodec]) + assert(codec.getClass === classOf[LZ4CompressionCodec]) testCodec(codec) } @@ -62,12 +62,10 @@ class CompressionCodecSuite extends SparkFunSuite { testCodec(codec) } - test("lz4 does not support concatenation of serialized streams") { + test("lz4 supports concatenation of serialized streams") { val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) assert(codec.getClass === classOf[LZ4CompressionCodec]) - intercept[Exception] { - testConcatenationOfSerializedStreams(codec) - } + testConcatenationOfSerializedStreams(codec) } test("lzf compression codec") { diff --git a/docs/configuration.md b/docs/configuration.md index 85e7d1202d2ab..a9ef37a9b1cd9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -595,7 +595,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec - snappy + lz4 The codec used to compress internal data such as RDD partitions, broadcast variables and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 101cf50d807ce..2715179e8500c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -319,7 +319,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 1536, minNumPostShufflePartitions) + withSQLContext(test, 2000, minNumPostShufflePartitions) } test(s"determining the number of reducers: join operator$testNameNote") { @@ -422,7 +422,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 6144, minNumPostShufflePartitions) + withSQLContext(test, 6644, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 2$testNameNote") { From 0a38637d05d2338503ecceacfb911a6da6d49538 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 21 Dec 2015 22:15:52 -0800 Subject: [PATCH 1324/2089] [SPARK-11807] Remove support for Hadoop < 2.2 i.e. Hadoop 1 and Hadoop 2.0 Author: Reynold Xin Closes #10404 from rxin/SPARK-11807. --- .../deploy/history/FsHistoryProvider.scala | 10 +--------- .../mapreduce/SparkHadoopMapReduceUtil.scala | 17 ++--------------- dev/create-release/release-build.sh | 3 --- dev/run-tests-jenkins.py | 4 ---- dev/run-tests.py | 2 -- docs/building-spark.md | 18 ++++-------------- make-distribution.sh | 2 +- pom.xml | 13 ------------- sql/README.md | 2 +- 9 files changed, 9 insertions(+), 62 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 718efc4f3bd5e..6e91d73b6e0fd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -663,16 +663,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - val hadoop1Class = "org.apache.hadoop.hdfs.protocol.FSConstants$SafeModeAction" val hadoop2Class = "org.apache.hadoop.hdfs.protocol.HdfsConstants$SafeModeAction" - val actionClass: Class[_] = - try { - getClass().getClassLoader().loadClass(hadoop2Class) - } catch { - case _: ClassNotFoundException => - getClass().getClassLoader().loadClass(hadoop1Class) - } - + val actionClass: Class[_] = getClass().getClassLoader().loadClass(hadoop2Class) val action = actionClass.getField("SAFEMODE_GET").get(null) val method = dfs.getClass().getMethod("setSafeMode", action.getClass()) method.invoke(dfs, action).asInstanceOf[Boolean] diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 943ebcb7bd0a1..82d807fad8938 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -26,17 +26,13 @@ import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { - val klass = firstAvailableClass( - "org.apache.hadoop.mapreduce.task.JobContextImpl", // hadoop2, hadoop2-yarn - "org.apache.hadoop.mapreduce.JobContext") // hadoop1 + val klass = Utils.classForName("org.apache.hadoop.mapreduce.task.JobContextImpl") val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID]) ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass( - "org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl", // hadoop2, hadoop2-yarn - "org.apache.hadoop.mapreduce.TaskAttemptContext") // hadoop1 + val klass = Utils.classForName("org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl") val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID]) ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } @@ -69,13 +65,4 @@ trait SparkHadoopMapReduceUtil { } } } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - Utils.classForName(first) - } catch { - case e: ClassNotFoundException => - Utils.classForName(second) - } - } } diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index cb79e9eba06e2..b1895b16b1b61 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -166,9 +166,6 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 7aecea25b2099..42afca0e52448 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -163,10 +163,6 @@ def main(): if "test-maven" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_TOOL"] = "maven" # Switch the Hadoop profile based on the PR title: - if "test-hadoop1.0" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" - if "test-hadoop2.0" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" if "test-hadoop2.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" if "test-hadoop2.3" in ghprb_pull_title: diff --git a/dev/run-tests.py b/dev/run-tests.py index 2d4e04c4684de..17ceba052b8cd 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -301,8 +301,6 @@ def get_hadoop_profiles(hadoop_version): """ sbt_maven_hadoop_profiles = { - "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.2.1"], - "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], diff --git a/docs/building-spark.md b/docs/building-spark.md index 3d38edbdad4bc..785988902da8e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -33,13 +33,13 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ # Building a Runnable Distribution -To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as -to be runnable, use `make-distribution.sh` in the project root directory. It can be configured +To create a Spark distribution like those distributed by the +[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +to be runnable, use `make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: ./make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn - + For more information on usage, run `./make-distribution.sh --help` # Setting up Maven's Memory Usage @@ -74,7 +74,6 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro Hadoop versionProfile required - 1.x to 2.1.xhadoop-1 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 @@ -82,15 +81,6 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro -For Apache Hadoop versions 1.x, Cloudera CDH "mr1" distributions, and other Hadoop versions without YARN, use: - -{% highlight bash %} -# Apache Hadoop 1.2.1 -mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package - -# Cloudera CDH 4.2.0 with MapReduce v1 -mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package -{% endhighlight %} You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. diff --git a/make-distribution.sh b/make-distribution.sh index e64ceb802464c..351b9e7d89a32 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -58,7 +58,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) diff --git a/pom.xml b/pom.xml index 32918d6a74af9..284c219519bca 100644 --- a/pom.xml +++ b/pom.xml @@ -2442,19 +2442,6 @@ http://hadoop.apache.org/docs/ra.b.c/hadoop-project-dist/hadoop-common/dependency-analysis.html --> - - hadoop-1 - - 1.2.1 - 2.4.1 - 0.98.7-hadoop1 - hadoop1 - 1.8.8 - org.spark-project.akka - 2.3.4-spark - - - hadoop-2.2 diff --git a/sql/README.md b/sql/README.md index 63d4dac9829e0..a13bdab6d457f 100644 --- a/sql/README.md +++ b/sql/README.md @@ -20,7 +20,7 @@ If you are working with Hive 0.12.0, you will need to set several environmental ``` export HIVE_HOME="/hive/build/dist" export HIVE_DEV_HOME="/hive/" -export HADOOP_HOME="/hadoop-1.0.4" +export HADOOP_HOME="/hadoop" ``` If you are working with Hive 0.13.1, the following steps are needed: From 93da8565fea42d8ac978df411daced4a9ea3a9c8 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 21 Dec 2015 22:28:18 -0800 Subject: [PATCH 1325/2089] [MINOR] Fix typos in JavaStreamingContext Author: Shixiong Zhu Closes #10424 from zsxwing/typo. --- .../spark/streaming/api/java/JavaStreamingContext.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 8f21c79a760c1..7a50135025463 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -695,9 +695,9 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. */ - @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -718,7 +718,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -744,7 +744,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( From 2235cd44407e3b6b401fb84a2096ade042c51d36 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Dec 2015 23:12:05 -0800 Subject: [PATCH 1326/2089] [SPARK-11823][SQL] Fix flaky JDBC cancellation test in HiveThriftBinaryServerSuite This patch fixes a flaky "test jdbc cancel" test in HiveThriftBinaryServerSuite. This test is prone to a race-condition which causes it to block indefinitely with while waiting for an extremely slow query to complete, which caused many Jenkins builds to time out. For more background, see my comments on #6207 (the PR which introduced this test). Author: Josh Rosen Closes #10425 from JoshRosen/SPARK-11823. --- .../HiveThriftServer2Suites.scala | 85 ++++++++++++------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 139d8e897ba1d..ebb2575416b72 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -23,9 +23,8 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise, future} +import scala.concurrent.{Await, ExecutionContext, Promise, future} import scala.io.Source import scala.util.{Random, Try} @@ -43,7 +42,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SparkFunSuite} object TestData { @@ -356,31 +355,54 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") queries.foreach(statement.execute) - - val largeJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(10)("join test_map").mkString(" ") - val f = future { Thread.sleep(100); statement.cancel(); } - val e = intercept[SQLException] { - statement.executeQuery(largeJoin) + implicit val ec = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("test-jdbc-cancel")) + try { + // Start a very-long-running query that will take hours to finish, then cancel it in order + // to demonstrate that cancellation works. + val f = future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ")) + } + // Note that this is slightly race-prone: if the cancel is issued before the statement + // begins executing then we'll fail with a timeout. As a result, this fixed delay is set + // slightly more conservatively than may be strictly necessary. + Thread.sleep(1000) + statement.cancel() + val e = intercept[SQLException] { + Await.result(f, 3.minute) + } + assert(e.getMessage.contains("cancelled")) + + // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + try { + val sf = future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + ) + } + // Similarly, this is also slightly race-prone on fast machines where the query above + // might race and complete before we issue the cancel. + Thread.sleep(1000) + statement.cancel() + val rs1 = Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } finally { + statement.executeQuery("SET spark.sql.hive.thriftServer.async=true") + } + } finally { + ec.shutdownNow() } - assert(e.getMessage contains "cancelled") - Await.result(f, 3.minute) - - // cancel is a noop - statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") - val sf = future { Thread.sleep(100); statement.cancel(); } - val smallJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(4)("join test_map").mkString(" ") - val rs1 = statement.executeQuery(smallJoin) - Await.result(sf, 3.minute) - rs1.next() - assert(rs1.getInt(1) === math.pow(5, 5)) - rs1.close() - - val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") - rs2.next() - assert(rs2.getInt(1) === 5) - rs2.close() } } @@ -817,6 +839,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def beforeAll(): Unit = { + super.beforeAll() // Chooses a random port between 10000 and 19999 listeningPort = 10000 + Random.nextInt(10000) diagnosisBuffer.clear() @@ -838,7 +861,11 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def afterAll(): Unit = { - stopThriftServer() - logInfo("HiveThriftServer2 stopped") + try { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } finally { + super.afterAll() + } } } From 969d5665bb1806703f948e8e7ab6133fca38c086 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Dec 2015 09:14:12 +0200 Subject: [PATCH 1327/2089] [SPARK-12296][PYSPARK][MLLIB] Feature parity for pyspark mllib standard scaler model Some methods are missing, such as ways to access the std, mean, etc. This PR is for feature parity for pyspark.mllib.feature.StandardScaler & StandardScalerModel. Author: Holden Karau Closes #10298 from holdenk/SPARK-12296-feature-parity-pyspark-mllib-StandardScalerModel. --- python/pyspark/mllib/feature.py | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index acd7ec57d69da..612935352575f 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -172,6 +172,38 @@ def setWithStd(self, withStd): self.call("setWithStd", withStd) return self + @property + @since('2.0.0') + def withStd(self): + """ + Returns if the model scales the data to unit standard deviation. + """ + return self.call("withStd") + + @property + @since('2.0.0') + def withMean(self): + """ + Returns if the model centers the data before scaling. + """ + return self.call("withMean") + + @property + @since('2.0.0') + def std(self): + """ + Return the column standard deviation values. + """ + return self.call("std") + + @property + @since('2.0.0') + def mean(self): + """ + Return the column mean values. + """ + return self.call("mean") + class StandardScaler(object): """ @@ -196,6 +228,14 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + >>> int(model.std[0]) + 4 + >>> int(model.mean[0]*10) + 9 + >>> model.withStd + True + >>> model.withMean + True .. versionadded:: 1.2.0 """ From 8c1b867cee816d0943184c7b485cd11e255d8130 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Tue, 22 Dec 2015 00:50:05 -0800 Subject: [PATCH 1328/2089] [SPARK-12446][SQL] Add unit tests for JDBCRDD internal functions No tests done for JDBCRDD#compileFilter. Author: Takeshi YAMAMURO Closes #10409 from maropu/AddTestsInJdbcRdd. --- .../execution/datasources/jdbc/JDBCRDD.scala | 63 ++++++++++--------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 24 ++++++- 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2d38562e0901a..fc0f86cb1813f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -163,8 +163,37 @@ private[sql] object JDBCRDD extends Logging { * @return A Catalyst schema corresponding to columns in the given order. */ private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { - val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*) - new StructType(columns map { name => fieldMap(name) }) + val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*) + new StructType(columns.map(name => fieldMap(name))) + } + + /** + * Converts value to SQL expression. + */ + private def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case _ => value + } + + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Turns a single Filter into a String representing a SQL expression. + * Returns null for an unhandled filter. + */ + private def compileFilter(f: Filter): String = f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" + case _ => null } /** @@ -262,40 +291,12 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } - /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" - case timestampValue: Timestamp => "'" + timestampValue + "'" - case dateValue: Date => "'" + dateValue + "'" - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. - */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case _ => null - } /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ private val filterWhereClause: String = { - val filterStrings = filters map compileFilter filter (_ != null) + val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") filterStrings.foreach(x => sb.append(x).append(" AND ")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2b91f62c2fa22..7975c5df6c0bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.DriverManager +import java.sql.{Date, DriverManager, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class JDBCSuite extends SparkFunSuite + with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ val url = "jdbc:h2:mem:testdb0" @@ -429,6 +433,22 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } + test("compile filters") { + val compileFilter = PrivateMethod[String]('compileFilter) + def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) + assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") + assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "col1 != 'abc'") + assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5") + assert(doCompileFilter(LessThan("col3", + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'") + assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'") + assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5") + assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3") + assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3") + assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") + assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + } + test("Dialect unregister") { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) From 42bfde29836529251a4337ea8cfc539c9c8b04b8 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 22 Dec 2015 19:41:44 +0800 Subject: [PATCH 1329/2089] [SPARK-12371][SQL] Runtime nullability check for NewInstance This PR adds a new expression `AssertNotNull` to ensure non-nullable fields of products and case classes don't receive null values at runtime. Author: Cheng Lian Closes #10331 from liancheng/dataset-nullability-check. --- .../sql/catalyst/JavaTypeInference.scala | 9 +- .../spark/sql/catalyst/ScalaReflection.scala | 10 +- .../sql/catalyst/expressions/objects.scala | 40 ++++++ .../encoders/EncoderResolutionSuite.scala | 21 ++- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../apache/spark/sql/JavaDatasetSuite.java | 126 +++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 33 +++++ 7 files changed, 232 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f566d1b3caebf..a1500cbc305d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -288,7 +288,14 @@ object JavaTypeInference { val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType - p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + val (_, nullable) = inferDataType(fieldType) + val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val setter = if (nullable) { + constructor + } else { + AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + } + p.getWriteMethod.getName -> setter }.toMap val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index cc9e6af1818f2..becd019caeca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -326,7 +326,7 @@ object ScalaReflection extends ScalaReflection { val cls = getClassFromType(tpe) val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => - val dataType = schemaFor(fieldType).dataType + val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. @@ -336,10 +336,16 @@ object ScalaReflection extends ScalaReflection { Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - constructorFor( + val constructor = constructorFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + + if (!nullable) { + AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + } else { + constructor + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 492cc9bf4146c..d40cd96905732 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -624,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """ } } + +/** + * Asserts that input values of a non-nullable child expression are not null. + * + * Note that there are cases where `child.nullable == true`, while we still needs to add this + * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable + * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all + * non-null `s`, `s.i` can't be null. + */ +case class AssertNotNull( + child: Expression, parentType: String, fieldName: String, fieldType: String) + extends UnaryExpression { + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childGen = child.gen(ctx) + + ev.isNull = "false" + ev.value = childGen.value + + s""" + ${childGen.code} + + if (${childGen.isNull}) { + throw new RuntimeException( + "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + ); + } + """ + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 815a03f7c1a89..764ffdc0947c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -36,12 +36,16 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[StringLongClass] val cls = classOf[StringLongClass] + { val attrs = Seq('a.string, 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, - toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + Seq( + toExternalString('a.string), + AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") + ), false, ObjectType(cls)) compareExpressions(fromRowExpr, expected) @@ -52,7 +56,10 @@ class EncoderResolutionSuite extends PlanTest { val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression val expected = NewInstance( cls, - toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + Seq( + toExternalString('a.int.cast(StringType)), + AssertNotNull('b.long, cls.getName, "b", "Long") + ), false, ObjectType(cls)) compareExpressions(fromRowExpr, expected) @@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest { val expected: Expression = NewInstance( cls, Seq( - 'a.int.cast(LongType), + AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), If( 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), @@ -78,7 +85,9 @@ class EncoderResolutionSuite extends PlanTest { Seq( toExternalString( GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), + AssertNotNull( + GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), + innerCls.getName, "b", "Long")), false, ObjectType(innerCls)) )), @@ -102,7 +111,9 @@ class EncoderResolutionSuite extends PlanTest { cls, Seq( toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), + AssertNotNull( + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), + cls.getName, "b", "Long")), false, ObjectType(cls)), 'b.int.cast(LongType)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d201d65238523..a763a951440cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD @@ -64,7 +65,7 @@ import org.apache.spark.util.Utils class Dataset[T] private[sql]( @transient override val sqlContext: SQLContext, @transient override val queryExecution: QueryExecution, - tEncoder: Encoder[T]) extends Queryable with Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable with Logging { /** * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0dbaeb81c7ec9..9f8db39e33d7e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,8 @@ import java.sql.Timestamp; import java.util.*; +import com.google.common.base.Objects; +import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -39,7 +41,6 @@ import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; @@ -741,4 +742,127 @@ public void testJavaBeanEncoder2() { context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } + + public class SmallBean implements Serializable { + private String a; + + private int b; + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SmallBean smallBean = (SmallBean) o; + return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a); + } + + @Override + public int hashCode() { + return Objects.hashCode(a, b); + } + } + + public class NestedSmallBean implements Serializable { + private SmallBean f; + + public SmallBean getF() { + return f; + } + + public void setF(SmallBean f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NestedSmallBean that = (NestedSmallBean) o; + return Objects.equal(f, that.f); + } + + @Override + public int hashCode() { + return Objects.hashCode(f); + } + } + + @Rule + public transient ExpectedException nullabilityCheck = ExpectedException.none(); + + @Test + public void testRuntimeNullabilityCheck() { + OuterScopes.addOuterScope(this); + + StructType schema = new StructType() + .add("f", new StructType() + .add("a", StringType, true) + .add("b", IntegerType, true), true); + + // Shouldn't throw runtime exception since it passes nullability check. + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", 1 + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + SmallBean smallBean = new SmallBean(); + smallBean.setA("hello"); + smallBean.setB(1); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + nestedSmallBean.setF(smallBean); + + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + { + Row row = new GenericRow(new Object[] { null }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + nullabilityCheck.expect(RuntimeException.class); + nullabilityCheck.expectMessage( + "Null value appeared in non-nullable field " + + "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", null + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + ds.collect(); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index de012a9a56454..3337996309d4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,6 +24,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { @@ -515,12 +516,44 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) } + + test("runtime nullability check") { + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = true) + )) + + def buildDataset(rows: Row*): Dataset[NestedStruct] = { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + } + + checkAnswer( + buildDataset(Row(Row("hello", 1))), + NestedStruct(ClassData("hello", 1)) + ) + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) + + val message = intercept[RuntimeException] { + buildDataset(Row(Row("hello", null))).collect() + }.getMessage + + assert(message.contains( + "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." + )) + } } case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) +case class NestedStruct(f: ClassData) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. From 364d244a50aab9169ec1abe7e327004e681f8a71 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 23 Dec 2015 00:39:49 +0800 Subject: [PATCH 1330/2089] [SPARK-11677][SQL][FOLLOW-UP] Add tests for checking the ORC filter creation against pushed down filters. https://issues.apache.org/jira/browse/SPARK-11677 Although it checks correctly the filters by the number of results if ORC filter-push-down is enabled, the filters themselves are not being tested. So, this PR includes the test similarly with `ParquetFilterSuite`. Since the results are checked by `OrcQuerySuite`, this `OrcFilterSuite` only checks if the appropriate filters are created. One thing different with `ParquetFilterSuite` here is, it does not check the results because that is checked in `OrcQuerySuite`. Author: hyukjinkwon Closes #10341 from HyukjinKwon/SPARK-11677-followup. --- .../spark/sql/hive/orc/OrcFilterSuite.scala | 236 ++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala new file mode 100644 index 0000000000000..7b61b635bdb48 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, PredicateLeaf} + +import org.apache.spark.sql.{Column, DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} + +/** + * A test suite that tests ORC filter API based filter pushdown optimization. + */ +class OrcFilterSuite extends QueryTest with OrcTest { + private def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[OrcRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: OrcRelation, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(selectedFilters.toArray) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") + checker(maybeFilter.get) + } + + private def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala.head.getOperator + assert(operator === filterOperator) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + private def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + test("filter pushdown - boolean") { + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + } + } + + test("filter pushdown - integer") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - long") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - float") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - double") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - string") { + withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") + } + + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + } + } + + test("filter pushdown - combinations with logical operators") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked + // in string form in order to check filter creation including logical operators + // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` + // to produce string expression and then compare it to given string expression below. + // This might have to be changed after Hive version is upgraded. + checkFilterPredicate( + '_1.isNotNull, + """leaf-0 = (IS_NULL _1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 !== 1, + """leaf-0 = (EQUALS _1 1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + !('_1 < 4), + """leaf-0 = (LESS_THAN _1 4) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 || '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 && '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (and leaf-0 (not leaf-1))""".stripMargin.trim + ) + } + } +} From bc0f30d0f5d01424d2f886adf3ffeaa1fc83a8af Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Dec 2015 10:23:21 -0800 Subject: [PATCH 1331/2089] [SPARK-12475][BUILD] Upgrade Zinc from 0.3.5.3 to 0.3.9 We should update to the latest version of Zinc in order to match our SBT version. Author: Josh Rosen Closes #10426 from JoshRosen/update-zinc. --- build/mvn | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build/mvn b/build/mvn index 7603ea03deb73..63ca9c98067d0 100755 --- a/build/mvn +++ b/build/mvn @@ -81,11 +81,11 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.5.3/bin/zinc" + local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ - "http://downloads.typesafe.com/zinc/0.3.5.3" \ - "zinc-0.3.5.3.tgz" \ + "http://downloads.typesafe.com/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } From b5ce84a1bb8be26d67a2e44011a0c36375de399b Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Tue, 22 Dec 2015 10:44:01 -0800 Subject: [PATCH 1332/2089] [SPARK-12456][SQL] Add ExpressionDescription to misc functions First try, not sure how much information we need to provide in the usage part. Author: Xiu Guo Closes #10423 from xguo27/SPARK-12456. --- .../sql/catalyst/expressions/InputFileName.scala | 3 +++ .../expressions/MonotonicallyIncreasingID.scala | 8 ++++++++ .../catalyst/expressions/SparkPartitionID.scala | 3 +++ .../spark/sql/catalyst/expressions/misc.scala | 15 +++++++++++++++ 4 files changed, 29 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index bf215783fc27d..50ec1d0cccfba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -26,6 +26,9 @@ import org.apache.spark.unsafe.types.UTF8String /** * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the name of the current file being read if available", + extended = "> SELECT _FUNC_();\n ''") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 2d7679fdfe043..6b5aebc428a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -32,6 +32,14 @@ import org.apache.spark.sql.types.{LongType, DataType} * * Since this expression is stateful, it cannot be a case object. */ +@ExpressionDescription( + usage = + """_FUNC_() - Returns monotonically increasing 64-bit integers. + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + represent the record number within each partition. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records.""", + extended = "> SELECT _FUNC_();\n 0") private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8bff173d64eb9..63ec8c64c14e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current partition id of the Spark task", + extended = "> SELECT _FUNC_();\n 0") private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5baab4f7e8c51..97f276d49f08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -30,6 +30,9 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns an MD5 128-bit checksum as a hex string of the input", + extended = "> SELECT _FUNC_('Spark');\n '8cde774d6f7333752ed72cacddb05126'") case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -53,6 +56,12 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or * the hash length is not one of the permitted values, the return value is NULL. */ +@ExpressionDescription( + usage = "_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the " + + "input. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent " + + "to 256", + extended = "> SELECT _FUNC_('Spark', 0);\n " + + "'529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b'") case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { @@ -118,6 +127,9 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns a sha1 hash value as a hex string of the input", + extended = "> SELECT _FUNC_('Spark');\n '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c'") case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -138,6 +150,9 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns a cyclic redundancy check value as a bigint of the input", + extended = "> SELECT _FUNC_('Spark');\n '1557323817'") case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType From 7c970f9093bda0a789d7d6e43c72a6d317fc3723 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 22 Dec 2015 10:47:10 -0800 Subject: [PATCH 1333/2089] Minor corrections, i.e. typo fixes and follow deprecated Author: Jacek Laskowski Closes #10432 from jaceklaskowski/minor-corrections. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 2 +- .../scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala | 2 +- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 4 ++-- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c4541aa3766a8..67230f4207b83 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2095,7 +2095,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") - def defaultMinSplits: Int = math.min(defaultParallelism, 2) + def defaultMinSplits: Int = defaultMinPartitions /** * Default min number of partitions for Hadoop RDDs when not given by user diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c2ebf30596215..77c88baa9be20 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -257,7 +257,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // scalastyle:off println System.err.println( """ - |"Usage: CoarseGrainedExecutorBackend [options] + |Usage: CoarseGrainedExecutorBackend [options] | | Options are: | --driver-url diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala index d2e94f943aba5..cd6f00cc08e6c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -26,7 +26,7 @@ import org.apache.spark.rpc.RpcAddress * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only * connection and can only be reached via the client that sent the endpoint reference. * - * @param rpcAddress The socket address of the endpint. + * @param rpcAddress The socket address of the endpoint. * @param name Name of the endpoint. */ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a02f3017cb6e9..380301f1c9aec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -608,7 +608,7 @@ private[spark] class TaskSetManager( } /** - * Marks the task as successful and notifies the DAGScheduler that a task has ended. + * Marks a task as successful and notifies the DAGScheduler that the task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) @@ -705,7 +705,7 @@ private[spark] class TaskSetManager( ef.exception case e: ExecutorLostFailure if !e.exitCausedByApp => - logInfo(s"Task $tid failed because while it was being computed, its executor" + + logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2279e8cad7bcf..f222007a38c9b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** - * A scheduler backend that waits for coarse grained executors to connect to it through Akka. + * A scheduler backend that waits for coarse-grained executors to connect. * This backend holds onto each executor for the duration of the Spark job rather than relinquishing * executors whenever a task is done and asking the scheduler to launch a new executor for * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the From 575a1327976202614a6d3268918ae8dad49fcd72 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 22 Dec 2015 13:27:28 -0800 Subject: [PATCH 1334/2089] [SPARK-12471][CORE] Spark daemons will log their pid on start up. Author: Nong Li Closes #10422 from nongli/12471-pids. --- .../spark/deploy/ExternalShuffleService.scala | 5 +++-- .../spark/deploy/history/HistoryServer.scala | 8 +++----- .../apache/spark/deploy/master/Master.scala | 4 ++-- .../deploy/mesos/MesosClusterDispatcher.scala | 4 ++-- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 7 ++----- .../spark/executor/MesosExecutorBackend.scala | 6 +++--- .../scala/org/apache/spark/util/Utils.scala | 18 ++++++++++++++++++ .../hive/thriftserver/HiveThriftServer2.scala | 1 + 9 files changed, 35 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 7fc96e4f764b7..c514a1a86bab8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -21,11 +21,11 @@ import java.util.concurrent.CountDownLatch import scala.collection.JavaConverters._ -import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap -import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} +import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -108,6 +108,7 @@ object ExternalShuffleService extends Logging { private[spark] def main( args: Array[String], newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = { + Utils.initDaemon(log) val sparkConf = new SparkConf Utils.loadDefaultSparkProperties(sparkConf) val securityManager = new SecurityManager(sparkConf) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index f31fef0eccc3b..0bc0cb1c15eb2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -23,14 +23,12 @@ import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} - import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, - UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -223,7 +221,7 @@ object HistoryServer extends Logging { val UI_PATH_PREFIX = "/history" def main(argStrings: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) new HistoryServerArguments(conf, argStrings) initSecurity() val securityManager = new SecurityManager(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 5d97c63918856..bd3d981ce08b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -45,7 +45,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, Utils} private[deploy] class Master( override val rpcEnv: RpcEnv, @@ -1087,7 +1087,7 @@ private[deploy] object Master extends Logging { val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 389eff5e0645b..89f1a8671fdb6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{ShutdownHookManager, SignalLogger} +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SecurityManager, SparkConf} /* @@ -92,7 +92,7 @@ private[mesos] class MesosClusterDispatcher( private[mesos] object MesosClusterDispatcher extends Logging { def main(args: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) conf.setMaster(dispatcherArgs.masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f41efb097b4be..84e7b366bc965 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -686,7 +686,7 @@ private[deploy] object Worker extends Logging { val ENDPOINT_NAME = "Worker" def main(argStrings: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 77c88baa9be20..edbd7225ca06a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,11 +20,8 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer -import org.apache.hadoop.conf.Configuration - import scala.collection.mutable import scala.util.{Failure, Success} - import org.apache.spark.rpc._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -33,7 +30,7 @@ import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -146,7 +143,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { workerUrl: Option[String], userClassPath: Seq[URL]) { - SignalLogger.register(log) + Utils.initDaemon(log) SparkHadoopUtil.get.runAsSparkUser { () => // Debug code diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index c9f18ebc7f0ea..d85465eb25683 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -25,11 +25,11 @@ import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} -import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} +import org.apache.spark.{Logging, SparkConf, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.Utils private[spark] class MesosExecutorBackend extends MesosExecutor @@ -121,7 +121,7 @@ private[spark] class MesosExecutorBackend */ private[spark] object MesosExecutorBackend extends Logging { def main(args: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) // Create a new Executor and start it running val runner = new MesosExecutorBackend() new MesosExecutorDriver(runner).run() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fce89dfccfe23..1a07f7ca7eaf5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -43,6 +43,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ +import org.slf4j.Logger import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} @@ -2221,6 +2222,23 @@ private[spark] object Utils extends Logging { def tempFileWith(path: File): File = { new File(path.getAbsolutePath + "." + UUID.randomUUID()) } + + /** + * Returns the name of this JVM process. This is OS dependent but typically (OSX, Linux, Windows), + * this is formatted as PID@hostname. + */ + def getProcessName(): String = { + ManagementFactory.getRuntimeMXBean().getName() + } + + /** + * Utility function that should be called early in `main()` for daemons to set up some common + * diagnostic state. + */ + def initDaemon(log: Logger): Unit = { + log.info(s"Started daemon with process name: ${Utils.getProcessName()}") + SignalLogger.register(log) + } } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index a4fd0c3ce9702..3e3f0382f6a3b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -67,6 +67,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { + Utils.initDaemon(log) val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) From b374a25831af031f461716c52b615665aa5392c2 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 22 Dec 2015 15:21:49 -0800 Subject: [PATCH 1335/2089] [SPARK-12102][SQL] Cast a non-nullable struct field to a nullable field during analysis Compare both left and right side of the case expression ignoring nullablity when checking for type equality. Author: Dilip Biswal Closes #10156 from dilipbiswal/spark-12102. --- .../sql/catalyst/expressions/conditionalExpressions.scala | 4 +++- .../apache/spark/sql/catalyst/analysis/AnalysisSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 40b1eec63e551..f79c8676fb58c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -91,7 +91,9 @@ trait CaseWhenLike extends Expression { // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { + case Seq(dt1, dt2) => dt1.sameType(dt2) + } override def checkInputDataTypes(): TypeCheckResult = { if (valueTypesEqual) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index aeeca802d8bb3..fa823e3021835 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -274,4 +274,10 @@ class AnalysisSuite extends AnalysisTest { assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } + + test("SPARK-12102: Ignore nullablity when comparing two sides of case") { + val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) + val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) + assertAnalysisSuccess(plan) + } } From 93db50d1c2ff97e6eb9200a995e4601f752968ae Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Dec 2015 15:33:30 -0800 Subject: [PATCH 1336/2089] [SPARK-12487][STREAMING][DOCUMENT] Add docs for Kafka message handler Author: Shixiong Zhu Closes #10439 from zsxwing/kafka-message-handler-doc. --- docs/streaming-kafka-integration.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 5be73c42560f5..9454714eeb9cb 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -104,6 +104,7 @@ Next, we discuss how to use this approach in your streaming application. [key class], [value class], [key decoder class], [value decoder class] ]( streamingContext, [map of Kafka parameters], [set of topics to consume]) + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
    @@ -115,6 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [key class], [value class], [key decoder class], [value decoder class], [map of Kafka parameters], [set of topics to consume]); + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). @@ -123,6 +125,7 @@ Next, we discuss how to use this approach in your streaming application. from pyspark.streaming.kafka import KafkaUtils directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
    From 20591afd790799327f99485c5a969ed7412eca45 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 22 Dec 2015 16:39:10 -0800 Subject: [PATCH 1337/2089] [SPARK-12429][STREAMING][DOC] Add Accumulator and Broadcast example for Streaming This PR adds Scala, Java and Python examples to show how to use Accumulator and Broadcast in Spark Streaming to support checkpointing. Author: Shixiong Zhu Closes #10385 from zsxwing/accumulator-broadcast-example. --- docs/programming-guide.md | 6 +- docs/streaming-programming-guide.md | 165 ++++++++++++++++++ .../JavaRecoverableNetworkWordCount.java | 71 +++++++- .../recoverable_network_wordcount.py | 30 +++- .../RecoverableNetworkWordCount.scala | 66 ++++++- 5 files changed, 325 insertions(+), 13 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index c5e2a1cd7b8aa..bad25e63e89e6 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -806,7 +806,7 @@ However, in `cluster` mode, what happens is more complicated, and the above may What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#AccumLink). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1091,7 +1091,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1338,7 +1338,7 @@ run on the cluster so that `v` is not shipped to the nodes more than once. In ad `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later). -## Accumulators +## Accumulators Accumulators are variables that are only "added" to through an associative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ed6b28c282135..3b071c7da5596 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1415,6 +1415,171 @@ Note that the connections in the pool should be lazily created on demand and tim *** +## Accumulators and Broadcast Variables + +[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. + +
    +
    +{% highlight scala %} + +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +object DroppedWordsCounter { + + @volatile private var instance: Accumulator[Long] = null + + def getInstance(sc: SparkContext): Accumulator[Long] = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.accumulator(0L, "WordsInBlacklistCounter") + } + } + } + instance + } +} + +wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter += count + false + } else { + true + } + }.collect() + val output = "Counts at time " + time + " " + counts +}) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +
    +
    +{% highlight java %} + +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +class JavaDroppedWordsCounter { + + private static volatile Accumulator instance = null; + + public static Accumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) throws Exception { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + } +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +
    +
    +{% highlight python %} + +def getWordBlacklist(sparkContext): + if ('wordBlacklist' not in globals()): + globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordBlacklist'] + +def getDroppedWordsCounter(sparkContext): + if ('droppedWordsCounter' not in globals()): + globals()['droppedWordsCounter'] = sparkContext.accumulator(0) + return globals()['droppedWordsCounter'] + +def echo(time, rdd): + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) + +wordCounts.foreachRDD(echo) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). + +
    +
    + +*** + ## DataFrame and SQL Operations You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index bceda97f058ea..90d473703ec5a 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -21,17 +21,22 @@ import java.io.IOException; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.List; import java.util.regex.Pattern; import scala.Tuple2; import com.google.common.collect.Lists; import com.google.common.io.Files; +import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.JavaDStream; @@ -41,7 +46,48 @@ import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; /** - * Counts words in text encoded with UTF8 received from the network every second. + * Use this singleton to get or register a Broadcast variable. + */ +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +/** + * Use this singleton to get or register an Accumulator. + */ +class JavaDroppedWordsCounter { + + private static volatile Accumulator instance = null; + + public static Accumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +/** + * Counts words in text encoded with UTF8 received from the network every second. This example also + * shows how to use lazily instantiated singleton instances for Accumulator and Broadcast so that + * they can be registered on driver failures. * * Usage: JavaRecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive @@ -111,10 +157,27 @@ public Integer call(Integer i1, Integer i2) { wordCounts.foreachRDD(new Function2, Time, Void>() { @Override public Void call(JavaPairRDD rdd, Time time) throws IOException { - String counts = "Counts at time " + time + " " + rdd.collect(); - System.out.println(counts); + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) throws Exception { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + System.out.println(output); + System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); System.out.println("Appending to " + outputFile.getAbsolutePath()); - Files.append(counts + "\n", outputFile, Charset.defaultCharset()); + Files.append(output + "\n", outputFile, Charset.defaultCharset()); return null; } }); diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index ac91f0a06b172..52b2639cdf55c 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -44,6 +44,20 @@ from pyspark.streaming import StreamingContext +# Get or register a Broadcast variable +def getWordBlacklist(sparkContext): + if ('wordBlacklist' not in globals()): + globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordBlacklist'] + + +# Get or register an Accumulator +def getDroppedWordsCounter(sparkContext): + if ('droppedWordsCounter' not in globals()): + globals()['droppedWordsCounter'] = sparkContext.accumulator(0) + return globals()['droppedWordsCounter'] + + def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint @@ -60,8 +74,22 @@ def createContext(host, port, outputPath): wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) def echo(time, rdd): - counts = "Counts at time %s %s" % (time, rdd.collect()) + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) print(counts) + print("Dropped %d word(s) totally" % droppedWordsCounter.value) print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 9916882e4f94a..38d4fd11f97d1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -23,13 +23,55 @@ import java.nio.charset.Charset import com.google.common.io.Files -import org.apache.spark.SparkConf +import org.apache.spark.{Accumulator, SparkConf, SparkContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, Seconds, StreamingContext} import org.apache.spark.util.IntParam /** - * Counts words in text encoded with UTF8 received from the network every second. + * Use this singleton to get or register a Broadcast variable. + */ +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +/** + * Use this singleton to get or register an Accumulator. + */ +object DroppedWordsCounter { + + @volatile private var instance: Accumulator[Long] = null + + def getInstance(sc: SparkContext): Accumulator[Long] = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.accumulator(0L, "WordsInBlacklistCounter") + } + } + } + instance + } +} + +/** + * Counts words in text encoded with UTF8 received from the network every second. This example also + * shows how to use lazily instantiated singleton instances for Accumulator and Broadcast so that + * they can be registered on driver failures. * * Usage: RecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive @@ -75,10 +117,24 @@ object RecoverableNetworkWordCount { val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { - val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]") - println(counts) + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter += count + false + } else { + true + } + }.collect().mkString("[", ", ", "]") + val output = "Counts at time " + time + " " + counts + println(output) + println("Dropped " + droppedWordsCounter.value + " word(s) totally") println("Appending to " + outputFile.getAbsolutePath) - Files.append(counts + "\n", outputFile, Charset.defaultCharset()) + Files.append(output + "\n", outputFile, Charset.defaultCharset()) }) ssc } From 86761e10e145b6867cbe86b1e924ec237ba408af Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 23 Dec 2015 10:21:00 +0800 Subject: [PATCH 1338/2089] [SPARK-12478][SQL] Bugfix: Dataset fields of product types can't be null When creating extractors for product types (i.e. case classes and tuples), a null check is missing, thus we always assume input product values are non-null. This PR adds a null check in the extractor expression for product types. The null check is stripped off for top level product fields, which are mapped to the outermost `Row`s, since they can't be null. Thanks cloud-fan for helping investigating this issue! Author: Cheng Lian Closes #10431 from liancheng/spark-12478.top-level-null-field. --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++---- .../scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index becd019caeca4..8a22b37d07fc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -380,7 +380,7 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil extractorFor(inputObject, tpe, walkedTypePath) match { - case s: CreateNamedStruct => s + case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } @@ -466,14 +466,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Product] => val params = getConstructorParameters(t) - - CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3337996309d4d..7fe66e461c140 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -546,6 +546,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." )) } + + test("SPARK-12478: top level null field") { + val ds0 = Seq(NestedStruct(null)).toDS() + checkAnswer(ds0, NestedStruct(null)) + checkAnswer(ds0.toDF(), Row(null)) + + val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() + checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkAnswer(ds1.toDF(), Row(Row(null))) + } } case class ClassData(a: String, b: Int) @@ -553,6 +563,7 @@ case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) case class NestedStruct(f: ClassData) +case class DeepNestedStruct(f: NestedStruct) /** * A class used to test serialization using encoders. This class throws exceptions when using From 50301c0a28b64c5348b0f2c2d828589c0833c70c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 Dec 2015 14:08:29 +0800 Subject: [PATCH 1339/2089] [SPARK-11164][SQL] Add InSet pushdown filter back for Parquet When the filter is ```"b in ('1', '2')"```, the filter is not pushed down to Parquet. Thanks! Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10278 from gatorsmile/parquetFilterNot. --- .../BooleanSimplificationSuite.scala | 20 ++++++++----- .../datasources/parquet/ParquetFilters.scala | 3 ++ .../parquet/ParquetFilterSuite.scala | 30 +++++++++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index cde346e99eb17..a0c71d83d7e39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -86,23 +86,27 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition( ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) + 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } test("a && (!a || b)") { - checkCondition(('a && (!('a) || 'b )), ('a && 'b)) + checkCondition('a && (!'a || 'b ), 'a && 'b) - checkCondition(('a && ('b || !('a) )), ('a && 'b)) + checkCondition('a && ('b || !'a ), 'a && 'b) - checkCondition(((!('a) || 'b ) && 'a), ('b && 'a)) + checkCondition((!'a || 'b ) && 'a, 'b && 'a) - checkCondition((('b || !('a) ) && 'a), ('b && 'a)) + checkCondition(('b || !'a ) && 'a, 'b && 'a) } - test("!(a && b) , !(a || b)") { - checkCondition((!('a && 'b)), (!('a) || !('b))) + test("DeMorgan's law") { + checkCondition(!('a && 'b), !'a || !'b) - checkCondition(!('a || 'b), (!('a) && !('b))) + checkCondition(!('a || 'b), !'a && !'b) + + checkCondition(!(('a && 'b) || ('c && 'd)), (!'a || !'b) && (!'c || !'d)) + + checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } private val caseInsensitiveAnalyzer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 883013bf1bfc2..ac9b65b66d986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -256,6 +256,9 @@ private[sql] object ParquetFilters { case sources.GreaterThanOrEqual(name, value) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.In(name, valueSet) => + makeInSet.lift(dataTypeOf(name)).map(_(name, valueSet.toSet)) + case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side if we do not understand the // other side. Here is an example used to explain the reason. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 045425f282ad0..9197b8b5637eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -381,4 +381,34 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-11164: test the parquet filter in") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) + + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + val df = sqlContext.read.parquet(path).where("b in (0,2)") + assert(stripSparkFilter(df).count == 3) + + val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + assert(stripSparkFilter(df1).count == 3) + + val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + assert(stripSparkFilter(df2).count == 2) + + val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + assert(stripSparkFilter(df3).count == 4) + + val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + assert(stripSparkFilter(df4).count == 3) + } + } + } + } } From 43b2a6390087b7ce262a54dc8ab8dd825db62e21 Mon Sep 17 00:00:00 2001 From: pierre-borckmans Date: Tue, 22 Dec 2015 23:00:42 -0800 Subject: [PATCH 1340/2089] [SPARK-12477][SQL] - Tungsten projection fails for null values in array fields Accessing null elements in an array field fails when tungsten is enabled. It works in Spark 1.3.1, and in Spark > 1.5 with Tungsten disabled. This PR solves this by checking if the accessed element in the array field is null, in the generated code. Example: ``` // Array of String case class AS( as: Seq[String] ) val dfAS = sc.parallelize( Seq( AS ( Seq("a",null,"b") ) ) ).toDF dfAS.registerTempTable("T_AS") for (i <- 0 to 2) { println(i + " = " + sqlContext.sql(s"select as[$i] from T_AS").collect.mkString(","))} ``` With Tungsten disabled: ``` 0 = [a] 1 = [null] 2 = [b] ``` With Tungsten enabled: ``` 0 = [a] 15/12/22 09:32:50 ERROR Executor: Exception in task 7.0 in stage 1.0 (TID 15) java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.UnsafeRowWriters$UTF8StringWriter.getSize(UnsafeRowWriters.java:90) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.execution.TungstenProject$$anonfun$3$$anonfun$apply$3.apply(basicOperators.scala:90) at org.apache.spark.sql.execution.TungstenProject$$anonfun$3$$anonfun$apply$3.apply(basicOperators.scala:88) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) ``` Author: pierre-borckmans Closes #10429 from pierre-borckmans/SPARK-12477_Tungsten-Projection-Null-Element-In-Array. --- .../sql/catalyst/expressions/complexTypeExtractors.scala | 2 +- .../org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index c5ed173eeb9dd..91c275b1aa1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -227,7 +227,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0) { + if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) { ${ev.isNull} = true; } else { ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 09f7b507670c9..b76fc73b7fa0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -43,4 +43,13 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() } + + test("SPARK-12477 accessing null element in array field") { + val df = sparkContext.parallelize(Seq((Seq("val1", null, "val2"), + Seq(Some(1), None, Some(2))))).toDF("s", "i") + val nullStringRow = df.selectExpr("s[1]").collect()(0) + assert(nullStringRow == org.apache.spark.sql.Row(null)) + val nullIntRow = df.selectExpr("i[1]").collect()(0) + assert(nullIntRow == org.apache.spark.sql.Row(null)) + } } From ae1f54aa0ed69f9daa1f32766ca234bda9320452 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 23 Dec 2015 13:24:06 -0800 Subject: [PATCH 1341/2089] [SPARK-12500][CORE] Fix Tachyon deprecations; pull Tachyon dependency into one class Fix Tachyon deprecations; pull Tachyon dependency into `TachyonBlockManager` only CC calvinjia as I probably need a double-check that the usage of the new API is correct. Author: Sean Owen Closes #10449 from srowen/SPARK-12500. --- .../spark/storage/TachyonBlockManager.scala | 135 ++++++++++++++---- .../spark/util/ShutdownHookManager.scala | 42 ------ .../scala/org/apache/spark/util/Utils.scala | 11 -- 3 files changed, 104 insertions(+), 84 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index d14fe4613528a..7f88f2fe6d503 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -26,13 +26,17 @@ import scala.util.control.NonFatal import com.google.common.io.ByteStreams -import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} +import tachyon.{Constants, TachyonURI} +import tachyon.client.ClientContext +import tachyon.client.file.{TachyonFile, TachyonFileSystem} +import tachyon.client.file.TachyonFileSystem.TachyonFileSystemFactory +import tachyon.client.file.options.DeleteOptions import tachyon.conf.TachyonConf -import tachyon.TachyonURI +import tachyon.exception.{FileAlreadyExistsException, FileDoesNotExistException} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.Utils /** @@ -44,15 +48,15 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log var rootDirs: String = _ var master: String = _ - var client: tachyon.client.TachyonFS = _ + var client: TachyonFileSystem = _ private var subDirsPerTachyonDir: Int = _ // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; // then, inside this directory, create multiple subdirectories that we will hash files into, // in order to avoid having really large inodes at the top level in Tachyon. private var tachyonDirs: Array[TachyonFile] = _ - private var subDirs: Array[Array[tachyon.client.TachyonFile]] = _ - + private var subDirs: Array[Array[TachyonFile]] = _ + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() override def init(blockManager: BlockManager, executorId: String): Unit = { super.init(blockManager, executorId) @@ -62,7 +66,10 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log rootDirs = s"$storeDir/$appFolderName/$executorId" master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") client = if (master != null && master != "") { - TachyonFS.get(new TachyonURI(master), new TachyonConf()) + val tachyonConf = new TachyonConf() + tachyonConf.set(Constants.MASTER_ADDRESS, master) + ClientContext.reset(tachyonConf) + TachyonFileSystemFactory.get } else { null } @@ -80,7 +87,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(registerShutdownDeleteDir) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -89,6 +96,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) if (fileExists(file)) { removeFile(file) + true } else { false } @@ -101,7 +109,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { val file = getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) + val os = client.getOutStream(new TachyonURI(client.getInfo(file).getPath)) try { Utils.writeByteBuffer(bytes, os) } catch { @@ -115,7 +123,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { val file = getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) + val os = client.getOutStream(new TachyonURI(client.getInfo(file).getPath)) try { blockManager.dataSerializeStream(blockId, os, values) } catch { @@ -129,12 +137,17 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val file = getFile(blockId) - if (file == null || file.getLocationHosts.size == 0) { + if (file == null) { return None } - val is = file.getInStream(ReadType.CACHE) + val is = try { + client.getInStream(file) + } catch { + case _: FileDoesNotExistException => + return None + } try { - val size = file.length + val size = client.getInfo(file).length val bs = new Array[Byte](size.asInstanceOf[Int]) ByteStreams.readFully(is, bs) Some(ByteBuffer.wrap(bs)) @@ -149,25 +162,37 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def getValues(blockId: BlockId): Option[Iterator[_]] = { val file = getFile(blockId) - if (file == null || file.getLocationHosts().size() == 0) { + if (file == null) { return None } - val is = file.getInStream(ReadType.CACHE) - Option(is).map { is => - blockManager.dataDeserializeStream(blockId, is) + val is = try { + client.getInStream(file) + } catch { + case _: FileDoesNotExistException => + return None + } + try { + Some(blockManager.dataDeserializeStream(blockId, is)) + } finally { + is.close() } } override def getSize(blockId: BlockId): Long = { - getFile(blockId.name).length + client.getInfo(getFile(blockId.name)).length } - def removeFile(file: TachyonFile): Boolean = { - client.delete(new TachyonURI(file.getPath()), false) + def removeFile(file: TachyonFile): Unit = { + client.delete(file) } def fileExists(file: TachyonFile): Boolean = { - client.exist(new TachyonURI(file.getPath())) + try { + client.getInfo(file) + true + } catch { + case _: FileDoesNotExistException => false + } } def getFile(filename: String): TachyonFile = { @@ -186,18 +211,18 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log } else { val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) - val newDir = client.getFile(path) + val newDir = client.loadMetadata(path) subDirs(dirId)(subDirId) = newDir newDir } } } val filePath = new TachyonURI(s"$subDir/$filename") - if(!client.exist(filePath)) { - client.createFile(filePath) + try { + client.create(filePath) + } catch { + case _: FileAlreadyExistsException => client.loadMetadata(filePath) } - val file = client.getFile(filePath) - file } def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name) @@ -217,9 +242,11 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") - if (!client.exist(path)) { + try { foundLocalDir = client.mkdir(path) - tachyonDir = client.getFile(path) + tachyonDir = client.loadMetadata(path) + } catch { + case _: FileAlreadyExistsException => // continue } } catch { case NonFatal(e) => @@ -240,14 +267,60 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) + if (!hasRootAsShutdownDeleteDir(tachyonDir)) { + deleteRecursively(tachyonDir, client) } } catch { case NonFatal(e) => logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } - client.close() } + + /** + * Delete a file or directory and its contents recursively. + */ + private def deleteRecursively(dir: TachyonFile, client: TachyonFileSystem) { + client.delete(dir, new DeleteOptions.Builder(ClientContext.getConf).setRecursive(true).build()) + } + + // Register the tachyon path to be deleted via shutdown hook + private def registerShutdownDeleteDir(file: TachyonFile) { + val absolutePath = client.getInfo(file).getPath + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the tachyon path to be deleted via shutdown hook + private def removeShutdownDeleteDir(file: TachyonFile) { + val absolutePath = client.getInfo(file).getPath + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths -= absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + private def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = client.getInfo(file).getPath + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + private def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = client.getInfo(file).getPath + val hasRoot = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists( + path => !absolutePath.equals(path) && absolutePath.startsWith(path)) + } + if (hasRoot) { + logInfo(s"path = $absolutePath, already present as root for deletion.") + } + hasRoot + } + } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 1a0f3b477ba3f..0065b1fc660b0 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -21,7 +21,6 @@ import java.io.File import java.util.PriorityQueue import scala.util.{Failure, Success, Try} -import tachyon.client.TachyonFile import org.apache.hadoop.fs.FileSystem import org.apache.spark.Logging @@ -52,7 +51,6 @@ private[spark] object ShutdownHookManager extends Logging { } private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => @@ -77,14 +75,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - // Remove the path to be deleted via shutdown hook def removeShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -93,14 +83,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Remove the tachyon path to be deleted via shutdown hook - def removeShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.remove(absolutePath) - } - } - // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteDir(file: File): Boolean = { val absolutePath = file.getAbsolutePath() @@ -109,14 +91,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - // Note: if file is child of some registered path, while not equal to it, then return true; // else false. This is to ensure that two shutdown hooks do not try to delete each others // paths - resulting in IOException and incomplete cleanup. @@ -133,22 +107,6 @@ private[spark] object ShutdownHookManager extends Logging { retval } - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * Detect whether this thread might be executing a shutdown hook. Will always return true if * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1a07f7ca7eaf5..b8ca6b07e4198 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -44,8 +44,6 @@ import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ import org.slf4j.Logger -import tachyon.TachyonURI -import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -946,15 +944,6 @@ private[spark] object Utils extends Logging { } } - /** - * Delete a file or directory and its contents recursively. - */ - def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(new TachyonURI(dir.getPath()), true)) { - throw new IOException("Failed to delete the tachyon dir: " + dir) - } - } - /** * Check to see if file is a symbolic link. */ From ead6abf7e7fc14b451214951d4991d497aa65e63 Mon Sep 17 00:00:00 2001 From: Adrian Bridgett Date: Wed, 23 Dec 2015 16:00:03 -0800 Subject: [PATCH 1342/2089] [SPARK-12499][BUILD] don't force MAVEN_OPTS allow the user to override MAVEN_OPTS (2GB wasn't sufficient for me) Author: Adrian Bridgett Closes #10448 from abridgett/feature/do_not_force_maven_opts. --- make-distribution.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/make-distribution.sh b/make-distribution.sh index 351b9e7d89a32..a38fd8df17206 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -159,7 +159,7 @@ fi # Build uber fat JAR cd "$SPARK_HOME" -export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" +export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m}" # Store the command as an array because $MVN variable might have spaces in it. # Normal quoting tricks don't work. From 9e85bb71ad2d7d3a9da0cb8853f3216d37e6ff47 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 24 Dec 2015 21:27:55 +0900 Subject: [PATCH 1343/2089] [SPARK-12502][BUILD][PYTHON] Script /dev/run-tests fails when IBM Java is used fix an exception with IBM JDK by removing update field from a JavaVersion tuple. This is because IBM JDK does not have information on update '_xx' Author: Kazuaki Ishizaki Closes #10463 from kiszk/SPARK-12502. --- dev/run-tests.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 17ceba052b8cd..6129f87cf8503 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -148,7 +148,7 @@ def determine_java_executable(): return java_exe if java_exe else which("java") -JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch', 'update']) +JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch']) def determine_java_version(java_exe): @@ -164,14 +164,13 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - match = re.search('(\d+)\.(\d+)\.(\d+)_(\d+)', raw_version_str) + match = re.search('(\d+)\.(\d+)\.(\d+)', raw_version_str) major = int(match.group(1)) minor = int(match.group(2)) patch = int(match.group(3)) - update = int(match.group(4)) - return JavaVersion(major, minor, patch, update) + return JavaVersion(major, minor, patch) # ------------------------------------------------------------------------------------------------- # Functions for running the other build and test scripts From 392046611837a3a740ff97fa8177ca7c12316fb7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 24 Dec 2015 13:37:28 +0000 Subject: [PATCH 1344/2089] [SPARK-12311][CORE] Restore previous value of "os.arch" property in test suites after forcing to set specific value to "os.arch" property Restore the original value of os.arch property after each test Since some of tests forced to set the specific value to os.arch property, we need to set the original value. Author: Kazuaki Ishizaki Closes #10289 from kiszk/SPARK-12311. --- .../sort/IndexShuffleBlockResolverSuite.scala | 7 +++- .../org/apache/spark/CheckpointSuite.scala | 7 ++-- .../spark/ExternalShuffleServiceSuite.scala | 7 +++- .../org/apache/spark/FileServerSuite.scala | 7 ++-- .../scala/org/apache/spark/FileSuite.scala | 7 ++-- .../org/apache/spark/HashShuffleSuite.scala | 1 + .../apache/spark/HeartbeatReceiverSuite.scala | 1 + .../apache/spark/JobCancellationSuite.scala | 7 ++-- .../org/apache/spark/LocalSparkContext.scala | 9 +++-- .../apache/spark/SecurityManagerSuite.scala | 4 +-- .../org/apache/spark/SharedSparkContext.scala | 11 +++--- .../org/apache/spark/ShuffleNettySuite.scala | 1 + .../org/apache/spark/SortShuffleSuite.scala | 2 ++ .../StandaloneDynamicAllocationSuite.scala | 21 ++++++----- .../spark/deploy/client/AppClientSuite.scala | 21 ++++++----- .../deploy/history/HistoryServerSuite.scala | 3 +- .../rest/StandaloneRestSubmitSuite.scala | 8 +++-- .../WholeTextFileRecordReaderSuite.scala | 7 +++- .../NettyBlockTransferServiceSuite.scala | 18 ++++++---- .../spark/rdd/AsyncRDDActionsSuite.scala | 9 +++-- .../spark/rdd/LocalCheckpointSuite.scala | 1 + .../org/apache/spark/rpc/RpcEnvSuite.scala | 11 ++++-- .../SerializationDebuggerSuite.scala | 1 + .../BypassMergeSortShuffleWriterSuite.scala | 11 ++++-- .../spark/storage/BlockManagerSuite.scala | 35 +++++++++++-------- .../spark/storage/DiskBlockManagerSuite.scala | 16 ++++++--- .../storage/DiskBlockObjectWriterSuite.scala | 7 +++- .../org/apache/spark/ui/UISeleniumSuite.scala | 9 +++-- .../spark/util/ClosureCleanerSuite2.scala | 11 ++++-- .../spark/util/SizeEstimatorSuite.scala | 4 +++ .../source/libsvm/LibSVMRelationSuite.scala | 7 ++-- .../apache/spark/ml/util/TempDirectory.scala | 7 ++-- .../mllib/util/LocalClusterSparkContext.scala | 11 +++--- .../mllib/util/MLlibTestSparkContext.scala | 15 ++++---- .../spark/repl/ExecutorClassLoaderSuite.scala | 15 ++++---- .../spark/streaming/CheckpointSuite.scala | 15 +++++--- .../spark/streaming/DStreamClosureSuite.scala | 9 +++-- .../spark/streaming/DStreamScopeSuite.scala | 7 +++- .../spark/streaming/MapWithStateSuite.scala | 9 +++-- .../streaming/ReceiverInputDStreamSuite.scala | 6 +++- .../spark/streaming/UISeleniumSuite.scala | 9 +++-- .../streaming/rdd/MapWithStateRDDSuite.scala | 11 ++++-- .../WriteAheadLogBackedBlockRDDSuite.scala | 16 +++++++-- .../streaming/util/WriteAheadLogSuite.scala | 9 +++-- .../deploy/yarn/BaseYarnClusterSuite.scala | 13 +++++-- .../spark/deploy/yarn/ClientSuite.scala | 18 ++++++++-- .../deploy/yarn/YarnAllocatorSuite.scala | 7 +++- .../yarn/YarnSparkHadoopUtilSuite.scala | 5 +-- .../yarn/YarnShuffleServiceSuite.scala | 27 ++++++++------ 49 files changed, 338 insertions(+), 142 deletions(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 0b19861fc41ee..f200ff36c7dd5 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -42,6 +42,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa private val conf: SparkConf = new SparkConf(loadDefaults = false) override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() MockitoAnnotations.initMocks(this) @@ -55,7 +56,11 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("commit shuffle files multiple times") { diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 553d46285ac03..390764ba242fd 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -256,8 +256,11 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS } override def afterEach(): Unit = { - super.afterEach() - Utils.deleteRecursively(checkpointDir) + try { + Utils.deleteRecursively(checkpointDir) + } finally { + super.afterEach() + } } override def sparkContext: SparkContext = sc diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 1c775bcb3d9c1..eb3fb99747d12 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,6 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { + super.beforeAll() val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) @@ -46,7 +47,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } override def afterAll() { - server.close() + try { + server.close() + } finally { + super.afterAll() + } } // This test ensures that the external shuffle service is actually in use for the other tests. diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 1255e71af6c0b..2c32b69715484 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -75,8 +75,11 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } override def afterAll() { - super.afterAll() - Utils.deleteRecursively(tmpDir) + try { + Utils.deleteRecursively(tmpDir) + } finally { + super.afterAll() + } } test("Distributing files locally") { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index fdb00aafc4a48..f6a7f4375fac8 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -44,8 +44,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } override def afterEach() { - super.afterEach() - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("text files") { diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala index 19180e88ebe0a..10794235ed392 100644 --- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -24,6 +24,7 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with hash-based shuffle. override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.manager", "hash") } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 3cd80c0f7d171..9b43341576a8a 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -66,6 +66,7 @@ class HeartbeatReceiverSuite * that uses a manual clock. */ override def beforeEach(): Unit = { + super.beforeEach() val conf = new SparkConf() .setMaster("local[2]") .setAppName("test") diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 1168eb0b802f2..e13a442463e8d 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -38,8 +38,11 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft with LocalSparkContext { override def afterEach() { - super.afterEach() - resetSparkContext() + try { + resetSparkContext() + } finally { + super.afterEach() + } } test("local mode, FIFO scheduler") { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 8bf2e55defd02..214681970acbf 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -28,13 +28,16 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) } override def afterEach() { - resetSparkContext() - super.afterEach() + try { + resetSparkContext() + } finally { + super.afterEach() + } } def resetSparkContext(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 26b95c06789f7..e0226803bb1cf 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark import java.io.File -import org.apache.spark.util.{SparkConfWithEnv, Utils} +import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} -class SecurityManagerSuite extends SparkFunSuite { +class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { test("set security with conf") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 3d2700b7e6be4..858bc742e07cf 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,13 +30,16 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => var conf = new SparkConf(false) override def beforeAll() { - _sc = new SparkContext("local[4]", "test", conf) super.beforeAll() + _sc = new SparkContext("local[4]", "test", conf) } override def afterAll() { - LocalSparkContext.stop(_sc) - _sc = null - super.afterAll() + try { + LocalSparkContext.stop(_sc) + _sc = null + } finally { + super.afterAll() + } } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index d78c99c2e1e06..73638d9b131ea 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,6 +24,7 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.blockTransferService", "netty") } } diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index b8ab227517cc4..5354731465a4a 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -37,10 +37,12 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { private var tempDir: File = _ override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.manager", "sort") } override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() conf.set("spark.local.dir", tempDir.getAbsolutePath) } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 314517d296049..85c1c1bbf3dc1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -71,15 +71,18 @@ class StandaloneDynamicAllocationSuite } override def afterAll(): Unit = { - masterRpcEnv.shutdown() - workerRpcEnvs.foreach(_.shutdown()) - master.stop() - workers.foreach(_.stop()) - masterRpcEnv = null - workerRpcEnvs = null - master = null - workers = null - super.afterAll() + try { + masterRpcEnv.shutdown() + workerRpcEnvs.foreach(_.shutdown()) + master.stop() + workers.foreach(_.stop()) + masterRpcEnv = null + workerRpcEnvs = null + master = null + workers = null + } finally { + super.afterAll() + } } test("dynamic allocation default behavior") { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 1e5c05a73f8aa..415e2b37dbbdc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -63,15 +63,18 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } override def afterAll(): Unit = { - workerRpcEnvs.foreach(_.shutdown()) - masterRpcEnv.shutdown() - workers.foreach(_.stop()) - master.stop() - workerRpcEnvs = null - masterRpcEnv = null - workers = null - master = null - super.afterAll() + try { + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnv.shutdown() + workers.foreach(_.stop()) + master.stop() + workerRpcEnvs = null + masterRpcEnv = null + workers = null + master = null + } finally { + super.afterAll() + } } test("interface methods of AppClient using local Master") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4b7fd4f13b692..18659fc0c18de 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.mock.MockitoSugar import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.ui.{SparkUI, UIUtils} +import org.apache.spark.util.ResetSystemProperties /** * A collection of tests against the historyserver, including comparing responses from the json @@ -43,7 +44,7 @@ import org.apache.spark.ui.{SparkUI, UIUtils} * are considered part of Spark's public api. */ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar - with JsonTestUtils { + with JsonTestUtils with ResetSystemProperties { private val logDir = new File("src/test/resources/spark-events") private val expRoot = new File("src/test/resources/HistoryServerExpectations/") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 9693e32bf6af6..fa39aa2cb1311 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -43,8 +43,12 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { private var server: Option[RestSubmissionServer] = None override def afterEach() { - rpcEnv.foreach(_.shutdown()) - server.foreach(_.stop()) + try { + rpcEnv.foreach(_.shutdown()) + server.foreach(_.stop()) + } finally { + super.afterEach() + } } test("construct submit request") { diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 8a199459c1ddf..24184b02cb4c1 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -47,6 +47,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl // hard-to-reproduce test failures, since any suites that were run after this one would inherit // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, // we disable FileSystem caching in this suite. + super.beforeAll() val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") sc = new SparkContext("local", "test", conf) @@ -59,7 +60,11 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl } override def afterAll() { - sc.stop() + try { + sc.stop() + } finally { + super.afterAll() + } } private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 6f8e8a7ac6033..92daf4e6a2169 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -31,14 +31,18 @@ class NettyBlockTransferServiceSuite private var service1: NettyBlockTransferService = _ override def afterEach() { - if (service0 != null) { - service0.close() - service0 = null - } + try { + if (service0 != null) { + service0.close() + service0 = null + } - if (service1 != null) { - service1.close() - service1 = null + if (service1 != null) { + service1.close() + service1 = null + } + } finally { + super.afterEach() } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index de015ebd5d237..d18bde790b40a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -34,12 +34,17 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim @transient private var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() sc = new SparkContext("local[2]", "test") } override def afterAll() { - LocalSparkContext.stop(sc) - sc = null + try { + LocalSparkContext.stop(sc) + sc = null + } finally { + super.afterAll() + } } lazy val zeroPartRdd = new EmptyRDD[Int](sc) diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 5103eb74b2457..3a22a9850a096 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.{RDDBlockId, StorageLevel} class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { override def beforeEach(): Unit = { + super.beforeEach() sc = new SparkContext("local[2]", "test") } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 7b3a17c17233a..9c850c0da52a3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -44,6 +44,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { var env: RpcEnv = _ override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) @@ -53,10 +54,14 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } override def afterAll(): Unit = { - if (env != null) { - env.shutdown() + try { + if (env != null) { + env.shutdown() + } + SparkEnv.set(null) + } finally { + super.afterAll() } - SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index 2d5e9d66b2e15..683aaa3aab1ba 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -29,6 +29,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { import SerializationDebugger.find override def beforeEach(): Unit = { + super.beforeEach() SerializationDebugger.enableDebugging = true } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index d3b1b2b620b4d..bb331bb385df3 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -55,6 +55,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics @@ -119,9 +120,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) - blockIdToFileMap.clear() - temporaryFilesCreated.clear() + try { + Utils.deleteRecursively(tempDir) + blockIdToFileMap.clear() + temporaryFilesCreated.clear() + } finally { + super.afterEach() + } } test("write empty iterator") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bf49be3d4c4fd..2224a444c7b54 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -79,6 +79,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } override def beforeEach(): Unit = { + super.beforeEach() rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case @@ -97,22 +98,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } override def afterEach(): Unit = { - if (store != null) { - store.stop() - store = null - } - if (store2 != null) { - store2.stop() - store2 = null - } - if (store3 != null) { - store3.stop() - store3 = null + try { + if (store != null) { + store.stop() + store = null + } + if (store2 != null) { + store2.stop() + store2 = null + } + if (store3 != null) { + store3.stop() + store3 = null + } + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null + master = null + } finally { + super.afterEach() } - rpcEnv.shutdown() - rpcEnv.awaitTermination() - rpcEnv = null - master = null } test("StorageLevel object caching") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 688f56f4665f3..69e17461df755 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -45,19 +45,27 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B } override def afterAll() { - super.afterAll() - Utils.deleteRecursively(rootDir0) - Utils.deleteRecursively(rootDir1) + try { + Utils.deleteRecursively(rootDir0) + Utils.deleteRecursively(rootDir1) + } finally { + super.afterAll() + } } override def beforeEach() { + super.beforeEach() val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { - diskBlockManager.stop() + try { + diskBlockManager.stop() + } finally { + super.afterEach() + } } test("basic block creation") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7c19531c18802..5d36617cfc447 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -30,11 +30,16 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("verify write metrics") { diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index ceecfd665bf87..0e36d7fda430d 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -76,14 +76,19 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B override def beforeAll(): Unit = { + super.beforeAll() webDriver = new HtmlUnitDriver { getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) } } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index a829b099025e9..934385fbcad1b 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -38,14 +38,19 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri private var closureSerializer: SerializerInstance = null override def beforeAll(): Unit = { + super.beforeAll() sc = new SparkContext("local", "test") closureSerializer = sc.env.closureSerializer.newInstance() } override def afterAll(): Unit = { - sc.stop() - sc = null - closureSerializer = null + try { + sc.stop() + sc = null + closureSerializer = null + } finally { + super.afterAll() + } } // Some fields and methods to reference in inner closures later diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 101610e38014e..fbe7b956682d5 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -79,6 +79,10 @@ class SizeEstimatorSuite System.setProperty("spark.test.useCompressedOops", "true") } + override def afterEach(): Unit = { + super.afterEach() + } + test("simple classes") { assertResult(16)(SizeEstimator.estimate(new DummyClass1)) assertResult(16)(SizeEstimator.estimate(new DummyClass2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 997f574e51f6a..5f4d5f11bdd68 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -46,8 +46,11 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) - super.afterAll() + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterAll() + } } test("select as sparse vector") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index c8a0bb16247b4..8f11bbc8e47af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -39,7 +39,10 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => } override def afterAll(): Unit = { - Utils.deleteRecursively(_tempDir) - super.afterAll() + try { + Utils.deleteRecursively(_tempDir) + } finally { + super.afterAll() + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 525ab68c7921a..4f73b0809dca4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -25,18 +25,21 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() val conf = new SparkConf() .setMaster("local-cluster[2, 1, 1024]") .setAppName("test-cluster") .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data sc = new SparkContext(conf) - super.beforeAll() } override def afterAll() { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + } finally { + super.afterAll() } - super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 378139593b26f..ebcd591465cb5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -38,12 +38,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => } override def afterAll() { - sqlContext = null - SQLContext.clearActive() - if (sc != null) { - sc.stop() + try { + sqlContext = null + SQLContext.clearActive() + if (sc != null) { + sc.stop() + } + sc = null + } finally { + super.afterAll() } - sc = null - super.afterAll() } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 1360f09e7fa1f..05bf7a3aaefbf 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -72,13 +72,16 @@ class ExecutorClassLoaderSuite } override def afterAll() { - super.afterAll() - if (classServer != null) { - classServer.stop() + try { + if (classServer != null) { + classServer.stop() + } + Utils.deleteRecursively(tempDir1) + Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) + } finally { + super.afterAll() } - Utils.deleteRecursively(tempDir1) - Utils.deleteRecursively(tempDir2) - SparkEnv.set(null) } test("child first") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index f5f446f14a0da..4d04138da01f7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils} +import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties, + Utils} /** * A input stream that records the times of restore() invoked @@ -196,7 +197,8 @@ trait DStreamCheckpointTester { self: SparkFunSuite => * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester + with ResetSystemProperties { var ssc: StreamingContext = null @@ -208,9 +210,12 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { } override def afterFunction() { - super.afterFunction() - if (ssc != null) { ssc.stop() } - Utils.deleteRecursively(new File(checkpointDir)) + try { + if (ssc != null) { ssc.stop() } + Utils.deleteRecursively(new File(checkpointDir)) + } finally { + super.afterFunction() + } } test("basic rdd checkpoints + dstream graph checkpoint recovery") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 9b5e4dc819a2b..e897de3cba6d2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -33,13 +33,18 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private var ssc: StreamingContext = null override def beforeAll(): Unit = { + super.beforeAll() val sc = new SparkContext("local", "test") ssc = new StreamingContext(sc, Seconds(1)) } override def afterAll(): Unit = { - ssc.stop(stopSparkContext = true) - ssc = null + try { + ssc.stop(stopSparkContext = true) + ssc = null + } finally { + super.afterAll() + } } test("user provided closures are actually cleaned") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index bc223e648a417..4c12ecc399e41 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -35,13 +35,18 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd private val batchDuration: Duration = Seconds(1) override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf().setMaster("local").setAppName("test") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) ssc = new StreamingContext(new SparkContext(conf), batchDuration) } override def afterAll(): Unit = { - ssc.stop(stopSparkContext = true) + try { + ssc.stop(stopSparkContext = true) + } finally { + super.afterAll() + } } before { assertPropertiesNotSet() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 6b21433f1781b..62d75a9e0e7aa 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -49,14 +49,19 @@ class MapWithStateSuite extends SparkFunSuite } override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) sc = new SparkContext(conf) } override def afterAll(): Unit = { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + } finally { + super.afterAll() } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index 6d388d9624d92..e6d8fbd4d7c57 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -33,7 +33,11 @@ import org.apache.spark.{SparkConf, SparkEnv} class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { override def afterAll(): Unit = { - StreamingContext.getActive().map { _.stop() } + try { + StreamingContext.getActive().map { _.stop() } + } finally { + super.afterAll() + } } testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index a5744a9009c1c..c4ecebcacf3c8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -38,14 +38,19 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ override def beforeAll(): Unit = { + super.beforeAll() webDriver = new HtmlUnitDriver { getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) } } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index aa95bd33dda9f..1640b9e6b7a6c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -36,6 +36,7 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B private var checkpointDir: File = _ override def beforeAll(): Unit = { + super.beforeAll() sc = new SparkContext( new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) checkpointDir = Utils.createTempDir() @@ -43,10 +44,14 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B } override def afterAll(): Unit = { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + Utils.deleteRecursively(checkpointDir) + } finally { + super.afterAll() } - Utils.deleteRecursively(checkpointDir) } override def sparkContext: SparkContext = sc diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index cb017b798b2a4..43833c4361473 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -42,22 +42,32 @@ class WriteAheadLogBackedBlockRDDSuite var dir: File = null override def beforeEach(): Unit = { + super.beforeEach() dir = Utils.createTempDir() } override def afterEach(): Unit = { - Utils.deleteRecursively(dir) + try { + Utils.deleteRecursively(dir) + } finally { + super.afterEach() + } } override def beforeAll(): Unit = { + super.beforeAll() sparkContext = new SparkContext(conf) blockManager = sparkContext.env.blockManager } override def afterAll(): Unit = { // Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests. - sparkContext.stop() - System.clearProperty("spark.driver.port") + try { + sparkContext.stop() + System.clearProperty("spark.driver.port") + } finally { + super.afterAll() + } } test("Read data available in both block manager and write ahead log") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index ef1e89df31305..beaae34535fd6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -432,6 +432,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( private val queueLength = PrivateMethod[Int]('getQueueLength) override def beforeEach(): Unit = { + super.beforeEach() wal = mock[WriteAheadLog] walHandle = mock[WriteAheadLogRecordHandle] walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") @@ -439,8 +440,12 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } override def afterEach(): Unit = { - if (walBatchingExecutionContext != null) { - walBatchingExecutionContext.shutdownNow() + try { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() + } + } finally { + super.afterEach() } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 12494b01054ba..cd24c704ece5b 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -27,6 +27,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} @@ -59,10 +60,13 @@ abstract class BaseYarnClusterSuite protected var hadoopConfDir: File = _ private var logConfDir: File = _ + var oldSystemProperties: Properties = null + def newYarnConfig(): YarnConfiguration override def beforeAll() { super.beforeAll() + oldSystemProperties = SerializationUtils.clone(System.getProperties) tempDir = Utils.createTempDir() logConfDir = new File(tempDir, "log4j") @@ -115,9 +119,12 @@ abstract class BaseYarnClusterSuite } override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() + try { + yarnCluster.stop() + } finally { + System.setProperties(oldSystemProperties) + super.afterAll() + } } protected def runSpark( diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index e7f2501e7899f..7709c2f6e4f5f 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI +import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -39,16 +41,26 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, ResetSystemProperties} -class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll + with ResetSystemProperties { + + var oldSystemProperties: Properties = null override def beforeAll(): Unit = { + super.beforeAll() + oldSystemProperties = SerializationUtils.clone(System.getProperties) System.setProperty("SPARK_YARN_MODE", "true") } override def afterAll(): Unit = { - System.clearProperty("SPARK_YARN_MODE") + try { + System.setProperties(oldSystemProperties) + oldSystemProperties = null + } finally { + super.afterAll() + } } test("default Yarn application classpath") { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index bd80036c5cfa7..57edbd67253d4 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -72,13 +72,18 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var containerNum = 0 override def beforeEach() { + super.beforeEach() rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() } override def afterEach() { - rmClient.stop() + try { + rmClient.stop() + } finally { + super.afterEach() + } } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 3fafc91a166aa..c2861c9d7fbc7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -34,10 +34,11 @@ import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, ResetSystemProperties} -class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging + with ResetSystemProperties { val hasBash = try { diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index 6aa8c814cd4f0..5a426b86d10e0 100644 --- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -34,6 +34,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration override def beforeEach(): Unit = { + super.beforeEach() yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) @@ -54,17 +55,21 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd var s3: YarnShuffleService = null override def afterEach(): Unit = { - if (s1 != null) { - s1.stop() - s1 = null - } - if (s2 != null) { - s2.stop() - s2 = null - } - if (s3 != null) { - s3.stop() - s3 = null + try { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } finally { + super.afterEach() } } From 502476e45c314a1229b3bce1c61f5cb94a9fc04b Mon Sep 17 00:00:00 2001 From: CK50 Date: Thu, 24 Dec 2015 13:39:11 +0000 Subject: [PATCH 1345/2089] [SPARK-12010][SQL] Spark JDBC requires support for column-name-free INSERT syntax In the past Spark JDBC write only worked with technologies which support the following INSERT statement syntax (JdbcUtils.scala: insertStatement()): INSERT INTO $table VALUES ( ?, ?, ..., ? ) But some technologies require a list of column names: INSERT INTO $table ( $colNameList ) VALUES ( ?, ?, ..., ? ) This was blocking the use of e.g. the Progress JDBC Driver for Cassandra. Another limitation is that syntax 1 relies no the dataframe field ordering match that of the target table. This works fine, as long as the target table has been created by writer.jdbc(). If the target table contains more columns (not created by writer.jdbc()), then the insert fails due mismatch of number of columns or their data types. This PR switches to the recommended second INSERT syntax. Column names are taken from datafram field names. Author: CK50 Closes #10380 from CK50/master-SPARK-12010-2. --- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 252f1cfd5d9c5..28cd688ef7d7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -63,14 +63,10 @@ object JdbcUtils extends Logging { * Returns a PreparedStatement that inserts a row into table via conn. */ def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString()) + val columns = rddSchema.fields.map(_.name).mkString(",") + val placeholders = rddSchema.fields.map(_ => "?").mkString(",") + val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" + conn.prepareStatement(sql) } /** From ea4aab7e87fbcf9ac90f93af79cc892b56508aa0 Mon Sep 17 00:00:00 2001 From: pierre-borckmans Date: Thu, 24 Dec 2015 13:48:21 +0000 Subject: [PATCH 1346/2089] [SPARK-12440][CORE] Avoid setCheckpoint warning when directory is not local In SparkContext method `setCheckpointDir`, a warning is issued when spark master is not local and the passed directory for the checkpoint dir appears to be local. In practice, when relying on HDFS configuration file and using a relative path for the checkpoint directory (using an incomplete URI without HDFS scheme, ...), this warning should not be issued and might be confusing. In fact, in this case, the checkpoint directory is successfully created, and the checkpointing mechanism works as expected. This PR uses the `FileSystem` instance created with the given directory, and checks whether it is local or not. (The rationale is that since this same `FileSystem` instance is used to create the checkpoint dir anyway and can therefore be reliably used to determine if it is local or not). The warning is only issued if the directory is not local, on top of the existing conditions. Author: pierre-borckmans Closes #10392 from pierre-borckmans/SPARK-12440_CheckpointDir_Warning_NonLocal. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 67230f4207b83..d506782b73c43 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2073,8 +2073,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // its own local file system, which is incorrect because the checkpoint files // are actually on the executor machines. if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { - logWarning("Checkpoint directory must be non-local " + - "if Spark is running on a cluster: " + directory) + logWarning("Spark is not running in local mode, therefore the checkpoint directory " + + s"must not be on the local filesystem. Directory '$directory' " + + "appears to be on the local filesystem.") } checkpointDir = Option(directory).map { dir => From 1e97813951674aa5419744b455a4c7340462ac59 Mon Sep 17 00:00:00 2001 From: echo2mei <534384876@qq.com> Date: Fri, 25 Dec 2015 17:42:24 -0800 Subject: [PATCH 1347/2089] [SPARK-12396][CORE] Modify the function scheduleAtFixedRate to schedule. Instead of just cancel the registrationRetryTimer to avoid driver retry connect to master, change the function to schedule. It is no need to register to master iteratively. Author: echo2mei <534384876@qq.com> Closes #10447 from echoTomei/master. --- .../main/scala/org/apache/spark/deploy/client/AppClient.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 1e2f469214b84..a5753e1053649 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -124,7 +124,7 @@ private[spark] class AppClient( */ private def registerWithMaster(nthRetry: Int) { registerMasterFutures.set(tryRegisterAllMasters()) - registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable { + registrationRetryTimer.set(registrationRetryThread.schedule(new Runnable { override def run(): Unit = { Utils.tryOrExit { if (registered.get) { @@ -138,7 +138,7 @@ private[spark] class AppClient( } } } - }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) + }, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) } /** From 9ab296ecdceef88ebca523ed62848fbeb5df353b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 27 Dec 2015 23:18:48 -0800 Subject: [PATCH 1348/2089] [SPARK-12520] [PYSPARK] Correct Descriptions and Add Use Cases in Equi-Join After reading the JIRA https://issues.apache.org/jira/browse/SPARK-12520, I double checked the code. For example, users can do the Equi-Join like ```df.join(df2, 'name', 'outer').select('name', 'height').collect()``` - There exists a bug in 1.5 and 1.4. The code just ignores the third parameter (join type) users pass. However, the join type we called is `Inner`, even if the user-specified type is the other type (e.g., `Outer`). - After a PR: https://github.com/apache/spark/pull/8600, the 1.6 does not have such an issue, but the description has not been updated. Plan to submit another PR to fix 1.5 and issue an error message if users specify a non-inner join type when using Equi-Join. Author: gatorsmile Closes #10477 from gatorsmile/pyOuterJoin. --- python/pyspark/sql/dataframe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4b3791e1b8864..ad621df91064c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -608,13 +608,16 @@ def join(self, other, on=None, how=None): :param on: a string for join column name, a list of column names, , a join expression (Column) or a list of Columns. If `on` is a string or a list of string indicating the name of the join column(s), - the column(s) must exist on both sides, and this performs an inner equi-join. + the column(s) must exist on both sides, and this performs an equi-join. :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() + [Row(name=u'Tom', height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] From 5aa2710c1e587c340a66ed36af75214aea85a97a Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 28 Dec 2015 10:22:45 +0000 Subject: [PATCH 1349/2089] [SPARK-12515][SQL][DOC] minor doc update for read.jdbc Author: felixcheung Closes #10465 from felixcheung/dfreaderjdbcdoc. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c1a8f19313a7d..0acea95344c22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -154,13 +154,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash * your external database systems. * - * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param url JDBC database url of the form `jdbc:subprotocol:subname`. * @param table Name of the table in the external database. * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions + * @param lowerBound the minimum value of `columnName` used to decide partition stride. + * @param upperBound the maximum value of `columnName` used to decide partition stride. + * @param numPartitions the number of partitions. This, along with `lowerBound` (inclusive), + * `upperBound` (exclusive), form partition strides for generated WHERE + * clause expressions used to split the column `columnName` evenly. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. From 8d4940092141c0909b673607f393cdd53f093ed6 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 28 Dec 2015 10:43:23 +0000 Subject: [PATCH 1350/2089] [SPARK-12353][STREAMING][PYSPARK] Fix countByValue inconsistent output in Python API The semantics of Python countByValue is different from Scala API, it is more like countDistinctValue, so here change to make it consistent with Scala/Java API. Author: jerryshao Closes #10350 from jerryshao/SPARK-12353. --- python/pyspark/streaming/dstream.py | 4 ++-- python/pyspark/streaming/tests.py | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index adc2651740007..86447f5e58ecb 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -247,7 +247,7 @@ def countByValue(self): Return a new DStream in which each RDD contains the counts of each distinct value in each RDD of this DStream. """ - return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + return self.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x+y) def saveAsTextFiles(self, prefix, suffix=None): """ @@ -493,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda kv: kv[1] > 0).count() + return counted.filter(lambda kv: kv[1] > 0) def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 4949cd68e3212..86b05d9fd2424 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -279,8 +279,10 @@ def test_countByValue(self): def func(dstream): return dstream.countByValue() - expected = [[4], [4], [3]] - self._test_func(input, func, expected) + expected = [[(1, 2), (2, 2), (3, 2), (4, 2)], + [(5, 2), (6, 2), (7, 1), (8, 1)], + [("a", 2), ("b", 1), ("", 1)]] + self._test_func(input, func, expected, sort=True) def test_groupByKey(self): """Basic operation test for DStream.groupByKey.""" @@ -651,7 +653,16 @@ def test_count_by_value_and_window(self): def func(dstream): return dstream.countByValueAndWindow(2.5, .5) - expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + expected = [[(0, 1)], + [(0, 2), (1, 1)], + [(0, 3), (1, 2), (2, 1)], + [(0, 4), (1, 3), (2, 2), (3, 1)], + [(0, 5), (1, 4), (2, 3), (3, 2), (4, 1)], + [(0, 5), (1, 5), (2, 4), (3, 3), (4, 2), (5, 1)], + [(0, 4), (1, 4), (2, 4), (3, 3), (4, 2), (5, 1)], + [(0, 3), (1, 3), (2, 3), (3, 3), (4, 2), (5, 1)], + [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 1)], + [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]] self._test_func(input, func, expected) def test_group_by_key_and_window(self): From 8e23d8db7f28a97e2f4394cdf9d4c4260abbd750 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 28 Dec 2015 08:48:44 -0800 Subject: [PATCH 1351/2089] [SPARK-12218] Fixes ORC conjunction predicate push down This PR is a follow-up of PR #10362. Two major changes: 1. The fix introduced in #10362 is OK for Parquet, but may disable ORC PPD in many cases PR #10362 stops converting an `AND` predicate if any branch is inconvertible. On the other hand, `OrcFilters` combines all filters into a single big conjunction first and then tries to convert it into ORC `SearchArgument`. This means, if any filter is inconvertible, no filters can be pushed down. This PR fixes this issue by finding out all convertible filters first before doing the actual conversion. The reason behind the current implementation is mostly due to the limitation of ORC `SearchArgument` builder, which is documented in this PR in detail. 1. Copied the `AND` predicate fix for ORC from #10362 to avoid merge conflict. Same as #10362, this PR targets master (2.0.0-SNAPSHOT), branch-1.6, and branch-1.5. Author: Cheng Lian Closes #10377 from liancheng/spark-12218.fix-orc-conjunction-ppd. --- .../parquet/ParquetFilterSuite.scala | 42 +++++++++++- .../spark/sql/hive/orc/OrcFilters.scala | 68 +++++++++++-------- .../spark/sql/hive/orc/OrcSourceSuite.scala | 32 ++++++++- 3 files changed, 112 insertions(+), 30 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9197b8b5637eb..a0abfe458d498 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.parquet.filter2.predicate.Operators._ +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -382,6 +384,42 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("SPARK-12218 Converting conjunctions into Parquet filter predicates") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = true), + StructField("c", DoubleType, nullable = true) + )) + + assertResult(Some(and( + lt(intColumn("a"), 10: Integer), + gt(doubleColumn("c"), 1.5: java.lang.Double))) + ) { + ParquetFilters.createFilter( + schema, + sources.And( + sources.LessThan("a", 10), + sources.GreaterThan("c", 1.5D))) + } + + assertResult(None) { + ParquetFilters.createFilter( + schema, + sources.And( + sources.LessThan("a", 10), + sources.StringContains("b", "prefix"))) + } + + assertResult(None) { + ParquetFilters.createFilter( + schema, + sources.Not( + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix")))) + } + } + test("SPARK-11164: test the parquet filter in") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index ebfb1759b8b96..165210f9ff301 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -26,15 +26,47 @@ import org.apache.spark.Logging import org.apache.spark.sql.sources._ /** - * It may be optimized by push down partial filters. But we are conservative here. - * Because if some filters fail to be parsed, the tree may be corrupted, - * and cannot be used anymore. + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. */ private[orc] object OrcFilters extends Logging { def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(filter, SearchArgumentFactory.newBuilder()) + } yield filter + for { - // Combines all filters with `And`s to produce a single conjunction predicate - conjunction <- filters.reduceOption(And) + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- convertibleFilters.reduceOption(And) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) } yield builder.build() @@ -50,28 +82,6 @@ private[orc] object OrcFilters extends Logging { case _ => false } - // lian: I probably missed something here, and had to end up with a pretty weird double-checking - // pattern when converting `And`/`Or`/`Not` filters. - // - // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, - // and `startNot()` mutate internal state of the builder instance. This forces us to translate - // all convertible filters with a single builder instance. However, before actually converting a - // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible - // filter is found, we may already end up with a builder whose internal state is inconsistent. - // - // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and - // then try to convert its children. Say we convert `left` child successfully, but find that - // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is - // inconsistent now. - // - // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their - // children with brand new builders, and only do the actual conversion with the right builder - // instance when the children are proven to be convertible. - // - // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. - // Usage of builder methods mentioned above can only be found in test code, where all tested - // filters are known to be convertible. - expression match { case And(left, right) => // At here, it is not safe to just convert one side if we do not understand the @@ -102,6 +112,10 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + case EqualTo(attribute, value) if isSearchableLiteral(value) => Some(builder.startAnd().equals(attribute, value).end()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 7a34cf731b4c5..47e73b4006fa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -21,8 +21,9 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{QueryTest, Row} case class OrcData(intField: Int, stringField: String) @@ -174,4 +175,33 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + // The `LessThan` should be converted while the `StringContains` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } } From ab6bedd85dc29906ac2f175f603ae3b43ab03535 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 28 Dec 2015 10:40:03 -0800 Subject: [PATCH 1352/2089] [SPARK-12508][PROJECT-INFRA] Fix minor bugs in dev/tests/pr_public_classes.sh script This patch fixes a handful of minor bugs in the `dev/tests/pr_public_classes.sh` script, which is used by the `run_tests_jenkins` script to detect the addition of new public classes: - Account for differences between BSD and GNU `sed` in order to allow the script to run on OS X. - Diff `$ghprbActualCommit^...$ghprbActualCommit ` instead of `master...$ghprbActualCommit`: since `ghprbActualCommit` is a merge commit which results from merging the PR into the target branch, this will give us the desired diff and will avoid certain race-conditions which could lead to false-positives. - Use `echo -e` instead of `echo` so that newline characters are handled correctly in output. This should fix a formatting glitch which caused the output to appear on a single line in the GitHub comment (see [the SC2028 page](https://github.com/koalaman/shellcheck/wiki/SC2028) on the Shellcheck wiki for more details). Author: Josh Rosen Closes #10455 from JoshRosen/fix-pr-public-classes-test. --- dev/tests/pr_public_classes.sh | 42 ++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 927295b88c963..41c5d3ee8cb3c 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -24,36 +24,44 @@ # # Arg1: The Github Pull Request Actual Commit #+ known as `ghprbActualCommit` in `run-tests-jenkins` -# Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` -# - -# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR -#+ and not anything else added to master since the PR was branched. ghprbActualCommit="$1" -sha1="$2" + +# $ghprbActualCommit is an automatic merge commit generated by GitHub; its parents are some Spark +# master commit and the tip of the pull request branch. + +# By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only +# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# since the PR was branched. + +# Handle differences between GNU and BSD sed +if [[ $(uname) == "Darwin" ]]; then + SED='sed -E' +else + SED='sed -r' +fi source_files=$( - git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + git diff $ghprbActualCommit^...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ | grep -v -e "\/test" `# ignore files in test directories` \ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ | tr "\n" " " ) + new_public_classes=$( - git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + git diff $ghprbActualCommit^...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ + | $SED -e "s/^\+//g" `# remove the leading +` \ | grep -e "trait " -e "class " `# filter in lines with these key words` \ | grep -e "{" -e "(" `# filter in lines with these key words, too` \ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | $SED -e "s/\{.*//g" `# remove from the { onwards` \ + | $SED -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | $SED -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | $SED -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | $SED -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | $SED -e "s/$/\\\n/g" `# append newline to end of line` \ | tr -d "\n" `# remove actual LF characters` ) @@ -61,5 +69,5 @@ if [ -z "$new_public_classes" ]; then echo " * This patch adds no public classes." else public_classes_note=" * This patch adds the following public classes _(experimental)_:" - echo "${public_classes_note}\n${new_public_classes}" + echo -e "${public_classes_note}\n${new_public_classes}" fi From 8543997f2daa60dfa0509f149fab207de98145a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 28 Dec 2015 11:45:44 -0800 Subject: [PATCH 1353/2089] [HOT-FIX] bypass hive test when parse logical plan to json https://github.com/apache/spark/pull/10311 introduces some rare, non-deterministic flakiness for hive udf tests, see https://github.com/apache/spark/pull/10311#issuecomment-166548851 I can't reproduce it locally, and may need more time to investigate, a quick solution is: bypass hive tests for json serialization. Author: Wenchen Fan Closes #10430 from cloud-fan/hot-fix. --- .../src/test/scala/org/apache/spark/sql/QueryTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9246f55020fc4..442ae79f4f86f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -198,6 +198,9 @@ abstract class QueryTest extends PlanTest { case a: ImperativeAggregate => return } + // bypass hive tests before we fix all corner cases in hive module. + if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return + val jsonString = try { logicalPlan.toJSON } catch { @@ -209,9 +212,6 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - // bypass hive tests before we fix all corner cases in hive module. - if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return - // scala function is not serializable to JSON, use null to replace them so that we can compare // the plans later. val normalized1 = logicalPlan.transformAllExpressions { From fd50df413fbb3b7528cdff311cc040a6212340b9 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 28 Dec 2015 11:58:33 -0800 Subject: [PATCH 1354/2089] [SPARK-12231][SQL] create a combineFilters' projection when we call buildPartitionedTableScan Hello Michael & All: We have some issues to submit the new codes in the other PR(#10299), so we closed that PR and open this one with the fix. The reason for the previous failure is that the projection for the scan when there is a filter that is not pushed down (the "left-over" filter) could be different, in elements or ordering, from the original projection. With this new codes, the approach to solve this problem is: Insert a new Project if the "left-over" filter is nonempty and (the original projection is not empty and the projection for the scan has more than one elements which could otherwise cause different ordering in projection). We create 3 test cases to cover the otherwise failure cases. Author: Kevin Yu Closes #10388 from kevinyu98/spark-12231. --- .../datasources/DataSourceStrategy.scala | 28 ++++++++++--- .../parquet/ParquetFilterSuite.scala | 41 +++++++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8a15a51d825ef..3741a9cb32fd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -77,7 +77,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) // Predicates with both partition keys and attributes - val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet + val partitionAndNormalColumnFilters = + filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray @@ -88,16 +89,33 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } + // need to add projections from "partitionAndNormalColumnAttrs" in if it is not empty + val partitionAndNormalColumnAttrs = AttributeSet(partitionAndNormalColumnFilters) + val partitionAndNormalColumnProjs = if (partitionAndNormalColumnAttrs.isEmpty) { + projects + } else { + (partitionAndNormalColumnAttrs ++ projects).toSeq + } + val scan = buildPartitionedTableScan( l, - projects, + partitionAndNormalColumnProjs, pushedFilters, t.partitionSpec.partitionColumns, selectedPartitions) - combineFilters - .reduceLeftOption(expressions.And) - .map(execution.Filter(_, scan)).getOrElse(scan) :: Nil + // Add a Projection to guarantee the original projection: + // this is because "partitionAndNormalColumnAttrs" may be different + // from the original "projects", in elements or their ordering + + partitionAndNormalColumnFilters.reduceLeftOption(expressions.And).map(cf => + if (projects.isEmpty || projects == partitionAndNormalColumnProjs) { + // if the original projection is empty, no need for the additional Project either + execution.Filter(cf, scan) + } else { + execution.Project(projects, execution.Filter(cf, scan)) + } + ).getOrElse(scan) :: Nil // Scanning non-partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index a0abfe458d498..f42f173b2a863 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -325,6 +325,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("SPARK-12231: test the filter and empty project in partitioned DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}" + (1 to 3).map(i => (i, i + 1, i + 2, i + 3)).toDF("a", "b", "c", "d"). + write.partitionBy("a").parquet(path) + + // The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty, + // this query will throw an exception since the project from combinedFilter expect + // two projection while the + val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + + assert(df1.filter("a > 1 or b < 2").count() == 2) + } + } + } + + test("SPARK-12231: test the new projection in partitioned DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}" + (1 to 3).map(i => (i, i + 1, i + 2, i + 3)).toDF("a", "b", "c", "d"). + write.partitionBy("a").parquet(path) + + // test the generate new projection case + // when projects != partitionAndNormalColumnProjs + + val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + + checkAnswer( + df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"), + (2 to 3).map(i => Row(i, i + 1, i + 2, i + 3))) + } + } + } + + test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { import testImplicits._ From 73b70f076d4e22396b7e145f2ce5974fbf788048 Mon Sep 17 00:00:00 2001 From: Yaron Weinsberg Date: Tue, 29 Dec 2015 05:19:11 +0900 Subject: [PATCH 1355/2089] [SPARK-12517] add default RDD name for one created via sc.textFile The feature was first added at commit: 7b877b27053bfb7092e250e01a3b887e1b50a109 but was later removed (probably by mistake) at commit: fc8b58195afa67fbb75b4c8303e022f703cbf007. This change sets the default path of RDDs created via sc.textFile(...) to the path argument. Here is the symptom: * Using spark-1.5.2-bin-hadoop2.6: scala> sc.textFile("/home/root/.bashrc").name res5: String = null scala> sc.binaryFiles("/home/root/.bashrc").name res6: String = /home/root/.bashrc * while using Spark 1.3.1: scala> sc.textFile("/home/root/.bashrc").name res0: String = /home/root/.bashrc scala> sc.binaryFiles("/home/root/.bashrc").name res1: String = /home/root/.bashrc Author: Yaron Weinsberg Author: yaron Closes #10456 from wyaron/master. --- .../scala/org/apache/spark/SparkContext.scala | 4 +-- .../org/apache/spark/SparkContextSuite.scala | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d506782b73c43..bbdc9158d8e2b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -836,7 +836,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[String] = withScope { assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], - minPartitions).map(pair => pair._2.toString) + minPartitions).map(pair => pair._2.toString).setName(path) } /** @@ -885,7 +885,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[Text], classOf[Text], updateConf, - minPartitions).setName(path).map(record => (record._1.toString, record._2.toString)) + minPartitions).map(record => (record._1.toString, record._2.toString)).setName(path) } /** diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index d4f2ea87650a9..172ef050cc275 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -274,6 +274,31 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("Default path for file based RDDs is properly set (SPARK-12517)") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + // Test filetextFile, wholeTextFiles, binaryFiles, hadoopFile and + // newAPIHadoopFile for setting the default path as the RDD name + val mockPath = "default/path/for/" + + var targetPath = mockPath + "textFile" + assert(sc.textFile(targetPath).name === targetPath) + + targetPath = mockPath + "wholeTextFiles" + assert(sc.wholeTextFiles(targetPath).name === targetPath) + + targetPath = mockPath + "binaryFiles" + assert(sc.binaryFiles(targetPath).name === targetPath) + + targetPath = mockPath + "hadoopFile" + assert(sc.hadoopFile(targetPath).name === targetPath) + + targetPath = mockPath + "newAPIHadoopFile" + assert(sc.newAPIHadoopFile(targetPath).name === targetPath) + + sc.stop() + } + test("calling multiple sc.stop() must not throw any exception") { noException should be thrownBy { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) From e01c6c8664d74d434e9b6b3c8c70570f01d4a0a4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 28 Dec 2015 12:23:28 -0800 Subject: [PATCH 1356/2089] [SPARK-12287][SQL] Support UnsafeRow in MapPartitions/MapGroups/CoGroup Support Unsafe Row in MapPartitions/MapGroups/CoGroup. Added a test case for MapPartitions. Since MapGroups and CoGroup are built on AppendColumns, all the related dataset test cases already can verify the correctness when MapGroups and CoGroup processing unsafe rows. davies cloud-fan Not sure if my understanding is right, please correct me. Thank you! Author: gatorsmile Closes #10398 from gatorsmile/unsafeRowMapGroup. --- .../apache/spark/sql/execution/basicOperators.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 21325beb1c8c0..6b7b3bbbf6058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -370,6 +370,10 @@ case class MapPartitions[T, U]( output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def canProcessSafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) @@ -391,6 +395,7 @@ case class AppendColumns[T, U]( // We are using an unsafe combiner. override def canProcessSafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true override def output: Seq[Attribute] = child.output ++ newColumns @@ -420,6 +425,10 @@ case class MapGroups[K, T, U]( output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def canProcessSafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -459,6 +468,10 @@ case class CoGroup[Key, Left, Right, Result]( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def canProcessSafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil From 07165ca06fe0866677525f85fec25e4dbd336674 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 29 Dec 2015 05:33:19 +0900 Subject: [PATCH 1357/2089] [SPARK-12424][ML] The implementation of ParamMap#filter is wrong. ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`. Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654). Author: Kousuke Saruta Closes #10381 from sarutak/SPARK-12424. --- .../org/apache/spark/ml/param/params.scala | 8 ++++-- .../apache/spark/ml/param/ParamsSuite.scala | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ee7e89edd8798..c0546695e487b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Filters this param map for the given parent. */ def filter(parent: Params): ParamMap = { - val filtered = map.filterKeys(_.parent == parent) - new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + // Don't use filterKeys because mutable.Map#filterKeys + // returns the instance of collections.Map, not mutable.Map. + // Otherwise, we get ClassCastException. + // Not using filterKeys also avoid SI-6654 + val filtered = map.filter { case (k, _) => k.parent == parent.uid } + new ParamMap(filtered) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a1878be747ceb..748868554fe65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.param +import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MyParams import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite { val t3 = t.copy(ParamMap(t.maxIter -> 20)) assert(t3.isSet(t3.maxIter)) } + + test("Filtering ParamMap") { + val params1 = new MyParams("my_params1") + val params2 = new MyParams("my_params2") + val paramMap = ParamMap( + params1.intParam -> 1, + params2.intParam -> 1, + params1.doubleParam -> 0.2, + params2.doubleParam -> 0.2) + val filteredParamMap = paramMap.filter(params1) + + assert(filteredParamMap.size === 2) + filteredParamMap.toSeq.foreach { + case ParamPair(p, _) => + assert(p.parent === params1.uid) + } + + // At the previous implementation of ParamMap#filter, + // mutable.Map#filterKeys was used internally but + // the return type of the method is not serializable (see SI-6654). + // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. + // So let's ensure serializability. + val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) + objOut.writeObject(filteredParamMap) + } } object ParamsSuite extends SparkFunSuite { From a6a4812434c6f43cd4742437f957fecd86220255 Mon Sep 17 00:00:00 2001 From: Stephan Kessler Date: Mon, 28 Dec 2015 12:46:20 -0800 Subject: [PATCH 1358/2089] [SPARK-7727][SQL] Avoid inner classes in RuleExecutor Moved (case) classes Strategy, Once, FixedPoint and Batch to the companion object. This is necessary if we want to have the Optimizer easily extendable in the following sense: Usually a user wants to add additional rules, and just take the ones that are already there. However, inner classes made that impossible since the code did not compile This allows easy extension of existing Optimizers see the DefaultOptimizerExtendableSuite for a corresponding test case. Author: Stephan Kessler Closes #10174 from stephankessler/SPARK-7727. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 ++++-- .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../optimizer/OptimizerExtendableSuite.scala | 58 +++++++++++++++++++ 3 files changed, 74 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a9276..0b1c74293bb8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -28,10 +28,12 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -abstract class Optimizer extends RuleExecutor[LogicalPlan] - -object DefaultOptimizer extends Optimizer { - val batches = +/** + * Abstract class all optimizers should inherit of, contains the standard batches (extending + * Optimizers can override this. + */ +abstract class Optimizer extends RuleExecutor[LogicalPlan] { + def batches: Seq[Batch] = { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -66,8 +68,17 @@ object DefaultOptimizer extends Optimizer { DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Nil + } } +/** + * Non-abstract representation of the standard Spark optimizing strategies + * + * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while + * specific rules go to the subclasses + */ +object DefaultOptimizer extends Optimizer + /** * Pushes operations down into a Sample. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index f80d2a93241d1..62ea731ab5f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -59,7 +59,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ - protected val batches: Seq[Batch] + protected def batches: Seq[Batch] /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala new file mode 100644 index 0000000000000..7e3da6bea75e3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This is a test for SPARK-7727 if the Optimizer is kept being extendable + */ +class OptimizerExtendableSuite extends SparkFunSuite { + + /** + * Dummy rule for test batches + */ + object DummyRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + + /** + * This class represents a dummy extended optimizer that takes the batches of the + * Optimizer and adds custom ones. + */ + class ExtendedOptimizer extends Optimizer { + + // rules set to DummyRule, would not be executed anyways + val myBatches: Seq[Batch] = { + Batch("once", Once, + DummyRule) :: + Batch("fixedPoint", FixedPoint(100), + DummyRule) :: Nil + } + + override def batches: Seq[Batch] = super.batches ++ myBatches + } + + test("Extending batches possible") { + // test simply instantiates the new extended optimizer + val extendedOptimizer = new ExtendedOptimizer() + } +} From 01ba95d8bfc16a2542c67b066b0a1d1e465f91da Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 28 Dec 2015 12:48:30 -0800 Subject: [PATCH 1359/2089] [SPARK-12441][SQL] Fixing missingInput in Generate/MapPartitions/AppendColumns/MapGroups/CoGroup When explain any plan with Generate, we will see an exclamation mark in the plan. Normally, when we see this mark, it means the plan has an error. This PR is to correct the `missingInput` in `Generate`. For example, ```scala val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = df.explode('letters) { case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq } df2.explain(true) ``` Before the fix, the plan is like ``` == Parsed Logical Plan == 'Generate UserDefinedGenerator('letters), true, false, None +- Project [_1#0 AS number#2,_2#1 AS letters#3] +- LocalRelation [_1#0,_2#1], [[1,a b c],[2,a b],[3,a]] == Analyzed Logical Plan == number: int, letters: string, _1: string Generate UserDefinedGenerator(letters#3), true, false, None, [_1#8] +- Project [_1#0 AS number#2,_2#1 AS letters#3] +- LocalRelation [_1#0,_2#1], [[1,a b c],[2,a b],[3,a]] == Optimized Logical Plan == Generate UserDefinedGenerator(letters#3), true, false, None, [_1#8] +- LocalRelation [number#2,letters#3], [[1,a b c],[2,a b],[3,a]] == Physical Plan == !Generate UserDefinedGenerator(letters#3), true, false, [number#2,letters#3,_1#8] +- LocalTableScan [number#2,letters#3], [[1,a b c],[2,a b],[3,a]] ``` **Updates**: The same issues are also found in the other four Dataset operators: `MapPartitions`/`AppendColumns`/`MapGroups`/`CoGroup`. Fixed all these four. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10393 from gatorsmile/generateExplain. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 11 ++++++----- .../sql/catalyst/plans/logical/LocalRelation.scala | 4 ++-- .../sql/catalyst/plans/logical/LogicalPlan.scala | 1 + .../catalyst/plans/logical/basicOperators.scala | 8 ++++---- .../apache/spark/sql/execution/ExistingRDD.scala | 8 +++++--- .../org/apache/spark/sql/execution/Generate.scala | 2 ++ .../org/apache/spark/sql/execution/SparkPlan.scala | 1 + .../execution/aggregate/SortBasedAggregate.scala | 9 +++++++++ .../execution/aggregate/TungstenAggregate.scala | 5 +++++ .../spark/sql/execution/basicOperators.scala | 4 ++++ .../columnar/InMemoryColumnarTableScan.scala | 2 ++ .../apache/spark/sql/ExtraStrategiesSuite.scala | 7 +++---- .../scala/org/apache/spark/sql/QueryTest.scala | 14 ++++++++++++++ .../spark/sql/hive/execution/HiveTableScan.scala | 3 +++ .../sql/hive/execution/ScriptTransformation.scala | 2 ++ 15 files changed, 63 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d2626440b9434..b43b7ee71e7aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -43,16 +43,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def inputSet: AttributeSet = AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + /** + * The set of all attributes that are produced by this node. + */ + def producedAttributes: AttributeSet = AttributeSet.empty + /** * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. - * - * Note that virtual columns should be excluded. Currently, we only support the grouping ID - * virtual column. */ - def missingInput: AttributeSet = - (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) + def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e7a11dba973..572d7d2f0b537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e105932..6d859551f8c52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -295,6 +295,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ abstract class LeafNode extends LogicalPlan { override def children: Seq[LogicalPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64ef4d799659f..5f34d4a4eb73c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -526,7 +526,7 @@ case class MapPartitions[T, U]( uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } /** Factory for constructing new `AppendColumn` nodes. */ @@ -552,7 +552,7 @@ case class AppendColumns[T, U]( newColumns: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns - override def missingInput: AttributeSet = super.missingInput -- newColumns + override def producedAttributes: AttributeSet = AttributeSet(newColumns) } /** Factory for constructing new `MapGroups` nodes. */ @@ -587,7 +587,7 @@ case class MapGroups[K, T, U]( groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } /** Factory for constructing new `CoGroup` nodes. */ @@ -630,5 +630,5 @@ case class CoGroup[Key, Left, Right, Result]( rightGroup: Seq[Attribute], left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index ea5a9afe03b00..5c01af011d306 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation} +import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} @@ -84,6 +84,8 @@ private[sql] case class LogicalRDD( case _ => false } + override def producedAttributes: AttributeSet = outputSet + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 54b8cb58285c2..0c613e91b979f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -54,6 +54,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { + override def expressions: Seq[Expression] = generator :: Nil + val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ec98f81041343..fe9b2ad4a0bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -279,6 +279,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] trait LeafNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } private[sql] trait UnaryNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c5470a6989de7..c4587ba677b2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -36,6 +36,15 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index b8849c827048a..9d758eb3b7c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -55,6 +55,11 @@ case class TungstenAggregate( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.length == 0 => AllTuples :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6b7b3bbbf6058..f19d72f067218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -369,6 +369,7 @@ case class MapPartitions[T, U]( uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = outputSet override def canProcessSafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -391,6 +392,7 @@ case class AppendColumns[T, U]( uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = AttributeSet(newColumns) // We are using an unsafe combiner. override def canProcessSafeRows: Boolean = false @@ -424,6 +426,7 @@ case class MapGroups[K, T, U]( groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = outputSet override def canProcessSafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -467,6 +470,7 @@ case class CoGroup[Key, Left, Right, Result]( rightGroup: Seq[Attribute], left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def producedAttributes: AttributeSet = outputSet override def canProcessSafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 4afa5f8ec1035..aa7a668e0e938 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -66,6 +66,8 @@ private[sql] case class InMemoryRelation( private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { + override def producedAttributes: AttributeSet = outputSet + private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = if (_batchStats == null) { child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 78a98798eff64..359a1e7f8424a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -15,16 +15,14 @@ * limitations under the License. */ -package test.org.apache.spark.sql +package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.UTF8String case class FastOperator(output: Seq[Attribute]) extends SparkPlan { @@ -34,6 +32,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { sparkContext.parallelize(Seq(row)) } + override def producedAttributes: AttributeSet = outputSet override def children: Seq[SparkPlan] = Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 442ae79f4f86f..815372f19233b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -130,6 +130,8 @@ abstract class QueryTest extends PlanTest { checkJsonFormat(analyzedDF) + assertEmptyMissingInput(df) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -275,6 +277,18 @@ abstract class QueryTest extends PlanTest { """.stripMargin) } } + + /** + * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans. + */ + def assertEmptyMissingInput(query: Queryable): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}") + } } object QueryTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 806d2b9b0b7d4..8141136de5311 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -51,6 +51,9 @@ case class HiveTableScan( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + override def producedAttributes: AttributeSet = outputSet ++ + AttributeSet(partitionPruningPred.flatMap(_.references)) + // Retrieve the original attributes based on expression ID so that capitalization matches. val attributes = requestedAttributes.map(relation.attributeMap) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index d9b9ba4bfdfed..a61e162f48f1b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -60,6 +60,8 @@ case class ScriptTransformation( override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override def producedAttributes: AttributeSet = outputSet -- inputSet + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) protected override def doExecute(): RDD[InternalRow] = { From a6d385322e7dfaff600465fa5302010a5f122c6b Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 29 Dec 2015 07:02:30 +0900 Subject: [PATCH 1360/2089] [SPARK-12222][CORE] Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception Since we only need to implement `def skipBytes(n: Int)`, code in #10213 could be simplified. davies scwf Author: Daoyuan Wang Closes #10253 from adrian-wang/kryo. --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index cb2ac5ea167ec..eed9937b3046f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -401,12 +401,7 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readInt(): Int = input.readInt() override def readUnsignedShort(): Int = input.readShortUnsigned() override def skipBytes(n: Int): Int = { - var remaining: Long = n - while (remaining > 0) { - val skip = Math.min(Integer.MAX_VALUE, remaining).asInstanceOf[Int] - input.skip(skip) - remaining -= skip - } + input.skip(n) n } override def readFully(b: Array[Byte]): Unit = input.read(b) From fb572c6e4b0645c8084aa013d0c93bb21a79977b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 28 Dec 2015 14:51:22 -0800 Subject: [PATCH 1361/2089] [SPARK-12525] Fix fatal compiler warnings in Kinesis ASL due to @transient annotations The Scala 2.11 SBT build currently fails for Spark 1.6.0 and master due to warnings about the `transient` annotation: ``` [error] [warn] /Users/joshrosen/Documents/spark/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala:73: no valid targets for annotation on value sc - it is discarded unused. You may specify targets with meta-annotations, e.g. (transient param) [error] [warn] transient sc: SparkContext, ``` This fix implemented here is the same as what we did in #8433: remove the `transient` annotations when they are not necessary and replace use `transient private val` in the remaining cases. Author: Josh Rosen Closes #10479 from JoshRosen/fix-sbt-2.11. --- .../streaming/kinesis/KinesisBackedBlockRDD.scala | 14 +++++++------- .../streaming/kinesis/KinesisInputDStream.scala | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 691c1790b207f..3996f168e69ee 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -70,26 +70,26 @@ class KinesisBackedBlockRDDPartition( */ private[kinesis] class KinesisBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, val regionName: String, val endpointUrl: String, - @transient blockIds: Array[BlockId], + @transient private val _blockIds: Array[BlockId], @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, val awsCredentialsOption: Option[SerializableAWSCredentials] = None - ) extends BlockRDD[T](sc, blockIds) { + ) extends BlockRDD[T](sc, _blockIds) { - require(blockIds.length == arrayOfseqNumberRanges.length, + require(_blockIds.length == arrayOfseqNumberRanges.length, "Number of blockIds is not equal to the number of sequence number ranges") override def isValid(): Boolean = true override def getPartitions: Array[Partition] = { - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + new KinesisBackedBlockRDDPartition(i, _blockIds(i), isValid, arrayOfseqNumberRanges(i)) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 72ab6357a53b0..3321c7527edb4 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -30,7 +30,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.{Duration, StreamingContext, Time} private[kinesis] class KinesisInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, streamName: String, endpointUrl: String, regionName: String, From 710b41172958a0b3a2b70c48821aefc81893731b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 28 Dec 2015 15:01:51 -0800 Subject: [PATCH 1362/2089] [SPARK-12489][CORE][SQL][MLIB] Fix minor issues found by FindBugs Include the following changes: 1. Close `java.sql.Statement` 2. Fix incorrect `asInstanceOf`. 3. Remove unnecessary `synchronized` and `ReentrantLock`. Author: Shixiong Zhu Closes #10440 from zsxwing/findbugs. --- .../cluster/mesos/MesosClusterScheduler.scala | 3 +- .../apache/spark/launcher/LauncherServer.java | 4 +- .../java/org/apache/spark/launcher/Main.java | 2 +- .../scala/org/apache/spark/ml/tree/Node.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 7 ++- .../execution/datasources/jdbc/JDBCRDD.scala | 47 ++++++++++--------- .../datasources/jdbc/JdbcUtils.scala | 16 ++++++- 7 files changed, 51 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index a6d9374eb9e8c..16815d51d4c67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} import scala.collection.JavaConverters._ @@ -126,7 +125,7 @@ private[spark] class MesosClusterScheduler( private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute private val schedulerState = engineFactory.createEngine("scheduler") - private val stateLock = new ReentrantLock() + private val stateLock = new Object() private val finishedDrivers = new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) private var frameworkId: String = null diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index d099ee9aa9dae..414ffc2c84e52 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -293,9 +293,7 @@ private class ServerConnection extends LauncherConnection { protected void handle(Message msg) throws IOException { try { if (msg instanceof Hello) { - synchronized (timeout) { - timeout.cancel(); - } + timeout.cancel(); timeout = null; Hello hello = (Hello) msg; ChildProcAppHandle handle = pending.remove(hello.secret); diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index a4e3acc674f36..e751e948e3561 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -151,7 +151,7 @@ private static class MainClassOptionParser extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { - if (opt == CLASS) { + if (CLASS.equals(opt)) { className = value; } return false; diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d89682611e3f5..9cfd466294b95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -386,9 +386,9 @@ private[tree] object LearningNode { var levelsToGo = indexToLevel(nodeIndex) while (levelsToGo > 0) { if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { - tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode] + tmpNode = tmpNode.leftChild.get } else { - tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode] + tmpNode = tmpNode.rightChild.get } levelsToGo -= 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb78224..ab362539e2982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -297,7 +297,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" - conn.createStatement.executeUpdate(sql) + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index fc0f86cb1813f..4e2f5059be4e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -122,30 +122,35 @@ private[sql] object JDBCRDD extends Logging { val dialect = JdbcDialects.get(url) val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() try { - val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() + val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") try { - val rsmd = rs.getMetaData - val ncols = rsmd.getColumnCount - val fields = new Array[StructField](ncols) - var i = 0 - while (i < ncols) { - val columnName = rsmd.getColumnLabel(i + 1) - val dataType = rsmd.getColumnType(i + 1) - val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val fieldScale = rsmd.getScale(i + 1) - val isSigned = rsmd.isSigned(i + 1) - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) - val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) - i = i + 1 + val rs = statement.executeQuery() + try { + val rsmd = rs.getMetaData + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnLabel(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val typeName = rsmd.getColumnTypeName(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val metadata = new MetadataBuilder().putString("name", columnName) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) + fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + i = i + 1 + } + return new StructType(fields) + } finally { + rs.close() } - return new StructType(fields) } finally { - rs.close() + statement.close() } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 28cd688ef7d7a..46f2670eee010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -49,14 +49,26 @@ object JdbcUtils extends Logging { // Somewhat hacky, but there isn't a good way to identify whether a table exists for all // SQL database systems using JDBC meta data calls, considering "table" could also include // the database name. Query used to find table exists can be overriden by the dialects. - Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess + Try { + val statement = conn.prepareStatement(dialect.getTableExistsQuery(table)) + try { + statement.executeQuery() + } finally { + statement.close() + } + }.isSuccess } /** * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String): Unit = { - conn.createStatement.executeUpdate(s"DROP TABLE $table") + val statement = conn.createStatement + try { + statement.executeUpdate(s"DROP TABLE $table") + } finally { + statement.close() + } } /** From 124a3a5e4eece3aabca44fbdd2f8c4c086d6eec3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 28 Dec 2015 16:42:11 -0800 Subject: [PATCH 1363/2089] [SPARK-12490] Don't use Javascript for web UI's paginated table controls The web UI's paginated table uses Javascript to implement certain navigation controls, such as table sorting and the "go to page" form. This is unnecessary and should be simplified to use plain HTML form controls and links. /cc zsxwing, who wrote this original code, and yhuai. Author: Josh Rosen Closes #10441 from JoshRosen/simplify-paginated-table-sorting. --- .../org/apache/spark/ui/static/webui.css | 11 +- .../org/apache/spark/ui/PagedTable.scala | 100 ++++++++++++------ .../org/apache/spark/ui/jobs/StagePage.scala | 80 +++++++++----- .../org/apache/spark/ui/storage/RDDPage.scala | 76 +++++++------ .../org/apache/spark/ui/PagedTableSuite.scala | 8 +- 5 files changed, 178 insertions(+), 97 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b54e33a96fa23..dd708ef2c29b5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -225,4 +225,13 @@ a.expandbutton { background-color: #49535a !important; color: white; cursor:pointer; -} \ No newline at end of file +} + +th a, th a:hover { + /* Make the entire header clickable, not just the text label */ + display: block; + width: 100%; + /* Suppress the default link styling */ + color: #333; + text-decoration: none; +} diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 6e2375477a688..9b6ed8cbbef10 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -17,8 +17,15 @@ package org.apache.spark.ui +import java.net.URLDecoder + +import scala.collection.JavaConverters._ import scala.xml.{Node, Unparsed} +import com.google.common.base.Splitter + +import org.apache.spark.util.Utils + /** * A data source that provides data for a page. * @@ -71,6 +78,12 @@ private[ui] trait PagedTable[T] { def tableCssClass: String + def pageSizeFormField: String + + def prevPageSizeFormField: String + + def pageNumberFormField: String + def dataSource: PagedDataSource[T] def headers: Seq[Node] @@ -95,7 +108,12 @@ private[ui] trait PagedTable[T] { val PageData(totalPages, _) = _dataSource.pageData(1)
    {pageNavigation(1, _dataSource.pageSize, totalPages)} -
    {e.getMessage}
    +
    +

    Error while rendering table:

    +
    +              {Utils.exceptionString(e)}
    +            
    +
    } } @@ -151,36 +169,56 @@ private[ui] trait PagedTable[T] { // The current page should be disabled so that it cannot be clicked.
  • {p}
  • } else { -
  • {p}
  • +
  • {p}
  • + } + } + + val hiddenFormFields = { + if (goButtonFormPath.contains('?')) { + val querystring = goButtonFormPath.split("\\?", 2)(1) + Splitter + .on('&') + .trimResults() + .withKeyValueSeparator("=") + .split(querystring) + .asScala + .filterKeys(_ != pageSizeFormField) + .filterKeys(_ != prevPageSizeFormField) + .filterKeys(_ != pageNumberFormField) + .mapValues(URLDecoder.decode(_, "UTF-8")) + .map { case (k, v) => + + } + } else { + Seq.empty } } - val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction - // When clicking the "Go" button, it will call this javascript method and then call - // "goButtonJsFuncName" - val formJs = - s"""$$(function(){ - | $$( "#form-$tableId-page" ).submit(function(event) { - | var page = $$("#form-$tableId-page-no").val() - | var pageSize = $$("#form-$tableId-page-size").val() - | pageSize = pageSize ? pageSize: 100; - | if (page != "") { - | ${goButtonJsFuncName}(page, pageSize); - | } - | event.preventDefault(); - | }); - |}); - """.stripMargin
    + method="get" + action={Unparsed(goButtonFormPath)} + class="form-inline pull-right" + style="margin-bottom: 0px;"> + + {hiddenFormFields} - + + + id={s"form-$tableId-page-size"} + name={pageSizeFormField} + value={pageSize.toString} + class="span1" /> +
    @@ -189,7 +227,7 @@ private[ui] trait PagedTable[T] {
    -
    } } @@ -239,10 +272,7 @@ private[ui] trait PagedTable[T] { def pageLink(page: Int): String /** - * Only the implementation knows how to create the url with a page number and the page size, so we - * leave this one to the implementation. The implementation should create a JavaScript method that - * accepts a page number along with the page size and jumps to the page. The return value is this - * method name and its JavaScript codes. + * Returns the submission path for the "go to page #" form. */ - def goButtonJavascriptFunction: (String, String) + def goButtonFormPath: String } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1b34ba9f03c44..b02b99a6fc7aa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -97,11 +97,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterTaskSortColumn = request.getParameter("task.sort") val parameterTaskSortDesc = request.getParameter("task.desc") val parameterTaskPageSize = request.getParameter("task.pageSize") + val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) // If this is set, expand the dag visualization by default val expandDagVizParam = request.getParameter("expandDagViz") @@ -274,6 +276,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { accumulableRow, externalAccumulables.toSeq) + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (taskPageSize <= taskPrevPageSize) { + taskPage + } else { + 1 + } + } val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( @@ -292,10 +303,17 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { sortColumn = taskSortColumn, desc = taskSortDesc ) - (_taskTable, _taskTable.table(taskPage)) + (_taskTable, _taskTable.table(page)) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - (null,
    {e.getMessage}
    ) + val errorMessage = +
    +

    Error while rendering stage table:

    +
    +                {Utils.exceptionString(e)}
    +              
    +
    + (null, errorMessage) } val jsForScrollingDownToTaskTable = @@ -1217,6 +1235,12 @@ private[ui] class TaskPagedTable( override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def pageSizeFormField: String = "task.pageSize" + + override def prevPageSizeFormField: String = "task.prevPageSize" + + override def pageNumberFormField: String = "task.page" + override val dataSource: TaskDataSource = new TaskDataSource( data, hasAccumulators, @@ -1232,24 +1256,16 @@ private[ui] class TaskPagedTable( override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + - s"&task.pageSize=${pageSize}" + basePath + + s"&$pageNumberFormField=$page" + + s"&task.sort=$encodedSortColumn" + + s"&task.desc=$desc" + + s"&$pageSizeFormField=$pageSize" } - override def goButtonJavascriptFunction: (String, String) = { - val jsFuncName = "goToTaskPage" + override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - val jsFunc = s""" - |currentTaskPageSize = ${pageSize} - |function goToTaskPage(page, pageSize) { - | // Set page to 1 if the page size changes - | page = pageSize == currentTaskPageSize ? page : 1; - | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + - | "&task.page=" + page + "&task.pageSize=" + pageSize; - | window.location.href = url; - |} - """.stripMargin - (jsFuncName, jsFunc) + s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" } def headers: Seq[Node] = { @@ -1298,21 +1314,27 @@ private[ui] class TaskPagedTable( val headerRow: Seq[Node] = { taskHeadersAndCssClasses.map { case (header, cssClass) => if (header == sortColumn) { - val headerLink = - s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + - s"&task.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") + val headerLink = Unparsed( + basePath + + s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&task.desc=${!desc}" + + s"&task.pageSize=$pageSize") val arrow = if (desc) "▾" else "▴" // UP or DOWN - - {header} -  {Unparsed(arrow)} + +
    + {header} +  {Unparsed(arrow)} + } else { - val headerLink = - s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") - - {header} + val headerLink = Unparsed( + basePath + + s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&task.pageSize=$pageSize") + + + {header} + } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index fd6cc3ed759b3..3e51ce2e97994 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -38,11 +38,13 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val parameterBlockSortColumn = request.getParameter("block.sort") val parameterBlockSortDesc = request.getParameter("block.desc") val parameterBlockPageSize = request.getParameter("block.pageSize") + val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize") val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) + val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize) val rddId = parameterId.toInt val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) @@ -56,17 +58,26 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table")) // Block table - val (blockTable, blockTableHTML) = try { + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (blockPageSize <= blockPrevPageSize) { + blockPage + } else { + 1 + } + } + val blockTableHTML = try { val _blockTable = new BlockPagedTable( UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, blockSortDesc) - (_blockTable, _blockTable.table(blockPage)) + _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - (null,
    {e.getMessage}
    ) +
    {e.getMessage}
    } val jsForScrollingDownToBlockTable = @@ -228,6 +239,12 @@ private[ui] class BlockPagedTable( override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def pageSizeFormField: String = "block.pageSize" + + override def prevPageSizeFormField: String = "block.prevPageSize" + + override def pageNumberFormField: String = "block.page" + override val dataSource: BlockDataSource = new BlockDataSource( rddPartitions, pageSize, @@ -236,24 +253,16 @@ private[ui] class BlockPagedTable( override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"${basePath}&block.page=$page&block.sort=${encodedSortColumn}&block.desc=${desc}" + - s"&block.pageSize=${pageSize}" + basePath + + s"&$pageNumberFormField=$page" + + s"&block.sort=$encodedSortColumn" + + s"&block.desc=$desc" + + s"&$pageSizeFormField=$pageSize" } - override def goButtonJavascriptFunction: (String, String) = { - val jsFuncName = "goToBlockPage" + override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - val jsFunc = s""" - |currentBlockPageSize = ${pageSize} - |function goToBlockPage(page, pageSize) { - | // Set page to 1 if the page size changes - | page = pageSize == currentBlockPageSize ? page : 1; - | var url = "${basePath}&block.sort=${encodedSortColumn}&block.desc=${desc}" + - | "&block.page=" + page + "&block.pageSize=" + pageSize; - | window.location.href = url; - |} - """.stripMargin - (jsFuncName, jsFunc) + s"$basePath&block.sort=$encodedSortColumn&block.desc=$desc" } override def headers: Seq[Node] = { @@ -271,22 +280,27 @@ private[ui] class BlockPagedTable( val headerRow: Seq[Node] = { blockHeaders.map { header => if (header == sortColumn) { - val headerLink = - s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}&block.desc=${!desc}" + - s"&block.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") + val headerLink = Unparsed( + basePath + + s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.desc=${!desc}" + + s"&block.pageSize=$pageSize") val arrow = if (desc) "▾" else "▴" // UP or DOWN - - {header} -  {Unparsed(arrow)} + + + {header} +  {Unparsed(arrow)} + } else { - val headerLink = - s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&block.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") - - {header} + val headerLink = Unparsed( + basePath + + s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.pageSize=$pageSize") + + + {header} + } } diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala index cc76c141c53cc..74eeca282882a 100644 --- a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -64,7 +64,13 @@ class PagedTableSuite extends SparkFunSuite { override def row(t: Int): Seq[Node] = Nil - override def goButtonJavascriptFunction: (String, String) = ("", "") + override def pageSizeFormField: String = "pageSize" + + override def prevPageSizeFormField: String = "prevPageSize" + + override def pageNumberFormField: String = "page" + + override def goButtonFormPath: String = "" } assert(pagedTable.pageNavigation(1, 10, 1) === Nil) From 043135819c487abe9657c11006ce468a6e1f262e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 28 Dec 2015 17:22:18 -0800 Subject: [PATCH 1364/2089] [SPARK-12522][SQL][MINOR] Add the missing document strings for the SQL configuration Fixing the missing the document for the configuration. We can see the missing messages "TODO" when issuing the command "SET -V". ``` spark.sql.columnNameOfCorruptRecord spark.sql.hive.verifyPartitionPath spark.sql.sources.parallelPartitionDiscovery.threshold spark.sql.hive.convertMetastoreParquet.mergeSchema spark.sql.hive.convertCTAS spark.sql.hive.thriftServer.async ``` Author: gatorsmile Closes #10471 from gatorsmile/commandDesc. --- .../src/main/scala/org/apache/spark/sql/SQLConf.scala | 8 +++++--- .../scala/org/apache/spark/sql/execution/commands.scala | 2 -- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 9 ++++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 3d819262859f8..b58a3739912bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -334,7 +334,8 @@ private[spark] object SQLConf { val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", defaultValue = Some(false), - doc = "") + doc = "When true, check all the partition paths under the table\'s root directory " + + "when reading data stored in HDFS.") val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", defaultValue = Some(false), @@ -352,7 +353,7 @@ private[spark] object SQLConf { val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), - doc = "") + doc = "The name of internal column for storing raw/un-parsed JSON records that fail to parse.") val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", defaultValue = Some(5 * 60), @@ -413,7 +414,8 @@ private[spark] object SQLConf { val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = intConf( key = "spark.sql.sources.parallelPartitionDiscovery.threshold", defaultValue = Some(32), - doc = "") + doc = "The degree of parallelism for schema merging and partition discovery of " + + "Parquet data sources.") // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e2dc13d66c61e..6ec4cadeeb072 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -148,8 +148,6 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) - (keyValueOutput, runFunc) - case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) => val runFunc = (sqlContext: SQLContext) => { logWarning( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 0eeb62ca2cb3f..384ea211df843 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -692,11 +692,14 @@ private[hive] object HiveContext { val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( "spark.sql.hive.convertMetastoreParquet.mergeSchema", defaultValue = Some(false), - doc = "TODO") + doc = "When true, also tries to merge possibly different but compatible Parquet schemas in " + + "different Parquet data files. This configuration is only effective " + + "when \"spark.sql.hive.convertMetastoreParquet\" is true.") val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", defaultValue = Some(false), - doc = "TODO") + doc = "When true, a table created by a Hive CTAS statement (no USING clause) will be " + + "converted to a data source table, using the data source set by spark.sql.sources.default.") val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", defaultValue = Some(jdbcPrefixes), @@ -717,7 +720,7 @@ private[hive] object HiveContext { val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", defaultValue = Some(true), - doc = "TODO") + doc = "When set to true, Hive Thrift server executes SQL queries in an asynchronous way.") /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(useInMemoryDerby: Boolean): Map[String, String] = { From 1a91be807802ec88c068f1090dafe8fbfb1c6d5c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Dec 2015 20:43:06 -0800 Subject: [PATCH 1365/2089] [SPARK-12547][SQL] Tighten scala style checker enforcement for UDF registration We use scalastyle:off to turn off style checks in certain places where it is not possible to follow the style guide. This is usually ok. However, in udf registration, we disable the checker for a large amount of code simply because some of them exceed 100 char line limit. It is better to just disable the line limit check rather than everything. In this pull request, I only disabled line length check, and fixed a problem (lack explicit types for public methods). Author: Reynold Xin Closes #10501 from rxin/SPARK-12547. --- .../apache/spark/sql/UDFRegistration.scala | 53 +++++++++---------- .../org/apache/spark/sql/functions.scala | 6 ++- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 051694c0d43a6..42c373ea723dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.{List => JList, Map => JMap} - import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -69,7 +67,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } - // scalastyle:off + // scalastyle:off line.size.limit /* register 0-22 were generated by this script @@ -102,7 +100,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { | * Register a user-defined function with ${i} arguments. | * @since 1.3.0 | */ - |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { + |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { | functionRegistry.registerFunction( | name, | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) @@ -416,7 +414,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 1 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF1[_, _], returnType: DataType) = { + def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) @@ -426,7 +424,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 2 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { + def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) @@ -436,7 +434,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 3 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) @@ -446,7 +444,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 4 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -456,7 +454,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 5 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -466,7 +464,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 6 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -476,7 +474,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 7 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -486,7 +484,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 8 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -496,7 +494,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 9 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -506,7 +504,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 10 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -516,7 +514,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 11 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -526,7 +524,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 12 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -536,7 +534,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 13 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -546,7 +544,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 14 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -556,7 +554,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 15 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -566,7 +564,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 16 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -576,7 +574,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 17 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -586,7 +584,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 18 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -596,7 +594,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 19 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -606,7 +604,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 20 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -616,7 +614,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 21 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -626,11 +624,12 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 22 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } - // scalastyle:on + // scalastyle:on line.size.limit + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 65733dcf83e76..487191638f5c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2512,7 +2512,8 @@ object functions extends LegacyFunctions { ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off + // scalastyle:off line.size.limit + // scalastyle:off parameter.number /* Use the following code to generate: (0 to 10).map { x => @@ -2839,7 +2840,8 @@ object functions extends LegacyFunctions { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } - // scalastyle:on + // scalastyle:on parameter.number + // scalastyle:on line.size.limit /** * Call an user-defined function. From 73862a1eb9744c3c32458c9c6f6431c23783786a Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 28 Dec 2015 21:28:32 -0800 Subject: [PATCH 1366/2089] [SPARK-11394][SQL] Throw IllegalArgumentException for unsupported types in postgresql If DataFrame has BYTE types, throws an exception: org.postgresql.util.PSQLException: ERROR: type "byte" does not exist Author: Takeshi YAMAMURO Closes #9350 from maropu/FixBugInPostgreJdbc. --- .../scala/org/apache/spark/sql/jdbc/PostgresDialect.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3cf80f576e92c..ad9e31690b2d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -64,6 +64,7 @@ private object PostgresDialect extends JdbcDialect { getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) + case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7975c5df6c0bb..4044a10ce70cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -482,6 +482,10 @@ class JDBCSuite extends SparkFunSuite val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + val errMsg = intercept[IllegalArgumentException] { + Postgres.getJDBCType(ByteType) + } + assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType") } test("DerbyDialect jdbc type mapping") { From d80cc90b5545cff82cd9b340f12d01eafc9ca524 Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Tue, 29 Dec 2015 12:45:24 +0530 Subject: [PATCH 1367/2089] [SPARK-12526][SPARKR] ifelse`, `when`, `otherwise` unable to take Column as value `ifelse`, `when`, `otherwise` is unable to take `Column` typed S4 object as values. For example: ```r ifelse(lit(1) == lit(1), lit(2), lit(3)) ifelse(df$mpg > 0, df$mpg, 0) ``` will both fail with ```r attempt to replicate an object of type 'environment' ``` The PR replaces `ifelse` calls with `if ... else ...` inside the function implementations to avoid attempt to vectorize(i.e. `rep()`). It remains to be discussed whether we should instead support vectorization in these functions for consistency because `ifelse` in base R is vectorized but I cannot foresee any scenarios these functions will want to be vectorized in SparkR. For reference, added test cases which trigger failures: ```r . Error: when(), otherwise() and ifelse() with column on a DataFrame ---------- error in evaluating the argument 'x' in selecting a method for function 'collect': error in evaluating the argument 'col' in selecting a method for function 'select': attempt to replicate an object of type 'environment' Calls: when -> when -> ifelse -> ifelse 1: withCallingHandlers(eval(code, new_test_environment), error = capture_calls, message = function(c) invokeRestart("muffleMessage")) 2: eval(code, new_test_environment) 3: eval(expr, envir, enclos) 4: expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) at test_sparkSQL.R:1126 5: expect_that(object, equals(expected, label = expected.label, ...), info = info, label = label) 6: condition(object) 7: compare(actual, expected, ...) 8: collect(select(df, when(df$a > 1 & df$b > 2, lit(1)))) Error: Test failures Execution halted ``` Author: Forest Fang Closes #10481 from saurfang/spark-12526. --- R/pkg/R/column.R | 4 ++-- R/pkg/R/functions.R | 13 ++++++++----- R/pkg/inst/tests/testthat/test_sparkSQL.R | 8 ++++++++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 7bb8ef2595b59..356bcee3cf5c6 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -215,7 +215,7 @@ setMethod("%in%", #' otherwise #' -#' If values in the specified column are null, returns the value. +#' If values in the specified column are null, returns the value. #' Can be used in conjunction with `when` to specify a default value for expressions. #' #' @rdname otherwise @@ -225,7 +225,7 @@ setMethod("%in%", setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 09e4e04335a33..df36bc869acb4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -37,7 +37,7 @@ setMethod("lit", signature("ANY"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "lit", - ifelse(class(x) == "Column", x@jc, x)) + if (class(x) == "Column") { x@jc } else { x }) column(jc) }) @@ -2262,7 +2262,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) column(jc) }) @@ -2277,13 +2277,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @name ifelse #' @seealso \link{when} #' @export -#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} +#' @examples \dontrun{ +#' ifelse(df$a > 1 & df$b > 2, 0, 1) +#' ifelse(df$a > 1, df$a, 1) +#' } setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { test <- test@jc - yes <- ifelse(class(yes) == "Column", yes@jc, yes) - no <- ifelse(class(no) == "Column", no@jc, no) + yes <- if (class(yes) == "Column") { yes@jc } else { yes } + no <- if (class(no) == "Column") { no@jc } else { no } jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", "when", test, yes), diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 135c7576e5291..c2b6adbe3ae01 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1120,6 +1120,14 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) +test_that("when(), otherwise() and ifelse() with column on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, lit(1)), lit(0))))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) +}) + test_that("group by, agg functions", { df <- read.json(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") From 8e629b10cb5167926356e2f23d3c35610aa87ffe Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 29 Dec 2015 10:35:23 -0800 Subject: [PATCH 1368/2089] [SPARK-12530][BUILD] Fix build break at Spark-Master-Maven-Snapshots from #1293 Compilation error caused due to string concatenations that are not a constant Use raw string literal to avoid string concatenations https://amplab.cs.berkeley.edu/jenkins/view/Spark-Packaging/job/Spark-Master-Maven-Snapshots/1293/ Author: Kazuaki Ishizaki Closes #10488 from kiszk/SPARK-12530. --- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 97f276d49f08c..d0ec99b2320df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -57,9 +57,10 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput * the hash length is not one of the permitted values, the return value is NULL. */ @ExpressionDescription( - usage = "_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the " + - "input. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent " + - "to 256", + usage = + """_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the input. + SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.""" + , extended = "> SELECT _FUNC_('Spark', 0);\n " + "'529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b'") case class Sha2(left: Expression, right: Expression) From f6ecf143335d734b8f22c59649c6bbd4d5401745 Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 29 Dec 2015 11:44:20 -0800 Subject: [PATCH 1369/2089] [SPARK-11199][SPARKR] Improve R context management story and add getOrCreate * Changes api.r.SQLUtils to use ```SQLContext.getOrCreate``` instead of creating a new context. * Adds a simple test [SPARK-11199] #comment link with JIRA Author: Hossein Closes #9185 from falaki/SPARK-11199. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ .../src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c2b6adbe3ae01..7b508b860efb2 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -62,6 +62,10 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) +test_that("calling sparkRSQL.init returns existing SQL context", { + expect_equal(sparkRSQL.init(sc), sqlContext) +}) + test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index b3f134614c6bb..67da7b808bf59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -32,7 +32,7 @@ private[r] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) def createSQLContext(jsc: JavaSparkContext): SQLContext = { - new SQLContext(jsc) + SQLContext.getOrCreate(jsc.sc) } def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { From be86268eb54e3fa0a9ce7a07359b3e67731ed8b5 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 29 Dec 2015 16:32:26 -0800 Subject: [PATCH 1370/2089] [SPARK-12349][SPARK-12349][ML] Fix typo in Spark version regex introduced in / PR 10327 Sorry jkbradley Ref: https://github.com/apache/spark/pull/10327#discussion_r48502942 Author: Sean Owen Closes #10508 from srowen/SPARK-12349.2. --- mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 759be813eea6b..f653798b46043 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -180,7 +180,7 @@ object PCAModel extends MLReadable[PCAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) // explainedVariance field is not present in Spark <= 1.6 - val versionRegex = "([0-9]+)\\.([0-9])+.*".r + val versionRegex = "([0-9]+)\\.([0-9]+).*".r val hasExplainedVariance = metadata.sparkVersion match { case versionRegex(major, minor) => (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) From 270a659584b6c1c304a9f9a331c56287672e00b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 29 Dec 2015 16:58:23 -0800 Subject: [PATCH 1371/2089] [SPARK-12549][SQL] Take Option[Seq[DataType]] in UDF input type specification. In Spark we allow UDFs to declare its expected input types in order to apply type coercion. The expected input type parameter takes a Seq[DataType] and uses Nil when no type coercion is applied. It makes more sense to take Option[Seq[DataType]] instead, so we can differentiate a no-arg function vs function with no expected input type specified. Author: Reynold Xin Closes #10504 from rxin/SPARK-12549. --- .../sql/catalyst/expressions/ScalaUDF.scala | 12 ++- .../apache/spark/sql/UDFRegistration.scala | 96 +++++++++---------- .../spark/sql/UserDefinedFunction.scala | 4 +- .../datasources/WriterContainer.scala | 5 +- .../org/apache/spark/sql/functions.scala | 26 ++--- 5 files changed, 75 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 85faa19bbf5ec..64d397bf848a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -30,7 +30,10 @@ import org.apache.spark.sql.types.DataType * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. * @param children The input expressions of this UDF. - * @param inputTypes The expected input types of this UDF. + * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do + * not want to perform coercion, simply use "Nil". Note that it would've been + * better to use Option of Seq[DataType] so we can use "None" as the case for no + * type coercion. However, that would require more refactoring of the codebase. */ case class ScalaUDF( function: AnyRef, @@ -43,7 +46,7 @@ case class ScalaUDF( override def toString: String = s"UDF(${children.mkString(",")})" - // scalastyle:off + // scalastyle:off line.size.limit /** This method has been generated by this script @@ -969,7 +972,7 @@ case class ScalaUDF( } } - // scalastyle:on + // scalastyle:on line.size.limit // Generate codes used to convert the arguments to Scala type for user-defined funtions private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { @@ -1010,7 +1013,7 @@ case class ScalaUDF( // This must be called before children expressions' codegen // because ctx.references is used in genCodeForConverter - val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _)) + val converterTerms = children.indices.map(genCodeForConverter(ctx, _)) // Initialize user-defined function val funcClassName = s"scala.Function${children.size}" @@ -1054,5 +1057,6 @@ case class ScalaUDF( } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 42c373ea723dc..f87a88d49744f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -83,8 +83,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try($inputTypes).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try($inputTypes).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) }""") @@ -115,8 +115,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -128,8 +128,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -141,8 +141,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -154,8 +154,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -167,8 +167,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -180,8 +180,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -193,8 +193,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -206,8 +206,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -219,8 +219,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -232,8 +232,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -245,8 +245,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -258,8 +258,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -271,8 +271,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -284,8 +284,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -297,8 +297,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -310,8 +310,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -323,8 +323,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -336,8 +336,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -349,8 +349,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -362,8 +362,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -375,8 +375,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -388,8 +388,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -401,8 +401,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 0f8cd280b5acb..2fb3bf07aa60b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -44,10 +44,10 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Seq[DataType] = Nil) { + inputTypes: Option[Seq[DataType]]) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index ad55367258890..983f4df1de369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -332,7 +332,10 @@ private[sql] class DynamicPartitionWriterContainer( val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( - PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + PartitioningUtils.escapePathName _, + StringType, + Seq(Cast(c, StringType)), + Seq(StringType)) val str = If(IsNull(c), Literal(defaultPartitionName), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 487191638f5c3..97c5aed6da9c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2529,7 +2529,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val inputTypes = Try($inputTypes).getOrElse(Nil) + val inputTypes = Try($inputTypes).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) }""") } @@ -2549,7 +2549,7 @@ object functions extends LegacyFunctions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr { - ScalaUDF(f, returnType, Seq($argsInUDF)) + ScalaUDF(f, returnType, Option(Seq($argsInUDF))) }""") } */ @@ -2561,7 +2561,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - val inputTypes = Try(Nil).getOrElse(Nil) + val inputTypes = Try(Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2573,7 +2573,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2585,7 +2585,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2597,7 +2597,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2609,7 +2609,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2621,7 +2621,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2633,7 +2633,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2645,7 +2645,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2657,7 +2657,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2669,7 +2669,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2681,7 +2681,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } From b600bccf41a7b1958e33d8301a19214e6517e388 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 29 Dec 2015 18:47:41 -0800 Subject: [PATCH 1372/2089] [SPARK-12362][SQL][WIP] Inline Hive Parser This is a WIP. The PR has been taken over from nongli (see https://github.com/apache/spark/pull/10420). I have removed some additional dead code, and fixed a few issues which were caused by the fact that the inlined Hive parser is newer than the Hive parser we currently use in Spark. I am submitting this PR in order to get some feedback and testing done. There is quite a bit of work to do: - [ ] Get it to pass jenkins build/test. - [ ] Aknowledge Hive-project for using their parser. - [ ] Refactorings between HiveQl and the java classes. - [ ] Create our own ASTNode and integrate the current implicit extentions. - [ ] Move remaining ```SemanticAnalyzer``` and ```ParseUtils``` functionality to ```HiveQl```. - [ ] Removing Hive dependencies from the parser. This will require some edits in the grammar files. - [ ] Introduce our own context which needs to contain a ```TokenRewriteStream```. - [ ] Add ```useSQL11ReservedKeywordsForIdentifier``` and ```allowQuotedId``` to the catalyst or sql configuration. - [ ] Remove ```HiveConf``` from grammar files &HiveQl, and pass in our own configuration. - [ ] Moving the parser into sql/core. cc nongli rxin Author: Herman van Hovell Author: Nong Li Author: Nong Li Closes #10509 from hvanhovell/SPARK-12362. --- pom.xml | 5 + project/SparkBuild.scala | 2 +- project/plugins.sbt | 4 + .../execution/HiveCompatibilitySuite.scala | 10 +- sql/hive/pom.xml | 22 + .../spark/sql/parser/FromClauseParser.g | 330 +++ .../spark/sql/parser/IdentifiersParser.g | 697 +++++ .../spark/sql/parser/SelectClauseParser.g | 226 ++ .../apache/spark/sql/parser/SparkSqlLexer.g | 474 ++++ .../apache/spark/sql/parser/SparkSqlParser.g | 2457 +++++++++++++++++ .../apache/spark/sql/parser/ASTErrorNode.java | 49 + .../org/apache/spark/sql/parser/ASTNode.java | 245 ++ .../apache/spark/sql/parser/ParseDriver.java | 213 ++ .../apache/spark/sql/parser/ParseError.java | 54 + .../spark/sql/parser/ParseException.java | 51 + .../apache/spark/sql/parser/ParseUtils.java | 96 + .../spark/sql/parser/SemanticAnalyzer.java | 406 +++ .../org/apache/spark/sql/hive/HiveQl.scala | 133 +- 18 files changed, 5402 insertions(+), 72 deletions(-) create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java diff --git a/pom.xml b/pom.xml index 284c219519bca..73ba8d555a90c 100644 --- a/pom.xml +++ b/pom.xml @@ -1951,6 +1951,11 @@ + + org.antlr + antlr3-maven-plugin + 3.5.2 + org.apache.maven.plugins diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c3d53f835f395..df21d3eb636f0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -415,7 +415,7 @@ object Hive { // in order to generate golden files. This is only required for developers who are adding new // new query tests. fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } - ) + ) ++ sbtantlr.SbtAntlrPlugin.antlrSettings } diff --git a/project/plugins.sbt b/project/plugins.sbt index 5e23224cf8aa5..f172dc9c1f0e3 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -4,6 +4,8 @@ resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/release resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" +resolvers += "stefri" at "http://stefri.github.io/repo/releases" + addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") @@ -24,6 +26,8 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") +addSbtPlugin("com.github.stefri" % "sbt-antlr" % "0.5.3") + libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2d0d7b8af3581..2b0e48dbfcf28 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -308,7 +308,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The difference between the double numbers generated by Hive and Spark // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) - "udaf_corr" + "udaf_corr", + + // Feature removed in HIVE-11145 + "alter_partition_protect_mode", + "drop_partitions_ignore_protection", + "protectmode" ) /** @@ -328,7 +333,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_index", "alter_merge_2", "alter_partition_format_loc", - "alter_partition_protect_mode", "alter_partition_with_whitelist", "alter_rename_partition", "alter_table_serde", @@ -460,7 +464,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_filter", "drop_partitions_filter2", "drop_partitions_filter3", - "drop_partitions_ignore_protection", "drop_table", "drop_table2", "drop_table_removes_partition_dirs", @@ -778,7 +781,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", - "protectmode", "push_or", "query_with_semi", "quote1", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e9885f6682028..ffabb92179a18 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -232,6 +232,7 @@ v${hive.version.short}/src/main/scala + ${project.build.directory/generated-sources/antlr @@ -260,6 +261,27 @@ + + + + org.antlr + antlr3-maven-plugin + + + + antlr + + + + + ${basedir}/src/main/antlr3 + + **/SparkSqlLexer.g + **/SparkSqlParser.g + + + + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g new file mode 100644 index 0000000000000..e4a80f0ce8ebf --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g @@ -0,0 +1,330 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar FromClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +tableAllColumns + : STAR + -> ^(TOK_ALLCOLREF) + | tableName DOT STAR + -> ^(TOK_ALLCOLREF tableName) + ; + +// (table|column) +tableOrColumn +@init { gParent.pushMsg("table or column identifier", state); } +@after { gParent.popMsg(state); } + : + identifier -> ^(TOK_TABLE_OR_COL identifier) + ; + +expressionList +@init { gParent.pushMsg("expression list", state); } +@after { gParent.popMsg(state); } + : + expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) + ; + +aliasList +@init { gParent.pushMsg("alias list", state); } +@after { gParent.popMsg(state); } + : + identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) + ; + +//----------------------- Rules for parsing fromClause ------------------------------ +// from [col1, col2, col3] table1, [col4, col5] table2 +fromClause +@init { gParent.pushMsg("from clause", state); } +@after { gParent.popMsg(state); } + : + KW_FROM joinSource -> ^(TOK_FROM joinSource) + ; + +joinSource +@init { gParent.pushMsg("join source", state); } +@after { gParent.popMsg(state); } + : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )* + | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ + ; + +uniqueJoinSource +@init { gParent.pushMsg("unique join source", state); } +@after { gParent.popMsg(state); } + : KW_PRESERVE? fromSource uniqueJoinExpr + ; + +uniqueJoinExpr +@init { gParent.pushMsg("unique join expression list", state); } +@after { gParent.popMsg(state); } + : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN + -> ^(TOK_EXPLIST $e1*) + ; + +uniqueJoinToken +@init { gParent.pushMsg("unique join", state); } +@after { gParent.popMsg(state); } + : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; + +joinToken +@init { gParent.pushMsg("join type specifier", state); } +@after { gParent.popMsg(state); } + : + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + ; + +lateralView +@init {gParent.pushMsg("lateral view", state); } +@after {gParent.popMsg(state); } + : + (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + | + KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + ; + +tableAlias +@init {gParent.pushMsg("table alias", state); } +@after {gParent.popMsg(state); } + : + identifier -> ^(TOK_TABALIAS identifier) + ; + +fromSource +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + (LPAREN KW_VALUES) => fromSource0 + | (LPAREN) => LPAREN joinSource RPAREN -> joinSource + | fromSource0 + ; + + +fromSource0 +@init { gParent.pushMsg("from source 0", state); } +@after { gParent.popMsg(state); } + : + ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* + ; + +tableBucketSample +@init { gParent.pushMsg("table bucket sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) + ; + +splitSample +@init { gParent.pushMsg("table split sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN + -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) + -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) + | + KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN + -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) + ; + +tableSample +@init { gParent.pushMsg("table sample specification", state); } +@after { gParent.popMsg(state); } + : + tableBucketSample | + splitSample + ; + +tableSource +@init { gParent.pushMsg("table source", state); } +@after { gParent.popMsg(state); } + : tabname=tableName + ((tableProperties) => props=tableProperties)? + ((tableSample) => ts=tableSample)? + ((KW_AS) => (KW_AS alias=Identifier) + | + (Identifier) => (alias=Identifier))? + -> ^(TOK_TABREF $tabname $props? $ts? $alias?) + ; + +tableName +@init { gParent.pushMsg("table name", state); } +@after { gParent.popMsg(state); } + : + db=identifier DOT tab=identifier + -> ^(TOK_TABNAME $db $tab) + | + tab=identifier + -> ^(TOK_TABNAME $tab) + ; + +viewName +@init { gParent.pushMsg("view name", state); } +@after { gParent.popMsg(state); } + : + (db=identifier DOT)? view=identifier + -> ^(TOK_TABNAME $db? $view) + ; + +subQuerySource +@init { gParent.pushMsg("subquery source", state); } +@after { gParent.popMsg(state); } + : + LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) + ; + +//---------------------- Rules for parsing PTF clauses ----------------------------- +partitioningSpec +@init { gParent.pushMsg("partitioningSpec clause", state); } +@after { gParent.popMsg(state); } + : + partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | + orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | + distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | + sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | + clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) + ; + +partitionTableFunctionSource +@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } +@after { gParent.popMsg(state); } + : + subQuerySource | + tableSource | + partitionedTableFunction + ; + +partitionedTableFunction +@init { gParent.pushMsg("ptf clause", state); } +@after { gParent.popMsg(state); } + : + name=Identifier LPAREN KW_ON + ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) + ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? + ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? + -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) + ; + +//----------------------- Rules for parsing whereClause ----------------------------- +// where a=b and ... +whereClause +@init { gParent.pushMsg("where clause", state); } +@after { gParent.popMsg(state); } + : + KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) + ; + +searchCondition +@init { gParent.pushMsg("search condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +//----------------------------------------------------------------------------------- + +//-------- Row Constructor ---------------------------------------------------------- +//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and +// INSERT INTO (col1,col2,...) VALUES(...),(...),... +// INSERT INTO
    (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) +valueRowConstructor +@init { gParent.pushMsg("value row constructor", state); } +@after { gParent.popMsg(state); } + : + LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) + ; + +valuesTableConstructor +@init { gParent.pushMsg("values table constructor", state); } +@after { gParent.popMsg(state); } + : + valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) + ; + +/* +VALUES(1),(2) means 2 rows, 1 column each. +VALUES(1,2),(3,4) means 2 rows, 2 columns each. +VALUES(1,2,3) means 1 row, 3 columns +*/ +valuesClause +@init { gParent.pushMsg("values clause", state); } +@after { gParent.popMsg(state); } + : + KW_VALUES valuesTableConstructor -> valuesTableConstructor + ; + +/* +This represents a clause like this: +(VALUES(1,2),(2,3)) as VirtTable(col1,col2) +*/ +virtualTableSource +@init { gParent.pushMsg("virtual table source", state); } +@after { gParent.popMsg(state); } + : + LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) + ; +/* +e.g. as VirtTable(col1,col2) +Note that we only want literals as column names +*/ +tableNameColList +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) + ; + +//----------------------------------------------------------------------------------- diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g new file mode 100644 index 0000000000000..5c3d7ef866240 --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g @@ -0,0 +1,697 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar IdentifiersParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +// group by a,b +groupByClause +@init { gParent.pushMsg("group by clause", state); } +@after { gParent.popMsg(state); } + : + KW_GROUP KW_BY + expression + ( COMMA expression)* + ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? + (sets=KW_GROUPING KW_SETS + LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? + -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) + -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) + -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) + -> ^(TOK_GROUPBY expression+) + ; + +groupingSetExpression +@init {gParent.pushMsg("grouping set expression", state); } +@after {gParent.popMsg(state); } + : + (LPAREN) => groupingSetExpressionMultiple + | + groupingExpressionSingle + ; + +groupingSetExpressionMultiple +@init {gParent.pushMsg("grouping set part expression", state); } +@after {gParent.popMsg(state); } + : + LPAREN + expression? (COMMA expression)* + RPAREN + -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) + ; + +groupingExpressionSingle +@init { gParent.pushMsg("groupingExpression expression", state); } +@after { gParent.popMsg(state); } + : + expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) + ; + +havingClause +@init { gParent.pushMsg("having clause", state); } +@after { gParent.popMsg(state); } + : + KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) + ; + +havingCondition +@init { gParent.pushMsg("having condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +expressionsInParenthese + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +expressionsNotInParenthese + : + expression (COMMA expression)* -> expression+ + ; + +columnRefOrderInParenthese + : + LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ + ; + +columnRefOrderNotInParenthese + : + columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ + ; + +// order by a,b +orderByClause +@init { gParent.pushMsg("order by clause", state); } +@after { gParent.popMsg(state); } + : + KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) + ; + +clusterByClause +@init { gParent.pushMsg("cluster by clause", state); } +@after { gParent.popMsg(state); } + : + KW_CLUSTER KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) + ) + ; + +partitionByClause +@init { gParent.pushMsg("partition by clause", state); } +@after { gParent.popMsg(state); } + : + KW_PARTITION KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +distributeByClause +@init { gParent.pushMsg("distribute by clause", state); } +@after { gParent.popMsg(state); } + : + KW_DISTRIBUTE KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +sortByClause +@init { gParent.pushMsg("sort by clause", state); } +@after { gParent.popMsg(state); } + : + KW_SORT KW_BY + ( + (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) + | + columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) + ) + ; + +// fun(par1, par2, par3) +function +@init { gParent.pushMsg("function specification", state); } +@after { gParent.popMsg(state); } + : + functionName + LPAREN + ( + (STAR) => (star=STAR) + | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? + ) + RPAREN (KW_OVER ws=window_specification)? + -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) + -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) + -> ^(TOK_FUNCTIONDI functionName (selectExpression+)?) + ; + +functionName +@init { gParent.pushMsg("function name", state); } +@after { gParent.popMsg(state); } + : // Keyword IF is also a function name + (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) + | + (functionIdentifier) => functionIdentifier + | + {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] + ; + +castExpression +@init { gParent.pushMsg("cast expression", state); } +@after { gParent.popMsg(state); } + : + KW_CAST + LPAREN + expression + KW_AS + primitiveType + RPAREN -> ^(TOK_FUNCTION primitiveType expression) + ; + +caseExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE expression + (KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_CASE expression*) + ; + +whenExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE + ( KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) + ; + +constant +@init { gParent.pushMsg("constant", state); } +@after { gParent.popMsg(state); } + : + Number + | dateLiteral + | timestampLiteral + | intervalLiteral + | StringLiteral + | stringLiteralSequence + | BigintLiteral + | SmallintLiteral + | TinyintLiteral + | DecimalLiteral + | charSetStringLiteral + | booleanValue + ; + +stringLiteralSequence + : + StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) + ; + +charSetStringLiteral +@init { gParent.pushMsg("character string literal", state); } +@after { gParent.popMsg(state); } + : + csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) + ; + +dateLiteral + : + KW_DATE StringLiteral -> + { + // Create DateLiteral token, but with the text of the string value + // This makes the dateLiteral more consistent with the other type literals. + adaptor.create(TOK_DATELITERAL, $StringLiteral.text) + } + | + KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) + ; + +timestampLiteral + : + KW_TIMESTAMP StringLiteral -> + { + adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) + } + | + KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) + ; + +intervalLiteral + : + KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> + { + adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + } + ; + +intervalQualifiers + : + KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL + | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL + | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL + | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL + | KW_DAY -> TOK_INTERVAL_DAY_LITERAL + | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL + | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL + | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + ; + +expression +@init { gParent.pushMsg("expression specification", state); } +@after { gParent.popMsg(state); } + : + precedenceOrExpression + ; + +atomExpression + : + (KW_NULL) => KW_NULL -> TOK_NULL + | (constant) => constant + | castExpression + | caseExpression + | whenExpression + | (functionName LPAREN) => function + | tableOrColumn + | LPAREN! expression RPAREN! + ; + + +precedenceFieldExpression + : + atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* + ; + +precedenceUnaryOperator + : + PLUS | MINUS | TILDE + ; + +nullCondition + : + KW_NULL -> ^(TOK_ISNULL) + | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) + ; + +precedenceUnaryPrefixExpression + : + (precedenceUnaryOperator^)* precedenceFieldExpression + ; + +precedenceUnarySuffixExpression + : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) + -> precedenceUnaryPrefixExpression + ; + + +precedenceBitwiseXorOperator + : + BITWISEXOR + ; + +precedenceBitwiseXorExpression + : + precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* + ; + + +precedenceStarOperator + : + STAR | DIVIDE | MOD | DIV + ; + +precedenceStarExpression + : + precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* + ; + + +precedencePlusOperator + : + PLUS | MINUS + ; + +precedencePlusExpression + : + precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* + ; + + +precedenceAmpersandOperator + : + AMPERSAND + ; + +precedenceAmpersandExpression + : + precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* + ; + + +precedenceBitwiseOrOperator + : + BITWISEOR + ; + +precedenceBitwiseOrExpression + : + precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* + ; + + +// Equal operators supporting NOT prefix +precedenceEqualNegatableOperator + : + KW_LIKE | KW_RLIKE | KW_REGEXP + ; + +precedenceEqualOperator + : + precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +subQueryExpression + : + LPAREN! selectStatement[true] RPAREN! + ; + +precedenceEqualExpression + : + (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple + | + precedenceEqualExpressionSingle + ; + +precedenceEqualExpressionSingle + : + (left=precedenceBitwiseOrExpression -> $left) + ( + (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) + -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) + | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) + -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) + | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) + -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) + | (KW_NOT KW_IN expressions) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) + | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) + | (KW_IN expressions) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) + | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) + | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) + )* + | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) + ; + +expressions + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) +precedenceEqualExpressionMutiple + : + (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) + ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) + | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) + ; + +expressionsToStruct + : + LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) + ; + +precedenceNotOperator + : + KW_NOT + ; + +precedenceNotExpression + : + (precedenceNotOperator^)* precedenceEqualExpression + ; + + +precedenceAndOperator + : + KW_AND + ; + +precedenceAndExpression + : + precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* + ; + + +precedenceOrOperator + : + KW_OR + ; + +precedenceOrExpression + : + precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* + ; + + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : db=identifier DOT fn=identifier + -> Identifier[$db.text + "." + $fn.text] + | + identifier + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g new file mode 100644 index 0000000000000..48bc8b0a300af --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g @@ -0,0 +1,226 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar SelectClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------- Rules for parsing selectClause ----------------------------- +// select a,b,c ... +selectClause +@init { gParent.pushMsg("select clause", state); } +@after { gParent.popMsg(state); } + : + KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) + | (transform=KW_TRANSFORM selectTrfmClause)) + -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) + -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) + -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) + | + trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) + ; + +selectList +@init { gParent.pushMsg("select list", state); } +@after { gParent.popMsg(state); } + : + selectItem ( COMMA selectItem )* -> selectItem+ + ; + +selectTrfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + LPAREN selectExpressionList RPAREN + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +hintClause +@init { gParent.pushMsg("hint clause", state); } +@after { gParent.popMsg(state); } + : + DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) + ; + +hintList +@init { gParent.pushMsg("hint list", state); } +@after { gParent.popMsg(state); } + : + hintItem (COMMA hintItem)* -> hintItem+ + ; + +hintItem +@init { gParent.pushMsg("hint item", state); } +@after { gParent.popMsg(state); } + : + hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) + ; + +hintName +@init { gParent.pushMsg("hint name", state); } +@after { gParent.popMsg(state); } + : + KW_MAPJOIN -> TOK_MAPJOIN + | KW_STREAMTABLE -> TOK_STREAMTABLE + ; + +hintArgs +@init { gParent.pushMsg("hint arguments", state); } +@after { gParent.popMsg(state); } + : + hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) + ; + +hintArgName +@init { gParent.pushMsg("hint argument name", state); } +@after { gParent.popMsg(state); } + : + identifier + ; + +selectItem +@init { gParent.pushMsg("selection target", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) + | + ( expression + ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? + ) -> ^(TOK_SELEXPR expression identifier*) + ; + +trfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + ( KW_MAP selectExpressionList + | KW_REDUCE selectExpressionList ) + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +selectExpression +@init { gParent.pushMsg("select expression", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns + | + expression + ; + +selectExpressionList +@init { gParent.pushMsg("select expression list", state); } +@after { gParent.popMsg(state); } + : + selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) + ; + +//---------------------- Rules for windowing clauses ------------------------------- +window_clause +@init { gParent.pushMsg("window_clause", state); } +@after { gParent.popMsg(state); } +: + KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) +; + +window_defn +@init { gParent.pushMsg("window_defn", state); } +@after { gParent.popMsg(state); } +: + Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) +; + +window_specification +@init { gParent.pushMsg("window_specification", state); } +@after { gParent.popMsg(state); } +: + (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) +; + +window_frame : + window_range_expression | + window_value_expression +; + +window_range_expression +@init { gParent.pushMsg("window_range_expression", state); } +@after { gParent.popMsg(state); } +: + KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | + KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) +; + +window_value_expression +@init { gParent.pushMsg("window_value_expression", state); } +@after { gParent.popMsg(state); } +: + KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | + KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) +; + +window_frame_start_boundary +@init { gParent.pushMsg("windowframestartboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number KW_PRECEDING -> ^(KW_PRECEDING Number) +; + +window_frame_boundary +@init { gParent.pushMsg("windowframeboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) +; + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g new file mode 100644 index 0000000000000..ee1b8989b5aff --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g @@ -0,0 +1,474 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +lexer grammar SparkSqlLexer; + +@lexer::header { +package org.apache.spark.sql.parser; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +} + +@lexer::members { + private Configuration hiveConf; + + public void setHiveConf(Configuration hiveConf) { + this.hiveConf = hiveConf; + } + + protected boolean allowQuotedId() { + String supportedQIds = HiveConf.getVar(hiveConf, HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT); + return !"none".equals(supportedQIds); + } +} + +// Keywords + +KW_TRUE : 'TRUE'; +KW_FALSE : 'FALSE'; +KW_ALL : 'ALL'; +KW_NONE: 'NONE'; +KW_AND : 'AND'; +KW_OR : 'OR'; +KW_NOT : 'NOT' | '!'; +KW_LIKE : 'LIKE'; + +KW_IF : 'IF'; +KW_EXISTS : 'EXISTS'; + +KW_ASC : 'ASC'; +KW_DESC : 'DESC'; +KW_ORDER : 'ORDER'; +KW_GROUP : 'GROUP'; +KW_BY : 'BY'; +KW_HAVING : 'HAVING'; +KW_WHERE : 'WHERE'; +KW_FROM : 'FROM'; +KW_AS : 'AS'; +KW_SELECT : 'SELECT'; +KW_DISTINCT : 'DISTINCT'; +KW_INSERT : 'INSERT'; +KW_OVERWRITE : 'OVERWRITE'; +KW_OUTER : 'OUTER'; +KW_UNIQUEJOIN : 'UNIQUEJOIN'; +KW_PRESERVE : 'PRESERVE'; +KW_JOIN : 'JOIN'; +KW_LEFT : 'LEFT'; +KW_RIGHT : 'RIGHT'; +KW_FULL : 'FULL'; +KW_ANTI : 'ANTI'; +KW_ON : 'ON'; +KW_PARTITION : 'PARTITION'; +KW_PARTITIONS : 'PARTITIONS'; +KW_TABLE: 'TABLE'; +KW_TABLES: 'TABLES'; +KW_COLUMNS: 'COLUMNS'; +KW_INDEX: 'INDEX'; +KW_INDEXES: 'INDEXES'; +KW_REBUILD: 'REBUILD'; +KW_FUNCTIONS: 'FUNCTIONS'; +KW_SHOW: 'SHOW'; +KW_MSCK: 'MSCK'; +KW_REPAIR: 'REPAIR'; +KW_DIRECTORY: 'DIRECTORY'; +KW_LOCAL: 'LOCAL'; +KW_TRANSFORM : 'TRANSFORM'; +KW_USING: 'USING'; +KW_CLUSTER: 'CLUSTER'; +KW_DISTRIBUTE: 'DISTRIBUTE'; +KW_SORT: 'SORT'; +KW_UNION: 'UNION'; +KW_LOAD: 'LOAD'; +KW_EXPORT: 'EXPORT'; +KW_IMPORT: 'IMPORT'; +KW_REPLICATION: 'REPLICATION'; +KW_METADATA: 'METADATA'; +KW_DATA: 'DATA'; +KW_INPATH: 'INPATH'; +KW_IS: 'IS'; +KW_NULL: 'NULL'; +KW_CREATE: 'CREATE'; +KW_EXTERNAL: 'EXTERNAL'; +KW_ALTER: 'ALTER'; +KW_CHANGE: 'CHANGE'; +KW_COLUMN: 'COLUMN'; +KW_FIRST: 'FIRST'; +KW_AFTER: 'AFTER'; +KW_DESCRIBE: 'DESCRIBE'; +KW_DROP: 'DROP'; +KW_RENAME: 'RENAME'; +KW_TO: 'TO'; +KW_COMMENT: 'COMMENT'; +KW_BOOLEAN: 'BOOLEAN'; +KW_TINYINT: 'TINYINT'; +KW_SMALLINT: 'SMALLINT'; +KW_INT: 'INT'; +KW_BIGINT: 'BIGINT'; +KW_FLOAT: 'FLOAT'; +KW_DOUBLE: 'DOUBLE'; +KW_DATE: 'DATE'; +KW_DATETIME: 'DATETIME'; +KW_TIMESTAMP: 'TIMESTAMP'; +KW_INTERVAL: 'INTERVAL'; +KW_DECIMAL: 'DECIMAL'; +KW_STRING: 'STRING'; +KW_CHAR: 'CHAR'; +KW_VARCHAR: 'VARCHAR'; +KW_ARRAY: 'ARRAY'; +KW_STRUCT: 'STRUCT'; +KW_MAP: 'MAP'; +KW_UNIONTYPE: 'UNIONTYPE'; +KW_REDUCE: 'REDUCE'; +KW_PARTITIONED: 'PARTITIONED'; +KW_CLUSTERED: 'CLUSTERED'; +KW_SORTED: 'SORTED'; +KW_INTO: 'INTO'; +KW_BUCKETS: 'BUCKETS'; +KW_ROW: 'ROW'; +KW_ROWS: 'ROWS'; +KW_FORMAT: 'FORMAT'; +KW_DELIMITED: 'DELIMITED'; +KW_FIELDS: 'FIELDS'; +KW_TERMINATED: 'TERMINATED'; +KW_ESCAPED: 'ESCAPED'; +KW_COLLECTION: 'COLLECTION'; +KW_ITEMS: 'ITEMS'; +KW_KEYS: 'KEYS'; +KW_KEY_TYPE: '$KEY$'; +KW_LINES: 'LINES'; +KW_STORED: 'STORED'; +KW_FILEFORMAT: 'FILEFORMAT'; +KW_INPUTFORMAT: 'INPUTFORMAT'; +KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; +KW_INPUTDRIVER: 'INPUTDRIVER'; +KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; +KW_ENABLE: 'ENABLE'; +KW_DISABLE: 'DISABLE'; +KW_LOCATION: 'LOCATION'; +KW_TABLESAMPLE: 'TABLESAMPLE'; +KW_BUCKET: 'BUCKET'; +KW_OUT: 'OUT'; +KW_OF: 'OF'; +KW_PERCENT: 'PERCENT'; +KW_CAST: 'CAST'; +KW_ADD: 'ADD'; +KW_REPLACE: 'REPLACE'; +KW_RLIKE: 'RLIKE'; +KW_REGEXP: 'REGEXP'; +KW_TEMPORARY: 'TEMPORARY'; +KW_FUNCTION: 'FUNCTION'; +KW_MACRO: 'MACRO'; +KW_FILE: 'FILE'; +KW_JAR: 'JAR'; +KW_EXPLAIN: 'EXPLAIN'; +KW_EXTENDED: 'EXTENDED'; +KW_FORMATTED: 'FORMATTED'; +KW_PRETTY: 'PRETTY'; +KW_DEPENDENCY: 'DEPENDENCY'; +KW_LOGICAL: 'LOGICAL'; +KW_SERDE: 'SERDE'; +KW_WITH: 'WITH'; +KW_DEFERRED: 'DEFERRED'; +KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; +KW_DBPROPERTIES: 'DBPROPERTIES'; +KW_LIMIT: 'LIMIT'; +KW_SET: 'SET'; +KW_UNSET: 'UNSET'; +KW_TBLPROPERTIES: 'TBLPROPERTIES'; +KW_IDXPROPERTIES: 'IDXPROPERTIES'; +KW_VALUE_TYPE: '$VALUE$'; +KW_ELEM_TYPE: '$ELEM$'; +KW_DEFINED: 'DEFINED'; +KW_CASE: 'CASE'; +KW_WHEN: 'WHEN'; +KW_THEN: 'THEN'; +KW_ELSE: 'ELSE'; +KW_END: 'END'; +KW_MAPJOIN: 'MAPJOIN'; +KW_STREAMTABLE: 'STREAMTABLE'; +KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; +KW_UTC: 'UTC'; +KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; +KW_LONG: 'LONG'; +KW_DELETE: 'DELETE'; +KW_PLUS: 'PLUS'; +KW_MINUS: 'MINUS'; +KW_FETCH: 'FETCH'; +KW_INTERSECT: 'INTERSECT'; +KW_VIEW: 'VIEW'; +KW_IN: 'IN'; +KW_DATABASE: 'DATABASE'; +KW_DATABASES: 'DATABASES'; +KW_MATERIALIZED: 'MATERIALIZED'; +KW_SCHEMA: 'SCHEMA'; +KW_SCHEMAS: 'SCHEMAS'; +KW_GRANT: 'GRANT'; +KW_REVOKE: 'REVOKE'; +KW_SSL: 'SSL'; +KW_UNDO: 'UNDO'; +KW_LOCK: 'LOCK'; +KW_LOCKS: 'LOCKS'; +KW_UNLOCK: 'UNLOCK'; +KW_SHARED: 'SHARED'; +KW_EXCLUSIVE: 'EXCLUSIVE'; +KW_PROCEDURE: 'PROCEDURE'; +KW_UNSIGNED: 'UNSIGNED'; +KW_WHILE: 'WHILE'; +KW_READ: 'READ'; +KW_READS: 'READS'; +KW_PURGE: 'PURGE'; +KW_RANGE: 'RANGE'; +KW_ANALYZE: 'ANALYZE'; +KW_BEFORE: 'BEFORE'; +KW_BETWEEN: 'BETWEEN'; +KW_BOTH: 'BOTH'; +KW_BINARY: 'BINARY'; +KW_CROSS: 'CROSS'; +KW_CONTINUE: 'CONTINUE'; +KW_CURSOR: 'CURSOR'; +KW_TRIGGER: 'TRIGGER'; +KW_RECORDREADER: 'RECORDREADER'; +KW_RECORDWRITER: 'RECORDWRITER'; +KW_SEMI: 'SEMI'; +KW_LATERAL: 'LATERAL'; +KW_TOUCH: 'TOUCH'; +KW_ARCHIVE: 'ARCHIVE'; +KW_UNARCHIVE: 'UNARCHIVE'; +KW_COMPUTE: 'COMPUTE'; +KW_STATISTICS: 'STATISTICS'; +KW_USE: 'USE'; +KW_OPTION: 'OPTION'; +KW_CONCATENATE: 'CONCATENATE'; +KW_SHOW_DATABASE: 'SHOW_DATABASE'; +KW_UPDATE: 'UPDATE'; +KW_RESTRICT: 'RESTRICT'; +KW_CASCADE: 'CASCADE'; +KW_SKEWED: 'SKEWED'; +KW_ROLLUP: 'ROLLUP'; +KW_CUBE: 'CUBE'; +KW_DIRECTORIES: 'DIRECTORIES'; +KW_FOR: 'FOR'; +KW_WINDOW: 'WINDOW'; +KW_UNBOUNDED: 'UNBOUNDED'; +KW_PRECEDING: 'PRECEDING'; +KW_FOLLOWING: 'FOLLOWING'; +KW_CURRENT: 'CURRENT'; +KW_CURRENT_DATE: 'CURRENT_DATE'; +KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +KW_LESS: 'LESS'; +KW_MORE: 'MORE'; +KW_OVER: 'OVER'; +KW_GROUPING: 'GROUPING'; +KW_SETS: 'SETS'; +KW_TRUNCATE: 'TRUNCATE'; +KW_NOSCAN: 'NOSCAN'; +KW_PARTIALSCAN: 'PARTIALSCAN'; +KW_USER: 'USER'; +KW_ROLE: 'ROLE'; +KW_ROLES: 'ROLES'; +KW_INNER: 'INNER'; +KW_EXCHANGE: 'EXCHANGE'; +KW_URI: 'URI'; +KW_SERVER : 'SERVER'; +KW_ADMIN: 'ADMIN'; +KW_OWNER: 'OWNER'; +KW_PRINCIPALS: 'PRINCIPALS'; +KW_COMPACT: 'COMPACT'; +KW_COMPACTIONS: 'COMPACTIONS'; +KW_TRANSACTIONS: 'TRANSACTIONS'; +KW_REWRITE : 'REWRITE'; +KW_AUTHORIZATION: 'AUTHORIZATION'; +KW_CONF: 'CONF'; +KW_VALUES: 'VALUES'; +KW_RELOAD: 'RELOAD'; +KW_YEAR: 'YEAR'; +KW_MONTH: 'MONTH'; +KW_DAY: 'DAY'; +KW_HOUR: 'HOUR'; +KW_MINUTE: 'MINUTE'; +KW_SECOND: 'SECOND'; +KW_START: 'START'; +KW_TRANSACTION: 'TRANSACTION'; +KW_COMMIT: 'COMMIT'; +KW_ROLLBACK: 'ROLLBACK'; +KW_WORK: 'WORK'; +KW_ONLY: 'ONLY'; +KW_WRITE: 'WRITE'; +KW_ISOLATION: 'ISOLATION'; +KW_LEVEL: 'LEVEL'; +KW_SNAPSHOT: 'SNAPSHOT'; +KW_AUTOCOMMIT: 'AUTOCOMMIT'; + +// Operators +// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. + +DOT : '.'; // generated as a part of Number rule +COLON : ':' ; +COMMA : ',' ; +SEMICOLON : ';' ; + +LPAREN : '(' ; +RPAREN : ')' ; +LSQUARE : '[' ; +RSQUARE : ']' ; +LCURLY : '{'; +RCURLY : '}'; + +EQUAL : '=' | '=='; +EQUAL_NS : '<=>'; +NOTEQUAL : '<>' | '!='; +LESSTHANOREQUALTO : '<='; +LESSTHAN : '<'; +GREATERTHANOREQUALTO : '>='; +GREATERTHAN : '>'; + +DIVIDE : '/'; +PLUS : '+'; +MINUS : '-'; +STAR : '*'; +MOD : '%'; +DIV : 'DIV'; + +AMPERSAND : '&'; +TILDE : '~'; +BITWISEOR : '|'; +BITWISEXOR : '^'; +QUESTION : '?'; +DOLLAR : '$'; + +// LITERALS +fragment +Letter + : 'a'..'z' | 'A'..'Z' + ; + +fragment +HexDigit + : 'a'..'f' | 'A'..'F' + ; + +fragment +Digit + : + '0'..'9' + ; + +fragment +Exponent + : + ('e' | 'E') ( PLUS|MINUS )? (Digit)+ + ; + +fragment +RegexComponent + : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' + | PLUS | STAR | QUESTION | MINUS | DOT + | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY + | BITWISEXOR | BITWISEOR | DOLLAR | '!' + ; + +StringLiteral + : + ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + )+ + ; + +CharSetLiteral + : + StringLiteral + | '0' 'X' (HexDigit|Digit)+ + ; + +BigintLiteral + : + (Digit)+ 'L' + ; + +SmallintLiteral + : + (Digit)+ 'S' + ; + +TinyintLiteral + : + (Digit)+ 'Y' + ; + +DecimalLiteral + : + Number 'B' 'D' + ; + +ByteLengthLiteral + : + (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') + ; + +Number + : + (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? + ; + +/* +An Identifier can be: +- tableName +- columnName +- select expr alias +- lateral view aliases +- database name +- view name +- subquery alias +- function name +- ptf argument identifier +- index name +- property name for: db,tbl,partition... +- fileFormat +- role name +- privilege name +- principal name +- macro name +- hint name +- window name +*/ +Identifier + : + (Letter | Digit) (Letter | Digit | '_')* + | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; + at the API level only columns are allowed to be of this form */ + | '`' RegexComponent+ '`' + ; + +fragment +QuotedIdentifier + : + '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + ; + +CharSetName + : + '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ + ; + +WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} + ; + +COMMENT + : '--' (~('\n'|'\r'))* + { $channel=HIDDEN; } + ; + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g new file mode 100644 index 0000000000000..69574d713d0be --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g @@ -0,0 +1,2457 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar SparkSqlParser; + +options +{ +tokenVocab=SparkSqlLexer; +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} +import SelectClauseParser, FromClauseParser, IdentifiersParser; + +tokens { +TOK_INSERT; +TOK_QUERY; +TOK_SELECT; +TOK_SELECTDI; +TOK_SELEXPR; +TOK_FROM; +TOK_TAB; +TOK_PARTSPEC; +TOK_PARTVAL; +TOK_DIR; +TOK_TABREF; +TOK_SUBQUERY; +TOK_INSERT_INTO; +TOK_DESTINATION; +TOK_ALLCOLREF; +TOK_TABLE_OR_COL; +TOK_FUNCTION; +TOK_FUNCTIONDI; +TOK_FUNCTIONSTAR; +TOK_WHERE; +TOK_OP_EQ; +TOK_OP_NE; +TOK_OP_LE; +TOK_OP_LT; +TOK_OP_GE; +TOK_OP_GT; +TOK_OP_DIV; +TOK_OP_ADD; +TOK_OP_SUB; +TOK_OP_MUL; +TOK_OP_MOD; +TOK_OP_BITAND; +TOK_OP_BITNOT; +TOK_OP_BITOR; +TOK_OP_BITXOR; +TOK_OP_AND; +TOK_OP_OR; +TOK_OP_NOT; +TOK_OP_LIKE; +TOK_TRUE; +TOK_FALSE; +TOK_TRANSFORM; +TOK_SERDE; +TOK_SERDENAME; +TOK_SERDEPROPS; +TOK_EXPLIST; +TOK_ALIASLIST; +TOK_GROUPBY; +TOK_ROLLUP_GROUPBY; +TOK_CUBE_GROUPBY; +TOK_GROUPING_SETS; +TOK_GROUPING_SETS_EXPRESSION; +TOK_HAVING; +TOK_ORDERBY; +TOK_CLUSTERBY; +TOK_DISTRIBUTEBY; +TOK_SORTBY; +TOK_UNIONALL; +TOK_UNIONDISTINCT; +TOK_JOIN; +TOK_LEFTOUTERJOIN; +TOK_RIGHTOUTERJOIN; +TOK_FULLOUTERJOIN; +TOK_UNIQUEJOIN; +TOK_CROSSJOIN; +TOK_LOAD; +TOK_EXPORT; +TOK_IMPORT; +TOK_REPLICATION; +TOK_METADATA; +TOK_NULL; +TOK_ISNULL; +TOK_ISNOTNULL; +TOK_TINYINT; +TOK_SMALLINT; +TOK_INT; +TOK_BIGINT; +TOK_BOOLEAN; +TOK_FLOAT; +TOK_DOUBLE; +TOK_DATE; +TOK_DATELITERAL; +TOK_DATETIME; +TOK_TIMESTAMP; +TOK_TIMESTAMPLITERAL; +TOK_INTERVAL_YEAR_MONTH; +TOK_INTERVAL_YEAR_MONTH_LITERAL; +TOK_INTERVAL_DAY_TIME; +TOK_INTERVAL_DAY_TIME_LITERAL; +TOK_INTERVAL_YEAR_LITERAL; +TOK_INTERVAL_MONTH_LITERAL; +TOK_INTERVAL_DAY_LITERAL; +TOK_INTERVAL_HOUR_LITERAL; +TOK_INTERVAL_MINUTE_LITERAL; +TOK_INTERVAL_SECOND_LITERAL; +TOK_STRING; +TOK_CHAR; +TOK_VARCHAR; +TOK_BINARY; +TOK_DECIMAL; +TOK_LIST; +TOK_STRUCT; +TOK_MAP; +TOK_UNIONTYPE; +TOK_COLTYPELIST; +TOK_CREATEDATABASE; +TOK_CREATETABLE; +TOK_TRUNCATETABLE; +TOK_CREATEINDEX; +TOK_CREATEINDEX_INDEXTBLNAME; +TOK_DEFERRED_REBUILDINDEX; +TOK_DROPINDEX; +TOK_LIKETABLE; +TOK_DESCTABLE; +TOK_DESCFUNCTION; +TOK_ALTERTABLE; +TOK_ALTERTABLE_RENAME; +TOK_ALTERTABLE_ADDCOLS; +TOK_ALTERTABLE_RENAMECOL; +TOK_ALTERTABLE_RENAMEPART; +TOK_ALTERTABLE_REPLACECOLS; +TOK_ALTERTABLE_ADDPARTS; +TOK_ALTERTABLE_DROPPARTS; +TOK_ALTERTABLE_PARTCOLTYPE; +TOK_ALTERTABLE_MERGEFILES; +TOK_ALTERTABLE_TOUCH; +TOK_ALTERTABLE_ARCHIVE; +TOK_ALTERTABLE_UNARCHIVE; +TOK_ALTERTABLE_SERDEPROPERTIES; +TOK_ALTERTABLE_SERIALIZER; +TOK_ALTERTABLE_UPDATECOLSTATS; +TOK_TABLE_PARTITION; +TOK_ALTERTABLE_FILEFORMAT; +TOK_ALTERTABLE_LOCATION; +TOK_ALTERTABLE_PROPERTIES; +TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; +TOK_ALTERTABLE_DROPPROPERTIES; +TOK_ALTERTABLE_SKEWED; +TOK_ALTERTABLE_EXCHANGEPARTITION; +TOK_ALTERTABLE_SKEWED_LOCATION; +TOK_ALTERTABLE_BUCKETS; +TOK_ALTERTABLE_CLUSTER_SORT; +TOK_ALTERTABLE_COMPACT; +TOK_ALTERINDEX_REBUILD; +TOK_ALTERINDEX_PROPERTIES; +TOK_MSCK; +TOK_SHOWDATABASES; +TOK_SHOWTABLES; +TOK_SHOWCOLUMNS; +TOK_SHOWFUNCTIONS; +TOK_SHOWPARTITIONS; +TOK_SHOW_CREATEDATABASE; +TOK_SHOW_CREATETABLE; +TOK_SHOW_TABLESTATUS; +TOK_SHOW_TBLPROPERTIES; +TOK_SHOWLOCKS; +TOK_SHOWCONF; +TOK_LOCKTABLE; +TOK_UNLOCKTABLE; +TOK_LOCKDB; +TOK_UNLOCKDB; +TOK_SWITCHDATABASE; +TOK_DROPDATABASE; +TOK_DROPTABLE; +TOK_DATABASECOMMENT; +TOK_TABCOLLIST; +TOK_TABCOL; +TOK_TABLECOMMENT; +TOK_TABLEPARTCOLS; +TOK_TABLEROWFORMAT; +TOK_TABLEROWFORMATFIELD; +TOK_TABLEROWFORMATCOLLITEMS; +TOK_TABLEROWFORMATMAPKEYS; +TOK_TABLEROWFORMATLINES; +TOK_TABLEROWFORMATNULL; +TOK_TABLEFILEFORMAT; +TOK_FILEFORMAT_GENERIC; +TOK_OFFLINE; +TOK_ENABLE; +TOK_DISABLE; +TOK_READONLY; +TOK_NO_DROP; +TOK_STORAGEHANDLER; +TOK_NOT_CLUSTERED; +TOK_NOT_SORTED; +TOK_TABCOLNAME; +TOK_TABLELOCATION; +TOK_PARTITIONLOCATION; +TOK_TABLEBUCKETSAMPLE; +TOK_TABLESPLITSAMPLE; +TOK_PERCENT; +TOK_LENGTH; +TOK_ROWCOUNT; +TOK_TMP_FILE; +TOK_TABSORTCOLNAMEASC; +TOK_TABSORTCOLNAMEDESC; +TOK_STRINGLITERALSEQUENCE; +TOK_CHARSETLITERAL; +TOK_CREATEFUNCTION; +TOK_DROPFUNCTION; +TOK_RELOADFUNCTION; +TOK_CREATEMACRO; +TOK_DROPMACRO; +TOK_TEMPORARY; +TOK_CREATEVIEW; +TOK_DROPVIEW; +TOK_ALTERVIEW; +TOK_ALTERVIEW_PROPERTIES; +TOK_ALTERVIEW_DROPPROPERTIES; +TOK_ALTERVIEW_ADDPARTS; +TOK_ALTERVIEW_DROPPARTS; +TOK_ALTERVIEW_RENAME; +TOK_VIEWPARTCOLS; +TOK_EXPLAIN; +TOK_EXPLAIN_SQ_REWRITE; +TOK_TABLESERIALIZER; +TOK_TABLEPROPERTIES; +TOK_TABLEPROPLIST; +TOK_INDEXPROPERTIES; +TOK_INDEXPROPLIST; +TOK_TABTYPE; +TOK_LIMIT; +TOK_TABLEPROPERTY; +TOK_IFEXISTS; +TOK_IFNOTEXISTS; +TOK_ORREPLACE; +TOK_HINTLIST; +TOK_HINT; +TOK_MAPJOIN; +TOK_STREAMTABLE; +TOK_HINTARGLIST; +TOK_USERSCRIPTCOLNAMES; +TOK_USERSCRIPTCOLSCHEMA; +TOK_RECORDREADER; +TOK_RECORDWRITER; +TOK_LEFTSEMIJOIN; +TOK_ANTIJOIN; +TOK_LATERAL_VIEW; +TOK_LATERAL_VIEW_OUTER; +TOK_TABALIAS; +TOK_ANALYZE; +TOK_CREATEROLE; +TOK_DROPROLE; +TOK_GRANT; +TOK_REVOKE; +TOK_SHOW_GRANT; +TOK_PRIVILEGE_LIST; +TOK_PRIVILEGE; +TOK_PRINCIPAL_NAME; +TOK_USER; +TOK_GROUP; +TOK_ROLE; +TOK_RESOURCE_ALL; +TOK_GRANT_WITH_OPTION; +TOK_GRANT_WITH_ADMIN_OPTION; +TOK_ADMIN_OPTION_FOR; +TOK_GRANT_OPTION_FOR; +TOK_PRIV_ALL; +TOK_PRIV_ALTER_METADATA; +TOK_PRIV_ALTER_DATA; +TOK_PRIV_DELETE; +TOK_PRIV_DROP; +TOK_PRIV_INDEX; +TOK_PRIV_INSERT; +TOK_PRIV_LOCK; +TOK_PRIV_SELECT; +TOK_PRIV_SHOW_DATABASE; +TOK_PRIV_CREATE; +TOK_PRIV_OBJECT; +TOK_PRIV_OBJECT_COL; +TOK_GRANT_ROLE; +TOK_REVOKE_ROLE; +TOK_SHOW_ROLE_GRANT; +TOK_SHOW_ROLES; +TOK_SHOW_SET_ROLE; +TOK_SHOW_ROLE_PRINCIPALS; +TOK_SHOWINDEXES; +TOK_SHOWDBLOCKS; +TOK_INDEXCOMMENT; +TOK_DESCDATABASE; +TOK_DATABASEPROPERTIES; +TOK_DATABASELOCATION; +TOK_DBPROPLIST; +TOK_ALTERDATABASE_PROPERTIES; +TOK_ALTERDATABASE_OWNER; +TOK_TABNAME; +TOK_TABSRC; +TOK_RESTRICT; +TOK_CASCADE; +TOK_TABLESKEWED; +TOK_TABCOLVALUE; +TOK_TABCOLVALUE_PAIR; +TOK_TABCOLVALUES; +TOK_SKEWED_LOCATIONS; +TOK_SKEWED_LOCATION_LIST; +TOK_SKEWED_LOCATION_MAP; +TOK_STOREDASDIRS; +TOK_PARTITIONINGSPEC; +TOK_PTBLFUNCTION; +TOK_WINDOWDEF; +TOK_WINDOWSPEC; +TOK_WINDOWVALUES; +TOK_WINDOWRANGE; +TOK_SUBQUERY_EXPR; +TOK_SUBQUERY_OP; +TOK_SUBQUERY_OP_NOTIN; +TOK_SUBQUERY_OP_NOTEXISTS; +TOK_DB_TYPE; +TOK_TABLE_TYPE; +TOK_CTE; +TOK_ARCHIVE; +TOK_FILE; +TOK_JAR; +TOK_RESOURCE_URI; +TOK_RESOURCE_LIST; +TOK_SHOW_COMPACTIONS; +TOK_SHOW_TRANSACTIONS; +TOK_DELETE_FROM; +TOK_UPDATE_TABLE; +TOK_SET_COLUMNS_CLAUSE; +TOK_VALUE_ROW; +TOK_VALUES_TABLE; +TOK_VIRTUAL_TABLE; +TOK_VIRTUAL_TABREF; +TOK_ANONYMOUS; +TOK_COL_NAME; +TOK_URI_TYPE; +TOK_SERVER_TYPE; +TOK_START_TRANSACTION; +TOK_ISOLATION_LEVEL; +TOK_ISOLATION_SNAPSHOT; +TOK_TXN_ACCESS_MODE; +TOK_TXN_READ_ONLY; +TOK_TXN_READ_WRITE; +TOK_COMMIT; +TOK_ROLLBACK; +TOK_SET_AUTOCOMMIT; +} + + +// Package headers +@header { +package org.apache.spark.sql.parser; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +} + + +@members { + ArrayList errors = new ArrayList(); + Stack msgs = new Stack(); + + private static HashMap xlateMap; + static { + //this is used to support auto completion in CLI + xlateMap = new HashMap(); + + // Keywords + xlateMap.put("KW_TRUE", "TRUE"); + xlateMap.put("KW_FALSE", "FALSE"); + xlateMap.put("KW_ALL", "ALL"); + xlateMap.put("KW_NONE", "NONE"); + xlateMap.put("KW_AND", "AND"); + xlateMap.put("KW_OR", "OR"); + xlateMap.put("KW_NOT", "NOT"); + xlateMap.put("KW_LIKE", "LIKE"); + + xlateMap.put("KW_ASC", "ASC"); + xlateMap.put("KW_DESC", "DESC"); + xlateMap.put("KW_ORDER", "ORDER"); + xlateMap.put("KW_BY", "BY"); + xlateMap.put("KW_GROUP", "GROUP"); + xlateMap.put("KW_WHERE", "WHERE"); + xlateMap.put("KW_FROM", "FROM"); + xlateMap.put("KW_AS", "AS"); + xlateMap.put("KW_SELECT", "SELECT"); + xlateMap.put("KW_DISTINCT", "DISTINCT"); + xlateMap.put("KW_INSERT", "INSERT"); + xlateMap.put("KW_OVERWRITE", "OVERWRITE"); + xlateMap.put("KW_OUTER", "OUTER"); + xlateMap.put("KW_JOIN", "JOIN"); + xlateMap.put("KW_LEFT", "LEFT"); + xlateMap.put("KW_RIGHT", "RIGHT"); + xlateMap.put("KW_FULL", "FULL"); + xlateMap.put("KW_ON", "ON"); + xlateMap.put("KW_PARTITION", "PARTITION"); + xlateMap.put("KW_PARTITIONS", "PARTITIONS"); + xlateMap.put("KW_TABLE", "TABLE"); + xlateMap.put("KW_TABLES", "TABLES"); + xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_SHOW", "SHOW"); + xlateMap.put("KW_MSCK", "MSCK"); + xlateMap.put("KW_DIRECTORY", "DIRECTORY"); + xlateMap.put("KW_LOCAL", "LOCAL"); + xlateMap.put("KW_TRANSFORM", "TRANSFORM"); + xlateMap.put("KW_USING", "USING"); + xlateMap.put("KW_CLUSTER", "CLUSTER"); + xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); + xlateMap.put("KW_SORT", "SORT"); + xlateMap.put("KW_UNION", "UNION"); + xlateMap.put("KW_LOAD", "LOAD"); + xlateMap.put("KW_DATA", "DATA"); + xlateMap.put("KW_INPATH", "INPATH"); + xlateMap.put("KW_IS", "IS"); + xlateMap.put("KW_NULL", "NULL"); + xlateMap.put("KW_CREATE", "CREATE"); + xlateMap.put("KW_EXTERNAL", "EXTERNAL"); + xlateMap.put("KW_ALTER", "ALTER"); + xlateMap.put("KW_DESCRIBE", "DESCRIBE"); + xlateMap.put("KW_DROP", "DROP"); + xlateMap.put("KW_RENAME", "RENAME"); + xlateMap.put("KW_TO", "TO"); + xlateMap.put("KW_COMMENT", "COMMENT"); + xlateMap.put("KW_BOOLEAN", "BOOLEAN"); + xlateMap.put("KW_TINYINT", "TINYINT"); + xlateMap.put("KW_SMALLINT", "SMALLINT"); + xlateMap.put("KW_INT", "INT"); + xlateMap.put("KW_BIGINT", "BIGINT"); + xlateMap.put("KW_FLOAT", "FLOAT"); + xlateMap.put("KW_DOUBLE", "DOUBLE"); + xlateMap.put("KW_DATE", "DATE"); + xlateMap.put("KW_DATETIME", "DATETIME"); + xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); + xlateMap.put("KW_STRING", "STRING"); + xlateMap.put("KW_BINARY", "BINARY"); + xlateMap.put("KW_ARRAY", "ARRAY"); + xlateMap.put("KW_MAP", "MAP"); + xlateMap.put("KW_REDUCE", "REDUCE"); + xlateMap.put("KW_PARTITIONED", "PARTITIONED"); + xlateMap.put("KW_CLUSTERED", "CLUSTERED"); + xlateMap.put("KW_SORTED", "SORTED"); + xlateMap.put("KW_INTO", "INTO"); + xlateMap.put("KW_BUCKETS", "BUCKETS"); + xlateMap.put("KW_ROW", "ROW"); + xlateMap.put("KW_FORMAT", "FORMAT"); + xlateMap.put("KW_DELIMITED", "DELIMITED"); + xlateMap.put("KW_FIELDS", "FIELDS"); + xlateMap.put("KW_TERMINATED", "TERMINATED"); + xlateMap.put("KW_COLLECTION", "COLLECTION"); + xlateMap.put("KW_ITEMS", "ITEMS"); + xlateMap.put("KW_KEYS", "KEYS"); + xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); + xlateMap.put("KW_LINES", "LINES"); + xlateMap.put("KW_STORED", "STORED"); + xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); + xlateMap.put("KW_TEXTFILE", "TEXTFILE"); + xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); + xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); + xlateMap.put("KW_LOCATION", "LOCATION"); + xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); + xlateMap.put("KW_BUCKET", "BUCKET"); + xlateMap.put("KW_OUT", "OUT"); + xlateMap.put("KW_OF", "OF"); + xlateMap.put("KW_CAST", "CAST"); + xlateMap.put("KW_ADD", "ADD"); + xlateMap.put("KW_REPLACE", "REPLACE"); + xlateMap.put("KW_COLUMNS", "COLUMNS"); + xlateMap.put("KW_RLIKE", "RLIKE"); + xlateMap.put("KW_REGEXP", "REGEXP"); + xlateMap.put("KW_TEMPORARY", "TEMPORARY"); + xlateMap.put("KW_FUNCTION", "FUNCTION"); + xlateMap.put("KW_EXPLAIN", "EXPLAIN"); + xlateMap.put("KW_EXTENDED", "EXTENDED"); + xlateMap.put("KW_SERDE", "SERDE"); + xlateMap.put("KW_WITH", "WITH"); + xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); + xlateMap.put("KW_LIMIT", "LIMIT"); + xlateMap.put("KW_SET", "SET"); + xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); + xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); + xlateMap.put("KW_DEFINED", "DEFINED"); + xlateMap.put("KW_SUBQUERY", "SUBQUERY"); + xlateMap.put("KW_REWRITE", "REWRITE"); + xlateMap.put("KW_UPDATE", "UPDATE"); + xlateMap.put("KW_VALUES", "VALUES"); + xlateMap.put("KW_PURGE", "PURGE"); + + + // Operators + xlateMap.put("DOT", "."); + xlateMap.put("COLON", ":"); + xlateMap.put("COMMA", ","); + xlateMap.put("SEMICOLON", ");"); + + xlateMap.put("LPAREN", "("); + xlateMap.put("RPAREN", ")"); + xlateMap.put("LSQUARE", "["); + xlateMap.put("RSQUARE", "]"); + + xlateMap.put("EQUAL", "="); + xlateMap.put("NOTEQUAL", "<>"); + xlateMap.put("EQUAL_NS", "<=>"); + xlateMap.put("LESSTHANOREQUALTO", "<="); + xlateMap.put("LESSTHAN", "<"); + xlateMap.put("GREATERTHANOREQUALTO", ">="); + xlateMap.put("GREATERTHAN", ">"); + + xlateMap.put("DIVIDE", "/"); + xlateMap.put("PLUS", "+"); + xlateMap.put("MINUS", "-"); + xlateMap.put("STAR", "*"); + xlateMap.put("MOD", "\%"); + + xlateMap.put("AMPERSAND", "&"); + xlateMap.put("TILDE", "~"); + xlateMap.put("BITWISEOR", "|"); + xlateMap.put("BITWISEXOR", "^"); + xlateMap.put("CharSetLiteral", "\\'"); + } + + public static Collection getKeywords() { + return xlateMap.values(); + } + + private static String xlate(String name) { + + String ret = xlateMap.get(name); + if (ret == null) { + ret = name; + } + + return ret; + } + + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + errors.add(new ParseError(this, e, tokenNames)); + } + + @Override + public String getErrorHeader(RecognitionException e) { + String header = null; + if (e.charPositionInLine < 0 && input.LT(-1) != null) { + Token t = input.LT(-1); + header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); + } else { + header = super.getErrorHeader(e); + } + + return header; + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + // Translate the token names to something that the user can understand + String[] xlateNames = new String[tokenNames.length]; + for (int i = 0; i < tokenNames.length; ++i) { + xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); + } + + if (e instanceof NoViableAltException) { + @SuppressWarnings("unused") + NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "cannot recognize input near" + + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") + + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") + + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); + } else if (e instanceof MismatchedTokenException) { + MismatchedTokenException mte = (MismatchedTokenException) e; + msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; + } else if (e instanceof FailedPredicateException) { + FailedPredicateException fpe = (FailedPredicateException) e; + msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; + } else { + msg = super.getErrorMessage(e, xlateNames); + } + + if (msgs.size() > 0) { + msg = msg + " in " + msgs.peek(); + } + return msg; + } + + public void pushMsg(String msg, RecognizerSharedState state) { + // ANTLR generated code does not wrap the @init code wit this backtracking check, + // even if the matching @after has it. If we have parser rules with that are doing + // some lookahead with syntactic predicates this can cause the push() and pop() calls + // to become unbalanced, so make sure both push/pop check the backtracking state. + if (state.backtracking == 0) { + msgs.push(msg); + } + } + + public void popMsg(RecognizerSharedState state) { + if (state.backtracking == 0) { + Object o = msgs.pop(); + } + } + + // counter to generate unique union aliases + private int aliasCounter; + private String generateUnionAlias() { + return "_u" + (++aliasCounter); + } + private char [] excludedCharForColumnName = {'.', ':'}; + private boolean containExcludedCharForCreateTableColumnName(String input) { + for(char c : excludedCharForColumnName) { + if(input.indexOf(c)>-1) { + return true; + } + } + return false; + } + private CommonTree throwSetOpException() throws RecognitionException { + throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); + } + private CommonTree throwColumnNameException() throws RecognitionException { + throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); + } + private Configuration hiveConf; + public void setHiveConf(Configuration hiveConf) { + this.hiveConf = hiveConf; + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + if(hiveConf==null){ + return false; + } + return !HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS); + } +} + +@rulecatch { +catch (RecognitionException e) { + reportError(e); + throw e; +} +} + +// starting rule +statement + : explainStatement EOF + | execStatement EOF + ; + +explainStatement +@init { pushMsg("explain statement", state); } +@after { popMsg(state); } + : KW_EXPLAIN ( + explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) + | + KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) + ; + +explainOption +@init { msgs.push("explain option"); } +@after { msgs.pop(); } + : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION + ; + +execStatement +@init { pushMsg("statement", state); } +@after { popMsg(state); } + : queryStatementExpression[true] + | loadStatement + | exportStatement + | importStatement + | ddlStatement + | deleteStatement + | updateStatement + | sqlTransactionStatement + ; + +loadStatement +@init { pushMsg("load statement", state); } +@after { popMsg(state); } + : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) + -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) + ; + +replicationClause +@init { pushMsg("replication clause", state); } +@after { popMsg(state); } + : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN + -> ^(TOK_REPLICATION $replId $isMetadataOnly?) + ; + +exportStatement +@init { pushMsg("export statement", state); } +@after { popMsg(state); } + : KW_EXPORT + KW_TABLE (tab=tableOrPartition) + KW_TO (path=StringLiteral) + replicationClause? + -> ^(TOK_EXPORT $tab $path replicationClause?) + ; + +importStatement +@init { pushMsg("import statement", state); } +@after { popMsg(state); } + : KW_IMPORT + ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? + KW_FROM (path=StringLiteral) + tableLocation? + -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) + ; + +ddlStatement +@init { pushMsg("ddl statement", state); } +@after { popMsg(state); } + : createDatabaseStatement + | switchDatabaseStatement + | dropDatabaseStatement + | createTableStatement + | dropTableStatement + | truncateTableStatement + | alterStatement + | descStatement + | showStatement + | metastoreCheck + | createViewStatement + | dropViewStatement + | createFunctionStatement + | createMacroStatement + | createIndexStatement + | dropIndexStatement + | dropFunctionStatement + | reloadFunctionStatement + | dropMacroStatement + | analyzeStatement + | lockStatement + | unlockStatement + | lockDatabase + | unlockDatabase + | createRoleStatement + | dropRoleStatement + | (grantPrivileges) => grantPrivileges + | (revokePrivileges) => revokePrivileges + | showGrants + | showRoleGrants + | showRolePrincipals + | showRoles + | grantRole + | revokeRole + | setRole + | showCurrentRole + ; + +ifExists +@init { pushMsg("if exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_EXISTS + -> ^(TOK_IFEXISTS) + ; + +restrictOrCascade +@init { pushMsg("restrict or cascade clause", state); } +@after { popMsg(state); } + : KW_RESTRICT + -> ^(TOK_RESTRICT) + | KW_CASCADE + -> ^(TOK_CASCADE) + ; + +ifNotExists +@init { pushMsg("if not exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_NOT KW_EXISTS + -> ^(TOK_IFNOTEXISTS) + ; + +storedAsDirs +@init { pushMsg("stored as directories", state); } +@after { popMsg(state); } + : KW_STORED KW_AS KW_DIRECTORIES + -> ^(TOK_STOREDASDIRS) + ; + +orReplace +@init { pushMsg("or replace clause", state); } +@after { popMsg(state); } + : KW_OR KW_REPLACE + -> ^(TOK_ORREPLACE) + ; + +createDatabaseStatement +@init { pushMsg("create database statement", state); } +@after { popMsg(state); } + : KW_CREATE (KW_DATABASE|KW_SCHEMA) + ifNotExists? + name=identifier + databaseComment? + dbLocation? + (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? + -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) + ; + +dbLocation +@init { pushMsg("database location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) + ; + +dbProperties +@init { pushMsg("dbproperties", state); } +@after { popMsg(state); } + : + LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) + ; + +dbPropertiesList +@init { pushMsg("database properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) + ; + + +switchDatabaseStatement +@init { pushMsg("switch database statement", state); } +@after { popMsg(state); } + : KW_USE identifier + -> ^(TOK_SWITCHDATABASE identifier) + ; + +dropDatabaseStatement +@init { pushMsg("drop database statement", state); } +@after { popMsg(state); } + : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? + -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) + ; + +databaseComment +@init { pushMsg("database's comment", state); } +@after { popMsg(state); } + : KW_COMMENT comment=StringLiteral + -> ^(TOK_DATABASECOMMENT $comment) + ; + +createTableStatement +@init { pushMsg("create table statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName + ( like=KW_LIKE likeName=tableName + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + | (LPAREN columnNameTypeList RPAREN)? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + (KW_AS selectStatementWithCTE)? + ) + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + columnNameTypeList? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + selectStatementWithCTE? + ) + ; + +truncateTableStatement +@init { pushMsg("truncate table statement", state); } +@after { popMsg(state); } + : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); + +createIndexStatement +@init { pushMsg("create index statement", state);} +@after {popMsg(state);} + : KW_CREATE KW_INDEX indexName=identifier + KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN + KW_AS typeName=StringLiteral + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment? + ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment?) + ; + +indexComment +@init { pushMsg("comment on an index", state);} +@after {popMsg(state);} + : + KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) + ; + +autoRebuild +@init { pushMsg("auto rebuild index", state);} +@after {popMsg(state);} + : KW_WITH KW_DEFERRED KW_REBUILD + ->^(TOK_DEFERRED_REBUILDINDEX) + ; + +indexTblName +@init { pushMsg("index table name", state);} +@after {popMsg(state);} + : KW_IN KW_TABLE indexTbl=tableName + ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) + ; + +indexPropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_IDXPROPERTIES! indexProperties + ; + +indexProperties +@init { pushMsg("index properties", state); } +@after { popMsg(state); } + : + LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) + ; + +indexPropertiesList +@init { pushMsg("index properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) + ; + +dropIndexStatement +@init { pushMsg("drop index statement", state);} +@after {popMsg(state);} + : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName + ->^(TOK_DROPINDEX $indexName $tab ifExists?) + ; + +dropTableStatement +@init { pushMsg("drop statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? + -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) + ; + +alterStatement +@init { pushMsg("alter statement", state); } +@after { popMsg(state); } + : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) + | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) + | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix + | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix + ; + +alterTableStatementSuffix +@init { pushMsg("alter table statement", state); } +@after { popMsg(state); } + : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] + | alterStatementSuffixDropPartitions[true] + | alterStatementSuffixAddPartitions[true] + | alterStatementSuffixTouch + | alterStatementSuffixArchive + | alterStatementSuffixUnArchive + | alterStatementSuffixProperties + | alterStatementSuffixSkewedby + | alterStatementSuffixExchangePartition + | alterStatementPartitionKeyType + | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? + ; + +alterTblPartitionStatementSuffix +@init {pushMsg("alter table partition statement suffix", state);} +@after {popMsg(state);} + : alterStatementSuffixFileFormat + | alterStatementSuffixLocation + | alterStatementSuffixMergeFiles + | alterStatementSuffixSerdeProperties + | alterStatementSuffixRenamePart + | alterStatementSuffixBucketNum + | alterTblPartitionStatementSuffixSkewedLocation + | alterStatementSuffixClusterbySortby + | alterStatementSuffixCompact + | alterStatementSuffixUpdateStatsCol + | alterStatementSuffixRenameCol + | alterStatementSuffixAddCol + ; + +alterStatementPartitionKeyType +@init {msgs.push("alter partition key type"); } +@after {msgs.pop();} + : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN + -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) + ; + +alterViewStatementSuffix +@init { pushMsg("alter view statement", state); } +@after { popMsg(state); } + : alterViewSuffixProperties + | alterStatementSuffixRename[false] + | alterStatementSuffixAddPartitions[false] + | alterStatementSuffixDropPartitions[false] + | selectStatementWithCTE + ; + +alterIndexStatementSuffix +@init { pushMsg("alter index statement", state); } +@after { popMsg(state); } + : indexName=identifier KW_ON tableName partitionSpec? + ( + KW_REBUILD + ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) + | + KW_SET KW_IDXPROPERTIES + indexProperties + ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) + ) + ; + +alterDatabaseStatementSuffix +@init { pushMsg("alter database statement", state); } +@after { popMsg(state); } + : alterDatabaseSuffixProperties + | alterDatabaseSuffixSetOwner + ; + +alterDatabaseSuffixProperties +@init { pushMsg("alter database properties statement", state); } +@after { popMsg(state); } + : name=identifier KW_SET KW_DBPROPERTIES dbProperties + -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) + ; + +alterDatabaseSuffixSetOwner +@init { pushMsg("alter database set owner", state); } +@after { popMsg(state); } + : dbName=identifier KW_SET KW_OWNER principalName + -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) + ; + +alterStatementSuffixRename[boolean table] +@init { pushMsg("rename statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO tableName + -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) + -> ^(TOK_ALTERVIEW_RENAME tableName) + ; + +alterStatementSuffixAddCol +@init { pushMsg("add column statement", state); } +@after { popMsg(state); } + : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? + -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) + -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) + ; + +alterStatementSuffixRenameCol +@init { pushMsg("rename column name", state); } +@after { popMsg(state); } + : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? + ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) + ; + +alterStatementSuffixUpdateStatsCol +@init { pushMsg("update column statistics", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementChangeColPosition + : first=KW_FIRST|KW_AFTER afterCol=identifier + ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) + -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) + ; + +alterStatementSuffixAddPartitions[boolean table] +@init { pushMsg("add partition statement", state); } +@after { popMsg(state); } + : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ + -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + ; + +alterStatementSuffixAddPartitionsElement + : partitionSpec partitionLocation? + ; + +alterStatementSuffixTouch +@init { pushMsg("touch statement", state); } +@after { popMsg(state); } + : KW_TOUCH (partitionSpec)* + -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) + ; + +alterStatementSuffixArchive +@init { pushMsg("archive statement", state); } +@after { popMsg(state); } + : KW_ARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) + ; + +alterStatementSuffixUnArchive +@init { pushMsg("unarchive statement", state); } +@after { popMsg(state); } + : KW_UNARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) + ; + +partitionLocation +@init { pushMsg("partition location", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) + ; + +alterStatementSuffixDropPartitions[boolean table] +@init { pushMsg("drop partition statement", state); } +@after { popMsg(state); } + : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? + -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) + -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) + ; + +alterStatementSuffixProperties +@init { pushMsg("alter properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) + ; + +alterViewSuffixProperties +@init { pushMsg("alter view properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) + ; + +alterStatementSuffixSerdeProperties +@init { pushMsg("alter serdes statement", state); } +@after { popMsg(state); } + : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? + -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) + | KW_SET KW_SERDEPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) + ; + +tablePartitionPrefix +@init {pushMsg("table partition prefix", state);} +@after {popMsg(state);} + : tableName partitionSpec? + ->^(TOK_TABLE_PARTITION tableName partitionSpec?) + ; + +alterStatementSuffixFileFormat +@init {pushMsg("alter fileformat statement", state); } +@after {popMsg(state);} + : KW_SET KW_FILEFORMAT fileFormat + -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) + ; + +alterStatementSuffixClusterbySortby +@init {pushMsg("alter partition cluster by sort by statement", state);} +@after {popMsg(state);} + : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) + | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) + | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) + ; + +alterTblPartitionStatementSuffixSkewedLocation +@init {pushMsg("alter partition skewed location", state);} +@after {popMsg(state);} + : KW_SET KW_SKEWED KW_LOCATION skewedLocations + -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) + ; + +skewedLocations +@init { pushMsg("skewed locations", state); } +@after { popMsg(state); } + : + LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) + ; + +skewedLocationsList +@init { pushMsg("skewed locations list", state); } +@after { popMsg(state); } + : + skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) + ; + +skewedLocationMap +@init { pushMsg("specifying skewed location map", state); } +@after { popMsg(state); } + : + key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) + ; + +alterStatementSuffixLocation +@init {pushMsg("alter location", state);} +@after {popMsg(state);} + : KW_SET KW_LOCATION newLoc=StringLiteral + -> ^(TOK_ALTERTABLE_LOCATION $newLoc) + ; + + +alterStatementSuffixSkewedby +@init {pushMsg("alter skewed by statement", state);} +@after{popMsg(state);} + : tableSkewed + ->^(TOK_ALTERTABLE_SKEWED tableSkewed) + | + KW_NOT KW_SKEWED + ->^(TOK_ALTERTABLE_SKEWED) + | + KW_NOT storedAsDirs + ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) + ; + +alterStatementSuffixExchangePartition +@init {pushMsg("alter exchange partition", state);} +@after{popMsg(state);} + : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName + -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) + ; + +alterStatementSuffixRenamePart +@init { pushMsg("alter table rename partition statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO partitionSpec + ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) + ; + +alterStatementSuffixStatsPart +@init { pushMsg("alter table stats partition statement", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementSuffixMergeFiles +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_CONCATENATE + -> ^(TOK_ALTERTABLE_MERGEFILES) + ; + +alterStatementSuffixBucketNum +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $num) + ; + +alterStatementSuffixCompact +@init { msgs.push("compaction request"); } +@after { msgs.pop(); } + : KW_COMPACT compactType=StringLiteral + -> ^(TOK_ALTERTABLE_COMPACT $compactType) + ; + + +fileFormat +@init { pushMsg("file format specification", state); } +@after { popMsg(state); } + : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) + | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tabTypeExpr +@init { pushMsg("specifying table types", state); } +@after { popMsg(state); } + : identifier (DOT^ identifier)? + (identifier (DOT^ + ( + (KW_ELEM_TYPE) => KW_ELEM_TYPE + | + (KW_KEY_TYPE) => KW_KEY_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE + | identifier + ))* + )? + ; + +partTypeExpr +@init { pushMsg("specifying table partitions", state); } +@after { popMsg(state); } + : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) + ; + +tabPartColTypeExpr +@init { pushMsg("specifying table partitions columnName", state); } +@after { popMsg(state); } + : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) + ; + +descStatement +@init { pushMsg("describe statement", state); } +@after { popMsg(state); } + : + (KW_DESCRIBE|KW_DESC) + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) + | + (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) + | + (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) + | + parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) + ) + ; + +analyzeStatement +@init { pushMsg("analyze statement", state); } +@after { popMsg(state); } + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? + -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) + ; + +showStatement +@init { pushMsg("show statement", state); } +@after { popMsg(state); } + : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) + | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) + | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWCOLUMNS tableName $db_name?) + | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_CREATE ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) + | + KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) + ) + | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? + -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) + | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) + | KW_SHOW KW_LOCKS + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) + | + (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) + ) + | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) + | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) + | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) + | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) + ; + +lockStatement +@init { pushMsg("lock statement", state); } +@after { popMsg(state); } + : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) + ; + +lockDatabase +@init { pushMsg("lock database statement", state); } +@after { popMsg(state); } + : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) + ; + +lockMode +@init { pushMsg("lock mode", state); } +@after { popMsg(state); } + : KW_SHARED | KW_EXCLUSIVE + ; + +unlockStatement +@init { pushMsg("unlock statement", state); } +@after { popMsg(state); } + : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) + ; + +unlockDatabase +@init { pushMsg("unlock database statement", state); } +@after { popMsg(state); } + : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) + ; + +createRoleStatement +@init { pushMsg("create role", state); } +@after { popMsg(state); } + : KW_CREATE KW_ROLE roleName=identifier + -> ^(TOK_CREATEROLE $roleName) + ; + +dropRoleStatement +@init {pushMsg("drop role", state);} +@after {popMsg(state);} + : KW_DROP KW_ROLE roleName=identifier + -> ^(TOK_DROPROLE $roleName) + ; + +grantPrivileges +@init {pushMsg("grant privileges", state);} +@after {popMsg(state);} + : KW_GRANT privList=privilegeList + privilegeObject? + KW_TO principalSpecification + withGrantOption? + -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) + ; + +revokePrivileges +@init {pushMsg("revoke privileges", state);} +@afer {popMsg(state);} + : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification + -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) + ; + +grantRole +@init {pushMsg("grant role", state);} +@after {popMsg(state);} + : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? + -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) + ; + +revokeRole +@init {pushMsg("revoke role", state);} +@after {popMsg(state);} + : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification + -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) + ; + +showRoleGrants +@init {pushMsg("show role grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLE KW_GRANT principalName + -> ^(TOK_SHOW_ROLE_GRANT principalName) + ; + + +showRoles +@init {pushMsg("show roles", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLES + -> ^(TOK_SHOW_ROLES) + ; + +showCurrentRole +@init {pushMsg("show current role", state);} +@after {popMsg(state);} + : KW_SHOW KW_CURRENT KW_ROLES + -> ^(TOK_SHOW_SET_ROLE) + ; + +setRole +@init {pushMsg("set role", state);} +@after {popMsg(state);} + : KW_SET KW_ROLE + ( + (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) + | + (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) + | + identifier -> ^(TOK_SHOW_SET_ROLE identifier) + ) + ; + +showGrants +@init {pushMsg("show grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? + -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) + ; + +showRolePrincipals +@init {pushMsg("show role principals", state);} +@after {popMsg(state);} + : KW_SHOW KW_PRINCIPALS roleName=identifier + -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) + ; + + +privilegeIncludeColObject +@init {pushMsg("privilege object including columns", state);} +@after {popMsg(state);} + : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) + | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) + ; + +privilegeObject +@init {pushMsg("privilege object", state);} +@after {popMsg(state);} + : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) + ; + +// database or table type. Type is optional, default type is table +privObject + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privObjectCols + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privilegeList +@init {pushMsg("grant privilege list", state);} +@after {popMsg(state);} + : privlegeDef (COMMA privlegeDef)* + -> ^(TOK_PRIVILEGE_LIST privlegeDef+) + ; + +privlegeDef +@init {pushMsg("grant privilege", state);} +@after {popMsg(state);} + : privilegeType (LPAREN cols=columnNameList RPAREN)? + -> ^(TOK_PRIVILEGE privilegeType $cols?) + ; + +privilegeType +@init {pushMsg("privilege type", state);} +@after {popMsg(state);} + : KW_ALL -> ^(TOK_PRIV_ALL) + | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) + | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) + | KW_CREATE -> ^(TOK_PRIV_CREATE) + | KW_DROP -> ^(TOK_PRIV_DROP) + | KW_INDEX -> ^(TOK_PRIV_INDEX) + | KW_LOCK -> ^(TOK_PRIV_LOCK) + | KW_SELECT -> ^(TOK_PRIV_SELECT) + | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) + | KW_INSERT -> ^(TOK_PRIV_INSERT) + | KW_DELETE -> ^(TOK_PRIV_DELETE) + ; + +principalSpecification +@init { pushMsg("user/group/role name list", state); } +@after { popMsg(state); } + : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) + ; + +principalName +@init {pushMsg("user|group|role name", state);} +@after {popMsg(state);} + : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) + | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) + | KW_ROLE identifier -> ^(TOK_ROLE identifier) + ; + +withGrantOption +@init {pushMsg("with grant option", state);} +@after {popMsg(state);} + : KW_WITH KW_GRANT KW_OPTION + -> ^(TOK_GRANT_WITH_OPTION) + ; + +grantOptionFor +@init {pushMsg("grant option for", state);} +@after {popMsg(state);} + : KW_GRANT KW_OPTION KW_FOR + -> ^(TOK_GRANT_OPTION_FOR) +; + +adminOptionFor +@init {pushMsg("admin option for", state);} +@after {popMsg(state);} + : KW_ADMIN KW_OPTION KW_FOR + -> ^(TOK_ADMIN_OPTION_FOR) +; + +withAdminOption +@init {pushMsg("with admin option", state);} +@after {popMsg(state);} + : KW_WITH KW_ADMIN KW_OPTION + -> ^(TOK_GRANT_WITH_ADMIN_OPTION) + ; + +metastoreCheck +@init { pushMsg("metastore check statement", state); } +@after { popMsg(state); } + : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? + -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) + ; + +resourceList +@init { pushMsg("resource list", state); } +@after { popMsg(state); } + : + resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) + ; + +resource +@init { pushMsg("resource", state); } +@after { popMsg(state); } + : + resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) + ; + +resourceType +@init { pushMsg("resource type", state); } +@after { popMsg(state); } + : + KW_JAR -> ^(TOK_JAR) + | + KW_FILE -> ^(TOK_FILE) + | + KW_ARCHIVE -> ^(TOK_ARCHIVE) + ; + +createFunctionStatement +@init { pushMsg("create function statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral + (KW_USING rList=resourceList)? + -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) + -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) + ; + +dropFunctionStatement +@init { pushMsg("drop function statement", state); } +@after { popMsg(state); } + : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier + -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) + -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) + ; + +reloadFunctionStatement +@init { pushMsg("reload function statement", state); } +@after { popMsg(state); } + : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); + +createMacroStatement +@init { pushMsg("create macro statement", state); } +@after { popMsg(state); } + : KW_CREATE KW_TEMPORARY KW_MACRO Identifier + LPAREN columnNameTypeList? RPAREN expression + -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) + ; + +dropMacroStatement +@init { pushMsg("drop macro statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier + -> ^(TOK_DROPMACRO Identifier ifExists?) + ; + +createViewStatement +@init { + pushMsg("create view statement", state); +} +@after { popMsg(state); } + : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName + (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? + tablePropertiesPrefixed? + KW_AS + selectStatementWithCTE + -> ^(TOK_CREATEVIEW $name orReplace? + ifNotExists? + columnNameCommentList? + tableComment? + viewPartition? + tablePropertiesPrefixed? + selectStatementWithCTE + ) + ; + +viewPartition +@init { pushMsg("view partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN + -> ^(TOK_VIEWPARTCOLS columnNameList) + ; + +dropViewStatement +@init { pushMsg("drop view statement", state); } +@after { popMsg(state); } + : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) + ; + +showFunctionIdentifier +@init { pushMsg("identifier for show function statement", state); } +@after { popMsg(state); } + : functionIdentifier + | StringLiteral + ; + +showStmtIdentifier +@init { pushMsg("identifier for show statement", state); } +@after { popMsg(state); } + : identifier + | StringLiteral + ; + +tableComment +@init { pushMsg("table's comment", state); } +@after { popMsg(state); } + : + KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) + ; + +tablePartition +@init { pushMsg("table partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN + -> ^(TOK_TABLEPARTCOLS columnNameTypeList) + ; + +tableBuckets +@init { pushMsg("table buckets specification", state); } +@after { popMsg(state); } + : + KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) + ; + +tableSkewed +@init { pushMsg("table skewed specification", state); } +@after { popMsg(state); } + : + KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? + -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) + ; + +rowFormat +@init { pushMsg("serde specification", state); } +@after { popMsg(state); } + : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) + | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) + | -> ^(TOK_SERDE) + ; + +recordReader +@init { pushMsg("record reader specification", state); } +@after { popMsg(state); } + : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) + | -> ^(TOK_RECORDREADER) + ; + +recordWriter +@init { pushMsg("record writer specification", state); } +@after { popMsg(state); } + : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) + | -> ^(TOK_RECORDWRITER) + ; + +rowFormatSerde +@init { pushMsg("serde format specification", state); } +@after { popMsg(state); } + : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_SERDENAME $name $serdeprops?) + ; + +rowFormatDelimited +@init { pushMsg("serde properties specification", state); } +@after { popMsg(state); } + : + KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? + -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) + ; + +tableRowFormat +@init { pushMsg("table row format specification", state); } +@after { popMsg(state); } + : + rowFormatDelimited + -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) + | rowFormatSerde + -> ^(TOK_TABLESERIALIZER rowFormatSerde) + ; + +tablePropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_TBLPROPERTIES! tableProperties + ; + +tableProperties +@init { pushMsg("table properties", state); } +@after { popMsg(state); } + : + LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) + ; + +tablePropertiesList +@init { pushMsg("table properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) + | + keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) + ; + +keyValueProperty +@init { pushMsg("specifying key/value property", state); } +@after { popMsg(state); } + : + key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) + ; + +keyProperty +@init { pushMsg("specifying key property", state); } +@after { popMsg(state); } + : + key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) + ; + +tableRowFormatFieldIdentifier +@init { pushMsg("table row format's field separator", state); } +@after { popMsg(state); } + : + KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? + -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) + ; + +tableRowFormatCollItemsIdentifier +@init { pushMsg("table row format's column separator", state); } +@after { popMsg(state); } + : + KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) + ; + +tableRowFormatMapKeysIdentifier +@init { pushMsg("table row format's map key separator", state); } +@after { popMsg(state); } + : + KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) + ; + +tableRowFormatLinesIdentifier +@init { pushMsg("table row format's line separator", state); } +@after { popMsg(state); } + : + KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) + ; + +tableRowNullFormat +@init { pushMsg("table row format's null specifier", state); } +@after { popMsg(state); } + : + KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) + ; +tableFileFormat +@init { pushMsg("table file format specification", state); } +@after { popMsg(state); } + : + (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) + | KW_STORED KW_BY storageHandler=StringLiteral + (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) + | KW_STORED KW_AS genericSpec=identifier + -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tableLocation +@init { pushMsg("table location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) + ; + +columnNameTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) + ; + +columnNameColonTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) + ; + +columnNameList +@init { pushMsg("column name list", state); } +@after { popMsg(state); } + : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) + ; + +columnName +@init { pushMsg("column name", state); } +@after { popMsg(state); } + : + identifier + ; + +extColumnName +@init { pushMsg("column name for complex types", state); } +@after { popMsg(state); } + : + identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* + ; + +columnNameOrderList +@init { pushMsg("column name order list", state); } +@after { popMsg(state); } + : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) + ; + +skewedValueElement +@init { pushMsg("skewed value element", state); } +@after { popMsg(state); } + : + skewedColumnValues + | skewedColumnValuePairList + ; + +skewedColumnValuePairList +@init { pushMsg("column value pair list", state); } +@after { popMsg(state); } + : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) + ; + +skewedColumnValuePair +@init { pushMsg("column value pair", state); } +@after { popMsg(state); } + : + LPAREN colValues=skewedColumnValues RPAREN + -> ^(TOK_TABCOLVALUES $colValues) + ; + +skewedColumnValues +@init { pushMsg("column values", state); } +@after { popMsg(state); } + : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) + ; + +skewedColumnValue +@init { pushMsg("column value", state); } +@after { popMsg(state); } + : + constant + ; + +skewedValueLocationElement +@init { pushMsg("skewed value location element", state); } +@after { popMsg(state); } + : + skewedColumnValue + | skewedColumnValuePair + ; + +columnNameOrder +@init { pushMsg("column name order", state); } +@after { popMsg(state); } + : identifier (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) + -> ^(TOK_TABSORTCOLNAMEDESC identifier) + ; + +columnNameCommentList +@init { pushMsg("column name comment list", state); } +@after { popMsg(state); } + : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) + ; + +columnNameComment +@init { pushMsg("column name comment", state); } +@after { popMsg(state); } + : colName=identifier (KW_COMMENT comment=StringLiteral)? + -> ^(TOK_TABCOL $colName TOK_NULL $comment?) + ; + +columnRefOrder +@init { pushMsg("column order", state); } +@after { popMsg(state); } + : expression (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) + -> ^(TOK_TABSORTCOLNAMEDESC expression) + ; + +columnNameType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier colType (KW_COMMENT comment=StringLiteral)? + -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +columnNameColonType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +colType +@init { pushMsg("column type", state); } +@after { popMsg(state); } + : type + ; + +colTypeList +@init { pushMsg("column type list", state); } +@after { popMsg(state); } + : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) + ; + +type + : primitiveType + | listType + | structType + | mapType + | unionType; + +primitiveType +@init { pushMsg("primitive type specification", state); } +@after { popMsg(state); } + : KW_TINYINT -> TOK_TINYINT + | KW_SMALLINT -> TOK_SMALLINT + | KW_INT -> TOK_INT + | KW_BIGINT -> TOK_BIGINT + | KW_BOOLEAN -> TOK_BOOLEAN + | KW_FLOAT -> TOK_FLOAT + | KW_DOUBLE -> TOK_DOUBLE + | KW_DATE -> TOK_DATE + | KW_DATETIME -> TOK_DATETIME + | KW_TIMESTAMP -> TOK_TIMESTAMP + // Uncomment to allow intervals as table column types + //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH + //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME + | KW_STRING -> TOK_STRING + | KW_BINARY -> TOK_BINARY + | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) + | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) + | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) + ; + +listType +@init { pushMsg("list type", state); } +@after { popMsg(state); } + : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) + ; + +structType +@init { pushMsg("struct type", state); } +@after { popMsg(state); } + : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) + ; + +mapType +@init { pushMsg("map type", state); } +@after { popMsg(state); } + : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + -> ^(TOK_MAP $left $right) + ; + +unionType +@init { pushMsg("uniontype type", state); } +@after { popMsg(state); } + : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) + ; + +setOperator +@init { pushMsg("set operator", state); } +@after { popMsg(state); } + : KW_UNION KW_ALL -> ^(TOK_UNIONALL) + | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) + ; + +queryStatementExpression[boolean topLevel] + : + /* Would be nice to do this as a gated semantic perdicate + But the predicate gets pushed as a lookahead decision. + Calling rule doesnot know about topLevel + */ + (w=withClause {topLevel}?)? + queryStatementExpressionBody[topLevel] { + if ($w.tree != null) { + $queryStatementExpressionBody.tree.insertChild(0, $w.tree); + } + } + -> queryStatementExpressionBody + ; + +queryStatementExpressionBody[boolean topLevel] + : + fromStatement[topLevel] + | regularBody[topLevel] + ; + +withClause + : + KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) +; + +cteStatement + : + identifier KW_AS LPAREN queryStatementExpression[false] RPAREN + -> ^(TOK_SUBQUERY queryStatementExpression identifier) +; + +fromStatement[boolean topLevel] +: (singleFromStatement -> singleFromStatement) + (u=setOperator r=singleFromStatement + -> ^($u {$fromStatement.tree} $r) + )* + -> {u != null && topLevel}? ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$fromStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$fromStatement.tree} + ; + + +singleFromStatement + : + fromClause + ( b+=body )+ -> ^(TOK_QUERY fromClause body+) + ; + +/* +The valuesClause rule below ensures that the parse tree for +"insert into table FOO values (1,2),(3,4)" looks the same as +"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look +very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name +is implicit, it's represented as TOK_ANONYMOUS. +*/ +regularBody[boolean topLevel] + : + i=insertClause + ( + s=selectStatement[topLevel] + {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} + | + valuesClause + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) + ) + ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) + ) + ) + | + selectStatement[topLevel] + ; + +selectStatement[boolean topLevel] + : + ( + s=selectClause + f=fromClause? + w=whereClause? + g=groupByClause? + h=havingClause? + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + $s $w? $g? $h? $o? $c? + $d? $sort? $win? $l?)) + ) + (set=setOpSelectStatement[$selectStatement.tree, topLevel])? + -> {set == null}? + {$selectStatement.tree} + -> {o==null && c==null && d==null && sort==null && l==null}? + {$set.tree} + -> {throwSetOpException()} + ; + +setOpSelectStatement[CommonTree t, boolean topLevel] + : + (u=setOperator b=simpleSelectStatement + -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^(TOK_UNIONALL {$t} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> ^(TOK_UNIONALL {$t} $b) + )+ + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? + {$setOpSelectStatement.tree} + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$setOpSelectStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + $o? $c? $d? $sort? $win? $l? + ) + ) + ; + +simpleSelectStatement + : + selectClause + fromClause? + whereClause? + groupByClause? + havingClause? + ((window_clause) => window_clause)? + -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause whereClause? groupByClause? havingClause? window_clause?)) + ; + +selectStatementWithCTE + : + (w=withClause)? + selectStatement[true] { + if ($w.tree != null) { + $selectStatement.tree.insertChild(0, $w.tree); + } + } + -> selectStatement + ; + +body + : + insertClause + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT insertClause + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + | + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + ; + +insertClause +@init { pushMsg("insert clause", state); } +@after { popMsg(state); } + : + KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) + | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? + -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) + ; + +destination +@init { pushMsg("destination specification", state); } +@after { popMsg(state); } + : + (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? + -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) + | KW_TABLE tableOrPartition -> tableOrPartition + ; + +limitClause +@init { pushMsg("limit clause", state); } +@after { popMsg(state); } + : + KW_LIMIT num=Number -> ^(TOK_LIMIT $num) + ; + +//DELETE FROM WHERE ...; +deleteStatement +@init { pushMsg("delete statement", state); } +@after { popMsg(state); } + : + KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) + ; + +/*SET = (3 + col2)*/ +columnAssignmentClause + : + tableOrColumn EQUAL^ precedencePlusExpression + ; + +/*SET col1 = 5, col2 = (4 + col4), ...*/ +setColumnsClause + : + KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) + ; + +/* + UPDATE
    + SET col1 = val1, col2 = val2... WHERE ... +*/ +updateStatement +@init { pushMsg("update statement", state); } +@after { popMsg(state); } + : + KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) + ; + +/* +BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of +"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. +*/ +sqlTransactionStatement +@init { pushMsg("transaction statement", state); } +@after { popMsg(state); } + : + startTransactionStatement + | commitStatement + | rollbackStatement + | setAutoCommitStatement + ; + +startTransactionStatement + : + KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) + ; + +transactionMode + : + isolationLevel + | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) + ; + +transactionAccessMode + : + KW_READ KW_ONLY -> TOK_TXN_READ_ONLY + | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE + ; + +isolationLevel + : + KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) + ; + +/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ +levelOfIsolation + : + KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT + ; + +commitStatement + : + KW_COMMIT ( KW_WORK )? -> TOK_COMMIT + ; + +rollbackStatement + : + KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK + ; +setAutoCommitStatement + : + KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) + ; +/* +END user defined transaction boundaries +*/ diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java new file mode 100644 index 0000000000000..35ecdc5ad10a9 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.antlr.runtime.RecognitionException; +import org.antlr.runtime.Token; +import org.antlr.runtime.TokenStream; +import org.antlr.runtime.tree.CommonErrorNode; + +public class ASTErrorNode extends ASTNode { + + /** + * + */ + private static final long serialVersionUID = 1L; + CommonErrorNode delegate; + + public ASTErrorNode(TokenStream input, Token start, Token stop, + RecognitionException e){ + delegate = new CommonErrorNode(input,start,stop,e); + } + + @Override + public boolean isNil() { return delegate.isNil(); } + + @Override + public int getType() { return delegate.getType(); } + + @Override + public String getText() { return delegate.getText(); } + @Override + public String toString() { return delegate.toString(); } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java new file mode 100644 index 0000000000000..33d9322b628ec --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java @@ -0,0 +1,245 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.antlr.runtime.Token; +import org.antlr.runtime.tree.CommonTree; +import org.antlr.runtime.tree.Tree; +import org.apache.hadoop.hive.ql.lib.Node; + +public class ASTNode extends CommonTree implements Node, Serializable { + private static final long serialVersionUID = 1L; + private transient StringBuffer astStr; + private transient int startIndx = -1; + private transient int endIndx = -1; + private transient ASTNode rootNode; + private transient boolean isValidASTStr; + + public ASTNode() { + } + + /** + * Constructor. + * + * @param t + * Token for the CommonTree Node + */ + public ASTNode(Token t) { + super(t); + } + + public ASTNode(ASTNode node) { + super(node); + } + + @Override + public Tree dupNode() { + return new ASTNode(this); + } + + /* + * (non-Javadoc) + * + * @see org.apache.hadoop.hive.ql.lib.Node#getChildren() + */ + @Override + public ArrayList getChildren() { + if (super.getChildCount() == 0) { + return null; + } + + ArrayList ret_vec = new ArrayList(); + for (int i = 0; i < super.getChildCount(); ++i) { + ret_vec.add((Node) super.getChild(i)); + } + + return ret_vec; + } + + /* + * (non-Javadoc) + * + * @see org.apache.hadoop.hive.ql.lib.Node#getName() + */ + @Override + public String getName() { + return (Integer.valueOf(super.getToken().getType())).toString(); + } + + public String dump() { + StringBuilder sb = new StringBuilder("\n"); + dump(sb, ""); + return sb.toString(); + } + + private StringBuilder dump(StringBuilder sb, String ws) { + sb.append(ws); + sb.append(toString()); + sb.append("\n"); + + ArrayList children = getChildren(); + if (children != null) { + for (Node node : getChildren()) { + if (node instanceof ASTNode) { + ((ASTNode) node).dump(sb, ws + " "); + } else { + sb.append(ws); + sb.append(" NON-ASTNODE!!"); + sb.append("\n"); + } + } + } + return sb; + } + + private ASTNode getRootNodeWithValidASTStr(boolean useMemoizedRoot) { + if (useMemoizedRoot && rootNode != null && rootNode.parent == null && + rootNode.hasValidMemoizedString()) { + return rootNode; + } + ASTNode retNode = this; + while (retNode.parent != null) { + retNode = (ASTNode) retNode.parent; + } + rootNode=retNode; + if (!rootNode.isValidASTStr) { + rootNode.astStr = new StringBuffer(); + rootNode.toStringTree(rootNode); + rootNode.isValidASTStr = true; + } + return retNode; + } + + private boolean hasValidMemoizedString() { + return isValidASTStr && astStr != null; + } + + private void resetRootInformation() { + // Reset the previously stored rootNode string + if (rootNode != null) { + rootNode.astStr = null; + rootNode.isValidASTStr = false; + } + } + + private int getMemoizedStringLen() { + return astStr == null ? 0 : astStr.length(); + } + + private String getMemoizedSubString(int start, int end) { + return (astStr == null || start < 0 || end > astStr.length() || start >= end) ? null : + astStr.subSequence(start, end).toString(); + } + + private void addtoMemoizedString(String string) { + if (astStr == null) { + astStr = new StringBuffer(); + } + astStr.append(string); + } + + @Override + public void setParent(Tree t) { + super.setParent(t); + resetRootInformation(); + } + + @Override + public void addChild(Tree t) { + super.addChild(t); + resetRootInformation(); + } + + @Override + public void addChildren(List kids) { + super.addChildren(kids); + resetRootInformation(); + } + + @Override + public void setChild(int i, Tree t) { + super.setChild(i, t); + resetRootInformation(); + } + + @Override + public void insertChild(int i, Object t) { + super.insertChild(i, t); + resetRootInformation(); + } + + @Override + public Object deleteChild(int i) { + Object ret = super.deleteChild(i); + resetRootInformation(); + return ret; + } + + @Override + public void replaceChildren(int startChildIndex, int stopChildIndex, Object t) { + super.replaceChildren(startChildIndex, stopChildIndex, t); + resetRootInformation(); + } + + @Override + public String toStringTree() { + + // The root might have changed because of tree modifications. + // Compute the new root for this tree and set the astStr. + getRootNodeWithValidASTStr(true); + + // If rootNotModified is false, then startIndx and endIndx will be stale. + if (startIndx >= 0 && endIndx <= rootNode.getMemoizedStringLen()) { + return rootNode.getMemoizedSubString(startIndx, endIndx); + } + return toStringTree(rootNode); + } + + private String toStringTree(ASTNode rootNode) { + this.rootNode = rootNode; + startIndx = rootNode.getMemoizedStringLen(); + // Leaf node + if ( children==null || children.size()==0 ) { + rootNode.addtoMemoizedString(this.toString()); + endIndx = rootNode.getMemoizedStringLen(); + return this.toString(); + } + if ( !isNil() ) { + rootNode.addtoMemoizedString("("); + rootNode.addtoMemoizedString(this.toString()); + rootNode.addtoMemoizedString(" "); + } + for (int i = 0; children!=null && i < children.size(); i++) { + ASTNode t = (ASTNode)children.get(i); + if ( i>0 ) { + rootNode.addtoMemoizedString(" "); + } + t.toStringTree(rootNode); + } + if ( !isNil() ) { + rootNode.addtoMemoizedString(")"); + } + endIndx = rootNode.getMemoizedStringLen(); + return rootNode.getMemoizedSubString(startIndx, endIndx); + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java new file mode 100644 index 0000000000000..c77198b087cbd --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.util.ArrayList; +import org.antlr.runtime.ANTLRStringStream; +import org.antlr.runtime.CharStream; +import org.antlr.runtime.NoViableAltException; +import org.antlr.runtime.RecognitionException; +import org.antlr.runtime.Token; +import org.antlr.runtime.TokenRewriteStream; +import org.antlr.runtime.TokenStream; +import org.antlr.runtime.tree.CommonTree; +import org.antlr.runtime.tree.CommonTreeAdaptor; +import org.antlr.runtime.tree.TreeAdaptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.hadoop.hive.ql.Context; + +/** + * ParseDriver. + * + */ +public class ParseDriver { + + private static final Logger LOG = LoggerFactory.getLogger("hive.ql.parse.ParseDriver"); + + /** + * ANTLRNoCaseStringStream. + * + */ + //This class provides and implementation for a case insensitive token checker + //for the lexical analysis part of antlr. By converting the token stream into + //upper case at the time when lexical rules are checked, this class ensures that the + //lexical rules need to just match the token with upper case letters as opposed to + //combination of upper case and lower case characters. This is purely used for matching lexical + //rules. The actual token text is stored in the same way as the user input without + //actually converting it into an upper case. The token values are generated by the consume() + //function of the super class ANTLRStringStream. The LA() function is the lookahead function + //and is purely used for matching lexical rules. This also means that the grammar will only + //accept capitalized tokens in case it is run from other tools like antlrworks which + //do not have the ANTLRNoCaseStringStream implementation. + public class ANTLRNoCaseStringStream extends ANTLRStringStream { + + public ANTLRNoCaseStringStream(String input) { + super(input); + } + + @Override + public int LA(int i) { + + int returnChar = super.LA(i); + if (returnChar == CharStream.EOF) { + return returnChar; + } else if (returnChar == 0) { + return returnChar; + } + + return Character.toUpperCase((char) returnChar); + } + } + + /** + * HiveLexerX. + * + */ + public class HiveLexerX extends SparkSqlLexer { + + private final ArrayList errors; + + public HiveLexerX(CharStream input) { + super(input); + errors = new ArrayList(); + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + errors.add(new ParseError(this, e, tokenNames)); + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + if (e instanceof NoViableAltException) { + // @SuppressWarnings("unused") + // NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "character " + getCharErrorDisplay(e.c) + " not supported here"; + } else { + msg = super.getErrorMessage(e, tokenNames); + } + + return msg; + } + + public ArrayList getErrors() { + return errors; + } + + } + + /** + * Tree adaptor for making antlr return ASTNodes instead of CommonTree nodes + * so that the graph walking algorithms and the rules framework defined in + * ql.lib can be used with the AST Nodes. + */ + public static final TreeAdaptor adaptor = new CommonTreeAdaptor() { + /** + * Creates an ASTNode for the given token. The ASTNode is a wrapper around + * antlr's CommonTree class that implements the Node interface. + * + * @param payload + * The token. + * @return Object (which is actually an ASTNode) for the token. + */ + @Override + public Object create(Token payload) { + return new ASTNode(payload); + } + + @Override + public Object dupNode(Object t) { + + return create(((CommonTree)t).token); + }; + + @Override + public Object errorNode(TokenStream input, Token start, Token stop, RecognitionException e) { + return new ASTErrorNode(input, start, stop, e); + }; + }; + + public ASTNode parse(String command) throws ParseException { + return parse(command, null); + } + + public ASTNode parse(String command, Context ctx) + throws ParseException { + return parse(command, ctx, true); + } + + /** + * Parses a command, optionally assigning the parser's token stream to the + * given context. + * + * @param command + * command to parse + * + * @param ctx + * context with which to associate this parser's token stream, or + * null if either no context is available or the context already has + * an existing stream + * + * @return parsed AST + */ + public ASTNode parse(String command, Context ctx, boolean setTokenRewriteStream) + throws ParseException { + LOG.info("Parsing command: " + command); + + HiveLexerX lexer = new HiveLexerX(new ANTLRNoCaseStringStream(command)); + TokenRewriteStream tokens = new TokenRewriteStream(lexer); + if (ctx != null) { + if ( setTokenRewriteStream) { + ctx.setTokenRewriteStream(tokens); + } + lexer.setHiveConf(ctx.getConf()); + } + SparkSqlParser parser = new SparkSqlParser(tokens); + if (ctx != null) { + parser.setHiveConf(ctx.getConf()); + } + parser.setTreeAdaptor(adaptor); + SparkSqlParser.statement_return r = null; + try { + r = parser.statement(); + } catch (RecognitionException e) { + e.printStackTrace(); + throw new ParseException(parser.errors); + } + + if (lexer.getErrors().size() == 0 && parser.errors.size() == 0) { + LOG.info("Parse Completed"); + } else if (lexer.getErrors().size() != 0) { + throw new ParseException(lexer.getErrors()); + } else { + throw new ParseException(parser.errors); + } + + ASTNode tree = (ASTNode) r.getTree(); + tree.setUnknownTokenBoundaries(); + return tree; + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java new file mode 100644 index 0000000000000..b47bcfb2914df --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.antlr.runtime.BaseRecognizer; +import org.antlr.runtime.RecognitionException; + +/** + * + */ +public class ParseError { + private final BaseRecognizer br; + private final RecognitionException re; + private final String[] tokenNames; + + ParseError(BaseRecognizer br, RecognitionException re, String[] tokenNames) { + this.br = br; + this.re = re; + this.tokenNames = tokenNames; + } + + BaseRecognizer getBaseRecognizer() { + return br; + } + + RecognitionException getRecognitionException() { + return re; + } + + String[] getTokenNames() { + return tokenNames; + } + + String getMessage() { + return br.getErrorHeader(re) + " " + br.getErrorMessage(re, tokenNames); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java new file mode 100644 index 0000000000000..fff891ced5550 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.util.ArrayList; + +/** + * ParseException. + * + */ +public class ParseException extends Exception { + + private static final long serialVersionUID = 1L; + ArrayList errors; + + public ParseException(ArrayList errors) { + super(); + this.errors = errors; + } + + @Override + public String getMessage() { + + StringBuilder sb = new StringBuilder(); + for (ParseError err : errors) { + if (sb.length() > 0) { + sb.append('\n'); + } + sb.append(err.getMessage()); + } + + return sb.toString(); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java new file mode 100644 index 0000000000000..a5c2998f86cc1 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + + +/** + * Library of utility functions used in the parse code. + * + */ +public final class ParseUtils { + /** + * Performs a descent of the leftmost branch of a tree, stopping when either a + * node with a non-null token is found or the leaf level is encountered. + * + * @param tree + * candidate node from which to start searching + * + * @return node at which descent stopped + */ + public static ASTNode findRootNonNullToken(ASTNode tree) { + while ((tree.getToken() == null) && (tree.getChildCount() > 0)) { + tree = (org.apache.spark.sql.parser.ASTNode) tree.getChild(0); + } + return tree; + } + + private ParseUtils() { + // prevent instantiation + } + + public static VarcharTypeInfo getVarcharTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() != 1) { + throw new SemanticException("Bad params for type varchar"); + } + + String lengthStr = node.getChild(0).getText(); + return TypeInfoFactory.getVarcharTypeInfo(Integer.valueOf(lengthStr)); + } + + public static CharTypeInfo getCharTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() != 1) { + throw new SemanticException("Bad params for type char"); + } + + String lengthStr = node.getChild(0).getText(); + return TypeInfoFactory.getCharTypeInfo(Integer.valueOf(lengthStr)); + } + + public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() > 2) { + throw new SemanticException("Bad params for type decimal"); + } + + int precision = HiveDecimal.USER_DEFAULT_PRECISION; + int scale = HiveDecimal.USER_DEFAULT_SCALE; + + if (node.getChildCount() >= 1) { + String precStr = node.getChild(0).getText(); + precision = Integer.valueOf(precStr); + } + + if (node.getChildCount() == 2) { + String scaleStr = node.getChild(1).getText(); + scale = Integer.valueOf(scaleStr); + } + + return TypeInfoFactory.getDecimalTypeInfo(precision, scale); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java new file mode 100644 index 0000000000000..4b2015e0df84e --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java @@ -0,0 +1,406 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.antlr.runtime.tree.Tree; +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.ErrorMsg; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + +/** + * SemanticAnalyzer. + * + */ +public abstract class SemanticAnalyzer { + public static String charSetString(String charSetName, String charSetString) + throws SemanticException { + try { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), + charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + String res = new String(bArray, charSetName); + return res; + } + } catch (UnsupportedEncodingException e) { + throw new SemanticException(e); + } + } + + /** + * Remove the encapsulating "`" pair from the identifier. We allow users to + * use "`" to escape identifier for table names, column names and aliases, in + * case that coincide with Hive language keywords. + */ + public static String unescapeIdentifier(String val) { + if (val == null) { + return null; + } + if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { + val = val.substring(1, val.length() - 1); + } + return val; + } + + /** + * Converts parsed key/value properties pairs into a map. + * + * @param prop ASTNode parent of the key/value pairs + * + * @param mapProp property map which receives the mappings + */ + public static void readProps( + ASTNode prop, Map mapProp) { + + for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { + String key = unescapeSQLString(prop.getChild(propChild).getChild(0) + .getText()); + String value = null; + if (prop.getChild(propChild).getChild(1) != null) { + value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); + } + mapProp.put(key, value); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } + + /** + * Get the list of FieldSchema out of the ASTNode. + */ + public static List getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { + List colList = new ArrayList(); + int numCh = ast.getChildCount(); + for (int i = 0; i < numCh; i++) { + FieldSchema col = new FieldSchema(); + ASTNode child = (ASTNode) ast.getChild(i); + Tree grandChild = child.getChild(0); + if(grandChild != null) { + String name = grandChild.getText(); + if(lowerCase) { + name = name.toLowerCase(); + } + // child 0 is the name of the column + col.setName(unescapeIdentifier(name)); + // child 1 is the type of the column + ASTNode typeChild = (ASTNode) (child.getChild(1)); + col.setType(getTypeStringFromAST(typeChild)); + + // child 2 is the optional comment of the column + if (child.getChildCount() == 3) { + col.setComment(unescapeSQLString(child.getChild(2).getText())); + } + } + colList.add(col); + } + return colList; + } + + protected static String getTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + switch (typeNode.getType()) { + case SparkSqlParser.TOK_LIST: + return serdeConstants.LIST_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; + case SparkSqlParser.TOK_MAP: + return serdeConstants.MAP_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," + + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; + case SparkSqlParser.TOK_STRUCT: + return getStructTypeStringFromAST(typeNode); + case SparkSqlParser.TOK_UNIONTYPE: + return getUnionTypeStringFromAST(typeNode); + default: + return getTypeName(typeNode); + } + } + + private static String getStructTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty struct not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + ASTNode child = (ASTNode) typeNode.getChild(i); + buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); + buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); + if (i < children - 1) { + buffer.append(","); + } + } + + buffer.append(">"); + return buffer.toString(); + } + + private static String getUnionTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty union not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); + if (i < children - 1) { + buffer.append(","); + } + } + buffer.append(">"); + typeStr = buffer.toString(); + return typeStr; + } + + public static String getAstNodeText(ASTNode tree) { + return tree.getChildCount() == 0?tree.getText() : + getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); + } + + public static String generateErrorMessage(ASTNode ast, String message) { + StringBuilder sb = new StringBuilder(); + if (ast == null) { + sb.append(message).append(". Cannot tell the position of null AST."); + return sb.toString(); + } + sb.append(ast.getLine()); + sb.append(":"); + sb.append(ast.getCharPositionInLine()); + sb.append(" "); + sb.append(message); + sb.append(". Error encountered near token '"); + sb.append(getAstNodeText(ast)); + sb.append("'"); + return sb.toString(); + } + + private static final Map TokenToTypeName = new HashMap(); + + static { + TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); + } + + public static String getTypeName(ASTNode node) throws SemanticException { + int token = node.getType(); + String typeName; + + // datetime type isn't currently supported + if (token == SparkSqlParser.TOK_DATETIME) { + throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); + } + + switch (token) { + case SparkSqlParser.TOK_CHAR: + CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); + typeName = charTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_VARCHAR: + VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); + typeName = varcharTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_DECIMAL: + DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); + typeName = decTypeInfo.getQualifiedName(); + break; + default: + typeName = TokenToTypeName.get(token); + } + return typeName; + } + + public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { + boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); + if (testMode) { + URI uri = new Path(location).toUri(); + String scheme = uri.getScheme(); + String authority = uri.getAuthority(); + String path = uri.getPath(); + if (!path.startsWith("/")) { + path = (new Path(System.getProperty("test.tmp.dir"), + path)).toUri().getPath(); + } + if (StringUtils.isEmpty(scheme)) { + scheme = "pfile"; + } + try { + uri = new URI(scheme, authority, path, null, null); + } catch (URISyntaxException e) { + throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); + } + return uri.toString(); + } else { + //no-op for non-test mode for now + return location; + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0e89928cb636d..b1d841d1b5543 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -27,28 +27,28 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse._ +import org.apache.hadoop.hive.ql.parse.SemanticException import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, catalyst} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} +import org.apache.spark.sql.parser._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -227,7 +227,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren.asJava) + newChildren.foreach(n.addChild(_)) n } @@ -273,7 +273,8 @@ private[hive] object HiveQl extends Logging { private def createContext(): Context = new Context(hiveConf) private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) + ParseUtils.findRootNonNullToken( + (new ParseDriver).parse(sql, context)) /** * Returns the HiveConf @@ -312,7 +313,7 @@ private[hive] object HiveQl extends Logging { context.clear() plan } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => + case pe: ParseException => pe.getMessage match { case errorRegEx(line, start, message) => throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) @@ -337,7 +338,8 @@ private[hive] object HiveQl extends Logging { val tree = try { ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) + (new ParseDriver) + .parse(ddl, null /* no context required for parsing alone */)) } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) @@ -598,12 +600,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { tableType match { - case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { - nameParts.head match { + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => { + nameParts match { case Token(".", dbName :: tableName :: Nil) => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts.head) + val tableIdent = extractTableIdent(nameParts) DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => @@ -662,7 +664,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { val schema = maybeColumns.map { cols => - BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + SemanticAnalyzer.getColumns(cols, true).asScala.map { field => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. HiveColumn(field.getName, null, field.getComment) @@ -678,7 +680,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeComment.foreach { case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + val comment = SemanticAnalyzer.unescapeSQLString(child.getText) if (comment ne null) { properties += ("comment" -> comment) } @@ -750,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C children.collect { case list @ Token("TOK_TABCOLLIST", _) => - val cols = BaseSemanticAnalyzer.getColumns(list, true) + val cols = SemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( schema = cols.asScala.map { field => @@ -758,11 +760,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }) } case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + val comment = SemanticAnalyzer.unescapeSQLString(child.getText) // TODO support the sql text tableDesc = tableDesc.copy(viewText = Option(comment)) case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = BaseSemanticAnalyzer.getColumns(list(0), false) + val cols = SemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( partitionColumns = cols.asScala.map { field => @@ -773,21 +775,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) + val fieldDelim = SemanticAnalyzer.unescapeSQLString (rowChild1.getText()) serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) if (rowChild2.length > 1) { - val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) + val fieldEscape = SemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val collItemDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val mapKeyDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val lineDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) if (!(lineDelim == "\n") && !(lineDelim == "10")) { throw new AnalysisException( SemanticAnalyzer.generateErrorMessage( @@ -796,22 +798,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val nullFormat = SemanticAnalyzer.unescapeSQLString(rowChild.getText) // TODO support the nullFormat case _ => assert(false) } tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => - var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - location = EximUtil.relativeToAbsolutePath(hiveConf, location) + var location = SemanticAnalyzer.unescapeSQLString(child.getText) + location = SemanticAnalyzer.relativeToAbsolutePath(hiveConf, location) tableDesc = tableDesc.copy(location = Option(location)) case Token("TOK_TABLESERIALIZER", child :: Nil) => tableDesc = tableDesc.copy( - serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) + serde = Option(SemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) if (child.getChildCount == 2) { val serdeParams = new java.util.HashMap[String, String]() - BaseSemanticAnalyzer.readProps( + SemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) @@ -891,9 +893,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), + Option(SemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), outputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) + Option(SemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) case Token("TOK_STORAGEHANDLER", _) => throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) case _ => // Unsupport features @@ -909,24 +911,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) - if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + if Seq("TOK_CTE", "TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - // check if has CTE - insertClauses.last match { - case Token("TOK_CTE", cteClauses) => - val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - (relation.alias, relation) - }).toMap - (Some(args.head), insertClauses.init, Some(cteRelations)) - - case _ => (Some(args.head), insertClauses, None) + case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => + val cteRelations = ctes.map { node => + val relation = nodeToRelation(node, context).asInstanceOf[Subquery] + relation.alias -> relation } - - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + (Some(from.head), inserts, Some(cteRelations.toMap)) + case Token("TOK_FROM", from) :: inserts => + (Some(from.head), inserts, None) + case Token("TOK_INSERT", _) :: Nil => + (None, queryArgs, None) } // Return one query for each insert clause. @@ -1025,20 +1023,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) + (Nil, Some(SemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (BaseSemanticAnalyzer.unescapeSQLString(name), - BaseSemanticAnalyzer.unescapeSQLString(value)) + (SemanticAnalyzer.unescapeSQLString(name), + SemanticAnalyzer.unescapeSQLString(value)) } // SPARK-10310: Special cases LazySimpleSerDe // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val unescapedSerDeClass = SemanticAnalyzer.unescapeSQLString(serdeClass) val useDefaultRecordReaderWriter = unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) @@ -1055,7 +1053,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = matchSerDe(outputSerdeClause) - val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + val unescapedScript = SemanticAnalyzer.unescapeSQLString(script) // TODO Adds support for user-defined record reader/writer classes val recordReaderClass = if (useDefaultRecordReader) { @@ -1361,6 +1359,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_ANTIJOIN" => throw new NotImplementedError("Anti join not supported") } Join(nodeToRelation(relation1, context), nodeToRelation(relation2, context), @@ -1475,11 +1474,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val numericAstTypes = Seq( - HiveParser.Number, - HiveParser.TinyintLiteral, - HiveParser.SmallintLiteral, - HiveParser.BigintLiteral, - HiveParser.DecimalLiteral) + SparkSqlParser.Number, + SparkSqlParser.TinyintLiteral, + SparkSqlParser.SmallintLiteral, + SparkSqlParser.BigintLiteral, + SparkSqlParser.DecimalLiteral) /* Case insensitive matches */ val COUNT = "(?i)COUNT".r @@ -1649,7 +1648,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(TRUE(), Nil) => Literal.create(true, BooleanType) case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + Literal(strings.map(s => SemanticAnalyzer.unescapeSQLString(s.getText)).mkString) // This code is adapted from // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 @@ -1684,37 +1683,37 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C v } - case ast: ASTNode if ast.getType == HiveParser.StringLiteral => - Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == SparkSqlParser.StringLiteral => + Literal(SemanticAnalyzer.unescapeSQLString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => - Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_CHARSETLITERAL => + Literal(SemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => Literal(CalendarInterval.fromYearMonthString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => Literal(CalendarInterval.fromDayTimeString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) case a: ASTNode => From 7ab0e2289df39eed03b0b8c8b2b350faf2b4e4ee Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 29 Dec 2015 19:54:10 -0800 Subject: [PATCH 1373/2089] [SPARK-12490][CORE] Limit the css style scope to fix the Streaming UI #10441 broke the Streaming UI because of the new CSS style. screen shot 2015-12-29 at 4 49 04 pm This PR just added a class for the new style and only applied them to the paged tables. Author: Shixiong Zhu Closes #10517 from zsxwing/fix-streaming-ui. --- core/src/main/resources/org/apache/spark/ui/static/webui.css | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 3 ++- core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index dd708ef2c29b5..48f86d1536c99 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -227,7 +227,7 @@ a.expandbutton { cursor:pointer; } -th a, th a:hover { +.table-head-clickable th a, .table-head-clickable th a:hover { /* Make the entire header clickable, not just the text label */ display: block; width: 100%; diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index b02b99a6fc7aa..08e7576b0c08e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1233,7 +1233,8 @@ private[ui] class TaskPagedTable( override def tableId: String = "task-table" - override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def tableCssClass: String = + "table table-bordered table-condensed table-striped table-head-clickable" override def pageSizeFormField: String = "task.pageSize" diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 3e51ce2e97994..606d15d599e81 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -237,7 +237,8 @@ private[ui] class BlockPagedTable( override def tableId: String = "rdd-storage-by-block-table" - override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def tableCssClass: String = + "table table-bordered table-condensed table-striped table-head-clickable" override def pageSizeFormField: String = "block.pageSize" From 4f75f785df0e59ca5ae48e86f3dfc00b45d96b18 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 29 Dec 2015 22:28:59 -0800 Subject: [PATCH 1374/2089] [SPARK-12564][SQL] Improve missing column AnalysisException ``` org.apache.spark.sql.AnalysisException: cannot resolve 'value' given input columns text; ``` lets put a `:` after `columns` and put the columns in `[]` so that they match the toString of DataFrame. Author: gatorsmile Closes #10518 from gatorsmile/improveAnalysisExceptionMsg. --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 440f679913802..a1be1473cc80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -57,7 +57,7 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7fe66e461c140..c19b5a4d98a85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -514,7 +514,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { ds.as[ClassData2].collect() } - assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) + assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) } test("runtime nullability check") { From 27af6157f9cceeb9aa74eec54c8898d3e0749ed0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Dec 2015 00:08:44 -0800 Subject: [PATCH 1375/2089] Revert "[SPARK-12362][SQL][WIP] Inline Hive Parser" This reverts commit b600bccf41a7b1958e33d8301a19214e6517e388 due to non-deterministic build breaks. --- pom.xml | 5 - project/SparkBuild.scala | 2 +- project/plugins.sbt | 4 - .../execution/HiveCompatibilitySuite.scala | 10 +- sql/hive/pom.xml | 22 - .../spark/sql/parser/FromClauseParser.g | 330 --- .../spark/sql/parser/IdentifiersParser.g | 697 ----- .../spark/sql/parser/SelectClauseParser.g | 226 -- .../apache/spark/sql/parser/SparkSqlLexer.g | 474 ---- .../apache/spark/sql/parser/SparkSqlParser.g | 2457 ----------------- .../apache/spark/sql/parser/ASTErrorNode.java | 49 - .../org/apache/spark/sql/parser/ASTNode.java | 245 -- .../apache/spark/sql/parser/ParseDriver.java | 213 -- .../apache/spark/sql/parser/ParseError.java | 54 - .../spark/sql/parser/ParseException.java | 51 - .../apache/spark/sql/parser/ParseUtils.java | 96 - .../spark/sql/parser/SemanticAnalyzer.java | 406 --- .../org/apache/spark/sql/hive/HiveQl.scala | 133 +- 18 files changed, 72 insertions(+), 5402 deletions(-) delete mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g delete mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g delete mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g delete mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g delete mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java diff --git a/pom.xml b/pom.xml index 73ba8d555a90c..284c219519bca 100644 --- a/pom.xml +++ b/pom.xml @@ -1951,11 +1951,6 @@ - - org.antlr - antlr3-maven-plugin - 3.5.2 - org.apache.maven.plugins diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index df21d3eb636f0..c3d53f835f395 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -415,7 +415,7 @@ object Hive { // in order to generate golden files. This is only required for developers who are adding new // new query tests. fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } - ) ++ sbtantlr.SbtAntlrPlugin.antlrSettings + ) } diff --git a/project/plugins.sbt b/project/plugins.sbt index f172dc9c1f0e3..5e23224cf8aa5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -4,8 +4,6 @@ resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/release resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" -resolvers += "stefri" at "http://stefri.github.io/repo/releases" - addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") @@ -26,8 +24,6 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") -addSbtPlugin("com.github.stefri" % "sbt-antlr" % "0.5.3") - libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2b0e48dbfcf28..2d0d7b8af3581 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -308,12 +308,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The difference between the double numbers generated by Hive and Spark // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) - "udaf_corr", - - // Feature removed in HIVE-11145 - "alter_partition_protect_mode", - "drop_partitions_ignore_protection", - "protectmode" + "udaf_corr" ) /** @@ -333,6 +328,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_index", "alter_merge_2", "alter_partition_format_loc", + "alter_partition_protect_mode", "alter_partition_with_whitelist", "alter_rename_partition", "alter_table_serde", @@ -464,6 +460,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_filter", "drop_partitions_filter2", "drop_partitions_filter3", + "drop_partitions_ignore_protection", "drop_table", "drop_table2", "drop_table_removes_partition_dirs", @@ -781,6 +778,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", + "protectmode", "push_or", "query_with_semi", "quote1", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ffabb92179a18..e9885f6682028 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -232,7 +232,6 @@ v${hive.version.short}/src/main/scala - ${project.build.directory/generated-sources/antlr @@ -261,27 +260,6 @@ - - - - org.antlr - antlr3-maven-plugin - - - - antlr - - - - - ${basedir}/src/main/antlr3 - - **/SparkSqlLexer.g - **/SparkSqlParser.g - - - - diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g deleted file mode 100644 index e4a80f0ce8ebf..0000000000000 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g +++ /dev/null @@ -1,330 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -parser grammar FromClauseParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------------------------------------------------------------------- - -tableAllColumns - : STAR - -> ^(TOK_ALLCOLREF) - | tableName DOT STAR - -> ^(TOK_ALLCOLREF tableName) - ; - -// (table|column) -tableOrColumn -@init { gParent.pushMsg("table or column identifier", state); } -@after { gParent.popMsg(state); } - : - identifier -> ^(TOK_TABLE_OR_COL identifier) - ; - -expressionList -@init { gParent.pushMsg("expression list", state); } -@after { gParent.popMsg(state); } - : - expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) - ; - -aliasList -@init { gParent.pushMsg("alias list", state); } -@after { gParent.popMsg(state); } - : - identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) - ; - -//----------------------- Rules for parsing fromClause ------------------------------ -// from [col1, col2, col3] table1, [col4, col5] table2 -fromClause -@init { gParent.pushMsg("from clause", state); } -@after { gParent.popMsg(state); } - : - KW_FROM joinSource -> ^(TOK_FROM joinSource) - ; - -joinSource -@init { gParent.pushMsg("join source", state); } -@after { gParent.popMsg(state); } - : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )* - | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ - ; - -uniqueJoinSource -@init { gParent.pushMsg("unique join source", state); } -@after { gParent.popMsg(state); } - : KW_PRESERVE? fromSource uniqueJoinExpr - ; - -uniqueJoinExpr -@init { gParent.pushMsg("unique join expression list", state); } -@after { gParent.popMsg(state); } - : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN - -> ^(TOK_EXPLIST $e1*) - ; - -uniqueJoinToken -@init { gParent.pushMsg("unique join", state); } -@after { gParent.popMsg(state); } - : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; - -joinToken -@init { gParent.pushMsg("join type specifier", state); } -@after { gParent.popMsg(state); } - : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN - ; - -lateralView -@init {gParent.pushMsg("lateral view", state); } -@after {gParent.popMsg(state); } - : - (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? - -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) - | - KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? - -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) - ; - -tableAlias -@init {gParent.pushMsg("table alias", state); } -@after {gParent.popMsg(state); } - : - identifier -> ^(TOK_TABALIAS identifier) - ; - -fromSource -@init { gParent.pushMsg("from source", state); } -@after { gParent.popMsg(state); } - : - (LPAREN KW_VALUES) => fromSource0 - | (LPAREN) => LPAREN joinSource RPAREN -> joinSource - | fromSource0 - ; - - -fromSource0 -@init { gParent.pushMsg("from source 0", state); } -@after { gParent.popMsg(state); } - : - ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* - ; - -tableBucketSample -@init { gParent.pushMsg("table bucket sample specification", state); } -@after { gParent.popMsg(state); } - : - KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) - ; - -splitSample -@init { gParent.pushMsg("table split sample specification", state); } -@after { gParent.popMsg(state); } - : - KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN - -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) - -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) - | - KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN - -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) - ; - -tableSample -@init { gParent.pushMsg("table sample specification", state); } -@after { gParent.popMsg(state); } - : - tableBucketSample | - splitSample - ; - -tableSource -@init { gParent.pushMsg("table source", state); } -@after { gParent.popMsg(state); } - : tabname=tableName - ((tableProperties) => props=tableProperties)? - ((tableSample) => ts=tableSample)? - ((KW_AS) => (KW_AS alias=Identifier) - | - (Identifier) => (alias=Identifier))? - -> ^(TOK_TABREF $tabname $props? $ts? $alias?) - ; - -tableName -@init { gParent.pushMsg("table name", state); } -@after { gParent.popMsg(state); } - : - db=identifier DOT tab=identifier - -> ^(TOK_TABNAME $db $tab) - | - tab=identifier - -> ^(TOK_TABNAME $tab) - ; - -viewName -@init { gParent.pushMsg("view name", state); } -@after { gParent.popMsg(state); } - : - (db=identifier DOT)? view=identifier - -> ^(TOK_TABNAME $db? $view) - ; - -subQuerySource -@init { gParent.pushMsg("subquery source", state); } -@after { gParent.popMsg(state); } - : - LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) - ; - -//---------------------- Rules for parsing PTF clauses ----------------------------- -partitioningSpec -@init { gParent.pushMsg("partitioningSpec clause", state); } -@after { gParent.popMsg(state); } - : - partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | - orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | - distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | - sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | - clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) - ; - -partitionTableFunctionSource -@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } -@after { gParent.popMsg(state); } - : - subQuerySource | - tableSource | - partitionedTableFunction - ; - -partitionedTableFunction -@init { gParent.pushMsg("ptf clause", state); } -@after { gParent.popMsg(state); } - : - name=Identifier LPAREN KW_ON - ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) - ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? - ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? - -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) - ; - -//----------------------- Rules for parsing whereClause ----------------------------- -// where a=b and ... -whereClause -@init { gParent.pushMsg("where clause", state); } -@after { gParent.popMsg(state); } - : - KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) - ; - -searchCondition -@init { gParent.pushMsg("search condition", state); } -@after { gParent.popMsg(state); } - : - expression - ; - -//----------------------------------------------------------------------------------- - -//-------- Row Constructor ---------------------------------------------------------- -//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and -// INSERT INTO
    (col1,col2,...) VALUES(...),(...),... -// INSERT INTO
    (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) -valueRowConstructor -@init { gParent.pushMsg("value row constructor", state); } -@after { gParent.popMsg(state); } - : - LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) - ; - -valuesTableConstructor -@init { gParent.pushMsg("values table constructor", state); } -@after { gParent.popMsg(state); } - : - valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) - ; - -/* -VALUES(1),(2) means 2 rows, 1 column each. -VALUES(1,2),(3,4) means 2 rows, 2 columns each. -VALUES(1,2,3) means 1 row, 3 columns -*/ -valuesClause -@init { gParent.pushMsg("values clause", state); } -@after { gParent.popMsg(state); } - : - KW_VALUES valuesTableConstructor -> valuesTableConstructor - ; - -/* -This represents a clause like this: -(VALUES(1,2),(2,3)) as VirtTable(col1,col2) -*/ -virtualTableSource -@init { gParent.pushMsg("virtual table source", state); } -@after { gParent.popMsg(state); } - : - LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) - ; -/* -e.g. as VirtTable(col1,col2) -Note that we only want literals as column names -*/ -tableNameColList -@init { gParent.pushMsg("from source", state); } -@after { gParent.popMsg(state); } - : - KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) - ; - -//----------------------------------------------------------------------------------- diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g deleted file mode 100644 index 5c3d7ef866240..0000000000000 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g +++ /dev/null @@ -1,697 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -parser grammar IdentifiersParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------------------------------------------------------------------- - -// group by a,b -groupByClause -@init { gParent.pushMsg("group by clause", state); } -@after { gParent.popMsg(state); } - : - KW_GROUP KW_BY - expression - ( COMMA expression)* - ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? - (sets=KW_GROUPING KW_SETS - LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? - -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) - -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) - -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) - -> ^(TOK_GROUPBY expression+) - ; - -groupingSetExpression -@init {gParent.pushMsg("grouping set expression", state); } -@after {gParent.popMsg(state); } - : - (LPAREN) => groupingSetExpressionMultiple - | - groupingExpressionSingle - ; - -groupingSetExpressionMultiple -@init {gParent.pushMsg("grouping set part expression", state); } -@after {gParent.popMsg(state); } - : - LPAREN - expression? (COMMA expression)* - RPAREN - -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) - ; - -groupingExpressionSingle -@init { gParent.pushMsg("groupingExpression expression", state); } -@after { gParent.popMsg(state); } - : - expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) - ; - -havingClause -@init { gParent.pushMsg("having clause", state); } -@after { gParent.popMsg(state); } - : - KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) - ; - -havingCondition -@init { gParent.pushMsg("having condition", state); } -@after { gParent.popMsg(state); } - : - expression - ; - -expressionsInParenthese - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -expressionsNotInParenthese - : - expression (COMMA expression)* -> expression+ - ; - -columnRefOrderInParenthese - : - LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ - ; - -columnRefOrderNotInParenthese - : - columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ - ; - -// order by a,b -orderByClause -@init { gParent.pushMsg("order by clause", state); } -@after { gParent.popMsg(state); } - : - KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) - ; - -clusterByClause -@init { gParent.pushMsg("cluster by clause", state); } -@after { gParent.popMsg(state); } - : - KW_CLUSTER KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) - ) - ; - -partitionByClause -@init { gParent.pushMsg("partition by clause", state); } -@after { gParent.popMsg(state); } - : - KW_PARTITION KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) - ) - ; - -distributeByClause -@init { gParent.pushMsg("distribute by clause", state); } -@after { gParent.popMsg(state); } - : - KW_DISTRIBUTE KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) - ) - ; - -sortByClause -@init { gParent.pushMsg("sort by clause", state); } -@after { gParent.popMsg(state); } - : - KW_SORT KW_BY - ( - (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) - | - columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) - ) - ; - -// fun(par1, par2, par3) -function -@init { gParent.pushMsg("function specification", state); } -@after { gParent.popMsg(state); } - : - functionName - LPAREN - ( - (STAR) => (star=STAR) - | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? - ) - RPAREN (KW_OVER ws=window_specification)? - -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) - -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) - -> ^(TOK_FUNCTIONDI functionName (selectExpression+)?) - ; - -functionName -@init { gParent.pushMsg("function name", state); } -@after { gParent.popMsg(state); } - : // Keyword IF is also a function name - (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) - | - (functionIdentifier) => functionIdentifier - | - {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] - ; - -castExpression -@init { gParent.pushMsg("cast expression", state); } -@after { gParent.popMsg(state); } - : - KW_CAST - LPAREN - expression - KW_AS - primitiveType - RPAREN -> ^(TOK_FUNCTION primitiveType expression) - ; - -caseExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE expression - (KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_CASE expression*) - ; - -whenExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE - ( KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) - ; - -constant -@init { gParent.pushMsg("constant", state); } -@after { gParent.popMsg(state); } - : - Number - | dateLiteral - | timestampLiteral - | intervalLiteral - | StringLiteral - | stringLiteralSequence - | BigintLiteral - | SmallintLiteral - | TinyintLiteral - | DecimalLiteral - | charSetStringLiteral - | booleanValue - ; - -stringLiteralSequence - : - StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) - ; - -charSetStringLiteral -@init { gParent.pushMsg("character string literal", state); } -@after { gParent.popMsg(state); } - : - csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) - ; - -dateLiteral - : - KW_DATE StringLiteral -> - { - // Create DateLiteral token, but with the text of the string value - // This makes the dateLiteral more consistent with the other type literals. - adaptor.create(TOK_DATELITERAL, $StringLiteral.text) - } - | - KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) - ; - -timestampLiteral - : - KW_TIMESTAMP StringLiteral -> - { - adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) - } - | - KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) - ; - -intervalLiteral - : - KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> - { - adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) - } - ; - -intervalQualifiers - : - KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL - | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL - | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL - | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL - | KW_DAY -> TOK_INTERVAL_DAY_LITERAL - | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL - | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL - | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL - ; - -expression -@init { gParent.pushMsg("expression specification", state); } -@after { gParent.popMsg(state); } - : - precedenceOrExpression - ; - -atomExpression - : - (KW_NULL) => KW_NULL -> TOK_NULL - | (constant) => constant - | castExpression - | caseExpression - | whenExpression - | (functionName LPAREN) => function - | tableOrColumn - | LPAREN! expression RPAREN! - ; - - -precedenceFieldExpression - : - atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* - ; - -precedenceUnaryOperator - : - PLUS | MINUS | TILDE - ; - -nullCondition - : - KW_NULL -> ^(TOK_ISNULL) - | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) - ; - -precedenceUnaryPrefixExpression - : - (precedenceUnaryOperator^)* precedenceFieldExpression - ; - -precedenceUnarySuffixExpression - : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? - -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) - -> precedenceUnaryPrefixExpression - ; - - -precedenceBitwiseXorOperator - : - BITWISEXOR - ; - -precedenceBitwiseXorExpression - : - precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* - ; - - -precedenceStarOperator - : - STAR | DIVIDE | MOD | DIV - ; - -precedenceStarExpression - : - precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* - ; - - -precedencePlusOperator - : - PLUS | MINUS - ; - -precedencePlusExpression - : - precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* - ; - - -precedenceAmpersandOperator - : - AMPERSAND - ; - -precedenceAmpersandExpression - : - precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* - ; - - -precedenceBitwiseOrOperator - : - BITWISEOR - ; - -precedenceBitwiseOrExpression - : - precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* - ; - - -// Equal operators supporting NOT prefix -precedenceEqualNegatableOperator - : - KW_LIKE | KW_RLIKE | KW_REGEXP - ; - -precedenceEqualOperator - : - precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -subQueryExpression - : - LPAREN! selectStatement[true] RPAREN! - ; - -precedenceEqualExpression - : - (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple - | - precedenceEqualExpressionSingle - ; - -precedenceEqualExpressionSingle - : - (left=precedenceBitwiseOrExpression -> $left) - ( - (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) - -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) - | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) - -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) - | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) - -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) - | (KW_NOT KW_IN expressions) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) - | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) - -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) - | (KW_IN expressions) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) - | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) - | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) - )* - | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) - ; - -expressions - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) -precedenceEqualExpressionMutiple - : - (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) - ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) - | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) - ; - -expressionsToStruct - : - LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) - ; - -precedenceNotOperator - : - KW_NOT - ; - -precedenceNotExpression - : - (precedenceNotOperator^)* precedenceEqualExpression - ; - - -precedenceAndOperator - : - KW_AND - ; - -precedenceAndExpression - : - precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* - ; - - -precedenceOrOperator - : - KW_OR - ; - -precedenceOrExpression - : - precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* - ; - - -booleanValue - : - KW_TRUE^ | KW_FALSE^ - ; - -booleanValueTok - : - KW_TRUE -> TOK_TRUE - | KW_FALSE -> TOK_FALSE - ; - -tableOrPartition - : - tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) - ; - -partitionSpec - : - KW_PARTITION - LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) - ; - -partitionVal - : - identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) - ; - -dropPartitionSpec - : - KW_PARTITION - LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) - ; - -dropPartitionVal - : - identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) - ; - -dropPartitionOperator - : - EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -sysFuncNames - : - KW_AND - | KW_OR - | KW_NOT - | KW_LIKE - | KW_IF - | KW_CASE - | KW_WHEN - | KW_TINYINT - | KW_SMALLINT - | KW_INT - | KW_BIGINT - | KW_FLOAT - | KW_DOUBLE - | KW_BOOLEAN - | KW_STRING - | KW_BINARY - | KW_ARRAY - | KW_MAP - | KW_STRUCT - | KW_UNIONTYPE - | EQUAL - | EQUAL_NS - | NOTEQUAL - | LESSTHANOREQUALTO - | LESSTHAN - | GREATERTHANOREQUALTO - | GREATERTHAN - | DIVIDE - | PLUS - | MINUS - | STAR - | MOD - | DIV - | AMPERSAND - | TILDE - | BITWISEOR - | BITWISEXOR - | KW_RLIKE - | KW_REGEXP - | KW_IN - | KW_BETWEEN - ; - -descFuncNames - : - (sysFuncNames) => sysFuncNames - | StringLiteral - | functionIdentifier - ; - -identifier - : - Identifier - | nonReserved -> Identifier[$nonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -functionIdentifier -@init { gParent.pushMsg("function identifier", state); } -@after { gParent.popMsg(state); } - : db=identifier DOT fn=identifier - -> Identifier[$db.text + "." + $fn.text] - | - identifier - ; - -principalIdentifier -@init { gParent.pushMsg("identifier for principal spec", state); } -@after { gParent.popMsg(state); } - : identifier - | QuotedIdentifier - ; - -//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved -//Non reserved keywords are basically the keywords that can be used as identifiers. -//All the KW_* are automatically not only keywords, but also reserved keywords. -//That means, they can NOT be used as identifiers. -//If you would like to use them as identifiers, put them in the nonReserved list below. -//If you are not sure, please refer to the SQL2011 column in -//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html -nonReserved - : - KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS - | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS - | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY - | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY - | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE - | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT - | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE - | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR - | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG - | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE - | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY - | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER - | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE - | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED - | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED - | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED - | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET - | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR - | KW_WORK - | KW_TRANSACTION - | KW_WRITE - | KW_ISOLATION - | KW_LEVEL - | KW_SNAPSHOT - | KW_AUTOCOMMIT - | KW_ANTI -; - -//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. -sql11ReservedKeywordsUsedAsCastFunctionName - : - KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP - ; - -//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. -//We are planning to remove the following whole list after several releases. -//Thus, please do not change the following list unless you know what to do. -sql11ReservedKeywordsUsedAsIdentifier - : - KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN - | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE - | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT - | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL - | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION - | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT - | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE - | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH -//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. - | KW_REGEXP | KW_RLIKE - ; diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g deleted file mode 100644 index 48bc8b0a300af..0000000000000 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g +++ /dev/null @@ -1,226 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -parser grammar SelectClauseParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------- Rules for parsing selectClause ----------------------------- -// select a,b,c ... -selectClause -@init { gParent.pushMsg("select clause", state); } -@after { gParent.popMsg(state); } - : - KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) - | (transform=KW_TRANSFORM selectTrfmClause)) - -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) - -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) - -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) - | - trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) - ; - -selectList -@init { gParent.pushMsg("select list", state); } -@after { gParent.popMsg(state); } - : - selectItem ( COMMA selectItem )* -> selectItem+ - ; - -selectTrfmClause -@init { gParent.pushMsg("transform clause", state); } -@after { gParent.popMsg(state); } - : - LPAREN selectExpressionList RPAREN - inSerde=rowFormat inRec=recordWriter - KW_USING StringLiteral - ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? - outSerde=rowFormat outRec=recordReader - -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) - ; - -hintClause -@init { gParent.pushMsg("hint clause", state); } -@after { gParent.popMsg(state); } - : - DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) - ; - -hintList -@init { gParent.pushMsg("hint list", state); } -@after { gParent.popMsg(state); } - : - hintItem (COMMA hintItem)* -> hintItem+ - ; - -hintItem -@init { gParent.pushMsg("hint item", state); } -@after { gParent.popMsg(state); } - : - hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) - ; - -hintName -@init { gParent.pushMsg("hint name", state); } -@after { gParent.popMsg(state); } - : - KW_MAPJOIN -> TOK_MAPJOIN - | KW_STREAMTABLE -> TOK_STREAMTABLE - ; - -hintArgs -@init { gParent.pushMsg("hint arguments", state); } -@after { gParent.popMsg(state); } - : - hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) - ; - -hintArgName -@init { gParent.pushMsg("hint argument name", state); } -@after { gParent.popMsg(state); } - : - identifier - ; - -selectItem -@init { gParent.pushMsg("selection target", state); } -@after { gParent.popMsg(state); } - : - (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) - | - ( expression - ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? - ) -> ^(TOK_SELEXPR expression identifier*) - ; - -trfmClause -@init { gParent.pushMsg("transform clause", state); } -@after { gParent.popMsg(state); } - : - ( KW_MAP selectExpressionList - | KW_REDUCE selectExpressionList ) - inSerde=rowFormat inRec=recordWriter - KW_USING StringLiteral - ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? - outSerde=rowFormat outRec=recordReader - -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) - ; - -selectExpression -@init { gParent.pushMsg("select expression", state); } -@after { gParent.popMsg(state); } - : - (tableAllColumns) => tableAllColumns - | - expression - ; - -selectExpressionList -@init { gParent.pushMsg("select expression list", state); } -@after { gParent.popMsg(state); } - : - selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) - ; - -//---------------------- Rules for windowing clauses ------------------------------- -window_clause -@init { gParent.pushMsg("window_clause", state); } -@after { gParent.popMsg(state); } -: - KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) -; - -window_defn -@init { gParent.pushMsg("window_defn", state); } -@after { gParent.popMsg(state); } -: - Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) -; - -window_specification -@init { gParent.pushMsg("window_specification", state); } -@after { gParent.popMsg(state); } -: - (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) -; - -window_frame : - window_range_expression | - window_value_expression -; - -window_range_expression -@init { gParent.pushMsg("window_range_expression", state); } -@after { gParent.popMsg(state); } -: - KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | - KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) -; - -window_value_expression -@init { gParent.pushMsg("window_value_expression", state); } -@after { gParent.popMsg(state); } -: - KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | - KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) -; - -window_frame_start_boundary -@init { gParent.pushMsg("windowframestartboundary", state); } -@after { gParent.popMsg(state); } -: - KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | - KW_CURRENT KW_ROW -> ^(KW_CURRENT) | - Number KW_PRECEDING -> ^(KW_PRECEDING Number) -; - -window_frame_boundary -@init { gParent.pushMsg("windowframeboundary", state); } -@after { gParent.popMsg(state); } -: - KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | - KW_CURRENT KW_ROW -> ^(KW_CURRENT) | - Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) -; - diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g deleted file mode 100644 index ee1b8989b5aff..0000000000000 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g +++ /dev/null @@ -1,474 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -lexer grammar SparkSqlLexer; - -@lexer::header { -package org.apache.spark.sql.parser; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; -} - -@lexer::members { - private Configuration hiveConf; - - public void setHiveConf(Configuration hiveConf) { - this.hiveConf = hiveConf; - } - - protected boolean allowQuotedId() { - String supportedQIds = HiveConf.getVar(hiveConf, HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT); - return !"none".equals(supportedQIds); - } -} - -// Keywords - -KW_TRUE : 'TRUE'; -KW_FALSE : 'FALSE'; -KW_ALL : 'ALL'; -KW_NONE: 'NONE'; -KW_AND : 'AND'; -KW_OR : 'OR'; -KW_NOT : 'NOT' | '!'; -KW_LIKE : 'LIKE'; - -KW_IF : 'IF'; -KW_EXISTS : 'EXISTS'; - -KW_ASC : 'ASC'; -KW_DESC : 'DESC'; -KW_ORDER : 'ORDER'; -KW_GROUP : 'GROUP'; -KW_BY : 'BY'; -KW_HAVING : 'HAVING'; -KW_WHERE : 'WHERE'; -KW_FROM : 'FROM'; -KW_AS : 'AS'; -KW_SELECT : 'SELECT'; -KW_DISTINCT : 'DISTINCT'; -KW_INSERT : 'INSERT'; -KW_OVERWRITE : 'OVERWRITE'; -KW_OUTER : 'OUTER'; -KW_UNIQUEJOIN : 'UNIQUEJOIN'; -KW_PRESERVE : 'PRESERVE'; -KW_JOIN : 'JOIN'; -KW_LEFT : 'LEFT'; -KW_RIGHT : 'RIGHT'; -KW_FULL : 'FULL'; -KW_ANTI : 'ANTI'; -KW_ON : 'ON'; -KW_PARTITION : 'PARTITION'; -KW_PARTITIONS : 'PARTITIONS'; -KW_TABLE: 'TABLE'; -KW_TABLES: 'TABLES'; -KW_COLUMNS: 'COLUMNS'; -KW_INDEX: 'INDEX'; -KW_INDEXES: 'INDEXES'; -KW_REBUILD: 'REBUILD'; -KW_FUNCTIONS: 'FUNCTIONS'; -KW_SHOW: 'SHOW'; -KW_MSCK: 'MSCK'; -KW_REPAIR: 'REPAIR'; -KW_DIRECTORY: 'DIRECTORY'; -KW_LOCAL: 'LOCAL'; -KW_TRANSFORM : 'TRANSFORM'; -KW_USING: 'USING'; -KW_CLUSTER: 'CLUSTER'; -KW_DISTRIBUTE: 'DISTRIBUTE'; -KW_SORT: 'SORT'; -KW_UNION: 'UNION'; -KW_LOAD: 'LOAD'; -KW_EXPORT: 'EXPORT'; -KW_IMPORT: 'IMPORT'; -KW_REPLICATION: 'REPLICATION'; -KW_METADATA: 'METADATA'; -KW_DATA: 'DATA'; -KW_INPATH: 'INPATH'; -KW_IS: 'IS'; -KW_NULL: 'NULL'; -KW_CREATE: 'CREATE'; -KW_EXTERNAL: 'EXTERNAL'; -KW_ALTER: 'ALTER'; -KW_CHANGE: 'CHANGE'; -KW_COLUMN: 'COLUMN'; -KW_FIRST: 'FIRST'; -KW_AFTER: 'AFTER'; -KW_DESCRIBE: 'DESCRIBE'; -KW_DROP: 'DROP'; -KW_RENAME: 'RENAME'; -KW_TO: 'TO'; -KW_COMMENT: 'COMMENT'; -KW_BOOLEAN: 'BOOLEAN'; -KW_TINYINT: 'TINYINT'; -KW_SMALLINT: 'SMALLINT'; -KW_INT: 'INT'; -KW_BIGINT: 'BIGINT'; -KW_FLOAT: 'FLOAT'; -KW_DOUBLE: 'DOUBLE'; -KW_DATE: 'DATE'; -KW_DATETIME: 'DATETIME'; -KW_TIMESTAMP: 'TIMESTAMP'; -KW_INTERVAL: 'INTERVAL'; -KW_DECIMAL: 'DECIMAL'; -KW_STRING: 'STRING'; -KW_CHAR: 'CHAR'; -KW_VARCHAR: 'VARCHAR'; -KW_ARRAY: 'ARRAY'; -KW_STRUCT: 'STRUCT'; -KW_MAP: 'MAP'; -KW_UNIONTYPE: 'UNIONTYPE'; -KW_REDUCE: 'REDUCE'; -KW_PARTITIONED: 'PARTITIONED'; -KW_CLUSTERED: 'CLUSTERED'; -KW_SORTED: 'SORTED'; -KW_INTO: 'INTO'; -KW_BUCKETS: 'BUCKETS'; -KW_ROW: 'ROW'; -KW_ROWS: 'ROWS'; -KW_FORMAT: 'FORMAT'; -KW_DELIMITED: 'DELIMITED'; -KW_FIELDS: 'FIELDS'; -KW_TERMINATED: 'TERMINATED'; -KW_ESCAPED: 'ESCAPED'; -KW_COLLECTION: 'COLLECTION'; -KW_ITEMS: 'ITEMS'; -KW_KEYS: 'KEYS'; -KW_KEY_TYPE: '$KEY$'; -KW_LINES: 'LINES'; -KW_STORED: 'STORED'; -KW_FILEFORMAT: 'FILEFORMAT'; -KW_INPUTFORMAT: 'INPUTFORMAT'; -KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; -KW_INPUTDRIVER: 'INPUTDRIVER'; -KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; -KW_ENABLE: 'ENABLE'; -KW_DISABLE: 'DISABLE'; -KW_LOCATION: 'LOCATION'; -KW_TABLESAMPLE: 'TABLESAMPLE'; -KW_BUCKET: 'BUCKET'; -KW_OUT: 'OUT'; -KW_OF: 'OF'; -KW_PERCENT: 'PERCENT'; -KW_CAST: 'CAST'; -KW_ADD: 'ADD'; -KW_REPLACE: 'REPLACE'; -KW_RLIKE: 'RLIKE'; -KW_REGEXP: 'REGEXP'; -KW_TEMPORARY: 'TEMPORARY'; -KW_FUNCTION: 'FUNCTION'; -KW_MACRO: 'MACRO'; -KW_FILE: 'FILE'; -KW_JAR: 'JAR'; -KW_EXPLAIN: 'EXPLAIN'; -KW_EXTENDED: 'EXTENDED'; -KW_FORMATTED: 'FORMATTED'; -KW_PRETTY: 'PRETTY'; -KW_DEPENDENCY: 'DEPENDENCY'; -KW_LOGICAL: 'LOGICAL'; -KW_SERDE: 'SERDE'; -KW_WITH: 'WITH'; -KW_DEFERRED: 'DEFERRED'; -KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; -KW_DBPROPERTIES: 'DBPROPERTIES'; -KW_LIMIT: 'LIMIT'; -KW_SET: 'SET'; -KW_UNSET: 'UNSET'; -KW_TBLPROPERTIES: 'TBLPROPERTIES'; -KW_IDXPROPERTIES: 'IDXPROPERTIES'; -KW_VALUE_TYPE: '$VALUE$'; -KW_ELEM_TYPE: '$ELEM$'; -KW_DEFINED: 'DEFINED'; -KW_CASE: 'CASE'; -KW_WHEN: 'WHEN'; -KW_THEN: 'THEN'; -KW_ELSE: 'ELSE'; -KW_END: 'END'; -KW_MAPJOIN: 'MAPJOIN'; -KW_STREAMTABLE: 'STREAMTABLE'; -KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; -KW_UTC: 'UTC'; -KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; -KW_LONG: 'LONG'; -KW_DELETE: 'DELETE'; -KW_PLUS: 'PLUS'; -KW_MINUS: 'MINUS'; -KW_FETCH: 'FETCH'; -KW_INTERSECT: 'INTERSECT'; -KW_VIEW: 'VIEW'; -KW_IN: 'IN'; -KW_DATABASE: 'DATABASE'; -KW_DATABASES: 'DATABASES'; -KW_MATERIALIZED: 'MATERIALIZED'; -KW_SCHEMA: 'SCHEMA'; -KW_SCHEMAS: 'SCHEMAS'; -KW_GRANT: 'GRANT'; -KW_REVOKE: 'REVOKE'; -KW_SSL: 'SSL'; -KW_UNDO: 'UNDO'; -KW_LOCK: 'LOCK'; -KW_LOCKS: 'LOCKS'; -KW_UNLOCK: 'UNLOCK'; -KW_SHARED: 'SHARED'; -KW_EXCLUSIVE: 'EXCLUSIVE'; -KW_PROCEDURE: 'PROCEDURE'; -KW_UNSIGNED: 'UNSIGNED'; -KW_WHILE: 'WHILE'; -KW_READ: 'READ'; -KW_READS: 'READS'; -KW_PURGE: 'PURGE'; -KW_RANGE: 'RANGE'; -KW_ANALYZE: 'ANALYZE'; -KW_BEFORE: 'BEFORE'; -KW_BETWEEN: 'BETWEEN'; -KW_BOTH: 'BOTH'; -KW_BINARY: 'BINARY'; -KW_CROSS: 'CROSS'; -KW_CONTINUE: 'CONTINUE'; -KW_CURSOR: 'CURSOR'; -KW_TRIGGER: 'TRIGGER'; -KW_RECORDREADER: 'RECORDREADER'; -KW_RECORDWRITER: 'RECORDWRITER'; -KW_SEMI: 'SEMI'; -KW_LATERAL: 'LATERAL'; -KW_TOUCH: 'TOUCH'; -KW_ARCHIVE: 'ARCHIVE'; -KW_UNARCHIVE: 'UNARCHIVE'; -KW_COMPUTE: 'COMPUTE'; -KW_STATISTICS: 'STATISTICS'; -KW_USE: 'USE'; -KW_OPTION: 'OPTION'; -KW_CONCATENATE: 'CONCATENATE'; -KW_SHOW_DATABASE: 'SHOW_DATABASE'; -KW_UPDATE: 'UPDATE'; -KW_RESTRICT: 'RESTRICT'; -KW_CASCADE: 'CASCADE'; -KW_SKEWED: 'SKEWED'; -KW_ROLLUP: 'ROLLUP'; -KW_CUBE: 'CUBE'; -KW_DIRECTORIES: 'DIRECTORIES'; -KW_FOR: 'FOR'; -KW_WINDOW: 'WINDOW'; -KW_UNBOUNDED: 'UNBOUNDED'; -KW_PRECEDING: 'PRECEDING'; -KW_FOLLOWING: 'FOLLOWING'; -KW_CURRENT: 'CURRENT'; -KW_CURRENT_DATE: 'CURRENT_DATE'; -KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; -KW_LESS: 'LESS'; -KW_MORE: 'MORE'; -KW_OVER: 'OVER'; -KW_GROUPING: 'GROUPING'; -KW_SETS: 'SETS'; -KW_TRUNCATE: 'TRUNCATE'; -KW_NOSCAN: 'NOSCAN'; -KW_PARTIALSCAN: 'PARTIALSCAN'; -KW_USER: 'USER'; -KW_ROLE: 'ROLE'; -KW_ROLES: 'ROLES'; -KW_INNER: 'INNER'; -KW_EXCHANGE: 'EXCHANGE'; -KW_URI: 'URI'; -KW_SERVER : 'SERVER'; -KW_ADMIN: 'ADMIN'; -KW_OWNER: 'OWNER'; -KW_PRINCIPALS: 'PRINCIPALS'; -KW_COMPACT: 'COMPACT'; -KW_COMPACTIONS: 'COMPACTIONS'; -KW_TRANSACTIONS: 'TRANSACTIONS'; -KW_REWRITE : 'REWRITE'; -KW_AUTHORIZATION: 'AUTHORIZATION'; -KW_CONF: 'CONF'; -KW_VALUES: 'VALUES'; -KW_RELOAD: 'RELOAD'; -KW_YEAR: 'YEAR'; -KW_MONTH: 'MONTH'; -KW_DAY: 'DAY'; -KW_HOUR: 'HOUR'; -KW_MINUTE: 'MINUTE'; -KW_SECOND: 'SECOND'; -KW_START: 'START'; -KW_TRANSACTION: 'TRANSACTION'; -KW_COMMIT: 'COMMIT'; -KW_ROLLBACK: 'ROLLBACK'; -KW_WORK: 'WORK'; -KW_ONLY: 'ONLY'; -KW_WRITE: 'WRITE'; -KW_ISOLATION: 'ISOLATION'; -KW_LEVEL: 'LEVEL'; -KW_SNAPSHOT: 'SNAPSHOT'; -KW_AUTOCOMMIT: 'AUTOCOMMIT'; - -// Operators -// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. - -DOT : '.'; // generated as a part of Number rule -COLON : ':' ; -COMMA : ',' ; -SEMICOLON : ';' ; - -LPAREN : '(' ; -RPAREN : ')' ; -LSQUARE : '[' ; -RSQUARE : ']' ; -LCURLY : '{'; -RCURLY : '}'; - -EQUAL : '=' | '=='; -EQUAL_NS : '<=>'; -NOTEQUAL : '<>' | '!='; -LESSTHANOREQUALTO : '<='; -LESSTHAN : '<'; -GREATERTHANOREQUALTO : '>='; -GREATERTHAN : '>'; - -DIVIDE : '/'; -PLUS : '+'; -MINUS : '-'; -STAR : '*'; -MOD : '%'; -DIV : 'DIV'; - -AMPERSAND : '&'; -TILDE : '~'; -BITWISEOR : '|'; -BITWISEXOR : '^'; -QUESTION : '?'; -DOLLAR : '$'; - -// LITERALS -fragment -Letter - : 'a'..'z' | 'A'..'Z' - ; - -fragment -HexDigit - : 'a'..'f' | 'A'..'F' - ; - -fragment -Digit - : - '0'..'9' - ; - -fragment -Exponent - : - ('e' | 'E') ( PLUS|MINUS )? (Digit)+ - ; - -fragment -RegexComponent - : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' - | PLUS | STAR | QUESTION | MINUS | DOT - | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY - | BITWISEXOR | BITWISEOR | DOLLAR | '!' - ; - -StringLiteral - : - ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' - )+ - ; - -CharSetLiteral - : - StringLiteral - | '0' 'X' (HexDigit|Digit)+ - ; - -BigintLiteral - : - (Digit)+ 'L' - ; - -SmallintLiteral - : - (Digit)+ 'S' - ; - -TinyintLiteral - : - (Digit)+ 'Y' - ; - -DecimalLiteral - : - Number 'B' 'D' - ; - -ByteLengthLiteral - : - (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') - ; - -Number - : - (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? - ; - -/* -An Identifier can be: -- tableName -- columnName -- select expr alias -- lateral view aliases -- database name -- view name -- subquery alias -- function name -- ptf argument identifier -- index name -- property name for: db,tbl,partition... -- fileFormat -- role name -- privilege name -- principal name -- macro name -- hint name -- window name -*/ -Identifier - : - (Letter | Digit) (Letter | Digit | '_')* - | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; - at the API level only columns are allowed to be of this form */ - | '`' RegexComponent+ '`' - ; - -fragment -QuotedIdentifier - : - '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } - ; - -CharSetName - : - '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ - ; - -WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} - ; - -COMMENT - : '--' (~('\n'|'\r'))* - { $channel=HIDDEN; } - ; - diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g deleted file mode 100644 index 69574d713d0be..0000000000000 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g +++ /dev/null @@ -1,2457 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -parser grammar SparkSqlParser; - -options -{ -tokenVocab=SparkSqlLexer; -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} -import SelectClauseParser, FromClauseParser, IdentifiersParser; - -tokens { -TOK_INSERT; -TOK_QUERY; -TOK_SELECT; -TOK_SELECTDI; -TOK_SELEXPR; -TOK_FROM; -TOK_TAB; -TOK_PARTSPEC; -TOK_PARTVAL; -TOK_DIR; -TOK_TABREF; -TOK_SUBQUERY; -TOK_INSERT_INTO; -TOK_DESTINATION; -TOK_ALLCOLREF; -TOK_TABLE_OR_COL; -TOK_FUNCTION; -TOK_FUNCTIONDI; -TOK_FUNCTIONSTAR; -TOK_WHERE; -TOK_OP_EQ; -TOK_OP_NE; -TOK_OP_LE; -TOK_OP_LT; -TOK_OP_GE; -TOK_OP_GT; -TOK_OP_DIV; -TOK_OP_ADD; -TOK_OP_SUB; -TOK_OP_MUL; -TOK_OP_MOD; -TOK_OP_BITAND; -TOK_OP_BITNOT; -TOK_OP_BITOR; -TOK_OP_BITXOR; -TOK_OP_AND; -TOK_OP_OR; -TOK_OP_NOT; -TOK_OP_LIKE; -TOK_TRUE; -TOK_FALSE; -TOK_TRANSFORM; -TOK_SERDE; -TOK_SERDENAME; -TOK_SERDEPROPS; -TOK_EXPLIST; -TOK_ALIASLIST; -TOK_GROUPBY; -TOK_ROLLUP_GROUPBY; -TOK_CUBE_GROUPBY; -TOK_GROUPING_SETS; -TOK_GROUPING_SETS_EXPRESSION; -TOK_HAVING; -TOK_ORDERBY; -TOK_CLUSTERBY; -TOK_DISTRIBUTEBY; -TOK_SORTBY; -TOK_UNIONALL; -TOK_UNIONDISTINCT; -TOK_JOIN; -TOK_LEFTOUTERJOIN; -TOK_RIGHTOUTERJOIN; -TOK_FULLOUTERJOIN; -TOK_UNIQUEJOIN; -TOK_CROSSJOIN; -TOK_LOAD; -TOK_EXPORT; -TOK_IMPORT; -TOK_REPLICATION; -TOK_METADATA; -TOK_NULL; -TOK_ISNULL; -TOK_ISNOTNULL; -TOK_TINYINT; -TOK_SMALLINT; -TOK_INT; -TOK_BIGINT; -TOK_BOOLEAN; -TOK_FLOAT; -TOK_DOUBLE; -TOK_DATE; -TOK_DATELITERAL; -TOK_DATETIME; -TOK_TIMESTAMP; -TOK_TIMESTAMPLITERAL; -TOK_INTERVAL_YEAR_MONTH; -TOK_INTERVAL_YEAR_MONTH_LITERAL; -TOK_INTERVAL_DAY_TIME; -TOK_INTERVAL_DAY_TIME_LITERAL; -TOK_INTERVAL_YEAR_LITERAL; -TOK_INTERVAL_MONTH_LITERAL; -TOK_INTERVAL_DAY_LITERAL; -TOK_INTERVAL_HOUR_LITERAL; -TOK_INTERVAL_MINUTE_LITERAL; -TOK_INTERVAL_SECOND_LITERAL; -TOK_STRING; -TOK_CHAR; -TOK_VARCHAR; -TOK_BINARY; -TOK_DECIMAL; -TOK_LIST; -TOK_STRUCT; -TOK_MAP; -TOK_UNIONTYPE; -TOK_COLTYPELIST; -TOK_CREATEDATABASE; -TOK_CREATETABLE; -TOK_TRUNCATETABLE; -TOK_CREATEINDEX; -TOK_CREATEINDEX_INDEXTBLNAME; -TOK_DEFERRED_REBUILDINDEX; -TOK_DROPINDEX; -TOK_LIKETABLE; -TOK_DESCTABLE; -TOK_DESCFUNCTION; -TOK_ALTERTABLE; -TOK_ALTERTABLE_RENAME; -TOK_ALTERTABLE_ADDCOLS; -TOK_ALTERTABLE_RENAMECOL; -TOK_ALTERTABLE_RENAMEPART; -TOK_ALTERTABLE_REPLACECOLS; -TOK_ALTERTABLE_ADDPARTS; -TOK_ALTERTABLE_DROPPARTS; -TOK_ALTERTABLE_PARTCOLTYPE; -TOK_ALTERTABLE_MERGEFILES; -TOK_ALTERTABLE_TOUCH; -TOK_ALTERTABLE_ARCHIVE; -TOK_ALTERTABLE_UNARCHIVE; -TOK_ALTERTABLE_SERDEPROPERTIES; -TOK_ALTERTABLE_SERIALIZER; -TOK_ALTERTABLE_UPDATECOLSTATS; -TOK_TABLE_PARTITION; -TOK_ALTERTABLE_FILEFORMAT; -TOK_ALTERTABLE_LOCATION; -TOK_ALTERTABLE_PROPERTIES; -TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; -TOK_ALTERTABLE_DROPPROPERTIES; -TOK_ALTERTABLE_SKEWED; -TOK_ALTERTABLE_EXCHANGEPARTITION; -TOK_ALTERTABLE_SKEWED_LOCATION; -TOK_ALTERTABLE_BUCKETS; -TOK_ALTERTABLE_CLUSTER_SORT; -TOK_ALTERTABLE_COMPACT; -TOK_ALTERINDEX_REBUILD; -TOK_ALTERINDEX_PROPERTIES; -TOK_MSCK; -TOK_SHOWDATABASES; -TOK_SHOWTABLES; -TOK_SHOWCOLUMNS; -TOK_SHOWFUNCTIONS; -TOK_SHOWPARTITIONS; -TOK_SHOW_CREATEDATABASE; -TOK_SHOW_CREATETABLE; -TOK_SHOW_TABLESTATUS; -TOK_SHOW_TBLPROPERTIES; -TOK_SHOWLOCKS; -TOK_SHOWCONF; -TOK_LOCKTABLE; -TOK_UNLOCKTABLE; -TOK_LOCKDB; -TOK_UNLOCKDB; -TOK_SWITCHDATABASE; -TOK_DROPDATABASE; -TOK_DROPTABLE; -TOK_DATABASECOMMENT; -TOK_TABCOLLIST; -TOK_TABCOL; -TOK_TABLECOMMENT; -TOK_TABLEPARTCOLS; -TOK_TABLEROWFORMAT; -TOK_TABLEROWFORMATFIELD; -TOK_TABLEROWFORMATCOLLITEMS; -TOK_TABLEROWFORMATMAPKEYS; -TOK_TABLEROWFORMATLINES; -TOK_TABLEROWFORMATNULL; -TOK_TABLEFILEFORMAT; -TOK_FILEFORMAT_GENERIC; -TOK_OFFLINE; -TOK_ENABLE; -TOK_DISABLE; -TOK_READONLY; -TOK_NO_DROP; -TOK_STORAGEHANDLER; -TOK_NOT_CLUSTERED; -TOK_NOT_SORTED; -TOK_TABCOLNAME; -TOK_TABLELOCATION; -TOK_PARTITIONLOCATION; -TOK_TABLEBUCKETSAMPLE; -TOK_TABLESPLITSAMPLE; -TOK_PERCENT; -TOK_LENGTH; -TOK_ROWCOUNT; -TOK_TMP_FILE; -TOK_TABSORTCOLNAMEASC; -TOK_TABSORTCOLNAMEDESC; -TOK_STRINGLITERALSEQUENCE; -TOK_CHARSETLITERAL; -TOK_CREATEFUNCTION; -TOK_DROPFUNCTION; -TOK_RELOADFUNCTION; -TOK_CREATEMACRO; -TOK_DROPMACRO; -TOK_TEMPORARY; -TOK_CREATEVIEW; -TOK_DROPVIEW; -TOK_ALTERVIEW; -TOK_ALTERVIEW_PROPERTIES; -TOK_ALTERVIEW_DROPPROPERTIES; -TOK_ALTERVIEW_ADDPARTS; -TOK_ALTERVIEW_DROPPARTS; -TOK_ALTERVIEW_RENAME; -TOK_VIEWPARTCOLS; -TOK_EXPLAIN; -TOK_EXPLAIN_SQ_REWRITE; -TOK_TABLESERIALIZER; -TOK_TABLEPROPERTIES; -TOK_TABLEPROPLIST; -TOK_INDEXPROPERTIES; -TOK_INDEXPROPLIST; -TOK_TABTYPE; -TOK_LIMIT; -TOK_TABLEPROPERTY; -TOK_IFEXISTS; -TOK_IFNOTEXISTS; -TOK_ORREPLACE; -TOK_HINTLIST; -TOK_HINT; -TOK_MAPJOIN; -TOK_STREAMTABLE; -TOK_HINTARGLIST; -TOK_USERSCRIPTCOLNAMES; -TOK_USERSCRIPTCOLSCHEMA; -TOK_RECORDREADER; -TOK_RECORDWRITER; -TOK_LEFTSEMIJOIN; -TOK_ANTIJOIN; -TOK_LATERAL_VIEW; -TOK_LATERAL_VIEW_OUTER; -TOK_TABALIAS; -TOK_ANALYZE; -TOK_CREATEROLE; -TOK_DROPROLE; -TOK_GRANT; -TOK_REVOKE; -TOK_SHOW_GRANT; -TOK_PRIVILEGE_LIST; -TOK_PRIVILEGE; -TOK_PRINCIPAL_NAME; -TOK_USER; -TOK_GROUP; -TOK_ROLE; -TOK_RESOURCE_ALL; -TOK_GRANT_WITH_OPTION; -TOK_GRANT_WITH_ADMIN_OPTION; -TOK_ADMIN_OPTION_FOR; -TOK_GRANT_OPTION_FOR; -TOK_PRIV_ALL; -TOK_PRIV_ALTER_METADATA; -TOK_PRIV_ALTER_DATA; -TOK_PRIV_DELETE; -TOK_PRIV_DROP; -TOK_PRIV_INDEX; -TOK_PRIV_INSERT; -TOK_PRIV_LOCK; -TOK_PRIV_SELECT; -TOK_PRIV_SHOW_DATABASE; -TOK_PRIV_CREATE; -TOK_PRIV_OBJECT; -TOK_PRIV_OBJECT_COL; -TOK_GRANT_ROLE; -TOK_REVOKE_ROLE; -TOK_SHOW_ROLE_GRANT; -TOK_SHOW_ROLES; -TOK_SHOW_SET_ROLE; -TOK_SHOW_ROLE_PRINCIPALS; -TOK_SHOWINDEXES; -TOK_SHOWDBLOCKS; -TOK_INDEXCOMMENT; -TOK_DESCDATABASE; -TOK_DATABASEPROPERTIES; -TOK_DATABASELOCATION; -TOK_DBPROPLIST; -TOK_ALTERDATABASE_PROPERTIES; -TOK_ALTERDATABASE_OWNER; -TOK_TABNAME; -TOK_TABSRC; -TOK_RESTRICT; -TOK_CASCADE; -TOK_TABLESKEWED; -TOK_TABCOLVALUE; -TOK_TABCOLVALUE_PAIR; -TOK_TABCOLVALUES; -TOK_SKEWED_LOCATIONS; -TOK_SKEWED_LOCATION_LIST; -TOK_SKEWED_LOCATION_MAP; -TOK_STOREDASDIRS; -TOK_PARTITIONINGSPEC; -TOK_PTBLFUNCTION; -TOK_WINDOWDEF; -TOK_WINDOWSPEC; -TOK_WINDOWVALUES; -TOK_WINDOWRANGE; -TOK_SUBQUERY_EXPR; -TOK_SUBQUERY_OP; -TOK_SUBQUERY_OP_NOTIN; -TOK_SUBQUERY_OP_NOTEXISTS; -TOK_DB_TYPE; -TOK_TABLE_TYPE; -TOK_CTE; -TOK_ARCHIVE; -TOK_FILE; -TOK_JAR; -TOK_RESOURCE_URI; -TOK_RESOURCE_LIST; -TOK_SHOW_COMPACTIONS; -TOK_SHOW_TRANSACTIONS; -TOK_DELETE_FROM; -TOK_UPDATE_TABLE; -TOK_SET_COLUMNS_CLAUSE; -TOK_VALUE_ROW; -TOK_VALUES_TABLE; -TOK_VIRTUAL_TABLE; -TOK_VIRTUAL_TABREF; -TOK_ANONYMOUS; -TOK_COL_NAME; -TOK_URI_TYPE; -TOK_SERVER_TYPE; -TOK_START_TRANSACTION; -TOK_ISOLATION_LEVEL; -TOK_ISOLATION_SNAPSHOT; -TOK_TXN_ACCESS_MODE; -TOK_TXN_READ_ONLY; -TOK_TXN_READ_WRITE; -TOK_COMMIT; -TOK_ROLLBACK; -TOK_SET_AUTOCOMMIT; -} - - -// Package headers -@header { -package org.apache.spark.sql.parser; - -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; -} - - -@members { - ArrayList errors = new ArrayList(); - Stack msgs = new Stack(); - - private static HashMap xlateMap; - static { - //this is used to support auto completion in CLI - xlateMap = new HashMap(); - - // Keywords - xlateMap.put("KW_TRUE", "TRUE"); - xlateMap.put("KW_FALSE", "FALSE"); - xlateMap.put("KW_ALL", "ALL"); - xlateMap.put("KW_NONE", "NONE"); - xlateMap.put("KW_AND", "AND"); - xlateMap.put("KW_OR", "OR"); - xlateMap.put("KW_NOT", "NOT"); - xlateMap.put("KW_LIKE", "LIKE"); - - xlateMap.put("KW_ASC", "ASC"); - xlateMap.put("KW_DESC", "DESC"); - xlateMap.put("KW_ORDER", "ORDER"); - xlateMap.put("KW_BY", "BY"); - xlateMap.put("KW_GROUP", "GROUP"); - xlateMap.put("KW_WHERE", "WHERE"); - xlateMap.put("KW_FROM", "FROM"); - xlateMap.put("KW_AS", "AS"); - xlateMap.put("KW_SELECT", "SELECT"); - xlateMap.put("KW_DISTINCT", "DISTINCT"); - xlateMap.put("KW_INSERT", "INSERT"); - xlateMap.put("KW_OVERWRITE", "OVERWRITE"); - xlateMap.put("KW_OUTER", "OUTER"); - xlateMap.put("KW_JOIN", "JOIN"); - xlateMap.put("KW_LEFT", "LEFT"); - xlateMap.put("KW_RIGHT", "RIGHT"); - xlateMap.put("KW_FULL", "FULL"); - xlateMap.put("KW_ON", "ON"); - xlateMap.put("KW_PARTITION", "PARTITION"); - xlateMap.put("KW_PARTITIONS", "PARTITIONS"); - xlateMap.put("KW_TABLE", "TABLE"); - xlateMap.put("KW_TABLES", "TABLES"); - xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); - xlateMap.put("KW_SHOW", "SHOW"); - xlateMap.put("KW_MSCK", "MSCK"); - xlateMap.put("KW_DIRECTORY", "DIRECTORY"); - xlateMap.put("KW_LOCAL", "LOCAL"); - xlateMap.put("KW_TRANSFORM", "TRANSFORM"); - xlateMap.put("KW_USING", "USING"); - xlateMap.put("KW_CLUSTER", "CLUSTER"); - xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); - xlateMap.put("KW_SORT", "SORT"); - xlateMap.put("KW_UNION", "UNION"); - xlateMap.put("KW_LOAD", "LOAD"); - xlateMap.put("KW_DATA", "DATA"); - xlateMap.put("KW_INPATH", "INPATH"); - xlateMap.put("KW_IS", "IS"); - xlateMap.put("KW_NULL", "NULL"); - xlateMap.put("KW_CREATE", "CREATE"); - xlateMap.put("KW_EXTERNAL", "EXTERNAL"); - xlateMap.put("KW_ALTER", "ALTER"); - xlateMap.put("KW_DESCRIBE", "DESCRIBE"); - xlateMap.put("KW_DROP", "DROP"); - xlateMap.put("KW_RENAME", "RENAME"); - xlateMap.put("KW_TO", "TO"); - xlateMap.put("KW_COMMENT", "COMMENT"); - xlateMap.put("KW_BOOLEAN", "BOOLEAN"); - xlateMap.put("KW_TINYINT", "TINYINT"); - xlateMap.put("KW_SMALLINT", "SMALLINT"); - xlateMap.put("KW_INT", "INT"); - xlateMap.put("KW_BIGINT", "BIGINT"); - xlateMap.put("KW_FLOAT", "FLOAT"); - xlateMap.put("KW_DOUBLE", "DOUBLE"); - xlateMap.put("KW_DATE", "DATE"); - xlateMap.put("KW_DATETIME", "DATETIME"); - xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); - xlateMap.put("KW_STRING", "STRING"); - xlateMap.put("KW_BINARY", "BINARY"); - xlateMap.put("KW_ARRAY", "ARRAY"); - xlateMap.put("KW_MAP", "MAP"); - xlateMap.put("KW_REDUCE", "REDUCE"); - xlateMap.put("KW_PARTITIONED", "PARTITIONED"); - xlateMap.put("KW_CLUSTERED", "CLUSTERED"); - xlateMap.put("KW_SORTED", "SORTED"); - xlateMap.put("KW_INTO", "INTO"); - xlateMap.put("KW_BUCKETS", "BUCKETS"); - xlateMap.put("KW_ROW", "ROW"); - xlateMap.put("KW_FORMAT", "FORMAT"); - xlateMap.put("KW_DELIMITED", "DELIMITED"); - xlateMap.put("KW_FIELDS", "FIELDS"); - xlateMap.put("KW_TERMINATED", "TERMINATED"); - xlateMap.put("KW_COLLECTION", "COLLECTION"); - xlateMap.put("KW_ITEMS", "ITEMS"); - xlateMap.put("KW_KEYS", "KEYS"); - xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); - xlateMap.put("KW_LINES", "LINES"); - xlateMap.put("KW_STORED", "STORED"); - xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); - xlateMap.put("KW_TEXTFILE", "TEXTFILE"); - xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); - xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); - xlateMap.put("KW_LOCATION", "LOCATION"); - xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); - xlateMap.put("KW_BUCKET", "BUCKET"); - xlateMap.put("KW_OUT", "OUT"); - xlateMap.put("KW_OF", "OF"); - xlateMap.put("KW_CAST", "CAST"); - xlateMap.put("KW_ADD", "ADD"); - xlateMap.put("KW_REPLACE", "REPLACE"); - xlateMap.put("KW_COLUMNS", "COLUMNS"); - xlateMap.put("KW_RLIKE", "RLIKE"); - xlateMap.put("KW_REGEXP", "REGEXP"); - xlateMap.put("KW_TEMPORARY", "TEMPORARY"); - xlateMap.put("KW_FUNCTION", "FUNCTION"); - xlateMap.put("KW_EXPLAIN", "EXPLAIN"); - xlateMap.put("KW_EXTENDED", "EXTENDED"); - xlateMap.put("KW_SERDE", "SERDE"); - xlateMap.put("KW_WITH", "WITH"); - xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); - xlateMap.put("KW_LIMIT", "LIMIT"); - xlateMap.put("KW_SET", "SET"); - xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); - xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); - xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); - xlateMap.put("KW_DEFINED", "DEFINED"); - xlateMap.put("KW_SUBQUERY", "SUBQUERY"); - xlateMap.put("KW_REWRITE", "REWRITE"); - xlateMap.put("KW_UPDATE", "UPDATE"); - xlateMap.put("KW_VALUES", "VALUES"); - xlateMap.put("KW_PURGE", "PURGE"); - - - // Operators - xlateMap.put("DOT", "."); - xlateMap.put("COLON", ":"); - xlateMap.put("COMMA", ","); - xlateMap.put("SEMICOLON", ");"); - - xlateMap.put("LPAREN", "("); - xlateMap.put("RPAREN", ")"); - xlateMap.put("LSQUARE", "["); - xlateMap.put("RSQUARE", "]"); - - xlateMap.put("EQUAL", "="); - xlateMap.put("NOTEQUAL", "<>"); - xlateMap.put("EQUAL_NS", "<=>"); - xlateMap.put("LESSTHANOREQUALTO", "<="); - xlateMap.put("LESSTHAN", "<"); - xlateMap.put("GREATERTHANOREQUALTO", ">="); - xlateMap.put("GREATERTHAN", ">"); - - xlateMap.put("DIVIDE", "/"); - xlateMap.put("PLUS", "+"); - xlateMap.put("MINUS", "-"); - xlateMap.put("STAR", "*"); - xlateMap.put("MOD", "\%"); - - xlateMap.put("AMPERSAND", "&"); - xlateMap.put("TILDE", "~"); - xlateMap.put("BITWISEOR", "|"); - xlateMap.put("BITWISEXOR", "^"); - xlateMap.put("CharSetLiteral", "\\'"); - } - - public static Collection getKeywords() { - return xlateMap.values(); - } - - private static String xlate(String name) { - - String ret = xlateMap.get(name); - if (ret == null) { - ret = name; - } - - return ret; - } - - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - errors.add(new ParseError(this, e, tokenNames)); - } - - @Override - public String getErrorHeader(RecognitionException e) { - String header = null; - if (e.charPositionInLine < 0 && input.LT(-1) != null) { - Token t = input.LT(-1); - header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); - } else { - header = super.getErrorHeader(e); - } - - return header; - } - - @Override - public String getErrorMessage(RecognitionException e, String[] tokenNames) { - String msg = null; - - // Translate the token names to something that the user can understand - String[] xlateNames = new String[tokenNames.length]; - for (int i = 0; i < tokenNames.length; ++i) { - xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); - } - - if (e instanceof NoViableAltException) { - @SuppressWarnings("unused") - NoViableAltException nvae = (NoViableAltException) e; - // for development, can add - // "decision=<<"+nvae.grammarDecisionDescription+">>" - // and "(decision="+nvae.decisionNumber+") and - // "state "+nvae.stateNumber - msg = "cannot recognize input near" - + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") - + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") - + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); - } else if (e instanceof MismatchedTokenException) { - MismatchedTokenException mte = (MismatchedTokenException) e; - msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; - } else if (e instanceof FailedPredicateException) { - FailedPredicateException fpe = (FailedPredicateException) e; - msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; - } else { - msg = super.getErrorMessage(e, xlateNames); - } - - if (msgs.size() > 0) { - msg = msg + " in " + msgs.peek(); - } - return msg; - } - - public void pushMsg(String msg, RecognizerSharedState state) { - // ANTLR generated code does not wrap the @init code wit this backtracking check, - // even if the matching @after has it. If we have parser rules with that are doing - // some lookahead with syntactic predicates this can cause the push() and pop() calls - // to become unbalanced, so make sure both push/pop check the backtracking state. - if (state.backtracking == 0) { - msgs.push(msg); - } - } - - public void popMsg(RecognizerSharedState state) { - if (state.backtracking == 0) { - Object o = msgs.pop(); - } - } - - // counter to generate unique union aliases - private int aliasCounter; - private String generateUnionAlias() { - return "_u" + (++aliasCounter); - } - private char [] excludedCharForColumnName = {'.', ':'}; - private boolean containExcludedCharForCreateTableColumnName(String input) { - for(char c : excludedCharForColumnName) { - if(input.indexOf(c)>-1) { - return true; - } - } - return false; - } - private CommonTree throwSetOpException() throws RecognitionException { - throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); - } - private CommonTree throwColumnNameException() throws RecognitionException { - throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); - } - private Configuration hiveConf; - public void setHiveConf(Configuration hiveConf) { - this.hiveConf = hiveConf; - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - if(hiveConf==null){ - return false; - } - return !HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS); - } -} - -@rulecatch { -catch (RecognitionException e) { - reportError(e); - throw e; -} -} - -// starting rule -statement - : explainStatement EOF - | execStatement EOF - ; - -explainStatement -@init { pushMsg("explain statement", state); } -@after { popMsg(state); } - : KW_EXPLAIN ( - explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) - | - KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) - ; - -explainOption -@init { msgs.push("explain option"); } -@after { msgs.pop(); } - : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION - ; - -execStatement -@init { pushMsg("statement", state); } -@after { popMsg(state); } - : queryStatementExpression[true] - | loadStatement - | exportStatement - | importStatement - | ddlStatement - | deleteStatement - | updateStatement - | sqlTransactionStatement - ; - -loadStatement -@init { pushMsg("load statement", state); } -@after { popMsg(state); } - : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) - -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) - ; - -replicationClause -@init { pushMsg("replication clause", state); } -@after { popMsg(state); } - : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN - -> ^(TOK_REPLICATION $replId $isMetadataOnly?) - ; - -exportStatement -@init { pushMsg("export statement", state); } -@after { popMsg(state); } - : KW_EXPORT - KW_TABLE (tab=tableOrPartition) - KW_TO (path=StringLiteral) - replicationClause? - -> ^(TOK_EXPORT $tab $path replicationClause?) - ; - -importStatement -@init { pushMsg("import statement", state); } -@after { popMsg(state); } - : KW_IMPORT - ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? - KW_FROM (path=StringLiteral) - tableLocation? - -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) - ; - -ddlStatement -@init { pushMsg("ddl statement", state); } -@after { popMsg(state); } - : createDatabaseStatement - | switchDatabaseStatement - | dropDatabaseStatement - | createTableStatement - | dropTableStatement - | truncateTableStatement - | alterStatement - | descStatement - | showStatement - | metastoreCheck - | createViewStatement - | dropViewStatement - | createFunctionStatement - | createMacroStatement - | createIndexStatement - | dropIndexStatement - | dropFunctionStatement - | reloadFunctionStatement - | dropMacroStatement - | analyzeStatement - | lockStatement - | unlockStatement - | lockDatabase - | unlockDatabase - | createRoleStatement - | dropRoleStatement - | (grantPrivileges) => grantPrivileges - | (revokePrivileges) => revokePrivileges - | showGrants - | showRoleGrants - | showRolePrincipals - | showRoles - | grantRole - | revokeRole - | setRole - | showCurrentRole - ; - -ifExists -@init { pushMsg("if exists clause", state); } -@after { popMsg(state); } - : KW_IF KW_EXISTS - -> ^(TOK_IFEXISTS) - ; - -restrictOrCascade -@init { pushMsg("restrict or cascade clause", state); } -@after { popMsg(state); } - : KW_RESTRICT - -> ^(TOK_RESTRICT) - | KW_CASCADE - -> ^(TOK_CASCADE) - ; - -ifNotExists -@init { pushMsg("if not exists clause", state); } -@after { popMsg(state); } - : KW_IF KW_NOT KW_EXISTS - -> ^(TOK_IFNOTEXISTS) - ; - -storedAsDirs -@init { pushMsg("stored as directories", state); } -@after { popMsg(state); } - : KW_STORED KW_AS KW_DIRECTORIES - -> ^(TOK_STOREDASDIRS) - ; - -orReplace -@init { pushMsg("or replace clause", state); } -@after { popMsg(state); } - : KW_OR KW_REPLACE - -> ^(TOK_ORREPLACE) - ; - -createDatabaseStatement -@init { pushMsg("create database statement", state); } -@after { popMsg(state); } - : KW_CREATE (KW_DATABASE|KW_SCHEMA) - ifNotExists? - name=identifier - databaseComment? - dbLocation? - (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? - -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) - ; - -dbLocation -@init { pushMsg("database location specification", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) - ; - -dbProperties -@init { pushMsg("dbproperties", state); } -@after { popMsg(state); } - : - LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) - ; - -dbPropertiesList -@init { pushMsg("database properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) - ; - - -switchDatabaseStatement -@init { pushMsg("switch database statement", state); } -@after { popMsg(state); } - : KW_USE identifier - -> ^(TOK_SWITCHDATABASE identifier) - ; - -dropDatabaseStatement -@init { pushMsg("drop database statement", state); } -@after { popMsg(state); } - : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? - -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) - ; - -databaseComment -@init { pushMsg("database's comment", state); } -@after { popMsg(state); } - : KW_COMMENT comment=StringLiteral - -> ^(TOK_DATABASECOMMENT $comment) - ; - -createTableStatement -@init { pushMsg("create table statement", state); } -@after { popMsg(state); } - : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName - ( like=KW_LIKE likeName=tableName - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - | (LPAREN columnNameTypeList RPAREN)? - tableComment? - tablePartition? - tableBuckets? - tableSkewed? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - (KW_AS selectStatementWithCTE)? - ) - -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? - ^(TOK_LIKETABLE $likeName?) - columnNameTypeList? - tableComment? - tablePartition? - tableBuckets? - tableSkewed? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - selectStatementWithCTE? - ) - ; - -truncateTableStatement -@init { pushMsg("truncate table statement", state); } -@after { popMsg(state); } - : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); - -createIndexStatement -@init { pushMsg("create index statement", state);} -@after {popMsg(state);} - : KW_CREATE KW_INDEX indexName=identifier - KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN - KW_AS typeName=StringLiteral - autoRebuild? - indexPropertiesPrefixed? - indexTblName? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - indexComment? - ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols - autoRebuild? - indexPropertiesPrefixed? - indexTblName? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - indexComment?) - ; - -indexComment -@init { pushMsg("comment on an index", state);} -@after {popMsg(state);} - : - KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) - ; - -autoRebuild -@init { pushMsg("auto rebuild index", state);} -@after {popMsg(state);} - : KW_WITH KW_DEFERRED KW_REBUILD - ->^(TOK_DEFERRED_REBUILDINDEX) - ; - -indexTblName -@init { pushMsg("index table name", state);} -@after {popMsg(state);} - : KW_IN KW_TABLE indexTbl=tableName - ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) - ; - -indexPropertiesPrefixed -@init { pushMsg("table properties with prefix", state); } -@after { popMsg(state); } - : - KW_IDXPROPERTIES! indexProperties - ; - -indexProperties -@init { pushMsg("index properties", state); } -@after { popMsg(state); } - : - LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) - ; - -indexPropertiesList -@init { pushMsg("index properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) - ; - -dropIndexStatement -@init { pushMsg("drop index statement", state);} -@after {popMsg(state);} - : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName - ->^(TOK_DROPINDEX $indexName $tab ifExists?) - ; - -dropTableStatement -@init { pushMsg("drop statement", state); } -@after { popMsg(state); } - : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? - -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) - ; - -alterStatement -@init { pushMsg("alter statement", state); } -@after { popMsg(state); } - : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) - | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) - | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix - | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix - ; - -alterTableStatementSuffix -@init { pushMsg("alter table statement", state); } -@after { popMsg(state); } - : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] - | alterStatementSuffixDropPartitions[true] - | alterStatementSuffixAddPartitions[true] - | alterStatementSuffixTouch - | alterStatementSuffixArchive - | alterStatementSuffixUnArchive - | alterStatementSuffixProperties - | alterStatementSuffixSkewedby - | alterStatementSuffixExchangePartition - | alterStatementPartitionKeyType - | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? - ; - -alterTblPartitionStatementSuffix -@init {pushMsg("alter table partition statement suffix", state);} -@after {popMsg(state);} - : alterStatementSuffixFileFormat - | alterStatementSuffixLocation - | alterStatementSuffixMergeFiles - | alterStatementSuffixSerdeProperties - | alterStatementSuffixRenamePart - | alterStatementSuffixBucketNum - | alterTblPartitionStatementSuffixSkewedLocation - | alterStatementSuffixClusterbySortby - | alterStatementSuffixCompact - | alterStatementSuffixUpdateStatsCol - | alterStatementSuffixRenameCol - | alterStatementSuffixAddCol - ; - -alterStatementPartitionKeyType -@init {msgs.push("alter partition key type"); } -@after {msgs.pop();} - : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN - -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) - ; - -alterViewStatementSuffix -@init { pushMsg("alter view statement", state); } -@after { popMsg(state); } - : alterViewSuffixProperties - | alterStatementSuffixRename[false] - | alterStatementSuffixAddPartitions[false] - | alterStatementSuffixDropPartitions[false] - | selectStatementWithCTE - ; - -alterIndexStatementSuffix -@init { pushMsg("alter index statement", state); } -@after { popMsg(state); } - : indexName=identifier KW_ON tableName partitionSpec? - ( - KW_REBUILD - ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) - | - KW_SET KW_IDXPROPERTIES - indexProperties - ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) - ) - ; - -alterDatabaseStatementSuffix -@init { pushMsg("alter database statement", state); } -@after { popMsg(state); } - : alterDatabaseSuffixProperties - | alterDatabaseSuffixSetOwner - ; - -alterDatabaseSuffixProperties -@init { pushMsg("alter database properties statement", state); } -@after { popMsg(state); } - : name=identifier KW_SET KW_DBPROPERTIES dbProperties - -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) - ; - -alterDatabaseSuffixSetOwner -@init { pushMsg("alter database set owner", state); } -@after { popMsg(state); } - : dbName=identifier KW_SET KW_OWNER principalName - -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) - ; - -alterStatementSuffixRename[boolean table] -@init { pushMsg("rename statement", state); } -@after { popMsg(state); } - : KW_RENAME KW_TO tableName - -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) - -> ^(TOK_ALTERVIEW_RENAME tableName) - ; - -alterStatementSuffixAddCol -@init { pushMsg("add column statement", state); } -@after { popMsg(state); } - : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? - -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) - -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) - ; - -alterStatementSuffixRenameCol -@init { pushMsg("rename column name", state); } -@after { popMsg(state); } - : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? - ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) - ; - -alterStatementSuffixUpdateStatsCol -@init { pushMsg("update column statistics", state); } -@after { popMsg(state); } - : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? - ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) - ; - -alterStatementChangeColPosition - : first=KW_FIRST|KW_AFTER afterCol=identifier - ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) - -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) - ; - -alterStatementSuffixAddPartitions[boolean table] -@init { pushMsg("add partition statement", state); } -@after { popMsg(state); } - : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ - -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) - -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) - ; - -alterStatementSuffixAddPartitionsElement - : partitionSpec partitionLocation? - ; - -alterStatementSuffixTouch -@init { pushMsg("touch statement", state); } -@after { popMsg(state); } - : KW_TOUCH (partitionSpec)* - -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) - ; - -alterStatementSuffixArchive -@init { pushMsg("archive statement", state); } -@after { popMsg(state); } - : KW_ARCHIVE (partitionSpec)* - -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) - ; - -alterStatementSuffixUnArchive -@init { pushMsg("unarchive statement", state); } -@after { popMsg(state); } - : KW_UNARCHIVE (partitionSpec)* - -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) - ; - -partitionLocation -@init { pushMsg("partition location", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) - ; - -alterStatementSuffixDropPartitions[boolean table] -@init { pushMsg("drop partition statement", state); } -@after { popMsg(state); } - : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? - -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) - -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) - ; - -alterStatementSuffixProperties -@init { pushMsg("alter properties statement", state); } -@after { popMsg(state); } - : KW_SET KW_TBLPROPERTIES tableProperties - -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) - | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties - -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) - ; - -alterViewSuffixProperties -@init { pushMsg("alter view properties statement", state); } -@after { popMsg(state); } - : KW_SET KW_TBLPROPERTIES tableProperties - -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) - | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties - -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) - ; - -alterStatementSuffixSerdeProperties -@init { pushMsg("alter serdes statement", state); } -@after { popMsg(state); } - : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? - -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) - | KW_SET KW_SERDEPROPERTIES tableProperties - -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) - ; - -tablePartitionPrefix -@init {pushMsg("table partition prefix", state);} -@after {popMsg(state);} - : tableName partitionSpec? - ->^(TOK_TABLE_PARTITION tableName partitionSpec?) - ; - -alterStatementSuffixFileFormat -@init {pushMsg("alter fileformat statement", state); } -@after {popMsg(state);} - : KW_SET KW_FILEFORMAT fileFormat - -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) - ; - -alterStatementSuffixClusterbySortby -@init {pushMsg("alter partition cluster by sort by statement", state);} -@after {popMsg(state);} - : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) - | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) - | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) - ; - -alterTblPartitionStatementSuffixSkewedLocation -@init {pushMsg("alter partition skewed location", state);} -@after {popMsg(state);} - : KW_SET KW_SKEWED KW_LOCATION skewedLocations - -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) - ; - -skewedLocations -@init { pushMsg("skewed locations", state); } -@after { popMsg(state); } - : - LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) - ; - -skewedLocationsList -@init { pushMsg("skewed locations list", state); } -@after { popMsg(state); } - : - skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) - ; - -skewedLocationMap -@init { pushMsg("specifying skewed location map", state); } -@after { popMsg(state); } - : - key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) - ; - -alterStatementSuffixLocation -@init {pushMsg("alter location", state);} -@after {popMsg(state);} - : KW_SET KW_LOCATION newLoc=StringLiteral - -> ^(TOK_ALTERTABLE_LOCATION $newLoc) - ; - - -alterStatementSuffixSkewedby -@init {pushMsg("alter skewed by statement", state);} -@after{popMsg(state);} - : tableSkewed - ->^(TOK_ALTERTABLE_SKEWED tableSkewed) - | - KW_NOT KW_SKEWED - ->^(TOK_ALTERTABLE_SKEWED) - | - KW_NOT storedAsDirs - ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) - ; - -alterStatementSuffixExchangePartition -@init {pushMsg("alter exchange partition", state);} -@after{popMsg(state);} - : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName - -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) - ; - -alterStatementSuffixRenamePart -@init { pushMsg("alter table rename partition statement", state); } -@after { popMsg(state); } - : KW_RENAME KW_TO partitionSpec - ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) - ; - -alterStatementSuffixStatsPart -@init { pushMsg("alter table stats partition statement", state); } -@after { popMsg(state); } - : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? - ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) - ; - -alterStatementSuffixMergeFiles -@init { pushMsg("", state); } -@after { popMsg(state); } - : KW_CONCATENATE - -> ^(TOK_ALTERTABLE_MERGEFILES) - ; - -alterStatementSuffixBucketNum -@init { pushMsg("", state); } -@after { popMsg(state); } - : KW_INTO num=Number KW_BUCKETS - -> ^(TOK_ALTERTABLE_BUCKETS $num) - ; - -alterStatementSuffixCompact -@init { msgs.push("compaction request"); } -@after { msgs.pop(); } - : KW_COMPACT compactType=StringLiteral - -> ^(TOK_ALTERTABLE_COMPACT $compactType) - ; - - -fileFormat -@init { pushMsg("file format specification", state); } -@after { popMsg(state); } - : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? - -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) - | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) - ; - -tabTypeExpr -@init { pushMsg("specifying table types", state); } -@after { popMsg(state); } - : identifier (DOT^ identifier)? - (identifier (DOT^ - ( - (KW_ELEM_TYPE) => KW_ELEM_TYPE - | - (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE - | identifier - ))* - )? - ; - -partTypeExpr -@init { pushMsg("specifying table partitions", state); } -@after { popMsg(state); } - : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) - ; - -tabPartColTypeExpr -@init { pushMsg("specifying table partitions columnName", state); } -@after { popMsg(state); } - : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) - ; - -descStatement -@init { pushMsg("describe statement", state); } -@after { popMsg(state); } - : - (KW_DESCRIBE|KW_DESC) - ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) - | - (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) - | - (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) - | - parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) - ) - ; - -analyzeStatement -@init { pushMsg("analyze statement", state); } -@after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) - | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? - -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) - ; - -showStatement -@init { pushMsg("show statement", state); } -@after { popMsg(state); } - : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) - | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) - | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? - -> ^(TOK_SHOWCOLUMNS tableName $db_name?) - | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) - | KW_SHOW KW_CREATE ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) - | - KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) - ) - | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? - -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) - | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS - ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) - | - (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) - ) - | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? - -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) - | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) - | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) - | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) - ; - -lockStatement -@init { pushMsg("lock statement", state); } -@after { popMsg(state); } - : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) - ; - -lockDatabase -@init { pushMsg("lock database statement", state); } -@after { popMsg(state); } - : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) - ; - -lockMode -@init { pushMsg("lock mode", state); } -@after { popMsg(state); } - : KW_SHARED | KW_EXCLUSIVE - ; - -unlockStatement -@init { pushMsg("unlock statement", state); } -@after { popMsg(state); } - : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) - ; - -unlockDatabase -@init { pushMsg("unlock database statement", state); } -@after { popMsg(state); } - : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) - ; - -createRoleStatement -@init { pushMsg("create role", state); } -@after { popMsg(state); } - : KW_CREATE KW_ROLE roleName=identifier - -> ^(TOK_CREATEROLE $roleName) - ; - -dropRoleStatement -@init {pushMsg("drop role", state);} -@after {popMsg(state);} - : KW_DROP KW_ROLE roleName=identifier - -> ^(TOK_DROPROLE $roleName) - ; - -grantPrivileges -@init {pushMsg("grant privileges", state);} -@after {popMsg(state);} - : KW_GRANT privList=privilegeList - privilegeObject? - KW_TO principalSpecification - withGrantOption? - -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) - ; - -revokePrivileges -@init {pushMsg("revoke privileges", state);} -@afer {popMsg(state);} - : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification - -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) - ; - -grantRole -@init {pushMsg("grant role", state);} -@after {popMsg(state);} - : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? - -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) - ; - -revokeRole -@init {pushMsg("revoke role", state);} -@after {popMsg(state);} - : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification - -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) - ; - -showRoleGrants -@init {pushMsg("show role grants", state);} -@after {popMsg(state);} - : KW_SHOW KW_ROLE KW_GRANT principalName - -> ^(TOK_SHOW_ROLE_GRANT principalName) - ; - - -showRoles -@init {pushMsg("show roles", state);} -@after {popMsg(state);} - : KW_SHOW KW_ROLES - -> ^(TOK_SHOW_ROLES) - ; - -showCurrentRole -@init {pushMsg("show current role", state);} -@after {popMsg(state);} - : KW_SHOW KW_CURRENT KW_ROLES - -> ^(TOK_SHOW_SET_ROLE) - ; - -setRole -@init {pushMsg("set role", state);} -@after {popMsg(state);} - : KW_SET KW_ROLE - ( - (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) - | - (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) - | - identifier -> ^(TOK_SHOW_SET_ROLE identifier) - ) - ; - -showGrants -@init {pushMsg("show grants", state);} -@after {popMsg(state);} - : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? - -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) - ; - -showRolePrincipals -@init {pushMsg("show role principals", state);} -@after {popMsg(state);} - : KW_SHOW KW_PRINCIPALS roleName=identifier - -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) - ; - - -privilegeIncludeColObject -@init {pushMsg("privilege object including columns", state);} -@after {popMsg(state);} - : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) - | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) - ; - -privilegeObject -@init {pushMsg("privilege object", state);} -@after {popMsg(state);} - : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) - ; - -// database or table type. Type is optional, default type is table -privObject - : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) - | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) - | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) - | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) - ; - -privObjectCols - : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) - | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) - | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) - | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) - ; - -privilegeList -@init {pushMsg("grant privilege list", state);} -@after {popMsg(state);} - : privlegeDef (COMMA privlegeDef)* - -> ^(TOK_PRIVILEGE_LIST privlegeDef+) - ; - -privlegeDef -@init {pushMsg("grant privilege", state);} -@after {popMsg(state);} - : privilegeType (LPAREN cols=columnNameList RPAREN)? - -> ^(TOK_PRIVILEGE privilegeType $cols?) - ; - -privilegeType -@init {pushMsg("privilege type", state);} -@after {popMsg(state);} - : KW_ALL -> ^(TOK_PRIV_ALL) - | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) - | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) - | KW_CREATE -> ^(TOK_PRIV_CREATE) - | KW_DROP -> ^(TOK_PRIV_DROP) - | KW_INDEX -> ^(TOK_PRIV_INDEX) - | KW_LOCK -> ^(TOK_PRIV_LOCK) - | KW_SELECT -> ^(TOK_PRIV_SELECT) - | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) - | KW_INSERT -> ^(TOK_PRIV_INSERT) - | KW_DELETE -> ^(TOK_PRIV_DELETE) - ; - -principalSpecification -@init { pushMsg("user/group/role name list", state); } -@after { popMsg(state); } - : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) - ; - -principalName -@init {pushMsg("user|group|role name", state);} -@after {popMsg(state);} - : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) - | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) - | KW_ROLE identifier -> ^(TOK_ROLE identifier) - ; - -withGrantOption -@init {pushMsg("with grant option", state);} -@after {popMsg(state);} - : KW_WITH KW_GRANT KW_OPTION - -> ^(TOK_GRANT_WITH_OPTION) - ; - -grantOptionFor -@init {pushMsg("grant option for", state);} -@after {popMsg(state);} - : KW_GRANT KW_OPTION KW_FOR - -> ^(TOK_GRANT_OPTION_FOR) -; - -adminOptionFor -@init {pushMsg("admin option for", state);} -@after {popMsg(state);} - : KW_ADMIN KW_OPTION KW_FOR - -> ^(TOK_ADMIN_OPTION_FOR) -; - -withAdminOption -@init {pushMsg("with admin option", state);} -@after {popMsg(state);} - : KW_WITH KW_ADMIN KW_OPTION - -> ^(TOK_GRANT_WITH_ADMIN_OPTION) - ; - -metastoreCheck -@init { pushMsg("metastore check statement", state); } -@after { popMsg(state); } - : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? - -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) - ; - -resourceList -@init { pushMsg("resource list", state); } -@after { popMsg(state); } - : - resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) - ; - -resource -@init { pushMsg("resource", state); } -@after { popMsg(state); } - : - resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) - ; - -resourceType -@init { pushMsg("resource type", state); } -@after { popMsg(state); } - : - KW_JAR -> ^(TOK_JAR) - | - KW_FILE -> ^(TOK_FILE) - | - KW_ARCHIVE -> ^(TOK_ARCHIVE) - ; - -createFunctionStatement -@init { pushMsg("create function statement", state); } -@after { popMsg(state); } - : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral - (KW_USING rList=resourceList)? - -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) - -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) - ; - -dropFunctionStatement -@init { pushMsg("drop function statement", state); } -@after { popMsg(state); } - : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier - -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) - -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) - ; - -reloadFunctionStatement -@init { pushMsg("reload function statement", state); } -@after { popMsg(state); } - : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); - -createMacroStatement -@init { pushMsg("create macro statement", state); } -@after { popMsg(state); } - : KW_CREATE KW_TEMPORARY KW_MACRO Identifier - LPAREN columnNameTypeList? RPAREN expression - -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) - ; - -dropMacroStatement -@init { pushMsg("drop macro statement", state); } -@after { popMsg(state); } - : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier - -> ^(TOK_DROPMACRO Identifier ifExists?) - ; - -createViewStatement -@init { - pushMsg("create view statement", state); -} -@after { popMsg(state); } - : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName - (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? - tablePropertiesPrefixed? - KW_AS - selectStatementWithCTE - -> ^(TOK_CREATEVIEW $name orReplace? - ifNotExists? - columnNameCommentList? - tableComment? - viewPartition? - tablePropertiesPrefixed? - selectStatementWithCTE - ) - ; - -viewPartition -@init { pushMsg("view partition specification", state); } -@after { popMsg(state); } - : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN - -> ^(TOK_VIEWPARTCOLS columnNameList) - ; - -dropViewStatement -@init { pushMsg("drop view statement", state); } -@after { popMsg(state); } - : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) - ; - -showFunctionIdentifier -@init { pushMsg("identifier for show function statement", state); } -@after { popMsg(state); } - : functionIdentifier - | StringLiteral - ; - -showStmtIdentifier -@init { pushMsg("identifier for show statement", state); } -@after { popMsg(state); } - : identifier - | StringLiteral - ; - -tableComment -@init { pushMsg("table's comment", state); } -@after { popMsg(state); } - : - KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) - ; - -tablePartition -@init { pushMsg("table partition specification", state); } -@after { popMsg(state); } - : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN - -> ^(TOK_TABLEPARTCOLS columnNameTypeList) - ; - -tableBuckets -@init { pushMsg("table buckets specification", state); } -@after { popMsg(state); } - : - KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS - -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) - ; - -tableSkewed -@init { pushMsg("table skewed specification", state); } -@after { popMsg(state); } - : - KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? - -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) - ; - -rowFormat -@init { pushMsg("serde specification", state); } -@after { popMsg(state); } - : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) - | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) - | -> ^(TOK_SERDE) - ; - -recordReader -@init { pushMsg("record reader specification", state); } -@after { popMsg(state); } - : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) - | -> ^(TOK_RECORDREADER) - ; - -recordWriter -@init { pushMsg("record writer specification", state); } -@after { popMsg(state); } - : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) - | -> ^(TOK_RECORDWRITER) - ; - -rowFormatSerde -@init { pushMsg("serde format specification", state); } -@after { popMsg(state); } - : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? - -> ^(TOK_SERDENAME $name $serdeprops?) - ; - -rowFormatDelimited -@init { pushMsg("serde properties specification", state); } -@after { popMsg(state); } - : - KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? - -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) - ; - -tableRowFormat -@init { pushMsg("table row format specification", state); } -@after { popMsg(state); } - : - rowFormatDelimited - -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) - | rowFormatSerde - -> ^(TOK_TABLESERIALIZER rowFormatSerde) - ; - -tablePropertiesPrefixed -@init { pushMsg("table properties with prefix", state); } -@after { popMsg(state); } - : - KW_TBLPROPERTIES! tableProperties - ; - -tableProperties -@init { pushMsg("table properties", state); } -@after { popMsg(state); } - : - LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) - ; - -tablePropertiesList -@init { pushMsg("table properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) - | - keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) - ; - -keyValueProperty -@init { pushMsg("specifying key/value property", state); } -@after { popMsg(state); } - : - key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) - ; - -keyProperty -@init { pushMsg("specifying key property", state); } -@after { popMsg(state); } - : - key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) - ; - -tableRowFormatFieldIdentifier -@init { pushMsg("table row format's field separator", state); } -@after { popMsg(state); } - : - KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? - -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) - ; - -tableRowFormatCollItemsIdentifier -@init { pushMsg("table row format's column separator", state); } -@after { popMsg(state); } - : - KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) - ; - -tableRowFormatMapKeysIdentifier -@init { pushMsg("table row format's map key separator", state); } -@after { popMsg(state); } - : - KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) - ; - -tableRowFormatLinesIdentifier -@init { pushMsg("table row format's line separator", state); } -@after { popMsg(state); } - : - KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) - ; - -tableRowNullFormat -@init { pushMsg("table row format's null specifier", state); } -@after { popMsg(state); } - : - KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) - ; -tableFileFormat -@init { pushMsg("table file format specification", state); } -@after { popMsg(state); } - : - (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? - -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) - | KW_STORED KW_BY storageHandler=StringLiteral - (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? - -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) - | KW_STORED KW_AS genericSpec=identifier - -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) - ; - -tableLocation -@init { pushMsg("table location specification", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) - ; - -columnNameTypeList -@init { pushMsg("column name type list", state); } -@after { popMsg(state); } - : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) - ; - -columnNameColonTypeList -@init { pushMsg("column name type list", state); } -@after { popMsg(state); } - : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) - ; - -columnNameList -@init { pushMsg("column name list", state); } -@after { popMsg(state); } - : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) - ; - -columnName -@init { pushMsg("column name", state); } -@after { popMsg(state); } - : - identifier - ; - -extColumnName -@init { pushMsg("column name for complex types", state); } -@after { popMsg(state); } - : - identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* - ; - -columnNameOrderList -@init { pushMsg("column name order list", state); } -@after { popMsg(state); } - : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) - ; - -skewedValueElement -@init { pushMsg("skewed value element", state); } -@after { popMsg(state); } - : - skewedColumnValues - | skewedColumnValuePairList - ; - -skewedColumnValuePairList -@init { pushMsg("column value pair list", state); } -@after { popMsg(state); } - : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) - ; - -skewedColumnValuePair -@init { pushMsg("column value pair", state); } -@after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN - -> ^(TOK_TABCOLVALUES $colValues) - ; - -skewedColumnValues -@init { pushMsg("column values", state); } -@after { popMsg(state); } - : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) - ; - -skewedColumnValue -@init { pushMsg("column value", state); } -@after { popMsg(state); } - : - constant - ; - -skewedValueLocationElement -@init { pushMsg("skewed value location element", state); } -@after { popMsg(state); } - : - skewedColumnValue - | skewedColumnValuePair - ; - -columnNameOrder -@init { pushMsg("column name order", state); } -@after { popMsg(state); } - : identifier (asc=KW_ASC | desc=KW_DESC)? - -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) - -> ^(TOK_TABSORTCOLNAMEDESC identifier) - ; - -columnNameCommentList -@init { pushMsg("column name comment list", state); } -@after { popMsg(state); } - : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) - ; - -columnNameComment -@init { pushMsg("column name comment", state); } -@after { popMsg(state); } - : colName=identifier (KW_COMMENT comment=StringLiteral)? - -> ^(TOK_TABCOL $colName TOK_NULL $comment?) - ; - -columnRefOrder -@init { pushMsg("column order", state); } -@after { popMsg(state); } - : expression (asc=KW_ASC | desc=KW_DESC)? - -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) - -> ^(TOK_TABSORTCOLNAMEDESC expression) - ; - -columnNameType -@init { pushMsg("column specification", state); } -@after { popMsg(state); } - : colName=identifier colType (KW_COMMENT comment=StringLiteral)? - -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} - -> {$comment == null}? ^(TOK_TABCOL $colName colType) - -> ^(TOK_TABCOL $colName colType $comment) - ; - -columnNameColonType -@init { pushMsg("column specification", state); } -@after { popMsg(state); } - : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? - -> {$comment == null}? ^(TOK_TABCOL $colName colType) - -> ^(TOK_TABCOL $colName colType $comment) - ; - -colType -@init { pushMsg("column type", state); } -@after { popMsg(state); } - : type - ; - -colTypeList -@init { pushMsg("column type list", state); } -@after { popMsg(state); } - : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) - ; - -type - : primitiveType - | listType - | structType - | mapType - | unionType; - -primitiveType -@init { pushMsg("primitive type specification", state); } -@after { popMsg(state); } - : KW_TINYINT -> TOK_TINYINT - | KW_SMALLINT -> TOK_SMALLINT - | KW_INT -> TOK_INT - | KW_BIGINT -> TOK_BIGINT - | KW_BOOLEAN -> TOK_BOOLEAN - | KW_FLOAT -> TOK_FLOAT - | KW_DOUBLE -> TOK_DOUBLE - | KW_DATE -> TOK_DATE - | KW_DATETIME -> TOK_DATETIME - | KW_TIMESTAMP -> TOK_TIMESTAMP - // Uncomment to allow intervals as table column types - //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH - //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME - | KW_STRING -> TOK_STRING - | KW_BINARY -> TOK_BINARY - | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) - | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) - | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) - ; - -listType -@init { pushMsg("list type", state); } -@after { popMsg(state); } - : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) - ; - -structType -@init { pushMsg("struct type", state); } -@after { popMsg(state); } - : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) - ; - -mapType -@init { pushMsg("map type", state); } -@after { popMsg(state); } - : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN - -> ^(TOK_MAP $left $right) - ; - -unionType -@init { pushMsg("uniontype type", state); } -@after { popMsg(state); } - : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) - ; - -setOperator -@init { pushMsg("set operator", state); } -@after { popMsg(state); } - : KW_UNION KW_ALL -> ^(TOK_UNIONALL) - | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) - ; - -queryStatementExpression[boolean topLevel] - : - /* Would be nice to do this as a gated semantic perdicate - But the predicate gets pushed as a lookahead decision. - Calling rule doesnot know about topLevel - */ - (w=withClause {topLevel}?)? - queryStatementExpressionBody[topLevel] { - if ($w.tree != null) { - $queryStatementExpressionBody.tree.insertChild(0, $w.tree); - } - } - -> queryStatementExpressionBody - ; - -queryStatementExpressionBody[boolean topLevel] - : - fromStatement[topLevel] - | regularBody[topLevel] - ; - -withClause - : - KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) -; - -cteStatement - : - identifier KW_AS LPAREN queryStatementExpression[false] RPAREN - -> ^(TOK_SUBQUERY queryStatementExpression identifier) -; - -fromStatement[boolean topLevel] -: (singleFromStatement -> singleFromStatement) - (u=setOperator r=singleFromStatement - -> ^($u {$fromStatement.tree} $r) - )* - -> {u != null && topLevel}? ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - {$fromStatement.tree} - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) - -> {$fromStatement.tree} - ; - - -singleFromStatement - : - fromClause - ( b+=body )+ -> ^(TOK_QUERY fromClause body+) - ; - -/* -The valuesClause rule below ensures that the parse tree for -"insert into table FOO values (1,2),(3,4)" looks the same as -"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look -very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name -is implicit, it's represented as TOK_ANONYMOUS. -*/ -regularBody[boolean topLevel] - : - i=insertClause - ( - s=selectStatement[topLevel] - {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} - | - valuesClause - -> ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) - ) - ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) - ) - ) - | - selectStatement[topLevel] - ; - -selectStatement[boolean topLevel] - : - ( - s=selectClause - f=fromClause? - w=whereClause? - g=groupByClause? - h=havingClause? - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - $s $w? $g? $h? $o? $c? - $d? $sort? $win? $l?)) - ) - (set=setOpSelectStatement[$selectStatement.tree, topLevel])? - -> {set == null}? - {$selectStatement.tree} - -> {o==null && c==null && d==null && sort==null && l==null}? - {$set.tree} - -> {throwSetOpException()} - ; - -setOpSelectStatement[CommonTree t, boolean topLevel] - : - (u=setOperator b=simpleSelectStatement - -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? - ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) - -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? - ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) - -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? - ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - ^(TOK_UNIONALL {$t} $b) - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) - -> ^(TOK_UNIONALL {$t} $b) - )+ - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? - {$setOpSelectStatement.tree} - -> ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - {$setOpSelectStatement.tree} - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) - $o? $c? $d? $sort? $win? $l? - ) - ) - ; - -simpleSelectStatement - : - selectClause - fromClause? - whereClause? - groupByClause? - havingClause? - ((window_clause) => window_clause)? - -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - selectClause whereClause? groupByClause? havingClause? window_clause?)) - ; - -selectStatementWithCTE - : - (w=withClause)? - selectStatement[true] { - if ($w.tree != null) { - $selectStatement.tree.insertChild(0, $w.tree); - } - } - -> selectStatement - ; - -body - : - insertClause - selectClause - lateralView? - whereClause? - groupByClause? - havingClause? - orderByClause? - clusterByClause? - distributeByClause? - sortByClause? - window_clause? - limitClause? -> ^(TOK_INSERT insertClause - selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? - distributeByClause? sortByClause? window_clause? limitClause?) - | - selectClause - lateralView? - whereClause? - groupByClause? - havingClause? - orderByClause? - clusterByClause? - distributeByClause? - sortByClause? - window_clause? - limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? - distributeByClause? sortByClause? window_clause? limitClause?) - ; - -insertClause -@init { pushMsg("insert clause", state); } -@after { popMsg(state); } - : - KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) - | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? - -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) - ; - -destination -@init { pushMsg("destination specification", state); } -@after { popMsg(state); } - : - (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? - -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) - | KW_TABLE tableOrPartition -> tableOrPartition - ; - -limitClause -@init { pushMsg("limit clause", state); } -@after { popMsg(state); } - : - KW_LIMIT num=Number -> ^(TOK_LIMIT $num) - ; - -//DELETE FROM WHERE ...; -deleteStatement -@init { pushMsg("delete statement", state); } -@after { popMsg(state); } - : - KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) - ; - -/*SET = (3 + col2)*/ -columnAssignmentClause - : - tableOrColumn EQUAL^ precedencePlusExpression - ; - -/*SET col1 = 5, col2 = (4 + col4), ...*/ -setColumnsClause - : - KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) - ; - -/* - UPDATE
    - SET col1 = val1, col2 = val2... WHERE ... -*/ -updateStatement -@init { pushMsg("update statement", state); } -@after { popMsg(state); } - : - KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) - ; - -/* -BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of -"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. -*/ -sqlTransactionStatement -@init { pushMsg("transaction statement", state); } -@after { popMsg(state); } - : - startTransactionStatement - | commitStatement - | rollbackStatement - | setAutoCommitStatement - ; - -startTransactionStatement - : - KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) - ; - -transactionMode - : - isolationLevel - | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) - ; - -transactionAccessMode - : - KW_READ KW_ONLY -> TOK_TXN_READ_ONLY - | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE - ; - -isolationLevel - : - KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) - ; - -/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ -levelOfIsolation - : - KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT - ; - -commitStatement - : - KW_COMMIT ( KW_WORK )? -> TOK_COMMIT - ; - -rollbackStatement - : - KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK - ; -setAutoCommitStatement - : - KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) - ; -/* -END user defined transaction boundaries -*/ diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java deleted file mode 100644 index 35ecdc5ad10a9..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.antlr.runtime.RecognitionException; -import org.antlr.runtime.Token; -import org.antlr.runtime.TokenStream; -import org.antlr.runtime.tree.CommonErrorNode; - -public class ASTErrorNode extends ASTNode { - - /** - * - */ - private static final long serialVersionUID = 1L; - CommonErrorNode delegate; - - public ASTErrorNode(TokenStream input, Token start, Token stop, - RecognitionException e){ - delegate = new CommonErrorNode(input,start,stop,e); - } - - @Override - public boolean isNil() { return delegate.isNil(); } - - @Override - public int getType() { return delegate.getType(); } - - @Override - public String getText() { return delegate.getText(); } - @Override - public String toString() { return delegate.toString(); } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java deleted file mode 100644 index 33d9322b628ec..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java +++ /dev/null @@ -1,245 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -import org.antlr.runtime.Token; -import org.antlr.runtime.tree.CommonTree; -import org.antlr.runtime.tree.Tree; -import org.apache.hadoop.hive.ql.lib.Node; - -public class ASTNode extends CommonTree implements Node, Serializable { - private static final long serialVersionUID = 1L; - private transient StringBuffer astStr; - private transient int startIndx = -1; - private transient int endIndx = -1; - private transient ASTNode rootNode; - private transient boolean isValidASTStr; - - public ASTNode() { - } - - /** - * Constructor. - * - * @param t - * Token for the CommonTree Node - */ - public ASTNode(Token t) { - super(t); - } - - public ASTNode(ASTNode node) { - super(node); - } - - @Override - public Tree dupNode() { - return new ASTNode(this); - } - - /* - * (non-Javadoc) - * - * @see org.apache.hadoop.hive.ql.lib.Node#getChildren() - */ - @Override - public ArrayList getChildren() { - if (super.getChildCount() == 0) { - return null; - } - - ArrayList ret_vec = new ArrayList(); - for (int i = 0; i < super.getChildCount(); ++i) { - ret_vec.add((Node) super.getChild(i)); - } - - return ret_vec; - } - - /* - * (non-Javadoc) - * - * @see org.apache.hadoop.hive.ql.lib.Node#getName() - */ - @Override - public String getName() { - return (Integer.valueOf(super.getToken().getType())).toString(); - } - - public String dump() { - StringBuilder sb = new StringBuilder("\n"); - dump(sb, ""); - return sb.toString(); - } - - private StringBuilder dump(StringBuilder sb, String ws) { - sb.append(ws); - sb.append(toString()); - sb.append("\n"); - - ArrayList children = getChildren(); - if (children != null) { - for (Node node : getChildren()) { - if (node instanceof ASTNode) { - ((ASTNode) node).dump(sb, ws + " "); - } else { - sb.append(ws); - sb.append(" NON-ASTNODE!!"); - sb.append("\n"); - } - } - } - return sb; - } - - private ASTNode getRootNodeWithValidASTStr(boolean useMemoizedRoot) { - if (useMemoizedRoot && rootNode != null && rootNode.parent == null && - rootNode.hasValidMemoizedString()) { - return rootNode; - } - ASTNode retNode = this; - while (retNode.parent != null) { - retNode = (ASTNode) retNode.parent; - } - rootNode=retNode; - if (!rootNode.isValidASTStr) { - rootNode.astStr = new StringBuffer(); - rootNode.toStringTree(rootNode); - rootNode.isValidASTStr = true; - } - return retNode; - } - - private boolean hasValidMemoizedString() { - return isValidASTStr && astStr != null; - } - - private void resetRootInformation() { - // Reset the previously stored rootNode string - if (rootNode != null) { - rootNode.astStr = null; - rootNode.isValidASTStr = false; - } - } - - private int getMemoizedStringLen() { - return astStr == null ? 0 : astStr.length(); - } - - private String getMemoizedSubString(int start, int end) { - return (astStr == null || start < 0 || end > astStr.length() || start >= end) ? null : - astStr.subSequence(start, end).toString(); - } - - private void addtoMemoizedString(String string) { - if (astStr == null) { - astStr = new StringBuffer(); - } - astStr.append(string); - } - - @Override - public void setParent(Tree t) { - super.setParent(t); - resetRootInformation(); - } - - @Override - public void addChild(Tree t) { - super.addChild(t); - resetRootInformation(); - } - - @Override - public void addChildren(List kids) { - super.addChildren(kids); - resetRootInformation(); - } - - @Override - public void setChild(int i, Tree t) { - super.setChild(i, t); - resetRootInformation(); - } - - @Override - public void insertChild(int i, Object t) { - super.insertChild(i, t); - resetRootInformation(); - } - - @Override - public Object deleteChild(int i) { - Object ret = super.deleteChild(i); - resetRootInformation(); - return ret; - } - - @Override - public void replaceChildren(int startChildIndex, int stopChildIndex, Object t) { - super.replaceChildren(startChildIndex, stopChildIndex, t); - resetRootInformation(); - } - - @Override - public String toStringTree() { - - // The root might have changed because of tree modifications. - // Compute the new root for this tree and set the astStr. - getRootNodeWithValidASTStr(true); - - // If rootNotModified is false, then startIndx and endIndx will be stale. - if (startIndx >= 0 && endIndx <= rootNode.getMemoizedStringLen()) { - return rootNode.getMemoizedSubString(startIndx, endIndx); - } - return toStringTree(rootNode); - } - - private String toStringTree(ASTNode rootNode) { - this.rootNode = rootNode; - startIndx = rootNode.getMemoizedStringLen(); - // Leaf node - if ( children==null || children.size()==0 ) { - rootNode.addtoMemoizedString(this.toString()); - endIndx = rootNode.getMemoizedStringLen(); - return this.toString(); - } - if ( !isNil() ) { - rootNode.addtoMemoizedString("("); - rootNode.addtoMemoizedString(this.toString()); - rootNode.addtoMemoizedString(" "); - } - for (int i = 0; children!=null && i < children.size(); i++) { - ASTNode t = (ASTNode)children.get(i); - if ( i>0 ) { - rootNode.addtoMemoizedString(" "); - } - t.toStringTree(rootNode); - } - if ( !isNil() ) { - rootNode.addtoMemoizedString(")"); - } - endIndx = rootNode.getMemoizedStringLen(); - return rootNode.getMemoizedSubString(startIndx, endIndx); - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java deleted file mode 100644 index c77198b087cbd..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.util.ArrayList; -import org.antlr.runtime.ANTLRStringStream; -import org.antlr.runtime.CharStream; -import org.antlr.runtime.NoViableAltException; -import org.antlr.runtime.RecognitionException; -import org.antlr.runtime.Token; -import org.antlr.runtime.TokenRewriteStream; -import org.antlr.runtime.TokenStream; -import org.antlr.runtime.tree.CommonTree; -import org.antlr.runtime.tree.CommonTreeAdaptor; -import org.antlr.runtime.tree.TreeAdaptor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.apache.hadoop.hive.ql.Context; - -/** - * ParseDriver. - * - */ -public class ParseDriver { - - private static final Logger LOG = LoggerFactory.getLogger("hive.ql.parse.ParseDriver"); - - /** - * ANTLRNoCaseStringStream. - * - */ - //This class provides and implementation for a case insensitive token checker - //for the lexical analysis part of antlr. By converting the token stream into - //upper case at the time when lexical rules are checked, this class ensures that the - //lexical rules need to just match the token with upper case letters as opposed to - //combination of upper case and lower case characters. This is purely used for matching lexical - //rules. The actual token text is stored in the same way as the user input without - //actually converting it into an upper case. The token values are generated by the consume() - //function of the super class ANTLRStringStream. The LA() function is the lookahead function - //and is purely used for matching lexical rules. This also means that the grammar will only - //accept capitalized tokens in case it is run from other tools like antlrworks which - //do not have the ANTLRNoCaseStringStream implementation. - public class ANTLRNoCaseStringStream extends ANTLRStringStream { - - public ANTLRNoCaseStringStream(String input) { - super(input); - } - - @Override - public int LA(int i) { - - int returnChar = super.LA(i); - if (returnChar == CharStream.EOF) { - return returnChar; - } else if (returnChar == 0) { - return returnChar; - } - - return Character.toUpperCase((char) returnChar); - } - } - - /** - * HiveLexerX. - * - */ - public class HiveLexerX extends SparkSqlLexer { - - private final ArrayList errors; - - public HiveLexerX(CharStream input) { - super(input); - errors = new ArrayList(); - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - errors.add(new ParseError(this, e, tokenNames)); - } - - @Override - public String getErrorMessage(RecognitionException e, String[] tokenNames) { - String msg = null; - - if (e instanceof NoViableAltException) { - // @SuppressWarnings("unused") - // NoViableAltException nvae = (NoViableAltException) e; - // for development, can add - // "decision=<<"+nvae.grammarDecisionDescription+">>" - // and "(decision="+nvae.decisionNumber+") and - // "state "+nvae.stateNumber - msg = "character " + getCharErrorDisplay(e.c) + " not supported here"; - } else { - msg = super.getErrorMessage(e, tokenNames); - } - - return msg; - } - - public ArrayList getErrors() { - return errors; - } - - } - - /** - * Tree adaptor for making antlr return ASTNodes instead of CommonTree nodes - * so that the graph walking algorithms and the rules framework defined in - * ql.lib can be used with the AST Nodes. - */ - public static final TreeAdaptor adaptor = new CommonTreeAdaptor() { - /** - * Creates an ASTNode for the given token. The ASTNode is a wrapper around - * antlr's CommonTree class that implements the Node interface. - * - * @param payload - * The token. - * @return Object (which is actually an ASTNode) for the token. - */ - @Override - public Object create(Token payload) { - return new ASTNode(payload); - } - - @Override - public Object dupNode(Object t) { - - return create(((CommonTree)t).token); - }; - - @Override - public Object errorNode(TokenStream input, Token start, Token stop, RecognitionException e) { - return new ASTErrorNode(input, start, stop, e); - }; - }; - - public ASTNode parse(String command) throws ParseException { - return parse(command, null); - } - - public ASTNode parse(String command, Context ctx) - throws ParseException { - return parse(command, ctx, true); - } - - /** - * Parses a command, optionally assigning the parser's token stream to the - * given context. - * - * @param command - * command to parse - * - * @param ctx - * context with which to associate this parser's token stream, or - * null if either no context is available or the context already has - * an existing stream - * - * @return parsed AST - */ - public ASTNode parse(String command, Context ctx, boolean setTokenRewriteStream) - throws ParseException { - LOG.info("Parsing command: " + command); - - HiveLexerX lexer = new HiveLexerX(new ANTLRNoCaseStringStream(command)); - TokenRewriteStream tokens = new TokenRewriteStream(lexer); - if (ctx != null) { - if ( setTokenRewriteStream) { - ctx.setTokenRewriteStream(tokens); - } - lexer.setHiveConf(ctx.getConf()); - } - SparkSqlParser parser = new SparkSqlParser(tokens); - if (ctx != null) { - parser.setHiveConf(ctx.getConf()); - } - parser.setTreeAdaptor(adaptor); - SparkSqlParser.statement_return r = null; - try { - r = parser.statement(); - } catch (RecognitionException e) { - e.printStackTrace(); - throw new ParseException(parser.errors); - } - - if (lexer.getErrors().size() == 0 && parser.errors.size() == 0) { - LOG.info("Parse Completed"); - } else if (lexer.getErrors().size() != 0) { - throw new ParseException(lexer.getErrors()); - } else { - throw new ParseException(parser.errors); - } - - ASTNode tree = (ASTNode) r.getTree(); - tree.setUnknownTokenBoundaries(); - return tree; - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java deleted file mode 100644 index b47bcfb2914df..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.antlr.runtime.BaseRecognizer; -import org.antlr.runtime.RecognitionException; - -/** - * - */ -public class ParseError { - private final BaseRecognizer br; - private final RecognitionException re; - private final String[] tokenNames; - - ParseError(BaseRecognizer br, RecognitionException re, String[] tokenNames) { - this.br = br; - this.re = re; - this.tokenNames = tokenNames; - } - - BaseRecognizer getBaseRecognizer() { - return br; - } - - RecognitionException getRecognitionException() { - return re; - } - - String[] getTokenNames() { - return tokenNames; - } - - String getMessage() { - return br.getErrorHeader(re) + " " + br.getErrorMessage(re, tokenNames); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java deleted file mode 100644 index fff891ced5550..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.util.ArrayList; - -/** - * ParseException. - * - */ -public class ParseException extends Exception { - - private static final long serialVersionUID = 1L; - ArrayList errors; - - public ParseException(ArrayList errors) { - super(); - this.errors = errors; - } - - @Override - public String getMessage() { - - StringBuilder sb = new StringBuilder(); - for (ParseError err : errors) { - if (sb.length() > 0) { - sb.append('\n'); - } - sb.append(err.getMessage()); - } - - return sb.toString(); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java deleted file mode 100644 index a5c2998f86cc1..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.apache.hadoop.hive.common.type.HiveDecimal; -import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; - - -/** - * Library of utility functions used in the parse code. - * - */ -public final class ParseUtils { - /** - * Performs a descent of the leftmost branch of a tree, stopping when either a - * node with a non-null token is found or the leaf level is encountered. - * - * @param tree - * candidate node from which to start searching - * - * @return node at which descent stopped - */ - public static ASTNode findRootNonNullToken(ASTNode tree) { - while ((tree.getToken() == null) && (tree.getChildCount() > 0)) { - tree = (org.apache.spark.sql.parser.ASTNode) tree.getChild(0); - } - return tree; - } - - private ParseUtils() { - // prevent instantiation - } - - public static VarcharTypeInfo getVarcharTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() != 1) { - throw new SemanticException("Bad params for type varchar"); - } - - String lengthStr = node.getChild(0).getText(); - return TypeInfoFactory.getVarcharTypeInfo(Integer.valueOf(lengthStr)); - } - - public static CharTypeInfo getCharTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() != 1) { - throw new SemanticException("Bad params for type char"); - } - - String lengthStr = node.getChild(0).getText(); - return TypeInfoFactory.getCharTypeInfo(Integer.valueOf(lengthStr)); - } - - public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() > 2) { - throw new SemanticException("Bad params for type decimal"); - } - - int precision = HiveDecimal.USER_DEFAULT_PRECISION; - int scale = HiveDecimal.USER_DEFAULT_SCALE; - - if (node.getChildCount() >= 1) { - String precStr = node.getChild(0).getText(); - precision = Integer.valueOf(precStr); - } - - if (node.getChildCount() == 2) { - String scaleStr = node.getChild(1).getText(); - scale = Integer.valueOf(scaleStr); - } - - return TypeInfoFactory.getDecimalTypeInfo(precision, scale); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java deleted file mode 100644 index 4b2015e0df84e..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java +++ /dev/null @@ -1,406 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.io.UnsupportedEncodingException; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.antlr.runtime.tree.Tree; -import org.apache.commons.lang.StringUtils; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.metastore.api.FieldSchema; -import org.apache.hadoop.hive.ql.ErrorMsg; -import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.serde.serdeConstants; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; - -/** - * SemanticAnalyzer. - * - */ -public abstract class SemanticAnalyzer { - public static String charSetString(String charSetName, String charSetString) - throws SemanticException { - try { - // The character set name starts with a _, so strip that - charSetName = charSetName.substring(1); - if (charSetString.charAt(0) == '\'') { - return new String(unescapeSQLString(charSetString).getBytes(), - charSetName); - } else // hex input is also supported - { - assert charSetString.charAt(0) == '0'; - assert charSetString.charAt(1) == 'x'; - charSetString = charSetString.substring(2); - - byte[] bArray = new byte[charSetString.length() / 2]; - int j = 0; - for (int i = 0; i < charSetString.length(); i += 2) { - int val = Character.digit(charSetString.charAt(i), 16) * 16 - + Character.digit(charSetString.charAt(i + 1), 16); - if (val > 127) { - val = val - 256; - } - bArray[j++] = (byte)val; - } - - String res = new String(bArray, charSetName); - return res; - } - } catch (UnsupportedEncodingException e) { - throw new SemanticException(e); - } - } - - /** - * Remove the encapsulating "`" pair from the identifier. We allow users to - * use "`" to escape identifier for table names, column names and aliases, in - * case that coincide with Hive language keywords. - */ - public static String unescapeIdentifier(String val) { - if (val == null) { - return null; - } - if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { - val = val.substring(1, val.length() - 1); - } - return val; - } - - /** - * Converts parsed key/value properties pairs into a map. - * - * @param prop ASTNode parent of the key/value pairs - * - * @param mapProp property map which receives the mappings - */ - public static void readProps( - ASTNode prop, Map mapProp) { - - for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { - String key = unescapeSQLString(prop.getChild(propChild).getChild(0) - .getText()); - String value = null; - if (prop.getChild(propChild).getChild(1) != null) { - value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); - } - mapProp.put(key, value); - } - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } - - /** - * Get the list of FieldSchema out of the ASTNode. - */ - public static List getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { - List colList = new ArrayList(); - int numCh = ast.getChildCount(); - for (int i = 0; i < numCh; i++) { - FieldSchema col = new FieldSchema(); - ASTNode child = (ASTNode) ast.getChild(i); - Tree grandChild = child.getChild(0); - if(grandChild != null) { - String name = grandChild.getText(); - if(lowerCase) { - name = name.toLowerCase(); - } - // child 0 is the name of the column - col.setName(unescapeIdentifier(name)); - // child 1 is the type of the column - ASTNode typeChild = (ASTNode) (child.getChild(1)); - col.setType(getTypeStringFromAST(typeChild)); - - // child 2 is the optional comment of the column - if (child.getChildCount() == 3) { - col.setComment(unescapeSQLString(child.getChild(2).getText())); - } - } - colList.add(col); - } - return colList; - } - - protected static String getTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - switch (typeNode.getType()) { - case SparkSqlParser.TOK_LIST: - return serdeConstants.LIST_TYPE_NAME + "<" - + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; - case SparkSqlParser.TOK_MAP: - return serdeConstants.MAP_TYPE_NAME + "<" - + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," - + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; - case SparkSqlParser.TOK_STRUCT: - return getStructTypeStringFromAST(typeNode); - case SparkSqlParser.TOK_UNIONTYPE: - return getUnionTypeStringFromAST(typeNode); - default: - return getTypeName(typeNode); - } - } - - private static String getStructTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; - typeNode = (ASTNode) typeNode.getChild(0); - int children = typeNode.getChildCount(); - if (children <= 0) { - throw new SemanticException("empty struct not allowed."); - } - StringBuilder buffer = new StringBuilder(typeStr); - for (int i = 0; i < children; i++) { - ASTNode child = (ASTNode) typeNode.getChild(i); - buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); - buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); - if (i < children - 1) { - buffer.append(","); - } - } - - buffer.append(">"); - return buffer.toString(); - } - - private static String getUnionTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; - typeNode = (ASTNode) typeNode.getChild(0); - int children = typeNode.getChildCount(); - if (children <= 0) { - throw new SemanticException("empty union not allowed."); - } - StringBuilder buffer = new StringBuilder(typeStr); - for (int i = 0; i < children; i++) { - buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); - if (i < children - 1) { - buffer.append(","); - } - } - buffer.append(">"); - typeStr = buffer.toString(); - return typeStr; - } - - public static String getAstNodeText(ASTNode tree) { - return tree.getChildCount() == 0?tree.getText() : - getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); - } - - public static String generateErrorMessage(ASTNode ast, String message) { - StringBuilder sb = new StringBuilder(); - if (ast == null) { - sb.append(message).append(". Cannot tell the position of null AST."); - return sb.toString(); - } - sb.append(ast.getLine()); - sb.append(":"); - sb.append(ast.getCharPositionInLine()); - sb.append(" "); - sb.append(message); - sb.append(". Error encountered near token '"); - sb.append(getAstNodeText(ast)); - sb.append("'"); - return sb.toString(); - } - - private static final Map TokenToTypeName = new HashMap(); - - static { - TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); - } - - public static String getTypeName(ASTNode node) throws SemanticException { - int token = node.getType(); - String typeName; - - // datetime type isn't currently supported - if (token == SparkSqlParser.TOK_DATETIME) { - throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); - } - - switch (token) { - case SparkSqlParser.TOK_CHAR: - CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); - typeName = charTypeInfo.getQualifiedName(); - break; - case SparkSqlParser.TOK_VARCHAR: - VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); - typeName = varcharTypeInfo.getQualifiedName(); - break; - case SparkSqlParser.TOK_DECIMAL: - DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); - typeName = decTypeInfo.getQualifiedName(); - break; - default: - typeName = TokenToTypeName.get(token); - } - return typeName; - } - - public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { - boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); - if (testMode) { - URI uri = new Path(location).toUri(); - String scheme = uri.getScheme(); - String authority = uri.getAuthority(); - String path = uri.getPath(); - if (!path.startsWith("/")) { - path = (new Path(System.getProperty("test.tmp.dir"), - path)).toUri().getPath(); - } - if (StringUtils.isEmpty(scheme)) { - scheme = "pfile"; - } - try { - uri = new URI(scheme, authority, path, null, null); - } catch (URISyntaxException e) { - throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); - } - return uri.toString(); - } else { - //no-op for non-test mode for now - return location; - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b1d841d1b5543..0e89928cb636d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -27,28 +27,28 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse.SemanticException +import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{logical, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} -import org.apache.spark.sql.parser._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -227,7 +227,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - newChildren.foreach(n.addChild(_)) + n.addChildren(newChildren.asJava) n } @@ -273,8 +273,7 @@ private[hive] object HiveQl extends Logging { private def createContext(): Context = new Context(hiveConf) private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(sql, context)) + ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) /** * Returns the HiveConf @@ -313,7 +312,7 @@ private[hive] object HiveQl extends Logging { context.clear() plan } catch { - case pe: ParseException => + case pe: org.apache.hadoop.hive.ql.parse.ParseException => pe.getMessage match { case errorRegEx(line, start, message) => throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) @@ -338,8 +337,7 @@ private[hive] object HiveQl extends Logging { val tree = try { ParseUtils.findRootNonNullToken( - (new ParseDriver) - .parse(ddl, null /* no context required for parsing alone */)) + (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) @@ -600,12 +598,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => { - nameParts match { + case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { + nameParts.head match { case Token(".", dbName :: tableName :: Nil) => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts) + val tableIdent = extractTableIdent(nameParts.head) DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => @@ -664,7 +662,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { val schema = maybeColumns.map { cols => - SemanticAnalyzer.getColumns(cols, true).asScala.map { field => + BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. HiveColumn(field.getName, null, field.getComment) @@ -680,7 +678,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeComment.foreach { case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = SemanticAnalyzer.unescapeSQLString(child.getText) + val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) if (comment ne null) { properties += ("comment" -> comment) } @@ -752,7 +750,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C children.collect { case list @ Token("TOK_TABCOLLIST", _) => - val cols = SemanticAnalyzer.getColumns(list, true) + val cols = BaseSemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( schema = cols.asScala.map { field => @@ -760,11 +758,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }) } case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = SemanticAnalyzer.unescapeSQLString(child.getText) + val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) // TODO support the sql text tableDesc = tableDesc.copy(viewText = Option(comment)) case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = SemanticAnalyzer.getColumns(list(0), false) + val cols = BaseSemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( partitionColumns = cols.asScala.map { field => @@ -775,21 +773,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = SemanticAnalyzer.unescapeSQLString (rowChild1.getText()) + val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) if (rowChild2.length > 1) { - val fieldEscape = SemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) + val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) + val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) + val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) + val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) if (!(lineDelim == "\n") && !(lineDelim == "10")) { throw new AnalysisException( SemanticAnalyzer.generateErrorMessage( @@ -798,22 +796,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = SemanticAnalyzer.unescapeSQLString(rowChild.getText) + val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) // TODO support the nullFormat case _ => assert(false) } tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => - var location = SemanticAnalyzer.unescapeSQLString(child.getText) - location = SemanticAnalyzer.relativeToAbsolutePath(hiveConf, location) + var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + location = EximUtil.relativeToAbsolutePath(hiveConf, location) tableDesc = tableDesc.copy(location = Option(location)) case Token("TOK_TABLESERIALIZER", child :: Nil) => tableDesc = tableDesc.copy( - serde = Option(SemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) + serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) if (child.getChildCount == 2) { val serdeParams = new java.util.HashMap[String, String]() - SemanticAnalyzer.readProps( + BaseSemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) @@ -893,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = - Option(SemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), + Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), outputFormat = - Option(SemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) + Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) case Token("TOK_STORAGEHANDLER", _) => throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) case _ => // Unsupport features @@ -911,20 +909,24 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) - if Seq("TOK_CTE", "TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = queryArgs match { - case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => - val cteRelations = ctes.map { node => - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - relation.alias -> relation + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + // check if has CTE + insertClauses.last match { + case Token("TOK_CTE", cteClauses) => + val cteRelations = cteClauses.map(node => { + val relation = nodeToRelation(node, context).asInstanceOf[Subquery] + (relation.alias, relation) + }).toMap + (Some(args.head), insertClauses.init, Some(cteRelations)) + + case _ => (Some(args.head), insertClauses, None) } - (Some(from.head), inserts, Some(cteRelations.toMap)) - case Token("TOK_FROM", from) :: inserts => - (Some(from.head), inserts, None) - case Token("TOK_INSERT", _) :: Nil => - (None, queryArgs, None) + + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) } // Return one query for each insert clause. @@ -1023,20 +1025,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(SemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (SemanticAnalyzer.unescapeSQLString(name), - SemanticAnalyzer.unescapeSQLString(value)) + (BaseSemanticAnalyzer.unescapeSQLString(name), + BaseSemanticAnalyzer.unescapeSQLString(value)) } // SPARK-10310: Special cases LazySimpleSerDe // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = SemanticAnalyzer.unescapeSQLString(serdeClass) + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) val useDefaultRecordReaderWriter = unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) @@ -1053,7 +1055,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = matchSerDe(outputSerdeClause) - val unescapedScript = SemanticAnalyzer.unescapeSQLString(script) + val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) // TODO Adds support for user-defined record reader/writer classes val recordReaderClass = if (useDefaultRecordReader) { @@ -1359,7 +1361,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_ANTIJOIN" => throw new NotImplementedError("Anti join not supported") } Join(nodeToRelation(relation1, context), nodeToRelation(relation2, context), @@ -1474,11 +1475,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val numericAstTypes = Seq( - SparkSqlParser.Number, - SparkSqlParser.TinyintLiteral, - SparkSqlParser.SmallintLiteral, - SparkSqlParser.BigintLiteral, - SparkSqlParser.DecimalLiteral) + HiveParser.Number, + HiveParser.TinyintLiteral, + HiveParser.SmallintLiteral, + HiveParser.BigintLiteral, + HiveParser.DecimalLiteral) /* Case insensitive matches */ val COUNT = "(?i)COUNT".r @@ -1648,7 +1649,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(TRUE(), Nil) => Literal.create(true, BooleanType) case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => SemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) // This code is adapted from // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 @@ -1683,37 +1684,37 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C v } - case ast: ASTNode if ast.getType == SparkSqlParser.StringLiteral => - Literal(SemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == HiveParser.StringLiteral => + Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_DATELITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_CHARSETLITERAL => - Literal(SemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => + Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => Literal(CalendarInterval.fromYearMonthString(ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => Literal(CalendarInterval.fromDayTimeString(ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) case a: ASTNode => From 932cf44248e067ee7cae6fef79ddf2ab9b1c36d8 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Wed, 30 Dec 2015 11:14:13 +0000 Subject: [PATCH 1376/2089] [SPARK-12263][DOCS] IllegalStateException: Memory can't be 0 for SPARK_WORKER_MEMORY without unit Updated the Worker Unit IllegalStateException message to indicate no values less than 1MB instead of 0 to help solve this. Requesting review Author: Neelesh Srinivas Salian Closes #10483 from nssalian/SPARK-12263. --- .../scala/org/apache/spark/deploy/worker/WorkerArguments.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 5181142c5f80e..de3c7cd265d2d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -175,7 +175,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { def checkWorkerMemory(): Unit = { if (memory <= 0) { - val message = "Memory can't be 0, missing a M or G on the end of the memory specification?" + val message = "Memory is below 1MB, or missing a M/G at the end of the memory specification?" throw new IllegalStateException(message) } } From aa48164a43bd9ed9eab53fcacbed92819e84eaf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Dec 2015 10:56:08 -0800 Subject: [PATCH 1377/2089] [SPARK-12495][SQL] use true as default value for propagateNull in NewInstance Most of cases we should propagate null when call `NewInstance`, and so far there is only one case we should stop null propagation: create product/java bean. So I think it makes more sense to propagate null by dafault. This also fixes a bug when encode null array/map, which is firstly discovered in https://github.com/apache/spark/pull/10401 Author: Wenchen Fan Closes #10443 from cloud-fan/encoder. --- .../sql/catalyst/JavaTypeInference.scala | 16 ++++++------- .../spark/sql/catalyst/ScalaReflection.scala | 16 ++++++------- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 -- .../sql/catalyst/expressions/objects.scala | 12 +++++----- .../encoders/EncoderResolutionSuite.scala | 24 +++++++++---------- .../encoders/ExpressionEncoderSuite.scala | 3 +++ 7 files changed, 38 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index a1500cbc305d8..ed153d1f88945 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -178,19 +178,19 @@ object JavaTypeInference { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -298,7 +298,7 @@ object JavaTypeInference { p.getWriteMethod.getName -> setter }.toMap - val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) if (path.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8a22b37d07fc6..9784c969665de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( @@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection { } } - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) if (path.nonEmpty) { expressions.If( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7a4401cf5810e..ad4beda9c4916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -133,7 +133,7 @@ object ExpressionEncoder { } val fromRowExpression = - NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) + NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 63bdf05ca7c28..6f3d5ba84c9ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -55,7 +55,6 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) @@ -166,7 +165,6 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index d40cd96905732..fb404c12d5a04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -165,7 +165,7 @@ case class Invoke( ${obj.code} ${argGen.map(_.code).mkString("\n")} - boolean ${ev.isNull} = ${obj.value} == null; + boolean ${ev.isNull} = ${obj.isNull}; $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value; @@ -178,8 +178,8 @@ object NewInstance { def apply( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = false, - dataType: DataType): NewInstance = + dataType: DataType, + propagateNull: Boolean = true): NewInstance = new NewInstance(cls, arguments, propagateNull, dataType, None) } @@ -231,7 +231,7 @@ case class NewInstance( s"new $className($argString)" } - if (propagateNull) { + if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -248,8 +248,8 @@ case class NewInstance( s""" $setup - $javaType ${ev.value} = $constructorCall; - final boolean ${ev.isNull} = ${ev.value} == null; + final $javaType ${ev.value} = $constructorCall; + final boolean ${ev.isNull} = false; """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 764ffdc0947c4..bc36a55ae0ea2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest { toExternalString('a.string), AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") ), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest { toExternalString('a.int.cast(StringType)), AssertNotNull('b.long, cls.getName, "b", "Long") ), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } } @@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest { AssertNotNull( GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), innerCls.getName, "b", "Long")), - false, - ObjectType(innerCls)) + ObjectType(innerCls), + propagateNull = false) )), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest { AssertNotNull( GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), cls.getName, "b", "Long")), - false, - ObjectType(cls)), + ObjectType(cls), + propagateNull = false), 'b.int.cast(LongType)), - false, - ObjectType(classOf[Tuple2[_, _]])) + ObjectType(classOf[Tuple2[_, _]]), + propagateNull = false) compareExpressions(fromRowExpr, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7233e0f1b5baf..666699e18d4a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") + encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")( From d1ca634db4ca9db7f0ba7ca38a0e03bcbfec23c9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 30 Dec 2015 11:14:47 -0800 Subject: [PATCH 1378/2089] [SPARK-12300] [SQL] [PYSPARK] fix schema inferance on local collections Current schema inference for local python collections halts as soon as there are no NullTypes. This is different than when we specify a sampling ratio of 1.0 on a distributed collection. This could result in incomplete schema information. Author: Holden Karau Closes #10275 from holdenk/SPARK-12300-fix-schmea-inferance-on-local-collections. --- python/pyspark/sql/context.py | 10 +++------- python/pyspark/sql/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b05aa2f5c4cd7..ba6915a12347e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -18,6 +18,7 @@ import sys import warnings import json +from functools import reduce if sys.version >= '3': basestring = unicode = str @@ -236,14 +237,9 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) + schema = reduce(_merge_type, map(_infer_schema, data)) if _has_nulltype(schema): - for r in data: - schema = _merge_type(schema, _infer_schema(r)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined after inferring") + raise ValueError("Some of types cannot be determined after inferring") return schema def _inferSchema(self, rdd, samplingRatio=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9f5f7cfdf7a69..10b99175ad952 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -353,6 +353,17 @@ def test_apply_schema_to_row(self): df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + df = self.sqlCtx.createDataFrame(input) + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) From 27a42c7108ced48a7f558990de2e4fc7ed340119 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 30 Dec 2015 12:47:42 -0800 Subject: [PATCH 1379/2089] [SPARK-10359] Enumerate dependencies in a file and diff against it for new pull requests This patch adds a new build check which enumerates Spark's resolved runtime classpath and saves it to a file, then diffs against that file to detect whether pull requests have introduced dependency changes. The aim of this check is to make it simpler to reason about whether pull request which modify the build have introduced new dependencies or changed transitive dependencies in a way that affects the final classpath. This supplants the checks added in SPARK-4123 / #5093, which are currently disabled due to bugs. This patch is based on pwendell's work in #8531. Closes #8531. Author: Josh Rosen Author: Patrick Wendell Closes #10461 from JoshRosen/SPARK-10359. --- .rat-excludes | 1 + dev/deps/spark-deps-hadoop-2.3 | 184 ++++++++++++++++++++++++++++++ dev/deps/spark-deps-hadoop-2.4 | 185 +++++++++++++++++++++++++++++++ dev/run-tests-jenkins.py | 2 +- dev/run-tests.py | 8 ++ dev/sparktestsupport/__init__.py | 1 + dev/sparktestsupport/modules.py | 15 ++- dev/test-dependencies.sh | 102 +++++++++++++++++ dev/tests/pr_new_dependencies.sh | 117 ------------------- pom.xml | 17 +++ 10 files changed, 512 insertions(+), 120 deletions(-) create mode 100644 dev/deps/spark-deps-hadoop-2.3 create mode 100644 dev/deps/spark-deps-hadoop-2.4 create mode 100755 dev/test-dependencies.sh delete mode 100755 dev/tests/pr_new_dependencies.sh diff --git a/.rat-excludes b/.rat-excludes index 3544c0fc3d910..bf071eba652b1 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -85,3 +85,4 @@ org.apache.spark.sql.sources.DataSourceRegister org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet LZ4BlockInputStream.java +spark-deps-.* diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 new file mode 100644 index 0000000000000..6014d50c6b6fd --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -0,0 +1,184 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.3.0.jar +hadoop-auth-2.3.0.jar +hadoop-client-2.3.0.jar +hadoop-common-2.3.0.jar +hadoop-hdfs-2.3.0.jar +hadoop-mapreduce-client-app-2.3.0.jar +hadoop-mapreduce-client-common-2.3.0.jar +hadoop-mapreduce-client-core-2.3.0.jar +hadoop-mapreduce-client-jobclient-2.3.0.jar +hadoop-mapreduce-client-shuffle-2.3.0.jar +hadoop-yarn-api-2.3.0.jar +hadoop-yarn-client-2.3.0.jar +hadoop-yarn-common-2.3.0.jar +hadoop-yarn-server-common-2.3.0.jar +hadoop-yarn-server-web-proxy-2.3.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 new file mode 100644 index 0000000000000..f56e6f4393e78 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -0,0 +1,185 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.4.0.jar +hadoop-auth-2.4.0.jar +hadoop-client-2.4.0.jar +hadoop-common-2.4.0.jar +hadoop-hdfs-2.4.0.jar +hadoop-mapreduce-client-app-2.4.0.jar +hadoop-mapreduce-client-common-2.4.0.jar +hadoop-mapreduce-client-core-2.4.0.jar +hadoop-mapreduce-client-jobclient-2.4.0.jar +hadoop-mapreduce-client-shuffle-2.4.0.jar +hadoop-yarn-api-2.4.0.jar +hadoop-yarn-client-2.4.0.jar +hadoop-yarn-common-2.4.0.jar +hadoop-yarn-server-common-2.4.0.jar +hadoop-yarn-server-web-proxy-2.4.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 42afca0e52448..6501721572f98 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -124,6 +124,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', ERROR_CODES["BLOCK_BUILD"]: 'to build', + ERROR_CODES["BLOCK_BUILD_TESTS"]: 'build dependency tests', ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', @@ -193,7 +194,6 @@ def main(): pr_tests = [ "pr_merge_ability", "pr_public_classes" - # DISABLED (pwendell) "pr_new_dependencies" ] # `bind_message_base` returns a function to generate messages for Github posting diff --git a/dev/run-tests.py b/dev/run-tests.py index 6129f87cf8503..706e2d141c27f 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -417,6 +417,11 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_build_tests(): + set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") + run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + + def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") @@ -537,6 +542,9 @@ def main(): # if "DOCS" in changed_modules and test_env == "amplab_jenkins": # build_spark_documentation() + if any(m.should_run_build_tests for m in test_modules): + run_build_tests() + # spark build build_apache_spark(build_tool, hadoop_version) diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 0e8032d13341e..89015f8c4fb9c 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -32,5 +32,6 @@ "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, "BLOCK_JAVA_STYLE": 21, + "BLOCK_BUILD_TESTS": 22, "BLOCK_TIMEOUT": 124 } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d65547e04db4b..4667b289f507a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + test_tags=(), should_run_r_tests=False, should_run_build_tests=False): """ Define a new module. @@ -53,6 +53,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param test_tags A set of tags that will be excluded when running unit tests if the module is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. + :param should_run_build_tests: If true, changes in this module will trigger build tests. """ self.name = name self.dependencies = dependencies @@ -64,6 +65,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.blacklisted_python_implementations = blacklisted_python_implementations self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests + self.should_run_build_tests = should_run_build_tests self.dependent_modules = set() for dep in dependencies: @@ -394,6 +396,14 @@ def contains_file(self, filename): ] ) +build = Module( + name="build", + dependencies=[], + source_file_regexes=[ + ".*pom.xml", + "dev/test-dependencies.sh", + ] +) ec2 = Module( name="ec2", @@ -433,5 +443,6 @@ def contains_file(self, filename): "test", ], python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), - should_run_r_tests=True + should_run_r_tests=True, + should_run_build_tests=True ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh new file mode 100755 index 0000000000000..984e29d1beb88 --- /dev/null +++ b/dev/test-dependencies.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" + +# TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. + +# NOTE: These should match those in the release publishing script +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pyarn -Phive" +MVN="build/mvn --force" +HADOOP_PROFILES=( + hadoop-2.3 + hadoop-2.4 +) + +# We'll switch the version to a temp. one, publish POMs using that new version, then switch back to +# the old version. We need to do this because the `dependency:build-classpath` task needs to +# resolve Spark's internal submodule dependencies. + +# See http://stackoverflow.com/a/3545363 for an explanation of this one-liner: +OLD_VERSION=$(mvn help:evaluate -Dexpression=project.version|grep -Ev '(^\[|Download\w+:)') +TEMP_VERSION="spark-$(date +%s | tail -c6)" + +function reset_version { + # Delete the temporary POMs that we wrote to the local Maven repo: + find "$HOME/.m2/" | grep "$TEMP_VERSION" | xargs rm -rf + + # Restore the original version number: + $MVN -q versions:set -DnewVersion=$OLD_VERSION -DgenerateBackupPoms=false > /dev/null +} +trap reset_version EXIT + +$MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /dev/null + +# Generate manifests for each Hadoop profile: +for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do + echo "Performing Maven install for $HADOOP_PROFILE" + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar install:install -q \ + -pl '!assembly' \ + -pl '!examples' \ + -pl '!external/flume-assembly' \ + -pl '!external/kafka-assembly' \ + -pl '!external/twitter' \ + -pl '!external/flume' \ + -pl '!external/mqtt' \ + -pl '!external/mqtt-assembly' \ + -pl '!external/zeromq' \ + -pl '!external/kafka' \ + -pl '!tags' \ + -DskipTests + + echo "Generating dependency manifest for $HADOOP_PROFILE" + mkdir -p dev/pr-deps + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly \ + | grep "Building Spark Project Assembly" -A 5 \ + | tail -n 1 | tr ":" "\n" | rev | cut -d "/" -f 1 | rev | sort \ + | grep -v spark > dev/pr-deps/spark-deps-$HADOOP_PROFILE +done + +if [[ $@ == **replace-manifest** ]]; then + echo "Replacing manifests and creating new files at dev/deps" + rm -rf dev/deps + mv dev/pr-deps dev/deps + exit 0 +fi + +for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do + set +e + dep_diff="$( + git diff \ + --no-index \ + dev/deps/spark-deps-$HADOOP_PROFILE \ + dev/pr-deps/spark-deps-$HADOOP_PROFILE \ + )" + set -e + if [ "$dep_diff" != "" ]; then + echo "Spark's published dependencies DO NOT MATCH the manifest file (dev/spark-deps)." + echo "To update the manifest file, run './dev/test-dependencies.sh --replace-manifest'." + echo "$dep_diff" + rm -rf dev/pr-deps + exit 1 + fi +done diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh deleted file mode 100755 index fdfb3c62aff58..0000000000000 --- a/dev/tests/pr_new_dependencies.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# -# This script follows the base format for testing pull requests against -# another branch and returning results to be published. More details can be -# found at dev/run-tests-jenkins. -# -# Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` -# Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` -# Arg3: Current PR Commit Hash -#+ the PR hash for the current commit -# - -ghprbActualCommit="$1" -sha1="$2" -current_pr_head="$3" - -MVN_BIN="build/mvn" -CURR_CP_FILE="my-classpath.txt" -MASTER_CP_FILE="master-classpath.txt" - -# First switch over to the master branch -git checkout -f master -# Find and copy all pom.xml files into a *.gate file that we can check -# against through various `git` changes -find -name "pom.xml" -exec cp {} {}.gate \; -# Switch back to the current PR -git checkout -f "${current_pr_head}" - -# Check if any *.pom files from the current branch are different from the master -difference_q="" -for p in $(find -name "pom.xml"); do - [[ -f "${p}" && -f "${p}.gate" ]] && \ - difference_q="${difference_q}$(diff $p.gate $p)" -done - -# If no pom files were changed we can easily say no new dependencies were added -if [ -z "${difference_q}" ]; then - echo " * This patch does not change any dependencies." -else - # Else we need to manually build spark to determine what, if any, dependencies - # were added into the Spark assembly jar - ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ - sed -n -e '/Building Spark Project Assembly/,$p' | \ - grep --context=1 -m 2 "Dependencies classpath:" | \ - head -n 3 | \ - tail -n 1 | \ - tr ":" "\n" | \ - rev | \ - cut -d "/" -f 1 | \ - rev | \ - sort > ${CURR_CP_FILE} - - # Checkout the master branch to compare against - git checkout -f master - - ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ - sed -n -e '/Building Spark Project Assembly/,$p' | \ - grep --context=1 -m 2 "Dependencies classpath:" | \ - head -n 3 | \ - tail -n 1 | \ - tr ":" "\n" | \ - rev | \ - cut -d "/" -f 1 | \ - rev | \ - sort > ${MASTER_CP_FILE} - - DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" - - if [ -z "${DIFF_RESULTS}" ]; then - echo " * This patch does not change any dependencies." - else - # Pretty print the new dependencies - added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') - removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') - added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" - removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" - - # Construct the final returned message with proper - return_mssg="" - [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" - if [ -n "${removed_deps}" ]; then - if [ -n "${return_mssg}" ]; then - return_mssg="${return_mssg}\n${removed_deps_text}" - else - return_mssg="${removed_deps_text}" - fi - fi - echo "${return_mssg}" - fi - - # Remove the files we've left over - [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" - [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" - - # Clean up our mess from the Maven builds just in case - ${MVN_BIN} clean &>/dev/null -fi diff --git a/pom.xml b/pom.xml index 284c219519bca..62ea829b1dbfd 100644 --- a/pom.xml +++ b/pom.xml @@ -2113,6 +2113,23 @@ maven-deploy-plugin 2.8.2 + + org.apache.maven.plugins + maven-dependency-plugin + + + default-cli + + build-classpath + + + + runtime + + + + From 5c2682b0c8fd2aeae2af1adb716ee0d5f8b85135 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 30 Dec 2015 13:34:37 -0800 Subject: [PATCH 1380/2089] [SPARK-12409][SPARK-12387][SPARK-12391][SQL] Support AND/OR/IN/LIKE push-down filters for JDBC This is rework from #10386 and add more tests and LIKE push-down support. Author: Takeshi YAMAMURO Closes #10468 from maropu/SupportMorePushdownInJdbc. --- .../execution/datasources/jdbc/JDBCRDD.scala | 9 +++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 28 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 4e2f5059be4e1..7072ee4b4e3bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -179,6 +179,7 @@ private[sql] object JDBCRDD extends Logging { case stringValue: String => s"'${escapeSql(stringValue)}'" case timestampValue: Timestamp => "'" + timestampValue + "'" case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Object] => arrayValue.map(compileValue).mkString(", ") case _ => value } @@ -191,13 +192,19 @@ private[sql] object JDBCRDD extends Logging { */ private def compileFilter(f: Filter): String = f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" + case Not(f) => s"(NOT (${compileFilter(f)}))" case LessThan(attr, value) => s"$attr < ${compileValue(value)}" case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" + case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" case IsNull(attr) => s"$attr IS NULL" case IsNotNull(attr) => s"$attr IS NOT NULL" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Or(f1, f2) => s"(${compileFilter(f1)}) OR (${compileFilter(f2)})" + case And(f1, f2) => s"(${compileFilter(f1)}) AND (${compileFilter(f2)})" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4044a10ce70cc..00e37f107a88b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -186,8 +187,26 @@ class JDBCSuite extends SparkFunSuite assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) + .collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) + .collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) + .collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " + + "AND THEID = 2")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) + + // This is a test to reflect discussion in SPARK-12218. + // The older versions of spark have this kind of bugs in parquet data source. + val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')") + val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") + assert(df1.collect.toSet === Set(Row("mary", 2))) + assert(df2.collect.toSet === Set(Row("mary", 2))) } test("SELECT * WHERE (quoted strings)") { @@ -437,7 +456,11 @@ class JDBCSuite extends SparkFunSuite val compileFilter = PrivateMethod[String]('compileFilter) def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") - assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "col1 != 'abc'") + assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") + assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) + === "(col0 = 0) AND (col1 = 'def')") + assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) + === "(col0 = 2) OR (col1 = 'ghi')") assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5") assert(doCompileFilter(LessThan("col3", Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'") @@ -445,6 +468,9 @@ class JDBCSuite extends SparkFunSuite assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5") assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3") assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3") + assert(doCompileFilter(In("col1", Array("jkl"))) === "col1 IN ('jkl')") + assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) + === "(NOT (col1 IN ('mno', 'pqr')))") assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") } From b244297966be1d09f8e861cfe2d8e69f7bed84da Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Wed, 30 Dec 2015 13:49:10 -0800 Subject: [PATCH 1381/2089] [SPARK-12399] Display correct error message when accessing REST API with an unknown app Id I got an exception when accessing the below REST API with an unknown application Id. `http://:18080/api/v1/applications/xxx/jobs` Instead of an exception, I expect an error message "no such app: xxx" which is a similar error message when I access `/api/v1/applications/xxx` ``` org.spark-project.guava.util.concurrent.UncheckedExecutionException: java.util.NoSuchElementException: no app with key xxx at org.spark-project.guava.cache.LocalCache$Segment.get(LocalCache.java:2263) at org.spark-project.guava.cache.LocalCache.get(LocalCache.java:4000) at org.spark-project.guava.cache.LocalCache.getOrLoad(LocalCache.java:4004) at org.spark-project.guava.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874) at org.apache.spark.deploy.history.HistoryServer.getSparkUI(HistoryServer.scala:116) at org.apache.spark.status.api.v1.UIRoot$class.withSparkUI(ApiRootResource.scala:226) at org.apache.spark.deploy.history.HistoryServer.withSparkUI(HistoryServer.scala:46) at org.apache.spark.status.api.v1.ApiRootResource.getJobs(ApiRootResource.scala:66) ``` Author: Carson Wang Closes #10352 from carsonwang/unknownAppFix. --- .../spark/deploy/history/HistoryServer.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 0bc0cb1c15eb2..6143a33b69344 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -21,6 +21,8 @@ import java.util.NoSuchElementException import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} +import scala.util.control.NonFatal + import com.google.common.cache._ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -113,7 +115,17 @@ class HistoryServer( } def getSparkUI(appKey: String): Option[SparkUI] = { - Option(appCache.get(appKey)) + try { + val ui = appCache.get(appKey) + Some(ui) + } catch { + case NonFatal(e) => e.getCause() match { + case nsee: NoSuchElementException => + None + + case cause: Exception => throw cause + } + } } initialize() @@ -193,7 +205,7 @@ class HistoryServer( appCache.get(appId + attemptId.map { id => s"/$id" }.getOrElse("")) true } catch { - case e: Exception => e.getCause() match { + case NonFatal(e) => e.getCause() match { case nsee: NoSuchElementException => false From f76ee109d87e727710d2721e4be47fdabc21582c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 30 Dec 2015 16:51:07 -0800 Subject: [PATCH 1382/2089] [SPARK-8641][SPARK-12455][SQL] Native Spark Window functions - Follow-up (docs & tests) This PR is a follow-up for PR https://github.com/apache/spark/pull/9819. It adds documentation for the window functions and a couple of NULL tests. The documentation was largely based on the documentation in (the source of) Hive and Presto: * https://prestodb.io/docs/current/functions/window.html * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+WindowingAndAnalytics I am not sure if we need to add the licenses of these two projects to the licenses directory. They are both under the ASL. srowen any thoughts? cc yhuai Author: Herman van Hovell Closes #10402 from hvanhovell/SPARK-8641-docs. --- .../expressions/windowExpressions.scala | 130 +++++++++++++++++- .../spark/sql/DataFrameWindowSuite.scala | 20 +++ .../sql/hive/execution/WindowQuerySuite.scala | 15 ++ 3 files changed, 162 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 91f169e7eac4b..f1a333b8e56a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -314,8 +314,8 @@ abstract class OffsetWindowFunction val offset: Expression /** - * Direction (above = 1/below = -1) of the number of rows between the current row and the row - * where the input expression is evaluated. + * Direction of the number of rows between the current row and the row where the input expression + * is evaluated. */ val direction: SortDirection @@ -327,7 +327,7 @@ abstract class OffsetWindowFunction * both the input and the default expression are foldable, the result is still not foldable due to * the frame. */ - override def foldable: Boolean = input.foldable && (default == null || default.foldable) + override def foldable: Boolean = false override def nullable: Boolean = default == null || default.nullable @@ -353,6 +353,21 @@ abstract class OffsetWindowFunction override def toString: String = s"$prettyName($input, $offset, $default)" } +/** + * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows after the current row. + * @param offset rows to jump ahead in the partition. + * @param default to use when the input value is null or when the offset is larger than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows after the + current row in the window""") case class Lead(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -365,6 +380,21 @@ case class Lead(input: Expression, offset: Expression, default: Expression) override val direction = Ascending } +/** + * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows before the current row. + * @param offset rows to jump back in the partition. + * @param default to use when the input value is null or when the offset is smaller than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows before the + current row in the window""") case class Lag(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -409,10 +439,31 @@ object SizeBasedWindowFunction { val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() } +/** + * The RowNumber function computes a unique, sequential number to each row, starting with one, + * according to the ordering of rows within the window partition. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +@ExpressionDescription(usage = + """_FUNC_() - The ROW_NUMBER() function assigns a unique, sequential + number to each row, starting with one, according to the ordering of rows within the window + partition.""") case class RowNumber() extends RowNumberLike { override val evaluateExpression = rowNumber } +/** + * The CumeDist function computes the position of a value relative to a all values in the partition. + * The result is the number of rows preceding or equal to the current row in the ordering of the + * partition divided by the total number of rows in the window partition. Any tie values in the + * ordering will evaluate to the same position. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +@ExpressionDescription(usage = + """_FUNC_() - The CUME_DIST() function computes the position of a value relative to a all values + in the partition.""") case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { override def dataType: DataType = DoubleType // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must @@ -421,6 +472,30 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) } +/** + * The NTile function divides the rows for each window partition into 'n' buckets ranging from 1 to + * at most 'n'. Bucket values will differ by at most 1. If the number of rows in the partition does + * not divide evenly into the number of buckets, then the remainder values are distributed one per + * bucket, starting with the first bucket. + * + * The NTile function is particularly useful for the calculation of tertiles, quartiles, deciles and + * other common summary statistics + * + * The function calculates two variables during initialization: The size of a regular bucket, and + * the number of buckets that will have one extra row added to it (when the rows do not evenly fit + * into the number of buckets); both variables are based on the size of the current partition. + * During the calculation process the function keeps track of the current row number, the current + * bucket number, and the row number at which the bucket will change (bucketThreshold). When the + * current row number reaches bucket threshold, the bucket value is increased by one and the the + * threshold is increased by the bucket size (plus one extra if the current bucket is padded). + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param buckets number of buckets to divide the rows in. Default value is 1. + */ +@ExpressionDescription(usage = + """_FUNC_(x) - The NTILE(n) function divides the rows for each window partition into 'n' buckets + ranging from 1 to at most 'n'.""") case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(1)) @@ -474,6 +549,8 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow * the order of the window in which is processed. For instance, when the value of 'x' changes in a * window ordered by 'x' the rank function also changes. The size of the change of the rank function * is (typically) not dependent on the size of the change in 'x'. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. */ abstract class RankLike extends AggregateWindowFunction { override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) @@ -513,11 +590,41 @@ abstract class RankLike extends AggregateWindowFunction { def withOrder(order: Seq[Expression]): RankLike } +/** + * The Rank function computes the rank of a value in a group of values. The result is one plus the + * number of rows preceding or equal to the current row in the ordering of the partition. Tie values + * will produce gaps in the sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - RANK() computes the rank of a value in a group of values. The result is one plus + the number of rows preceding or equal to the current row in the ordering of the partition. Tie + values will produce gaps in the sequence.""") case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) } +/** + * The DenseRank function computes the rank of a value in a group of values. The result is one plus + * the previously assigned rank value. Unlike Rank, DenseRank will not produce gaps in the ranking + * sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - The DENSE_RANK() function computes the rank of a value in a group of values. The + result is one plus the previously assigned rank value. Unlike Rank, DenseRank will not produce + gaps in the ranking sequence.""") case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) @@ -527,6 +634,23 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { override val initialValues = zero +: orderInit } +/** + * The PercentRank function computes the percentage ranking of a value in a group of values. The + * result the rank of the minus one divided by the total number of rows in the partitiion minus one: + * (r - 1) / (n - 1). If a partition only contains one row, the function will return 0. + * + * The PercentRank function is similar to the CumeDist function, but it uses rank values instead of + * row counts in the its numerator. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - PERCENT_RANK() The PercentRank function computes the percentage ranking of a value + in a group of values.""") case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index b50d7604e0ec7..3917b9762ba63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -292,4 +292,24 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 3, 8, 32), Row("b", 2, 4, 8))) } + + test("null inputs") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) + .toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.select( + $"key", + $"value", + avg(lit(null)).over(window), + sum(lit(null)).over(window)), + Seq( + Row("a", 1, null, null), + Row("a", 1, null, null), + Row("a", 2, null, null), + Row("a", 2, null, null), + Row("b", 4, null, null), + Row("b", 3, null, null), + Row("b", 2, null, null))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index c05dbfd7608d9..ea82b8c459695 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -227,4 +227,19 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) // scalastyle:on } + + test("null arguments") { + checkAnswer(sql(""" + |select p_mfgr, p_name, p_size, + |sum(null) over(distribute by p_mfgr sort by p_name) as sum, + |avg(null) over(distribute by p_mfgr sort by p_name) as avg + |from part + """.stripMargin), + sql(""" + |select p_mfgr, p_name, p_size, + |null as sum, + |null as avg + |from part + """.stripMargin)) + } } From ee8f8d318417c514fbb26e57157483d466ddbfae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Dec 2015 18:07:07 -0800 Subject: [PATCH 1383/2089] [SPARK-12588] Remove HttpBroadcast in Spark 2.0. We switched to TorrentBroadcast in Spark 1.1, and HttpBroadcast has been undocumented since then. It's time to remove it in Spark 2.0. Author: Reynold Xin Closes #10531 from rxin/SPARK-12588. --- .../spark/broadcast/BroadcastFactory.scala | 4 +- .../spark/broadcast/BroadcastManager.scala | 13 +- .../spark/broadcast/HttpBroadcast.scala | 269 ------------------ .../broadcast/HttpBroadcastFactory.scala | 47 --- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../broadcast/TorrentBroadcastFactory.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 2 - .../spark/broadcast/BroadcastSuite.scala | 131 +-------- docs/configuration.md | 19 +- docs/security.md | 13 +- .../apache/spark/examples/BroadcastTest.scala | 8 +- project/MimaExcludes.scala | 3 +- 12 files changed, 22 insertions(+), 491 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala delete mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 6a187b40628a2..7f35ac47479b0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -24,14 +24,12 @@ import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi /** - * :: DeveloperApi :: * An interface for all the broadcast implementations in Spark (to allow * multiple broadcast implementations). SparkContext uses a user-specified * BroadcastFactory implementation to instantiate a particular broadcast for the * entire Spark job. */ -@DeveloperApi -trait BroadcastFactory { +private[spark] trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index fac6666bb3410..61343607a13bc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SecurityManager} + private[spark] class BroadcastManager( val isDriver: Boolean, @@ -39,15 +39,8 @@ private[spark] class BroadcastManager( private def initialize() { synchronized { if (!initialized) { - val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - - broadcastFactory = - Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory = new TorrentBroadcastFactory broadcastFactory.initialize(isDriver, conf, securityManager) - initialized = true } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala deleted file mode 100644 index b69af639f7862..0000000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} -import java.io.{BufferedInputStream, BufferedOutputStream} -import java.net.{URL, URLConnection, URI} -import java.util.concurrent.TimeUnit - -import scala.reflect.ClassTag - -import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} -import org.apache.spark.io.CompressionCodec -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} - -/** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server - * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a - * task) is deserialized in the executor, the broadcasted data is fetched from the driver - * (through a HTTP server running at the driver) and stored in the BlockManager of the - * executor to speed up future accesses. - */ -private[spark] class HttpBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) with Logging with Serializable { - - override protected def getValue() = value_ - - private val blockId = BroadcastBlockId(id) - - /* - * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster - * does not need to be told about this block as not only need to know about this data block. - */ - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - } - - if (!isLocal) { - HttpBroadcast.write(id, value_) - } - - /** - * Remove all persisted state associated with this HTTP broadcast on the executors. - */ - override protected def doUnpersist(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) - } - - /** - * Remove all persisted state associated with this HTTP broadcast on the executors and driver. - */ - override protected def doDestroy(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) - } - - /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { - assertValid() - out.defaultWriteObject() - } - - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => value_ = x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime - value_ = HttpBroadcast.read[T](id) - /* - * We cache broadcast data in the BlockManager so that subsequent tasks using it - * do not need to re-fetch. This data is only used locally and no other node - * needs to fetch this block, so we don't notify the master. - */ - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - } -} - -private[broadcast] object HttpBroadcast extends Logging { - private var initialized = false - private var broadcastDir: File = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - private var serverUri: String = null - private var server: HttpServer = null - private var securityManager: SecurityManager = null - - // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[File] - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null - private var cleaner: MetadataCleaner = null - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - synchronized { - if (!initialized) { - bufferSize = conf.getInt("spark.buffer.size", 65536) - compress = conf.getBoolean("spark.broadcast.compress", true) - securityManager = securityMgr - if (isDriver) { - createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) - } - serverUri = conf.get("spark.httpBroadcast.uri") - cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) - compressionCodec = CompressionCodec.createCodec(conf) - initialized = true - } - } - } - - def stop() { - synchronized { - if (server != null) { - server.stop() - server = null - } - if (cleaner != null) { - cleaner.cancel() - cleaner = null - } - compressionCodec = null - initialized = false - } - } - - private def createServer(conf: SparkConf) { - broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast") - val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = - new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") - server.start() - serverUri = server.uri - logInfo("Broadcast server started at " + serverUri) - } - - def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) - - private def write(id: Long, value: Any) { - val file = getFile(id) - val fileOutputStream = new FileOutputStream(file) - Utils.tryWithSafeFinally { - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(fileOutputStream) - } else { - new BufferedOutputStream(fileOutputStream, bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - Utils.tryWithSafeFinally { - serOut.writeObject(value) - } { - serOut.close() - } - files += file - } { - fileOutputStream.close() - } - } - - private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) - val url = serverUri + "/" + BroadcastBlockId(id).name - - var uc: URLConnection = null - if (securityManager.isAuthenticationEnabled()) { - logDebug("broadcast security enabled") - val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL.openConnection() - uc.setConnectTimeout(httpReadTimeout) - uc.setAllowUserInteraction(false) - } else { - logDebug("broadcast not using security") - uc = new URL(url).openConnection() - uc.setConnectTimeout(httpReadTimeout) - } - Utils.setupSecureURLConnection(uc, securityManager) - - val in = { - uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream - if (compress) { - compressionCodec.compressedInputStream(inputStream) - } else { - new BufferedInputStream(inputStream, bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serIn = ser.deserializeStream(in) - Utils.tryWithSafeFinally { - serIn.readObject[T]() - } { - serIn.close() - } - } - - /** - * Remove all persisted blocks associated with this HTTP broadcast on the executors. - * If removeFromDriver is true, also remove these persisted blocks on the driver - * and delete the associated broadcast file. - */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - if (removeFromDriver) { - val file = getFile(id) - files.remove(file) - deleteBroadcastFile(file) - } - } - - /** - * Periodically clean up old broadcasts by removing the associated map entries and - * deleting the associated files. - */ - private def cleanup(cleanupTime: Long) { - val iterator = files.internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val (file, time) = (entry.getKey, entry.getValue) - if (time < cleanupTime) { - iterator.remove() - deleteBroadcastFile(file) - } - } - } - - private def deleteBroadcastFile(file: File) { - try { - if (file.exists) { - if (file.delete()) { - logInfo("Deleted broadcast file: %s".format(file)) - } else { - logWarning("Could not delete broadcast file: %s".format(file)) - } - } - } catch { - case e: Exception => - logError("Exception while deleting broadcast file: %s".format(file), e) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala deleted file mode 100644 index cf3ae36f27949..0000000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import scala.reflect.ClassTag - -import org.apache.spark.{SecurityManager, SparkConf} - -/** - * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a - * HTTP server as the broadcast mechanism. Refer to - * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. - */ -class HttpBroadcastFactory extends BroadcastFactory { - override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = - new HttpBroadcast[T](value_, isLocal, id) - - override def stop() { HttpBroadcast.stop() } - - /** - * Remove all persisted state associated with the HTTP broadcast with the given ID. - * @param removeFromDriver Whether to remove state from the driver - * @param blocking Whether to block until unbroadcasted - */ - override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver, blocking) - } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7e3764d802fe1..9bd69727f6086 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream * BlockManager, ready for other executors to fetch from. * * This prevents the driver from being the bottleneck in sending out multiple copies of the - * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. + * broadcast data (one per executor). * * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 96d8dd79908c8..b11f9ba171b84 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SecurityManager, SparkConf} * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. */ -class TorrentBroadcastFactory extends BroadcastFactory { +private[spark] class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index eed9937b3046f..1b4538e6afb85 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -34,7 +34,6 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ @@ -107,7 +106,6 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index ba21075ce6be5..88fdbbdaec902 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -45,39 +45,8 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { class BroadcastSuite extends SparkFunSuite with LocalSparkContext { - private val httpConf = broadcastConf("HttpBroadcastFactory") - private val torrentConf = broadcastConf("TorrentBroadcastFactory") - - test("Using HttpBroadcast locally") { - sc = new SparkContext("local", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === Set((1, 10), (2, 10))) - } - - test("Accessing HttpBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) - } - - test("Accessing HttpBroadcast variables in a local cluster") { - val numSlaves = 4 - val conf = httpConf.clone - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - test("Using TorrentBroadcast locally") { - sc = new SparkContext("local", "test", torrentConf) + sc = new SparkContext("local", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) @@ -85,7 +54,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { } test("Accessing TorrentBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", torrentConf) + sc = new SparkContext("local[10]", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) @@ -94,7 +63,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - val conf = torrentConf.clone + val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -124,31 +93,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 - val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") val rdd = sc.parallelize(1 to numSlaves) - val results = new DummyBroadcastClass(rdd).doSomething() assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) } - test("Unpersisting HttpBroadcast on executors only in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) - } - - test("Unpersisting HttpBroadcast on executors only in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) - } - test("Unpersisting TorrentBroadcast on executors only in local mode") { testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) } @@ -179,66 +130,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(thrown.getMessage.toLowerCase.contains("stopped")) } - /** - * Verify the persistence of state associated with an HttpBroadcast in either local mode or - * local-cluster mode (when distributed = true). - * - * This test creates a broadcast variable, uses it on all executors, and then unpersists it. - * In between each step, this test verifies that the broadcast blocks and the broadcast file - * are present only on the expected nodes. - */ - private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { - val numSlaves = if (distributed) 2 else 0 - - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.isDriver, "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } - if (distributed) { - // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") - } - } - - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === numSlaves + 1) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } - - // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver - // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - if (distributed && removeFromDriver) { - // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, - "Broadcast file should%s be deleted".format(possiblyNot)) - } - } - - testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, - afterUsingBroadcast, afterUnpersist, removeFromDriver) - } - /** * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. * @@ -284,7 +175,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -300,7 +191,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private def testUnpersistBroadcast( distributed: Boolean, numSlaves: Int, // used only when distributed = true - broadcastConf: SparkConf, afterCreation: (Long, BlockManagerMaster) => Unit, afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, afterUnpersist: (Long, BlockManagerMaster) => Unit, @@ -308,7 +198,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") // Wait until all salves are up try { _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) @@ -319,7 +209,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { throw e } } else { - new SparkContext("local", "test", broadcastConf) + new SparkContext("local", "test") } val blockManagerMaster = sc.env.blockManager.master val list = List[Int](1, 2, 3, 4) @@ -356,13 +246,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) } } - - /** Helper method to create a SparkConf that uses the given broadcast factory. */ - private def broadcastConf(factoryName: String): SparkConf = { - val conf = new SparkConf - conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) - conf - } } package object testPackage extends Assertions { diff --git a/docs/configuration.md b/docs/configuration.md index a9ef37a9b1cd9..7d743d572b582 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -823,13 +823,6 @@ Apart from these, the following properties are also available, and may be useful too small, BlockManager might take a performance hit. - - - - - @@ -1017,14 +1010,6 @@ Apart from these, the following properties are also available, and may be useful Port for all block managers to listen on. These exist on both the driver and the executors. - - - - - @@ -1444,8 +1429,8 @@ Apart from these, the following properties are also available, and may be useful

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for particular protocol denoted by YYY. Currently YYY can be - either akka for Akka based connections or fs for broadcast and - file server.

    + either akka for Akka based connections or fs for file + server.

    diff --git a/docs/security.md b/docs/security.md index 0bfc791c5744e..1b7741d4dd93c 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,7 +23,7 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +Spark supports SSL for Akka and HTTP (for file server) protocols. SASL encryption is supported for the block transfer service. Encryption is not yet supported for the WebUI. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle @@ -32,7 +32,7 @@ to configure your cluster manager to store application data on encrypted disks. ### SSL Configuration -Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -160,15 +160,6 @@ configure those ports. - - - - - - - - diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index d812262fd87dc..3da5236745b51 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,16 +21,14 @@ package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: BroadcastTest [slices] [numElem] [broadcastAlgo] [blockSize] + * Usage: BroadcastTest [slices] [numElem] [blockSize] */ object BroadcastTest { def main(args: Array[String]) { - val bcName = if (args.length > 2) args(2) else "Http" - val blockSize = if (args.length > 3) args(3) else "4096" + val blockSize = if (args.length > 2) args(2) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) @@ -44,7 +42,7 @@ object BroadcastTest { println("===========") val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) - val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size) + val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.length) // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ad878c1892e99..b7d27c9f06666 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -35,7 +35,8 @@ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("2.0") => Seq( - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD") + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ // When 1.6 is officially released, update this exclusion list. Seq( From fd3333313864e21e3a9c95577723c931357d1f16 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 30 Dec 2015 18:20:27 -0800 Subject: [PATCH 1384/2089] [SPARK-3873][YARN] Fix import ordering. Author: Marcelo Vanzin Closes #10536 from vanzin/SPARK-3873-yarn. --- .../spark/deploy/yarn/AMDelegationTokenRenewer.scala | 2 +- .../apache/spark/deploy/yarn/ApplicationMaster.scala | 6 +++--- .../deploy/yarn/ApplicationMasterArguments.scala | 5 +++-- .../scala/org/apache/spark/deploy/yarn/Client.scala | 12 +++++------- .../deploy/yarn/ClientDistributedCacheManager.scala | 2 +- .../deploy/yarn/ExecutorDelegationTokenUpdater.scala | 6 +++--- .../apache/spark/deploy/yarn/ExecutorRunnable.scala | 4 ++-- .../org/apache/spark/deploy/yarn/YarnAllocator.scala | 1 - .../org/apache/spark/deploy/yarn/YarnRMClient.scala | 2 +- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 6 +++--- .../cluster/YarnClientSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/YarnScheduler.scala | 1 - 12 files changed, 23 insertions(+), 26 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index 56e4741b93873..b8daa501af7f0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -24,9 +24,9 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.ThreadUtils /* diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index fc742df73d731..a01bb267d7948 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,23 +17,23 @@ package org.apache.spark.deploy.yarn -import scala.util.control.NonFatal - import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference +import scala.util.control.NonFatal + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 17d9943c795e3..5af3941c6023e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.util.{MemoryParam, IntParam} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import collection.mutable.ArrayBuffer +import org.apache.spark.util.{IntParam, MemoryParam} class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7742ec92eb4e8..8cf438be587dc 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -28,22 +28,20 @@ import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe -import scala.util.{Try, Success, Failure} +import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files - -import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.io.Text +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.io.{DataOutputBuffer, Text} import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.{TokenIdentifier, Token} +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -55,8 +53,8 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} -import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.util.Utils private[spark] class Client( diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 3d3a966960e9f..4ef05c5a846d5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.util.{Records, ConverterUtils} +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.Logging diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 94feb6393fd69..9d99c0d93fd1e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -18,16 +18,16 @@ package org.apache.spark.deploy.yarn import java.util.concurrent.{Executors, TimeUnit} +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{ThreadUtils, Utils} -import scala.util.control.NonFatal - private[spark] class ExecutorDelegationTokenUpdater( sparkConf: SparkConf, hadoopConf: Configuration) extends Logging { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 2232ffba473b5..31fa53e24b507 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -25,12 +25,12 @@ import java.util.Collections import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.NMClient import org.apache.hadoop.yarn.conf.YarnConfiguration diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4e044aa4788da..11426eb07c7ed 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver - import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index d2a211f6711ff..af83cf6a77d17 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConverters._ import scala.collection.{Map, Set} +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.conf.Configuration diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 36a2d61429887..e286aed9f9781 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,19 +30,19 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.{Master, JobConf} +import org.apache.hadoop.mapred.{JobConf, Master} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.security.token.{Token, TokenIdentifier} -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.ConverterUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.launcher.YarnCommandBuilderUtils -import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils /** diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0e27a2665e939..20e2030fce086 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState -import org.apache.spark.{SparkException, Logging, SparkContext} +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 4ebf3af12b381..029382133ddf2 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.util.RackResolver - import org.apache.log4j.{Level, Logger} import org.apache.spark._ From 9140d9074379055a0b4b2f5c381362b31c141941 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 30 Dec 2015 18:26:08 -0800 Subject: [PATCH 1385/2089] [SPARK-3873][GRAPHX] Import order fixes. There's one warning left, caused by a bug in the checker. Author: Marcelo Vanzin Closes #10537 from vanzin/SPARK-3873-graphx. --- .../src/main/scala/org/apache/spark/graphx/EdgeRDD.scala | 5 ++--- .../org/apache/spark/graphx/GraphKryoRegistrator.scala | 5 ++--- .../main/scala/org/apache/spark/graphx/GraphLoader.scala | 2 +- .../src/main/scala/org/apache/spark/graphx/GraphOps.scala | 3 +-- .../main/scala/org/apache/spark/graphx/GraphXUtils.scala | 4 +--- .../src/main/scala/org/apache/spark/graphx/Pregel.scala | 2 +- .../main/scala/org/apache/spark/graphx/VertexRDD.scala | 5 ++--- .../apache/spark/graphx/impl/EdgePartitionBuilder.scala | 2 +- .../scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala | 5 ++--- .../scala/org/apache/spark/graphx/impl/GraphImpl.scala | 5 ++--- .../apache/spark/graphx/impl/ReplicatedVertexView.scala | 3 +-- .../apache/spark/graphx/impl/RoutingTablePartition.scala | 8 +++----- .../spark/graphx/impl/ShippableVertexPartition.scala | 3 +-- .../org/apache/spark/graphx/impl/VertexPartition.scala | 3 +-- .../apache/spark/graphx/impl/VertexPartitionBase.scala | 3 +-- .../apache/spark/graphx/impl/VertexPartitionBaseOps.scala | 3 +-- .../org/apache/spark/graphx/impl/VertexRDDImpl.scala | 3 +-- .../org/apache/spark/graphx/lib/LabelPropagation.scala | 1 + .../main/scala/org/apache/spark/graphx/lib/PageRank.scala | 2 +- .../scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala | 2 +- .../scala/org/apache/spark/graphx/lib/ShortestPaths.scala | 3 ++- .../org/apache/spark/graphx/util/GraphGenerators.scala | 7 ++----- .../util/collection/GraphXPrimitiveKeyOpenHashMap.scala | 4 ++-- 23 files changed, 33 insertions(+), 50 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index ee7302a1edbf6..45526bf062fab 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -24,12 +24,11 @@ import org.apache.spark.Dependency import org.apache.spark.Partition import org.apache.spark.SparkContext import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.graphx.impl.EdgePartitionBuilder import org.apache.spark.graphx.impl.EdgeRDDImpl +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index 563c948957ecf..eaa71dad17a04 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -19,12 +19,11 @@ package org.apache.spark.graphx import com.esotericsoftware.kryo.Kryo +import org.apache.spark.graphx.impl._ +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.BitSet - -import org.apache.spark.graphx.impl._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.OpenHashSet /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 21187be7678a6..1672f7d27c401 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -17,9 +17,9 @@ package org.apache.spark.graphx -import org.apache.spark.storage.StorageLevel import org.apache.spark.{Logging, SparkContext} import org.apache.spark.graphx.impl.{EdgePartitionBuilder, GraphImpl} +import org.apache.spark.storage.StorageLevel /** * Provides utilities for loading [[Graph]]s from files. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 9827dfab8684a..fc36e12dd2aed 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -22,9 +22,8 @@ import scala.util.Random import org.apache.spark.SparkException import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import org.apache.spark.graphx.lib._ +import org.apache.spark.rdd.RDD /** * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala index 2cb07937eaa2a..8ec33e140000e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -18,12 +18,10 @@ package org.apache.spark.graphx import org.apache.spark.SparkConf - import org.apache.spark.graphx.impl._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -import org.apache.spark.util.collection.{OpenHashSet, BitSet} import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.{BitSet, OpenHashSet} object GraphXUtils { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 2ca60d51f8331..b908860310093 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -18,8 +18,8 @@ package org.apache.spark.graphx import scala.reflect.ClassTag -import org.apache.spark.Logging +import org.apache.spark.Logging /** * Implements a Pregel-like bulk-synchronous message-passing API. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 1ef7a78fbcd00..53a9f92b82bc5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -21,13 +21,12 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock import org.apache.spark.graphx.impl.VertexRDDImpl +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 906d42328fcb9..b122969b817fa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.{SortDataFormat, Sorter, PrimitiveVector} +import org.apache.spark.util.collection.{PrimitiveVector, SortDataFormat, Sorter} /** Constructs an EdgePartition from scratch. */ private[graphx] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index c88b2f65a86cd..6e153b7e803ea 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,12 +19,11 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, HashPartitioner} +import org.apache.spark.{HashPartitioner, OneToOneDependency} +import org.apache.spark.graphx._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx._ - class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( @transient override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index da95314440d86..81182adbc6389 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -21,12 +21,11 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, ShuffledRDD} -import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils - +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.storage.StorageLevel /** * An implementation of [[org.apache.spark.graphx.Graph]] to support computation on graphs. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 1df86449fa0c2..f79f9c7ec448f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -20,9 +20,8 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import org.apache.spark.graphx._ +import org.apache.spark.rdd.RDD /** * Manages shipping vertex attributes to the edge partitions of an diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 4f1260a5a67b2..3fd76902af646 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import org.apache.spark.Partitioner +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.rdd.RDD import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.util.collection.{BitSet, PrimitiveVector} -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage - private[graphx] object RoutingTablePartition { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index aa320088f2088..3f203c4eca485 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -19,10 +19,9 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.util.collection.{BitSet, PrimitiveVector} - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} /** Stores vertex attributes to ship to an edge partition. */ private[graphx] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala index fbe53acfc32aa..4512bc17399a9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala @@ -19,10 +19,9 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartition { /** Construct a `VertexPartition` from the given vertices. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala index 5ad6390a56c4f..8d608c99b1a1d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala @@ -20,10 +20,9 @@ package org.apache.spark.graphx.impl import scala.language.higherKinds import scala.reflect.ClassTag -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartitionBase { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index b90f9fa327052..f508b483a2f1b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -22,10 +22,9 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet /** * An class containing additional operations for subclasses of VertexPartitionBase that provide diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 7f4e7e9d79d6b..d5accdfbf7e9c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -21,11 +21,10 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.SparkContext._ +import org.apache.spark.graphx._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx._ - class VertexRDDImpl[VD] private[graphx] ( @transient val partitionsRDD: RDD[ShippableVertexPartition[VD]], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index a3ad6bed1c998..7a53eca7eac64 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag + import org.apache.spark.graphx._ /** Label Propagation algorithm. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 52b237fc15093..35b26c998e1d9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphx.lib -import scala.reflect.ClassTag import scala.language.postfixOps +import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.graphx._ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 9cb24ed080e1c..16300e0740790 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -21,8 +21,8 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.rdd._ import org.apache.spark.graphx._ +import org.apache.spark.rdd._ /** Implementation of SVD++ algorithm. */ object SVDPlusPlus { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 179f2843818e0..f0c6bcb93445c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -17,9 +17,10 @@ package org.apache.spark.graphx.lib -import org.apache.spark.graphx._ import scala.reflect.ClassTag +import org.apache.spark.graphx._ + /** * Computes shortest paths to the given set of landmark vertices, returning a graph where each * vertex attribute is a map containing the shortest-path distance to each reachable landmark. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 989e226305265..280b6c5578fe5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -23,14 +23,11 @@ import scala.reflect.ClassTag import scala.util._ import org.apache.spark._ -import org.apache.spark.serializer._ -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.Graph -import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer._ /** A collection of graph generating functions. */ object GraphGenerators extends Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index e2754ea699da9..972237da1cb28 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util.collection -import org.apache.spark.util.collection.OpenHashSet - import scala.reflect._ +import org.apache.spark.util.collection.OpenHashSet + /** * A fast hash map implementation for primitive, non-null keys. This hash map supports * insertions and updates, but not deletions. This map is about an order of magnitude From be33a0cd3def86e0aa64dab411e504abbbdfb03c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Dec 2015 18:28:08 -0800 Subject: [PATCH 1386/2089] [SPARK-12561] Remove JobLogger in Spark 2.0. It was research code and has been deprecated since 1.0.0. No one really uses it since they can just use event logging. Author: Reynold Xin Closes #10530 from rxin/SPARK-12561. --- .../apache/spark/scheduler/JobLogger.scala | 277 ------------------ 1 file changed, 277 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala deleted file mode 100644 index f96eb8ca0ae00..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.io.{File, FileNotFoundException, IOException, PrintWriter} -import java.text.SimpleDateFormat -import java.util.{Date, Properties} - -import scala.collection.mutable.HashMap - -import org.apache.spark._ -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics - -/** - * :: DeveloperApi :: - * A logger class to record runtime information for jobs in Spark. This class outputs one log file - * for each Spark job, containing tasks start/stop and shuffle information. JobLogger is a subclass - * of SparkListener, use addSparkListener to add JobLogger to a SparkContext after the SparkContext - * is created. Note that each JobLogger only works for one SparkContext - * - * NOTE: The functionality of this class is heavily stripped down to accommodate for a general - * refactor of the SparkListener interface. In its place, the EventLoggingListener is introduced - * to log application information as SparkListenerEvents. To enable this functionality, set - * spark.eventLog.enabled to true. - */ -@DeveloperApi -@deprecated("Log application information by setting spark.eventLog.enabled.", "1.0.0") -class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging { - - def this() = this(System.getProperty("user.name", ""), - String.valueOf(System.currentTimeMillis())) - - private val logDir = - if (System.getenv("SPARK_LOG_DIR") != null) { - System.getenv("SPARK_LOG_DIR") - } else { - "/tmp/spark-%s".format(user) - } - - private val jobIdToPrintWriter = new HashMap[Int, PrintWriter] - private val stageIdToJobId = new HashMap[Int, Int] - private val jobIdToStageIds = new HashMap[Int, Seq[Int]] - private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - } - - createLogDir() - - /** Create a folder for log files, the folder's name is the creation time of jobLogger */ - protected def createLogDir() { - val dir = new File(logDir + "/" + logDirName + "/") - if (dir.exists()) { - return - } - if (!dir.mkdirs()) { - // JobLogger should throw a exception rather than continue to construct this object. - throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") - } - } - - /** - * Create a log file for one job - * @param jobId ID of the job - * @throws FileNotFoundException Fail to create log file - */ - protected def createLogWriter(jobId: Int) { - try { - val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobId) - jobIdToPrintWriter += (jobId -> fileWriter) - } catch { - case e: FileNotFoundException => e.printStackTrace() - } - } - - /** - * Close log file, and clean the stage relationship in stageIdToJobId - * @param jobId ID of the job - */ - protected def closeLogWriter(jobId: Int) { - jobIdToPrintWriter.get(jobId).foreach { fileWriter => - fileWriter.close() - jobIdToStageIds.get(jobId).foreach(_.foreach { stageId => - stageIdToJobId -= stageId - }) - jobIdToPrintWriter -= jobId - jobIdToStageIds -= jobId - } - } - - /** - * Build up the maps that represent stage-job relationships - * @param jobId ID of the job - * @param stageIds IDs of the associated stages - */ - protected def buildJobStageDependencies(jobId: Int, stageIds: Seq[Int]) = { - jobIdToStageIds(jobId) = stageIds - stageIds.foreach { stageId => stageIdToJobId(stageId) = jobId } - } - - /** - * Write info into log file - * @param jobId ID of the job - * @param info Info to be recorded - * @param withTime Controls whether to record time stamp before the info, default is true - */ - protected def jobLogInfo(jobId: Int, info: String, withTime: Boolean = true) { - var writeInfo = info - if (withTime) { - val date = new Date(System.currentTimeMillis()) - writeInfo = dateFormat.get.format(date) + ": " + info - } - // scalastyle:off println - jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) - // scalastyle:on println - } - - /** - * Write info into log file - * @param stageId ID of the stage - * @param info Info to be recorded - * @param withTime Controls whether to record time stamp before the info, default is true - */ - protected def stageLogInfo(stageId: Int, info: String, withTime: Boolean = true) { - stageIdToJobId.get(stageId).foreach(jobId => jobLogInfo(jobId, info, withTime)) - } - - /** - * Record task metrics into job log files, including execution info and shuffle metrics - * @param stageId Stage ID of the task - * @param status Status info of the task - * @param taskInfo Task description info - * @param taskMetrics Task running metrics - */ - protected def recordTaskMetrics(stageId: Int, status: String, - taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageId + - " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + - " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname - val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime - val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime - val inputMetrics = taskMetrics.inputMetrics match { - case Some(metrics) => - " READ_METHOD=" + metrics.readMethod.toString + - " INPUT_BYTES=" + metrics.bytesRead - case None => "" - } - val outputMetrics = taskMetrics.outputMetrics match { - case Some(metrics) => - " OUTPUT_BYTES=" + metrics.bytesWritten - case None => "" - } - val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { - case Some(metrics) => - " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + - " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + - " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + - " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + - " LOCAL_BYTES_READ=" + metrics.localBytesRead - case None => "" - } - val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => - " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + - " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime - case None => "" - } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + - shuffleReadMetrics + writeMetrics) - } - - /** - * When stage is submitted, record stage submit info - * @param stageSubmitted Stage submitted event - */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - val stageInfo = stageSubmitted.stageInfo - stageLogInfo(stageInfo.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageInfo.stageId, stageInfo.numTasks)) - } - - /** - * When stage is completed, record stage completion status - * @param stageCompleted Stage completed event - */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - val stageId = stageCompleted.stageInfo.stageId - if (stageCompleted.stageInfo.failureReason.isEmpty) { - stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED") - } else { - stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED") - } - } - - /** - * When task ends, record task completion status and metrics - * @param taskEnd Task end event - */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val taskInfo = taskEnd.taskInfo - var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType) - val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty - taskEnd.reason match { - case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics) - case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + - " STAGE_ID=" + taskEnd.stageId - stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => - taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + - taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + - mapId + " REDUCE_ID=" + reduceId - stageLogInfo(taskEnd.stageId, taskStatus) - case _ => - } - } - - /** - * When job ends, recording job completion status and close log file - * @param jobEnd Job end event - */ - override def onJobEnd(jobEnd: SparkListenerJobEnd) { - val jobId = jobEnd.jobId - var info = "JOB_ID=" + jobId - jobEnd.jobResult match { - case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception) => - info += " STATUS=FAILED REASON=" - exception.getMessage.split("\\s+").foreach(info += _ + "_") - case _ => - } - jobLogInfo(jobId, info.substring(0, info.length - 1).toUpperCase) - closeLogWriter(jobId) - } - - /** - * Record job properties into job log file - * @param jobId ID of the job - * @param properties Properties of the job - */ - protected def recordJobProperties(jobId: Int, properties: Properties) { - if (properties != null) { - val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobId, description, withTime = false) - } - } - - /** - * When job starts, record job property and stage graph - * @param jobStart Job start event - */ - override def onJobStart(jobStart: SparkListenerJobStart) { - val jobId = jobStart.jobId - val properties = jobStart.properties - createLogWriter(jobId) - recordJobProperties(jobId, properties) - buildJobStageDependencies(jobId, jobStart.stageIds) - jobLogInfo(jobId, "JOB_ID=" + jobId + " STATUS=STARTED") - } -} From 7b4452ba98d53ed646a2e744bb701702fc5371a7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Dec 2015 18:49:17 -0800 Subject: [PATCH 1387/2089] House cleaning: close open pull requests created before June 1st, 2015 Closes #5358 Closes #3744 Closes #3677 Closes #3536 Closes #3249 Closes #3221 Closes #2446 Closes #3794 Closes #3815 Closes #3816 Closes #3866 Closes #4286 Closes #5184 Closes #5170 Closes #5142 Closes #5025 Closes #5005 Closes #4897 Closes #4887 Closes #4849 Closes #4632 Closes #4622 Closes #4456 Closes #4449 Closes #4417 Closes #5483 Closes #5325 Closes #6545 Closes #6449 Closes #6433 Closes #6416 Closes #6403 Closes #6386 Closes #6263 Closes #6245 Closes #6213 Closes #6155 Closes #6133 Closes #6018 Closes #5978 Closes #5869 Closes #5852 Closes #5848 Closes #5754 Closes #5598 Closes #5503 Closes #4380 From c642c3a2102fd016deabcb78bd0292bbf3dd7acf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Dec 2015 18:50:30 -0800 Subject: [PATCH 1388/2089] Closes #10386 since it was superseded by #10468. From 93b52abca708aba90b6d0d36c92b920f6f8338ad Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Dec 2015 18:54:03 -0800 Subject: [PATCH 1389/2089] House cleaning: close old pull requests. Closes #5400 Closes #5408 Closes #5423 Closes #5668 Closes #6757 Closes #6745 Closes #6613 From e6c77874b915691dead91e8d96ad9f58ba3a73db Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Dec 2015 22:16:37 -0800 Subject: [PATCH 1390/2089] [SPARK-12585] [SQL] move numFields to constructor of UnsafeRow Right now, numFields will be passed in by pointTo(), then bitSetWidthInBytes is calculated, making pointTo() a little bit heavy. It should be part of constructor of UnsafeRow. Author: Davies Liu Closes #10528 from davies/numFields. --- .../catalyst/expressions/UnsafeArrayData.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 88 ++++++------------- .../execution/UnsafeExternalRowSorter.java | 16 ++-- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 4 +- .../GenerateUnsafeRowJoinerBitsetSuite.scala | 4 +- .../UnsafeFixedWidthAggregationMap.java | 10 +-- .../sql/execution/UnsafeKVExternalSorter.java | 24 ++--- .../parquet/UnsafeRowParquetRecordReader.java | 32 +++---- .../sql/execution/UnsafeRowSerializer.scala | 6 +- .../sql/execution/columnar/ColumnType.scala | 3 +- .../columnar/GenerateColumnAccessor.scala | 4 +- .../datasources/text/DefaultSource.scala | 4 +- .../execution/joins/CartesianProduct.scala | 5 +- .../sql/execution/joins/HashedRelation.scala | 4 +- .../org/apache/spark/sql/UnsafeRowSuite.scala | 11 ++- 16 files changed, 86 insertions(+), 137 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 3513960b41813..3d80df227151d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -270,8 +270,8 @@ public UnsafeRow getStruct(int ordinal, int numFields) { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b6979d0c82977..7492b88c471a4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,11 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.io.OutputStream; +import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -30,26 +26,12 @@ import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.BinaryType; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.CalendarIntervalType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.NullType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.sql.types.StringType; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.TimestampType; -import org.apache.spark.sql.types.UserDefinedType; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -57,23 +39,9 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.BooleanType; -import static org.apache.spark.sql.types.DataTypes.ByteType; -import static org.apache.spark.sql.types.DataTypes.DateType; -import static org.apache.spark.sql.types.DataTypes.DoubleType; -import static org.apache.spark.sql.types.DataTypes.FloatType; -import static org.apache.spark.sql.types.DataTypes.IntegerType; -import static org.apache.spark.sql.types.DataTypes.LongType; -import static org.apache.spark.sql.types.DataTypes.NullType; -import static org.apache.spark.sql.types.DataTypes.ShortType; -import static org.apache.spark.sql.types.DataTypes.TimestampType; +import static org.apache.spark.sql.types.DataTypes.*; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -167,8 +135,16 @@ private void assertIndexIsValid(int index) { /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, * since the value returned by this constructor is equivalent to a null pointer. + * + * @param numFields the number of fields in this row */ - public UnsafeRow() { } + public UnsafeRow(int numFields) { + this.numFields = numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + } + + // for serializer + public UnsafeRow() {} public Object getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } @@ -182,15 +158,12 @@ public UnsafeRow() { } * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param numFields the number of fields in this row * @param sizeInBytes the size of this row's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; - this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; - this.numFields = numFields; this.sizeInBytes = sizeInBytes; } @@ -198,23 +171,12 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * Update this UnsafeRow to point to the underlying byte array. * * @param buf byte array to point to - * @param numFields the number of fields in this row - * @param sizeInBytes the number of bytes valid in the byte array - */ - public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); - } - - /** - * Updates this UnsafeRow preserving the number of fields. - * @param buf byte array to point to * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int sizeInBytes) { - pointTo(buf, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } - public void setNotNullAt(int i) { assertIndexIsValid(i); BitSetMethods.unset(baseObject, baseOffset, i); @@ -489,8 +451,8 @@ public UnsafeRow getStruct(int ordinal, int numFields) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } } @@ -529,7 +491,7 @@ public UnsafeMapData getMap(int ordinal) { */ @Override public UnsafeRow copy() { - UnsafeRow rowCopy = new UnsafeRow(); + UnsafeRow rowCopy = new UnsafeRow(numFields); final byte[] rowDataCopy = new byte[sizeInBytes]; Platform.copyMemory( baseObject, @@ -538,7 +500,7 @@ public UnsafeRow copy() { Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return rowCopy; } @@ -547,8 +509,8 @@ public UnsafeRow copy() { * The returned row is invalid until we call copyFrom on it. */ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { - final UnsafeRow row = new UnsafeRow(); - row.pointTo(new byte[numBytes], numFields, numBytes); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(new byte[numBytes], numBytes); return row; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 352002b3499a2..27ae62f1212f6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -26,10 +26,9 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -123,7 +122,7 @@ Iterator sort() throws IOException { return new AbstractScalaRowIterator() { private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(); + private UnsafeRow row = new UnsafeRow(numFields); @Override public boolean hasNext() { @@ -137,7 +136,6 @@ public UnsafeRow next() { row.pointTo( sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), - numFields, sortedIterator.getRecordLength()); if (!hasNext()) { UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page @@ -173,19 +171,21 @@ public Iterator sort(Iterator inputIterator) throws IOExce private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; public RowComparator(Ordering ordering, int numFields) { this.numFields = numFields; + this.row1 = new UnsafeRow(numFields); + this.row2 = new UnsafeRow(numFields); this.ordering = ordering; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // TODO: Why are the sizes -1? - row1.pointTo(baseObj1, baseOff1, numFields, -1); - row2.pointTo(baseObj2, baseOff2, numFields, -1); + row1.pointTo(baseObj1, baseOff1, -1); + row2.pointTo(baseObj2, baseOff2, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index c1defe12b0b91..d0e031f27990c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -289,7 +289,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") val bufferHolder = ctx.freshName("bufferHolder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") @@ -303,7 +303,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); + $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da602d9b4bce1..c9ff357bf3476 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -165,7 +165,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} { | private byte[] buf = new byte[64]; - | private UnsafeRow out = new UnsafeRow(); + | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { | // row1: ${schema1.size} fields, $bitset1Words words in bitset @@ -188,7 +188,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyVariableLengthRow2 | $updateOffset | - | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction); + | out.pointTo(buf, sizeInBytes - $sizeReduction); | | return out; | } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 796d60032e1a6..f8342214d9ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -90,13 +90,13 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { } private def createUnsafeRow(numFields: Int): UnsafeRow = { - val row = new UnsafeRow + val row = new UnsafeRow(numFields) val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle. // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, sizeInBytes) row } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index a2f99d566d471..6bf9d7bd0367c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -61,7 +61,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentAggregationBuffer; private final boolean enablePerfMetrics; @@ -98,6 +98,7 @@ public UnsafeFixedWidthAggregationMap( long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; + this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = @@ -147,7 +148,6 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo currentAggregationBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), loc.getValueLength() ); return currentAggregationBuffer; @@ -165,8 +165,8 @@ public KVIterator iterator() { private final BytesToBytesMap.MapIterator mapLocationIterator = map.destructiveIterator(); - private final UnsafeRow key = new UnsafeRow(); - private final UnsafeRow value = new UnsafeRow(); + private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length()); + private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length()); @Override public boolean next() { @@ -177,13 +177,11 @@ public boolean next() { key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), loc.getKeyLength() ); value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), loc.getValueLength() ); return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 8c9b9c85e37fc..0da26bf376a6a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -94,7 +94,7 @@ public UnsafeKVExternalSorter( // The only new memory we are allocating is the pointer/prefix array. BytesToBytesMap.MapIterator iter = map.iterator(); final int numKeyFields = keySchema.size(); - UnsafeRow row = new UnsafeRow(); + UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); final Object baseObject = loc.getKeyAddress().getBaseObject(); @@ -107,7 +107,7 @@ public UnsafeKVExternalSorter( long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); // Compute prefix - row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + row.pointTo(baseObject, baseOffset, loc.getKeyLength()); final long prefix = prefixComputer.computePrefix(row); inMemSorter.insertRecord(address, prefix); @@ -194,12 +194,14 @@ public void cleanupResources() { private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; private final int numKeyFields; public KVComparator(BaseOrdering ordering, int numKeyFields) { this.numKeyFields = numKeyFields; + this.row1 = new UnsafeRow(numKeyFields); + this.row2 = new UnsafeRow(numKeyFields); this.ordering = ordering; } @@ -207,17 +209,15 @@ public KVComparator(BaseOrdering ordering, int numKeyFields) { public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. - row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); - row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + row1.pointTo(baseObj1, baseOff1 + 4, -1); + row2.pointTo(baseObj2, baseOff2 + 4, -1); return ordering.compare(row1, row2); } } public class KVSorterIterator extends KVIterator { - private UnsafeRow key = new UnsafeRow(); - private UnsafeRow value = new UnsafeRow(); - private final int numKeyFields = keySchema.size(); - private final int numValueFields = valueSchema.size(); + private UnsafeRow key = new UnsafeRow(keySchema.size()); + private UnsafeRow value = new UnsafeRow(valueSchema.size()); private final UnsafeSorterIterator underlying; private KVSorterIterator(UnsafeSorterIterator underlying) { @@ -237,8 +237,8 @@ public boolean next() throws IOException { // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + key.pointTo(baseObj, recordOffset + 4, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen); return true; } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 0cc4566c9cdde..a6758bddfa7d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,35 +21,28 @@ import java.nio.ByteBuffer; import java.util.List; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - -import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; -import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; -import static org.apache.parquet.column.ValuesType.VALUES; - import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; import org.apache.parquet.column.Encoding; -import org.apache.parquet.column.page.DataPage; -import org.apache.parquet.column.page.DataPageV1; -import org.apache.parquet.column.page.DataPageV2; -import org.apache.parquet.column.page.DictionaryPage; -import org.apache.parquet.column.page.PageReadStore; -import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.*; + /** * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. * @@ -181,12 +174,11 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont rowWriters = new UnsafeRowWriter[rows.length]; for (int i = 0; i < rows.length; ++i) { - rows[i] = new UnsafeRow(); + rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); rowWriters[i] = new UnsafeRowWriter(); BufferHolder holder = new BufferHolder(rowByteSize); rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); - rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), - holder.buffer.length); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 7e981268de392..4730647c4be9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -94,7 +94,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) - private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var row: UnsafeRow = new UnsafeRow(numFields) private[this] var rowTuple: (Int, UnsafeRow) = (0, row) private[this] val EOF: Int = -1 @@ -117,7 +117,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream dIn.close() @@ -152,7 +152,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index c9f2329db4b6d..9c908b2877e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -574,11 +574,10 @@ private[columnar] case class STRUCT(dataType: StructType) assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + sizeInBytes) - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numOfFields) unsafeRow.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, - numOfFields, sizeInBytes) unsafeRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index eaafc96e4d2e7..b208425ffc3c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -131,7 +131,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow(); + private UnsafeRow unsafeRow = new UnsafeRow($numFields); private BufferHolder bufferHolder = new BufferHolder(); private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); private MutableUnsafeRow mutableRow = null; @@ -183,7 +183,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera bufferHolder.reset(); rowWriter.initialize(bufferHolder, $numFields); ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 4a1cbe4c38fa2..41fcb11d84bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -101,14 +101,14 @@ private[sql] class TextRelation( .mapPartitions { iter => val bufferHolder = new BufferHolder val unsafeRowWriter = new UnsafeRowWriter - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(1) iter.map { case (_, line) => // Writes to an UnsafeRow directly bufferHolder.reset() unsafeRowWriter.initialize(bufferHolder, 1) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize()) + unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) unsafeRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fa2bc7672131c..81bfe4e67ca73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -56,15 +56,14 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] def createIter(): Iterator[UnsafeRow] = { val iter = sorter.getIterator - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numFieldsOfRight) new Iterator[UnsafeRow] { override def hasNext: Boolean = { iter.hasNext } override def next(): UnsafeRow = { iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, - iter.getRecordLength) + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) unsafeRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8c7099ab5a34d..c6f56cfaed22c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -245,8 +245,8 @@ private[joins] final class UnsafeHashedRelation( val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 - val row = new UnsafeRow - row.pointTo(base, offset, numFields, sizeInBytes) + val row = new UnsafeRow(numFields) + row.pointTo(base, offset, sizeInBytes) buffer += row offset += sizeInBytes } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 00f1526576cc5..a32763db054f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -34,8 +34,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Java serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new JavaSerializer(new SparkConf).newInstance() @@ -47,8 +47,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Kryo serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new KryoSerializer(new SparkConf).newInstance() @@ -86,11 +86,10 @@ class UnsafeRowSuite extends SparkFunSuite { offheapRowPage.getBaseOffset, arrayBackedUnsafeRow.getSizeInBytes ) - val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + val offheapUnsafeRow: UnsafeRow = new UnsafeRow(3) offheapUnsafeRow.pointTo( offheapRowPage.getBaseObject, offheapRowPage.getBaseOffset, - 3, // num fields arrayBackedUnsafeRow.getSizeInBytes ) assert(offheapUnsafeRow.getBaseObject === null) From 4f5a24d7e73104771f233af041eeba4f41675974 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 31 Dec 2015 00:15:55 -0800 Subject: [PATCH 1391/2089] [SPARK-7995][SPARK-6280][CORE] Remove AkkaRpcEnv and remove systemName from setupEndpointRef ### Remove AkkaRpcEnv Keep `SparkEnv.actorSystem` because Streaming still uses it. Will remove it and AkkaUtils after refactoring Streaming actorStream API. ### Remove systemName There are 2 places using `systemName`: * `RpcEnvConfig.name`. Actually, although it's used as `systemName` in `AkkaRpcEnv`, `NettyRpcEnv` uses it as the service name to output the log `Successfully started service *** on port ***`. Since the service name in log is useful, I keep `RpcEnvConfig.name`. * `def setupEndpointRef(systemName: String, address: RpcAddress, endpointName: String)`. Each `ActorSystem` has a `systemName`. Akka requires `systemName` in its URI and will refuse a connection if `systemName` is not matched. However, `NettyRpcEnv` doesn't use it. So we can remove `systemName` from `setupEndpointRef` since we are removing `AkkaRpcEnv`. ### Remove RpcEnv.uriOf `uriOf` exists because Akka uses different URI formats for with and without authentication, e.g., `akka.ssl.tcp...` and `akka.tcp://...`. But `NettyRpcEnv` uses the same format. So it's not necessary after removing `AkkaRpcEnv`. Author: Shixiong Zhu Closes #10459 from zsxwing/remove-akka-rpc-env. --- .../scala/org/apache/spark/SparkConf.scala | 3 +- .../scala/org/apache/spark/SparkEnv.scala | 12 +- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 3 +- .../apache/spark/deploy/worker/Worker.scala | 11 +- .../rpc/{netty => }/RpcEndpointAddress.scala | 11 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 28 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 404 ------------------ .../apache/spark/rpc/netty/NettyRpcEnv.scala | 5 +- .../cluster/SimrSchedulerBackend.scala | 11 +- .../cluster/SparkDeploySchedulerBackend.scala | 9 +- .../mesos/CoarseMesosSchedulerBackend.scala | 10 +- .../apache/spark/util/ActorLogReceive.scala | 70 --- .../org/apache/spark/util/AkkaUtils.scala | 107 +---- .../org/apache/spark/util/RpcUtils.scala | 11 +- .../apache/spark/MapOutputTrackerSuite.scala | 2 +- .../org/apache/spark/SSLSampleConfigs.scala | 2 - .../StandaloneDynamicAllocationSuite.scala | 2 +- .../spark/deploy/client/AppClientSuite.scala | 2 +- .../spark/deploy/master/MasterSuite.scala | 2 +- .../spark/deploy/worker/WorkerSuite.scala | 8 +- .../deploy/worker/WorkerWatcherSuite.scala | 6 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 30 +- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 71 --- .../rpc/netty/NettyRpcAddressSuite.scala | 1 + .../spark/rpc/netty/NettyRpcEnvSuite.scala | 4 +- .../apache/spark/util/AkkaUtilsSuite.scala | 360 ---------------- project/MimaExcludes.scala | 11 + .../spark/deploy/yarn/ApplicationMaster.scala | 12 +- 29 files changed, 90 insertions(+), 1120 deletions(-) rename core/src/main/scala/org/apache/spark/rpc/{netty => }/RpcEndpointAddress.scala (89%) delete mode 100644 core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala delete mode 100644 core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d3384fb297732..ff2c4c34c0ca7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -544,7 +544,8 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + - "are no longer accepted. To specify the equivalent now, one may use '64k'.") + "are no longer accepted. To specify the equivalent now, one may use '64k'."), + DeprecatedConfig("spark.rpc", "2.0", "Not used any more.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 52acde1b414eb..b98cc964eda87 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,7 +34,6 @@ import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemor import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer @@ -97,9 +96,7 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { - actorSystem.shutdown() - } + actorSystem.shutdown() rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -248,14 +245,11 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) - // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + // Create the ActorSystem for Akka and get the port it binds to. val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, clientMode = !isDriver) - val actorSystem: ActorSystem = - if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { - rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - } else { + val actorSystem: ActorSystem = { val actorSystemPort = if (port == 0 || rpcEnv.address == null) { port diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index f03875a3e8c89..328a1bb84f5fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -230,7 +230,7 @@ object Client { RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). - map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME)) rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) rpcEnv.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index a5753e1053649..f7c33214c2406 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -104,8 +104,7 @@ private[spark] class AppClient( return } logInfo("Connecting to master " + masterAddress.toSparkURL + "...") - val masterRef = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterRef = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) masterRef.send(RegisterApplication(appDescription, self)) } catch { case ie: InterruptedException => // Cancelled diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 84e7b366bc965..37b94e02cc9d1 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -45,7 +45,6 @@ private[deploy] class Worker( cores: Int, memory: Int, masterRpcAddresses: Array[RpcAddress], - systemName: String, endpointName: String, workDirPath: String = null, val conf: SparkConf, @@ -101,7 +100,7 @@ private[deploy] class Worker( private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString private var registered = false private var connected = false private val workerId = generateWorkerId() @@ -209,8 +208,7 @@ private[deploy] class Worker( override def run(): Unit = { try { logInfo("Connecting to master " + masterAddress + "...") - val masterEndpoint = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled @@ -266,8 +264,7 @@ private[deploy] class Worker( override def run(): Unit = { try { logInfo("Connecting to master " + masterAddress + "...") - val masterEndpoint = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled @@ -711,7 +708,7 @@ private[deploy] object Worker extends Logging { val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, - masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr)) + masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr)) rpcEnv } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala rename to core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala index cd6f00cc08e6c..b9db60a7797d8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.rpc.netty +package org.apache.spark.rpc import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAddress /** * An address identifier for an RPC endpoint. @@ -29,7 +28,7 @@ import org.apache.spark.rpc.RpcAddress * @param rpcAddress The socket address of the endpoint. * @param name Name of the endpoint. */ -private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { +private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { require(name != null, "RpcEndpoint name must be provided.") @@ -44,7 +43,11 @@ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val nam } } -private[netty] object RpcEndpointAddress { +private[spark] object RpcEndpointAddress { + + def apply(host: String, port: Int, name: String): RpcEndpointAddress = { + new RpcEndpointAddress(host, port, name) + } def apply(sparkUrl: String): RpcEndpointAddress = { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 64a4a8bf7c5eb..56683771335a6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -23,7 +23,8 @@ import java.nio.channels.ReadableByteChannel import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.rpc.netty.NettyRpcEnvFactory +import org.apache.spark.util.RpcUtils /** @@ -32,15 +33,6 @@ import org.apache.spark.util.{RpcUtils, Utils} */ private[spark] object RpcEnv { - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - val rpcEnvNames = Map( - "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", - "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "netty") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] - } - def create( name: String, host: String, @@ -48,9 +40,8 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) - getRpcEnvFactory(conf).create(config) + new NettyRpcEnvFactory().create(config) } } @@ -98,12 +89,11 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName`. * This is a blocking action. */ - def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(RpcEndpointAddress(address, endpointName).toString) } /** @@ -124,12 +114,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def awaitTermination(): Unit - /** - * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of - * creating it manually because different [[RpcEnv]] may have different formats. - */ - def uriOf(systemName: String, address: RpcAddress, endpointName: String): String - /** * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala deleted file mode 100644 index 9d098154f7190..0000000000000 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import java.io.File -import java.nio.channels.ReadableByteChannel -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Future -import scala.language.postfixOps -import scala.reflect.ClassTag -import scala.util.control.NonFatal - -import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} -import akka.event.Logging.Error -import akka.pattern.{ask => akkaAsk} -import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import akka.serialization.JavaSerializer - -import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} - -/** - * A RpcEnv implementation based on Akka. - * - * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and - * remove Akka from the dependencies. - */ -private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, - val securityManager: SecurityManager, - conf: SparkConf, - boundPort: Int) - extends RpcEnv(conf) with Logging { - - private val defaultAddress: RpcAddress = { - val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress - // In some test case, ActorSystem doesn't bind to any address. - // So just use some default value since they are only some unit tests - RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) - } - - override val address: RpcAddress = defaultAddress - - /** - * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make - * [[RpcEndpoint.self]] work. - */ - private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() - - /** - * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` - */ - private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() - - private val _fileServer = new AkkaFileServer(conf, securityManager) - - private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { - endpointToRef.put(endpoint, endpointRef) - refToEndpoint.put(endpointRef, endpoint) - } - - private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { - val endpoint = refToEndpoint.remove(endpointRef) - if (endpoint != null) { - endpointToRef.remove(endpoint) - } - } - - /** - * Retrieve the [[RpcEndpointRef]] of `endpoint`. - */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) - - override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use defered function because the Actor needs to use `endpointRef`. - // So `actorRef` should be created after assigning `endpointRef`. - val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { - - assert(endpointRef != null) - - override def preStart(): Unit = { - // Listen for remote client network events - context.system.eventStream.subscribe(self, classOf[AssociationEvent]) - safelyCall(endpoint) { - endpoint.onStart() - } - } - - override def receiveWithLogging: Receive = { - case AssociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case DisassociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => - safelyCall(endpoint) { - endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) - } - - case e: AssociationEvent => - // TODO ignore? - - case m: AkkaMessage => - logDebug(s"Received RPC message: $m") - safelyCall(endpoint) { - processMessage(endpoint, m, sender) - } - - case AkkaFailure(e) => - safelyCall(endpoint) { - throw e - } - - case message: Any => { - logWarning(s"Unknown message: $message") - } - - } - - override def postStop(): Unit = { - unregisterEndpoint(endpoint.self) - safelyCall(endpoint) { - endpoint.onStop() - } - } - - }), name = name) - endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) - registerEndpoint(endpoint, endpointRef) - // Now actorRef can be created safely - endpointRef.init() - endpointRef - } - - private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { - val message = m.message - val needReply = m.needReply - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(new RpcCallContext { - override def sendFailure(e: Throwable): Unit = { - _sender ! AkkaFailure(e) - } - - override def reply(response: Any): Unit = { - _sender ! AkkaMessage(response, false) - } - - // Use "lazy" because most of RpcEndpoints don't need "senderAddress" - override lazy val senderAddress: RpcAddress = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address - }) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](message, { message => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - } catch { - case NonFatal(e) => - _sender ! AkkaFailure(e) - if (!needReply) { - // If the sender does not require a reply, it may not handle the exception. So we rethrow - // "e" to make sure it will be processed. - throw e - } - } - } - - /** - * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will - * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. - */ - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) - } - } - } - } - - private def akkaAddressToRpcAddress(address: Address): RpcAddress = { - RpcAddress(address.host.getOrElse(defaultAddress.host), - address.port.getOrElse(defaultAddress.port)) - } - - override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). - // this is just in case there is a timeout from creating the future in resolveOne, we want the - // exception to indicate the conf that determines the timeout - recover(defaultLookupTimeout.addMessageIfTimeout) - } - - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { - AkkaUtils.address( - AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) - } - - override def shutdown(): Unit = { - actorSystem.shutdown() - _fileServer.shutdown() - } - - override def stop(endpoint: RpcEndpointRef): Unit = { - require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) - actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) - } - - override def awaitTermination(): Unit = { - actorSystem.awaitTermination() - } - - override def toString: String = s"${getClass.getSimpleName}($actorSystem)" - - override def deserialize[T](deserializationAction: () => T): T = { - JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { - deserializationAction() - } - } - - override def openChannel(uri: String): ReadableByteChannel = { - throw new UnsupportedOperationException( - "AkkaRpcEnv's files should be retrieved using an HTTP client.") - } - - override def fileServer: RpcEnvFileServer = _fileServer - -} - -private[akka] class AkkaFileServer( - conf: SparkConf, - securityManager: SecurityManager) extends RpcEnvFileServer { - - @volatile private var httpFileServer: HttpFileServer = _ - - override def addFile(file: File): String = { - getFileServer().addFile(file) - } - - override def addJar(file: File): String = { - getFileServer().addJar(file) - } - - override def addDirectory(baseUri: String, path: File): String = { - val fixedBaseUri = validateDirectoryUri(baseUri) - getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) - } - - def shutdown(): Unit = { - if (httpFileServer != null) { - httpFileServer.stop() - } - } - - private def getFileServer(): HttpFileServer = { - if (httpFileServer == null) synchronized { - if (httpFileServer == null) { - httpFileServer = startFileServer() - } - } - httpFileServer - } - - private def startFileServer(): HttpFileServer = { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - server - } - -} - -private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - config.name, config.host, config.port, config.conf, config.securityManager) - actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) - } -} - -/** - * Monitor errors reported by Akka and log them. - */ -private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[Error]) - } - - override def receiveWithLogging: Actor.Receive = { - case Error(cause: Throwable, _, _, message: String) => logDebug(message, cause) - } -} - -private[akka] class AkkaRpcEndpointRef( - @transient private val defaultAddress: RpcAddress, - @transient private val _actorRef: () => ActorRef, - conf: SparkConf, - initInConstructor: Boolean) - extends RpcEndpointRef(conf) with Logging { - - def this( - defaultAddress: RpcAddress, - _actorRef: ActorRef, - conf: SparkConf) = { - this(defaultAddress, () => _actorRef, conf, true) - } - - lazy val actorRef = _actorRef() - - override lazy val address: RpcAddress = { - val akkaAddress = actorRef.path.address - RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), - akkaAddress.port.getOrElse(defaultAddress.port)) - } - - override lazy val name: String = actorRef.path.name - - private[akka] def init(): Unit = { - // Initialize the lazy vals - actorRef - address - name - } - - if (initInConstructor) { - init() - } - - override def send(message: Any): Unit = { - actorRef ! AkkaMessage(message, false) - } - - override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { - // The function will run in the calling thread, so it should be short and never block. - case msg @ AkkaMessage(message, reply) => - if (reply) { - logError(s"Receive $msg but the sender cannot reply") - Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) - } else { - Future.successful(message) - } - case AkkaFailure(e) => - Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T]. - recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) - } - - override def toString: String = s"${getClass.getSimpleName}($actorRef)" - - final override def equals(that: Any): Boolean = that match { - case other: AkkaRpcEndpointRef => actorRef == other.actorRef - case _ => false - } - - final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() -} - -/** - * A wrapper to `message` so that the receiver knows if the sender expects a reply. - * @param message - * @param needReply if the sender expects a reply message - */ -private[akka] case class AkkaMessage(message: Any, needReply: Boolean) - -/** - * A reply with the failure error from the receiver to the sender - */ -private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 090a1b9f6e366..ef876b1d8c15a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -257,9 +257,6 @@ private[netty] class NettyRpcEnv( dispatcher.getRpcEndpointRef(endpoint) } - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new RpcEndpointAddress(address, endpointName).toString - override def shutdown(): Unit = { cleanup() } @@ -427,7 +424,7 @@ private[netty] object NettyRpcEnv extends Logging { } -private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { +private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 641638a77d5f5..781ecfff7e5e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} -import org.apache.spark.rpc.RpcAddress -import org.apache.spark.{Logging, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( @@ -39,9 +39,10 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 5105475c760e2..1209cce6d1a61 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress} import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} @@ -54,9 +54,10 @@ private[spark] class SparkDeploySchedulerBackend( launcherBackend.connect() // The endpoint for executors to talk to us - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 7d08eae0b4871..a4ed85cd2a4a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -31,7 +31,7 @@ import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -215,10 +215,10 @@ private[spark] class CoarseMesosSchedulerBackend( if (conf.contains("spark.testing")) { "driverURL" } else { - sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString } } diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala deleted file mode 100644 index 81a7cbde01ce5..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import akka.actor.Actor -import org.slf4j.Logger - -/** - * A trait to enable logging all Akka actor messages. Here's an example of using this: - * - * {{{ - * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { - * ... - * override def receiveWithLogging = { - * case GetLocations(blockId) => - * sender ! getLocations(blockId) - * ... - * } - * ... - * } - * }}} - * - */ -private[spark] trait ActorLogReceive { - self: Actor => - - override def receive: Actor.Receive = new Actor.Receive { - - private val _receiveWithLogging = receiveWithLogging - - override def isDefinedAt(o: Any): Boolean = { - val handled = _receiveWithLogging.isDefinedAt(o) - if (!handled) { - log.debug(s"Received unexpected actor system event: $o") - } - handled - } - - override def apply(o: Any): Unit = { - if (log.isDebugEnabled) { - log.debug(s"[actor] received message $o from ${self.sender}") - } - val start = System.nanoTime - _receiveWithLogging.apply(o) - val timeTaken = (System.nanoTime - start).toDouble / 1000000 - if (log.isDebugEnabled) { - log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") - } - } - } - - def receiveWithLogging: Actor.Receive - - protected def log: Logger -} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 1738258a0c794..f2d93edd4fd2e 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import scala.collection.JavaConverters._ -import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} -import akka.pattern.ask - +import akka.actor.{ActorSystem, ExtendedActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} -import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. @@ -139,104 +136,4 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - timeout: RpcTimeout): T = { - askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) - } - - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails even after the specified number of retries. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - maxAttempts: Int, - retryInterval: Long, - timeout: RpcTimeout): T = { - // TODO: Consider removing multiple attempts - if (actor == null) { - throw new SparkException(s"Error sending message [message = $message]" + - " as actor is null ") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < maxAttempts) { - attempts += 1 - try { - val future = actor.ask(message)(timeout.duration) - val result = timeout.awaitResult(future) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - if (attempts < maxAttempts) { - Thread.sleep(retryInterval) - } - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - - def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def makeExecutorRef( - name: String, - conf: SparkConf, - host: String, - port: Int, - actorSystem: ActorSystem): ActorRef = { - val executorActorSystemName = SparkEnv.executorActorSystemName - Utils.checkHost(host, "Expected hostname") - val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def protocol(actorSystem: ActorSystem): String = { - val akkaConf = actorSystem.settings.config - val sslProp = "akka.remote.netty.tcp.enable-ssl" - protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp)) - } - - def protocol(ssl: Boolean = false): String = { - if (ssl) { - "akka.ssl.tcp" - } else { - "akka.tcp" - } - } - - def address( - protocol: String, - systemName: String, - host: String, - port: Int, - actorName: String): String = { - s"$protocol://$systemName@$host:$port/user/$actorName" - } - } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 7578a3b1d85f2..a51f30b9c2921 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -20,20 +20,19 @@ package org.apache.spark.util import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps -import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} -object RpcUtils { +private[spark] object RpcUtils { /** * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. */ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } /** Returns the configured number of times to retry connecting */ @@ -47,7 +46,7 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ - private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + def askRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") } @@ -57,7 +56,7 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ - private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7e70308bb360c..5b29d69cd9428 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -125,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 2d14249855c9d..33270bec6247c 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,7 +41,6 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -55,7 +54,6 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 85c1c1bbf3dc1..ab3d4cafebefa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -474,7 +474,7 @@ class StandaloneDynamicAllocationSuite (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 415e2b37dbbdc..eb794b6739d5e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -147,7 +147,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 242bf4b5566eb..10e33a32ba4c3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -98,7 +98,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) + rpcEnv.setupEndpointRef(rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index faed4bdc68447..082d5e86eb512 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -67,7 +67,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -93,7 +93,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -128,7 +128,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -154,7 +154,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 40c24bdecc6ce..0ffd91d8ffc06 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) @@ -36,7 +36,7 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 9c850c0da52a3..924fce7f61c26 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -94,7 +94,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") try { rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { @@ -148,7 +148,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) @@ -176,7 +176,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") try { // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause val e = intercept[SparkException] { @@ -435,7 +435,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) @@ -475,8 +475,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-remotely-error") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { @@ -527,8 +526,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") try { - val serverRefInServer2 = - serverEnv1.setupEndpointRef("server2", serverRef2.address, serverRef2.name) + val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address, serverRef2.name) // Send a message to set up the connection serverRefInServer2.send("hello") @@ -556,8 +554,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) try { - val serverRefInClient = - clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) // Send a message to set up the connection serverRefInClient.send("hello") @@ -588,8 +585,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") try { - val serverRefInClient = - clientEnv.setupEndpointRef("server", serverRef.address, serverRef.name) + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) // Send a message to set up the connection serverRefInClient.send("hello") @@ -623,8 +619,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-unserializable-error") + val rpcEndpointRef = + anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[Exception] { @@ -661,8 +657,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { case msg: String => message = msg } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "send-authentication") rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { assert("hello" === message) @@ -693,8 +688,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala deleted file mode 100644 index 7aac02775e1bf..0000000000000 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import org.apache.spark.rpc._ -import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} - -class AkkaRpcEnvSuite extends RpcEnvSuite { - - override def createRpcEnv(conf: SparkConf, - name: String, - port: Int, - clientMode: Boolean = false): RpcEnv = { - new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) - } - - test("setupEndpointRef: systemName, address, endpointName") { - val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { - override val rpcEnv = env - - override def receive = { - case _ => - } - }) - val conf = new SparkConf() - val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) - try { - val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") - assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === - newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) - } finally { - newRpcEnv.shutdown() - } - } - - test("uriOf") { - val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } - - test("uriOf: ssl") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val securityManager = new SecurityManager(conf) - val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) - try { - val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } finally { - rpcEnv.shutdown() - } - } - -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 56743ba650b41..4fcdb619f9300 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.SparkFunSuite +import org.apache.spark.rpc.RpcEndpointAddress class NettyRpcAddressSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index ce83087ec04d6..994a58836bd0d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -33,9 +33,9 @@ class NettyRpcEnvSuite extends RpcEnvSuite { } test("non-existent endpoint") { - val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString val e = intercept[RpcEndpointNotFoundException] { - env.setupEndpointRef("test", env.address, "nonexist-endpoint") + env.setupEndpointRef(env.address, "nonexist-endpoint") } assert(e.getMessage.contains(uri)) } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala deleted file mode 100644 index 0af4b6098bb0a..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.ArrayBuffer - -import java.util.concurrent.TimeoutException - -import akka.actor.ActorNotFound - -import org.apache.spark._ -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -import org.apache.spark.SSLSampleConfigs._ - - -/** - * Test the AkkaUtils with various security settings. - */ -class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { - - test("remote fetch security bad password") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security pass") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val goodconf = new SparkConf - goodconf.set("spark.authenticate", "true") - goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf) - - assert(securityManagerGood.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security on and passwords match - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off client") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch ssl on") { - val conf = sparkSSLConfig() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled") { - val conf = sparkSSLConfig() - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled - bad credentials") { - val conf = sparkSSLConfig() - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.rpc", "akka") - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on - untrusted server") { - val conf = sparkSSLConfigUntrusted() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - .set("spark.rpc.askTimeout", "5s") - .set("spark.rpc.lookupTimeout", "5s") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - try { - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - fail("should receive either ActorNotFound or TimeoutException") - } catch { - case e: ActorNotFound => - case e: TimeoutException => - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b7d27c9f06666..59886ab76244a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -35,6 +35,17 @@ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("2.0") => Seq( + // SPARK-7995 Remove AkkaRpcEnv + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaFailure"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaFailure$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEndpointRef$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEnvFactory"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEnv"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaMessage$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEndpointRef"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.ErrorMonitor"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaMessage") + ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index a01bb267d7948..cccc061647a7f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -29,8 +29,7 @@ import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, - SparkException, SparkUserAppException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.rpc._ @@ -281,10 +280,10 @@ private[spark] class ApplicationMaster( .getOrElse("") val _sparkConf = if (sc != null) sc.getConf else sparkConf - val driverUrl = _rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + _sparkConf.get("spark.driver.host"), + _sparkConf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString allocator = client.register(driverUrl, driverRef, yarnConf, @@ -310,7 +309,6 @@ private[spark] class ApplicationMaster( port: String, isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( - SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = From 5cdecb1841f5f1208a6100a673a768c84396633f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 31 Dec 2015 01:33:21 -0800 Subject: [PATCH 1392/2089] [SPARK-12039][SQL] Re-enable HiveSparkSubmitSuite's SPARK-9757 Persist Parquet relation with decimal column https://issues.apache.org/jira/browse/SPARK-12039 since we do not support hadoop1, we can re-enable this test in master. Author: Yin Huai Closes #10533 from yhuai/SPARK-12039-enable. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 53185fd7751e3..71f05f3b00291 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -103,7 +103,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - ignore("SPARK-9757 Persist Parquet relation with decimal column") { + test("SPARK-9757 Persist Parquet relation with decimal column") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SPARK_9757.getClass.getName.stripSuffix("$"), From efb10cc9ad370955cec64e8f63a3b646058a9840 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 31 Dec 2015 01:34:13 -0800 Subject: [PATCH 1393/2089] [SPARK-3873][STREAMING] Import order fixes for streaming. Also included a few miscelaneous other modules that had very few violations. Author: Marcelo Vanzin Closes #10532 from vanzin/SPARK-3873-streaming. --- .../flume/sink/TransactionProcessor.scala | 2 +- .../streaming/flume/EventTransformer.scala | 4 ++-- .../streaming/flume/FlumeInputDStream.scala | 20 +++++++++---------- .../flume/FlumePollingInputDStream.scala | 2 +- .../streaming/flume/FlumeTestUtils.scala | 2 +- .../spark/streaming/flume/FlumeUtils.scala | 3 +-- .../flume/PollingFlumeTestUtils.scala | 4 ++-- .../spark/streaming/kafka/KafkaCluster.scala | 11 ++++++---- .../streaming/kafka/KafkaInputDStream.scala | 2 +- .../spark/streaming/kafka/KafkaRDD.scala | 10 +++++----- .../streaming/kafka/KafkaTestUtils.scala | 4 ++-- .../spark/streaming/kafka/KafkaUtils.scala | 10 +++++----- .../kafka/ReliableKafkaReceiver.scala | 6 +++--- .../twitter/TwitterInputDStream.scala | 6 +++--- .../streaming/twitter/TwitterUtils.scala | 5 +++-- .../spark/streaming/zeromq/ZeroMQUtils.scala | 2 +- .../spark/repl/ExecutorClassLoader.scala | 7 +++---- .../apache/spark/streaming/Checkpoint.scala | 7 +++---- .../apache/spark/streaming/DStreamGraph.scala | 6 ++++-- .../apache/spark/streaming/StateSpec.scala | 3 ++- .../spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/java/JavaDStream.scala | 12 +++++------ .../streaming/api/java/JavaPairDStream.scala | 2 +- .../api/java/JavaPairInputDStream.scala | 4 ++-- .../api/java/JavaStreamingContext.scala | 6 +++--- .../streaming/api/python/PythonDStream.scala | 5 ++--- .../dstream/ConstantInputDStream.scala | 2 +- .../dstream/DStreamCheckpointData.scala | 8 +++++--- .../streaming/dstream/FilteredDStream.scala | 5 +++-- .../dstream/FlatMapValuedDStream.scala | 7 ++++--- .../streaming/dstream/FlatMappedDStream.scala | 5 +++-- .../streaming/dstream/ForEachDStream.scala | 3 ++- .../streaming/dstream/GlommedDStream.scala | 5 +++-- .../dstream/MapPartitionedDStream.scala | 5 +++-- .../streaming/dstream/MapValuedDStream.scala | 7 ++++--- .../dstream/MapWithStateDStream.scala | 2 +- .../streaming/dstream/MappedDStream.scala | 5 +++-- .../dstream/PairDStreamFunctions.scala | 4 ++-- .../dstream/PluggableInputDStream.scala | 3 ++- .../streaming/dstream/QueueInputDStream.scala | 2 +- .../streaming/dstream/RawInputDStream.scala | 15 +++++++------- .../dstream/ReceiverInputDStream.scala | 4 ++-- .../dstream/ReducedWindowedDStream.scala | 11 ++++------ .../streaming/dstream/ShuffledDStream.scala | 5 +++-- .../dstream/SocketInputDStream.scala | 13 ++++++------ .../streaming/dstream/StateDStream.scala | 6 +++--- .../streaming/dstream/UnionDStream.scala | 3 +-- .../streaming/dstream/WindowedDStream.scala | 4 ++-- .../spark/streaming/rdd/MapWithStateRDD.scala | 4 ++-- .../streaming/receiver/BlockGenerator.scala | 2 +- .../receiver/ReceivedBlockHandler.scala | 4 ++-- .../spark/streaming/receiver/Receiver.scala | 2 +- .../receiver/ReceiverSupervisor.scala | 4 ++-- .../receiver/ReceiverSupervisorImpl.scala | 2 +- .../scheduler/InputInfoTracker.scala | 2 +- .../spark/streaming/scheduler/Job.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 4 ++-- .../scheduler/ReceivedBlockTracker.scala | 2 +- .../streaming/scheduler/ReceiverTracker.scala | 4 ++-- .../scheduler/StreamingListener.scala | 2 +- .../apache/spark/streaming/ui/BatchPage.scala | 2 +- .../ui/StreamingJobProgressListener.scala | 12 +++-------- .../spark/streaming/ui/StreamingTab.scala | 6 +++--- .../apache/spark/streaming/ui/UIUtils.scala | 8 ++++---- .../streaming/util/BatchedWriteAheadLog.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 4 ++-- .../util/FileBasedWriteAheadLogReader.scala | 3 ++- .../util/RateLimitedOutputStream.scala | 5 ++--- .../spark/streaming/util/RawTextSender.scala | 2 +- .../streaming/util/WriteAheadLogUtils.scala | 2 +- .../spark/tools/GenerateMIMAIgnore.scala | 2 +- .../tools/JavaAPICompletenessChecker.scala | 6 +++--- .../spark/tools/StoragePerfTester.scala | 2 +- 73 files changed, 181 insertions(+), 180 deletions(-) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index 7ad43b1d7b0a0..b15c2097e550c 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} import scala.util.control.Breaks -import org.apache.flume.{Transaction, Channel} +import org.apache.flume.{Channel, Transaction} // Flume forces transactions to be thread-local (horrible, I know!) // So the sink basically spawns a new thread to pull the events out within a transaction. diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 48df27b26867f..5c773d4b07cf6 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.flume -import java.io.{ObjectOutput, ObjectInput} +import java.io.{ObjectInput, ObjectOutput} import scala.collection.JavaConverters._ -import org.apache.spark.util.Utils import org.apache.spark.Logging +import org.apache.spark.util.Utils /** * A simple object that provides the implementation of readExternal and writeExternal for both diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2b9116eb3c790..1bfa35a8b3d1d 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -17,29 +17,27 @@ package org.apache.spark.streaming.flume +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.net.InetSocketAddress -import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder import org.apache.avro.ipc.NettyServer +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol, Status} +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} +import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ + import org.apache.spark.Logging -import org.apache.spark.util.Utils import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver - -import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} -import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory -import org.jboss.netty.handler.codec.compression._ +import org.apache.spark.util.Utils private[streaming] class FlumeInputDStream[T: ClassTag]( diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 6737750c3d63e..d9c25e86540db 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -32,8 +32,8 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.flume.sink._ +import org.apache.spark.streaming.receiver.Receiver /** * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index fe5dcc8e4b9de..3f87ce46e5952 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -29,7 +29,7 @@ import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index c719b80aca7ed..3e3ed712f0dbf 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume +import java.io.{ByteArrayOutputStream, DataOutputStream} import java.net.InetSocketAddress -import java.io.{DataOutputStream, ByteArrayOutputStream} import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -30,7 +30,6 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index bfe7548d4f50e..9515d07c5ee5b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume -import java.util.concurrent._ import java.util.{Collections, List => JList, Map => JMap} +import java.util.concurrent._ import scala.collection.mutable.ArrayBuffer @@ -28,7 +28,7 @@ import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables -import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} +import org.apache.spark.streaming.flume.sink.{SparkSink, SparkSinkConfig} /** * Share codes for Scala and Python unit tests diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 8465432c5850f..c4e18d92eefa9 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -17,14 +17,17 @@ package org.apache.spark.streaming.kafka -import scala.util.control.NonFatal -import scala.util.Random -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConverters._ import java.util.Properties + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.util.control.NonFatal + import kafka.api._ import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition} import kafka.consumer.{ConsumerConfig, SimpleConsumer} + import org.apache.spark.SparkException /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38730fecf332a..67f2360896b16 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -22,7 +22,7 @@ import java.util.Properties import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index ea5f842c6cafe..603be22818206 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -20,11 +20,6 @@ package org.apache.spark.streaming.kafka import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} -import org.apache.spark.partial.{PartialResult, BoundedDouble} -import org.apache.spark.rdd.RDD -import org.apache.spark.util.NextIterator - import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} import kafka.consumer.SimpleConsumer @@ -32,6 +27,11 @@ import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties +import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + /** * A batch-oriented interface for consuming from Kafka. * Starting and ending offsets are specified in advance, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 45a6982b9afe5..a76fa6671a4b0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.streaming.kafka import java.io.File import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.concurrent.TimeoutException import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeoutException import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -37,9 +37,9 @@ import kafka.utils.{ZKStringSerializer, ZkUtils} import org.I0Itec.zkclient.ZkClient import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.streaming.Time import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf} /** * This is a helper class for Kafka test suites. This has the functionality to set up diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index fe572220528d5..0cb875c9758f9 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -27,19 +27,19 @@ import scala.reflect.ClassTag import com.google.common.base.Charsets.UTF_8 import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} -import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler} +import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java._ import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.util.WriteAheadLogUtils object KafkaUtils { /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 764d170934aa6..a872781b78eeb 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -18,10 +18,10 @@ package org.apache.spark.streaming.kafka import java.util.Properties -import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} +import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} -import scala.collection.{Map, mutable} -import scala.reflect.{ClassTag, classTag} +import scala.collection.{mutable, Map} +import scala.reflect.{classTag, ClassTag} import kafka.common.TopicAndPartition import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 9a85a6597c27f..a48eec70b9f78 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming.twitter import twitter4j._ import twitter4j.auth.Authorization -import twitter4j.conf.ConfigurationBuilder import twitter4j.auth.OAuthAuthorization +import twitter4j.conf.ConfigurationBuilder +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.Logging import org.apache.spark.streaming.receiver.Receiver /* A stream of Twitter statuses, potentially filtered by one or more keywords. diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala index c6a9a2b73714f..3e843e947da61 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala @@ -19,10 +19,11 @@ package org.apache.spark.streaming.twitter import twitter4j.Status import twitter4j.auth.Authorization + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{DStream, ReceiverInputDStream} object TwitterUtils { /** diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 4ea218eaa4de1..63cd8a2721f0c 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.zeromq -import scala.reflect.ClassTag import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index de7b831adc736..2bf1be1a582b5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,7 +17,7 @@ package org.apache.spark.repl -import java.io.{FilterInputStream, ByteArrayOutputStream, InputStream, IOException} +import java.io.{ByteArrayOutputStream, FilterInputStream, InputStream, IOException} import java.net.{HttpURLConnection, URI, URL, URLEncoder} import java.nio.channels.Channels @@ -27,10 +27,9 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ -import org.apache.spark.{SparkConf, SparkEnv, Logging} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils -import org.apache.spark.util.ParentClassLoader +import org.apache.spark.util.{ParentClassLoader, Utils} /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index d0046afdeb447..61b230ab6f98a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -21,15 +21,14 @@ import java.io._ import java.util.concurrent.Executors import java.util.concurrent.RejectedExecutionException -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator - +import org.apache.spark.util.{MetadataCleaner, Utils} private[streaming] class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 7829f5e887995..eedb42c0611c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -17,11 +17,13 @@ package org.apache.spark.streaming +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.collection.mutable.ArrayBuffer -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} + import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.streaming.dstream.{DStream, ReceiverInputDStream, InputDStream} import org.apache.spark.util.Utils final private[streaming] class DStreamGraph extends Serializable with Logging { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 9f6f95223f619..0b094558dfd59 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -18,12 +18,13 @@ package org.apache.spark.streaming import com.google.common.base.Optional + +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner -import org.apache.spark.{HashPartitioner, Partitioner} /** * :: Experimental :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b24c0d067bb05..c4a10aa2dd3b9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -29,8 +29,8 @@ import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 01cdcb0574040..a59f4efccb575 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.rdd.RDD - import scala.language.implicitConversions import scala.reflect.ClassTag + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.dstream.DStream /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 42ddd63f0f06c..2bf3ccec6bc55 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong, Iterable => JIterable} +import java.lang.{Iterable => JIterable, Long => JLong} import java.util.{List => JList} import scala.collection.JavaConverters._ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala index e6ff8a0cb545f..da0db02236a1f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.dstream.InputDStream - import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.InputDStream + /** * A Java-friendly interface to [[org.apache.spark.streaming.dstream.InputDStream]] of * key-value pairs. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 7a50135025463..00f9d8a9e8817 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,14 +17,15 @@ package org.apache.spark.streaming.api.java -import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} +import java.lang.{Boolean => JBoolean} import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -37,10 +38,9 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver -import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.scheduler.StreamingListener /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 056248ccc7bcd..953fe95177f02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -30,12 +30,11 @@ import org.apache.spark.SparkException import org.apache.spark.api.java._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Interval, Duration, Time} -import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.{Duration, Interval, Time} import org.apache.spark.streaming.api.java._ +import org.apache.spark.streaming.dstream._ import org.apache.spark.util.Utils - /** * Interface for Python callback function which is used to transform RDDs */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index 4eb92dd8b1053..695384deb32d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} /** * An input stream that always returns the same RDD on each timestep. Useful for testing. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 39fd21342813e..3eff174c2b66c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -17,11 +17,13 @@ package org.apache.spark.streaming.dstream +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import java.io.{ObjectOutputStream, ObjectInputStream, IOException} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.FileSystem + +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.Logging import org.apache.spark.streaming.Time import org.apache.spark.util.Utils diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index fcd5216f101af..43079880b2352 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FilteredDStream[T: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 9d09a3baf37ca..778d556d2efb9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import scala.reflect.ClassTag +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 475ea2d2d4f38..96a444a7baa5e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FlatMappedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 4410a9977c87b..a0fadee8a9844 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.scheduler.Job -import scala.reflect.ClassTag /** * An internal DStream used to represent output operations like DStream.foreachRDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index dbb295fe54f71..9f1252f091a63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 5994bc1e23f2b..bcdf1752e61e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MapPartitionedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 954d2eb4a7b00..855c3dd096f4b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import scala.reflect.ClassTag +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index 706465d4e25d7..36ff9c7e6182f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._ +import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} /** * :: Experimental :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index fa14b2e897c3e..e11d82697af89 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MappedDStream[T: ClassTag, U: ClassTag] ( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index a64a1fe93f40d..babc722709325 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,12 +24,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} -import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 002aac9f43617..2442e4c01a0c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext import scala.reflect.ClassTag + +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index cd073646370d0..a8d108de6c3e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag import org.apache.spark.rdd.{RDD, UnionRDD} -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} private[streaming] class QueueInputDStream[T: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index 5a9eda7c12776..ac73dca05a674 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.StreamingContext - -import scala.reflect.ClassTag - +import java.io.EOFException import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException import java.util.concurrent.ArrayBlockingQueue -import org.apache.spark.streaming.receiver.Receiver +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.receiver.Receiver /** * An input stream that reads blocks of serialized objects from a given network address. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 87c20afd5c13c..a18551fac719a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,12 +21,12 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{RateController, ReceivedBlockInfo, StreamInputInfo} import org.apache.spark.streaming.scheduler.rate.RateEstimator -import org.apache.spark.streaming.scheduler.{ReceivedBlockInfo, RateController, StreamInputInfo} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.streaming.{StreamingContext, Time} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 6a583bf2a3626..535954908539e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -17,18 +17,15 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD} +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD, RDD} import org.apache.spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer import org.apache.spark.streaming.{Duration, Interval, Time} -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - private[streaming] class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index e0ffd5d86b435..0fe15440dd44d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.Partitioner -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index de84e0c9a498d..10644b9201918 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,18 +17,17 @@ package org.apache.spark.streaming.dstream -import scala.util.control.NonFatal - -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.NextIterator +import java.io._ +import java.net.{Socket, UnknownHostException} import scala.reflect.ClassTag +import scala.util.control.NonFatal -import java.io._ -import java.net.{UnknownHostException, Socket} import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.NextIterator private[streaming] class SocketInputDStream[T: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 621d6dff788f4..ebbe139a2cdf8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag - private[streaming] class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index d73ffdfd84d2d..2b07dd6185861 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -21,9 +21,8 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.SparkException +import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 4efba039f8959..ee50a8d024e12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -17,13 +17,13 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.Duration -import scala.reflect.ClassTag - private[streaming] class WindowedDStream[T: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index fdf61674a37f2..1d2244eaf22b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -22,11 +22,11 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import org.apache.spark._ import org.apache.spark.rdd.{MapPartitionsRDD, RDD} -import org.apache.spark.streaming.{Time, StateImpl, State} +import org.apache.spark.streaming.{State, StateImpl, Time} import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils -import org.apache.spark._ /** * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index cc7c04bfc9f63..109af32cf4bbd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, SystemClock} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 5f6c5b024085c..43c605af73716 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import scala.language.{existentials, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} -import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 2252e28f22af8..b08152485ab5b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -22,8 +22,8 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 158d1ba2f183a..c42a9ac233f87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.{Utils, ThreadUtils} +import org.apache.spark.util.{ThreadUtils, Utils} /** * Abstract class that is responsible for supervising a Receiver in the worker. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 167f56aa42281..b774b6b9a55d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.util.RpcUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index deb15d075975c..92da0ced28fbc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index ab1b3565fcc19..7050d7ef45240 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Try} import org.apache.spark.streaming.Time -import org.apache.spark.util.{Utils, CallSite} +import org.apache.spark.util.{CallSite, Utils} /** * Class representing a Spark computation. It may contain multiple Spark jobs. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8dfdc1f57b403..a5a01e77639c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,10 +19,10 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} +import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 4dab64d696b3e..60b5c838e9734 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,11 +27,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index ea5d12b50fcc5..9ddf176aee84c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,14 +20,14 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} import scala.language.existentials import scala.util.{Failure, Success} import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{TaskLocation, ExecutorCacheTaskLocation} +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util.WriteAheadLogUtils diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index d19bdbb443c5e..58fc78d552106 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.Queue -import org.apache.spark.util.Distribution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Distribution /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index bc1711930d3ac..7635f79a3d2d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -25,8 +25,8 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.ui.StreamingJobProgressListener.{OutputOpId, SparkJobId} -import org.apache.spark.ui.jobs.UIData.JobUIData import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.ui.jobs.UIData.JobUIData private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index f6cc6edf2569a..4908be0536353 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,19 +17,13 @@ package org.apache.spark.streaming.ui -import java.util.LinkedHashMap -import java.util.{Map => JMap} -import java.util.Properties +import java.util.{LinkedHashMap, Map => JMap, Properties} -import scala.collection.mutable.{ArrayBuffer, Queue, HashMap, SynchronizedBuffer} +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue, SynchronizedBuffer} import org.apache.spark.scheduler._ -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted -import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted - private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) extends StreamingListener with SparkListener { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index bc53f2a31f6d1..0662c64a0ce9b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -21,14 +21,14 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext import org.apache.spark.ui.{SparkUI, SparkUITab} -import StreamingTab._ - /** * Spark Web UI tab that shows statistics of a streaming job. * This assumes the given SparkContext has enabled its SparkUI. */ private[spark] class StreamingTab(val ssc: StreamingContext) - extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging { + + import StreamingTab._ private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index d89f7ad3e16b7..a485a46937f31 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.ui -import scala.xml.Node - -import org.apache.commons.lang3.StringEscapeUtils - import java.text.SimpleDateFormat import java.util.TimeZone import java.util.concurrent.TimeUnit +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + private[streaming] object UIUtils { /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index b2cd524f28b74..8cb45cdffa5d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -18,8 +18,8 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.LinkedBlockingQueue import java.util.{Iterator => JIterator} +import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index b946e0d8e9271..9418beec0d74a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -29,8 +29,8 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.{CompletionIterator, ThreadUtils} import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{CompletionIterator, ThreadUtils} /** * This class manages write ahead log files. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index a375c0729534b..e79b139bdd037 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.streaming.util -import java.io.{IOException, Closeable, EOFException} +import java.io.{Closeable, EOFException, IOException} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration + import org.apache.spark.Logging /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala index a96e2924a0b44..5c3c7a6bf1b39 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala @@ -17,13 +17,12 @@ package org.apache.spark.streaming.util -import scala.annotation.tailrec - import java.io.OutputStream import java.util.concurrent.TimeUnit._ -import org.apache.spark.Logging +import scala.annotation.tailrec +import org.apache.spark.Logging private[streaming] class RateLimitedOutputStream(out: OutputStream, desiredBytesPerSec: Int) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index 6addb96752038..e48eaf7913b12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.io.Source -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.IntParam diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 7f9e2c9734970..ed616d8e810bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -21,8 +21,8 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.util.Utils /** A helper class with utility functions related to the WriteAheadLog interface */ private[streaming] object WriteAheadLogUtils extends Logging { diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 5155daa6d17bf..a947fac1d751d 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -23,8 +23,8 @@ import java.util.jar.JarFile import scala.collection.mutable import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} +import scala.reflect.runtime.universe.runtimeMirror import scala.util.Try /** diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 856ea177a9a10..6fb7184e877ee 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,16 +17,16 @@ package org.apache.spark.tools -import java.lang.reflect.{Type, Method} +import java.lang.reflect.{Method, Type} import scala.collection.mutable.ArrayBuffer import scala.language.existentials import org.apache.spark._ import org.apache.spark.api.java._ -import org.apache.spark.rdd.{RDD, DoubleRDDFunctions, PairRDDFunctions, OrderedRDDFunctions} +import org.apache.spark.rdd.{DoubleRDDFunctions, OrderedRDDFunctions, PairRDDFunctions, RDD} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.{DStream, PairDStreamFunctions} diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 0dc2861253f17..8a5c7c0e730e6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -20,8 +20,8 @@ package org.apache.spark.tools import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util.Utils From 5adec63a922d6f60cd6cb87ebdab61a17131ac1a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 31 Dec 2015 20:23:19 -0800 Subject: [PATCH 1394/2089] [SPARK-10359][PROJECT-INFRA] Multiple fixes to dev/test-dependencies.sh script This patch includes multiple fixes for the `dev/test-dependencies.sh` script (which was introduced in #10461): - Use `build/mvn --force` instead of `mvn` in one additional place. - Explicitly set a zero exit code on success. - Set `LC_ALL=C` to make `sort` results agree across machines (see https://stackoverflow.com/questions/28881/). - Set `should_run_build_tests=True` for `build` module (this somehow got lost). Author: Josh Rosen Closes #10543 from JoshRosen/dep-script-fixes. --- dev/sparktestsupport/modules.py | 3 ++- dev/test-dependencies.sh | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4667b289f507a..47cd600bd18a4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -402,7 +402,8 @@ def contains_file(self, filename): source_file_regexes=[ ".*pom.xml", "dev/test-dependencies.sh", - ] + ], + should_run_build_tests=True ) ec2 = Module( diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 984e29d1beb88..4e260e2abf042 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -22,6 +22,10 @@ set -e FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script @@ -37,7 +41,7 @@ HADOOP_PROFILES=( # resolve Spark's internal submodule dependencies. # See http://stackoverflow.com/a/3545363 for an explanation of this one-liner: -OLD_VERSION=$(mvn help:evaluate -Dexpression=project.version|grep -Ev '(^\[|Download\w+:)') +OLD_VERSION=$($MVN help:evaluate -Dexpression=project.version|grep -Ev '(^\[|Download\w+:)') TEMP_VERSION="spark-$(date +%s | tail -c6)" function reset_version { @@ -100,3 +104,5 @@ for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do exit 1 fi done + +exit 0 From c9dbfcc653b868fdb28106c1d1bcb6cb6caac6cc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 31 Dec 2015 23:48:05 -0800 Subject: [PATCH 1395/2089] [SPARK-11743][SQL] Move the test for arrayOfUDT A following pr for #9712. Move the test for arrayOfUDT. Author: Liang-Chi Hsieh Closes #10538 from viirya/move-udt-test. --- .../sql/catalyst/encoders/RowEncoderSuite.scala | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 0ea51ece4bc5e..8f4faab7bace5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -108,7 +108,8 @@ class RowEncoderSuite extends SparkFunSuite { .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) .add("arrayOfMap", ArrayType(mapOfString)) - .add("arrayOfStruct", ArrayType(structOfString))) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) encodeDecodeTest( new StructType() @@ -130,18 +131,6 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) - test(s"encode/decode: arrayOfUDT") { - val schema = new StructType() - .add("arrayOfUDT", arrayOfUDT) - - val encoder = RowEncoder(schema) - - val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) - assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) - } - test(s"encode/decode: Product") { val schema = new StructType() .add("structAsProduct", From a59a357cae82ca9b6926b55903ce4f12ae131735 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 31 Dec 2015 23:48:55 -0800 Subject: [PATCH 1396/2089] [SPARK-3873][MLLIB] Import order fixes. A slight adjustment to the checker configuration was needed; there is a handful of warnings still left, but those are because of a bug in the checker that I'll fix separately (before enabling errors for the checker, of course). Author: Marcelo Vanzin Closes #10535 from vanzin/SPARK-3873-mllib. --- .../src/main/scala/org/apache/spark/ml/Pipeline.scala | 6 ++---- .../main/scala/org/apache/spark/ml/Predictor.scala | 2 +- .../main/scala/org/apache/spark/ml/ann/Layer.scala | 4 ++-- .../org/apache/spark/ml/attribute/attributes.scala | 2 +- .../scala/org/apache/spark/ml/attribute/package.scala | 2 +- .../apache/spark/ml/classification/Classifier.scala | 3 +-- .../spark/ml/classification/GBTClassifier.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 10 +++++----- .../ml/classification/ProbabilisticClassifier.scala | 2 +- .../ml/classification/RandomForestClassifier.scala | 4 ++-- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 6 +++--- .../scala/org/apache/spark/ml/clustering/LDA.scala | 7 ++++--- .../MulticlassClassificationEvaluator.scala | 6 +++--- .../scala/org/apache/spark/ml/feature/Binarizer.scala | 2 +- .../org/apache/spark/ml/feature/Bucketizer.scala | 2 +- .../org/apache/spark/ml/feature/CountVectorizer.scala | 2 +- .../main/scala/org/apache/spark/ml/feature/DCT.scala | 4 ++-- .../apache/spark/ml/feature/ElementwiseProduct.scala | 2 +- .../scala/org/apache/spark/ml/feature/HashingTF.scala | 2 +- .../org/apache/spark/ml/feature/Interaction.scala | 4 ++-- .../org/apache/spark/ml/feature/MinMaxScaler.scala | 3 +-- .../scala/org/apache/spark/ml/feature/NGram.scala | 2 +- .../org/apache/spark/ml/feature/Normalizer.scala | 2 +- .../org/apache/spark/ml/feature/OneHotEncoder.scala | 2 +- .../apache/spark/ml/feature/PolynomialExpansion.scala | 4 ++-- .../apache/spark/ml/feature/QuantileDiscretizer.scala | 6 +++--- .../scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../org/apache/spark/ml/feature/SQLTransformer.scala | 6 +++--- .../apache/spark/ml/feature/StopWordsRemover.scala | 2 +- .../scala/org/apache/spark/ml/feature/Tokenizer.scala | 2 +- .../org/apache/spark/ml/feature/VectorAssembler.scala | 4 ++-- .../org/apache/spark/ml/feature/VectorSlicer.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/Word2Vec.scala | 2 +- .../scala/org/apache/spark/ml/r/SparkRWrappers.scala | 4 ++-- .../org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../spark/ml/regression/AFTSurvivalRegression.scala | 8 ++++---- .../spark/ml/regression/IsotonicRegression.scala | 6 +++--- .../apache/spark/ml/regression/LinearRegression.scala | 2 +- .../org/apache/spark/ml/regression/Regressor.scala | 2 +- .../spark/ml/source/libsvm/LibSVMRelation.scala | 2 +- .../main/scala/org/apache/spark/ml/tree/Node.scala | 4 ++-- .../org/apache/spark/ml/tree/impl/NodeIdCache.scala | 2 +- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 4 ++-- .../scala/org/apache/spark/ml/tree/treeModels.scala | 2 +- .../org/apache/spark/ml/tuning/CrossValidator.scala | 2 +- .../apache/spark/ml/tuning/TrainValidationSplit.scala | 2 +- .../org/apache/spark/ml/tuning/ValidatorParams.scala | 2 +- .../python/PowerIterationClusteringModelWrapper.scala | 2 +- .../spark/mllib/api/python/PythonMLLibAPI.scala | 9 ++++----- .../spark/mllib/api/python/Word2VecModelWrapper.scala | 1 + .../mllib/classification/LogisticRegression.scala | 4 ++-- .../spark/mllib/clustering/GaussianMixtureModel.scala | 7 +++---- .../apache/spark/mllib/clustering/KMeansModel.scala | 5 ++--- .../org/apache/spark/mllib/clustering/LDAModel.scala | 2 +- .../apache/spark/mllib/clustering/LDAOptimizer.scala | 4 ++-- .../org/apache/spark/mllib/clustering/LDAUtils.scala | 2 +- .../apache/spark/mllib/clustering/LocalKMeans.scala | 2 +- .../mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../spark/mllib/clustering/StreamingKMeans.scala | 2 +- .../spark/mllib/evaluation/AreaUnderCurve.scala | 2 +- .../spark/mllib/evaluation/RankingMetrics.scala | 2 +- .../spark/mllib/evaluation/RegressionMetrics.scala | 2 +- .../apache/spark/mllib/feature/ChiSqSelector.scala | 2 +- .../org/apache/spark/mllib/feature/Word2Vec.scala | 3 +-- .../scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- .../spark/mllib/impl/PeriodicCheckpointer.scala | 4 ++-- .../spark/mllib/linalg/EigenValueDecomposition.scala | 2 +- .../org/apache/spark/mllib/linalg/Matrices.scala | 2 +- .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 4 ++-- .../mllib/linalg/distributed/CoordinateMatrix.scala | 2 +- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 2 +- .../spark/mllib/linalg/distributed/RowMatrix.scala | 6 +++--- .../spark/mllib/optimization/GradientDescent.scala | 6 +++--- .../apache/spark/mllib/optimization/Optimizer.scala | 3 +-- .../org/apache/spark/mllib/optimization/Updater.scala | 4 ++-- .../spark/mllib/random/RandomDataGenerator.scala | 4 ++-- .../scala/org/apache/spark/mllib/rdd/RandomRDD.scala | 6 +++--- .../scala/org/apache/spark/mllib/rdd/SlidingRDD.scala | 2 +- .../mllib/regression/GeneralizedLinearAlgorithm.scala | 6 +++--- .../apache/spark/mllib/regression/LabeledPoint.scala | 2 +- .../org/apache/spark/mllib/regression/Lasso.scala | 2 +- .../spark/mllib/regression/LinearRegression.scala | 2 +- .../mllib/stat/MultivariateOnlineSummarizer.scala | 2 +- .../org/apache/spark/mllib/stat/Statistics.scala | 4 ++-- .../stat/distribution/MultivariateGaussian.scala | 4 ++-- .../org/apache/spark/mllib/stat/test/ChiSqTest.scala | 6 +++--- .../mllib/tree/configuration/BoostingStrategy.scala | 2 +- .../spark/mllib/tree/configuration/Strategy.scala | 2 +- .../apache/spark/mllib/tree/impl/NodeIdCache.scala | 6 +++--- .../org/apache/spark/mllib/tree/model/Node.scala | 4 ++-- .../org/apache/spark/mllib/tree/model/Split.scala | 1 - .../mllib/util/LogisticRegressionDataGenerator.scala | 6 +++--- .../org/apache/spark/mllib/util/MFDataGenerator.scala | 2 +- .../scala/org/apache/spark/mllib/util/MLUtils.scala | 11 +++++------ scalastyle-config.xml | 4 ++-- 95 files changed, 160 insertions(+), 169 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 4b2b3f8489fd0..3acc60d6c6d65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -26,11 +26,9 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.MLReader -import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e0dcd427fae24..6aacffd4f236f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,9 +24,9 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * (private[ml]) Trait for parameters for prediction (regression and classification). diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index b5258ff348477..d02806a6ea227 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.ann -import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, - sum => Bsum} +import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV, + Vector => BV} import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} import org.apache.spark.mllib.linalg.{Vector, Vectors} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index a7c10333c0d53..521d209a8f0ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala index 7ac21d7d563f2..f6964054db839 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml -import org.apache.spark.sql.DataFrame import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.sql.DataFrame /** * ==ML attributes== diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 45df557a89908..8186afc17a53e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -26,7 +26,6 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * (private[spark]) Params for classification. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cda2bca58c50d..74bf07c3f1fff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index a691aa005ef54..719d1076fee8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} +import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index fdd1851ae5508..865614aa5c8a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d6d85ad2533a2..f7d662df2fe5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 71e968497500f..6e5abb29ff0a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -20,15 +20,15 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * Common params for KMeans and KMeansModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 830510b1698d4..af0b3e1835003 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,19 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index c44db0ec595ea..a921153b9474f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 63c06581482ed..5b17d3483b895 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 324353a96afb3..33abc7c99d4b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index b9e2144c0ad40..1268c87908c62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6bed72164a1da..a6f878151de73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.types.DataType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index a359cb8f37ec3..07a12df320357 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 9e15835429a38..61a78d73c4347 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 2181119f04a5d..7d2a1da990fce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index c2866f5eceff3..559a025265916 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature - import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} @@ -25,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ import org.apache.spark.sql.functions._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 65414ecbefbbd..f8bc7e3f0c031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index c2d514fd9629e..a603b3f833202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index d70164eaf0224..c01e29af478c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 08610593fadda..42b26c8ee836c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 7bf67c6325a35..39de8461dc9c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.Logging -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5c43a41bee3b4..2b578c2a95e16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -21,8 +21,8 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c09f4d076c964..e0ca45b9a6190 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.types.StructType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 318808596dc6a..5d6936dce2c74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 8ad7bbedaab5c..8456a0e915804 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 801096fed27bf..e9d1b57b91d07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 5410a50bc2e47..4813d8a5b5dc0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index f105a983a34f6..59c34cd1703aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 4d82b90bfdf20..551e75dc0a02d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index b798aa1fab767..14a28b8d5b51f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,7 +31,7 @@ import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index aedfb48058dc5..3787ca45d5172 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -23,18 +23,18 @@ import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index bbb1c7ac0a51e..e8d361b1a2a8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,18 +21,18 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 5e5850963edc9..dee26337dcdf6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,9 +25,9 @@ import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim.WeightedLeastSquares -import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c72ef29680329..cf189e8e96f95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 11b9815ecc832..1bed542c40316 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 9cfd466294b95..6507a8ad7cf3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict, ImpurityStats} +import org.apache.spark.mllib.tree.model.{ImpurityStats, + InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 1ee01131d6334..172ea52820564 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4a3b12d1440b8..6e87302c7779b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,10 +26,10 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index b77191156f68f..40ed95773e149 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Abstraction for Decision Tree models. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 40f8857fc5866..477675cad1a90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.jackson.JsonMethods._ import org.json4s.{DefaultFormats, JObject} +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index adf06302047a7..f346ea655ae5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 8897ab0825acd..553f254172410 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.Estimator import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param.{ParamMap, Param, Params} +import org.apache.spark.ml.param.{Param, ParamMap, Params} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala index bc6041b221732..6530870b83a11 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.api.python -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.clustering.PowerIterationClusteringModel +import org.apache.spark.rdd.RDD /** * A Wrapper of PowerIterationClusteringModel to provide helper method for Python diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f6826ddbfabfe..061db56c74938 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -42,18 +42,17 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} -import org.apache.spark.mllib.stat.{ - KernelDensity, MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} -import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 0f55980481dcb..55dfd973eb25f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + import scala.collection.JavaConverters._ import org.apache.spark.SparkContext diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2d52abc122bf2..2a7697b5a79cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel -import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 74d13e4f77945..5c9bc62cb09bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -26,11 +25,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 91fa9b0d3590d..26c6235fe5907 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,15 +23,14 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SQLContext} /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7384d065a2ea8..2fce3ff641101 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} +import breeze.linalg.{argmax, argtopk, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 17c0609800e90..c19595e6cd21b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} -import breeze.numerics.{trigamma, abs, exp} +import breeze.linalg.{all, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics.{abs, exp, trigamma} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index a9ba7b60bad08..647d37bd822c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.linalg.{max, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics._ /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index b2f140e1b1352..c9a96c68667af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.clustering import scala.util.Random import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.Vectors /** * An utility object to run K-means locally. This is private to the ML package because it's used diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index bb1804505948b..2ab0920b06363 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.clustering -import org.json4s.JsonDSL._ import org.json4s._ +import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ @@ -30,7 +31,6 @@ import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.{Logging, SparkContext, SparkException} /** * Model produced by [[PowerIterationClustering]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 80843719f50b4..79d217e183c62 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 078fbfbe4f0e1..f0779491e6374 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * Computes the area under the curve (AUC) using the trapezoidal rule. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index cc01936dd34b2..f8de4e2220c4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 1d8f4fe340fb4..34883f2f390d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.sql.DataFrame /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index eaa99cfe82e27..33728bf5d77e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Chi Squared selector model. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1f400e1430eba..a7e1b76df6a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -36,9 +35,9 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.sql.SQLContext /** * Entry in vocabulary diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 97916daa2e9ad..ed49c9492fdcd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.fpm import java.{lang => jl, util => ju} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import org.apache.spark.Logging diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 72d3aabc9b1f4..57ca4d3464f15 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 863abe86d38d7..bb94745f078e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import com.github.fommil.netlib.ARPACK -import org.netlib.util.{intW, doubleW} +import org.netlib.util.{doubleW, intW} /** * Compute eigen-decomposition. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 8879dcf75c9bf..d7a74db0b1fd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import java.util.{Arrays, Random} -import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet} import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4dcf351df43fa..cecfd067bd874 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.linalg -import java.util import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson} +import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} import org.apache.spark.SparkException import org.apache.spark.annotation.{AlphaComponent, Since} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 8a70f34e70f6a..97b03b340f20e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} +import org.apache.spark.rdd.RDD /** * Represents an entry in an distributed matrix. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 976299124cedd..e8de515211a18 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.rdd.RDD /** * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 2018a678688e1..0a36da4101339 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -21,8 +21,8 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd, MatrixSingularException, inv} +import breeze.linalg.{axpy => brzAxpy, inv, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV, + MatrixSingularException, SparseVector => BSV} import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.Logging @@ -30,8 +30,8 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD -import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.XORShiftRandom /** * Represents a row-oriented distributed Matrix with no meaningful row indices. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 37bb6f6097f67..5873669b37e36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,12 +19,12 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, norm} +import breeze.linalg.{norm, DenseVector => BDV} -import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 7f6d94571b5ef..d8e56720967d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.rdd.RDD - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 9f463e0cafb6f..03c01e0553d78 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.optimization import scala.math._ -import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} +import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9eab7efc160da..fa04f8eb5e796 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution._ -import org.apache.spark.annotation.{Since, DeveloperApi} -import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.util.random.{Pseudorandom, XORShiftRandom} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f8cea7ecea6bf..92bc66949ae80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -17,15 +17,15 @@ package org.apache.spark.mllib.rdd +import scala.reflect.ClassTag +import scala.util.Random + import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag -import scala.util.Random - private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, val generator: RandomDataGenerator[T], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index ead8db6344998..adb5e51947f6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.rdd import scala.collection.mutable import scala.reflect.ClassTag -import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD private[mllib] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8f657bfb9c730..e60edc675c83f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.regression +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index c284ad2325374..45540f0c5c4ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index a9aba173fa0e3..d55e5dfdaaf53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 4996ace5df85d..7da82c862a2b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 201333c3690df..98404be2603c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index bcb33a7a04677..f3159f7e724cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} -import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 0724af93088c2..052b5b1d65b07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} +import breeze.linalg.{diag, eigSym, max, DenseMatrix => DBM, DenseVector => DBV, Vector => BV} import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLUtils /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 23c8d7c7c8075..f22f2df320f0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.stat.test +import scala.collection.mutable + import breeze.linalg.{DenseMatrix => BDM} import org.apache.commons.math3.distribution.ChiSquaredDistribution -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import scala.collection.mutable - /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index d2513a9d5c5bb..0b118a76733fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,7 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} +import org.apache.spark.mllib.tree.loss.{LogLoss, Loss, SquaredError} /** * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 372d6617a4014..6c04403f1ad75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -21,9 +21,9 @@ import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since -import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} /** * Stores all the configuration options for tree construction diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 1c611976a9308..fbbec1197404a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.rdd.RDD import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.tree.model.{Bin, Node, Split} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ea6e5aa5d94e7..66f0908c1250f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging -import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.FeatureType._ /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b85a66c05a81d..783a4acb55ce9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 33477ee20ebbd..68835bc79677f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 906bd30563bd0..8af6750da4ff3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4c9151f0cb4fb..89186de96988f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import org.apache.spark.annotation.Since import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliCellSampler -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.BernoulliCellSampler /** * Helper methods to load, save and pre-process data used in ML Lib. diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 6925e18737b75..16d18b3328ff7 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -219,8 +219,8 @@ This file is divided into 3 sections: java,scala,3rdParty,spark - javax?\..+ - scala\..+ + javax?\..* + scala\..* (?!org\.apache\.spark\.).* org\.apache\.spark\..* From ad5b7cfcca7a5feb83b9ed94b6e725c6d789579b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Jan 2016 00:54:25 -0800 Subject: [PATCH 1397/2089] [SPARK-12409][SPARK-12387][SPARK-12391][SQL] Refactor filter pushdown for JDBCRDD and add few filters This patch refactors the filter pushdown for JDBCRDD and also adds few filters. Added filters are basically from #10468 with some refactoring. Test cases are from #10468. Author: Liang-Chi Hsieh Closes #10470 from viirya/refactor-jdbc-filter. --- .../execution/datasources/jdbc/JDBCRDD.scala | 69 +++++++++++-------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 7 +- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 7072ee4b4e3bc..c74574d280a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -179,7 +179,7 @@ private[sql] object JDBCRDD extends Logging { case stringValue: String => s"'${escapeSql(stringValue)}'" case timestampValue: Timestamp => "'" + timestampValue + "'" case dateValue: Date => "'" + dateValue + "'" - case arrayValue: Array[Object] => arrayValue.map(compileValue).mkString(", ") + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") case _ => value } @@ -188,24 +188,41 @@ private[sql] object JDBCRDD extends Logging { /** * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. + * Returns None for an unhandled filter. */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case Not(f) => s"(NOT (${compileFilter(f)}))" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" - case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" - case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case In(attr, value) => s"$attr IN (${compileValue(value)})" - case Or(f1, f2) => s"(${compileFilter(f1)}) OR (${compileFilter(f2)})" - case And(f1, f2) => s"(${compileFilter(f1)}) AND (${compileFilter(f2)})" - case _ => null + private def compileFilter(f: Filter): Option[String] = { + Option(f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" + case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" + case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null) + case Or(f1, f2) => + // We can't compile Or filter unless both sub-filters are compiled successfully. + // It applies too for the following And filter. + // If we can make sure compileFilter supports all filters, we can remove this check. + val or = Seq(f1, f2).map(compileFilter(_)).flatten + if (or.size == 2) { + or.map(p => s"($p)").mkString(" OR ") + } else { + null + } + case And(f1, f2) => + val and = Seq(f1, f2).map(compileFilter(_)).flatten + if (and.size == 2) { + and.map(p => s"($p)").mkString(" AND ") + } else { + null + } + case _ => null + }) } /** @@ -307,25 +324,21 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { - val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } + private val filterWhereClause: String = + filters.map(JDBCRDD.compileFilter).flatten.mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. */ private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause + "WHERE " + filterWhereClause + " AND " + part.whereClause } else if (part.whereClause != null) { "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause } else { - filterWhereClause + "" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 00e37f107a88b..633ae215fc123 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -190,7 +190,7 @@ class JDBCSuite extends SparkFunSuite assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) - .collect().size === 2) + .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) .collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " @@ -453,8 +453,8 @@ class JDBCSuite extends SparkFunSuite } test("compile filters") { - val compileFilter = PrivateMethod[String]('compileFilter) - def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) + val compileFilter = PrivateMethod[Option[String]]('compileFilter) + def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("") assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) @@ -473,6 +473,7 @@ class JDBCSuite extends SparkFunSuite === "(NOT (col1 IN ('mno', 'pqr')))") assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) === "") } test("Dialect unregister") { From 01a29866b1da23157fec41e4b037b7dbe4ffda16 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 1 Jan 2016 13:24:09 -0800 Subject: [PATCH 1398/2089] [SPARK-12592][SQL][TEST] Don't mute Spark loggers in TestHive.reset() There's a hack done in `TestHive.reset()`, which intended to mute noisy Hive loggers. However, Spark testing loggers are also muted. Author: Cheng Lian Closes #10540 from liancheng/spark-12592.dont-mute-spark-loggers. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 97792549bb7ab..013fbab0a812b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -410,7 +410,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { try { // HACK: Hive is too noisy by default. org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => - log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + val logger = log.asInstanceOf[org.apache.log4j.Logger] + if (!logger.getName.contains("org.apache.spark")) { + logger.setLevel(org.apache.log4j.Level.WARN) + } } cacheManager.clearCache() From 6c20b3c0871609cbbb1de8e12edd3cad318fc14e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 Jan 2016 13:31:25 -0800 Subject: [PATCH 1399/2089] Disable test-dependencies.sh. --- dev/run-tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 706e2d141c27f..23278d298c22d 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -418,8 +418,9 @@ def run_python_tests(test_modules, parallelism): def run_build_tests(): - set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") - run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + # set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") + # run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + pass def run_sparkr_tests(): From 0da7bd50ddf0fb9e0e8aeadb9c7fb3edf6f0ee6e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Jan 2016 13:39:20 -0800 Subject: [PATCH 1400/2089] [SPARK-12286][SPARK-12290][SPARK-12294][SPARK-12284][SQL] always output UnsafeRow It's confusing that some operator output UnsafeRow but some not, easy to make mistake. This PR change to only output UnsafeRow for all the operators (SparkPlan), removed the rule to insert Unsafe/Safe conversions. For those that can't output UnsafeRow directly, added UnsafeProjection into them. Closes #10330 cc JoshRosen rxin Author: Davies Liu Closes #10511 from davies/unsafe_row. --- .../org/apache/spark/sql/SQLContext.scala | 3 +- .../apache/spark/sql/execution/Exchange.scala | 25 +-- .../spark/sql/execution/ExistingRDD.scala | 15 +- .../apache/spark/sql/execution/Expand.scala | 13 +- .../apache/spark/sql/execution/Generate.scala | 8 +- .../spark/sql/execution/LocalTableScan.scala | 13 +- .../org/apache/spark/sql/execution/Sort.scala | 4 - .../spark/sql/execution/SparkPlan.scala | 23 --- .../apache/spark/sql/execution/Window.scala | 8 +- .../aggregate/SortBasedAggregate.scala | 4 - .../SortBasedAggregationIterator.scala | 8 +- .../aggregate/TungstenAggregate.scala | 4 - .../spark/sql/execution/basicOperators.scala | 58 +------ .../columnar/InMemoryColumnarTableScan.scala | 8 +- .../joins/BroadcastNestedLoopJoin.scala | 7 - .../execution/joins/CartesianProduct.scala | 4 - .../spark/sql/execution/joins/HashJoin.scala | 4 - .../sql/execution/joins/HashOuterJoin.scala | 4 - .../sql/execution/joins/HashSemiJoin.scala | 4 - .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 - .../sql/execution/joins/SortMergeJoin.scala | 4 - .../execution/joins/SortMergeOuterJoin.scala | 4 - .../spark/sql/execution/local/LocalNode.scala | 12 -- .../execution/local/NestedLoopJoinNode.scala | 6 +- .../apache/spark/sql/execution/python.scala | 7 +- .../sql/execution/rowFormatConverters.scala | 108 ------------ .../spark/sql/execution/ExchangeSuite.scala | 2 +- .../spark/sql/execution/ExpandSuite.scala | 54 ------ .../execution/RowFormatConvertersSuite.scala | 164 ------------------ .../spark/sql/execution/SortSuite.scala | 2 +- .../sql/hive/execution/HiveTableScan.scala | 16 +- .../hive/execution/InsertIntoHiveTable.scala | 15 +- .../hive/execution/ScriptTransformation.scala | 3 +- .../sql/sources/hadoopFsRelationSuites.scala | 31 ---- 34 files changed, 74 insertions(+), 574 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index eadf5cba6d9bb..022303239f2af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -904,8 +904,7 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Add row converters", Once, EnsureRowFormats) + Batch("Add exchange", Once, EnsureRequirements(self)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 62cbc518e02af..7b4161930b7d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -50,26 +49,14 @@ case class Exchange( case None => "" } - val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" + val simpleNodeName = "Exchange" s"$simpleNodeName$extraInfo" } - /** - * Returns true iff we can support the data type, and we are not doing range partitioning. - */ - private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] - override def outputPartitioning: Partitioning = newPartitioning override def output: Seq[Attribute] = child.output - // This setting is somewhat counterintuitive: - // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, - // so the planner inserts a converter to convert data into UnsafeRow if needed. - override def outputsUnsafeRows: Boolean = tungstenMode - override def canProcessSafeRows: Boolean = !tungstenMode - override def canProcessUnsafeRows: Boolean = tungstenMode - /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -130,15 +117,7 @@ case class Exchange( } } - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - - private val serializer: Serializer = { - if (tungstenMode) { - new UnsafeRowSerializer(child.output.size) - } else { - new SparkSqlSerializer(sparkConf) - } - } + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) override protected def doPrepare(): Unit = { // If an ExchangeCoordinator is needed, we register this Exchange operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 5c01af011d306..fc508bfafa1c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, AttributeSet, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -99,10 +99,19 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], override val nodeName: String, override val metadata: Map[String, String] = Map.empty, - override val outputsUnsafeRows: Boolean = false) + isUnsafeRow: Boolean = false) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = rdd + protected override def doExecute(): RDD[InternalRow] = { + if (isUnsafeRow) { + rdd + } else { + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } + } + } override def simpleString: String = { val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 91530bd63798a..c3683cc4e7aac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,20 +41,11 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - private[this] val projection = { - if (outputsUnsafeRows) { - (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) - } else { - (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() - } - } + private[this] val projection = + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 0c613e91b979f..4db88a09d8152 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -64,6 +64,7 @@ case class Generate( child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow + val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -77,13 +78,14 @@ case class Generate( } ++ LazyIterator(() => boundGenerator.terminate()).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - joinedRow.withRight(row) + proj(joinedRow.withRight(row)) } } } else { child.execute().mapPartitionsInternal { iter => - iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate()) + val proj = UnsafeProjection.create(output, output) + (iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate())).map(proj) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index ba7f6287ac6c3..59057bf9666ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} /** @@ -29,15 +29,20 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + private val unsafeRows: Array[InternalRow] = { + val proj = UnsafeProjection.create(output, output) + rows.map(r => proj(r).copy()).toArray + } + + private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[InternalRow] = { - rows.toArray + unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - rows.take(limit).toArray + unsafeRows.take(limit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 24207cb46fd29..73dc8cb984471 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -39,10 +39,6 @@ case class Sort( testSpillFrequency: Int = 0) extends UnaryNode { - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fe9b2ad4a0bc3..f20f32aaced2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,17 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute @@ -115,18 +104,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { - if (children.nonEmpty) { - val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) - val hasSafeInputs = children.exists(!_.outputsUnsafeRows) - assert(!(hasSafeInputs && hasUnsafeInputs), - "Child operators should output rows in the same format") - assert(canProcessSafeRows || canProcessUnsafeRows, - "Operator must be able to process at least one row format") - assert(!hasSafeInputs || canProcessSafeRows, - "Operator will receive safe rows as input but cannot process safe rows") - assert(!hasUnsafeInputs || canProcessUnsafeRows, - "Operator will receive unsafe rows as input but cannot process unsafe rows") - } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() doExecute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index c941d673c7248..b79d93d7ca4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -100,8 +100,6 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def canProcessUnsafeRows: Boolean = true - /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -259,16 +257,16 @@ case class Window( * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): MutableProjection = { + expressions: Seq[Expression]): UnsafeProjection = { val references = expressions.zipWithIndex.map{ case (e, i) => // Results of window expressions will be on the right side of child's output BoundReference(child.output.size + i, e.dataType, e.nullable) } val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - newMutableProjection( + UnsafeProjection.create( projectList ++ patchedWindowExpression, - child.output)() + child.output) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c4587ba677b2f..01d076678f041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -49,10 +49,6 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def requiredChildDistribution: List[Distribution] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index ac920aa8bc7f7..6501634ff998b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,6 +87,10 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be + // compared to MutableRow (aggregation buffer) directly. + private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -110,7 +114,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -122,7 +126,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, currentRow) + processRow(sortBasedAggregationBuffer, safeProj(currentRow)) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9d758eb3b7c32..999ebb768af50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -49,10 +49,6 @@ case class TungstenAggregate( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def producedAttributes: AttributeSet = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f19d72f067218..af7237ef25886 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -36,10 +36,6 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -80,12 +76,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - - override def canProcessUnsafeRows: Boolean = true - - override def canProcessSafeRows: Boolean = true } /** @@ -108,10 +98,6 @@ case class Sample( { override def output: Seq[Attribute] = child.output - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -135,8 +121,6 @@ case class Range( output: Seq[Attribute]) extends LeafNode { - override def outputsUnsafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { sqlContext .sparkContext @@ -199,9 +183,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -268,12 +249,14 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. - @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) - private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - projection.map(data.map(_)).getOrElse(data) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } } override def executeCollect(): Array[InternalRow] = { @@ -311,10 +294,6 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { child.execute().coalesce(numPartitions, shuffle = false) } - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -327,10 +306,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -343,10 +318,6 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -371,10 +342,6 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) @@ -394,11 +361,6 @@ case class AppendColumns[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = AttributeSet(newColumns) - // We are using an unsafe combiner. - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { @@ -428,10 +390,6 @@ case class MapGroups[K, T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -472,10 +430,6 @@ case class CoGroup[Key, Left, Right, Result]( right: SparkPlan) extends BinaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index aa7a668e0e938..d80912309bab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -39,9 +39,7 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, - if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), - tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } /** @@ -226,8 +224,6 @@ private[sql] case class InMemoryColumnarTableScan( // The cached version does not change the outputOrdering of the original SparkPlan. override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering - override def outputsUnsafeRows: Boolean = true - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index aab177b2e8427..54275c2cc1134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -46,15 +46,8 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } - override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 81bfe4e67ca73..d9fa4c6b83798 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -81,10 +81,6 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb961d97c3c3c..7f9d9daa5ab20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,10 +44,6 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildSideKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index c6e5868187518..6d464d6946b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,10 +64,6 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index f23a1830e91c1..3e0f74cd98c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,10 +33,6 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def leftKeyGenerator: Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index efa7b49410edc..82498ee395649 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -42,9 +42,6 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4bf7b521c77d3..812f881d06fb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,10 +53,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 7ce38ebdb3413..c3a2bfc59c7a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -89,10 +89,6 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def createLeftKeyGenerator(): Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 6a882c9234df4..e46217050bad5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,18 +69,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin */ def close(): Unit - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true - /** * Returns the content through the [[Iterator]] interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index 7321fc66b4dde..b7fa0c0202221 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -47,11 +47,7 @@ case class NestedLoopJoinNode( } private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } + UnsafeProjection.create(schema) } private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index defcec95fb555..efb4b09c16348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -351,10 +351,6 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -400,13 +396,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val unpickle = new Unpickler val row = new GenericMutableRow(1) val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - joined(queue.poll(), row) + resultProj(joined(queue.poll(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala deleted file mode 100644 index 5f8fc2de8b46d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Converts Java-object-based rows into [[UnsafeRow]]s. - */ -case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToUnsafe = UnsafeProjection.create(child.schema) - iter.map(convertToUnsafe) - } - } -} - -/** - * Converts [[UnsafeRow]]s back into Java-object-based rows. - */ -case class ConvertToSafe(child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) - iter.map(convertToSafe) - } - } -} - -private[sql] object EnsureRowFormats extends Rule[SparkPlan] { - - private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && !operator.canProcessUnsafeRows - - private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessUnsafeRows && !operator.canProcessSafeRows - - private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && operator.canProcessUnsafeRows - - override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { - case operator: SparkPlan if onlyHandlesSafeRows(operator) => - if (operator.children.exists(_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => - if (operator.children.exists(!_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => - if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows. - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 911d12e93e503..87bff3295f5be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -28,7 +28,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + plan => Exchange(SinglePartition, plan), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala deleted file mode 100644 index faef76d52ae75..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.IntegerType - -class ExpandSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - private def testExpand(f: SparkPlan => SparkPlan): Unit = { - val input = (1 to 1000).map(Tuple1.apply) - val projections = Seq.tabulate(2) { i => - Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil - } - val attributes = projections.head.map(_.toAttribute) - checkAnswer( - input.toDF(), - plan => Expand(projections, attributes, f(plan)), - input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) - ) - } - - test("inheriting child row type") { - val exprs = AttributeReference("a", IntegerType, false)() :: Nil - val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) - assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") - } - - test("expanding UnsafeRows") { - testExpand(ConvertToUnsafe) - } - - test("expanding SafeRows") { - testExpand(identity) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala deleted file mode 100644 index 2328899bb2f8d..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{ArrayType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { - - private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { - case c: ConvertToUnsafe => c - case c: ConvertToSafe => c - } - - private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(outputsUnsafe.outputsUnsafeRows) - - test("planner should insert unsafe->safe conversions when required") { - val plan = Limit(10, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) - } - - test("filter can process unsafe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("filter can process safe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("coalesce can process unsafe rows") { - val plan = Coalesce(1, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except can process unsafe rows") { - val plan = Except(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except requires all of its input rows' formats to agree") { - val plan = Except(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect can process unsafe rows") { - val plan = Intersect(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect requires all of its input rows' formats to agree") { - val plan = Intersect(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("execute() fails an assertion if inputs rows are of different formats") { - val e = intercept[AssertionError] { - Union(Seq(outputsSafe, outputsUnsafe)).execute() - } - assert(e.getMessage.contains("format")) - } - - test("union requires all of its input rows' formats to agree") { - val plan = Union(Seq(outputsSafe, outputsUnsafe)) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("union can process safe rows") { - val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("union can process unsafe rows") { - val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("round trip with ConvertToUnsafe and ConvertToSafe") { - val input = Seq(("hello", 1), ("world", 2)) - checkAnswer( - sqlContext.createDataFrame(input), - plan => ConvertToSafe(ConvertToUnsafe(plan)), - input.map(Row.fromTuple) - ) - } - - test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SQLContext.setActive(sqlContext) - val schema = ArrayType(StringType) - val rows = (1 to 100).map { i => - InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) - } - val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) - - val plan = - DummyPlan( - ConvertToSafe( - ConvertToUnsafe(relation))) - assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) - } -} - -case class DummyPlan(child: SparkPlan) extends UnaryNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some - // values gotten from the incoming rows. - // we cache all strings here to make sure we have deep copied UTF8String inside incoming - // safe InternalRow. - val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] - iter.foreach { row => - strings += row.getArray(0).getUTF8String(0) - } - strings.map(InternalRow(_)).iterator - } - } - - override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index e5d34be4c65e8..af971dfc6faec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -99,7 +99,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { ) checkThatPlansAgree( inputDf, - p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), ReferenceSort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8141136de5311..1588728bdbaa4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -132,11 +132,17 @@ case class HiveTableScan( } } - protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + protected override def doExecute(): RDD[InternalRow] = { + val rdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + } + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f936cf565b2bc..44dc68e6ba47f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,18 +28,17 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, Attribute} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.util.SerializableJobConf +import org.apache.spark.{SparkException, TaskContext} private[hive] case class InsertIntoHiveTable( @@ -101,15 +100,17 @@ case class InsertIntoHiveTable( writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + val proj = FromUnsafeProjection(child.schema) iterator.foreach { row => var i = 0 + val safeRow = proj(row) while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) i += 1 } writerContainer - .getLocalFileWriter(row, table.schema) + .getLocalFileWriter(safeRow, table.schema) .write(serializer.serialize(outputData, standardOI)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a61e162f48f1b..6ccd4178190cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -213,7 +213,8 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => if (iter.hasNext) { - processIterator(iter) + val proj = UnsafeProjection.create(schema) + processIterator(iter).map(proj) } else { // If the input iterator has no rows then do not launch the external script. Iterator.empty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 665e87e3e3355..efbf9988ddc13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -27,7 +27,6 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -689,36 +688,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } - - test("HadoopFsRelation produces UnsafeRow") { - withTempTable("test_unsafe") { - withTempPath { dir => - val path = dir.getCanonicalPath - sqlContext.range(3).write.format(dataSourceName).save(path) - sqlContext.read - .format(dataSourceName) - .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) - .load(path) - .registerTempTable("test_unsafe") - - val df = sqlContext.sql( - """SELECT COUNT(*) - |FROM test_unsafe a JOIN test_unsafe b - |WHERE a.id = b.id - """.stripMargin) - - val plan = df.queryExecution.executedPlan - - assert( - plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, - s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): - |$plan - """.stripMargin) - - checkAnswer(df, Row(3)) - } - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when From 44ee920fd49d35b421ae562ea99bcc8f2b98ced6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 Jan 2016 19:23:06 -0800 Subject: [PATCH 1401/2089] Revert "[SPARK-12286][SPARK-12290][SPARK-12294][SPARK-12284][SQL] always output UnsafeRow" This reverts commit 0da7bd50ddf0fb9e0e8aeadb9c7fb3edf6f0ee6e. --- .../org/apache/spark/sql/SQLContext.scala | 3 +- .../apache/spark/sql/execution/Exchange.scala | 25 ++- .../spark/sql/execution/ExistingRDD.scala | 15 +- .../apache/spark/sql/execution/Expand.scala | 13 +- .../apache/spark/sql/execution/Generate.scala | 8 +- .../spark/sql/execution/LocalTableScan.scala | 13 +- .../org/apache/spark/sql/execution/Sort.scala | 4 + .../spark/sql/execution/SparkPlan.scala | 23 +++ .../apache/spark/sql/execution/Window.scala | 8 +- .../aggregate/SortBasedAggregate.scala | 4 + .../SortBasedAggregationIterator.scala | 8 +- .../aggregate/TungstenAggregate.scala | 4 + .../spark/sql/execution/basicOperators.scala | 58 ++++++- .../columnar/InMemoryColumnarTableScan.scala | 8 +- .../joins/BroadcastNestedLoopJoin.scala | 7 + .../execution/joins/CartesianProduct.scala | 4 + .../spark/sql/execution/joins/HashJoin.scala | 4 + .../sql/execution/joins/HashOuterJoin.scala | 4 + .../sql/execution/joins/HashSemiJoin.scala | 4 + .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 + .../sql/execution/joins/SortMergeJoin.scala | 4 + .../execution/joins/SortMergeOuterJoin.scala | 4 + .../spark/sql/execution/local/LocalNode.scala | 12 ++ .../execution/local/NestedLoopJoinNode.scala | 6 +- .../apache/spark/sql/execution/python.scala | 7 +- .../sql/execution/rowFormatConverters.scala | 108 ++++++++++++ .../spark/sql/execution/ExchangeSuite.scala | 2 +- .../spark/sql/execution/ExpandSuite.scala | 54 ++++++ .../execution/RowFormatConvertersSuite.scala | 164 ++++++++++++++++++ .../spark/sql/execution/SortSuite.scala | 2 +- .../sql/hive/execution/HiveTableScan.scala | 16 +- .../hive/execution/InsertIntoHiveTable.scala | 15 +- .../hive/execution/ScriptTransformation.scala | 3 +- .../sql/sources/hadoopFsRelationSuites.scala | 31 ++++ 34 files changed, 574 insertions(+), 74 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 022303239f2af..eadf5cba6d9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -904,7 +904,8 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)) + Batch("Add exchange", Once, EnsureRequirements(self)), + Batch("Add row converters", Once, EnsureRowFormats) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7b4161930b7d2..62cbc518e02af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -49,14 +50,26 @@ case class Exchange( case None => "" } - val simpleNodeName = "Exchange" + val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" s"$simpleNodeName$extraInfo" } + /** + * Returns true iff we can support the data type, and we are not doing range partitioning. + */ + private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] + override def outputPartitioning: Partitioning = newPartitioning override def output: Seq[Attribute] = child.output + // This setting is somewhat counterintuitive: + // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, + // so the planner inserts a converter to convert data into UnsafeRow if needed. + override def outputsUnsafeRows: Boolean = tungstenMode + override def canProcessSafeRows: Boolean = !tungstenMode + override def canProcessUnsafeRows: Boolean = tungstenMode + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -117,7 +130,15 @@ case class Exchange( } } - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf + + private val serializer: Serializer = { + if (tungstenMode) { + new UnsafeRowSerializer(child.output.size) + } else { + new SparkSqlSerializer(sparkConf) + } + } override protected def doPrepare(): Unit = { // If an ExchangeCoordinator is needed, we register this Exchange operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index fc508bfafa1c0..5c01af011d306 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, AttributeSet, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -99,19 +99,10 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], override val nodeName: String, override val metadata: Map[String, String] = Map.empty, - isUnsafeRow: Boolean = false) + override val outputsUnsafeRows: Boolean = false) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = { - if (isUnsafeRow) { - rdd - } else { - rdd.mapPartitionsInternal { iter => - val proj = UnsafeProjection.create(schema) - iter.map(proj) - } - } - } + protected override def doExecute(): RDD[InternalRow] = rdd override def simpleString: String = { val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index c3683cc4e7aac..91530bd63798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,11 +41,20 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - private[this] val projection = - (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + private[this] val projection = { + if (outputsUnsafeRows) { + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + } else { + (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() + } + } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 4db88a09d8152..0c613e91b979f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -64,7 +64,6 @@ case class Generate( child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow - val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -78,14 +77,13 @@ case class Generate( } ++ LazyIterator(() => boundGenerator.terminate()).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - proj(joinedRow.withRight(row)) + joinedRow.withRight(row) } } } else { child.execute().mapPartitionsInternal { iter => - val proj = UnsafeProjection.create(output, output) - (iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate())).map(proj) + iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate()) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 59057bf9666ef..ba7f6287ac6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.Attribute /** @@ -29,20 +29,15 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - private val unsafeRows: Array[InternalRow] = { - val proj = UnsafeProjection.create(output, output) - rows.map(r => proj(r).copy()).toArray - } - - private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) + private lazy val rdd = sqlContext.sparkContext.parallelize(rows) protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[InternalRow] = { - unsafeRows + rows.toArray } override def executeTake(limit: Int): Array[InternalRow] = { - unsafeRows.take(limit) + rows.take(limit).toArray } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 73dc8cb984471..24207cb46fd29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -39,6 +39,10 @@ case class Sort( testSpillFrequency: Int = 0) extends UnaryNode { + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index f20f32aaced2e..fe9b2ad4a0bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,6 +97,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute @@ -104,6 +115,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { + if (children.nonEmpty) { + val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) + val hasSafeInputs = children.exists(!_.outputsUnsafeRows) + assert(!(hasSafeInputs && hasUnsafeInputs), + "Child operators should output rows in the same format") + assert(canProcessSafeRows || canProcessUnsafeRows, + "Operator must be able to process at least one row format") + assert(!hasSafeInputs || canProcessSafeRows, + "Operator will receive safe rows as input but cannot process safe rows") + assert(!hasUnsafeInputs || canProcessUnsafeRows, + "Operator will receive unsafe rows as input but cannot process unsafe rows") + } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() doExecute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index b79d93d7ca4c9..c941d673c7248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -100,6 +100,8 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def canProcessUnsafeRows: Boolean = true + /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -257,16 +259,16 @@ case class Window( * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): UnsafeProjection = { + expressions: Seq[Expression]): MutableProjection = { val references = expressions.zipWithIndex.map{ case (e, i) => // Results of window expressions will be on the right side of child's output BoundReference(child.output.size + i, e.dataType, e.nullable) } val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( + newMutableProjection( projectList ++ patchedWindowExpression, - child.output) + child.output)() } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 01d076678f041..c4587ba677b2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -49,6 +49,10 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def requiredChildDistribution: List[Distribution] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 6501634ff998b..ac920aa8bc7f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,10 +87,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer - // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be - // compared to MutableRow (aggregation buffer) directly. - private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -114,7 +110,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -126,7 +122,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, safeProj(currentRow)) + processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 999ebb768af50..9d758eb3b7c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -49,6 +49,10 @@ case class TungstenAggregate( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def producedAttributes: AttributeSet = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index af7237ef25886..f19d72f067218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -36,6 +36,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -76,6 +80,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = true } /** @@ -98,6 +108,10 @@ case class Sample( { override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -121,6 +135,8 @@ case class Range( output: Seq[Attribute]) extends LeafNode { + override def outputsUnsafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { sqlContext .sparkContext @@ -183,6 +199,9 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -249,14 +268,12 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) + // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. + @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) - data.map(r => proj(r).copy()) - } else { - data - } + projection.map(data.map(_)).getOrElse(data) } override def executeCollect(): Array[InternalRow] = { @@ -294,6 +311,10 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { child.execute().coalesce(numPartitions, shuffle = false) } + + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -306,6 +327,10 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -318,6 +343,10 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -342,6 +371,10 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet + override def canProcessSafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) @@ -361,6 +394,11 @@ case class AppendColumns[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = AttributeSet(newColumns) + // We are using an unsafe combiner. + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { @@ -390,6 +428,10 @@ case class MapGroups[K, T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet + override def canProcessSafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -430,6 +472,10 @@ case class CoGroup[Key, Left, Right, Result]( right: SparkPlan) extends BinaryNode { override def producedAttributes: AttributeSet = outputSet + override def canProcessSafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index d80912309bab9..aa7a668e0e938 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -39,7 +39,9 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, + if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), + tableName)() } /** @@ -224,6 +226,8 @@ private[sql] case class InMemoryColumnarTableScan( // The cached version does not change the outputOrdering of the original SparkPlan. override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + override def outputsUnsafeRows: Boolean = true + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 54275c2cc1134..aab177b2e8427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -46,8 +46,15 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + private[this] def genResultProjection: InternalRow => InternalRow = { + if (outputsUnsafeRows) { UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index d9fa4c6b83798..81bfe4e67ca73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -81,6 +81,10 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7f9d9daa5ab20..fb961d97c3c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,6 +44,10 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + protected def buildSideKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6d464d6946b78..c6e5868187518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,6 +64,10 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + protected def buildKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 3e0f74cd98c21..f23a1830e91c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,6 +33,10 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + protected def leftKeyGenerator: Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 82498ee395649..efa7b49410edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -42,6 +42,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 812f881d06fb8..4bf7b521c77d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,6 +53,10 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index c3a2bfc59c7a4..7ce38ebdb3413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -89,6 +89,10 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + private def createLeftKeyGenerator(): Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index e46217050bad5..6a882c9234df4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin */ def close(): Unit + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the content through the [[Iterator]] interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index b7fa0c0202221..7321fc66b4dde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -47,7 +47,11 @@ case class NestedLoopJoinNode( } private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } } private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index efb4b09c16348..defcec95fb555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -351,6 +351,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -396,14 +400,13 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val unpickle = new Unpickler val row = new GenericMutableRow(1) val joined = new JoinedRow - val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - resultProj(joined(queue.poll(), row)) + joined(queue.poll(), row) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala new file mode 100644 index 0000000000000..5f8fc2de8b46d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Converts Java-object-based rows into [[UnsafeRow]]s. + */ +case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToUnsafe = UnsafeProjection.create(child.schema) + iter.map(convertToUnsafe) + } + } +} + +/** + * Converts [[UnsafeRow]]s back into Java-object-based rows. + */ +case class ConvertToSafe(child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) + iter.map(convertToSafe) + } + } +} + +private[sql] object EnsureRowFormats extends Rule[SparkPlan] { + + private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && !operator.canProcessUnsafeRows + + private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessUnsafeRows && !operator.canProcessSafeRows + + private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && operator.canProcessUnsafeRows + + override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { + case operator: SparkPlan if onlyHandlesSafeRows(operator) => + if (operator.children.exists(_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => + if (operator.children.exists(!_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => + if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows. + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 87bff3295f5be..911d12e93e503 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -28,7 +28,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => Exchange(SinglePartition, plan), + plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala new file mode 100644 index 0000000000000..faef76d52ae75 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType + +class ExpandSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private def testExpand(f: SparkPlan => SparkPlan): Unit = { + val input = (1 to 1000).map(Tuple1.apply) + val projections = Seq.tabulate(2) { i => + Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil + } + val attributes = projections.head.map(_.toAttribute) + checkAnswer( + input.toDF(), + plan => Expand(projections, attributes, f(plan)), + input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) + ) + } + + test("inheriting child row type") { + val exprs = AttributeReference("a", IntegerType, false)() :: Nil + val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) + assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") + } + + test("expanding UnsafeRows") { + testExpand(ConvertToUnsafe) + } + + test("expanding SafeRows") { + testExpand(identity) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala new file mode 100644 index 0000000000000..2328899bb2f8d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{ArrayType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { + + private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { + case c: ConvertToUnsafe => c + case c: ConvertToSafe => c + } + + private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + assert(!outputsSafe.outputsUnsafeRows) + private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + assert(outputsUnsafe.outputsUnsafeRows) + + test("planner should insert unsafe->safe conversions when required") { + val plan = Limit(10, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) + } + + test("filter can process unsafe rows") { + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("filter can process safe rows") { + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("coalesce can process unsafe rows") { + val plan = Coalesce(1, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except can process unsafe rows") { + val plan = Except(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except requires all of its input rows' formats to agree") { + val plan = Except(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect can process unsafe rows") { + val plan = Intersect(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect requires all of its input rows' formats to agree") { + val plan = Intersect(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("execute() fails an assertion if inputs rows are of different formats") { + val e = intercept[AssertionError] { + Union(Seq(outputsSafe, outputsUnsafe)).execute() + } + assert(e.getMessage.contains("format")) + } + + test("union requires all of its input rows' formats to agree") { + val plan = Union(Seq(outputsSafe, outputsUnsafe)) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("union can process safe rows") { + val plan = Union(Seq(outputsSafe, outputsSafe)) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("union can process unsafe rows") { + val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("round trip with ConvertToUnsafe and ConvertToSafe") { + val input = Seq(("hello", 1), ("world", 2)) + checkAnswer( + sqlContext.createDataFrame(input), + plan => ConvertToSafe(ConvertToUnsafe(plan)), + input.map(Row.fromTuple) + ) + } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SQLContext.setActive(sqlContext) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index af971dfc6faec..e5d34be4c65e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -99,7 +99,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { ) checkThatPlansAgree( inputDf, - p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), + p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), ReferenceSort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 1588728bdbaa4..8141136de5311 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -132,17 +132,11 @@ case class HiveTableScan( } } - protected override def doExecute(): RDD[InternalRow] = { - val rdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) - } - rdd.mapPartitionsInternal { iter => - val proj = UnsafeProjection.create(schema) - iter.map(proj) - } + protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 44dc68e6ba47f..f936cf565b2bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,17 +28,18 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, Attribute} -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType -import org.apache.spark.util.SerializableJobConf import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.util.SerializableJobConf private[hive] case class InsertIntoHiveTable( @@ -100,17 +101,15 @@ case class InsertIntoHiveTable( writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - val proj = FromUnsafeProjection(child.schema) iterator.foreach { row => var i = 0 - val safeRow = proj(row) while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) i += 1 } writerContainer - .getLocalFileWriter(safeRow, table.schema) + .getLocalFileWriter(row, table.schema) .write(serializer.serialize(outputData, standardOI)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 6ccd4178190cd..a61e162f48f1b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -213,8 +213,7 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => if (iter.hasNext) { - val proj = UnsafeProjection.create(schema) - processIterator(iter).map(proj) + processIterator(iter) } else { // If the input iterator has no rows then do not launch the external script. Iterator.empty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index efbf9988ddc13..665e87e3e3355 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -27,6 +27,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -688,6 +689,36 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } + + test("HadoopFsRelation produces UnsafeRow") { + withTempTable("test_unsafe") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.format(dataSourceName).save(path) + sqlContext.read + .format(dataSourceName) + .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) + .load(path) + .registerTempTable("test_unsafe") + + val df = sqlContext.sql( + """SELECT COUNT(*) + |FROM test_unsafe a JOIN test_unsafe b + |WHERE a.id = b.id + """.stripMargin) + + val plan = df.queryExecution.executedPlan + + assert( + plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, + s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): + |$plan + """.stripMargin) + + checkAnswer(df, Row(3)) + } + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when From 970635a9f8d7f64708a9fcae0b231c570f3f2c51 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 1 Jan 2016 23:22:50 -0800 Subject: [PATCH 1402/2089] [SPARK-12362][SQL][WIP] Inline Hive Parser This PR inlines the Hive SQL parser in Spark SQL. The previous (merged) incarnation of this PR passed all tests, but had and still has problems with the build. These problems are caused by a the fact that - for some reason - in some cases the ANTLR generated code is not included in the compilation fase. This PR is a WIP and should not be merged until we have sorted out the build issues. Author: Herman van Hovell Author: Nong Li Author: Nong Li Closes #10525 from hvanhovell/SPARK-12362. --- pom.xml | 5 + project/SparkBuild.scala | 46 +- project/plugins.sbt | 2 + .../execution/HiveCompatibilitySuite.scala | 10 +- sql/hive/pom.xml | 22 + .../spark/sql/parser/FromClauseParser.g | 330 +++ .../spark/sql/parser/IdentifiersParser.g | 697 +++++ .../spark/sql/parser/SelectClauseParser.g | 226 ++ .../apache/spark/sql/parser/SparkSqlLexer.g | 474 ++++ .../apache/spark/sql/parser/SparkSqlParser.g | 2457 +++++++++++++++++ .../apache/spark/sql/parser/ASTErrorNode.java | 49 + .../org/apache/spark/sql/parser/ASTNode.java | 245 ++ .../apache/spark/sql/parser/ParseDriver.java | 213 ++ .../apache/spark/sql/parser/ParseError.java | 54 + .../spark/sql/parser/ParseException.java | 51 + .../apache/spark/sql/parser/ParseUtils.java | 96 + .../spark/sql/parser/SemanticAnalyzer.java | 406 +++ .../org/apache/spark/sql/hive/HiveQl.scala | 133 +- 18 files changed, 5443 insertions(+), 73 deletions(-) create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g create mode 100644 sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java diff --git a/pom.xml b/pom.xml index 62ea829b1dbfd..398fcc92db994 100644 --- a/pom.xml +++ b/pom.xml @@ -1951,6 +1951,11 @@ + + org.antlr + antlr3-maven-plugin + 3.5.2 + org.apache.maven.plugins diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c3d53f835f395..588e97f64e054 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -414,9 +414,51 @@ object Hive { // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new // new query tests. - fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } - ) + fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") }, + // ANTLR code-generation step. + // + // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of + // build errors in the current plugin. + // Create Parser from ANTLR grammar files. + sourceGenerators in Compile += Def.task { + val log = streams.value.log + + val grammarFileNames = Seq( + "SparkSqlLexer.g", + "SparkSqlParser.g") + val sourceDir = (sourceDirectory in Compile).value / "antlr3" + val targetDir = (sourceManaged in Compile).value + + // Create default ANTLR Tool. + val antlr = new org.antlr.Tool + + // Setup input and output directories. + antlr.setInputDirectory(sourceDir.getPath) + antlr.setOutputDirectory(targetDir.getPath) + antlr.setForceRelativeOutput(true) + antlr.setMake(true) + + // Add grammar files. + grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => + val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath + log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) + antlr.addGrammarFile(relGFilePath) + } + // Generate the parser. + antlr.process + if (antlr.getNumErrors > 0) { + log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) + } + + // Return all generated java files. + (targetDir ** "*.java").get.toSeq + }.taskValue, + // Include ANTLR tokens files. + resourceGenerators in Compile += Def.task { + ((sourceManaged in Compile).value ** "*.tokens").get.toSeq + }.taskValue + ) } object Assembly { diff --git a/project/plugins.sbt b/project/plugins.sbt index 5e23224cf8aa5..15ba3a36d51ca 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -27,3 +27,5 @@ addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" + +libraryDependencies += "org.antlr" % "antlr" % "3.5.2" diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2d0d7b8af3581..2b0e48dbfcf28 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -308,7 +308,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The difference between the double numbers generated by Hive and Spark // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) - "udaf_corr" + "udaf_corr", + + // Feature removed in HIVE-11145 + "alter_partition_protect_mode", + "drop_partitions_ignore_protection", + "protectmode" ) /** @@ -328,7 +333,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_index", "alter_merge_2", "alter_partition_format_loc", - "alter_partition_protect_mode", "alter_partition_with_whitelist", "alter_rename_partition", "alter_table_serde", @@ -460,7 +464,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_filter", "drop_partitions_filter2", "drop_partitions_filter3", - "drop_partitions_ignore_protection", "drop_table", "drop_table2", "drop_table_removes_partition_dirs", @@ -778,7 +781,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", - "protectmode", "push_or", "query_with_semi", "quote1", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e9885f6682028..ffabb92179a18 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -232,6 +232,7 @@ v${hive.version.short}/src/main/scala + ${project.build.directory/generated-sources/antlr @@ -260,6 +261,27 @@ + + + + org.antlr + antlr3-maven-plugin + + + + antlr + + + + + ${basedir}/src/main/antlr3 + + **/SparkSqlLexer.g + **/SparkSqlParser.g + + + + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g new file mode 100644 index 0000000000000..e4a80f0ce8ebf --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g @@ -0,0 +1,330 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar FromClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +tableAllColumns + : STAR + -> ^(TOK_ALLCOLREF) + | tableName DOT STAR + -> ^(TOK_ALLCOLREF tableName) + ; + +// (table|column) +tableOrColumn +@init { gParent.pushMsg("table or column identifier", state); } +@after { gParent.popMsg(state); } + : + identifier -> ^(TOK_TABLE_OR_COL identifier) + ; + +expressionList +@init { gParent.pushMsg("expression list", state); } +@after { gParent.popMsg(state); } + : + expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) + ; + +aliasList +@init { gParent.pushMsg("alias list", state); } +@after { gParent.popMsg(state); } + : + identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) + ; + +//----------------------- Rules for parsing fromClause ------------------------------ +// from [col1, col2, col3] table1, [col4, col5] table2 +fromClause +@init { gParent.pushMsg("from clause", state); } +@after { gParent.popMsg(state); } + : + KW_FROM joinSource -> ^(TOK_FROM joinSource) + ; + +joinSource +@init { gParent.pushMsg("join source", state); } +@after { gParent.popMsg(state); } + : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )* + | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ + ; + +uniqueJoinSource +@init { gParent.pushMsg("unique join source", state); } +@after { gParent.popMsg(state); } + : KW_PRESERVE? fromSource uniqueJoinExpr + ; + +uniqueJoinExpr +@init { gParent.pushMsg("unique join expression list", state); } +@after { gParent.popMsg(state); } + : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN + -> ^(TOK_EXPLIST $e1*) + ; + +uniqueJoinToken +@init { gParent.pushMsg("unique join", state); } +@after { gParent.popMsg(state); } + : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; + +joinToken +@init { gParent.pushMsg("join type specifier", state); } +@after { gParent.popMsg(state); } + : + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + ; + +lateralView +@init {gParent.pushMsg("lateral view", state); } +@after {gParent.popMsg(state); } + : + (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + | + KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + ; + +tableAlias +@init {gParent.pushMsg("table alias", state); } +@after {gParent.popMsg(state); } + : + identifier -> ^(TOK_TABALIAS identifier) + ; + +fromSource +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + (LPAREN KW_VALUES) => fromSource0 + | (LPAREN) => LPAREN joinSource RPAREN -> joinSource + | fromSource0 + ; + + +fromSource0 +@init { gParent.pushMsg("from source 0", state); } +@after { gParent.popMsg(state); } + : + ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* + ; + +tableBucketSample +@init { gParent.pushMsg("table bucket sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) + ; + +splitSample +@init { gParent.pushMsg("table split sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN + -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) + -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) + | + KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN + -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) + ; + +tableSample +@init { gParent.pushMsg("table sample specification", state); } +@after { gParent.popMsg(state); } + : + tableBucketSample | + splitSample + ; + +tableSource +@init { gParent.pushMsg("table source", state); } +@after { gParent.popMsg(state); } + : tabname=tableName + ((tableProperties) => props=tableProperties)? + ((tableSample) => ts=tableSample)? + ((KW_AS) => (KW_AS alias=Identifier) + | + (Identifier) => (alias=Identifier))? + -> ^(TOK_TABREF $tabname $props? $ts? $alias?) + ; + +tableName +@init { gParent.pushMsg("table name", state); } +@after { gParent.popMsg(state); } + : + db=identifier DOT tab=identifier + -> ^(TOK_TABNAME $db $tab) + | + tab=identifier + -> ^(TOK_TABNAME $tab) + ; + +viewName +@init { gParent.pushMsg("view name", state); } +@after { gParent.popMsg(state); } + : + (db=identifier DOT)? view=identifier + -> ^(TOK_TABNAME $db? $view) + ; + +subQuerySource +@init { gParent.pushMsg("subquery source", state); } +@after { gParent.popMsg(state); } + : + LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) + ; + +//---------------------- Rules for parsing PTF clauses ----------------------------- +partitioningSpec +@init { gParent.pushMsg("partitioningSpec clause", state); } +@after { gParent.popMsg(state); } + : + partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | + orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | + distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | + sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | + clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) + ; + +partitionTableFunctionSource +@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } +@after { gParent.popMsg(state); } + : + subQuerySource | + tableSource | + partitionedTableFunction + ; + +partitionedTableFunction +@init { gParent.pushMsg("ptf clause", state); } +@after { gParent.popMsg(state); } + : + name=Identifier LPAREN KW_ON + ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) + ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? + ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? + -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) + ; + +//----------------------- Rules for parsing whereClause ----------------------------- +// where a=b and ... +whereClause +@init { gParent.pushMsg("where clause", state); } +@after { gParent.popMsg(state); } + : + KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) + ; + +searchCondition +@init { gParent.pushMsg("search condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +//----------------------------------------------------------------------------------- + +//-------- Row Constructor ---------------------------------------------------------- +//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and +// INSERT INTO
    spark.broadcast.factoryorg.apache.spark.broadcast.
    TorrentBroadcastFactory
    - Which broadcast implementation to use. -
    spark.cleaner.ttl (infinite)
    spark.broadcast.port(random) - Port for the driver's HTTP broadcast server to listen on. - This is not relevant for torrent broadcast. -
    spark.driver.host (local hostname)
    spark.fileserver.port Jetty-based. Only used if Akka RPC backend is configured.
    ExecutorDriver(random)HTTP Broadcastspark.broadcast.portJetty-based. Not used by TorrentBroadcast, which sends data through the block manager - instead.
    Executor / Driver Executor / Driver
    (col1,col2,...) VALUES(...),(...),... +// INSERT INTO
    (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) +valueRowConstructor +@init { gParent.pushMsg("value row constructor", state); } +@after { gParent.popMsg(state); } + : + LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) + ; + +valuesTableConstructor +@init { gParent.pushMsg("values table constructor", state); } +@after { gParent.popMsg(state); } + : + valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) + ; + +/* +VALUES(1),(2) means 2 rows, 1 column each. +VALUES(1,2),(3,4) means 2 rows, 2 columns each. +VALUES(1,2,3) means 1 row, 3 columns +*/ +valuesClause +@init { gParent.pushMsg("values clause", state); } +@after { gParent.popMsg(state); } + : + KW_VALUES valuesTableConstructor -> valuesTableConstructor + ; + +/* +This represents a clause like this: +(VALUES(1,2),(2,3)) as VirtTable(col1,col2) +*/ +virtualTableSource +@init { gParent.pushMsg("virtual table source", state); } +@after { gParent.popMsg(state); } + : + LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) + ; +/* +e.g. as VirtTable(col1,col2) +Note that we only want literals as column names +*/ +tableNameColList +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) + ; + +//----------------------------------------------------------------------------------- diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g new file mode 100644 index 0000000000000..5c3d7ef866240 --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g @@ -0,0 +1,697 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar IdentifiersParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +// group by a,b +groupByClause +@init { gParent.pushMsg("group by clause", state); } +@after { gParent.popMsg(state); } + : + KW_GROUP KW_BY + expression + ( COMMA expression)* + ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? + (sets=KW_GROUPING KW_SETS + LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? + -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) + -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) + -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) + -> ^(TOK_GROUPBY expression+) + ; + +groupingSetExpression +@init {gParent.pushMsg("grouping set expression", state); } +@after {gParent.popMsg(state); } + : + (LPAREN) => groupingSetExpressionMultiple + | + groupingExpressionSingle + ; + +groupingSetExpressionMultiple +@init {gParent.pushMsg("grouping set part expression", state); } +@after {gParent.popMsg(state); } + : + LPAREN + expression? (COMMA expression)* + RPAREN + -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) + ; + +groupingExpressionSingle +@init { gParent.pushMsg("groupingExpression expression", state); } +@after { gParent.popMsg(state); } + : + expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) + ; + +havingClause +@init { gParent.pushMsg("having clause", state); } +@after { gParent.popMsg(state); } + : + KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) + ; + +havingCondition +@init { gParent.pushMsg("having condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +expressionsInParenthese + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +expressionsNotInParenthese + : + expression (COMMA expression)* -> expression+ + ; + +columnRefOrderInParenthese + : + LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ + ; + +columnRefOrderNotInParenthese + : + columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ + ; + +// order by a,b +orderByClause +@init { gParent.pushMsg("order by clause", state); } +@after { gParent.popMsg(state); } + : + KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) + ; + +clusterByClause +@init { gParent.pushMsg("cluster by clause", state); } +@after { gParent.popMsg(state); } + : + KW_CLUSTER KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) + ) + ; + +partitionByClause +@init { gParent.pushMsg("partition by clause", state); } +@after { gParent.popMsg(state); } + : + KW_PARTITION KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +distributeByClause +@init { gParent.pushMsg("distribute by clause", state); } +@after { gParent.popMsg(state); } + : + KW_DISTRIBUTE KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +sortByClause +@init { gParent.pushMsg("sort by clause", state); } +@after { gParent.popMsg(state); } + : + KW_SORT KW_BY + ( + (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) + | + columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) + ) + ; + +// fun(par1, par2, par3) +function +@init { gParent.pushMsg("function specification", state); } +@after { gParent.popMsg(state); } + : + functionName + LPAREN + ( + (STAR) => (star=STAR) + | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? + ) + RPAREN (KW_OVER ws=window_specification)? + -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) + -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) + -> ^(TOK_FUNCTIONDI functionName (selectExpression+)?) + ; + +functionName +@init { gParent.pushMsg("function name", state); } +@after { gParent.popMsg(state); } + : // Keyword IF is also a function name + (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) + | + (functionIdentifier) => functionIdentifier + | + {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] + ; + +castExpression +@init { gParent.pushMsg("cast expression", state); } +@after { gParent.popMsg(state); } + : + KW_CAST + LPAREN + expression + KW_AS + primitiveType + RPAREN -> ^(TOK_FUNCTION primitiveType expression) + ; + +caseExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE expression + (KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_CASE expression*) + ; + +whenExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE + ( KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) + ; + +constant +@init { gParent.pushMsg("constant", state); } +@after { gParent.popMsg(state); } + : + Number + | dateLiteral + | timestampLiteral + | intervalLiteral + | StringLiteral + | stringLiteralSequence + | BigintLiteral + | SmallintLiteral + | TinyintLiteral + | DecimalLiteral + | charSetStringLiteral + | booleanValue + ; + +stringLiteralSequence + : + StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) + ; + +charSetStringLiteral +@init { gParent.pushMsg("character string literal", state); } +@after { gParent.popMsg(state); } + : + csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) + ; + +dateLiteral + : + KW_DATE StringLiteral -> + { + // Create DateLiteral token, but with the text of the string value + // This makes the dateLiteral more consistent with the other type literals. + adaptor.create(TOK_DATELITERAL, $StringLiteral.text) + } + | + KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) + ; + +timestampLiteral + : + KW_TIMESTAMP StringLiteral -> + { + adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) + } + | + KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) + ; + +intervalLiteral + : + KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> + { + adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + } + ; + +intervalQualifiers + : + KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL + | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL + | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL + | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL + | KW_DAY -> TOK_INTERVAL_DAY_LITERAL + | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL + | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL + | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + ; + +expression +@init { gParent.pushMsg("expression specification", state); } +@after { gParent.popMsg(state); } + : + precedenceOrExpression + ; + +atomExpression + : + (KW_NULL) => KW_NULL -> TOK_NULL + | (constant) => constant + | castExpression + | caseExpression + | whenExpression + | (functionName LPAREN) => function + | tableOrColumn + | LPAREN! expression RPAREN! + ; + + +precedenceFieldExpression + : + atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* + ; + +precedenceUnaryOperator + : + PLUS | MINUS | TILDE + ; + +nullCondition + : + KW_NULL -> ^(TOK_ISNULL) + | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) + ; + +precedenceUnaryPrefixExpression + : + (precedenceUnaryOperator^)* precedenceFieldExpression + ; + +precedenceUnarySuffixExpression + : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) + -> precedenceUnaryPrefixExpression + ; + + +precedenceBitwiseXorOperator + : + BITWISEXOR + ; + +precedenceBitwiseXorExpression + : + precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* + ; + + +precedenceStarOperator + : + STAR | DIVIDE | MOD | DIV + ; + +precedenceStarExpression + : + precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* + ; + + +precedencePlusOperator + : + PLUS | MINUS + ; + +precedencePlusExpression + : + precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* + ; + + +precedenceAmpersandOperator + : + AMPERSAND + ; + +precedenceAmpersandExpression + : + precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* + ; + + +precedenceBitwiseOrOperator + : + BITWISEOR + ; + +precedenceBitwiseOrExpression + : + precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* + ; + + +// Equal operators supporting NOT prefix +precedenceEqualNegatableOperator + : + KW_LIKE | KW_RLIKE | KW_REGEXP + ; + +precedenceEqualOperator + : + precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +subQueryExpression + : + LPAREN! selectStatement[true] RPAREN! + ; + +precedenceEqualExpression + : + (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple + | + precedenceEqualExpressionSingle + ; + +precedenceEqualExpressionSingle + : + (left=precedenceBitwiseOrExpression -> $left) + ( + (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) + -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) + | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) + -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) + | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) + -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) + | (KW_NOT KW_IN expressions) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) + | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) + | (KW_IN expressions) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) + | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) + | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) + )* + | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) + ; + +expressions + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) +precedenceEqualExpressionMutiple + : + (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) + ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) + | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) + ; + +expressionsToStruct + : + LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) + ; + +precedenceNotOperator + : + KW_NOT + ; + +precedenceNotExpression + : + (precedenceNotOperator^)* precedenceEqualExpression + ; + + +precedenceAndOperator + : + KW_AND + ; + +precedenceAndExpression + : + precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* + ; + + +precedenceOrOperator + : + KW_OR + ; + +precedenceOrExpression + : + precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* + ; + + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : db=identifier DOT fn=identifier + -> Identifier[$db.text + "." + $fn.text] + | + identifier + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g new file mode 100644 index 0000000000000..48bc8b0a300af --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g @@ -0,0 +1,226 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar SelectClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.errors.add(new ParseError(gParent, e, tokenNames)); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------- Rules for parsing selectClause ----------------------------- +// select a,b,c ... +selectClause +@init { gParent.pushMsg("select clause", state); } +@after { gParent.popMsg(state); } + : + KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) + | (transform=KW_TRANSFORM selectTrfmClause)) + -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) + -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) + -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) + | + trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) + ; + +selectList +@init { gParent.pushMsg("select list", state); } +@after { gParent.popMsg(state); } + : + selectItem ( COMMA selectItem )* -> selectItem+ + ; + +selectTrfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + LPAREN selectExpressionList RPAREN + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +hintClause +@init { gParent.pushMsg("hint clause", state); } +@after { gParent.popMsg(state); } + : + DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) + ; + +hintList +@init { gParent.pushMsg("hint list", state); } +@after { gParent.popMsg(state); } + : + hintItem (COMMA hintItem)* -> hintItem+ + ; + +hintItem +@init { gParent.pushMsg("hint item", state); } +@after { gParent.popMsg(state); } + : + hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) + ; + +hintName +@init { gParent.pushMsg("hint name", state); } +@after { gParent.popMsg(state); } + : + KW_MAPJOIN -> TOK_MAPJOIN + | KW_STREAMTABLE -> TOK_STREAMTABLE + ; + +hintArgs +@init { gParent.pushMsg("hint arguments", state); } +@after { gParent.popMsg(state); } + : + hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) + ; + +hintArgName +@init { gParent.pushMsg("hint argument name", state); } +@after { gParent.popMsg(state); } + : + identifier + ; + +selectItem +@init { gParent.pushMsg("selection target", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) + | + ( expression + ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? + ) -> ^(TOK_SELEXPR expression identifier*) + ; + +trfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + ( KW_MAP selectExpressionList + | KW_REDUCE selectExpressionList ) + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +selectExpression +@init { gParent.pushMsg("select expression", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns + | + expression + ; + +selectExpressionList +@init { gParent.pushMsg("select expression list", state); } +@after { gParent.popMsg(state); } + : + selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) + ; + +//---------------------- Rules for windowing clauses ------------------------------- +window_clause +@init { gParent.pushMsg("window_clause", state); } +@after { gParent.popMsg(state); } +: + KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) +; + +window_defn +@init { gParent.pushMsg("window_defn", state); } +@after { gParent.popMsg(state); } +: + Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) +; + +window_specification +@init { gParent.pushMsg("window_specification", state); } +@after { gParent.popMsg(state); } +: + (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) +; + +window_frame : + window_range_expression | + window_value_expression +; + +window_range_expression +@init { gParent.pushMsg("window_range_expression", state); } +@after { gParent.popMsg(state); } +: + KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | + KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) +; + +window_value_expression +@init { gParent.pushMsg("window_value_expression", state); } +@after { gParent.popMsg(state); } +: + KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | + KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) +; + +window_frame_start_boundary +@init { gParent.pushMsg("windowframestartboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number KW_PRECEDING -> ^(KW_PRECEDING Number) +; + +window_frame_boundary +@init { gParent.pushMsg("windowframeboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) +; + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g new file mode 100644 index 0000000000000..ee1b8989b5aff --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g @@ -0,0 +1,474 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +lexer grammar SparkSqlLexer; + +@lexer::header { +package org.apache.spark.sql.parser; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +} + +@lexer::members { + private Configuration hiveConf; + + public void setHiveConf(Configuration hiveConf) { + this.hiveConf = hiveConf; + } + + protected boolean allowQuotedId() { + String supportedQIds = HiveConf.getVar(hiveConf, HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT); + return !"none".equals(supportedQIds); + } +} + +// Keywords + +KW_TRUE : 'TRUE'; +KW_FALSE : 'FALSE'; +KW_ALL : 'ALL'; +KW_NONE: 'NONE'; +KW_AND : 'AND'; +KW_OR : 'OR'; +KW_NOT : 'NOT' | '!'; +KW_LIKE : 'LIKE'; + +KW_IF : 'IF'; +KW_EXISTS : 'EXISTS'; + +KW_ASC : 'ASC'; +KW_DESC : 'DESC'; +KW_ORDER : 'ORDER'; +KW_GROUP : 'GROUP'; +KW_BY : 'BY'; +KW_HAVING : 'HAVING'; +KW_WHERE : 'WHERE'; +KW_FROM : 'FROM'; +KW_AS : 'AS'; +KW_SELECT : 'SELECT'; +KW_DISTINCT : 'DISTINCT'; +KW_INSERT : 'INSERT'; +KW_OVERWRITE : 'OVERWRITE'; +KW_OUTER : 'OUTER'; +KW_UNIQUEJOIN : 'UNIQUEJOIN'; +KW_PRESERVE : 'PRESERVE'; +KW_JOIN : 'JOIN'; +KW_LEFT : 'LEFT'; +KW_RIGHT : 'RIGHT'; +KW_FULL : 'FULL'; +KW_ANTI : 'ANTI'; +KW_ON : 'ON'; +KW_PARTITION : 'PARTITION'; +KW_PARTITIONS : 'PARTITIONS'; +KW_TABLE: 'TABLE'; +KW_TABLES: 'TABLES'; +KW_COLUMNS: 'COLUMNS'; +KW_INDEX: 'INDEX'; +KW_INDEXES: 'INDEXES'; +KW_REBUILD: 'REBUILD'; +KW_FUNCTIONS: 'FUNCTIONS'; +KW_SHOW: 'SHOW'; +KW_MSCK: 'MSCK'; +KW_REPAIR: 'REPAIR'; +KW_DIRECTORY: 'DIRECTORY'; +KW_LOCAL: 'LOCAL'; +KW_TRANSFORM : 'TRANSFORM'; +KW_USING: 'USING'; +KW_CLUSTER: 'CLUSTER'; +KW_DISTRIBUTE: 'DISTRIBUTE'; +KW_SORT: 'SORT'; +KW_UNION: 'UNION'; +KW_LOAD: 'LOAD'; +KW_EXPORT: 'EXPORT'; +KW_IMPORT: 'IMPORT'; +KW_REPLICATION: 'REPLICATION'; +KW_METADATA: 'METADATA'; +KW_DATA: 'DATA'; +KW_INPATH: 'INPATH'; +KW_IS: 'IS'; +KW_NULL: 'NULL'; +KW_CREATE: 'CREATE'; +KW_EXTERNAL: 'EXTERNAL'; +KW_ALTER: 'ALTER'; +KW_CHANGE: 'CHANGE'; +KW_COLUMN: 'COLUMN'; +KW_FIRST: 'FIRST'; +KW_AFTER: 'AFTER'; +KW_DESCRIBE: 'DESCRIBE'; +KW_DROP: 'DROP'; +KW_RENAME: 'RENAME'; +KW_TO: 'TO'; +KW_COMMENT: 'COMMENT'; +KW_BOOLEAN: 'BOOLEAN'; +KW_TINYINT: 'TINYINT'; +KW_SMALLINT: 'SMALLINT'; +KW_INT: 'INT'; +KW_BIGINT: 'BIGINT'; +KW_FLOAT: 'FLOAT'; +KW_DOUBLE: 'DOUBLE'; +KW_DATE: 'DATE'; +KW_DATETIME: 'DATETIME'; +KW_TIMESTAMP: 'TIMESTAMP'; +KW_INTERVAL: 'INTERVAL'; +KW_DECIMAL: 'DECIMAL'; +KW_STRING: 'STRING'; +KW_CHAR: 'CHAR'; +KW_VARCHAR: 'VARCHAR'; +KW_ARRAY: 'ARRAY'; +KW_STRUCT: 'STRUCT'; +KW_MAP: 'MAP'; +KW_UNIONTYPE: 'UNIONTYPE'; +KW_REDUCE: 'REDUCE'; +KW_PARTITIONED: 'PARTITIONED'; +KW_CLUSTERED: 'CLUSTERED'; +KW_SORTED: 'SORTED'; +KW_INTO: 'INTO'; +KW_BUCKETS: 'BUCKETS'; +KW_ROW: 'ROW'; +KW_ROWS: 'ROWS'; +KW_FORMAT: 'FORMAT'; +KW_DELIMITED: 'DELIMITED'; +KW_FIELDS: 'FIELDS'; +KW_TERMINATED: 'TERMINATED'; +KW_ESCAPED: 'ESCAPED'; +KW_COLLECTION: 'COLLECTION'; +KW_ITEMS: 'ITEMS'; +KW_KEYS: 'KEYS'; +KW_KEY_TYPE: '$KEY$'; +KW_LINES: 'LINES'; +KW_STORED: 'STORED'; +KW_FILEFORMAT: 'FILEFORMAT'; +KW_INPUTFORMAT: 'INPUTFORMAT'; +KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; +KW_INPUTDRIVER: 'INPUTDRIVER'; +KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; +KW_ENABLE: 'ENABLE'; +KW_DISABLE: 'DISABLE'; +KW_LOCATION: 'LOCATION'; +KW_TABLESAMPLE: 'TABLESAMPLE'; +KW_BUCKET: 'BUCKET'; +KW_OUT: 'OUT'; +KW_OF: 'OF'; +KW_PERCENT: 'PERCENT'; +KW_CAST: 'CAST'; +KW_ADD: 'ADD'; +KW_REPLACE: 'REPLACE'; +KW_RLIKE: 'RLIKE'; +KW_REGEXP: 'REGEXP'; +KW_TEMPORARY: 'TEMPORARY'; +KW_FUNCTION: 'FUNCTION'; +KW_MACRO: 'MACRO'; +KW_FILE: 'FILE'; +KW_JAR: 'JAR'; +KW_EXPLAIN: 'EXPLAIN'; +KW_EXTENDED: 'EXTENDED'; +KW_FORMATTED: 'FORMATTED'; +KW_PRETTY: 'PRETTY'; +KW_DEPENDENCY: 'DEPENDENCY'; +KW_LOGICAL: 'LOGICAL'; +KW_SERDE: 'SERDE'; +KW_WITH: 'WITH'; +KW_DEFERRED: 'DEFERRED'; +KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; +KW_DBPROPERTIES: 'DBPROPERTIES'; +KW_LIMIT: 'LIMIT'; +KW_SET: 'SET'; +KW_UNSET: 'UNSET'; +KW_TBLPROPERTIES: 'TBLPROPERTIES'; +KW_IDXPROPERTIES: 'IDXPROPERTIES'; +KW_VALUE_TYPE: '$VALUE$'; +KW_ELEM_TYPE: '$ELEM$'; +KW_DEFINED: 'DEFINED'; +KW_CASE: 'CASE'; +KW_WHEN: 'WHEN'; +KW_THEN: 'THEN'; +KW_ELSE: 'ELSE'; +KW_END: 'END'; +KW_MAPJOIN: 'MAPJOIN'; +KW_STREAMTABLE: 'STREAMTABLE'; +KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; +KW_UTC: 'UTC'; +KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; +KW_LONG: 'LONG'; +KW_DELETE: 'DELETE'; +KW_PLUS: 'PLUS'; +KW_MINUS: 'MINUS'; +KW_FETCH: 'FETCH'; +KW_INTERSECT: 'INTERSECT'; +KW_VIEW: 'VIEW'; +KW_IN: 'IN'; +KW_DATABASE: 'DATABASE'; +KW_DATABASES: 'DATABASES'; +KW_MATERIALIZED: 'MATERIALIZED'; +KW_SCHEMA: 'SCHEMA'; +KW_SCHEMAS: 'SCHEMAS'; +KW_GRANT: 'GRANT'; +KW_REVOKE: 'REVOKE'; +KW_SSL: 'SSL'; +KW_UNDO: 'UNDO'; +KW_LOCK: 'LOCK'; +KW_LOCKS: 'LOCKS'; +KW_UNLOCK: 'UNLOCK'; +KW_SHARED: 'SHARED'; +KW_EXCLUSIVE: 'EXCLUSIVE'; +KW_PROCEDURE: 'PROCEDURE'; +KW_UNSIGNED: 'UNSIGNED'; +KW_WHILE: 'WHILE'; +KW_READ: 'READ'; +KW_READS: 'READS'; +KW_PURGE: 'PURGE'; +KW_RANGE: 'RANGE'; +KW_ANALYZE: 'ANALYZE'; +KW_BEFORE: 'BEFORE'; +KW_BETWEEN: 'BETWEEN'; +KW_BOTH: 'BOTH'; +KW_BINARY: 'BINARY'; +KW_CROSS: 'CROSS'; +KW_CONTINUE: 'CONTINUE'; +KW_CURSOR: 'CURSOR'; +KW_TRIGGER: 'TRIGGER'; +KW_RECORDREADER: 'RECORDREADER'; +KW_RECORDWRITER: 'RECORDWRITER'; +KW_SEMI: 'SEMI'; +KW_LATERAL: 'LATERAL'; +KW_TOUCH: 'TOUCH'; +KW_ARCHIVE: 'ARCHIVE'; +KW_UNARCHIVE: 'UNARCHIVE'; +KW_COMPUTE: 'COMPUTE'; +KW_STATISTICS: 'STATISTICS'; +KW_USE: 'USE'; +KW_OPTION: 'OPTION'; +KW_CONCATENATE: 'CONCATENATE'; +KW_SHOW_DATABASE: 'SHOW_DATABASE'; +KW_UPDATE: 'UPDATE'; +KW_RESTRICT: 'RESTRICT'; +KW_CASCADE: 'CASCADE'; +KW_SKEWED: 'SKEWED'; +KW_ROLLUP: 'ROLLUP'; +KW_CUBE: 'CUBE'; +KW_DIRECTORIES: 'DIRECTORIES'; +KW_FOR: 'FOR'; +KW_WINDOW: 'WINDOW'; +KW_UNBOUNDED: 'UNBOUNDED'; +KW_PRECEDING: 'PRECEDING'; +KW_FOLLOWING: 'FOLLOWING'; +KW_CURRENT: 'CURRENT'; +KW_CURRENT_DATE: 'CURRENT_DATE'; +KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +KW_LESS: 'LESS'; +KW_MORE: 'MORE'; +KW_OVER: 'OVER'; +KW_GROUPING: 'GROUPING'; +KW_SETS: 'SETS'; +KW_TRUNCATE: 'TRUNCATE'; +KW_NOSCAN: 'NOSCAN'; +KW_PARTIALSCAN: 'PARTIALSCAN'; +KW_USER: 'USER'; +KW_ROLE: 'ROLE'; +KW_ROLES: 'ROLES'; +KW_INNER: 'INNER'; +KW_EXCHANGE: 'EXCHANGE'; +KW_URI: 'URI'; +KW_SERVER : 'SERVER'; +KW_ADMIN: 'ADMIN'; +KW_OWNER: 'OWNER'; +KW_PRINCIPALS: 'PRINCIPALS'; +KW_COMPACT: 'COMPACT'; +KW_COMPACTIONS: 'COMPACTIONS'; +KW_TRANSACTIONS: 'TRANSACTIONS'; +KW_REWRITE : 'REWRITE'; +KW_AUTHORIZATION: 'AUTHORIZATION'; +KW_CONF: 'CONF'; +KW_VALUES: 'VALUES'; +KW_RELOAD: 'RELOAD'; +KW_YEAR: 'YEAR'; +KW_MONTH: 'MONTH'; +KW_DAY: 'DAY'; +KW_HOUR: 'HOUR'; +KW_MINUTE: 'MINUTE'; +KW_SECOND: 'SECOND'; +KW_START: 'START'; +KW_TRANSACTION: 'TRANSACTION'; +KW_COMMIT: 'COMMIT'; +KW_ROLLBACK: 'ROLLBACK'; +KW_WORK: 'WORK'; +KW_ONLY: 'ONLY'; +KW_WRITE: 'WRITE'; +KW_ISOLATION: 'ISOLATION'; +KW_LEVEL: 'LEVEL'; +KW_SNAPSHOT: 'SNAPSHOT'; +KW_AUTOCOMMIT: 'AUTOCOMMIT'; + +// Operators +// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. + +DOT : '.'; // generated as a part of Number rule +COLON : ':' ; +COMMA : ',' ; +SEMICOLON : ';' ; + +LPAREN : '(' ; +RPAREN : ')' ; +LSQUARE : '[' ; +RSQUARE : ']' ; +LCURLY : '{'; +RCURLY : '}'; + +EQUAL : '=' | '=='; +EQUAL_NS : '<=>'; +NOTEQUAL : '<>' | '!='; +LESSTHANOREQUALTO : '<='; +LESSTHAN : '<'; +GREATERTHANOREQUALTO : '>='; +GREATERTHAN : '>'; + +DIVIDE : '/'; +PLUS : '+'; +MINUS : '-'; +STAR : '*'; +MOD : '%'; +DIV : 'DIV'; + +AMPERSAND : '&'; +TILDE : '~'; +BITWISEOR : '|'; +BITWISEXOR : '^'; +QUESTION : '?'; +DOLLAR : '$'; + +// LITERALS +fragment +Letter + : 'a'..'z' | 'A'..'Z' + ; + +fragment +HexDigit + : 'a'..'f' | 'A'..'F' + ; + +fragment +Digit + : + '0'..'9' + ; + +fragment +Exponent + : + ('e' | 'E') ( PLUS|MINUS )? (Digit)+ + ; + +fragment +RegexComponent + : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' + | PLUS | STAR | QUESTION | MINUS | DOT + | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY + | BITWISEXOR | BITWISEOR | DOLLAR | '!' + ; + +StringLiteral + : + ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + )+ + ; + +CharSetLiteral + : + StringLiteral + | '0' 'X' (HexDigit|Digit)+ + ; + +BigintLiteral + : + (Digit)+ 'L' + ; + +SmallintLiteral + : + (Digit)+ 'S' + ; + +TinyintLiteral + : + (Digit)+ 'Y' + ; + +DecimalLiteral + : + Number 'B' 'D' + ; + +ByteLengthLiteral + : + (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') + ; + +Number + : + (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? + ; + +/* +An Identifier can be: +- tableName +- columnName +- select expr alias +- lateral view aliases +- database name +- view name +- subquery alias +- function name +- ptf argument identifier +- index name +- property name for: db,tbl,partition... +- fileFormat +- role name +- privilege name +- principal name +- macro name +- hint name +- window name +*/ +Identifier + : + (Letter | Digit) (Letter | Digit | '_')* + | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; + at the API level only columns are allowed to be of this form */ + | '`' RegexComponent+ '`' + ; + +fragment +QuotedIdentifier + : + '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + ; + +CharSetName + : + '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ + ; + +WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} + ; + +COMMENT + : '--' (~('\n'|'\r'))* + { $channel=HIDDEN; } + ; + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g new file mode 100644 index 0000000000000..69574d713d0be --- /dev/null +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g @@ -0,0 +1,2457 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +parser grammar SparkSqlParser; + +options +{ +tokenVocab=SparkSqlLexer; +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} +import SelectClauseParser, FromClauseParser, IdentifiersParser; + +tokens { +TOK_INSERT; +TOK_QUERY; +TOK_SELECT; +TOK_SELECTDI; +TOK_SELEXPR; +TOK_FROM; +TOK_TAB; +TOK_PARTSPEC; +TOK_PARTVAL; +TOK_DIR; +TOK_TABREF; +TOK_SUBQUERY; +TOK_INSERT_INTO; +TOK_DESTINATION; +TOK_ALLCOLREF; +TOK_TABLE_OR_COL; +TOK_FUNCTION; +TOK_FUNCTIONDI; +TOK_FUNCTIONSTAR; +TOK_WHERE; +TOK_OP_EQ; +TOK_OP_NE; +TOK_OP_LE; +TOK_OP_LT; +TOK_OP_GE; +TOK_OP_GT; +TOK_OP_DIV; +TOK_OP_ADD; +TOK_OP_SUB; +TOK_OP_MUL; +TOK_OP_MOD; +TOK_OP_BITAND; +TOK_OP_BITNOT; +TOK_OP_BITOR; +TOK_OP_BITXOR; +TOK_OP_AND; +TOK_OP_OR; +TOK_OP_NOT; +TOK_OP_LIKE; +TOK_TRUE; +TOK_FALSE; +TOK_TRANSFORM; +TOK_SERDE; +TOK_SERDENAME; +TOK_SERDEPROPS; +TOK_EXPLIST; +TOK_ALIASLIST; +TOK_GROUPBY; +TOK_ROLLUP_GROUPBY; +TOK_CUBE_GROUPBY; +TOK_GROUPING_SETS; +TOK_GROUPING_SETS_EXPRESSION; +TOK_HAVING; +TOK_ORDERBY; +TOK_CLUSTERBY; +TOK_DISTRIBUTEBY; +TOK_SORTBY; +TOK_UNIONALL; +TOK_UNIONDISTINCT; +TOK_JOIN; +TOK_LEFTOUTERJOIN; +TOK_RIGHTOUTERJOIN; +TOK_FULLOUTERJOIN; +TOK_UNIQUEJOIN; +TOK_CROSSJOIN; +TOK_LOAD; +TOK_EXPORT; +TOK_IMPORT; +TOK_REPLICATION; +TOK_METADATA; +TOK_NULL; +TOK_ISNULL; +TOK_ISNOTNULL; +TOK_TINYINT; +TOK_SMALLINT; +TOK_INT; +TOK_BIGINT; +TOK_BOOLEAN; +TOK_FLOAT; +TOK_DOUBLE; +TOK_DATE; +TOK_DATELITERAL; +TOK_DATETIME; +TOK_TIMESTAMP; +TOK_TIMESTAMPLITERAL; +TOK_INTERVAL_YEAR_MONTH; +TOK_INTERVAL_YEAR_MONTH_LITERAL; +TOK_INTERVAL_DAY_TIME; +TOK_INTERVAL_DAY_TIME_LITERAL; +TOK_INTERVAL_YEAR_LITERAL; +TOK_INTERVAL_MONTH_LITERAL; +TOK_INTERVAL_DAY_LITERAL; +TOK_INTERVAL_HOUR_LITERAL; +TOK_INTERVAL_MINUTE_LITERAL; +TOK_INTERVAL_SECOND_LITERAL; +TOK_STRING; +TOK_CHAR; +TOK_VARCHAR; +TOK_BINARY; +TOK_DECIMAL; +TOK_LIST; +TOK_STRUCT; +TOK_MAP; +TOK_UNIONTYPE; +TOK_COLTYPELIST; +TOK_CREATEDATABASE; +TOK_CREATETABLE; +TOK_TRUNCATETABLE; +TOK_CREATEINDEX; +TOK_CREATEINDEX_INDEXTBLNAME; +TOK_DEFERRED_REBUILDINDEX; +TOK_DROPINDEX; +TOK_LIKETABLE; +TOK_DESCTABLE; +TOK_DESCFUNCTION; +TOK_ALTERTABLE; +TOK_ALTERTABLE_RENAME; +TOK_ALTERTABLE_ADDCOLS; +TOK_ALTERTABLE_RENAMECOL; +TOK_ALTERTABLE_RENAMEPART; +TOK_ALTERTABLE_REPLACECOLS; +TOK_ALTERTABLE_ADDPARTS; +TOK_ALTERTABLE_DROPPARTS; +TOK_ALTERTABLE_PARTCOLTYPE; +TOK_ALTERTABLE_MERGEFILES; +TOK_ALTERTABLE_TOUCH; +TOK_ALTERTABLE_ARCHIVE; +TOK_ALTERTABLE_UNARCHIVE; +TOK_ALTERTABLE_SERDEPROPERTIES; +TOK_ALTERTABLE_SERIALIZER; +TOK_ALTERTABLE_UPDATECOLSTATS; +TOK_TABLE_PARTITION; +TOK_ALTERTABLE_FILEFORMAT; +TOK_ALTERTABLE_LOCATION; +TOK_ALTERTABLE_PROPERTIES; +TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; +TOK_ALTERTABLE_DROPPROPERTIES; +TOK_ALTERTABLE_SKEWED; +TOK_ALTERTABLE_EXCHANGEPARTITION; +TOK_ALTERTABLE_SKEWED_LOCATION; +TOK_ALTERTABLE_BUCKETS; +TOK_ALTERTABLE_CLUSTER_SORT; +TOK_ALTERTABLE_COMPACT; +TOK_ALTERINDEX_REBUILD; +TOK_ALTERINDEX_PROPERTIES; +TOK_MSCK; +TOK_SHOWDATABASES; +TOK_SHOWTABLES; +TOK_SHOWCOLUMNS; +TOK_SHOWFUNCTIONS; +TOK_SHOWPARTITIONS; +TOK_SHOW_CREATEDATABASE; +TOK_SHOW_CREATETABLE; +TOK_SHOW_TABLESTATUS; +TOK_SHOW_TBLPROPERTIES; +TOK_SHOWLOCKS; +TOK_SHOWCONF; +TOK_LOCKTABLE; +TOK_UNLOCKTABLE; +TOK_LOCKDB; +TOK_UNLOCKDB; +TOK_SWITCHDATABASE; +TOK_DROPDATABASE; +TOK_DROPTABLE; +TOK_DATABASECOMMENT; +TOK_TABCOLLIST; +TOK_TABCOL; +TOK_TABLECOMMENT; +TOK_TABLEPARTCOLS; +TOK_TABLEROWFORMAT; +TOK_TABLEROWFORMATFIELD; +TOK_TABLEROWFORMATCOLLITEMS; +TOK_TABLEROWFORMATMAPKEYS; +TOK_TABLEROWFORMATLINES; +TOK_TABLEROWFORMATNULL; +TOK_TABLEFILEFORMAT; +TOK_FILEFORMAT_GENERIC; +TOK_OFFLINE; +TOK_ENABLE; +TOK_DISABLE; +TOK_READONLY; +TOK_NO_DROP; +TOK_STORAGEHANDLER; +TOK_NOT_CLUSTERED; +TOK_NOT_SORTED; +TOK_TABCOLNAME; +TOK_TABLELOCATION; +TOK_PARTITIONLOCATION; +TOK_TABLEBUCKETSAMPLE; +TOK_TABLESPLITSAMPLE; +TOK_PERCENT; +TOK_LENGTH; +TOK_ROWCOUNT; +TOK_TMP_FILE; +TOK_TABSORTCOLNAMEASC; +TOK_TABSORTCOLNAMEDESC; +TOK_STRINGLITERALSEQUENCE; +TOK_CHARSETLITERAL; +TOK_CREATEFUNCTION; +TOK_DROPFUNCTION; +TOK_RELOADFUNCTION; +TOK_CREATEMACRO; +TOK_DROPMACRO; +TOK_TEMPORARY; +TOK_CREATEVIEW; +TOK_DROPVIEW; +TOK_ALTERVIEW; +TOK_ALTERVIEW_PROPERTIES; +TOK_ALTERVIEW_DROPPROPERTIES; +TOK_ALTERVIEW_ADDPARTS; +TOK_ALTERVIEW_DROPPARTS; +TOK_ALTERVIEW_RENAME; +TOK_VIEWPARTCOLS; +TOK_EXPLAIN; +TOK_EXPLAIN_SQ_REWRITE; +TOK_TABLESERIALIZER; +TOK_TABLEPROPERTIES; +TOK_TABLEPROPLIST; +TOK_INDEXPROPERTIES; +TOK_INDEXPROPLIST; +TOK_TABTYPE; +TOK_LIMIT; +TOK_TABLEPROPERTY; +TOK_IFEXISTS; +TOK_IFNOTEXISTS; +TOK_ORREPLACE; +TOK_HINTLIST; +TOK_HINT; +TOK_MAPJOIN; +TOK_STREAMTABLE; +TOK_HINTARGLIST; +TOK_USERSCRIPTCOLNAMES; +TOK_USERSCRIPTCOLSCHEMA; +TOK_RECORDREADER; +TOK_RECORDWRITER; +TOK_LEFTSEMIJOIN; +TOK_ANTIJOIN; +TOK_LATERAL_VIEW; +TOK_LATERAL_VIEW_OUTER; +TOK_TABALIAS; +TOK_ANALYZE; +TOK_CREATEROLE; +TOK_DROPROLE; +TOK_GRANT; +TOK_REVOKE; +TOK_SHOW_GRANT; +TOK_PRIVILEGE_LIST; +TOK_PRIVILEGE; +TOK_PRINCIPAL_NAME; +TOK_USER; +TOK_GROUP; +TOK_ROLE; +TOK_RESOURCE_ALL; +TOK_GRANT_WITH_OPTION; +TOK_GRANT_WITH_ADMIN_OPTION; +TOK_ADMIN_OPTION_FOR; +TOK_GRANT_OPTION_FOR; +TOK_PRIV_ALL; +TOK_PRIV_ALTER_METADATA; +TOK_PRIV_ALTER_DATA; +TOK_PRIV_DELETE; +TOK_PRIV_DROP; +TOK_PRIV_INDEX; +TOK_PRIV_INSERT; +TOK_PRIV_LOCK; +TOK_PRIV_SELECT; +TOK_PRIV_SHOW_DATABASE; +TOK_PRIV_CREATE; +TOK_PRIV_OBJECT; +TOK_PRIV_OBJECT_COL; +TOK_GRANT_ROLE; +TOK_REVOKE_ROLE; +TOK_SHOW_ROLE_GRANT; +TOK_SHOW_ROLES; +TOK_SHOW_SET_ROLE; +TOK_SHOW_ROLE_PRINCIPALS; +TOK_SHOWINDEXES; +TOK_SHOWDBLOCKS; +TOK_INDEXCOMMENT; +TOK_DESCDATABASE; +TOK_DATABASEPROPERTIES; +TOK_DATABASELOCATION; +TOK_DBPROPLIST; +TOK_ALTERDATABASE_PROPERTIES; +TOK_ALTERDATABASE_OWNER; +TOK_TABNAME; +TOK_TABSRC; +TOK_RESTRICT; +TOK_CASCADE; +TOK_TABLESKEWED; +TOK_TABCOLVALUE; +TOK_TABCOLVALUE_PAIR; +TOK_TABCOLVALUES; +TOK_SKEWED_LOCATIONS; +TOK_SKEWED_LOCATION_LIST; +TOK_SKEWED_LOCATION_MAP; +TOK_STOREDASDIRS; +TOK_PARTITIONINGSPEC; +TOK_PTBLFUNCTION; +TOK_WINDOWDEF; +TOK_WINDOWSPEC; +TOK_WINDOWVALUES; +TOK_WINDOWRANGE; +TOK_SUBQUERY_EXPR; +TOK_SUBQUERY_OP; +TOK_SUBQUERY_OP_NOTIN; +TOK_SUBQUERY_OP_NOTEXISTS; +TOK_DB_TYPE; +TOK_TABLE_TYPE; +TOK_CTE; +TOK_ARCHIVE; +TOK_FILE; +TOK_JAR; +TOK_RESOURCE_URI; +TOK_RESOURCE_LIST; +TOK_SHOW_COMPACTIONS; +TOK_SHOW_TRANSACTIONS; +TOK_DELETE_FROM; +TOK_UPDATE_TABLE; +TOK_SET_COLUMNS_CLAUSE; +TOK_VALUE_ROW; +TOK_VALUES_TABLE; +TOK_VIRTUAL_TABLE; +TOK_VIRTUAL_TABREF; +TOK_ANONYMOUS; +TOK_COL_NAME; +TOK_URI_TYPE; +TOK_SERVER_TYPE; +TOK_START_TRANSACTION; +TOK_ISOLATION_LEVEL; +TOK_ISOLATION_SNAPSHOT; +TOK_TXN_ACCESS_MODE; +TOK_TXN_READ_ONLY; +TOK_TXN_READ_WRITE; +TOK_COMMIT; +TOK_ROLLBACK; +TOK_SET_AUTOCOMMIT; +} + + +// Package headers +@header { +package org.apache.spark.sql.parser; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +} + + +@members { + ArrayList errors = new ArrayList(); + Stack msgs = new Stack(); + + private static HashMap xlateMap; + static { + //this is used to support auto completion in CLI + xlateMap = new HashMap(); + + // Keywords + xlateMap.put("KW_TRUE", "TRUE"); + xlateMap.put("KW_FALSE", "FALSE"); + xlateMap.put("KW_ALL", "ALL"); + xlateMap.put("KW_NONE", "NONE"); + xlateMap.put("KW_AND", "AND"); + xlateMap.put("KW_OR", "OR"); + xlateMap.put("KW_NOT", "NOT"); + xlateMap.put("KW_LIKE", "LIKE"); + + xlateMap.put("KW_ASC", "ASC"); + xlateMap.put("KW_DESC", "DESC"); + xlateMap.put("KW_ORDER", "ORDER"); + xlateMap.put("KW_BY", "BY"); + xlateMap.put("KW_GROUP", "GROUP"); + xlateMap.put("KW_WHERE", "WHERE"); + xlateMap.put("KW_FROM", "FROM"); + xlateMap.put("KW_AS", "AS"); + xlateMap.put("KW_SELECT", "SELECT"); + xlateMap.put("KW_DISTINCT", "DISTINCT"); + xlateMap.put("KW_INSERT", "INSERT"); + xlateMap.put("KW_OVERWRITE", "OVERWRITE"); + xlateMap.put("KW_OUTER", "OUTER"); + xlateMap.put("KW_JOIN", "JOIN"); + xlateMap.put("KW_LEFT", "LEFT"); + xlateMap.put("KW_RIGHT", "RIGHT"); + xlateMap.put("KW_FULL", "FULL"); + xlateMap.put("KW_ON", "ON"); + xlateMap.put("KW_PARTITION", "PARTITION"); + xlateMap.put("KW_PARTITIONS", "PARTITIONS"); + xlateMap.put("KW_TABLE", "TABLE"); + xlateMap.put("KW_TABLES", "TABLES"); + xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_SHOW", "SHOW"); + xlateMap.put("KW_MSCK", "MSCK"); + xlateMap.put("KW_DIRECTORY", "DIRECTORY"); + xlateMap.put("KW_LOCAL", "LOCAL"); + xlateMap.put("KW_TRANSFORM", "TRANSFORM"); + xlateMap.put("KW_USING", "USING"); + xlateMap.put("KW_CLUSTER", "CLUSTER"); + xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); + xlateMap.put("KW_SORT", "SORT"); + xlateMap.put("KW_UNION", "UNION"); + xlateMap.put("KW_LOAD", "LOAD"); + xlateMap.put("KW_DATA", "DATA"); + xlateMap.put("KW_INPATH", "INPATH"); + xlateMap.put("KW_IS", "IS"); + xlateMap.put("KW_NULL", "NULL"); + xlateMap.put("KW_CREATE", "CREATE"); + xlateMap.put("KW_EXTERNAL", "EXTERNAL"); + xlateMap.put("KW_ALTER", "ALTER"); + xlateMap.put("KW_DESCRIBE", "DESCRIBE"); + xlateMap.put("KW_DROP", "DROP"); + xlateMap.put("KW_RENAME", "RENAME"); + xlateMap.put("KW_TO", "TO"); + xlateMap.put("KW_COMMENT", "COMMENT"); + xlateMap.put("KW_BOOLEAN", "BOOLEAN"); + xlateMap.put("KW_TINYINT", "TINYINT"); + xlateMap.put("KW_SMALLINT", "SMALLINT"); + xlateMap.put("KW_INT", "INT"); + xlateMap.put("KW_BIGINT", "BIGINT"); + xlateMap.put("KW_FLOAT", "FLOAT"); + xlateMap.put("KW_DOUBLE", "DOUBLE"); + xlateMap.put("KW_DATE", "DATE"); + xlateMap.put("KW_DATETIME", "DATETIME"); + xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); + xlateMap.put("KW_STRING", "STRING"); + xlateMap.put("KW_BINARY", "BINARY"); + xlateMap.put("KW_ARRAY", "ARRAY"); + xlateMap.put("KW_MAP", "MAP"); + xlateMap.put("KW_REDUCE", "REDUCE"); + xlateMap.put("KW_PARTITIONED", "PARTITIONED"); + xlateMap.put("KW_CLUSTERED", "CLUSTERED"); + xlateMap.put("KW_SORTED", "SORTED"); + xlateMap.put("KW_INTO", "INTO"); + xlateMap.put("KW_BUCKETS", "BUCKETS"); + xlateMap.put("KW_ROW", "ROW"); + xlateMap.put("KW_FORMAT", "FORMAT"); + xlateMap.put("KW_DELIMITED", "DELIMITED"); + xlateMap.put("KW_FIELDS", "FIELDS"); + xlateMap.put("KW_TERMINATED", "TERMINATED"); + xlateMap.put("KW_COLLECTION", "COLLECTION"); + xlateMap.put("KW_ITEMS", "ITEMS"); + xlateMap.put("KW_KEYS", "KEYS"); + xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); + xlateMap.put("KW_LINES", "LINES"); + xlateMap.put("KW_STORED", "STORED"); + xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); + xlateMap.put("KW_TEXTFILE", "TEXTFILE"); + xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); + xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); + xlateMap.put("KW_LOCATION", "LOCATION"); + xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); + xlateMap.put("KW_BUCKET", "BUCKET"); + xlateMap.put("KW_OUT", "OUT"); + xlateMap.put("KW_OF", "OF"); + xlateMap.put("KW_CAST", "CAST"); + xlateMap.put("KW_ADD", "ADD"); + xlateMap.put("KW_REPLACE", "REPLACE"); + xlateMap.put("KW_COLUMNS", "COLUMNS"); + xlateMap.put("KW_RLIKE", "RLIKE"); + xlateMap.put("KW_REGEXP", "REGEXP"); + xlateMap.put("KW_TEMPORARY", "TEMPORARY"); + xlateMap.put("KW_FUNCTION", "FUNCTION"); + xlateMap.put("KW_EXPLAIN", "EXPLAIN"); + xlateMap.put("KW_EXTENDED", "EXTENDED"); + xlateMap.put("KW_SERDE", "SERDE"); + xlateMap.put("KW_WITH", "WITH"); + xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); + xlateMap.put("KW_LIMIT", "LIMIT"); + xlateMap.put("KW_SET", "SET"); + xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); + xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); + xlateMap.put("KW_DEFINED", "DEFINED"); + xlateMap.put("KW_SUBQUERY", "SUBQUERY"); + xlateMap.put("KW_REWRITE", "REWRITE"); + xlateMap.put("KW_UPDATE", "UPDATE"); + xlateMap.put("KW_VALUES", "VALUES"); + xlateMap.put("KW_PURGE", "PURGE"); + + + // Operators + xlateMap.put("DOT", "."); + xlateMap.put("COLON", ":"); + xlateMap.put("COMMA", ","); + xlateMap.put("SEMICOLON", ");"); + + xlateMap.put("LPAREN", "("); + xlateMap.put("RPAREN", ")"); + xlateMap.put("LSQUARE", "["); + xlateMap.put("RSQUARE", "]"); + + xlateMap.put("EQUAL", "="); + xlateMap.put("NOTEQUAL", "<>"); + xlateMap.put("EQUAL_NS", "<=>"); + xlateMap.put("LESSTHANOREQUALTO", "<="); + xlateMap.put("LESSTHAN", "<"); + xlateMap.put("GREATERTHANOREQUALTO", ">="); + xlateMap.put("GREATERTHAN", ">"); + + xlateMap.put("DIVIDE", "/"); + xlateMap.put("PLUS", "+"); + xlateMap.put("MINUS", "-"); + xlateMap.put("STAR", "*"); + xlateMap.put("MOD", "\%"); + + xlateMap.put("AMPERSAND", "&"); + xlateMap.put("TILDE", "~"); + xlateMap.put("BITWISEOR", "|"); + xlateMap.put("BITWISEXOR", "^"); + xlateMap.put("CharSetLiteral", "\\'"); + } + + public static Collection getKeywords() { + return xlateMap.values(); + } + + private static String xlate(String name) { + + String ret = xlateMap.get(name); + if (ret == null) { + ret = name; + } + + return ret; + } + + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + errors.add(new ParseError(this, e, tokenNames)); + } + + @Override + public String getErrorHeader(RecognitionException e) { + String header = null; + if (e.charPositionInLine < 0 && input.LT(-1) != null) { + Token t = input.LT(-1); + header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); + } else { + header = super.getErrorHeader(e); + } + + return header; + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + // Translate the token names to something that the user can understand + String[] xlateNames = new String[tokenNames.length]; + for (int i = 0; i < tokenNames.length; ++i) { + xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); + } + + if (e instanceof NoViableAltException) { + @SuppressWarnings("unused") + NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "cannot recognize input near" + + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") + + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") + + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); + } else if (e instanceof MismatchedTokenException) { + MismatchedTokenException mte = (MismatchedTokenException) e; + msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; + } else if (e instanceof FailedPredicateException) { + FailedPredicateException fpe = (FailedPredicateException) e; + msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; + } else { + msg = super.getErrorMessage(e, xlateNames); + } + + if (msgs.size() > 0) { + msg = msg + " in " + msgs.peek(); + } + return msg; + } + + public void pushMsg(String msg, RecognizerSharedState state) { + // ANTLR generated code does not wrap the @init code wit this backtracking check, + // even if the matching @after has it. If we have parser rules with that are doing + // some lookahead with syntactic predicates this can cause the push() and pop() calls + // to become unbalanced, so make sure both push/pop check the backtracking state. + if (state.backtracking == 0) { + msgs.push(msg); + } + } + + public void popMsg(RecognizerSharedState state) { + if (state.backtracking == 0) { + Object o = msgs.pop(); + } + } + + // counter to generate unique union aliases + private int aliasCounter; + private String generateUnionAlias() { + return "_u" + (++aliasCounter); + } + private char [] excludedCharForColumnName = {'.', ':'}; + private boolean containExcludedCharForCreateTableColumnName(String input) { + for(char c : excludedCharForColumnName) { + if(input.indexOf(c)>-1) { + return true; + } + } + return false; + } + private CommonTree throwSetOpException() throws RecognitionException { + throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); + } + private CommonTree throwColumnNameException() throws RecognitionException { + throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); + } + private Configuration hiveConf; + public void setHiveConf(Configuration hiveConf) { + this.hiveConf = hiveConf; + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + if(hiveConf==null){ + return false; + } + return !HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS); + } +} + +@rulecatch { +catch (RecognitionException e) { + reportError(e); + throw e; +} +} + +// starting rule +statement + : explainStatement EOF + | execStatement EOF + ; + +explainStatement +@init { pushMsg("explain statement", state); } +@after { popMsg(state); } + : KW_EXPLAIN ( + explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) + | + KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) + ; + +explainOption +@init { msgs.push("explain option"); } +@after { msgs.pop(); } + : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION + ; + +execStatement +@init { pushMsg("statement", state); } +@after { popMsg(state); } + : queryStatementExpression[true] + | loadStatement + | exportStatement + | importStatement + | ddlStatement + | deleteStatement + | updateStatement + | sqlTransactionStatement + ; + +loadStatement +@init { pushMsg("load statement", state); } +@after { popMsg(state); } + : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) + -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) + ; + +replicationClause +@init { pushMsg("replication clause", state); } +@after { popMsg(state); } + : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN + -> ^(TOK_REPLICATION $replId $isMetadataOnly?) + ; + +exportStatement +@init { pushMsg("export statement", state); } +@after { popMsg(state); } + : KW_EXPORT + KW_TABLE (tab=tableOrPartition) + KW_TO (path=StringLiteral) + replicationClause? + -> ^(TOK_EXPORT $tab $path replicationClause?) + ; + +importStatement +@init { pushMsg("import statement", state); } +@after { popMsg(state); } + : KW_IMPORT + ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? + KW_FROM (path=StringLiteral) + tableLocation? + -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) + ; + +ddlStatement +@init { pushMsg("ddl statement", state); } +@after { popMsg(state); } + : createDatabaseStatement + | switchDatabaseStatement + | dropDatabaseStatement + | createTableStatement + | dropTableStatement + | truncateTableStatement + | alterStatement + | descStatement + | showStatement + | metastoreCheck + | createViewStatement + | dropViewStatement + | createFunctionStatement + | createMacroStatement + | createIndexStatement + | dropIndexStatement + | dropFunctionStatement + | reloadFunctionStatement + | dropMacroStatement + | analyzeStatement + | lockStatement + | unlockStatement + | lockDatabase + | unlockDatabase + | createRoleStatement + | dropRoleStatement + | (grantPrivileges) => grantPrivileges + | (revokePrivileges) => revokePrivileges + | showGrants + | showRoleGrants + | showRolePrincipals + | showRoles + | grantRole + | revokeRole + | setRole + | showCurrentRole + ; + +ifExists +@init { pushMsg("if exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_EXISTS + -> ^(TOK_IFEXISTS) + ; + +restrictOrCascade +@init { pushMsg("restrict or cascade clause", state); } +@after { popMsg(state); } + : KW_RESTRICT + -> ^(TOK_RESTRICT) + | KW_CASCADE + -> ^(TOK_CASCADE) + ; + +ifNotExists +@init { pushMsg("if not exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_NOT KW_EXISTS + -> ^(TOK_IFNOTEXISTS) + ; + +storedAsDirs +@init { pushMsg("stored as directories", state); } +@after { popMsg(state); } + : KW_STORED KW_AS KW_DIRECTORIES + -> ^(TOK_STOREDASDIRS) + ; + +orReplace +@init { pushMsg("or replace clause", state); } +@after { popMsg(state); } + : KW_OR KW_REPLACE + -> ^(TOK_ORREPLACE) + ; + +createDatabaseStatement +@init { pushMsg("create database statement", state); } +@after { popMsg(state); } + : KW_CREATE (KW_DATABASE|KW_SCHEMA) + ifNotExists? + name=identifier + databaseComment? + dbLocation? + (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? + -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) + ; + +dbLocation +@init { pushMsg("database location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) + ; + +dbProperties +@init { pushMsg("dbproperties", state); } +@after { popMsg(state); } + : + LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) + ; + +dbPropertiesList +@init { pushMsg("database properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) + ; + + +switchDatabaseStatement +@init { pushMsg("switch database statement", state); } +@after { popMsg(state); } + : KW_USE identifier + -> ^(TOK_SWITCHDATABASE identifier) + ; + +dropDatabaseStatement +@init { pushMsg("drop database statement", state); } +@after { popMsg(state); } + : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? + -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) + ; + +databaseComment +@init { pushMsg("database's comment", state); } +@after { popMsg(state); } + : KW_COMMENT comment=StringLiteral + -> ^(TOK_DATABASECOMMENT $comment) + ; + +createTableStatement +@init { pushMsg("create table statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName + ( like=KW_LIKE likeName=tableName + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + | (LPAREN columnNameTypeList RPAREN)? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + (KW_AS selectStatementWithCTE)? + ) + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + columnNameTypeList? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + selectStatementWithCTE? + ) + ; + +truncateTableStatement +@init { pushMsg("truncate table statement", state); } +@after { popMsg(state); } + : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); + +createIndexStatement +@init { pushMsg("create index statement", state);} +@after {popMsg(state);} + : KW_CREATE KW_INDEX indexName=identifier + KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN + KW_AS typeName=StringLiteral + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment? + ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment?) + ; + +indexComment +@init { pushMsg("comment on an index", state);} +@after {popMsg(state);} + : + KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) + ; + +autoRebuild +@init { pushMsg("auto rebuild index", state);} +@after {popMsg(state);} + : KW_WITH KW_DEFERRED KW_REBUILD + ->^(TOK_DEFERRED_REBUILDINDEX) + ; + +indexTblName +@init { pushMsg("index table name", state);} +@after {popMsg(state);} + : KW_IN KW_TABLE indexTbl=tableName + ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) + ; + +indexPropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_IDXPROPERTIES! indexProperties + ; + +indexProperties +@init { pushMsg("index properties", state); } +@after { popMsg(state); } + : + LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) + ; + +indexPropertiesList +@init { pushMsg("index properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) + ; + +dropIndexStatement +@init { pushMsg("drop index statement", state);} +@after {popMsg(state);} + : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName + ->^(TOK_DROPINDEX $indexName $tab ifExists?) + ; + +dropTableStatement +@init { pushMsg("drop statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? + -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) + ; + +alterStatement +@init { pushMsg("alter statement", state); } +@after { popMsg(state); } + : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) + | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) + | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix + | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix + ; + +alterTableStatementSuffix +@init { pushMsg("alter table statement", state); } +@after { popMsg(state); } + : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] + | alterStatementSuffixDropPartitions[true] + | alterStatementSuffixAddPartitions[true] + | alterStatementSuffixTouch + | alterStatementSuffixArchive + | alterStatementSuffixUnArchive + | alterStatementSuffixProperties + | alterStatementSuffixSkewedby + | alterStatementSuffixExchangePartition + | alterStatementPartitionKeyType + | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? + ; + +alterTblPartitionStatementSuffix +@init {pushMsg("alter table partition statement suffix", state);} +@after {popMsg(state);} + : alterStatementSuffixFileFormat + | alterStatementSuffixLocation + | alterStatementSuffixMergeFiles + | alterStatementSuffixSerdeProperties + | alterStatementSuffixRenamePart + | alterStatementSuffixBucketNum + | alterTblPartitionStatementSuffixSkewedLocation + | alterStatementSuffixClusterbySortby + | alterStatementSuffixCompact + | alterStatementSuffixUpdateStatsCol + | alterStatementSuffixRenameCol + | alterStatementSuffixAddCol + ; + +alterStatementPartitionKeyType +@init {msgs.push("alter partition key type"); } +@after {msgs.pop();} + : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN + -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) + ; + +alterViewStatementSuffix +@init { pushMsg("alter view statement", state); } +@after { popMsg(state); } + : alterViewSuffixProperties + | alterStatementSuffixRename[false] + | alterStatementSuffixAddPartitions[false] + | alterStatementSuffixDropPartitions[false] + | selectStatementWithCTE + ; + +alterIndexStatementSuffix +@init { pushMsg("alter index statement", state); } +@after { popMsg(state); } + : indexName=identifier KW_ON tableName partitionSpec? + ( + KW_REBUILD + ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) + | + KW_SET KW_IDXPROPERTIES + indexProperties + ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) + ) + ; + +alterDatabaseStatementSuffix +@init { pushMsg("alter database statement", state); } +@after { popMsg(state); } + : alterDatabaseSuffixProperties + | alterDatabaseSuffixSetOwner + ; + +alterDatabaseSuffixProperties +@init { pushMsg("alter database properties statement", state); } +@after { popMsg(state); } + : name=identifier KW_SET KW_DBPROPERTIES dbProperties + -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) + ; + +alterDatabaseSuffixSetOwner +@init { pushMsg("alter database set owner", state); } +@after { popMsg(state); } + : dbName=identifier KW_SET KW_OWNER principalName + -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) + ; + +alterStatementSuffixRename[boolean table] +@init { pushMsg("rename statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO tableName + -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) + -> ^(TOK_ALTERVIEW_RENAME tableName) + ; + +alterStatementSuffixAddCol +@init { pushMsg("add column statement", state); } +@after { popMsg(state); } + : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? + -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) + -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) + ; + +alterStatementSuffixRenameCol +@init { pushMsg("rename column name", state); } +@after { popMsg(state); } + : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? + ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) + ; + +alterStatementSuffixUpdateStatsCol +@init { pushMsg("update column statistics", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementChangeColPosition + : first=KW_FIRST|KW_AFTER afterCol=identifier + ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) + -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) + ; + +alterStatementSuffixAddPartitions[boolean table] +@init { pushMsg("add partition statement", state); } +@after { popMsg(state); } + : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ + -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + ; + +alterStatementSuffixAddPartitionsElement + : partitionSpec partitionLocation? + ; + +alterStatementSuffixTouch +@init { pushMsg("touch statement", state); } +@after { popMsg(state); } + : KW_TOUCH (partitionSpec)* + -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) + ; + +alterStatementSuffixArchive +@init { pushMsg("archive statement", state); } +@after { popMsg(state); } + : KW_ARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) + ; + +alterStatementSuffixUnArchive +@init { pushMsg("unarchive statement", state); } +@after { popMsg(state); } + : KW_UNARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) + ; + +partitionLocation +@init { pushMsg("partition location", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) + ; + +alterStatementSuffixDropPartitions[boolean table] +@init { pushMsg("drop partition statement", state); } +@after { popMsg(state); } + : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? + -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) + -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) + ; + +alterStatementSuffixProperties +@init { pushMsg("alter properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) + ; + +alterViewSuffixProperties +@init { pushMsg("alter view properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) + ; + +alterStatementSuffixSerdeProperties +@init { pushMsg("alter serdes statement", state); } +@after { popMsg(state); } + : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? + -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) + | KW_SET KW_SERDEPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) + ; + +tablePartitionPrefix +@init {pushMsg("table partition prefix", state);} +@after {popMsg(state);} + : tableName partitionSpec? + ->^(TOK_TABLE_PARTITION tableName partitionSpec?) + ; + +alterStatementSuffixFileFormat +@init {pushMsg("alter fileformat statement", state); } +@after {popMsg(state);} + : KW_SET KW_FILEFORMAT fileFormat + -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) + ; + +alterStatementSuffixClusterbySortby +@init {pushMsg("alter partition cluster by sort by statement", state);} +@after {popMsg(state);} + : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) + | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) + | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) + ; + +alterTblPartitionStatementSuffixSkewedLocation +@init {pushMsg("alter partition skewed location", state);} +@after {popMsg(state);} + : KW_SET KW_SKEWED KW_LOCATION skewedLocations + -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) + ; + +skewedLocations +@init { pushMsg("skewed locations", state); } +@after { popMsg(state); } + : + LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) + ; + +skewedLocationsList +@init { pushMsg("skewed locations list", state); } +@after { popMsg(state); } + : + skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) + ; + +skewedLocationMap +@init { pushMsg("specifying skewed location map", state); } +@after { popMsg(state); } + : + key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) + ; + +alterStatementSuffixLocation +@init {pushMsg("alter location", state);} +@after {popMsg(state);} + : KW_SET KW_LOCATION newLoc=StringLiteral + -> ^(TOK_ALTERTABLE_LOCATION $newLoc) + ; + + +alterStatementSuffixSkewedby +@init {pushMsg("alter skewed by statement", state);} +@after{popMsg(state);} + : tableSkewed + ->^(TOK_ALTERTABLE_SKEWED tableSkewed) + | + KW_NOT KW_SKEWED + ->^(TOK_ALTERTABLE_SKEWED) + | + KW_NOT storedAsDirs + ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) + ; + +alterStatementSuffixExchangePartition +@init {pushMsg("alter exchange partition", state);} +@after{popMsg(state);} + : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName + -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) + ; + +alterStatementSuffixRenamePart +@init { pushMsg("alter table rename partition statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO partitionSpec + ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) + ; + +alterStatementSuffixStatsPart +@init { pushMsg("alter table stats partition statement", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementSuffixMergeFiles +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_CONCATENATE + -> ^(TOK_ALTERTABLE_MERGEFILES) + ; + +alterStatementSuffixBucketNum +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $num) + ; + +alterStatementSuffixCompact +@init { msgs.push("compaction request"); } +@after { msgs.pop(); } + : KW_COMPACT compactType=StringLiteral + -> ^(TOK_ALTERTABLE_COMPACT $compactType) + ; + + +fileFormat +@init { pushMsg("file format specification", state); } +@after { popMsg(state); } + : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) + | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tabTypeExpr +@init { pushMsg("specifying table types", state); } +@after { popMsg(state); } + : identifier (DOT^ identifier)? + (identifier (DOT^ + ( + (KW_ELEM_TYPE) => KW_ELEM_TYPE + | + (KW_KEY_TYPE) => KW_KEY_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE + | identifier + ))* + )? + ; + +partTypeExpr +@init { pushMsg("specifying table partitions", state); } +@after { popMsg(state); } + : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) + ; + +tabPartColTypeExpr +@init { pushMsg("specifying table partitions columnName", state); } +@after { popMsg(state); } + : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) + ; + +descStatement +@init { pushMsg("describe statement", state); } +@after { popMsg(state); } + : + (KW_DESCRIBE|KW_DESC) + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) + | + (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) + | + (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) + | + parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) + ) + ; + +analyzeStatement +@init { pushMsg("analyze statement", state); } +@after { popMsg(state); } + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? + -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) + ; + +showStatement +@init { pushMsg("show statement", state); } +@after { popMsg(state); } + : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) + | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) + | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWCOLUMNS tableName $db_name?) + | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_CREATE ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) + | + KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) + ) + | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? + -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) + | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) + | KW_SHOW KW_LOCKS + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) + | + (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) + ) + | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) + | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) + | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) + | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) + ; + +lockStatement +@init { pushMsg("lock statement", state); } +@after { popMsg(state); } + : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) + ; + +lockDatabase +@init { pushMsg("lock database statement", state); } +@after { popMsg(state); } + : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) + ; + +lockMode +@init { pushMsg("lock mode", state); } +@after { popMsg(state); } + : KW_SHARED | KW_EXCLUSIVE + ; + +unlockStatement +@init { pushMsg("unlock statement", state); } +@after { popMsg(state); } + : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) + ; + +unlockDatabase +@init { pushMsg("unlock database statement", state); } +@after { popMsg(state); } + : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) + ; + +createRoleStatement +@init { pushMsg("create role", state); } +@after { popMsg(state); } + : KW_CREATE KW_ROLE roleName=identifier + -> ^(TOK_CREATEROLE $roleName) + ; + +dropRoleStatement +@init {pushMsg("drop role", state);} +@after {popMsg(state);} + : KW_DROP KW_ROLE roleName=identifier + -> ^(TOK_DROPROLE $roleName) + ; + +grantPrivileges +@init {pushMsg("grant privileges", state);} +@after {popMsg(state);} + : KW_GRANT privList=privilegeList + privilegeObject? + KW_TO principalSpecification + withGrantOption? + -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) + ; + +revokePrivileges +@init {pushMsg("revoke privileges", state);} +@afer {popMsg(state);} + : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification + -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) + ; + +grantRole +@init {pushMsg("grant role", state);} +@after {popMsg(state);} + : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? + -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) + ; + +revokeRole +@init {pushMsg("revoke role", state);} +@after {popMsg(state);} + : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification + -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) + ; + +showRoleGrants +@init {pushMsg("show role grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLE KW_GRANT principalName + -> ^(TOK_SHOW_ROLE_GRANT principalName) + ; + + +showRoles +@init {pushMsg("show roles", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLES + -> ^(TOK_SHOW_ROLES) + ; + +showCurrentRole +@init {pushMsg("show current role", state);} +@after {popMsg(state);} + : KW_SHOW KW_CURRENT KW_ROLES + -> ^(TOK_SHOW_SET_ROLE) + ; + +setRole +@init {pushMsg("set role", state);} +@after {popMsg(state);} + : KW_SET KW_ROLE + ( + (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) + | + (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) + | + identifier -> ^(TOK_SHOW_SET_ROLE identifier) + ) + ; + +showGrants +@init {pushMsg("show grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? + -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) + ; + +showRolePrincipals +@init {pushMsg("show role principals", state);} +@after {popMsg(state);} + : KW_SHOW KW_PRINCIPALS roleName=identifier + -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) + ; + + +privilegeIncludeColObject +@init {pushMsg("privilege object including columns", state);} +@after {popMsg(state);} + : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) + | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) + ; + +privilegeObject +@init {pushMsg("privilege object", state);} +@after {popMsg(state);} + : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) + ; + +// database or table type. Type is optional, default type is table +privObject + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privObjectCols + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privilegeList +@init {pushMsg("grant privilege list", state);} +@after {popMsg(state);} + : privlegeDef (COMMA privlegeDef)* + -> ^(TOK_PRIVILEGE_LIST privlegeDef+) + ; + +privlegeDef +@init {pushMsg("grant privilege", state);} +@after {popMsg(state);} + : privilegeType (LPAREN cols=columnNameList RPAREN)? + -> ^(TOK_PRIVILEGE privilegeType $cols?) + ; + +privilegeType +@init {pushMsg("privilege type", state);} +@after {popMsg(state);} + : KW_ALL -> ^(TOK_PRIV_ALL) + | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) + | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) + | KW_CREATE -> ^(TOK_PRIV_CREATE) + | KW_DROP -> ^(TOK_PRIV_DROP) + | KW_INDEX -> ^(TOK_PRIV_INDEX) + | KW_LOCK -> ^(TOK_PRIV_LOCK) + | KW_SELECT -> ^(TOK_PRIV_SELECT) + | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) + | KW_INSERT -> ^(TOK_PRIV_INSERT) + | KW_DELETE -> ^(TOK_PRIV_DELETE) + ; + +principalSpecification +@init { pushMsg("user/group/role name list", state); } +@after { popMsg(state); } + : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) + ; + +principalName +@init {pushMsg("user|group|role name", state);} +@after {popMsg(state);} + : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) + | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) + | KW_ROLE identifier -> ^(TOK_ROLE identifier) + ; + +withGrantOption +@init {pushMsg("with grant option", state);} +@after {popMsg(state);} + : KW_WITH KW_GRANT KW_OPTION + -> ^(TOK_GRANT_WITH_OPTION) + ; + +grantOptionFor +@init {pushMsg("grant option for", state);} +@after {popMsg(state);} + : KW_GRANT KW_OPTION KW_FOR + -> ^(TOK_GRANT_OPTION_FOR) +; + +adminOptionFor +@init {pushMsg("admin option for", state);} +@after {popMsg(state);} + : KW_ADMIN KW_OPTION KW_FOR + -> ^(TOK_ADMIN_OPTION_FOR) +; + +withAdminOption +@init {pushMsg("with admin option", state);} +@after {popMsg(state);} + : KW_WITH KW_ADMIN KW_OPTION + -> ^(TOK_GRANT_WITH_ADMIN_OPTION) + ; + +metastoreCheck +@init { pushMsg("metastore check statement", state); } +@after { popMsg(state); } + : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? + -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) + ; + +resourceList +@init { pushMsg("resource list", state); } +@after { popMsg(state); } + : + resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) + ; + +resource +@init { pushMsg("resource", state); } +@after { popMsg(state); } + : + resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) + ; + +resourceType +@init { pushMsg("resource type", state); } +@after { popMsg(state); } + : + KW_JAR -> ^(TOK_JAR) + | + KW_FILE -> ^(TOK_FILE) + | + KW_ARCHIVE -> ^(TOK_ARCHIVE) + ; + +createFunctionStatement +@init { pushMsg("create function statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral + (KW_USING rList=resourceList)? + -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) + -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) + ; + +dropFunctionStatement +@init { pushMsg("drop function statement", state); } +@after { popMsg(state); } + : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier + -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) + -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) + ; + +reloadFunctionStatement +@init { pushMsg("reload function statement", state); } +@after { popMsg(state); } + : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); + +createMacroStatement +@init { pushMsg("create macro statement", state); } +@after { popMsg(state); } + : KW_CREATE KW_TEMPORARY KW_MACRO Identifier + LPAREN columnNameTypeList? RPAREN expression + -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) + ; + +dropMacroStatement +@init { pushMsg("drop macro statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier + -> ^(TOK_DROPMACRO Identifier ifExists?) + ; + +createViewStatement +@init { + pushMsg("create view statement", state); +} +@after { popMsg(state); } + : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName + (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? + tablePropertiesPrefixed? + KW_AS + selectStatementWithCTE + -> ^(TOK_CREATEVIEW $name orReplace? + ifNotExists? + columnNameCommentList? + tableComment? + viewPartition? + tablePropertiesPrefixed? + selectStatementWithCTE + ) + ; + +viewPartition +@init { pushMsg("view partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN + -> ^(TOK_VIEWPARTCOLS columnNameList) + ; + +dropViewStatement +@init { pushMsg("drop view statement", state); } +@after { popMsg(state); } + : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) + ; + +showFunctionIdentifier +@init { pushMsg("identifier for show function statement", state); } +@after { popMsg(state); } + : functionIdentifier + | StringLiteral + ; + +showStmtIdentifier +@init { pushMsg("identifier for show statement", state); } +@after { popMsg(state); } + : identifier + | StringLiteral + ; + +tableComment +@init { pushMsg("table's comment", state); } +@after { popMsg(state); } + : + KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) + ; + +tablePartition +@init { pushMsg("table partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN + -> ^(TOK_TABLEPARTCOLS columnNameTypeList) + ; + +tableBuckets +@init { pushMsg("table buckets specification", state); } +@after { popMsg(state); } + : + KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) + ; + +tableSkewed +@init { pushMsg("table skewed specification", state); } +@after { popMsg(state); } + : + KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? + -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) + ; + +rowFormat +@init { pushMsg("serde specification", state); } +@after { popMsg(state); } + : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) + | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) + | -> ^(TOK_SERDE) + ; + +recordReader +@init { pushMsg("record reader specification", state); } +@after { popMsg(state); } + : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) + | -> ^(TOK_RECORDREADER) + ; + +recordWriter +@init { pushMsg("record writer specification", state); } +@after { popMsg(state); } + : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) + | -> ^(TOK_RECORDWRITER) + ; + +rowFormatSerde +@init { pushMsg("serde format specification", state); } +@after { popMsg(state); } + : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_SERDENAME $name $serdeprops?) + ; + +rowFormatDelimited +@init { pushMsg("serde properties specification", state); } +@after { popMsg(state); } + : + KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? + -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) + ; + +tableRowFormat +@init { pushMsg("table row format specification", state); } +@after { popMsg(state); } + : + rowFormatDelimited + -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) + | rowFormatSerde + -> ^(TOK_TABLESERIALIZER rowFormatSerde) + ; + +tablePropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_TBLPROPERTIES! tableProperties + ; + +tableProperties +@init { pushMsg("table properties", state); } +@after { popMsg(state); } + : + LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) + ; + +tablePropertiesList +@init { pushMsg("table properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) + | + keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) + ; + +keyValueProperty +@init { pushMsg("specifying key/value property", state); } +@after { popMsg(state); } + : + key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) + ; + +keyProperty +@init { pushMsg("specifying key property", state); } +@after { popMsg(state); } + : + key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) + ; + +tableRowFormatFieldIdentifier +@init { pushMsg("table row format's field separator", state); } +@after { popMsg(state); } + : + KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? + -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) + ; + +tableRowFormatCollItemsIdentifier +@init { pushMsg("table row format's column separator", state); } +@after { popMsg(state); } + : + KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) + ; + +tableRowFormatMapKeysIdentifier +@init { pushMsg("table row format's map key separator", state); } +@after { popMsg(state); } + : + KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) + ; + +tableRowFormatLinesIdentifier +@init { pushMsg("table row format's line separator", state); } +@after { popMsg(state); } + : + KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) + ; + +tableRowNullFormat +@init { pushMsg("table row format's null specifier", state); } +@after { popMsg(state); } + : + KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) + ; +tableFileFormat +@init { pushMsg("table file format specification", state); } +@after { popMsg(state); } + : + (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) + | KW_STORED KW_BY storageHandler=StringLiteral + (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) + | KW_STORED KW_AS genericSpec=identifier + -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tableLocation +@init { pushMsg("table location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) + ; + +columnNameTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) + ; + +columnNameColonTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) + ; + +columnNameList +@init { pushMsg("column name list", state); } +@after { popMsg(state); } + : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) + ; + +columnName +@init { pushMsg("column name", state); } +@after { popMsg(state); } + : + identifier + ; + +extColumnName +@init { pushMsg("column name for complex types", state); } +@after { popMsg(state); } + : + identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* + ; + +columnNameOrderList +@init { pushMsg("column name order list", state); } +@after { popMsg(state); } + : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) + ; + +skewedValueElement +@init { pushMsg("skewed value element", state); } +@after { popMsg(state); } + : + skewedColumnValues + | skewedColumnValuePairList + ; + +skewedColumnValuePairList +@init { pushMsg("column value pair list", state); } +@after { popMsg(state); } + : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) + ; + +skewedColumnValuePair +@init { pushMsg("column value pair", state); } +@after { popMsg(state); } + : + LPAREN colValues=skewedColumnValues RPAREN + -> ^(TOK_TABCOLVALUES $colValues) + ; + +skewedColumnValues +@init { pushMsg("column values", state); } +@after { popMsg(state); } + : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) + ; + +skewedColumnValue +@init { pushMsg("column value", state); } +@after { popMsg(state); } + : + constant + ; + +skewedValueLocationElement +@init { pushMsg("skewed value location element", state); } +@after { popMsg(state); } + : + skewedColumnValue + | skewedColumnValuePair + ; + +columnNameOrder +@init { pushMsg("column name order", state); } +@after { popMsg(state); } + : identifier (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) + -> ^(TOK_TABSORTCOLNAMEDESC identifier) + ; + +columnNameCommentList +@init { pushMsg("column name comment list", state); } +@after { popMsg(state); } + : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) + ; + +columnNameComment +@init { pushMsg("column name comment", state); } +@after { popMsg(state); } + : colName=identifier (KW_COMMENT comment=StringLiteral)? + -> ^(TOK_TABCOL $colName TOK_NULL $comment?) + ; + +columnRefOrder +@init { pushMsg("column order", state); } +@after { popMsg(state); } + : expression (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) + -> ^(TOK_TABSORTCOLNAMEDESC expression) + ; + +columnNameType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier colType (KW_COMMENT comment=StringLiteral)? + -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +columnNameColonType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +colType +@init { pushMsg("column type", state); } +@after { popMsg(state); } + : type + ; + +colTypeList +@init { pushMsg("column type list", state); } +@after { popMsg(state); } + : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) + ; + +type + : primitiveType + | listType + | structType + | mapType + | unionType; + +primitiveType +@init { pushMsg("primitive type specification", state); } +@after { popMsg(state); } + : KW_TINYINT -> TOK_TINYINT + | KW_SMALLINT -> TOK_SMALLINT + | KW_INT -> TOK_INT + | KW_BIGINT -> TOK_BIGINT + | KW_BOOLEAN -> TOK_BOOLEAN + | KW_FLOAT -> TOK_FLOAT + | KW_DOUBLE -> TOK_DOUBLE + | KW_DATE -> TOK_DATE + | KW_DATETIME -> TOK_DATETIME + | KW_TIMESTAMP -> TOK_TIMESTAMP + // Uncomment to allow intervals as table column types + //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH + //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME + | KW_STRING -> TOK_STRING + | KW_BINARY -> TOK_BINARY + | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) + | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) + | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) + ; + +listType +@init { pushMsg("list type", state); } +@after { popMsg(state); } + : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) + ; + +structType +@init { pushMsg("struct type", state); } +@after { popMsg(state); } + : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) + ; + +mapType +@init { pushMsg("map type", state); } +@after { popMsg(state); } + : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + -> ^(TOK_MAP $left $right) + ; + +unionType +@init { pushMsg("uniontype type", state); } +@after { popMsg(state); } + : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) + ; + +setOperator +@init { pushMsg("set operator", state); } +@after { popMsg(state); } + : KW_UNION KW_ALL -> ^(TOK_UNIONALL) + | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) + ; + +queryStatementExpression[boolean topLevel] + : + /* Would be nice to do this as a gated semantic perdicate + But the predicate gets pushed as a lookahead decision. + Calling rule doesnot know about topLevel + */ + (w=withClause {topLevel}?)? + queryStatementExpressionBody[topLevel] { + if ($w.tree != null) { + $queryStatementExpressionBody.tree.insertChild(0, $w.tree); + } + } + -> queryStatementExpressionBody + ; + +queryStatementExpressionBody[boolean topLevel] + : + fromStatement[topLevel] + | regularBody[topLevel] + ; + +withClause + : + KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) +; + +cteStatement + : + identifier KW_AS LPAREN queryStatementExpression[false] RPAREN + -> ^(TOK_SUBQUERY queryStatementExpression identifier) +; + +fromStatement[boolean topLevel] +: (singleFromStatement -> singleFromStatement) + (u=setOperator r=singleFromStatement + -> ^($u {$fromStatement.tree} $r) + )* + -> {u != null && topLevel}? ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$fromStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$fromStatement.tree} + ; + + +singleFromStatement + : + fromClause + ( b+=body )+ -> ^(TOK_QUERY fromClause body+) + ; + +/* +The valuesClause rule below ensures that the parse tree for +"insert into table FOO values (1,2),(3,4)" looks the same as +"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look +very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name +is implicit, it's represented as TOK_ANONYMOUS. +*/ +regularBody[boolean topLevel] + : + i=insertClause + ( + s=selectStatement[topLevel] + {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} + | + valuesClause + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) + ) + ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) + ) + ) + | + selectStatement[topLevel] + ; + +selectStatement[boolean topLevel] + : + ( + s=selectClause + f=fromClause? + w=whereClause? + g=groupByClause? + h=havingClause? + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + $s $w? $g? $h? $o? $c? + $d? $sort? $win? $l?)) + ) + (set=setOpSelectStatement[$selectStatement.tree, topLevel])? + -> {set == null}? + {$selectStatement.tree} + -> {o==null && c==null && d==null && sort==null && l==null}? + {$set.tree} + -> {throwSetOpException()} + ; + +setOpSelectStatement[CommonTree t, boolean topLevel] + : + (u=setOperator b=simpleSelectStatement + -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^(TOK_UNIONALL {$t} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> ^(TOK_UNIONALL {$t} $b) + )+ + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? + {$setOpSelectStatement.tree} + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$setOpSelectStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + $o? $c? $d? $sort? $win? $l? + ) + ) + ; + +simpleSelectStatement + : + selectClause + fromClause? + whereClause? + groupByClause? + havingClause? + ((window_clause) => window_clause)? + -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause whereClause? groupByClause? havingClause? window_clause?)) + ; + +selectStatementWithCTE + : + (w=withClause)? + selectStatement[true] { + if ($w.tree != null) { + $selectStatement.tree.insertChild(0, $w.tree); + } + } + -> selectStatement + ; + +body + : + insertClause + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT insertClause + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + | + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + ; + +insertClause +@init { pushMsg("insert clause", state); } +@after { popMsg(state); } + : + KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) + | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? + -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) + ; + +destination +@init { pushMsg("destination specification", state); } +@after { popMsg(state); } + : + (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? + -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) + | KW_TABLE tableOrPartition -> tableOrPartition + ; + +limitClause +@init { pushMsg("limit clause", state); } +@after { popMsg(state); } + : + KW_LIMIT num=Number -> ^(TOK_LIMIT $num) + ; + +//DELETE FROM WHERE ...; +deleteStatement +@init { pushMsg("delete statement", state); } +@after { popMsg(state); } + : + KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) + ; + +/*SET = (3 + col2)*/ +columnAssignmentClause + : + tableOrColumn EQUAL^ precedencePlusExpression + ; + +/*SET col1 = 5, col2 = (4 + col4), ...*/ +setColumnsClause + : + KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) + ; + +/* + UPDATE
    + SET col1 = val1, col2 = val2... WHERE ... +*/ +updateStatement +@init { pushMsg("update statement", state); } +@after { popMsg(state); } + : + KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) + ; + +/* +BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of +"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. +*/ +sqlTransactionStatement +@init { pushMsg("transaction statement", state); } +@after { popMsg(state); } + : + startTransactionStatement + | commitStatement + | rollbackStatement + | setAutoCommitStatement + ; + +startTransactionStatement + : + KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) + ; + +transactionMode + : + isolationLevel + | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) + ; + +transactionAccessMode + : + KW_READ KW_ONLY -> TOK_TXN_READ_ONLY + | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE + ; + +isolationLevel + : + KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) + ; + +/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ +levelOfIsolation + : + KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT + ; + +commitStatement + : + KW_COMMIT ( KW_WORK )? -> TOK_COMMIT + ; + +rollbackStatement + : + KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK + ; +setAutoCommitStatement + : + KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) + ; +/* +END user defined transaction boundaries +*/ diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java new file mode 100644 index 0000000000000..35ecdc5ad10a9 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.antlr.runtime.RecognitionException; +import org.antlr.runtime.Token; +import org.antlr.runtime.TokenStream; +import org.antlr.runtime.tree.CommonErrorNode; + +public class ASTErrorNode extends ASTNode { + + /** + * + */ + private static final long serialVersionUID = 1L; + CommonErrorNode delegate; + + public ASTErrorNode(TokenStream input, Token start, Token stop, + RecognitionException e){ + delegate = new CommonErrorNode(input,start,stop,e); + } + + @Override + public boolean isNil() { return delegate.isNil(); } + + @Override + public int getType() { return delegate.getType(); } + + @Override + public String getText() { return delegate.getText(); } + @Override + public String toString() { return delegate.toString(); } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java new file mode 100644 index 0000000000000..33d9322b628ec --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java @@ -0,0 +1,245 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.antlr.runtime.Token; +import org.antlr.runtime.tree.CommonTree; +import org.antlr.runtime.tree.Tree; +import org.apache.hadoop.hive.ql.lib.Node; + +public class ASTNode extends CommonTree implements Node, Serializable { + private static final long serialVersionUID = 1L; + private transient StringBuffer astStr; + private transient int startIndx = -1; + private transient int endIndx = -1; + private transient ASTNode rootNode; + private transient boolean isValidASTStr; + + public ASTNode() { + } + + /** + * Constructor. + * + * @param t + * Token for the CommonTree Node + */ + public ASTNode(Token t) { + super(t); + } + + public ASTNode(ASTNode node) { + super(node); + } + + @Override + public Tree dupNode() { + return new ASTNode(this); + } + + /* + * (non-Javadoc) + * + * @see org.apache.hadoop.hive.ql.lib.Node#getChildren() + */ + @Override + public ArrayList getChildren() { + if (super.getChildCount() == 0) { + return null; + } + + ArrayList ret_vec = new ArrayList(); + for (int i = 0; i < super.getChildCount(); ++i) { + ret_vec.add((Node) super.getChild(i)); + } + + return ret_vec; + } + + /* + * (non-Javadoc) + * + * @see org.apache.hadoop.hive.ql.lib.Node#getName() + */ + @Override + public String getName() { + return (Integer.valueOf(super.getToken().getType())).toString(); + } + + public String dump() { + StringBuilder sb = new StringBuilder("\n"); + dump(sb, ""); + return sb.toString(); + } + + private StringBuilder dump(StringBuilder sb, String ws) { + sb.append(ws); + sb.append(toString()); + sb.append("\n"); + + ArrayList children = getChildren(); + if (children != null) { + for (Node node : getChildren()) { + if (node instanceof ASTNode) { + ((ASTNode) node).dump(sb, ws + " "); + } else { + sb.append(ws); + sb.append(" NON-ASTNODE!!"); + sb.append("\n"); + } + } + } + return sb; + } + + private ASTNode getRootNodeWithValidASTStr(boolean useMemoizedRoot) { + if (useMemoizedRoot && rootNode != null && rootNode.parent == null && + rootNode.hasValidMemoizedString()) { + return rootNode; + } + ASTNode retNode = this; + while (retNode.parent != null) { + retNode = (ASTNode) retNode.parent; + } + rootNode=retNode; + if (!rootNode.isValidASTStr) { + rootNode.astStr = new StringBuffer(); + rootNode.toStringTree(rootNode); + rootNode.isValidASTStr = true; + } + return retNode; + } + + private boolean hasValidMemoizedString() { + return isValidASTStr && astStr != null; + } + + private void resetRootInformation() { + // Reset the previously stored rootNode string + if (rootNode != null) { + rootNode.astStr = null; + rootNode.isValidASTStr = false; + } + } + + private int getMemoizedStringLen() { + return astStr == null ? 0 : astStr.length(); + } + + private String getMemoizedSubString(int start, int end) { + return (astStr == null || start < 0 || end > astStr.length() || start >= end) ? null : + astStr.subSequence(start, end).toString(); + } + + private void addtoMemoizedString(String string) { + if (astStr == null) { + astStr = new StringBuffer(); + } + astStr.append(string); + } + + @Override + public void setParent(Tree t) { + super.setParent(t); + resetRootInformation(); + } + + @Override + public void addChild(Tree t) { + super.addChild(t); + resetRootInformation(); + } + + @Override + public void addChildren(List kids) { + super.addChildren(kids); + resetRootInformation(); + } + + @Override + public void setChild(int i, Tree t) { + super.setChild(i, t); + resetRootInformation(); + } + + @Override + public void insertChild(int i, Object t) { + super.insertChild(i, t); + resetRootInformation(); + } + + @Override + public Object deleteChild(int i) { + Object ret = super.deleteChild(i); + resetRootInformation(); + return ret; + } + + @Override + public void replaceChildren(int startChildIndex, int stopChildIndex, Object t) { + super.replaceChildren(startChildIndex, stopChildIndex, t); + resetRootInformation(); + } + + @Override + public String toStringTree() { + + // The root might have changed because of tree modifications. + // Compute the new root for this tree and set the astStr. + getRootNodeWithValidASTStr(true); + + // If rootNotModified is false, then startIndx and endIndx will be stale. + if (startIndx >= 0 && endIndx <= rootNode.getMemoizedStringLen()) { + return rootNode.getMemoizedSubString(startIndx, endIndx); + } + return toStringTree(rootNode); + } + + private String toStringTree(ASTNode rootNode) { + this.rootNode = rootNode; + startIndx = rootNode.getMemoizedStringLen(); + // Leaf node + if ( children==null || children.size()==0 ) { + rootNode.addtoMemoizedString(this.toString()); + endIndx = rootNode.getMemoizedStringLen(); + return this.toString(); + } + if ( !isNil() ) { + rootNode.addtoMemoizedString("("); + rootNode.addtoMemoizedString(this.toString()); + rootNode.addtoMemoizedString(" "); + } + for (int i = 0; children!=null && i < children.size(); i++) { + ASTNode t = (ASTNode)children.get(i); + if ( i>0 ) { + rootNode.addtoMemoizedString(" "); + } + t.toStringTree(rootNode); + } + if ( !isNil() ) { + rootNode.addtoMemoizedString(")"); + } + endIndx = rootNode.getMemoizedStringLen(); + return rootNode.getMemoizedSubString(startIndx, endIndx); + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java new file mode 100644 index 0000000000000..c77198b087cbd --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.util.ArrayList; +import org.antlr.runtime.ANTLRStringStream; +import org.antlr.runtime.CharStream; +import org.antlr.runtime.NoViableAltException; +import org.antlr.runtime.RecognitionException; +import org.antlr.runtime.Token; +import org.antlr.runtime.TokenRewriteStream; +import org.antlr.runtime.TokenStream; +import org.antlr.runtime.tree.CommonTree; +import org.antlr.runtime.tree.CommonTreeAdaptor; +import org.antlr.runtime.tree.TreeAdaptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.hadoop.hive.ql.Context; + +/** + * ParseDriver. + * + */ +public class ParseDriver { + + private static final Logger LOG = LoggerFactory.getLogger("hive.ql.parse.ParseDriver"); + + /** + * ANTLRNoCaseStringStream. + * + */ + //This class provides and implementation for a case insensitive token checker + //for the lexical analysis part of antlr. By converting the token stream into + //upper case at the time when lexical rules are checked, this class ensures that the + //lexical rules need to just match the token with upper case letters as opposed to + //combination of upper case and lower case characters. This is purely used for matching lexical + //rules. The actual token text is stored in the same way as the user input without + //actually converting it into an upper case. The token values are generated by the consume() + //function of the super class ANTLRStringStream. The LA() function is the lookahead function + //and is purely used for matching lexical rules. This also means that the grammar will only + //accept capitalized tokens in case it is run from other tools like antlrworks which + //do not have the ANTLRNoCaseStringStream implementation. + public class ANTLRNoCaseStringStream extends ANTLRStringStream { + + public ANTLRNoCaseStringStream(String input) { + super(input); + } + + @Override + public int LA(int i) { + + int returnChar = super.LA(i); + if (returnChar == CharStream.EOF) { + return returnChar; + } else if (returnChar == 0) { + return returnChar; + } + + return Character.toUpperCase((char) returnChar); + } + } + + /** + * HiveLexerX. + * + */ + public class HiveLexerX extends SparkSqlLexer { + + private final ArrayList errors; + + public HiveLexerX(CharStream input) { + super(input); + errors = new ArrayList(); + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + errors.add(new ParseError(this, e, tokenNames)); + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + if (e instanceof NoViableAltException) { + // @SuppressWarnings("unused") + // NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "character " + getCharErrorDisplay(e.c) + " not supported here"; + } else { + msg = super.getErrorMessage(e, tokenNames); + } + + return msg; + } + + public ArrayList getErrors() { + return errors; + } + + } + + /** + * Tree adaptor for making antlr return ASTNodes instead of CommonTree nodes + * so that the graph walking algorithms and the rules framework defined in + * ql.lib can be used with the AST Nodes. + */ + public static final TreeAdaptor adaptor = new CommonTreeAdaptor() { + /** + * Creates an ASTNode for the given token. The ASTNode is a wrapper around + * antlr's CommonTree class that implements the Node interface. + * + * @param payload + * The token. + * @return Object (which is actually an ASTNode) for the token. + */ + @Override + public Object create(Token payload) { + return new ASTNode(payload); + } + + @Override + public Object dupNode(Object t) { + + return create(((CommonTree)t).token); + }; + + @Override + public Object errorNode(TokenStream input, Token start, Token stop, RecognitionException e) { + return new ASTErrorNode(input, start, stop, e); + }; + }; + + public ASTNode parse(String command) throws ParseException { + return parse(command, null); + } + + public ASTNode parse(String command, Context ctx) + throws ParseException { + return parse(command, ctx, true); + } + + /** + * Parses a command, optionally assigning the parser's token stream to the + * given context. + * + * @param command + * command to parse + * + * @param ctx + * context with which to associate this parser's token stream, or + * null if either no context is available or the context already has + * an existing stream + * + * @return parsed AST + */ + public ASTNode parse(String command, Context ctx, boolean setTokenRewriteStream) + throws ParseException { + LOG.info("Parsing command: " + command); + + HiveLexerX lexer = new HiveLexerX(new ANTLRNoCaseStringStream(command)); + TokenRewriteStream tokens = new TokenRewriteStream(lexer); + if (ctx != null) { + if ( setTokenRewriteStream) { + ctx.setTokenRewriteStream(tokens); + } + lexer.setHiveConf(ctx.getConf()); + } + SparkSqlParser parser = new SparkSqlParser(tokens); + if (ctx != null) { + parser.setHiveConf(ctx.getConf()); + } + parser.setTreeAdaptor(adaptor); + SparkSqlParser.statement_return r = null; + try { + r = parser.statement(); + } catch (RecognitionException e) { + e.printStackTrace(); + throw new ParseException(parser.errors); + } + + if (lexer.getErrors().size() == 0 && parser.errors.size() == 0) { + LOG.info("Parse Completed"); + } else if (lexer.getErrors().size() != 0) { + throw new ParseException(lexer.getErrors()); + } else { + throw new ParseException(parser.errors); + } + + ASTNode tree = (ASTNode) r.getTree(); + tree.setUnknownTokenBoundaries(); + return tree; + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java new file mode 100644 index 0000000000000..b47bcfb2914df --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.antlr.runtime.BaseRecognizer; +import org.antlr.runtime.RecognitionException; + +/** + * + */ +public class ParseError { + private final BaseRecognizer br; + private final RecognitionException re; + private final String[] tokenNames; + + ParseError(BaseRecognizer br, RecognitionException re, String[] tokenNames) { + this.br = br; + this.re = re; + this.tokenNames = tokenNames; + } + + BaseRecognizer getBaseRecognizer() { + return br; + } + + RecognitionException getRecognitionException() { + return re; + } + + String[] getTokenNames() { + return tokenNames; + } + + String getMessage() { + return br.getErrorHeader(re) + " " + br.getErrorMessage(re, tokenNames); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java new file mode 100644 index 0000000000000..fff891ced5550 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.util.ArrayList; + +/** + * ParseException. + * + */ +public class ParseException extends Exception { + + private static final long serialVersionUID = 1L; + ArrayList errors; + + public ParseException(ArrayList errors) { + super(); + this.errors = errors; + } + + @Override + public String getMessage() { + + StringBuilder sb = new StringBuilder(); + for (ParseError err : errors) { + if (sb.length() > 0) { + sb.append('\n'); + } + sb.append(err.getMessage()); + } + + return sb.toString(); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java new file mode 100644 index 0000000000000..a5c2998f86cc1 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + + +/** + * Library of utility functions used in the parse code. + * + */ +public final class ParseUtils { + /** + * Performs a descent of the leftmost branch of a tree, stopping when either a + * node with a non-null token is found or the leaf level is encountered. + * + * @param tree + * candidate node from which to start searching + * + * @return node at which descent stopped + */ + public static ASTNode findRootNonNullToken(ASTNode tree) { + while ((tree.getToken() == null) && (tree.getChildCount() > 0)) { + tree = (org.apache.spark.sql.parser.ASTNode) tree.getChild(0); + } + return tree; + } + + private ParseUtils() { + // prevent instantiation + } + + public static VarcharTypeInfo getVarcharTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() != 1) { + throw new SemanticException("Bad params for type varchar"); + } + + String lengthStr = node.getChild(0).getText(); + return TypeInfoFactory.getVarcharTypeInfo(Integer.valueOf(lengthStr)); + } + + public static CharTypeInfo getCharTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() != 1) { + throw new SemanticException("Bad params for type char"); + } + + String lengthStr = node.getChild(0).getText(); + return TypeInfoFactory.getCharTypeInfo(Integer.valueOf(lengthStr)); + } + + public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) + throws SemanticException { + if (node.getChildCount() > 2) { + throw new SemanticException("Bad params for type decimal"); + } + + int precision = HiveDecimal.USER_DEFAULT_PRECISION; + int scale = HiveDecimal.USER_DEFAULT_SCALE; + + if (node.getChildCount() >= 1) { + String precStr = node.getChild(0).getText(); + precision = Integer.valueOf(precStr); + } + + if (node.getChildCount() == 2) { + String scaleStr = node.getChild(1).getText(); + scale = Integer.valueOf(scaleStr); + } + + return TypeInfoFactory.getDecimalTypeInfo(precision, scale); + } + +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java new file mode 100644 index 0000000000000..4b2015e0df84e --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java @@ -0,0 +1,406 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser; + +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.antlr.runtime.tree.Tree; +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.ErrorMsg; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; + +/** + * SemanticAnalyzer. + * + */ +public abstract class SemanticAnalyzer { + public static String charSetString(String charSetName, String charSetString) + throws SemanticException { + try { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), + charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + String res = new String(bArray, charSetName); + return res; + } + } catch (UnsupportedEncodingException e) { + throw new SemanticException(e); + } + } + + /** + * Remove the encapsulating "`" pair from the identifier. We allow users to + * use "`" to escape identifier for table names, column names and aliases, in + * case that coincide with Hive language keywords. + */ + public static String unescapeIdentifier(String val) { + if (val == null) { + return null; + } + if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { + val = val.substring(1, val.length() - 1); + } + return val; + } + + /** + * Converts parsed key/value properties pairs into a map. + * + * @param prop ASTNode parent of the key/value pairs + * + * @param mapProp property map which receives the mappings + */ + public static void readProps( + ASTNode prop, Map mapProp) { + + for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { + String key = unescapeSQLString(prop.getChild(propChild).getChild(0) + .getText()); + String value = null; + if (prop.getChild(propChild).getChild(1) != null) { + value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); + } + mapProp.put(key, value); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } + + /** + * Get the list of FieldSchema out of the ASTNode. + */ + public static List getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { + List colList = new ArrayList(); + int numCh = ast.getChildCount(); + for (int i = 0; i < numCh; i++) { + FieldSchema col = new FieldSchema(); + ASTNode child = (ASTNode) ast.getChild(i); + Tree grandChild = child.getChild(0); + if(grandChild != null) { + String name = grandChild.getText(); + if(lowerCase) { + name = name.toLowerCase(); + } + // child 0 is the name of the column + col.setName(unescapeIdentifier(name)); + // child 1 is the type of the column + ASTNode typeChild = (ASTNode) (child.getChild(1)); + col.setType(getTypeStringFromAST(typeChild)); + + // child 2 is the optional comment of the column + if (child.getChildCount() == 3) { + col.setComment(unescapeSQLString(child.getChild(2).getText())); + } + } + colList.add(col); + } + return colList; + } + + protected static String getTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + switch (typeNode.getType()) { + case SparkSqlParser.TOK_LIST: + return serdeConstants.LIST_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; + case SparkSqlParser.TOK_MAP: + return serdeConstants.MAP_TYPE_NAME + "<" + + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," + + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; + case SparkSqlParser.TOK_STRUCT: + return getStructTypeStringFromAST(typeNode); + case SparkSqlParser.TOK_UNIONTYPE: + return getUnionTypeStringFromAST(typeNode); + default: + return getTypeName(typeNode); + } + } + + private static String getStructTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty struct not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + ASTNode child = (ASTNode) typeNode.getChild(i); + buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); + buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); + if (i < children - 1) { + buffer.append(","); + } + } + + buffer.append(">"); + return buffer.toString(); + } + + private static String getUnionTypeStringFromAST(ASTNode typeNode) + throws SemanticException { + String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; + typeNode = (ASTNode) typeNode.getChild(0); + int children = typeNode.getChildCount(); + if (children <= 0) { + throw new SemanticException("empty union not allowed."); + } + StringBuilder buffer = new StringBuilder(typeStr); + for (int i = 0; i < children; i++) { + buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); + if (i < children - 1) { + buffer.append(","); + } + } + buffer.append(">"); + typeStr = buffer.toString(); + return typeStr; + } + + public static String getAstNodeText(ASTNode tree) { + return tree.getChildCount() == 0?tree.getText() : + getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); + } + + public static String generateErrorMessage(ASTNode ast, String message) { + StringBuilder sb = new StringBuilder(); + if (ast == null) { + sb.append(message).append(". Cannot tell the position of null AST."); + return sb.toString(); + } + sb.append(ast.getLine()); + sb.append(":"); + sb.append(ast.getCharPositionInLine()); + sb.append(" "); + sb.append(message); + sb.append(". Error encountered near token '"); + sb.append(getAstNodeText(ast)); + sb.append("'"); + return sb.toString(); + } + + private static final Map TokenToTypeName = new HashMap(); + + static { + TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); + TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); + } + + public static String getTypeName(ASTNode node) throws SemanticException { + int token = node.getType(); + String typeName; + + // datetime type isn't currently supported + if (token == SparkSqlParser.TOK_DATETIME) { + throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); + } + + switch (token) { + case SparkSqlParser.TOK_CHAR: + CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); + typeName = charTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_VARCHAR: + VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); + typeName = varcharTypeInfo.getQualifiedName(); + break; + case SparkSqlParser.TOK_DECIMAL: + DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); + typeName = decTypeInfo.getQualifiedName(); + break; + default: + typeName = TokenToTypeName.get(token); + } + return typeName; + } + + public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { + boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); + if (testMode) { + URI uri = new Path(location).toUri(); + String scheme = uri.getScheme(); + String authority = uri.getAuthority(); + String path = uri.getPath(); + if (!path.startsWith("/")) { + path = (new Path(System.getProperty("test.tmp.dir"), + path)).toUri().getPath(); + } + if (StringUtils.isEmpty(scheme)) { + scheme = "pfile"; + } + try { + uri = new URI(scheme, authority, path, null, null); + } catch (URISyntaxException e) { + throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); + } + return uri.toString(); + } else { + //no-op for non-test mode for now + return location; + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0e89928cb636d..b1d841d1b5543 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -27,28 +27,28 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse._ +import org.apache.hadoop.hive.ql.parse.SemanticException import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, catalyst} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} +import org.apache.spark.sql.parser._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -227,7 +227,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren.asJava) + newChildren.foreach(n.addChild(_)) n } @@ -273,7 +273,8 @@ private[hive] object HiveQl extends Logging { private def createContext(): Context = new Context(hiveConf) private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) + ParseUtils.findRootNonNullToken( + (new ParseDriver).parse(sql, context)) /** * Returns the HiveConf @@ -312,7 +313,7 @@ private[hive] object HiveQl extends Logging { context.clear() plan } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => + case pe: ParseException => pe.getMessage match { case errorRegEx(line, start, message) => throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) @@ -337,7 +338,8 @@ private[hive] object HiveQl extends Logging { val tree = try { ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) + (new ParseDriver) + .parse(ddl, null /* no context required for parsing alone */)) } catch { case pe: org.apache.hadoop.hive.ql.parse.ParseException => throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) @@ -598,12 +600,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { tableType match { - case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { - nameParts.head match { + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => { + nameParts match { case Token(".", dbName :: tableName :: Nil) => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts.head) + val tableIdent = extractTableIdent(nameParts) DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => @@ -662,7 +664,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C NativePlaceholder } else { val schema = maybeColumns.map { cols => - BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + SemanticAnalyzer.getColumns(cols, true).asScala.map { field => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. HiveColumn(field.getName, null, field.getComment) @@ -678,7 +680,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeComment.foreach { case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + val comment = SemanticAnalyzer.unescapeSQLString(child.getText) if (comment ne null) { properties += ("comment" -> comment) } @@ -750,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C children.collect { case list @ Token("TOK_TABCOLLIST", _) => - val cols = BaseSemanticAnalyzer.getColumns(list, true) + val cols = SemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( schema = cols.asScala.map { field => @@ -758,11 +760,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }) } case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + val comment = SemanticAnalyzer.unescapeSQLString(child.getText) // TODO support the sql text tableDesc = tableDesc.copy(viewText = Option(comment)) case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = BaseSemanticAnalyzer.getColumns(list(0), false) + val cols = SemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( partitionColumns = cols.asScala.map { field => @@ -773,21 +775,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) + val fieldDelim = SemanticAnalyzer.unescapeSQLString (rowChild1.getText()) serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) if (rowChild2.length > 1) { - val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) + val fieldEscape = SemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val collItemDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val mapKeyDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val lineDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) if (!(lineDelim == "\n") && !(lineDelim == "10")) { throw new AnalysisException( SemanticAnalyzer.generateErrorMessage( @@ -796,22 +798,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + val nullFormat = SemanticAnalyzer.unescapeSQLString(rowChild.getText) // TODO support the nullFormat case _ => assert(false) } tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => - var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - location = EximUtil.relativeToAbsolutePath(hiveConf, location) + var location = SemanticAnalyzer.unescapeSQLString(child.getText) + location = SemanticAnalyzer.relativeToAbsolutePath(hiveConf, location) tableDesc = tableDesc.copy(location = Option(location)) case Token("TOK_TABLESERIALIZER", child :: Nil) => tableDesc = tableDesc.copy( - serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) + serde = Option(SemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) if (child.getChildCount == 2) { val serdeParams = new java.util.HashMap[String, String]() - BaseSemanticAnalyzer.readProps( + SemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) tableDesc = tableDesc.copy( serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) @@ -891,9 +893,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), + Option(SemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), outputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) + Option(SemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) case Token("TOK_STORAGEHANDLER", _) => throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) case _ => // Unsupport features @@ -909,24 +911,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) - if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + if Seq("TOK_CTE", "TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - // check if has CTE - insertClauses.last match { - case Token("TOK_CTE", cteClauses) => - val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - (relation.alias, relation) - }).toMap - (Some(args.head), insertClauses.init, Some(cteRelations)) - - case _ => (Some(args.head), insertClauses, None) + case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => + val cteRelations = ctes.map { node => + val relation = nodeToRelation(node, context).asInstanceOf[Subquery] + relation.alias -> relation } - - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + (Some(from.head), inserts, Some(cteRelations.toMap)) + case Token("TOK_FROM", from) :: inserts => + (Some(from.head), inserts, None) + case Token("TOK_INSERT", _) :: Nil => + (None, queryArgs, None) } // Return one query for each insert clause. @@ -1025,20 +1023,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) + (Nil, Some(SemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (BaseSemanticAnalyzer.unescapeSQLString(name), - BaseSemanticAnalyzer.unescapeSQLString(value)) + (SemanticAnalyzer.unescapeSQLString(name), + SemanticAnalyzer.unescapeSQLString(value)) } // SPARK-10310: Special cases LazySimpleSerDe // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val unescapedSerDeClass = SemanticAnalyzer.unescapeSQLString(serdeClass) val useDefaultRecordReaderWriter = unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) @@ -1055,7 +1053,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = matchSerDe(outputSerdeClause) - val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + val unescapedScript = SemanticAnalyzer.unescapeSQLString(script) // TODO Adds support for user-defined record reader/writer classes val recordReaderClass = if (useDefaultRecordReader) { @@ -1361,6 +1359,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_ANTIJOIN" => throw new NotImplementedError("Anti join not supported") } Join(nodeToRelation(relation1, context), nodeToRelation(relation2, context), @@ -1475,11 +1474,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val numericAstTypes = Seq( - HiveParser.Number, - HiveParser.TinyintLiteral, - HiveParser.SmallintLiteral, - HiveParser.BigintLiteral, - HiveParser.DecimalLiteral) + SparkSqlParser.Number, + SparkSqlParser.TinyintLiteral, + SparkSqlParser.SmallintLiteral, + SparkSqlParser.BigintLiteral, + SparkSqlParser.DecimalLiteral) /* Case insensitive matches */ val COUNT = "(?i)COUNT".r @@ -1649,7 +1648,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(TRUE(), Nil) => Literal.create(true, BooleanType) case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + Literal(strings.map(s => SemanticAnalyzer.unescapeSQLString(s.getText)).mkString) // This code is adapted from // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 @@ -1684,37 +1683,37 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C v } - case ast: ASTNode if ast.getType == HiveParser.StringLiteral => - Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == SparkSqlParser.StringLiteral => + Literal(SemanticAnalyzer.unescapeSQLString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => - Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_CHARSETLITERAL => + Literal(SemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => Literal(CalendarInterval.fromYearMonthString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => Literal(CalendarInterval.fromDayTimeString(ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) case a: ASTNode => From 94f7a12b3c8e4a6ecd969893e562feb7ffba4c24 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 2 Jan 2016 00:04:48 -0800 Subject: [PATCH 1403/2089] [SPARK-10180][SQL] JDBC datasource are not processing EqualNullSafe filter This PR is followed by https://github.com/apache/spark/pull/8391. Previous PR fixes JDBCRDD to support null-safe equality comparison for JDBC datasource. This PR fixes the problem that it can actually return null as a result of the comparison resulting error as using the value of that comparison. Author: hyukjinkwon Author: HyukjinKwon Closes #8743 from HyukjinKwon/SPARK-10180. --- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c74574d280a3e..87d43addd36ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -193,6 +193,9 @@ private[sql] object JDBCRDD extends Logging { private def compileFilter(f: Filter): Option[String] = { Option(f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualNullSafe(attr, value) => + s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + + s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))" case LessThan(attr, value) => s"$attr < ${compileValue(value)}" case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" @@ -320,7 +323,6 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } - /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 633ae215fc123..dae72e8acb5a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -185,6 +185,7 @@ class JDBCSuite extends SparkFunSuite assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) @@ -473,7 +474,9 @@ class JDBCSuite extends SparkFunSuite === "(NOT (col1 IN ('mno', 'pqr')))") assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") - assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) === "") + assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) + === "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) " + + "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')") } test("Dialect unregister") { From 15bd73627e04591fd13667b4838c9098342db965 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Jan 2016 13:15:53 +0000 Subject: [PATCH 1404/2089] [SPARK-12481][CORE][STREAMING][SQL] Remove usage of Hadoop deprecated APIs and reflection that supported 1.x Remove use of deprecated Hadoop APIs now that 2.2+ is required Author: Sean Owen Closes #10446 from srowen/SPARK-12481. --- .../scala/org/apache/spark/SparkContext.scala | 16 ++--- .../org/apache/spark/SparkHadoopWriter.scala | 20 +++--- .../apache/spark/deploy/SparkHadoopUtil.scala | 41 ++--------- .../deploy/history/FsHistoryProvider.scala | 16 ++--- .../input/FixedLengthBinaryInputFormat.scala | 3 +- .../input/FixedLengthBinaryRecordReader.scala | 7 +- .../spark/input/PortableDataStream.scala | 7 +- .../input/WholeTextFileInputFormat.scala | 2 +- .../input/WholeTextFileRecordReader.scala | 5 +- .../spark/mapred/SparkHadoopMapRedUtil.scala | 51 +------------- .../mapreduce/SparkHadoopMapReduceUtil.scala | 68 ------------------ .../org/apache/spark/rdd/BinaryFileRDD.scala | 4 +- .../org/apache/spark/rdd/HadoopRDD.scala | 3 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 12 ++-- .../apache/spark/rdd/PairRDDFunctions.scala | 26 +++---- .../spark/rdd/ReliableCheckpointRDD.scala | 3 +- .../apache/spark/rdd/WholeTextFileRDD.scala | 3 +- .../scheduler/EventLoggingListener.scala | 12 +--- .../spark/scheduler/InputFormatInfo.scala | 2 +- .../spark/util/ShutdownHookManager.scala | 19 +---- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../scala/org/apache/spark/FileSuite.scala | 9 ++- .../scheduler/EventLoggingListenerSuite.scala | 4 +- .../spark/scheduler/ReplayListenerSuite.scala | 3 +- .../spark/examples/CassandraCQLTest.scala | 4 +- .../apache/spark/examples/CassandraTest.scala | 4 +- project/MimaExcludes.scala | 5 ++ scalastyle-config.xml | 8 --- .../InsertIntoHadoopFsRelation.scala | 2 +- .../datasources/SqlNewHadoopRDD.scala | 18 +++-- .../datasources/WriterContainer.scala | 27 +++---- .../datasources/json/JSONRelation.scala | 12 ++-- .../parquet/CatalystReadSupport.scala | 2 - .../DirectParquetOutputCommitter.scala | 6 +- .../datasources/parquet/ParquetRelation.scala | 20 +++--- .../datasources/text/DefaultSource.scala | 13 ++-- .../apache/spark/sql/sources/interfaces.scala | 8 +-- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/client/ClientWrapper.scala | 70 ------------------- .../spark/sql/hive/hiveWriterContainers.scala | 13 ++-- .../spark/sql/hive/orc/OrcFileOperator.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 16 ++--- .../sql/sources/SimpleTextRelation.scala | 5 +- .../util/FileBasedWriteAheadLog.scala | 3 +- .../util/FileBasedWriteAheadLogWriter.scala | 10 +-- .../streaming/util/WriteAheadLogSuite.scala | 3 +- 46 files changed, 150 insertions(+), 441 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index bbdc9158d8e2b..77e44ee0264af 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -874,11 +874,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = withScope { assertNotStopped() - val job = new NewHadoopJob(hadoopConfiguration) + val job = NewHadoopJob.getInstance(hadoopConfiguration) // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updateConf = job.getConfiguration new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -923,11 +923,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { assertNotStopped() - val job = new NewHadoopJob(hadoopConfiguration) + val job = NewHadoopJob.getInstance(hadoopConfiguration) // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updateConf = job.getConfiguration new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1100,13 +1100,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() - // The call to new NewHadoopJob automatically adds security credentials to conf, + // The call to NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves - val job = new NewHadoopJob(conf) + val job = NewHadoopJob.getInstance(conf) // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updatedConf = job.getConfiguration new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1369,7 +1369,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (!fs.exists(hadoopPath)) { throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") } - val isDir = fs.getFileStatus(hadoopPath).isDir + val isDir = fs.getFileStatus(hadoopPath).isDirectory if (!isLocal && scheme == "file" && isDir) { throw new SparkException(s"addFile does not support local directories when not running " + "local mode.") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ac6eaab20d8d2..dd400b8ae8a16 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -25,6 +25,7 @@ import java.util.Date import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -37,10 +38,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { +class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private val now = new Date() private val conf = new SerializableJobConf(jobConf) @@ -131,7 +129,7 @@ class SparkHadoopWriter(jobConf: JobConf) private def getJobContext(): JobContext = { if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) + jobContext = new JobContextImpl(conf.value, jID.value) } jobContext } @@ -143,6 +141,12 @@ class SparkHadoopWriter(jobConf: JobConf) taskContext } + protected def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + new TaskAttemptContextImpl(conf, attemptId) + } + private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { jobID = jobid splitID = splitid @@ -150,7 +154,7 @@ class SparkHadoopWriter(jobConf: JobConf) jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } } @@ -168,9 +172,9 @@ object SparkHadoopWriter { } val outputPath = new Path(path) val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { + if (fs == null) { throw new IllegalArgumentException("Incorrectly formatted output path") } - outputPath.makeQualified(fs) + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 59e90564b3516..4bd94f13e57e6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -33,9 +33,6 @@ import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.JobContext -import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} -import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.annotation.DeveloperApi @@ -76,9 +73,6 @@ class SparkHadoopUtil extends Logging { } } - @deprecated("use newConfiguration with SparkConf argument", "1.2.0") - def newConfiguration(): Configuration = newConfiguration(null) - /** * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop * subsystems. @@ -190,33 +184,6 @@ class SparkHadoopUtil extends Logging { statisticsDataClass.getDeclaredMethod(methodName) } - /** - * Using reflection to get the Configuration from JobContext/TaskAttemptContext. If we directly - * call `JobContext/TaskAttemptContext.getConfiguration`, it will generate different byte codes - * for Hadoop 1.+ and Hadoop 2.+ because JobContext/TaskAttemptContext is class in Hadoop 1.+ - * while it's interface in Hadoop 2.+. - */ - def getConfigurationFromJobContext(context: JobContext): Configuration = { - // scalastyle:off jobconfig - val method = context.getClass.getMethod("getConfiguration") - // scalastyle:on jobconfig - method.invoke(context).asInstanceOf[Configuration] - } - - /** - * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly - * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes - * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ - * while it's interface in Hadoop 2.+. - */ - def getTaskAttemptIDFromTaskAttemptContext( - context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { - // scalastyle:off jobconfig - val method = context.getClass.getMethod("getTaskAttemptID") - // scalastyle:on jobconfig - method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] - } - /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of @@ -233,11 +200,11 @@ class SparkHadoopUtil extends Logging { */ def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { - val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir) + val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDirectory) leaves ++ directories.flatMap(f => listLeafStatuses(fs, f)) } - if (baseStatus.isDir) recurse(baseStatus) else Seq(baseStatus) + if (baseStatus.isDirectory) recurse(baseStatus) else Seq(baseStatus) } def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { @@ -246,12 +213,12 @@ class SparkHadoopUtil extends Logging { def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { - val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir) + val (directories, files) = fs.listStatus(status.getPath).partition(_.isDirectory) val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus] leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir)) } - assert(baseStatus.isDir) + assert(baseStatus.isDirectory) recurse(baseStatus) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6e91d73b6e0fd..c93bc8c127f58 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -167,7 +168,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } throw new IllegalArgumentException(msg) } - if (!fs.getFileStatus(path).isDir) { + if (!fs.getFileStatus(path).isDirectory) { throw new IllegalArgumentException( "Logging directory specified is not a directory: %s".format(logDir)) } @@ -304,7 +305,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logError("Exception encountered when attempting to update last scan time", e) lastScanTime } finally { - if (!fs.delete(path)) { + if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") } } @@ -603,7 +604,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * As of Spark 1.3, these files are consolidated into a single one that replaces the directory. * See SPARK-2261 for more detail. */ - private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() + private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDirectory /** * Returns the modification time of the given event log. If the status points at an empty @@ -648,8 +649,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Checks whether HDFS is in safe mode. The API is slightly different between hadoop 1 and 2, - * so we have to resort to ugly reflection (as usual...). + * Checks whether HDFS is in safe mode. * * Note that DistributedFileSystem is a `@LimitedPrivate` class, which for all practical reasons * makes it more public than not. @@ -663,11 +663,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - val hadoop2Class = "org.apache.hadoop.hdfs.protocol.HdfsConstants$SafeModeAction" - val actionClass: Class[_] = getClass().getClassLoader().loadClass(hadoop2Class) - val action = actionClass.getField("SAFEMODE_GET").get(null) - val method = dfs.getClass().getMethod("setSafeMode", action.getClass()) - method.invoke(dfs, action).asInstanceOf[Boolean] + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) } } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index 532850dd57716..30431a9b986bb 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -23,7 +23,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil /** * Custom Input Format for reading and splitting flat binary files that contain records, @@ -36,7 +35,7 @@ private[spark] object FixedLengthBinaryInputFormat { /** Retrieves the record length property from a Hadoop configuration */ def getRecordLength(context: JobContext): Int = { - SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt } } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 67a96925da019..25596a15d93c0 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -24,7 +24,6 @@ import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileSplit -import org.apache.spark.deploy.SparkHadoopUtil /** * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. @@ -83,16 +82,16 @@ private[spark] class FixedLengthBinaryRecordReader // the actual file we will be reading from val file = fileSplit.getPath // job configuration - val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration // check compression - val codec = new CompressionCodecFactory(job).getCodec(file) + val codec = new CompressionCodecFactory(conf).getCodec(file) if (codec != null) { throw new IOException("FixedLengthRecordReader does not support reading compressed files") } // get the record length recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) // get the filesystem - val fs = file.getFileSystem(job) + val fs = file.getFileSystem(conf) // open the File fileInputStream = fs.open(file) // seek to the splitStart position diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 280e7a5fe893c..cb76e3c344fca 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,8 +27,6 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.deploy.SparkHadoopUtil - /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -44,7 +42,7 @@ private[spark] abstract class StreamFileInputFormat[T] */ def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } @@ -135,8 +133,7 @@ class PortableDataStream( private val confBytes = { val baos = new ByteArrayOutputStream() - SparkHadoopUtil.get.getConfigurationFromJobContext(context). - write(new DataOutputStream(baos)) + context.getConfiguration.write(new DataOutputStream(baos)) baos.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 413408723b54d..fa34f1e886c72 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,7 +53,7 @@ private[spark] class WholeTextFileInputFormat */ def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index b56b2aa88a414..998c898a3fc25 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -26,8 +26,6 @@ import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.deploy.SparkHadoopUtil - /** * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface. @@ -52,8 +50,7 @@ private[spark] class WholeTextFileRecordReader( extends RecordReader[Text, Text] with Configurable { private[this] val path = split.getPath(index) - private[this] val fs = path.getFileSystem( - SparkHadoopUtil.get.getConfigurationFromJobContext(context)) + private[this] val fs = path.getFileSystem(context.getConfiguration) // True means the current file has been processed, then skip it. private[this] var processed = false diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f7298e8d5c62c..249bdf5994f8f 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -18,61 +18,12 @@ package org.apache.spark.mapred import java.io.IOException -import java.lang.reflect.Modifier -import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} -import org.apache.spark.util.{Utils => SparkUtils} - -private[spark] -trait SparkHadoopMapRedUtil { - def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", - "org.apache.hadoop.mapred.JobContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], - classOf[org.apache.hadoop.mapreduce.JobID]) - // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. - // Make it accessible if it's not in order to access it. - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", - "org.apache.hadoop.mapred.TaskAttemptContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) - // See above - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) - } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - SparkUtils.classForName(first) - } catch { - case e: ClassNotFoundException => - SparkUtils.classForName(second) - } - } -} object SparkHadoopMapRedUtil extends Logging { /** @@ -93,7 +44,7 @@ object SparkHadoopMapRedUtil extends Logging { jobId: Int, splitId: Int): Unit = { - val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID // Called after we have decided to commit def performCommit(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala deleted file mode 100644 index 82d807fad8938..0000000000000 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mapreduce - -import java.lang.{Boolean => JBoolean, Integer => JInteger} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -import org.apache.spark.util.Utils - -private[spark] -trait SparkHadoopMapReduceUtil { - def newJobContext(conf: Configuration, jobId: JobID): JobContext = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.task.JobContextImpl") - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID]) - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl") - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID]) - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") - try { - // First, attempt to use the old-style constructor that takes a boolean isMap - // (not available in YARN) - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } catch { - case exc: NoSuchMethodException => { - // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") - .asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if (isMap) "MAP" else "REDUCE") - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index aedced7408cde..2bf2337d49fef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{ Configurable, Configuration } import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.JobContextImpl + import org.apache.spark.input.StreamFileInputFormat import org.apache.spark.{ Partition, SparkContext } @@ -40,7 +42,7 @@ private[spark] class BinaryFileRDD[T]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f37c95bedc0a5..920d3bf219ff5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID import org.apache.hadoop.mapred.lib.CombineFileSplit +import org.apache.hadoop.mapreduce.TaskType import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -357,7 +358,7 @@ private[spark] object HadoopRDD extends Logging { def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int, conf: JobConf) { val jobID = new JobID(jobTrackerId, jobId) - val taId = new TaskAttemptID(new TaskID(jobID, true, splitId), attemptId) + val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) conf.set("mapred.tip.id", taId.getTaskID.toString) conf.set("mapred.task.id", taId.toString) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 86f38ae836b2b..8b330a34c3d3a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -26,11 +26,11 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark.annotation.DeveloperApi import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.deploy.SparkHadoopUtil @@ -66,9 +66,7 @@ class NewHadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], @transient private val _conf: Configuration) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[(K, V)](sc, Nil) with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) @@ -109,7 +107,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(_conf, jobId) + val jobContext = new JobContextImpl(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -144,8 +142,8 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 44d195587a081..b87230142532b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskType, TaskAttemptID} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -53,10 +52,7 @@ import org.apache.spark.util.random.StratifiedSamplingUtils */ class PairRDDFunctions[K, V](self: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable -{ + extends Logging with Serializable { /** * :: Experimental :: @@ -985,11 +981,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration jobConfiguration.set("mapred.output.dir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1074,11 +1070,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1091,9 +1087,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, context.attemptNumber) - val hadoopContext = newTaskAttemptContext(config, attemptId) + val hadoopContext = new TaskAttemptContextImpl(config, attemptId) val format = outfmt.newInstance format match { case c: Configurable => c.setConf(config) @@ -1125,8 +1121,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) 1 } : Int - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) + val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) // When speculation is on and output committer class name contains "Direct", we should warn diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index fa71b8c26233d..a9b3d52bbee02 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -174,7 +174,8 @@ private[spark] object ReliableCheckpointRDD extends Logging { fs.create(tempOutputPath, false, bufferSize) } else { // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + fs.create(tempOutputPath, false, bufferSize, + fs.getDefaultReplication(fs.getWorkingDirectory), blockSize) } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index e3f14fe7ef0f8..8e1baae796fc5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.{Text, Writable} import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.WholeTextFileInputFormat @@ -44,7 +45,7 @@ private[spark] class WholeTextFileRDD( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index eaa07acc5132e..68792c58c9b4e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -77,14 +77,6 @@ private[spark] class EventLoggingListener( // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None - // The Hadoop APIs have changed over time, so we use reflection to figure out - // the correct method to use to flush a hadoop data stream. See SPARK-1518 - // for details. - private val hadoopFlushMethod = { - val cls = classOf[FSDataOutputStream] - scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync")) - } - private var writer: Option[PrintWriter] = None // For testing. Keep track of all JSON serialized events that have been logged. @@ -97,7 +89,7 @@ private[spark] class EventLoggingListener( * Creates the log file in the configured log directory. */ def start() { - if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) { + if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) { throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.") } @@ -147,7 +139,7 @@ private[spark] class EventLoggingListener( // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) - hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) + hadoopDataStream.foreach(_.hflush()) } if (testing) { loggedEvents += eventJson diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0e438ab4366d9..8235b10245376 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ org.apache.hadoop.mapreduce.InputFormat[_, _]] - val job = new Job(conf) + val job = Job.getInstance(conf) val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 0065b1fc660b0..acc24ca0fb814 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.util import java.io.File import java.util.PriorityQueue -import scala.util.{Failure, Success, Try} +import scala.util.Try import org.apache.hadoop.fs.FileSystem import org.apache.spark.Logging @@ -177,21 +177,8 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem] - .getField("SHUTDOWN_HOOK_PRIORITY") - .get(null) // static field, the value is not used - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - // scalastyle:on runtimeaddshutdownhook - } + org.apache.hadoop.util.ShutdownHookManager.get().addShutdownHook( + hookTask, FileSystem.SHUTDOWN_HOOK_PRIORITY + 30) } def runAll(): Unit = { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 11f1248c24d38..d91948e44694b 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1246,7 +1246,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.newAPIHadoopFile(outputDir, org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, new Job().getConfiguration()); + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index f6a7f4375fac8..2e47801aafd75 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark import java.io.{File, FileWriter} -import org.apache.spark.deploy.SparkHadoopUtil +import scala.io.Source + import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel -import scala.io.Source - import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} @@ -506,11 +505,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - val job = new Job(sc.hadoopConfiguration) + val job = Job.getInstance(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfig = job.getConfiguration jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 5cb2d4225d281..43da6fc5b5474 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -67,11 +67,11 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) assert(fileSystem.exists(logPath)) val logStatus = fileSystem.getFileStatus(logPath) - assert(!logStatus.isDir) + assert(!logStatus.isDirectory) // Verify log is renamed after stop() eventLogger.stop() - assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDir) + assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDirectory) } test("Basic event logging") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 103fc19369c97..761e82e6cf1ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -23,7 +23,6 @@ import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec @@ -115,7 +114,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applications = fileSystem.listStatus(logDirPath) assert(applications != null && applications.size > 0) val eventLog = applications.sortBy(_.getModificationTime).last - assert(!eventLog.isDir) + assert(!eventLog.isDirectory) // Replay events val logData = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index d1b9b8d398dd8..5a80985a49458 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,7 +16,6 @@ */ // scalastyle:off println - // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -80,7 +79,7 @@ object CassandraCQLTest { val InputColumnFamily = "ordercf" val OutputColumnFamily = "salecount" - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[CqlPagingInputFormat]) val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) @@ -137,4 +136,3 @@ object CassandraCQLTest { } } // scalastyle:on println -// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 1e679bfb55343..ad39a012b4ae6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,7 +16,6 @@ */ // scalastyle:off println -// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -59,7 +58,7 @@ object CassandraTest { val sc = new SparkContext(sparkConf) // Build the job configuration with ConfigHelper provided by Cassandra - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) val host: String = args(1) @@ -131,7 +130,6 @@ object CassandraTest { } } // scalastyle:on println -// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 59886ab76244a..612ddf86ded3e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -49,6 +49,11 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ + Seq( + // SPARK-12481 + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.mapred.SparkHadoopMapRedUtil") + ) ++ // When 1.6 is officially released, update this exclusion list. Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 16d18b3328ff7..ee855ca0e09cb 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -187,14 +187,6 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods - - - ^getConfiguration$|^getTaskAttemptID$ - Instead of calling .getConfiguration() or .getTaskAttemptID() directly, - use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. - - - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 735d52f808868..758bcd706a8c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -93,7 +93,7 @@ private[sql] case class InsertIntoHadoopFsRelation( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - val job = new Job(hadoopConf) + val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index eea780cbaa7e3..12f8783f846d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -26,10 +26,10 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel @@ -68,16 +68,14 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sqlContext.sparkContext, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[V](sqlContext.sparkContext, Nil) with Logging { protected def getJob(): Job = { - val conf: Configuration = broadcastedConf.value.value + val conf = broadcastedConf.value.value // "new Job" will make a copy of the conf. Then, it is // safe to mutate conf properties with initLocalJobFuncOpt // and initDriverSideJobFuncOpt. - val newJob = new Job(conf) + val newJob = Job.getInstance(conf) initLocalJobFuncOpt.map(f => f(newJob)) newJob } @@ -87,7 +85,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - SparkHadoopUtil.get.getConfigurationFromJobContext(job) + job.getConfiguration } private val jobTrackerId: String = { @@ -110,7 +108,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[SparkPartition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -154,8 +152,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private[this] var reader: RecordReader[Void, V] = null /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 983f4df1de369..8b0b647744559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -24,10 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow @@ -41,14 +41,12 @@ private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, @transient private val job: Job, isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { + extends Logging with Serializable { protected val dataSchema = relation.dataSchema protected val serializableConf = - new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + new SerializableConfiguration(job.getConfiguration) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -90,8 +88,7 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - SparkHadoopUtil.get.getConfigurationFromJobContext(job). - set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -101,7 +98,7 @@ private[sql] abstract class BaseWriterContainer( // committer, since their initialization involve the job configuration, which can be potentially // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) @@ -111,7 +108,7 @@ private[sql] abstract class BaseWriterContainer( def executorSideSetup(taskContext: TaskContext): Unit = { setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupTask(taskAttemptContext) } @@ -166,7 +163,7 @@ private[sql] abstract class BaseWriterContainer( "because spark.speculation is configured to be true.") defaultOutputCommitter } else { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val committerClass = configuration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) @@ -201,10 +198,8 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - // scalastyle:off jobcontext + this.taskId = new TaskID(this.jobId, TaskType.MAP, splitId) this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - // scalastyle:on jobcontext } private def setupConf(): Unit = { @@ -250,7 +245,7 @@ private[sql] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) @@ -421,7 +416,7 @@ private[sql] class DynamicPartitionWriterContainer( def newOutputWriter(key: InternalRow): OutputWriter = { val partitionPath = getPartitionString(key).getString(0) val path = new Path(getWorkPath, partitionPath) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + val configuration = taskAttemptContext.getConfiguration configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) val newWriter = super.newOutputWriter(path.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 3e61ba35bea8e..54a8552134c82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -30,8 +30,6 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection @@ -89,8 +87,8 @@ private[sql] class JSONRelation( override val needConversion: Boolean = false private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath) @@ -176,7 +174,7 @@ private[json] class JsonOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with Logging { + extends OutputWriter with Logging { private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records @@ -186,9 +184,9 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index a958373eb769d..e5d8e6088b395 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -58,9 +58,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with */ override def init(context: InitContext): ReadContext = { catalystRequestedSchema = { - // scalastyle:off jobcontext val conf = context.getConfiguration - // scalastyle:on jobcontext val schemaString = conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1a4e99ff10afb..e54f51e3830f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -54,11 +54,7 @@ private[datasources] class DirectParquetOutputCommitter( override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(jobContext) - // scalastyle:on jobcontext - } + val configuration = ContextUtil.getConfiguration(jobContext) val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1af2a394f399a..af964b4d35611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName @@ -40,7 +41,6 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -82,9 +82,9 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } @@ -217,11 +217,7 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(job) - // scalastyle:on jobcontext - } + val conf = ContextUtil.getConfiguration(job) // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -340,7 +336,7 @@ private[sql] class ParquetRelation( // URI of the path to create a new Path. val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime, f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) }.toSeq @@ -359,7 +355,7 @@ private[sql] class ParquetRelation( } } - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val jobContext = new JobContextImpl(getConf(isDriverSide = true), jobId) val rawSplits = inputFormat.getSplits(jobContext) Array.tabulate[SparkPartition](rawSplits.size) { i => @@ -564,7 +560,7 @@ private[sql] object ParquetRelation extends Logging { parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val conf = job.getConfiguration conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -607,7 +603,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + overrideMinSplitSize(parquetBlockSize, job.getConfiguration) } private[parquet] def readSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 41fcb11d84bff..248467abe9f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -26,8 +26,6 @@ import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -88,8 +86,8 @@ private[sql] class TextRelation( filters: Array[Filter], inputPaths: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath).sortBy(_.toUri) if (paths.nonEmpty) { @@ -138,17 +136,16 @@ private[sql] class TextRelation( } class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter - with SparkHadoopMapRedUtil { + extends OutputWriter { private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index fc8ce6901dfca..d6c5d1435702d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -462,7 +462,7 @@ abstract class HadoopFsRelation private[sql]( name.toLowerCase == "_temporary" || name.startsWith(".") } - val (dirs, files) = statuses.partition(_.isDir) + val (dirs, files) = statuses.partition(_.isDirectory) // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) if (dirs.isEmpty) { @@ -858,10 +858,10 @@ private[sql] object HadoopFsRelation extends Logging { val jobConf = new JobConf(fs.getConf, this.getClass()) val pathFilter = FileInputFormat.getInputPathFilter(jobConf) if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } @@ -896,7 +896,7 @@ private[sql] object HadoopFsRelation extends Logging { FakeFileStatus( status.getPath.toString, status.getLen, - status.isDir, + status.isDirectory, status.getReplication, status.getBlockSize, status.getModificationTime, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 384ea211df843..5d00e7367026f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -380,7 +380,7 @@ class HiveContext private[hive]( def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDir) { + val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath().getName().startsWith(stagingDir)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 598ccdeee4ad2..d3da22aa0ae5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -31,9 +31,7 @@ import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Driver, metadata} -import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkException, Logging} import org.apache.spark.sql.catalyst.expressions.Expression @@ -65,74 +63,6 @@ private[hive] class ClientWrapper( extends ClientInterface with Logging { - overrideHadoopShims() - - // !! HACK ALERT !! - // - // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking - // major version number gathered from Hadoop jar files: - // - // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with - // security. - // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. - // - // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It - // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but - // `Hadoop23Shims` is chosen because the major version number here is 2. - // - // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` - // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is - // not available, we decide whether to override the shims or not by checking for existence of a - // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. - private def overrideHadoopShims(): Unit = { - val hadoopVersion = VersionInfo.getVersion - val VersionPattern = """(\d+)\.(\d+).*""".r - - hadoopVersion match { - case null => - logError("Failed to inspect Hadoop version") - - // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. - val probeMethod = "getPathWithoutSchemeAndAuthority" - if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { - logInfo( - s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + - s"we are probably using Hadoop 1.x or 2.0.x") - loadHadoop20SShims() - } - - case VersionPattern(majorVersion, minorVersion) => - logInfo(s"Inspected Hadoop version: $hadoopVersion") - - // Loads Hadoop20SShims for 1.x and 2.0.x versions - val (major, minor) = (majorVersion.toInt, minorVersion.toInt) - if (major < 2 || (major == 2 && minor == 0)) { - loadHadoop20SShims() - } - } - - // Logs the actual loaded Hadoop shims class - val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName - logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") - } - - private def loadHadoop20SShims(): Unit = { - val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" - logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") - - try { - val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") - // scalastyle:off classforname - val shimsClass = Class.forName(hadoop20SShimsClassName) - // scalastyle:on classforname - val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) - shimsField.setAccessible(true) - shimsField.set(null, shims) - } catch { case cause: Throwable => - throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) - } - } - // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. private val outputBuffer = new CircularBuffer() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 93c016b6c6c7a..777e7857d2db2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -27,9 +27,10 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ -import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} @@ -46,9 +47,7 @@ import org.apache.spark.util.SerializableJobConf private[hive] class SparkHiveWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { + extends Logging with Serializable { private val now = new Date() private val tableDesc: TableDesc = fileSinkConf.getTableInfo @@ -68,8 +67,8 @@ private[hive] class SparkHiveWriterContainer( @transient private var writer: FileSinkOperator.RecordWriter = null @transient protected lazy val committer = conf.value.getOutputCommitter - @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) - @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient protected lazy val jobContext = new JobContextImpl(conf.value, jID.value) + @transient private lazy val taskContext = new TaskAttemptContextImpl(conf.value, taID.value) @transient private lazy val outputFormat = conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] @@ -131,7 +130,7 @@ private[hive] class SparkHiveWriterContainer( jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } private def setConfParams() { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 0f9a1a6ef3b27..b91a14bdbcc48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -95,7 +95,7 @@ private[orc] object OrcFileOperator extends Logging { val fs = origPath.getFileSystem(conf) val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) - .filterNot(_.isDir) + .filterNot(_.isDirectory) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1136670b7a0eb..84ef12a68e1ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -33,8 +33,6 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -67,7 +65,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with HiveInspectors { private val serializer = { val table = new Properties() @@ -77,7 +75,7 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration serde.initialize(configuration, table) serde } @@ -99,9 +97,9 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" @@ -208,7 +206,7 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { + job.getConfiguration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -289,8 +287,8 @@ private[orc] case class OrcTableScan( } def execute(): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 01960fd2901b0..e10d21d5e3682 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -25,7 +25,6 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} @@ -53,9 +52,9 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 9418beec0d74a..15ad2e27d372f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -224,7 +224,8 @@ private[streaming] class FileBasedWriteAheadLog( val logDirectoryPath = new Path(logDirectory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) pastLogs.clear() pastLogs ++= logFileInfo diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index 1185f30265f63..1f5c1d4369b53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -19,10 +19,7 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import scala.util.Try - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FSDataOutputStream import org.apache.spark.util.Utils @@ -34,11 +31,6 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) - private lazy val hadoopFlushMethod = { - // Use reflection to get the right flush operation - val cls = classOf[FSDataOutputStream] - Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption - } private var nextOffset = stream.getPos() private var closed = false @@ -62,7 +54,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: } private def flush() { - hadoopFlushMethod.foreach { _.invoke(stream) } + stream.hflush() // Useful for local file system where hflush/sync does not work (HADOOP-7844) stream.getWrappedStream.flush() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index beaae34535fd6..a670c7d638192 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -705,7 +705,8 @@ object WriteAheadLogSuite { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { fileSystem.listStatus(logDirectoryPath).map { _.getPath() }.sortBy { _.getName().split("-")(1).toLong }.map { From 513e3b092c4f3d58058ff64c861ea35cfec04205 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 2 Jan 2016 22:31:39 -0800 Subject: [PATCH 1405/2089] [SPARK-12599][MLLIB][SQL] Remove the use of callUDF in MLlib callUDF has been deprecated. However, we do not have an alternative for users to specify the output data type without type tags. This pull request introduced a new API for that, and replaces the invocation of the deprecated callUDF with that. Author: Reynold Xin Closes #10547 from rxin/SPARK-12599. --- .../scala/org/apache/spark/ml/Transformer.scala | 4 ++-- .../scala/org/apache/spark/sql/functions.scala | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590e6d..1f3325ad09ef1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -115,8 +115,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + val transformUDF = udf(this.createTransformFunc, outputDataType) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c5aed6da9c4..3572f3c3a1f2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2843,6 +2843,20 @@ object functions extends LegacyFunctions { // scalastyle:on parameter.number // scalastyle:on line.size.limit + /** + * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must + * specifcy the output data type, and there is no automatic input type coercion. + * + * @param f A closure in Scala + * @param dataType The output data type of the UDF + * + * @group udf_funcs + * @since 2.0.0 + */ + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + UserDefinedFunction(f, dataType, None) + } + /** * Call an user-defined function. * Example: From 6c5bbd628aaedb6efb44c15f816fea8fb600decc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 2 Jan 2016 22:39:25 -0800 Subject: [PATCH 1406/2089] Revert "Revert "[SPARK-12286][SPARK-12290][SPARK-12294][SPARK-12284][SQL] always output UnsafeRow"" This reverts commit 44ee920fd49d35b421ae562ea99bcc8f2b98ced6. --- .../org/apache/spark/sql/SQLContext.scala | 3 +- .../apache/spark/sql/execution/Exchange.scala | 25 +-- .../spark/sql/execution/ExistingRDD.scala | 15 +- .../apache/spark/sql/execution/Expand.scala | 13 +- .../apache/spark/sql/execution/Generate.scala | 8 +- .../spark/sql/execution/LocalTableScan.scala | 13 +- .../org/apache/spark/sql/execution/Sort.scala | 4 - .../spark/sql/execution/SparkPlan.scala | 23 --- .../apache/spark/sql/execution/Window.scala | 8 +- .../aggregate/SortBasedAggregate.scala | 4 - .../SortBasedAggregationIterator.scala | 8 +- .../aggregate/TungstenAggregate.scala | 4 - .../spark/sql/execution/basicOperators.scala | 58 +------ .../columnar/InMemoryColumnarTableScan.scala | 8 +- .../joins/BroadcastNestedLoopJoin.scala | 7 - .../execution/joins/CartesianProduct.scala | 4 - .../spark/sql/execution/joins/HashJoin.scala | 4 - .../sql/execution/joins/HashOuterJoin.scala | 4 - .../sql/execution/joins/HashSemiJoin.scala | 4 - .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 - .../sql/execution/joins/SortMergeJoin.scala | 4 - .../execution/joins/SortMergeOuterJoin.scala | 4 - .../spark/sql/execution/local/LocalNode.scala | 12 -- .../execution/local/NestedLoopJoinNode.scala | 6 +- .../apache/spark/sql/execution/python.scala | 7 +- .../sql/execution/rowFormatConverters.scala | 108 ------------ .../spark/sql/execution/ExchangeSuite.scala | 2 +- .../spark/sql/execution/ExpandSuite.scala | 54 ------ .../execution/RowFormatConvertersSuite.scala | 164 ------------------ .../spark/sql/execution/SortSuite.scala | 2 +- .../sql/hive/execution/HiveTableScan.scala | 16 +- .../hive/execution/InsertIntoHiveTable.scala | 15 +- .../hive/execution/ScriptTransformation.scala | 3 +- .../sql/sources/hadoopFsRelationSuites.scala | 31 ---- 34 files changed, 74 insertions(+), 574 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index eadf5cba6d9bb..022303239f2af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -904,8 +904,7 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Add row converters", Once, EnsureRowFormats) + Batch("Add exchange", Once, EnsureRequirements(self)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 62cbc518e02af..7b4161930b7d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -50,26 +49,14 @@ case class Exchange( case None => "" } - val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" + val simpleNodeName = "Exchange" s"$simpleNodeName$extraInfo" } - /** - * Returns true iff we can support the data type, and we are not doing range partitioning. - */ - private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] - override def outputPartitioning: Partitioning = newPartitioning override def output: Seq[Attribute] = child.output - // This setting is somewhat counterintuitive: - // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, - // so the planner inserts a converter to convert data into UnsafeRow if needed. - override def outputsUnsafeRows: Boolean = tungstenMode - override def canProcessSafeRows: Boolean = !tungstenMode - override def canProcessUnsafeRows: Boolean = tungstenMode - /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -130,15 +117,7 @@ case class Exchange( } } - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - - private val serializer: Serializer = { - if (tungstenMode) { - new UnsafeRowSerializer(child.output.size) - } else { - new SparkSqlSerializer(sparkConf) - } - } + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) override protected def doPrepare(): Unit = { // If an ExchangeCoordinator is needed, we register this Exchange operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 5c01af011d306..fc508bfafa1c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, AttributeSet, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -99,10 +99,19 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], override val nodeName: String, override val metadata: Map[String, String] = Map.empty, - override val outputsUnsafeRows: Boolean = false) + isUnsafeRow: Boolean = false) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = rdd + protected override def doExecute(): RDD[InternalRow] = { + if (isUnsafeRow) { + rdd + } else { + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } + } + } override def simpleString: String = { val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 91530bd63798a..c3683cc4e7aac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,20 +41,11 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - private[this] val projection = { - if (outputsUnsafeRows) { - (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) - } else { - (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() - } - } + private[this] val projection = + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 0c613e91b979f..4db88a09d8152 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -64,6 +64,7 @@ case class Generate( child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow + val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -77,13 +78,14 @@ case class Generate( } ++ LazyIterator(() => boundGenerator.terminate()).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - joinedRow.withRight(row) + proj(joinedRow.withRight(row)) } } } else { child.execute().mapPartitionsInternal { iter => - iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate()) + val proj = UnsafeProjection.create(output, output) + (iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate())).map(proj) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index ba7f6287ac6c3..59057bf9666ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} /** @@ -29,15 +29,20 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + private val unsafeRows: Array[InternalRow] = { + val proj = UnsafeProjection.create(output, output) + rows.map(r => proj(r).copy()).toArray + } + + private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[InternalRow] = { - rows.toArray + unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - rows.take(limit).toArray + unsafeRows.take(limit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 24207cb46fd29..73dc8cb984471 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -39,10 +39,6 @@ case class Sort( testSpillFrequency: Int = 0) extends UnaryNode { - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fe9b2ad4a0bc3..f20f32aaced2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,17 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute @@ -115,18 +104,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { - if (children.nonEmpty) { - val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) - val hasSafeInputs = children.exists(!_.outputsUnsafeRows) - assert(!(hasSafeInputs && hasUnsafeInputs), - "Child operators should output rows in the same format") - assert(canProcessSafeRows || canProcessUnsafeRows, - "Operator must be able to process at least one row format") - assert(!hasSafeInputs || canProcessSafeRows, - "Operator will receive safe rows as input but cannot process safe rows") - assert(!hasUnsafeInputs || canProcessUnsafeRows, - "Operator will receive unsafe rows as input but cannot process unsafe rows") - } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() doExecute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index c941d673c7248..b79d93d7ca4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -100,8 +100,6 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def canProcessUnsafeRows: Boolean = true - /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -259,16 +257,16 @@ case class Window( * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): MutableProjection = { + expressions: Seq[Expression]): UnsafeProjection = { val references = expressions.zipWithIndex.map{ case (e, i) => // Results of window expressions will be on the right side of child's output BoundReference(child.output.size + i, e.dataType, e.nullable) } val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - newMutableProjection( + UnsafeProjection.create( projectList ++ patchedWindowExpression, - child.output)() + child.output) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c4587ba677b2f..01d076678f041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -49,10 +49,6 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def requiredChildDistribution: List[Distribution] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index ac920aa8bc7f7..6501634ff998b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,6 +87,10 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be + // compared to MutableRow (aggregation buffer) directly. + private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -110,7 +114,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -122,7 +126,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, currentRow) + processRow(sortBasedAggregationBuffer, safeProj(currentRow)) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9d758eb3b7c32..999ebb768af50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -49,10 +49,6 @@ case class TungstenAggregate( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def producedAttributes: AttributeSet = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f19d72f067218..af7237ef25886 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -36,10 +36,6 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -80,12 +76,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - - override def canProcessUnsafeRows: Boolean = true - - override def canProcessSafeRows: Boolean = true } /** @@ -108,10 +98,6 @@ case class Sample( { override def output: Seq[Attribute] = child.output - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -135,8 +121,6 @@ case class Range( output: Seq[Attribute]) extends LeafNode { - override def outputsUnsafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { sqlContext .sparkContext @@ -199,9 +183,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -268,12 +249,14 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. - @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) - private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - projection.map(data.map(_)).getOrElse(data) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } } override def executeCollect(): Array[InternalRow] = { @@ -311,10 +294,6 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { child.execute().coalesce(numPartitions, shuffle = false) } - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -327,10 +306,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -343,10 +318,6 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -371,10 +342,6 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) @@ -394,11 +361,6 @@ case class AppendColumns[T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = AttributeSet(newColumns) - // We are using an unsafe combiner. - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { @@ -428,10 +390,6 @@ case class MapGroups[K, T, U]( child: SparkPlan) extends UnaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -472,10 +430,6 @@ case class CoGroup[Key, Left, Right, Result]( right: SparkPlan) extends BinaryNode { override def producedAttributes: AttributeSet = outputSet - override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index aa7a668e0e938..d80912309bab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -39,9 +39,7 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, - if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), - tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } /** @@ -226,8 +224,6 @@ private[sql] case class InMemoryColumnarTableScan( // The cached version does not change the outputOrdering of the original SparkPlan. override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering - override def outputsUnsafeRows: Boolean = true - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index aab177b2e8427..54275c2cc1134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -46,15 +46,8 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } - override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 81bfe4e67ca73..d9fa4c6b83798 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -81,10 +81,6 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb961d97c3c3c..7f9d9daa5ab20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,10 +44,6 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildSideKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index c6e5868187518..6d464d6946b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,10 +64,6 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index f23a1830e91c1..3e0f74cd98c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,10 +33,6 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def leftKeyGenerator: Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index efa7b49410edc..82498ee395649 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -42,9 +42,6 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4bf7b521c77d3..812f881d06fb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,10 +53,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 7ce38ebdb3413..c3a2bfc59c7a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -89,10 +89,6 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def createLeftKeyGenerator(): Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 6a882c9234df4..e46217050bad5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,18 +69,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin */ def close(): Unit - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true - /** * Returns the content through the [[Iterator]] interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index 7321fc66b4dde..b7fa0c0202221 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -47,11 +47,7 @@ case class NestedLoopJoinNode( } private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } + UnsafeProjection.create(schema) } private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index defcec95fb555..efb4b09c16348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -351,10 +351,6 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -400,13 +396,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val unpickle = new Unpickler val row = new GenericMutableRow(1) val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - joined(queue.poll(), row) + resultProj(joined(queue.poll(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala deleted file mode 100644 index 5f8fc2de8b46d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Converts Java-object-based rows into [[UnsafeRow]]s. - */ -case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToUnsafe = UnsafeProjection.create(child.schema) - iter.map(convertToUnsafe) - } - } -} - -/** - * Converts [[UnsafeRow]]s back into Java-object-based rows. - */ -case class ConvertToSafe(child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) - iter.map(convertToSafe) - } - } -} - -private[sql] object EnsureRowFormats extends Rule[SparkPlan] { - - private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && !operator.canProcessUnsafeRows - - private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessUnsafeRows && !operator.canProcessSafeRows - - private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && operator.canProcessUnsafeRows - - override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { - case operator: SparkPlan if onlyHandlesSafeRows(operator) => - if (operator.children.exists(_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => - if (operator.children.exists(!_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => - if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows. - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 911d12e93e503..87bff3295f5be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -28,7 +28,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + plan => Exchange(SinglePartition, plan), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala deleted file mode 100644 index faef76d52ae75..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.IntegerType - -class ExpandSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - private def testExpand(f: SparkPlan => SparkPlan): Unit = { - val input = (1 to 1000).map(Tuple1.apply) - val projections = Seq.tabulate(2) { i => - Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil - } - val attributes = projections.head.map(_.toAttribute) - checkAnswer( - input.toDF(), - plan => Expand(projections, attributes, f(plan)), - input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) - ) - } - - test("inheriting child row type") { - val exprs = AttributeReference("a", IntegerType, false)() :: Nil - val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) - assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") - } - - test("expanding UnsafeRows") { - testExpand(ConvertToUnsafe) - } - - test("expanding SafeRows") { - testExpand(identity) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala deleted file mode 100644 index 2328899bb2f8d..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{ArrayType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { - - private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { - case c: ConvertToUnsafe => c - case c: ConvertToSafe => c - } - - private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(outputsUnsafe.outputsUnsafeRows) - - test("planner should insert unsafe->safe conversions when required") { - val plan = Limit(10, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) - } - - test("filter can process unsafe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("filter can process safe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("coalesce can process unsafe rows") { - val plan = Coalesce(1, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except can process unsafe rows") { - val plan = Except(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except requires all of its input rows' formats to agree") { - val plan = Except(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect can process unsafe rows") { - val plan = Intersect(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect requires all of its input rows' formats to agree") { - val plan = Intersect(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("execute() fails an assertion if inputs rows are of different formats") { - val e = intercept[AssertionError] { - Union(Seq(outputsSafe, outputsUnsafe)).execute() - } - assert(e.getMessage.contains("format")) - } - - test("union requires all of its input rows' formats to agree") { - val plan = Union(Seq(outputsSafe, outputsUnsafe)) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("union can process safe rows") { - val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("union can process unsafe rows") { - val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("round trip with ConvertToUnsafe and ConvertToSafe") { - val input = Seq(("hello", 1), ("world", 2)) - checkAnswer( - sqlContext.createDataFrame(input), - plan => ConvertToSafe(ConvertToUnsafe(plan)), - input.map(Row.fromTuple) - ) - } - - test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SQLContext.setActive(sqlContext) - val schema = ArrayType(StringType) - val rows = (1 to 100).map { i => - InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) - } - val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) - - val plan = - DummyPlan( - ConvertToSafe( - ConvertToUnsafe(relation))) - assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) - } -} - -case class DummyPlan(child: SparkPlan) extends UnaryNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some - // values gotten from the incoming rows. - // we cache all strings here to make sure we have deep copied UTF8String inside incoming - // safe InternalRow. - val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] - iter.foreach { row => - strings += row.getArray(0).getUTF8String(0) - } - strings.map(InternalRow(_)).iterator - } - } - - override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index e5d34be4c65e8..af971dfc6faec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -99,7 +99,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { ) checkThatPlansAgree( inputDf, - p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), ReferenceSort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8141136de5311..1588728bdbaa4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -132,11 +132,17 @@ case class HiveTableScan( } } - protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + protected override def doExecute(): RDD[InternalRow] = { + val rdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + } + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f936cf565b2bc..44dc68e6ba47f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,18 +28,17 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, Attribute} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.util.SerializableJobConf +import org.apache.spark.{SparkException, TaskContext} private[hive] case class InsertIntoHiveTable( @@ -101,15 +100,17 @@ case class InsertIntoHiveTable( writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + val proj = FromUnsafeProjection(child.schema) iterator.foreach { row => var i = 0 + val safeRow = proj(row) while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) i += 1 } writerContainer - .getLocalFileWriter(row, table.schema) + .getLocalFileWriter(safeRow, table.schema) .write(serializer.serialize(outputData, standardOI)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a61e162f48f1b..6ccd4178190cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -213,7 +213,8 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => if (iter.hasNext) { - processIterator(iter) + val proj = UnsafeProjection.create(schema) + processIterator(iter).map(proj) } else { // If the input iterator has no rows then do not launch the external script. Iterator.empty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 665e87e3e3355..efbf9988ddc13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -27,7 +27,6 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -689,36 +688,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } - - test("HadoopFsRelation produces UnsafeRow") { - withTempTable("test_unsafe") { - withTempPath { dir => - val path = dir.getCanonicalPath - sqlContext.range(3).write.format(dataSourceName).save(path) - sqlContext.read - .format(dataSourceName) - .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) - .load(path) - .registerTempTable("test_unsafe") - - val df = sqlContext.sql( - """SELECT COUNT(*) - |FROM test_unsafe a JOIN test_unsafe b - |WHERE a.id = b.id - """.stripMargin) - - val plan = df.queryExecution.executedPlan - - assert( - plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, - s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): - |$plan - """.stripMargin) - - checkAnswer(df, Row(3)) - } - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when From c3d505602de2fd2361633f90e4fff7e041849e28 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 3 Jan 2016 20:53:35 +0530 Subject: [PATCH 1407/2089] [SPARK-12327][SPARKR] fix code for lintr warning for commented code shivaram Author: felixcheung Closes #10408 from felixcheung/rcodecomment. --- R/pkg/.lintr | 2 +- R/pkg/R/RDD.R | 40 +++++++++++++++++++++-- R/pkg/R/deserialize.R | 3 ++ R/pkg/R/pairRDD.R | 30 +++++++++++++++++ R/pkg/R/serialize.R | 2 ++ R/pkg/inst/tests/testthat/test_rdd.R | 4 +-- R/pkg/inst/tests/testthat/test_shuffle.R | 4 +-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 ++++--- R/pkg/inst/tests/testthat/test_utils.R | 2 ++ 9 files changed, 88 insertions(+), 11 deletions(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 39c872663ad44..038236fc149e6 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 00c40c38cabc9..a78fbb714f2be 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -180,7 +180,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), } # Save the serialization flag after we create a RRDD rdd@env$serializedMode <- serializedMode - rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") rdd@env$jrdd_val }) @@ -225,7 +225,7 @@ setMethod("cache", #' #' Persist this RDD with the specified storage level. For details of the #' supported storage levels, refer to -#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' #' @param x The RDD to persist #' @param newLevel The new storage level to be assigned @@ -382,11 +382,13 @@ setMethod("collectPartition", #' \code{collectAsMap} returns a named list as a map that contains all of the elements #' in a key-value pair RDD. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) #' collectAsMap(rdd) # list(`1` = 2, `3` = 4) #'} +# nolint end #' @rdname collect-methods #' @aliases collectAsMap,RDD-method #' @noRd @@ -442,11 +444,13 @@ setMethod("length", #' @return list of (value, count) pairs, where count is number of each unique #' value in rdd. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,3,2,1)) #' countByValue(rdd) # (1,2L), (2,2L), (3,1L) #'} +# nolint end #' @rdname countByValue #' @aliases countByValue,RDD-method #' @noRd @@ -597,11 +601,13 @@ setMethod("mapPartitionsWithIndex", #' @param x The RDD to be filtered. #' @param f A unary predicate function. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) #'} +# nolint end #' @rdname filterRDD #' @aliases filterRDD,RDD,function-method #' @noRd @@ -756,11 +762,13 @@ setMethod("foreachPartition", #' @param x The RDD to take elements from #' @param num Number of elements to take #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' take(rdd, 2L) # list(1, 2) #'} +# nolint end #' @rdname take #' @aliases take,RDD,numeric-method #' @noRd @@ -824,11 +832,13 @@ setMethod("first", #' @param x The RDD to remove duplicates from. #' @param numPartitions Number of partitions to create. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,2,3,3,3)) #' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) #'} +# nolint end #' @rdname distinct #' @aliases distinct,RDD-method #' @noRd @@ -974,11 +984,13 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", #' @param x The RDD. #' @param func The function to be applied. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3)) #' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) #'} +# nolint end #' @rdname keyBy #' @aliases keyBy,RDD #' @noRd @@ -1113,11 +1125,13 @@ setMethod("saveAsTextFile", #' @param numPartitions Number of partitions to create. #' @return An RDD where all elements are sorted. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(3, 2, 1)) #' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) #'} +# nolint end #' @rdname sortBy #' @aliases sortBy,RDD,RDD-method #' @noRd @@ -1188,11 +1202,13 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { #' @param num Number of elements to return. #' @return The first N elements from the RDD in ascending order. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) #' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) #'} +# nolint end #' @rdname takeOrdered #' @aliases takeOrdered,RDD,RDD-method #' @noRd @@ -1209,11 +1225,13 @@ setMethod("takeOrdered", #' @return The top N elements from the RDD. #' @rdname top #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) #' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) #'} +# nolint end #' @aliases top,RDD,RDD-method #' @noRd setMethod("top", @@ -1261,6 +1279,7 @@ setMethod("fold", #' @rdname aggregateRDD #' @seealso reduce #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4)) @@ -1269,6 +1288,7 @@ setMethod("fold", #' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } #' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) #'} +# nolint end #' @aliases aggregateRDD,RDD,RDD-method #' @noRd setMethod("aggregateRDD", @@ -1367,12 +1387,14 @@ setMethod("setName", #' @return An RDD with zipped items. #' @seealso zipWithIndex #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) #' collect(zipWithUniqueId(rdd)) #' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #'} +# nolint end #' @rdname zipWithUniqueId #' @aliases zipWithUniqueId,RDD #' @noRd @@ -1408,12 +1430,14 @@ setMethod("zipWithUniqueId", #' @return An RDD with zipped items. #' @seealso zipWithUniqueId #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) #' collect(zipWithIndex(rdd)) #' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) #'} +# nolint end #' @rdname zipWithIndex #' @aliases zipWithIndex,RDD #' @noRd @@ -1454,12 +1478,14 @@ setMethod("zipWithIndex", #' @return An RDD created by coalescing all elements within #' each partition into a list. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, as.list(1:4), 2L) #' collect(glom(rdd)) #' # list(list(1, 2), list(3, 4)) #'} +# nolint end #' @rdname glom #' @aliases glom,RDD #' @noRd @@ -1519,6 +1545,7 @@ setMethod("unionRDD", #' @param other Another RDD to be zipped. #' @return An RDD zipped from the two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 0:4) @@ -1526,6 +1553,7 @@ setMethod("unionRDD", #' collect(zipRDD(rdd1, rdd2)) #' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) #'} +# nolint end #' @rdname zipRDD #' @aliases zipRDD,RDD #' @noRd @@ -1557,12 +1585,14 @@ setMethod("zipRDD", #' @param other An RDD. #' @return A new RDD which is the Cartesian product of these two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:2) #' sortByKey(cartesian(rdd, rdd)) #' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #'} +# nolint end #' @rdname cartesian #' @aliases cartesian,RDD,RDD-method #' @noRd @@ -1587,6 +1617,7 @@ setMethod("cartesian", #' @param numPartitions Number of the partitions in the result RDD. #' @return An RDD with the elements from this that are not in other. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) @@ -1594,6 +1625,7 @@ setMethod("cartesian", #' collect(subtract(rdd1, rdd2)) #' # list(1, 1, 3) #'} +# nolint end #' @rdname subtract #' @aliases subtract,RDD #' @noRd @@ -1619,6 +1651,7 @@ setMethod("subtract", #' @param numPartitions The number of partitions in the result RDD. #' @return An RDD which is the intersection of these two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) @@ -1626,6 +1659,7 @@ setMethod("subtract", #' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) #' # list(1, 2, 3) #'} +# nolint end #' @rdname intersection #' @aliases intersection,RDD #' @noRd @@ -1653,6 +1687,7 @@ setMethod("intersection", #' Assumes that all the RDDs have the *same number of partitions*, but #' does *not* require them to have the same number of elements in each partition. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 @@ -1662,6 +1697,7 @@ setMethod("intersection", #' func = function(x, y, z) { list(list(x, y, z))} )) #' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #'} +# nolint end #' @rdname zipRDD #' @aliases zipPartitions,RDD #' @noRd diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index f7e56e43016ea..d8a0393275390 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -17,6 +17,7 @@ # Utility functions to deserialize objects from Java. +# nolint start # Type mapping from Java to R # # void -> NULL @@ -32,6 +33,8 @@ # # Array[T] -> list() # Object -> jobj +# +# nolint end readObject <- function(con) { # Read type first diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 334c11d2f89a1..f7131140feafb 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -30,12 +30,14 @@ NULL #' @param key The key to look up for #' @return a list of values in this RDD for key key #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) #' rdd <- parallelize(sc, pairs) #' lookup(rdd, 1) # list(1, 3) #'} +# nolint end #' @rdname lookup #' @aliases lookup,RDD-method #' @noRd @@ -58,11 +60,13 @@ setMethod("lookup", #' @param x The RDD to count keys. #' @return list of (key, count) pairs, where count is number of each key in rdd. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) #' countByKey(rdd) # ("a", 2L), ("b", 1L) #'} +# nolint end #' @rdname countByKey #' @aliases countByKey,RDD-method #' @noRd @@ -77,11 +81,13 @@ setMethod("countByKey", #' #' @param x The RDD from which the keys of each tuple is returned. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) #' collect(keys(rdd)) # list(1, 3) #'} +# nolint end #' @rdname keys #' @aliases keys,RDD #' @noRd @@ -98,11 +104,13 @@ setMethod("keys", #' #' @param x The RDD from which the values of each tuple is returned. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) #' collect(values(rdd)) # list(2, 4) #'} +# nolint end #' @rdname values #' @aliases values,RDD #' @noRd @@ -348,6 +356,7 @@ setMethod("reduceByKey", #' @return A list of elements of type list(K, V') where V' is the merged value for each key #' @seealso reduceByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) @@ -355,6 +364,7 @@ setMethod("reduceByKey", #' reduced <- reduceByKeyLocally(rdd, "+") #' reduced # list(list(1, 6), list(1.1, 3)) #'} +# nolint end #' @rdname reduceByKeyLocally #' @aliases reduceByKeyLocally,RDD,integer-method #' @noRd @@ -412,6 +422,7 @@ setMethod("reduceByKeyLocally", #' @return An RDD where each element is list(K, C) where C is the combined type #' @seealso groupByKey, reduceByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) @@ -420,6 +431,7 @@ setMethod("reduceByKeyLocally", #' combined <- collect(parts) #' combined[[1]] # Should be a list(1, 6) #'} +# nolint end #' @rdname combineByKey #' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method #' @noRd @@ -473,6 +485,7 @@ setMethod("combineByKey", #' @return An RDD containing the aggregation result. #' @seealso foldByKey, combineByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -482,6 +495,7 @@ setMethod("combineByKey", #' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) #' # list(list(1, list(3, 2)), list(2, list(7, 2))) #'} +# nolint end #' @rdname aggregateByKey #' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method #' @noRd @@ -509,11 +523,13 @@ setMethod("aggregateByKey", #' @return An RDD containing the aggregation result. #' @seealso aggregateByKey, combineByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) #' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) #'} +# nolint end #' @rdname foldByKey #' @aliases foldByKey,RDD,ANY,ANY,integer-method #' @noRd @@ -540,12 +556,14 @@ setMethod("foldByKey", #' @return a new RDD containing all pairs of elements with matching keys in #' two input RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) #' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) #'} +# nolint end #' @rdname join-methods #' @aliases join,RDD,RDD-method #' @noRd @@ -578,6 +596,7 @@ setMethod("join", #' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) #' if no elements in rdd2 have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) @@ -585,6 +604,7 @@ setMethod("join", #' leftOuterJoin(rdd1, rdd2, 2L) #' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) #'} +# nolint end #' @rdname join-methods #' @aliases leftOuterJoin,RDD,RDD-method #' @noRd @@ -616,6 +636,7 @@ setMethod("leftOuterJoin", #' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) #' if no elements in x have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) @@ -623,6 +644,7 @@ setMethod("leftOuterJoin", #' rightOuterJoin(rdd1, rdd2, 2L) #' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) #'} +# nolint end #' @rdname join-methods #' @aliases rightOuterJoin,RDD,RDD-method #' @noRd @@ -655,6 +677,7 @@ setMethod("rightOuterJoin", #' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements #' in x/y have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) @@ -664,6 +687,7 @@ setMethod("rightOuterJoin", #' # list(2, list(NULL, 4))) #' # list(3, list(3, NULL)), #'} +# nolint end #' @rdname join-methods #' @aliases fullOuterJoin,RDD,RDD-method #' @noRd @@ -688,6 +712,7 @@ setMethod("fullOuterJoin", #' @return a new RDD containing all pairs of elements with values in a list #' in all RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) @@ -695,6 +720,7 @@ setMethod("fullOuterJoin", #' cogroup(rdd1, rdd2, numPartitions = 2L) #' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) #'} +# nolint end #' @rdname cogroup #' @aliases cogroup,RDD-method #' @noRd @@ -740,11 +766,13 @@ setMethod("cogroup", #' @param numPartitions Number of partitions to create. #' @return An RDD where all (k, v) pair elements are sorted. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) #' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) #'} +# nolint end #' @rdname sortByKey #' @aliases sortByKey,RDD,RDD-method #' @noRd @@ -805,6 +833,7 @@ setMethod("sortByKey", #' @param numPartitions Number of the partitions in the result RDD. #' @return An RDD with the pairs from x whose keys are not in other. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), @@ -813,6 +842,7 @@ setMethod("sortByKey", #' collect(subtractByKey(rdd1, rdd2)) #' # list(list("b", 4), list("b", 5)) #'} +# nolint end #' @rdname subtractByKey #' @aliases subtractByKey,RDD #' @noRd diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 17082b4e52fcf..095ddb9aed2e7 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -17,6 +17,7 @@ # Utility functions to serialize R objects so they can be read in Java. +# nolint start # Type mapping from R to Java # # NULL -> Void @@ -31,6 +32,7 @@ # list[T] -> Array[T], where T is one of above mentioned types # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +# nolint end getSerdeType <- function(object) { type <- class(object)[[1]] diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 7423b4f2bed1f..1b3a22486e95f 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -223,14 +223,14 @@ test_that("takeSample() on RDDs", { s <- takeSample(data, TRUE, 100L, seed) expect_equal(length(s), 100L) # Chance of getting all distinct elements is astronomically low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } for (seed in 4:5) { s <- takeSample(data, TRUE, 200L, seed) expect_equal(length(s), 200L) # Chance of getting all distinct elements is still quite low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } }) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index adf0b91d25fe9..d3d0f8a24d01c 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -176,8 +176,8 @@ test_that("partitionBy() partitions data correctly", { resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) - expected_first <- list(list(1, 100), list(2, 200)) # key < 3 - expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + expected_first <- list(list(1, 100), list(2, 200)) # key less than 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3 actual_first <- collectPartition(resultRDD, 0L) actual_second <- collectPartition(resultRDD, 1L) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7b508b860efb2..9e5d0ebf60720 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -498,9 +498,11 @@ test_that("table() returns a new DataFrame", { expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + # nolint start # Test base::table is working #a <- letters[1:3] #expect_equal(class(table(a, sample(a))), "table") + # nolint end }) test_that("toRDD() returns an RRDD", { @@ -766,8 +768,10 @@ test_that("sample on a DataFrame", { sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + # nolint start # Test base::sample is working #expect_equal(length(sample(1:12)), 12) + # nolint end }) test_that("select operators", { @@ -1052,8 +1056,8 @@ test_that("string operators", { df2 <- createDataFrame(sqlContext, l2) expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) - expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") - expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint l3 <- list(list(a = "a.b.c.d")) df3 <- createDataFrame(sqlContext, l3) @@ -1259,7 +1263,7 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered6), 2) # Test stats::filter is working - #expect_true(is.ts(filter(1:100, rep(1, 3)))) + #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) test_that("join() and merge() on a DataFrame", { @@ -1659,7 +1663,7 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) # Test stats::cov is working - #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) # nolint }) test_that("freqItems() on a DataFrame", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 12df4cf4f65b7..56f14a3bce61e 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -95,7 +95,9 @@ test_that("cleanClosure on R functions", { # TODO(shivaram): length(ls(env)) is 4 here for some reason and `lapply` is included in `env`. # Disabling this test till we debug this. # + # nolint start # expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + # nolint end expect_true("g" %in% ls(env)) expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) From c82924d564c07e6e6f635b9e263994dedf06268a Mon Sep 17 00:00:00 2001 From: thomastechs Date: Sun, 3 Jan 2016 11:09:30 -0800 Subject: [PATCH 1408/2089] [SPARK-12533][SQL] hiveContext.table() throws the wrong exception Avoiding the the No such table exception and throwing analysis exception as per the bug: SPARK-12533 Author: thomastechs Closes #10529 from thomastechs/topic-branch. --- .../org/apache/spark/sql/catalyst/analysis/Catalog.scala | 2 +- .../test/scala/org/apache/spark/sql/CachedTableSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 8f4ce74a2ea38..3b775c3ca87b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -104,7 +104,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { val tableName = getTableName(tableIdent) val table = tables.get(tableName) if (table == null) { - throw new NoSuchTableException + throw new AnalysisException("Table not found: " + tableName) } val tableWithQualifiers = Subquery(tableName, table) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index d86df4cfb9b4d..6b735bcf16104 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException + import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD @@ -289,7 +289,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext testData.select('key).registerTempTable("t1") sqlContext.table("t1") sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) } test("Drops cached temporary table") { @@ -301,7 +301,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(sqlContext.isCached("t2")) sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) assert(!sqlContext.isCached("t2")) } From 7b92922f7f7ba4ff398dcbd734e8305ba03da87b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 3 Jan 2016 16:58:01 -0800 Subject: [PATCH 1409/2089] Update MimaExcludes now Spark 1.6 is in Maven. Author: Reynold Xin Closes #10561 from rxin/update-mima. --- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 158 +++---------------------------------- 2 files changed, 12 insertions(+), 148 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 519052620246f..9ba9f8286f10c 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.5.0" + val previousSparkVersion = "1.6.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 612ddf86ded3e..7a6e5cf4ad39a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -30,165 +30,29 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * It is also possible to exclude Spark classes and packages. This should be used sparingly: * * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") + * + * For a new Spark version, please update MimaBuild.scala to reflect the previous version. */ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("2.0") => Seq( - // SPARK-7995 Remove AkkaRpcEnv - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaFailure"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaFailure$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEndpointRef$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEnvFactory"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEnv"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaMessage$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaRpcEndpointRef"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.ErrorMonitor"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rpc.akka.AkkaMessage") + excludePackage("org.apache.spark.rpc"), + excludePackage("org.spark-project.jetty"), + excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.sql.catalyst"), + excludePackage("org.apache.spark.sql.execution"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this") ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ Seq( - // SPARK-12481 + // SPARK-12481 Remove Hadoop 1.x ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.mapred.SparkHadoopMapRedUtil") - ) ++ - // When 1.6 is officially released, update this exclusion list. - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("network"), - MimaBuild.excludeSparkPackage("unsafe"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // SQL columnar is considered private. - excludePackage("org.apache.spark.sql.columnar"), - // The shuffle package is considered private. - excludePackage("org.apache.spark.shuffle"), - // The collections utlities are considered pricate. - excludePackage("org.apache.spark.util.collection") - ) ++ - MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ - Seq( - // MiMa does not deal properly with sealed traits - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") - ) ++ Seq( - // SPARK-11530 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this") - ) ++ Seq( - // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. - // This class is marked as `private` but MiMa still seems to be confused by the change. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ Seq( - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.SQLContext$SQLSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.detachSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.tlSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.defaultSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.currentSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.openSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.setSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.createSession") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.preferredNodeLocationData_="), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") - ) ++ Seq( - // SPARK-11485 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), - // SPARK-11541 mark various JDBC dialects as private - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") - ) ++ Seq ( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationInfo.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.StageData.this") - ) ++ Seq( - // SPARK-11766 add toJson to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toJson") - ) ++ Seq( - // SPARK-9065 Support message handler in Kafka Python API - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") - ) ++ Seq( - // SPARK-4557 Changed foreachRDD to use VoidFunction - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") - ) ++ Seq( - // SPARK-11996 Make the executor thread dump work again - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") - ) ++ Seq( - // SPARK-3580 Add getNumPartitions method to JavaRDD - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") - ) ++ - // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a - // private class. - MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") + ) case v if v.startsWith("1.6") => Seq( MimaBuild.excludeSparkPackage("deploy"), From b8410ff9ce8cef7159a7364272e4c4234c5b474f Mon Sep 17 00:00:00 2001 From: Cazen Date: Sun, 3 Jan 2016 17:01:19 -0800 Subject: [PATCH 1410/2089] [SPARK-12537][SQL] Add option to accept quoting of all character backslash quoting mechanism We can provides the option to choose JSON parser can be enabled to accept quoting of all character or not. Author: Cazen Author: Cazen Lee Author: Cazen Lee Author: cazen.lee Closes #10497 from Cazen/master. --- python/pyspark/sql/readwriter.py | 2 ++ .../apache/spark/sql/DataFrameReader.scala | 2 ++ .../datasources/json/JSONOptions.scala | 9 +++++++-- .../json/JsonParsingOptionsSuite.scala | 19 +++++++++++++++++++ 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a3d7eca04b616..a2771daabe331 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -160,6 +160,8 @@ def json(self, path, schema=None): quotes * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ (e.g. 00012) + * ``allowBackslashEscapingAnyCharacter`` (default ``false``): allows accepting quoting \ + of all character using backslash quoting mechanism >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0acea95344c22..6debb302d9ecf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -258,6 +258,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • + *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all + * character using backslash quoting mechanism
  • * * @since 1.6.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index c132ead20e7d6..f805c0092585c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -31,7 +31,8 @@ case class JSONOptions( allowUnquotedFieldNames: Boolean = false, allowSingleQuotes: Boolean = true, allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false) { + allowNonNumericNumbers: Boolean = false, + allowBackslashEscapingAnyCharacter: Boolean = false) { /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -40,6 +41,8 @@ case class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + factory.configure(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + allowBackslashEscapingAnyCharacter) } } @@ -59,6 +62,8 @@ object JSONOptions { allowNumericLeadingZeros = parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), + allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 4cc0a3a9585d9..1742df31bba9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -111,4 +111,23 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) } + + test("allowBackslashEscapingAnyCharacter off") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowBackslashEscapingAnyCharacter on") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.schema.last.name == "price") + assert(df.first().getString(0) == "Cazen Lee") + assert(df.first().getString(1) == "$10") + } } From 13dab9c3862cc454094cd9ba7b4504a2d095028f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 3 Jan 2016 17:04:35 -0800 Subject: [PATCH 1411/2089] [SPARK-12611][SQL][PYSPARK][TESTS] Fix test_infer_schema_to_local Previously (when the PR was first created) not specifying b= explicitly was fine (and treated as default null) - instead be explicit about b being None in the test. Author: Holden Karau Closes #10564 from holdenk/SPARK-12611-fix-test-infer-schema-local. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 10b99175ad952..9ada96601a1cd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -360,7 +360,7 @@ def test_infer_schema_to_local(self): df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) From 84f8492c1555bf8ab44c9818752278f61768eb16 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Sun, 3 Jan 2016 20:48:56 -0800 Subject: [PATCH 1412/2089] [SPARK-12562][SQL] DataFrame.write.format(text) requires the column name to be called value Author: Xiu Guo Closes #10515 from xguo27/SPARK-12562. --- .../sql/execution/datasources/text/DefaultSource.scala | 9 +++++---- .../spark/sql/execution/datasources/text/TextSuite.scala | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 248467abe9f50..fe69c72d28cb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -48,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation(None, partitionColumns, paths)(sqlContext) + new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext) } override def shortName(): String = "text" @@ -68,15 +68,16 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { private[sql] class TextRelation( val maybePartitionSpec: Option[PartitionSpec], + val textSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - /** Data schema is always a single column, named "value". */ - override def dataSchema: StructType = new StructType().add("value", StringType) - + /** Data schema is always a single column, named "value" if original Data source has no schema. */ + override def dataSchema: StructType = + textSchema.getOrElse(new StructType().add("value", StringType)) /** This is an internal data source that outputs internal row format. */ override val needConversion: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 914e516613f9e..02c416af50cd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -33,8 +33,8 @@ class TextSuite extends QueryTest with SharedSQLContext { verifyFrame(sqlContext.read.text(testFile)) } - test("writing") { - val df = sqlContext.read.text(testFile) + test("SPARK-12562 verify write.text() can handle column name beyond `value`") { + val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() From 0d165ec2050aa06aa545d74c8d7c2ff197fa02de Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Jan 2016 22:05:02 -0800 Subject: [PATCH 1413/2089] [SPARK-12612][PROJECT-INFRA] Add missing Hadoop profiles to dev/run-tests-*.py scripts and dev/deps There are a couple of places in the `dev/run-tests-*.py` scripts which deal with Hadoop profiles, but the set of profiles that they handle does not include all Hadoop profiles defined in our POM. Similarly, the `hadoop-2.2` and `hadoop-2.6` profiles were missing from `dev/deps`. This patch updates these scripts to include all four Hadoop profiles defined in our POM. Author: Josh Rosen Closes #10565 from JoshRosen/add-missing-hadoop-profiles-in-test-scripts. --- dev/deps/spark-deps-hadoop-2.2 | 193 +++++++++++++++++++++++++++++++++ dev/deps/spark-deps-hadoop-2.6 | 192 ++++++++++++++++++++++++++++++++ dev/run-tests-jenkins.py | 4 + dev/run-tests.py | 5 +- dev/test-dependencies.sh | 2 + 5 files changed, 394 insertions(+), 2 deletions(-) create mode 100644 dev/deps/spark-deps-hadoop-2.2 create mode 100644 dev/deps/spark-deps-hadoop-2.6 diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 new file mode 100644 index 0000000000000..44727f9876deb --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -0,0 +1,193 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math-2.1.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +gmbal-api-only-3.0.0-b023.jar +grizzly-framework-2.1.2.jar +grizzly-http-2.1.2.jar +grizzly-http-server-2.1.2.jar +grizzly-http-servlet-2.1.2.jar +grizzly-rcm-2.1.2.jar +groovy-all-2.1.6.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.2.0.jar +hadoop-auth-2.2.0.jar +hadoop-client-2.2.0.jar +hadoop-common-2.2.0.jar +hadoop-hdfs-2.2.0.jar +hadoop-mapreduce-client-app-2.2.0.jar +hadoop-mapreduce-client-common-2.2.0.jar +hadoop-mapreduce-client-core-2.2.0.jar +hadoop-mapreduce-client-jobclient-2.2.0.jar +hadoop-mapreduce-client-shuffle-2.2.0.jar +hadoop-yarn-api-2.2.0.jar +hadoop-yarn-client-2.2.0.jar +hadoop-yarn-common-2.2.0.jar +hadoop-yarn-server-common-2.2.0.jar +hadoop-yarn-server-web-proxy-2.2.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javax.servlet-3.1.jar +javax.servlet-api-3.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-grizzly2-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jersey-test-framework-core-1.9.jar +jersey-test-framework-grizzly2-1.9.jar +jets3t-0.7.1.jar +jettison-1.1.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.1.jar +management-api-3.0.0-b012.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 new file mode 100644 index 0000000000000..e37484473db2e --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -0,0 +1,192 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +apacheds-i18n-2.0.0-M15.jar +apacheds-kerberos-codec-2.0.0-M15.jar +api-asn1-api-1.0.0-M20.jar +api-util-1.0.0-M20.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.6.0.jar +curator-framework-2.6.0.jar +curator-recipes-2.6.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +gson-2.2.4.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.6.0.jar +hadoop-auth-2.6.0.jar +hadoop-client-2.6.0.jar +hadoop-common-2.6.0.jar +hadoop-hdfs-2.6.0.jar +hadoop-mapreduce-client-app-2.6.0.jar +hadoop-mapreduce-client-common-2.6.0.jar +hadoop-mapreduce-client-core-2.6.0.jar +hadoop-mapreduce-client-jobclient-2.6.0.jar +hadoop-mapreduce-client-shuffle-2.6.0.jar +hadoop-yarn-api-2.6.0.jar +hadoop-yarn-client-2.6.0.jar +hadoop-yarn-common-2.6.0.jar +hadoop-yarn-server-common-2.6.0.jar +hadoop-yarn-server-web-proxy-2.6.0.jar +htrace-core-3.0.4.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xercesImpl-2.9.1.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.6.jar diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 6501721572f98..c44e522c0475d 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -168,6 +168,10 @@ def main(): os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" if "test-hadoop2.3" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.3" + if "test-hadoop2.4" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.4" + if "test-hadoop2.6" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.6" build_display_name = os.environ["BUILD_DISPLAY_NAME"] build_url = os.environ["BUILD_URL"] diff --git a/dev/run-tests.py b/dev/run-tests.py index 23278d298c22d..9db728d799022 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -295,13 +295,14 @@ def exec_sbt(sbt_args=()): def get_hadoop_profiles(hadoop_version): """ - For the given Hadoop version tag, return a list of SBT profile flags for + For the given Hadoop version tag, return a list of Maven/SBT profile flags for building and testing against that Hadoop version. """ sbt_maven_hadoop_profiles = { "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], - "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.3": ["-Pyarn", "-Phadoop-2.3"], + "hadoop2.4": ["-Pyarn", "-Phadoop-2.4"], "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], } diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 4e260e2abf042..d6a32717f588f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -32,8 +32,10 @@ export LC_ALL=C HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pyarn -Phive" MVN="build/mvn --force" HADOOP_PROFILES=( + hadoop-2.2 hadoop-2.3 hadoop-2.4 + hadoop-2.6 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to From 9fd7a2f0247ed6cea0e8dbcdd2b24f41200b3e24 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 Jan 2016 01:04:29 -0800 Subject: [PATCH 1414/2089] [SPARK-10359][PROJECT-INFRA] Use more random number in dev/test-dependencies.sh; fix version switching This patch aims to fix another potential source of flakiness in the `dev/test-dependencies.sh` script. pwendell's original patch and my version used `$(date +%s | tail -c6)` to generate a suffix to use when installing temporary Spark versions into the local Maven cache, but this value only changes once per second and thus is highly collision-prone when concurrent builds launch on AMPLab Jenkins. In order to reduce the potential for conflicts, this patch updates the script to call Python's random number generator instead. I also fixed a bug in how we captured the original project version; the bug was causing the exit handler code to fail. Author: Josh Rosen Closes #10558 from JoshRosen/build-dep-tests-round-3. --- dev/run-tests.py | 4 ++-- dev/test-dependencies.sh | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 9db728d799022..8726889cbc777 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -419,8 +419,8 @@ def run_python_tests(test_modules, parallelism): def run_build_tests(): - # set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") - # run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") + run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) pass diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index d6a32717f588f..424ce6ad7663c 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -42,9 +42,19 @@ HADOOP_PROFILES=( # the old version. We need to do this because the `dependency:build-classpath` task needs to # resolve Spark's internal submodule dependencies. -# See http://stackoverflow.com/a/3545363 for an explanation of this one-liner: -OLD_VERSION=$($MVN help:evaluate -Dexpression=project.version|grep -Ev '(^\[|Download\w+:)') -TEMP_VERSION="spark-$(date +%s | tail -c6)" +# From http://stackoverflow.com/a/26514030 +set +e +OLD_VERSION=$($MVN -q \ + -Dexec.executable="echo" \ + -Dexec.args='${project.version}' \ + --non-recursive \ + org.codehaus.mojo:exec-maven-plugin:1.3.1:exec) +if [ $? != 0 ]; then + echo -e "Error while getting version string from Maven:\n$OLD_VERSION" + exit 1 +fi +set -e +TEMP_VERSION="spark-$(python -S -c "import random; print(random.randrange(100000, 999999))")" function reset_version { # Delete the temporary POMs that we wrote to the local Maven repo: From 962aac4db99f3988c07ccb23439327c18ec178f1 Mon Sep 17 00:00:00 2001 From: guoxu1231 Date: Mon, 4 Jan 2016 14:23:07 +0000 Subject: [PATCH 1415/2089] [SPARK-12513][STREAMING] SocketReceiver hang in Netcat example Explicitly close client side socket connection before restart socket receiver. Author: guoxu1231 Author: Shawn Guo Closes #10464 from guoxu1231/SPARK-12513. --- .../dstream/SocketInputDStream.scala | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 10644b9201918..e70fc87c39d95 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import java.io._ -import java.net.{Socket, UnknownHostException} +import java.net.{ConnectException, Socket} import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -51,7 +51,20 @@ class SocketReceiver[T: ClassTag]( storageLevel: StorageLevel ) extends Receiver[T](storageLevel) with Logging { + private var socket: Socket = _ + def onStart() { + + logInfo(s"Connecting to $host:$port") + try { + socket = new Socket(host, port) + } catch { + case e: ConnectException => + restart(s"Error connecting to $host:$port", e) + return + } + logInfo(s"Connected to $host:$port") + // Start the thread that receives data over a connection new Thread("Socket Receiver") { setDaemon(true) @@ -60,20 +73,22 @@ class SocketReceiver[T: ClassTag]( } def onStop() { - // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // in case restart thread close it twice + synchronized { + if (socket != null) { + socket.close() + socket = null + logInfo(s"Closed socket to $host:$port") + } + } } /** Create a socket connection and receive data until receiver is stopped */ def receive() { - var socket: Socket = null try { - logInfo("Connecting to " + host + ":" + port) - socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) val iterator = bytesToObjects(socket.getInputStream()) while(!isStopped && iterator.hasNext) { - store(iterator.next) + store(iterator.next()) } if (!isStopped()) { restart("Socket data stream had no more data") @@ -81,16 +96,11 @@ class SocketReceiver[T: ClassTag]( logInfo("Stopped receiving") } } catch { - case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) case NonFatal(e) => logWarning("Error receiving data", e) restart("Error receiving data", e) } finally { - if (socket != null) { - socket.close() - logInfo("Closed socket to " + host + ":" + port) - } + onStop() } } } From 8f659393b270c46e940c4e98af2d996bd4fd6442 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 4 Jan 2016 10:37:56 -0800 Subject: [PATCH 1416/2089] [SPARK-12486] Worker should kill the executors more forcefully if possible. This patch updates the ExecutorRunner's terminate path to use the new java 8 API to terminate processes more forcefully if possible. If the executor is unhealthy, it would previously ignore the destroy() call. Presumably, the new java API was added to handle cases like this. We could update the termination path in the future to use OS specific commands for older java versions. Author: Nong Li Closes #10438 from nongli/spark-12486-executors. --- .../spark/deploy/worker/ExecutorRunner.scala | 17 ++-- .../scala/org/apache/spark/util/Utils.scala | 24 ++++++ .../org/apache/spark/util/UtilsSuite.scala | 83 +++++++++++++++++-- 3 files changed, 112 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 9a42487bb37aa..9c4b8cdc646b0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -23,13 +23,12 @@ import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files - -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Manages the execution of one executor process. @@ -60,6 +59,9 @@ private[deploy] class ExecutorRunner( private var stdoutAppender: FileAppender = null private var stderrAppender: FileAppender = null + // Timeout to wait for when trying to terminate an executor. + private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000 + // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. private var shutdownHook: AnyRef = null @@ -94,8 +96,11 @@ private[deploy] class ExecutorRunner( if (stderrAppender != null) { stderrAppender.stop() } - process.destroy() - exitCode = Some(process.waitFor()) + exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS) + if (exitCode.isEmpty) { + logWarning("Failed to terminate process: " + process + + ". This process will likely be orphaned.") + } } try { worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b8ca6b07e4198..0c1f9c1ae2496 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1698,6 +1698,30 @@ private[spark] object Utils extends Logging { new File(path).getName } + /** + * Terminates a process waiting for at most the specified duration. Returns whether + * the process terminated. + */ + def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { + try { + // Java8 added a new API which will more forcibly kill the process. Use that if available. + val destroyMethod = process.getClass().getMethod("destroyForcibly"); + destroyMethod.setAccessible(true) + destroyMethod.invoke(process) + } catch { + case NonFatal(e) => + if (!e.isInstanceOf[NoSuchMethodException]) { + logWarning("Exception when attempting to kill process", e) + } + process.destroy() + } + if (waitForProcess(process, timeoutMs)) { + Option(process.exitValue()) + } else { + None + } + } + /** * Wait for a process to terminate for at most the specified duration. * Return whether the process actually terminated after the given timeout. diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index fdb51d440eff6..7de995af512db 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.util -import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols -import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files - +import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path - import org.apache.spark.network.util.ByteUnit -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -745,4 +743,77 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") } + + test("Kill process") { + // Verify that we can terminate a process even if it is in a bad state. This is only run + // on UNIX since it does some OS specific things to verify the correct behavior. + if (SystemUtils.IS_OS_UNIX) { + def getPid(p: Process): Int = { + val f = p.getClass().getDeclaredField("pid") + f.setAccessible(true) + f.get(p).asInstanceOf[Int] + } + + def pidExists(pid: Int): Boolean = { + val p = Runtime.getRuntime.exec(s"kill -0 $pid") + p.waitFor() + p.exitValue() == 0 + } + + def signal(pid: Int, s: String): Unit = { + val p = Runtime.getRuntime.exec(s"kill -$s $pid") + p.waitFor() + } + + // Start up a process that runs 'sleep 10'. Terminate the process and assert it takes + // less time and the process is no longer there. + val startTimeMs = System.currentTimeMillis() + val process = new ProcessBuilder("sleep", "10").start() + val pid = getPid(process) + try { + assert(pidExists(pid)) + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val durationMs = System.currentTimeMillis() - startTimeMs + assert(durationMs < 5000) + assert(!pidExists(pid)) + } finally { + // Forcibly kill the test process just in case. + signal(pid, "SIGKILL") + } + + val v: String = System.getProperty("java.version") + if (v >= "1.8.0") { + // Java8 added a way to forcibly terminate a process. We'll make sure that works by + // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On + // older versions of java, this will *not* terminate. + val file = File.createTempFile("temp-file-name", ".tmp") + val cmd = + s""" + |#!/bin/bash + |trap "" SIGTERM + |sleep 10 + """.stripMargin + Files.write(cmd.getBytes(), file) + file.getAbsoluteFile.setExecutable(true) + + val process = new ProcessBuilder(file.getAbsolutePath).start() + val pid = getPid(process) + assert(pidExists(pid)) + try { + signal(pid, "SIGSTOP") + val start = System.currentTimeMillis() + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val duration = System.currentTimeMillis() - start + assert(duration < 5000) + assert(!pidExists(pid)) + } finally { + signal(pid, "SIGKILL") + } + } + } + } } From 6c83d938cc61bd5fabaf2157fcc3936364a83f02 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 Jan 2016 10:39:42 -0800 Subject: [PATCH 1417/2089] [SPARK-12579][SQL] Force user-specified JDBC driver to take precedence Spark SQL's JDBC data source allows users to specify an explicit JDBC driver to load (using the `driver` argument), but in the current code it's possible that the user-specified driver will not be used when it comes time to actually create a JDBC connection. In a nutshell, the problem is that you might have multiple JDBC drivers on the classpath that claim to be able to handle the same subprotocol, so simply registering the user-provided driver class with the our `DriverRegistry` and JDBC's `DriverManager` is not sufficient to ensure that it's actually used when creating the JDBC connection. This patch addresses this issue by first registering the user-specified driver with the DriverManager, then iterating over the driver manager's loaded drivers in order to obtain the correct driver and use it to create a connection (previously, we just called `DriverManager.getConnection()` directly). If a user did not specify a JDBC driver to use, then we call `DriverManager.getDriver` to figure out the class of the driver to use, then pass that class's name to executors; this guards against corner-case bugs in situations where the driver and executor JVMs might have different sets of JDBC drivers on their classpaths (previously, there was the (rare) potential for `DriverManager.getConnection()` to use different drivers on the driver and executors if the user had not explicitly specified a JDBC driver class and the classpaths were different). This patch is inspired by a similar patch that I made to the `spark-redshift` library (https://github.com/databricks/spark-redshift/pull/143), which contains its own modified fork of some of Spark's JDBC data source code (for cross-Spark-version compatibility reasons). Author: Josh Rosen Closes #10519 from JoshRosen/jdbc-driver-precedence. --- docs/sql-programming-guide.md | 4 +-- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/jdbc/DefaultSource.scala | 3 -- .../datasources/jdbc/DriverRegistry.scala | 5 --- .../execution/datasources/jdbc/JDBCRDD.scala | 33 +++-------------- .../datasources/jdbc/JDBCRelation.scala | 2 -- .../datasources/jdbc/JdbcUtils.scala | 35 +++++++++++++++---- 7 files changed, 34 insertions(+), 50 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f9a831eddc88..b058833616433 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ab362539e2982..9f59c0f3a74db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) - val conn = JdbcUtils.createConnection(url, props) + val conn = JdbcUtils.createConnectionFactory(url, props)() try { var tableExists = JdbcUtils.tableExists(conn, url, table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index f522303be94ad..5ae6cff9b5584 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) val partitionColumn = parameters.getOrElse("partitionColumn", null) val lowerBound = parameters.getOrElse("lowerBound", null) val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) DriverRegistry.register(driver) - if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7ccd61ed469e9..65af397451c5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -51,10 +51,5 @@ object DriverRegistry extends Logging { } } } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 87d43addd36ce..cb8d9504af997 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} +import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal @@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par override def index: Int = idx } - private[sql] object JDBCRDD extends Logging { /** @@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging { */ def resolveTable(url: String, table: String, properties: Properties): StructType = { val dialect = JdbcDialects.get(url) - val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() + val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() try { val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") try { @@ -228,36 +227,13 @@ private[sql] object JDBCRDD extends Logging { }) } - /** - * Given a driver string and an url, return a function that loads the - * specified driver string then returns a connection to the JDBC url. - * getConnector is run on the driver code, while the function it returns - * is run on the executor. - * - * @param driver - The class name of the JDBC driver for the given url, or null if the class name - * is not necessary. - * @param url - The JDBC url to connect to. - * - * @return A function that loads the driver and connects to the url. - */ - def getConnector(driver: String, url: String, properties: Properties): () => Connection = { - () => { - try { - if (driver != null) DriverRegistry.register(driver) - } catch { - case e: ClassNotFoundException => - logWarning(s"Couldn't find class $driver", e) - } - DriverManager.getConnection(url, properties) - } - } + /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param driver - The class name of the JDBC driver for the given url. * @param url - The JDBC url to connect to. * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. * @param requiredColumns - The names of the columns to SELECT. @@ -270,7 +246,6 @@ private[sql] object JDBCRDD extends Logging { def scanTable( sc: SparkContext, schema: StructType, - driver: String, url: String, properties: Properties, fqTable: String, @@ -281,7 +256,7 @@ private[sql] object JDBCRDD extends Logging { val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, - getConnector(driver, url, properties), + JdbcUtils.createConnectionFactory(url, properties), pruneSchema(schema, requiredColumns), fqTable, quotedColumns, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f9300dc2cb529..375266f1be42c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -91,12 +91,10 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverRegistry.getDriverClassName(url) // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, - driver, url, properties, table, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 46f2670eee010..10f650693f288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement} import java.util.Properties +import scala.collection.JavaConverters._ import scala.util.Try import scala.util.control.NonFatal @@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row} object JdbcUtils extends Logging { /** - * Establishes a JDBC connection. + * Returns a factory for creating connections to the given JDBC URL. + * + * @param url the JDBC url to connect to. + * @param properties JDBC connection properties. */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() + def createConnectionFactory(url: String, properties: Properties): () => Connection = { + val userSpecifiedDriverClass = Option(properties.getProperty("driver")) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + val driverClass: String = userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + () => { + userSpecifiedDriverClass.foreach(DriverRegistry.register) + val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d + case d if d.getClass.getCanonicalName == driverClass => d + }.getOrElse { + throw new IllegalStateException( + s"Did not find registered driver with class $driverClass") + } + driver.connect(url, properties) + } } /** @@ -242,15 +264,14 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) From b504b6a90a95a723210beb0031ed41a75d702f66 Mon Sep 17 00:00:00 2001 From: Pete Robbins Date: Mon, 4 Jan 2016 10:43:21 -0800 Subject: [PATCH 1418/2089] [SPARK-12470] [SQL] Fix size reduction calculation also only allocate required buffer size Author: Pete Robbins Closes #10421 from robbinspg/master. --- .../expressions/codegen/GenerateUnsafeRowJoiner.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index c9ff357bf3476..037ae83d485d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -61,9 +61,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 val bitset1Remainder = schema1.size % 64 - // The number of words we can reduce when we concat two rows together. + // The number of bytes we can reduce when we concat two rows together. // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. - val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords + val sizeReduction = (bitset1Words + bitset2Words - outputBitsetWords) * 8 // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => @@ -171,7 +171,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | // row1: ${schema1.size} fields, $bitset1Words words in bitset | // row2: ${schema2.size}, $bitset2Words words in bitset | // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset - | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes(); + | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes() - $sizeReduction; | if (sizeInBytes > buf.length) { | buf = new byte[sizeInBytes]; | } @@ -188,7 +188,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyVariableLengthRow2 | $updateOffset | - | out.pointTo(buf, sizeInBytes - $sizeReduction); + | out.pointTo(buf, sizeInBytes); | | return out; | } From 43706bf8bdfe08010bb11848788e0718d15363b3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 4 Jan 2016 11:00:15 -0800 Subject: [PATCH 1419/2089] [SPARK-12608][STREAMING] Remove submitJobThreadPool since submitJob doesn't create a separate thread to wait for the job result Before #9264, submitJob would create a separate thread to wait for the job result. `submitJobThreadPool` was a workaround in `ReceiverTracker` to run these waiting-job-result threads. Now #9264 has been merged to master and resolved this blocking issue, `submitJobThreadPool` can be removed now. Author: Shixiong Zhu Closes #10560 from zsxwing/remove-submitJobThreadPool. --- .../apache/spark/streaming/scheduler/ReceiverTracker.scala | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 9ddf176aee84c..678f1dc950ada 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -435,10 +435,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { - // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged - private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool")) - private val walBatchingThreadPool = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) @@ -610,12 +606,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logInfo(s"Restarting Receiver $receiverId") self.send(RestartReceiver(receiver)) } - }(submitJobThreadPool) + }(ThreadUtils.sameThread) logInfo(s"Receiver ${receiver.streamId} started") } override def onStop(): Unit = { - submitJobThreadPool.shutdownNow() active = false walBatchingThreadPool.shutdown() } From 573ac55d7469ea2ea7a5979b4d3eea99c98f6560 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 4 Jan 2016 12:34:04 -0800 Subject: [PATCH 1420/2089] [SPARK-12512][SQL] support column name with dot in withColumn() Author: Xiu Guo Closes #10500 from xguo27/SPARK-12512. --- .../org/apache/spark/sql/DataFrame.scala | 32 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 7 ++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 965eaa9efec41..0763aa4ed99da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1171,13 +1171,17 @@ class DataFrame private[sql]( */ def withColumn(colName: String, col: Column): DataFrame = { val resolver = sqlContext.analyzer.resolver - val replaced = schema.exists(f => resolver(f.name, colName)) - if (replaced) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, colName)) col.as(colName) else Column(name) + val output = queryExecution.analyzed.output + val shouldReplace = output.exists(f => resolver(f.name, colName)) + if (shouldReplace) { + val columns = output.map { field => + if (resolver(field.name, colName)) { + col.as(colName) + } else { + Column(field) + } } - select(colNames : _*) + select(columns : _*) } else { select(Column("*"), col.as(colName)) } @@ -1188,13 +1192,17 @@ class DataFrame private[sql]( */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { val resolver = sqlContext.analyzer.resolver - val replaced = schema.exists(f => resolver(f.name, colName)) - if (replaced) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + val output = queryExecution.analyzed.output + val shouldReplace = output.exists(f => resolver(f.name, colName)) + if (shouldReplace) { + val columns = output.map { field => + if (resolver(field.name, colName)) { + col.as(colName, metadata) + } else { + Column(field) + } } - select(colNames : _*) + select(columns : _*) } else { select(Column("*"), col.as(colName, metadata)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ad478b0511095..ab02b32f91aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1221,4 +1221,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { " _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]") } + + test("SPARK-12512: support `.` in column name for withColumn()") { + val df = Seq("a" -> "b").toDF("col.a", "col.b") + checkAnswer(df.select(df("*")), Row("a", "b")) + checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b")) + checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c")) + } } From 40d03960d79debdff5cef21997417c4f8a8ce2e9 Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 4 Jan 2016 12:38:04 -0800 Subject: [PATCH 1421/2089] [DOC] Adjust coverage for partitionBy() This is the related thread: http://search-hadoop.com/m/q3RTtO3ReeJ1iF02&subj=Re+partitioning+json+data+in+spark Michael suggested fixing the doc. Please review. Author: tedyu Closes #10499 from ted-yu/master. --- .../src/main/scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9f59c0f3a74db..9afa6856907a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -119,7 +119,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * Partitions the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's partitioning scheme. * - * This is only applicable for Parquet at the moment. + * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. * * @since 1.4.0 */ From 0171b71e9511cef512e96a759e407207037f3c49 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 4 Jan 2016 12:41:57 -0800 Subject: [PATCH 1422/2089] [SPARK-12421][SQL] Prevent Internal/External row from exposing state. It is currently possible to change the values of the supposedly immutable ```GenericRow``` and ```GenericInternalRow``` classes. This is caused by the fact that scala's ArrayOps ```toArray``` (returned by calling ```toSeq```) will return the backing array instead of a copy. This PR fixes this problem. This PR was inspired by https://github.com/apache/spark/pull/10374 by apo1. cc apo1 sarutak marmbrus cloud-fan nongli (everyone in the previous conversation). Author: Herman van Hovell Closes #10553 from hvanhovell/SPARK-12421. --- .../spark/sql/catalyst/expressions/rows.scala | 8 ++--- .../scala/org/apache/spark/sql/RowTest.scala | 30 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index cfc68fc00bea8..814b3c22f8806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -199,9 +199,9 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { override def get(i: Int): Any = values(i) - override def toSeq: Seq[Any] = values.toSeq + override def toSeq: Seq[Any] = values.clone() - override def copy(): Row = this + override def copy(): GenericRow = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -226,11 +226,11 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone() override def numFields: Int = values.length - override def copy(): InternalRow = new GenericInternalRow(values.clone()) + override def copy(): GenericInternalRow = this } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 5c22a72192541..72624e7cbc112 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -104,4 +104,34 @@ class RowTest extends FunSpec with Matchers { internalRow shouldEqual internalRow2 } } + + describe("row immutability") { + val values = Seq(1, 2, "3", "IV", 6L) + val externalRow = Row.fromSeq(values) + val internalRow = InternalRow.fromSeq(values) + + def modifyValues(values: Seq[Any]): Seq[Any] = { + val array = values.toArray + array(2) = "42" + array + } + + it("copy should return same ref for external rows") { + externalRow should be theSameInstanceAs externalRow.copy() + } + + it("copy should return same ref for interal rows") { + internalRow should be theSameInstanceAs internalRow.copy() + } + + it("toSeq should not expose internal state for external rows") { + val modifiedValues = modifyValues(externalRow.toSeq) + externalRow.toSeq should not equal modifiedValues + } + + it("toSeq should not expose internal state for internal rows") { + val modifiedValues = modifyValues(internalRow.toSeq(Seq.empty)) + internalRow.toSeq(Seq.empty) should not equal modifiedValues + } + } } From ba5f81859d6ba37a228a1c43d26c47e64c0382cd Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 4 Jan 2016 13:30:17 -0800 Subject: [PATCH 1423/2089] [SPARK-11259][ML] Params.validateParams() should be called automatically See JIRA: https://issues.apache.org/jira/browse/SPARK-11259 Author: Yanbo Liang Closes #9224 from yanboliang/spark-11259. --- .../scala/org/apache/spark/ml/Pipeline.scala | 2 ++ .../scala/org/apache/spark/ml/Predictor.scala | 1 + .../org/apache/spark/ml/Transformer.scala | 1 + .../apache/spark/ml/clustering/KMeans.scala | 1 + .../org/apache/spark/ml/clustering/LDA.scala | 1 + .../apache/spark/ml/feature/Binarizer.scala | 1 + .../apache/spark/ml/feature/Bucketizer.scala | 1 + .../spark/ml/feature/ChiSqSelector.scala | 2 ++ .../spark/ml/feature/CountVectorizer.scala | 1 + .../apache/spark/ml/feature/HashingTF.scala | 1 + .../org/apache/spark/ml/feature/IDF.scala | 1 + .../spark/ml/feature/MinMaxScaler.scala | 1 + .../spark/ml/feature/OneHotEncoder.scala | 1 + .../org/apache/spark/ml/feature/PCA.scala | 2 ++ .../ml/feature/QuantileDiscretizer.scala | 1 + .../apache/spark/ml/feature/RFormula.scala | 4 ++++ .../spark/ml/feature/SQLTransformer.scala | 1 + .../spark/ml/feature/StandardScaler.scala | 2 ++ .../spark/ml/feature/StopWordsRemover.scala | 1 + .../spark/ml/feature/StringIndexer.scala | 2 ++ .../spark/ml/feature/VectorAssembler.scala | 1 + .../spark/ml/feature/VectorIndexer.scala | 2 ++ .../spark/ml/feature/VectorSlicer.scala | 1 + .../apache/spark/ml/feature/Word2Vec.scala | 1 + .../apache/spark/ml/recommendation/ALS.scala | 2 ++ .../ml/regression/AFTSurvivalRegression.scala | 1 + .../ml/regression/IsotonicRegression.scala | 1 + .../spark/ml/tuning/CrossValidator.scala | 2 ++ .../ml/tuning/TrainValidationSplit.scala | 2 ++ .../org/apache/spark/ml/PipelineSuite.scala | 23 ++++++++++++++++++- 30 files changed, 63 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 3acc60d6c6d65..32570a16e6707 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -165,6 +165,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M } override def transformSchema(schema: StructType): StructType = { + validateParams() val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") @@ -296,6 +297,7 @@ class PipelineModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 6aacffd4f236f..d1388b5e2eb5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { + validateParams() // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 1f3325ad09ef1..fdce273193b7c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6e5abb29ff0a3..dc6d5d9280970 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index af0b3e1835003..99383e77f7eb0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -263,6 +263,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 5b17d3483b895..544cf05a30d48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -72,6 +72,7 @@ final class Binarizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 33abc7c99d4b0..0c75317d82703 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index dfec03828f4b7..7b565ef3ed922 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) @@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1268c87908c62..10dcda2382f48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 61a78d73c4347..8af00581f7e54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -69,6 +69,7 @@ class HashingTF(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f7b0f29a27c2d..9e7eee4f29988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 559a025265916..ad0458d0d0e1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -59,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index c01e29af478c0..342540418fa80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index f653798b46043..7020397f3b064 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -77,6 +77,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -130,6 +131,7 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 39de8461dc9c7..8fd0ce2f2e26c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2b578c2a95e16..f9952434d2982 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { + validateParams() if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { @@ -178,6 +179,7 @@ class RFormulaModel private[feature]( } override def transformSchema(schema: StructType): StructType = { + validateParams() checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(withFeatures)) { @@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } @@ -288,6 +291,7 @@ private class VectorAttributeRewriter( } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType( schema.fields.filter(_.name != vectorCol) ++ schema.fields.filter(_.name == vectorCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index e0ca45b9a6190..af6494b234cee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + validateParams() val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index d76a9c6275e6b..6a0b6c240ec60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -143,6 +144,7 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d6936dce2c74..b93c9ed382bdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 5c40c35eeaa48..912bd95a2ec70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], @@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String) final def getLabels: Array[String] = $(labels) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType.isInstanceOf[NumericType], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index e9d1b57b91d07..0b215659b3672 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index a637a6f2881de..2a5268406ddf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod } override def transformSchema(schema: StructType): StructType = { + validateParams() // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). val dataType = new VectorUDT @@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val dataType = new VectorUDT require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 4813d8a5b5dc0..300d63bd3a0da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 59c34cd1703aa..2b6b3c3a0fc58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 14a28b8d5b51f..472c1854d3d1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType @@ -213,6 +214,7 @@ class ALSModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3787ca45d5172..e8a1ff2278a92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params protected def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index e8d361b1a2a8a..1573bb4c1b745 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures protected[ml] def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() if (fitting) { SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) if (hasWeightCol) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 477675cad1a90..3eac616aeaf8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index f346ea655ae5a..4f67e8c21994f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 8c86767456368..f3321fb5a1abc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("pipeline validateParams") { + val df = sqlContext.createDataFrame( + Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "features", "label") + + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("features_scaled") + .setMin(10) + .setMax(0) + val pipeline = new Pipeline().setStages(Array(scaler)) + pipeline.fit(df) + } + } } From 93ef9b6a2aa1830170cb101f191022f2dda62c41 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 4 Jan 2016 13:32:14 -0800 Subject: [PATCH 1424/2089] [SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction DecisionTreeRegressor will provide variance of prediction as a Double column. Author: Yanbo Liang Closes #8866 from yanboliang/spark-9622. --- .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 15 ++++++++ .../ml/regression/DecisionTreeRegressor.scala | 36 +++++++++++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 18 ++++++++++ .../DecisionTreeRegressorSuite.scala | 26 +++++++++++++- 5 files changed, 92 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c7bca1243092c..4aff749ff75af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,6 +44,7 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), + ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cb2a060a34dd6..c088c16d1b05d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -138,6 +138,21 @@ private[ml] trait HasProbabilityCol extends Params { final def getProbabilityCol: String = $(probabilityCol) } +/** + * Trait for shared param varianceCol. + */ +private[ml] trait HasVarianceCol extends Params { + + /** + * Param for Column name for the biased sample variance of prediction. + * @group param + */ + final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction") + + /** @group getParam */ + final def getVarianceCol: String = $(varianceCol) +} + /** * Trait for shared param threshold (default: 0.5). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 477030d9ea3ee..18c94f36387b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ /** * :: Experimental :: @@ -40,7 +41,7 @@ import org.apache.spark.sql.DataFrame @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams with TreeRegressorParams { + with DecisionTreeRegressorParams { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -73,6 +74,9 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setSeed(value: Long): this.type = super.setSeed(value) + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -113,7 +117,10 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with DecisionTreeRegressorParams with Serializable { + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") @@ -129,6 +136,29 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } + /** We need to update this function if we ever add other impurity measures. */ + protected def predictVariance(features: Vector): Double = { + rootNode.predictImpl(features).impurityStats.calculate() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + var output = dataset + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + } + output + } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 1da97db9277d8..7443097492d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** * Parameters for Decision Tree-based algorithms. @@ -256,6 +258,22 @@ private[ml] object TreeRegressorParams { final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) } +private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams + with TreeRegressorParams with HasVarianceCol { + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType) + } else { + newSchema + } + } +} + /** * Parameters for Decision Tree-based ensemble algorithms. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 6999a910c34a4..0b39af5543e93 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -73,6 +74,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex MLTestingUtils.checkCopy(model) } + test("predictVariance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setPredictionCol("") + .setVarianceCol("variance") + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = dt.fit(df) + + val predictions = model.transform(df) + .select(model.getFeaturesCol, model.getVarianceCol) + .collect() + + predictions.foreach { case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From d084a2de3271fd8b0d29ee67e1e214e8b9d96d6d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Jan 2016 14:26:56 -0800 Subject: [PATCH 1425/2089] [SPARK-12541] [SQL] support cube/rollup as function This PR enable cube/rollup as function, so they can be used as this: ``` select a, b, sum(c) from t group by rollup(a, b) ``` Author: Davies Liu Closes #10522 from davies/rollup. --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++--- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../catalyst/analysis/FunctionRegistry.scala | 4 ++ .../sql/catalyst/expressions/grouping.scala | 43 +++++++++++++++++++ .../plans/logical/basicOperators.scala | 37 ---------------- .../org/apache/spark/sql/GroupedData.scala | 6 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- 8 files changed, 87 insertions(+), 48 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c396546b4c005..06efcd42aa62e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.types._ /** @@ -208,10 +208,10 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case a: Cube => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) - case a: Rollup => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a1be1473cc80b..2a2e0d27d9435 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 12c24cc768225..57d1a1107e5f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -285,6 +285,10 @@ object FunctionRegistry { expression[InputFileName]("input_file_name"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + // grouping sets + expression[Cube]("cube"), + expression[Rollup]("rollup"), + // window functions expression[Lead]("lead"), expression[Lag]("lag"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala new file mode 100644 index 0000000000000..2997ee879d479 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +/** + * A placeholder expression for cube/rollup, which will be replaced by analyzer + */ +trait GroupingSet extends Expression with CodegenFallback { + + def groupByExprs: Seq[Expression] + override def children: Seq[Expression] = groupByExprs + + // this should be replaced first + override lazy val resolved: Boolean = false + + override def dataType: DataType = throw new UnsupportedOperationException + override def foldable: Boolean = false + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException +} + +case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} + +case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5f34d4a4eb73c..986062e3971c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -397,43 +397,6 @@ case class GroupingSets( this.copy(aggregations = aggs) } -/** - * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, - * and eventually will be transformed to Aggregate(.., Expand) in Analyzer - * - * @param groupByExprs The Group By expressions candidates. - * @param child Child operator - * @param aggregations The Aggregation expressions, those non selected group by expressions - * will be considered as constant null if it appears in the expressions - */ -case class Cube( - groupByExprs: Seq[Expression], - child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) -} - -/** - * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, - * and eventually will be transformed to Aggregate(.., Expand) in Analyzer - * - * @param groupByExprs The Group By expressions candidates, take effective only if the - * associated bit in the bitmask set to 1. - * @param child Child operator - * @param aggregations The Aggregation expressions, those non selected group by expressions - * will be considered as constant null if it appears in the expressions - */ -case class Rollup( - groupByExprs: Seq[Expression], - child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) -} - case class Pivot( groupByExprs: Seq[NamedExpression], pivotColumn: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 13341a88a6b74..2aa82f1496ae5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Aggregate} import org.apache.spark.sql.types.NumericType @@ -58,10 +58,10 @@ class GroupedData protected[sql]( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case GroupedData.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) DataFrame( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bb82b562aaaa2..115b617c211fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,4 +2028,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } + test("rollup") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + + " order by course, year"), + Row(null, null, 113000.0) :: + Row("Java", null, 50000.0) :: + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("dotNET", null, 63000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b1d841d1b5543..cbfe09b31d380 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1121,12 +1121,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C }), rollupGroupByClause.map(e => e match { case Token("TOK_ROLLUP_GROUPBY", children) => - Rollup(children.map(nodeToExpr), withLateralView, selectExpressions) + Aggregate(Seq(Rollup(children.map(nodeToExpr))), selectExpressions, withLateralView) case _ => sys.error("Expect WITH ROLLUP") }), cubeGroupByClause.map(e => e match { case Token("TOK_CUBE_GROUPBY", children) => - Cube(children.map(nodeToExpr), withLateralView, selectExpressions) + Aggregate(Seq(Cube(children.map(nodeToExpr))), selectExpressions, withLateralView) case _ => sys.error("Expect WITH CUBE") }), Some(Project(selectExpressions, withLateralView))).flatten.head From 34de24abb518e95c4312b77aa107d061ce02c835 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 4 Jan 2016 14:58:24 -0800 Subject: [PATCH 1426/2089] [SPARK-12589][SQL] Fix UnsafeRowParquetRecordReader to properly set the row length. The reader was previously not setting the row length meaning it was wrong if there were variable length columns. This problem does not manifest usually, since the value in the column is correct and projecting the row fixes the issue. Author: Nong Li Closes #10576 from nongli/spark-12589. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 ++++ .../parquet/UnsafeRowParquetRecordReader.java | 9 +++++++ .../datasources/parquet/ParquetIOSuite.scala | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7492b88c471a4..1a351933a366c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -177,6 +177,10 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } + public void setTotalSize(int sizeInBytes) { + this.sizeInBytes = sizeInBytes; + } + public void setNotNullAt(int i) { assertIndexIsValid(i); BitSetMethods.unset(baseObject, baseOffset, i); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index a6758bddfa7d0..198bfb6d67aee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -256,6 +256,15 @@ private boolean loadBatch() throws IOException { numBatched = num; batchIdx = 0; } + + // Update the total row lengths if the schema contained variable length. We did not maintain + // this as we populated the columns. + if (containsVarLenFields) { + for (int i = 0; i < numBatched; ++i) { + rows[i].setTotalSize(rowWriters[i].holder().totalSize()); + } + } + return true; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0c5d4887ed799..b0581e8b35510 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,6 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -618,6 +619,29 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { readResourceParquetFile("dec-in-fixed-len.parquet"), sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } + + test("SPARK-12589 copy() on rows returned from reader works for strings") { + withTempPath { dir => + val data = (1, "abc") ::(2, "helloabcde") :: Nil + data.toDF().write.parquet(dir.getCanonicalPath) + var hash1: Int = 0 + var hash2: Int = 0 + (false :: true :: Nil).foreach { v => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> v.toString) { + val df = sqlContext.read.parquet(dir.getCanonicalPath) + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) + if (!v) { + hash1 = unsafeRows(0).hashCode() + hash2 = unsafeRows(1).hashCode() + } else { + assert(hash1 == unsafeRows(0).hashCode()) + assert(hash2 == unsafeRows(1).hashCode()) + } + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From fdfac22d08fc4fdc640843dd93a29e2ce4aee2ef Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Mon, 4 Jan 2016 16:14:49 -0800 Subject: [PATCH 1427/2089] [SPARK-12509][SQL] Fixed error messages for DataFrame correlation and covariance Currently, when we call corr or cov on dataframe with invalid input we see these error messages for both corr and cov: - "Currently cov supports calculating the covariance between two columns" - "Covariance calculation for columns with dataType "[DataType Name]" not supported." I've fixed this issue by passing the function name as an argument. We could also do the input checks separately for each function. I avoided doing that because of code duplication. Thanks! Author: Narine Kokhlikyan Closes #10458 from NarineK/sparksqlstatsmessages. --- .../spark/sql/execution/stat/StatFunctions.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 00231d65a7d54..725d6821bf11c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "correlation") counts.Ck / math.sqrt(counts.MkX * counts.MkY) } @@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging { def cov: Double = Ck / (count - 1) } - private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { - require(cols.length == 2, "Currently cov supports calculating the covariance " + + private def collectStatisticalData(df: DataFrame, cols: Seq[String], + functionName: String): CovarianceCounter = { + require(cols.length == 2, s"Currently $functionName calculation is supported " + "between two columns.") cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") - require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + - s"with dataType ${data.get.dataType} not supported.") + require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + + s"for columns with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( @@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging { * @return the covariance of the two columns. */ private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "covariance") counts.cov } From 77ab49b8575d2ebd678065fa70b0343d532ab9c2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Jan 2016 18:02:38 -0800 Subject: [PATCH 1428/2089] [SPARK-12600][SQL] Remove deprecated methods in Spark SQL Author: Reynold Xin Closes #10559 from rxin/remove-deprecated-sql. --- dev/run-tests.py | 11 +- project/MimaExcludes.scala | 14 +- python/pyspark/__init__.py | 2 +- python/pyspark/sql/__init__.py | 2 +- python/pyspark/sql/column.py | 20 +- python/pyspark/sql/context.py | 111 ------ python/pyspark/sql/dataframe.py | 48 +-- python/pyspark/sql/functions.py | 24 -- python/pyspark/sql/readwriter.py | 20 +- python/pyspark/sql/tests.py | 32 +- .../util/LegacyTypeStringParser.scala | 92 +++++ .../org/apache/spark/sql/types/DataType.scala | 79 ---- .../apache/spark/sql/types/DecimalType.scala | 36 -- .../apache/spark/sql/types/StructType.scala | 11 +- .../scala/org/apache/spark/sql/Column.scala | 12 - .../org/apache/spark/sql/DataFrame.scala | 338 ------------------ .../apache/spark/sql/DataFrameReader.scala | 21 +- .../org/apache/spark/sql/SQLContext.scala | 302 ---------------- .../datasources/parquet/ParquetRelation.scala | 5 +- .../org/apache/spark/sql/functions.scala | 252 ------------- .../scala/org/apache/spark/sql/package.scala | 6 - .../spark/sql/JavaApplySchemaSuite.java | 4 +- .../spark/sql/ColumnExpressionSuite.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 9 - .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- .../hive/execution/HiveResolutionSuite.scala | 2 +- 28 files changed, 174 insertions(+), 1295 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala diff --git a/dev/run-tests.py b/dev/run-tests.py index 8726889cbc777..acc9450586fe3 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -425,12 +425,13 @@ def run_build_tests(): def run_sparkr_tests(): - set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") + # set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") - if which("R"): - run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) - else: - print("Ignoring SparkR tests as R was not found in PATH") + # if which("R"): + # run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) + # else: + # print("Ignoring SparkR tests as R was not found in PATH") + pass def parse_opts(): diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7a6e5cf4ad39a..cf11504b99451 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -43,15 +43,23 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.catalyst"), excludePackage("org.apache.spark.sql.execution"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + // SPARK-12600 Remove SQL deprecated methods + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load") ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ Seq( // SPARK-12481 Remove Hadoop 1.x - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.mapred.SparkHadoopMapRedUtil") + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil") ) case v if v.startsWith("1.6") => Seq( diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 8475dfb1c6ad0..d530723ca9803 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -65,7 +65,7 @@ def deco(f): # for back compatibility -from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row +from pyspark.sql import SQLContext, HiveContext, Row __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 98eaf52866d23..0b06c8339f501 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -47,7 +47,7 @@ from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 81fd4e782628a..900def59d23a5 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -27,8 +27,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.types import * -__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "Column", "DataFrameNaFunctions", "DataFrameStatFunctions"] def _create_column_from_literal(literal): @@ -272,23 +271,6 @@ def substr(self, startPos, length): __getslice__ = substr - @ignore_unicode_prefix - @since(1.3) - def inSet(self, *cols): - """ - A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - - .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. - """ - warnings.warn("inSet is deprecated. Use isin() instead.") - return self.isin(*cols) - @ignore_unicode_prefix @since(1.5) def isin(self, *cols): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index ba6915a12347e..91e27cf16e439 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -274,33 +274,6 @@ def _inferSchema(self, rdd, samplingRatio=None): schema = rdd.map(_infer_schema).reduce(_merge_type) return schema - @ignore_unicode_prefix - def inferSchema(self, rdd, samplingRatio=None): - """ - .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. - """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - return self.createDataFrame(rdd, None, samplingRatio) - - @ignore_unicode_prefix - def applySchema(self, rdd, schema): - """ - .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. - """ - warnings.warn("applySchema is deprecated, please use createDataFrame instead") - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType, but got %s" % type(schema)) - - return self.createDataFrame(rdd, schema) - def _createFromRDD(self, rdd, schema, samplingRatio): """ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. @@ -450,90 +423,6 @@ def dropTempTable(self, tableName): """ self._ssql_ctx.dropTempTable(tableName) - def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. - - >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes - [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] - """ - warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") - gateway = self._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] - jdf = self._ssql_ctx.parquetFile(jpaths) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - - >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes - [('age', 'bigint'), ('name', 'string')] - """ - warnings.warn("jsonFile is deprecated. Use read.json() instead.") - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - @ignore_unicode_prefix - @since(1.0) - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. - - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - - >>> df1 = sqlContext.jsonRDD(json) - >>> df1.first() - Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - - >>> df2 = sqlContext.jsonRDD(json, df1.schema) - >>> df2.first() - Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))])) - ... ]) - >>> df3 = sqlContext.jsonRDD(json, schema) - >>> df3.first() - Row(field2=u'row1', field3=Row(field5=None)) - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def load(self, path=None, source=None, schema=None, **options): - """Returns the dataset in a data source as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. - """ - warnings.warn("load is deprecated. Use read.load() instead.") - return self.read.load(path, source, schema, **options) - @since(1.3) def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad621df91064c..a7bc288e38861 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -36,7 +36,7 @@ from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * -__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] +__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -113,14 +113,6 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - def saveAsParquetFile(self, path): - """Saves the contents as a Parquet file, preserving the schema. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. - """ - warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") - self._jdf.saveAsParquetFile(path) - @since(1.3) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. @@ -135,38 +127,6 @@ def registerTempTable(self, name): """ self._jdf.registerTempTable(name) - def registerAsTable(self, name): - """ - .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. - """ - warnings.warn("Use registerTempTable instead of registerAsTable.") - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this :class:`DataFrame` into the specified table. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. - """ - warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") - self.write.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName, source=None, mode="error", **options): - """Saves the contents of this :class:`DataFrame` to a data source as a table. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. - """ - warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") - self.write.saveAsTable(tableName, source, mode, **options) - - @since(1.3) - def save(self, path=None, source=None, mode="error", **options): - """Saves the contents of the :class:`DataFrame` to a data source. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. - """ - warnings.warn("insertInto is deprecated. Use write.save() instead.") - return self.write.save(path, source, mode, **options) - @property @since(1.4) def write(self): @@ -1388,12 +1348,6 @@ def toPandas(self): drop_duplicates = dropDuplicates -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """SchemaRDD is deprecated, please use :class:`DataFrame`. - """ - - def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 25594d79c2141..7c15e38458690 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -149,12 +149,8 @@ def _(): } _window_functions = { - 'rowNumber': - """.. note:: Deprecated in 1.6, use row_number instead.""", 'row_number': """returns a sequential number starting at 1 within a window partition.""", - 'denseRank': - """.. note:: Deprecated in 1.6, use dense_rank instead.""", 'dense_rank': """returns the rank of rows within a window partition, without any gaps. @@ -171,13 +167,9 @@ def _(): place and that the next person came in third. This is equivalent to the RANK function in SQL.""", - 'cumeDist': - """.. note:: Deprecated in 1.6, use cume_dist instead.""", 'cume_dist': """returns the cumulative distribution of values within a window partition, i.e. the fraction of rows that are below the current row.""", - 'percentRank': - """.. note:: Deprecated in 1.6, use percent_rank instead.""", 'percent_rank': """returns the relative rank (i.e. percentile) of rows within a window partition.""", } @@ -318,14 +310,6 @@ def isnull(col): return Column(sc._jvm.functions.isnull(_to_java_column(col))) -@since(1.4) -def monotonicallyIncreasingId(): - """ - .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. - """ - return monotonically_increasing_id() - - @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. @@ -434,14 +418,6 @@ def shiftRightUnsigned(col, numBits): return Column(jc) -@since(1.4) -def sparkPartitionId(): - """ - .. note:: Deprecated in 1.6, use spark_partition_id instead. - """ - return spark_partition_id() - - @since(1.6) def spark_partition_id(): """A column for partition ID of the Spark task. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a2771daabe331..0b20022b14b8d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -130,11 +130,9 @@ def load(self, path=None, format=None, schema=None, **options): self.schema(schema) self.options(**options) if path is not None: - if type(path) == list: - return self._df( - self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) - else: - return self._df(self._jreader.load(path)) + if type(path) != list: + path = [path] + return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load()) @@ -179,7 +177,17 @@ def json(self, path, schema=None): elif type(path) == list: return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): - return self._df(self._jreader.json(path._jrdd)) + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = path.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString()) + return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string or RDD") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9ada96601a1cd..e396cf41f2f7b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -326,7 +326,7 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) + df = self.sqlCtx.read.json(rdd) df.count() df.collect() df.schema @@ -345,7 +345,7 @@ def test_basic_functions(self): df.collect() def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) @@ -408,12 +408,12 @@ def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]) - df = self.sqlCtx.inferSchema(nestedRdd1) + df = self.sqlCtx.createDataFrame(nestedRdd1) self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.sqlCtx.inferSchema(nestedRdd2) + df = self.sqlCtx.createDataFrame(nestedRdd2) self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) from collections import namedtuple @@ -421,7 +421,7 @@ def test_infer_nested_schema(self): rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): @@ -581,14 +581,14 @@ def test_parquet_with_udt(self): df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') - df1 = self.sqlCtx.parquetFile(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) @@ -763,7 +763,7 @@ def test_save_and_load(self): defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.load(path=tmpPath) + actual = self.sqlCtx.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) @@ -796,7 +796,7 @@ def test_save_and_load_builder(self): defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.load(path=tmpPath) + actual = self.sqlCtx.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) @@ -805,7 +805,7 @@ def test_save_and_load_builder(self): def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) + df = self.sqlCtx.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) @@ -853,8 +853,8 @@ def test_infer_long_type(self): # this saving as Parquet caused issues as well. output_dir = os.path.join(self.tempdir.name, "infer_long_type") - df.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) + df.write.parquet(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) self.assertEqual('a', df1.first().f1) self.assertEqual(100000000000000, df1.first().f2) @@ -1205,9 +1205,9 @@ def test_window_functions(self): F.max("key").over(w.rowsBetween(0, 1)), F.min("key").over(w.rowsBetween(0, 1)), F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.rowNumber().over(w), + F.row_number().over(w), F.rank().over(w), - F.denseRank().over(w), + F.dense_rank().over(w), F.ntile(2).over(w)) rs = sorted(sel.collect()) expected = [ @@ -1227,9 +1227,9 @@ def test_window_functions_without_partitionBy(self): F.max("key").over(w.rowsBetween(0, 1)), F.min("key").over(w.rowsBetween(0, 1)), F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.rowNumber().over(w), + F.row_number().over(w), F.rank().over(w), - F.denseRank().over(w), + F.dense_rank().over(w), F.ntile(2).over(w)) rs = sorted(sel.collect()) expected = [ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala new file mode 100644 index 0000000000000..e27cf9c1989f3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.sql.types._ + +/** + * Parser that turns case class strings into datatypes. This is only here to maintain compatibility + * with Parquet files written by Spark 1.1 and below. + */ +object LegacyTypeStringParser extends RegexParsers { + + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + */ + def parse(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f8d71c5f02372..301b3a70f68f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.types -import scala.util.Try -import scala.util.parsing.combinator.RegexParsers - import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s._ @@ -94,18 +91,9 @@ abstract class DataType extends AbstractDataType { object DataType { - private[sql] def fromString(raw: String): DataType = { - Try(DataType.fromJson(raw)).getOrElse(DataType.fromCaseClassString(raw)) - } def fromJson(json: String): DataType = parseDataType(parse(json)) - /** - * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` - */ - @deprecated("Use DataType.fromJson instead", "1.2.0") - def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) - private val nonDecimalNameToType = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) @@ -184,73 +172,6 @@ object DataType { StructField(name, parseDataType(dataType), nullable) } - private object CaseClassStringParser extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.USER_DEFAULT - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) - } - - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => - StructField(name, tpe, nullable = nullable) - } - - protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) - - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) - - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => - throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") - } - } - protected[types] def buildFormattedString( dataType: DataType, prefix: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ce45245b9f6dd..fdae2e82a01b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -20,25 +20,10 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression -/** Precision parameters for a Decimal */ -@deprecated("Use DecimalType(precision, scale) directly", "1.5") -case class PrecisionInfo(precision: Int, scale: Int) { - if (scale > precision) { - throw new AnalysisException( - s"Decimal scale ($scale) cannot be greater than precision ($precision).") - } - if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException( - s"DecimalType can only support precision up to 38" - ) - } -} - /** * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. @@ -58,15 +43,6 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { def this(precision: Int) = this(precision, 0) def this() = this(10) - @deprecated("Use DecimalType(precision, scale) instead", "1.5") - def this(precisionInfo: Option[PrecisionInfo]) { - this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, - precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) - } - - @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5") - val precisionInfo = Some(PrecisionInfo(precision, scale)) - private[sql] type InternalType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = Decimal.DecimalIsFractional @@ -122,9 +98,6 @@ object DecimalType extends AbstractDataType { val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) - @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") - val Unlimited: DecimalType = SYSTEM_DEFAULT - // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) @@ -142,15 +115,6 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } - @deprecated("please specify precision and scale", "1.5") - def apply(): DecimalType = USER_DEFAULT - - @deprecated("Use DecimalType(precision, scale) instead", "1.5") - def apply(precisionInfo: Option[PrecisionInfo]) { - this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, - precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) - } - private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d56802276558a..34382bf124eb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.types import scala.collection.mutable.ArrayBuffer +import scala.util.Try import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.catalyst.util.{LegacyTypeStringParser, DataTypeParser} /** @@ -337,9 +338,11 @@ object StructType extends AbstractDataType { override private[sql] def simpleString: String = "struct" - private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { - case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + private[sql] def fromString(raw: String): StructType = { + Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { + case t: StructType => t + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + } } def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 5026c0d6d12b5..71fa970907f6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -708,18 +708,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def mod(other: Any): Column = this % other - /** - * A boolean expression that is evaluated to true if the value of this expression is contained - * by the evaluated values of the arguments. - * - * @group expr_ops - * @since 1.3.0 - * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. - */ - @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") - @scala.annotation.varargs - def in(list: Any*): Column = isin(list : _*) - /** * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 0763aa4ed99da..c42192c83de89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1750,344 +1750,6 @@ class DataFrame private[sql]( } } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. - */ - @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") - def toSchemaRDD: DataFrame = this - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write - w.jdbc(url, table, new Properties) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) - w.jdbc(url, table, new Properties) - } - - /** - * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. - * Files that are written out using this method can be read back in as a [[DataFrame]] - * using the `parquetFile` function in [[SQLContext]]. - * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") - def saveAsParquetFile(path: String): Unit = { - write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - } - - /** - * Creates a table from the the contents of this DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * This will fail if the table already exists. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable(tableName: String): Unit = { - write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) - } - - /** - * Creates a table from the the contents of this DataFrame, using the default data source - * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", - "1.4.0") - def saveAsTable(tableName: String, mode: SaveMode): Unit = { - write.mode(mode).saveAsTable(tableName) - } - - /** - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source and a set of options, - * using [[SaveMode.ErrorIfExists]] as the save mode. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", - "1.4.0") - def saveAsTable(tableName: String, source: String): Unit = { - write.format(source).saveAsTable(tableName) - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { - write.format(source).mode(mode).saveAsTable(tableName) - } - - /** - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: java.util.Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).saveAsTable(tableName) - } - - /** - * (Scala-specific) - * Creates a table from the the contents of this DataFrame based on a given data source, - * [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).saveAsTable(tableName) - } - - /** - * Saves the contents of this DataFrame to the given path, - * using the default data source configured by spark.sql.sources.default and - * [[SaveMode.ErrorIfExists]] as the save mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") - def save(path: String): Unit = { - write.save(path) - } - - /** - * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, - * using the default data source configured by spark.sql.sources.default. - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") - def save(path: String, mode: SaveMode): Unit = { - write.mode(mode).save(path) - } - - /** - * Saves the contents of this DataFrame to the given path based on the given data source, - * using [[SaveMode.ErrorIfExists]] as the save mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") - def save(path: String, source: String): Unit = { - write.format(source).save(path) - } - - /** - * Saves the contents of this DataFrame to the given path based on the given data source and - * [[SaveMode]] specified by mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).save(path). " + - "This will be removed in Spark 2.0.", "1.4.0") - def save(path: String, source: String, mode: SaveMode): Unit = { - write.format(source).mode(mode).save(path) - } - - /** - * Saves the contents of this DataFrame based on the given data source, - * [[SaveMode]] specified by mode, and a set of options. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).save(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def save( - source: String, - mode: SaveMode, - options: java.util.Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).save() - } - - /** - * (Scala-specific) - * Saves the contents of this DataFrame based on the given data source, - * [[SaveMode]] specified by mode, and a set of options - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).save(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def save( - source: String, - mode: SaveMode, - options: Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).save() - } - - /** - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def insertInto(tableName: String, overwrite: Boolean): Unit = { - write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) - } - - /** - * Adds the rows from this RDD to the specified table. - * Throws an exception if the table already exists. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().mode(SaveMode.Append).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def insertInto(tableName: String): Unit = { - write.mode(SaveMode.Append).insertInto(tableName) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6debb302d9ecf..d4df913e472a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -98,17 +98,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { this } - /** - * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by - * a local or distributed file system). - * - * @since 1.4.0 - */ - // TODO: Remove this one in Spark 2.0. - def load(path: String): DataFrame = { - option("path", path).load() - } - /** * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external * key-value stores). @@ -125,6 +114,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { DataFrame(sqlContext, LogicalRelation(resolved.relation)) } + /** + * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by + * a local or distributed file system). + * + * @since 1.4.0 + */ + def load(path: String): DataFrame = { + option("path", path).load() + } + /** * Loads input in as a [[DataFrame]], for data sources that support multiple paths. * Only works if the source is a HadoopFsRelationProvider. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 022303239f2af..3a875c4f9a284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -888,9 +888,6 @@ class SQLContext private[sql]( }.toArray } - @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") - protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) - @transient protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) @@ -908,10 +905,6 @@ class SQLContext private[sql]( ) } - @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") - protected[sql] class QueryExecution(logical: LogicalPlan) - extends sparkexecution.QueryExecution(this, logical) - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. @@ -952,301 +945,6 @@ class SQLContext private[sql]( } } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else { - read.parquet(paths : _*) - } - } - - /** - * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonFile(path: String): DataFrame = { - read.json(path) - } - - /** - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonFile(path: String, schema: StructType): DataFrame = { - read.schema(schema).json(path) - } - - /** - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonFile(path: String, samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(path) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: RDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an JavaRDD storing JSON objects (one object per record) and applies the given - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") - def load(path: String): DataFrame = { - read.load(path) - } - - /** - * Returns the dataset stored at path as a DataFrame, using the given data source. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") - def load(path: String, source: String): DataFrame = { - read.format(source).load(path) - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use read.format(source).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - */ - @deprecated("Use read.format(source).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, options: Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = - { - read.format(source).schema(schema).options(options).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def jdbc(url: String, table: String): DataFrame = { - read.jdbc(url, table, new Properties) - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - read.jdbc(url, table, theParts, new Properties) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - // Register a succesfully instantiatd context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index af964b4d35611..8e1fe8090cc1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -44,6 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} @@ -638,7 +639,7 @@ private[sql] object ParquetRelation extends Logging { logInfo( s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema.get) + LegacyTypeStringParser.parse(serializedSchema.get) } .recover { case cause: Throwable => logWarning( @@ -821,7 +822,7 @@ private[sql] object ParquetRelation extends Logging { logInfo( s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + LegacyTypeStringParser.parse(schemaString).asInstanceOf[StructType] }.recoverWith { case cause: Throwable => logWarning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3572f3c3a1f2c..2b3db398aaecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -558,13 +558,6 @@ object functions extends LegacyFunctions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. - */ - @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") - def cumeDist(): Column = cume_dist() - /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -579,13 +572,6 @@ object functions extends LegacyFunctions { */ def cume_dist(): Column = withExpr { new CumeDist } - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. - */ - @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") - def denseRank(): Column = dense_rank() - /** * Window function: returns the rank of rows within a window partition, without any gaps. * @@ -715,13 +701,6 @@ object functions extends LegacyFunctions { */ def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. - */ - @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") - def percentRank(): Column = percent_rank() - /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. * @@ -752,13 +731,6 @@ object functions extends LegacyFunctions { */ def rank(): Column = withExpr { new Rank } - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. - */ - @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") - def rowNumber(): Column = row_number() - /** * Window function: returns a sequential number starting at 1 within a window partition. * @@ -827,13 +799,6 @@ object functions extends LegacyFunctions { @scala.annotation.varargs def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } - /** - * @group normal_funcs - * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. - */ - @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") - def inputFileName(): Column = input_file_name() - /** * Creates a string column for the file name of the current Spark task. * @@ -842,13 +807,6 @@ object functions extends LegacyFunctions { */ def input_file_name(): Column = withExpr { InputFileName() } - /** - * @group normal_funcs - * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. - */ - @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") - def isNaN(e: Column): Column = isnan(e) - /** * Return true iff the column is NaN. * @@ -972,14 +930,6 @@ object functions extends LegacyFunctions { */ def randn(): Column = randn(Utils.random.nextLong) - /** - * @group normal_funcs - * @since 1.4.0 - * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. - */ - @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") - def sparkPartitionId(): Column = spark_partition_id() - /** * Partition ID of the Spark task. * @@ -2534,24 +2484,6 @@ object functions extends LegacyFunctions { }""") } - (0 to 10).map { x => - val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") - val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") - println(s""" - /** - * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr { - ScalaUDF(f, returnType, Option(Seq($argsInUDF))) - }""") - } */ /** * Defines a user-defined function of 0 arguments as user-defined function (UDF). @@ -2685,161 +2617,6 @@ object functions extends LegacyFunctions { UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } - ////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { - ScalaUDF(f, returnType, Seq()) - } - - /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr)) - } - - /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) - } - - /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) - } - - /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) - } - - /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) - } - - /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) - } - - /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) - } - - /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) - } - - /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf(). - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) - } - - /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf(). - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) - } - // scalastyle:on parameter.number // scalastyle:on line.size.limit @@ -2877,33 +2654,4 @@ object functions extends LegacyFunctions { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } - /** - * Call an user-defined function. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUDF", $"value")) - * }}} - * - * @group udf_funcs - * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. - * This will be removed in Spark 2.0. - */ - @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") - def callUdf(udfName: String, cols: Column*): Column = withExpr { - // Note: we avoid using closures here because on file systems that are case-insensitive, the - // compiled class file for the closure here will conflict with the one in callUDF (upper case). - val exprs = new Array[Expression](cols.size) - var i = 0 - while (i < cols.size) { - exprs(i) = cols(i).expr - i += 1 - } - UnresolvedFunction(udfName, exprs, isDistinct = false) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index a9c600b139b18..bd73a36fd40b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -42,10 +42,4 @@ package object sql { @DeveloperApi type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] - /** - * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. - * @deprecated As of 1.3.0, replaced by `DataFrame`. - */ - @deprecated("use DataFrame", "1.3.0") - type SchemaRDD = DataFrame } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 7b50aad4ad498..640efcc737eaa 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -107,7 +107,7 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.applySchema(rowRDD, schema); + DataFrame df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); @@ -143,7 +143,7 @@ public Row call(Person person) { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.applySchema(rowRDD, schema); + DataFrame df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 38c0eb589f965..53a9788024ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -298,7 +298,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - testData.select(isNaN($"a"), isNaN($"b")), + testData.select(isnan($"a"), isnan($"b")), Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( @@ -586,7 +586,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( - df.select(sparkPartitionId()), + df.select(spark_partition_id()), Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil ) } @@ -595,11 +595,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(inputFileName()).limit(1), Row("")) + checkAnswer(data.select(input_file_name()).limit(1), Row("")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ab02b32f91aff..e8fa663363731 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -341,15 +341,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("deprecated callUdf in SQLContext") { - val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - val sqlctx = df.sqlContext - sqlctx.udf.register("simpleUdf", (v: Int) => v * v) - checkAnswer( - df.select($"id", callUdf("simpleUdf", $"value")), - Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) - } - test("callUDF in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5d00e7367026f..86769f1a0d412 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} import org.apache.spark.sql.execution.datasources.{ResolveDataSource, DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.types._ @@ -567,7 +567,7 @@ class HiveContext private[hive]( } @transient - private val hivePlanner = new SparkPlanner with HiveStrategies { + private val hivePlanner = new SparkPlanner(this) with HiveStrategies { val hiveContext = self override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d38ad9127327d..e8376083c0d39 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.hive.execution._ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. - self: SQLContext#SparkPlanner => + self: SparkPlanner => val hiveContext: HiveContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 71f05f3b00291..bff6811bf4164 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -363,7 +363,7 @@ object SPARK_11009 extends QueryTest { val df = sqlContext.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) - val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0") if (df3.rdd.count() != 0) { throw new Exception("df3 should have 0 output row.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b08db6de2d2f6..dd13b8392880a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, sql} import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) From b1a771231e20df157fb3e780287390a883c0cc6f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Jan 2016 18:49:41 -0800 Subject: [PATCH 1429/2089] [SPARK-12480][SQL] add Hash expression that can calculate hash value for a group of expressions just write the arguments into unsafe row and use murmur3 to calculate hash code Author: Wenchen Fan Closes #10435 from cloud-fan/hash-expr. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 + .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../spark/sql/catalyst/expressions/misc.scala | 44 +++++++++++ .../catalyst/encoders/RowEncoderSuite.scala | 2 +- .../expressions/MiscFunctionsSuite.scala | 73 ++++++++++++++++++- .../org/apache/spark/sql/functions.scala | 11 +++ .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +++ .../execution/HiveCompatibilitySuite.scala | 3 + .../apache/spark/sql/hive/test/TestHive.scala | 24 ++++++ .../sql/hive/execution/HiveQuerySuite.scala | 3 - 10 files changed, 171 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1a351933a366c..b8d3c49100476 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -566,6 +566,10 @@ public int hashCode() { return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); } + public int hashCode(int seed) { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed); + } + @Override public boolean equals(Object other) { if (other instanceof UnsafeRow) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 57d1a1107e5f6..5c2aa3c06b3e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -49,7 +49,7 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - private val functionBuilders = + private[sql] val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) override def registerFunction( @@ -278,6 +278,7 @@ object FunctionRegistry { // misc functions expression[Crc32]("crc32"), expression[Md5]("md5"), + expression[Murmur3Hash]("hash"), expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index d0ec99b2320df..8834924687c0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -22,6 +22,8 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -177,3 +179,45 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +/** + * A function that calculates hash value for a group of expressions. Note that the `seed` argument + * is not exposed to users and should only be set inside spark SQL. + * + * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of + * the unsafe row using murmur3 hasher with a seed. + * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle + * and bucketing have same data distribution. + */ +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be empty") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private lazy val unsafeProjection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + unsafeProjection(input).hashCode(seed) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children) + ev.isNull = "false" + s""" + ${unsafeRow.code} + final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); + """ + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 8f4faab7bace5..b17f8d5ec70af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -99,7 +99,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("binary", BinaryType) .add("date", DateType) .add("timestamp", TimestampType) - .add("udt", new ExamplePointUDT, false)) + .add("udt", new ExamplePointUDT)) encodeDecodeTest( new StructType() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75d17417e5a02..9175568f43a4e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} +import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -59,4 +61,73 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) + private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) + private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + + testMurmur3Hash( + new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT)) + + testMurmur3Hash( + new StructType() + .add("arrayOfNull", arrayOfNull) + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) + + testMurmur3Hash( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + testMurmur3Hash( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + private def testMurmur3Hash(inputSchema: StructType): Unit = { + val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get + val encoder = RowEncoder(inputSchema) + val seed = scala.util.Random.nextInt() + test(s"murmur3 hash: ${inputSchema.simpleString}") { + for (_ <- 1 to 10) { + val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { + case (value, dt) => Literal.create(value, dt) + } + checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed)) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2b3db398aaecd..e223e32fd702e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1813,6 +1813,17 @@ object functions extends LegacyFunctions { */ def crc32(e: Column): Column = withExpr { Crc32(e.expr) } + /** + * Calculates the hash code of given columns, and returns the result as a int column. + * + * @group misc_funcs + * @since 2.0 + */ + @scala.annotation.varargs + def hash(col: Column, cols: Column*): Column = withExpr { + new Murmur3Hash((col +: cols).map(_.expr)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 115b617c211fc..72845711adddd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2057,4 +2057,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("hash function") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + withTempTable("tbl") { + df.registerTempTable("tbl") + checkAnswer( + df.select(hash($"i", $"j")), + sql("SELECT hash(i, j) from tbl") + ) + } + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2b0e48dbfcf28..bd1a52e5f3303 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -53,6 +53,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + // Use Hive hash expression instead of the native one + TestHive.functionRegistry.unregisterFunction("hash") RuleExecutor.resetTime() } @@ -62,6 +64,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + TestHive.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 013fbab0a812b..66d5f20d88421 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -31,10 +31,13 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.sql.{SQLContext, SQLConf} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} @@ -451,6 +454,27 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logError("FATAL ERROR: Failed to reset TestDB state.", e) } } + + @transient + override protected[sql] lazy val functionRegistry = new TestHiveFunctionRegistry( + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) +} + +private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper) + extends HiveFunctionRegistry(fr, client) { + + private val removedFunctions = + collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] + + def unregisterFunction(name: String): Unit = { + fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) + } + + def restore(): Unit = { + removedFunctions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + } } private[hive] object TestHiveContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 8a5acaf3e10bc..acd1130f2762c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -387,9 +387,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("partitioned table scan", "SELECT ds, hr, key, value FROM srcpart") - createQueryTest("hash", - "SELECT hash('test') FROM src LIMIT 1") - createQueryTest("create table as", """ |CREATE TABLE createdtable AS SELECT * FROM src; From 8896ec9f02a6747917f3ae42a517ff0e3742eaf6 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 5 Jan 2016 08:39:58 +0530 Subject: [PATCH 1430/2089] [SPARKR][DOC] minor doc update for version in migration guide checked that the change is in Spark 1.6.0. shivaram Author: felixcheung Closes #10574 from felixcheung/rwritemodedoc. --- docs/sparkr.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index 9ddd2eda3fe8b..ea81532c611e2 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -385,12 +385,12 @@ The following functions are masked by the SparkR package:
    driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded - on the master and workers before running an JDBC commands to allow the driver to - register itself with the JDBC subsystem. + The class name of the JDBC driver to use to connect to this URL.
    Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. - + You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) # Migration Guide -## Upgrading From SparkR 1.6 to 1.7 +## Upgrading From SparkR 1.5.x to 1.6 - - Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API. + - Before Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API. From b634901bb28070ac5d9a24a9bc7b7640472a54e2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Jan 2016 21:05:27 -0800 Subject: [PATCH 1431/2089] [SPARK-12600][SQL] follow up: add range check for DecimalType This addresses davies' code review feedback in https://github.com/apache/spark/pull/10559 Author: Reynold Xin Closes #10586 from rxin/remove-deprecated-sql-followup. --- .../scala/org/apache/spark/sql/types/DecimalType.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index fdae2e82a01b3..cf5322125bd72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression @@ -39,6 +40,15 @@ import org.apache.spark.sql.catalyst.expressions.Expression @DeveloperApi case class DecimalType(precision: Int, scale: Int) extends FractionalType { + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } + + if (precision > DecimalType.MAX_PRECISION) { + throw new AnalysisException(s"DecimalType can only support precision up to 38") + } + // default constructor for Java def this(precision: Int) = this(precision, 0) def this() = this(10) From cc4d5229c98a589da76a4d5e5fdc5ea92385183b Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 4 Jan 2016 22:32:07 -0800 Subject: [PATCH 1432/2089] [SPARK-12625][SPARKR][SQL] replace R usage of Spark SQL deprecated API rxin davies shivaram Took save mode from my PR #10480, and move everything to writer methods. This is related to PR #10559 - [x] it seems jsonRDD() is broken, need to investigate - this is not a public API though; will look into some more tonight. (fixed) Author: felixcheung Closes #10584 from felixcheung/rremovedeprecated. --- R/pkg/R/DataFrame.R | 33 +++++++++++------------ R/pkg/R/SQLContext.R | 10 +++---- R/pkg/R/column.R | 2 +- R/pkg/R/utils.R | 9 +++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 +-- dev/run-tests.py | 11 ++++---- 6 files changed, 38 insertions(+), 31 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0cfa12b997d69..c126f9efb475a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -458,7 +458,10 @@ setMethod("registerTempTable", setMethod("insertInto", signature(x = "DataFrame", tableName = "character"), function(x, tableName, overwrite = FALSE) { - callJMethod(x@sdf, "insertInto", tableName, overwrite) + jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) + write <- callJMethod(x@sdf, "write") + write <- callJMethod(write, "mode", jmode) + callJMethod(write, "insertInto", tableName) }) #' Cache @@ -1948,18 +1951,15 @@ setMethod("write.df", source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - allModes <- c("append", "overwrite", "error", "ignore") - # nolint start - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') - } - # nolint end - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } - callJMethod(df@sdf, "save", source, jmode, options) + write <- callJMethod(df@sdf, "write") + write <- callJMethod(write, "format", source) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "save", path) }) #' @rdname write.df @@ -2013,15 +2013,14 @@ setMethod("saveAsTable", source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - allModes <- c("append", "overwrite", "error", "ignore") - # nolint start - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') - } - # nolint end - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) - callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + + write <- callJMethod(df@sdf, "write") + write <- callJMethod(write, "format", source) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) + callJMethod(write, "saveAsTable", tableName) }) #' summary diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9243d70e66f75..ccc683d86a3e5 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -256,9 +256,12 @@ jsonFile <- function(sqlContext, path) { # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { + .Deprecated("read.json") rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + read <- callJMethod(sqlContext, "read") + # samplingRatio is deprecated + sdf <- callJMethod(read, "json", callJMethod(getJRDD(rdd), "rdd")) dataFrame(sdf) } else { stop("not implemented") @@ -289,10 +292,7 @@ read.parquet <- function(sqlContext, path) { # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { .Deprecated("read.parquet") - # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) - sdf <- callJMethod(sqlContext, "parquetFile", paths) - dataFrame(sdf) + read.parquet(sqlContext, unlist(list(...))) } #' SQL Query diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 356bcee3cf5c6..3ffd9a9890b2e 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -209,7 +209,7 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - jc <- callJMethod(x@jc, "in", as.list(table)) + jc <- callJMethod(x@jc, "isin", as.list(table)) return(column(jc)) }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 43105aaa38424..aa386e5da933b 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -641,3 +641,12 @@ assignNewEnv <- function(data) { splitString <- function(input) { Filter(nzchar, unlist(strsplit(input, ",|\\s"))) } + +convertToJSaveMode <- function(mode) { + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') # nolint + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode +} diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9e5d0ebf60720..ebe8faa34cf7d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -423,12 +423,12 @@ test_that("read/write json files", { test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) expect_equal(count(rdd), 3) - df <- jsonRDD(sqlContext, rdd) + df <- suppressWarnings(jsonRDD(sqlContext, rdd)) expect_is(df, "DataFrame") expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlContext, rdd2) + df <- suppressWarnings(jsonRDD(sqlContext, rdd2)) expect_is(df, "DataFrame") expect_equal(count(df), 6) }) diff --git a/dev/run-tests.py b/dev/run-tests.py index acc9450586fe3..8726889cbc777 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -425,13 +425,12 @@ def run_build_tests(): def run_sparkr_tests(): - # set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") + set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") - # if which("R"): - # run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) - # else: - # print("Ignoring SparkR tests as R was not found in PATH") - pass + if which("R"): + run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) + else: + print("Ignoring SparkR tests as R was not found in PATH") def parse_opts(): From 7058dc115047258197f6c09eee404f1ccf41038d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 4 Jan 2016 22:42:54 -0800 Subject: [PATCH 1433/2089] [SPARK-3873][EXAMPLES] Import ordering fixes. Author: Marcelo Vanzin Closes #10575 from vanzin/SPARK-3873-examples. --- .../org/apache/spark/examples/CassandraCQLTest.scala | 3 +-- .../org/apache/spark/examples/CassandraTest.scala | 2 +- .../org/apache/spark/examples/DFSReadWriteTest.scala | 2 +- .../scala/org/apache/spark/examples/HBaseTest.scala | 3 +-- .../org/apache/spark/examples/LocalFileLR.scala | 2 +- .../org/apache/spark/examples/LocalKMeans.scala | 2 +- .../scala/org/apache/spark/examples/LocalLR.scala | 2 +- .../apache/spark/examples/MultiBroadcastTest.scala | 2 +- .../org/apache/spark/examples/SparkHdfsLR.scala | 2 +- .../org/apache/spark/examples/SparkKMeans.scala | 2 +- .../scala/org/apache/spark/examples/SparkLR.scala | 2 +- .../org/apache/spark/examples/SparkPageRank.scala | 2 +- .../scala/org/apache/spark/examples/SparkTC.scala | 2 +- .../apache/spark/examples/SparkTachyonHdfsLR.scala | 3 +-- .../org/apache/spark/examples/graphx/Analytics.scala | 5 +++-- .../spark/examples/graphx/SynthBenchmark.scala | 5 +++-- .../examples/ml/AFTSurvivalRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/BinarizerExample.scala | 2 +- .../apache/spark/examples/ml/BucketizerExample.scala | 2 +- .../spark/examples/ml/ChiSqSelectorExample.scala | 2 +- .../spark/examples/ml/CountVectorizerExample.scala | 3 +-- .../spark/examples/ml/CrossValidatorExample.scala | 2 +- .../org/apache/spark/examples/ml/DCTExample.scala | 2 +- .../ml/DecisionTreeClassificationExample.scala | 8 ++++---- .../spark/examples/ml/DecisionTreeExample.scala | 7 +++---- .../examples/ml/DecisionTreeRegressionExample.scala | 12 +++++++----- .../examples/ml/ElementwiseProductExample.scala | 2 +- .../ml/GradientBoostedTreeClassifierExample.scala | 2 +- .../ml/GradientBoostedTreeRegressorExample.scala | 2 +- .../spark/examples/ml/IndexToStringExample.scala | 4 ++-- .../org/apache/spark/examples/ml/LDAExample.scala | 4 ++-- .../ml/LinearRegressionWithElasticNetExample.scala | 2 +- .../ml/LogisticRegressionSummaryExample.scala | 2 +- .../ml/LogisticRegressionWithElasticNetExample.scala | 2 +- .../spark/examples/ml/MinMaxScalerExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 4 ++-- .../org/apache/spark/examples/ml/NGramExample.scala | 2 +- .../apache/spark/examples/ml/NormalizerExample.scala | 2 +- .../spark/examples/ml/OneHotEncoderExample.scala | 2 +- .../apache/spark/examples/ml/OneVsRestExample.scala | 4 ++-- .../org/apache/spark/examples/ml/PCAExample.scala | 2 +- .../examples/ml/PolynomialExpansionExample.scala | 2 +- .../examples/ml/QuantileDiscretizerExample.scala | 2 +- .../apache/spark/examples/ml/RFormulaExample.scala | 2 +- .../examples/ml/RandomForestClassifierExample.scala | 2 +- .../examples/ml/RandomForestRegressorExample.scala | 2 +- .../spark/examples/ml/SQLTransformerExample.scala | 3 +-- .../spark/examples/ml/StandardScalerExample.scala | 2 +- .../spark/examples/ml/StopWordsRemoverExample.scala | 2 +- .../spark/examples/ml/StringIndexerExample.scala | 2 +- .../org/apache/spark/examples/ml/TfIdfExample.scala | 2 +- .../apache/spark/examples/ml/TokenizerExample.scala | 2 +- .../examples/ml/TrainValidationSplitExample.scala | 2 +- .../spark/examples/ml/VectorAssemblerExample.scala | 2 +- .../spark/examples/ml/VectorIndexerExample.scala | 2 +- .../spark/examples/ml/VectorSlicerExample.scala | 2 +- .../apache/spark/examples/ml/Word2VecExample.scala | 2 +- .../examples/mllib/AssociationRulesExample.scala | 3 +-- .../spark/examples/mllib/BinaryClassification.scala | 2 +- .../mllib/BinaryClassificationMetricsExample.scala | 2 +- .../examples/mllib/BisectingKMeansExample.scala | 2 +- .../apache/spark/examples/mllib/Correlations.scala | 3 +-- .../spark/examples/mllib/CosineSimilarity.scala | 2 +- .../mllib/DecisionTreeClassificationExample.scala | 2 +- .../mllib/DecisionTreeRegressionExample.scala | 2 +- .../spark/examples/mllib/DecisionTreeRunner.scala | 2 +- .../spark/examples/mllib/FPGrowthExample.scala | 2 +- .../examples/mllib/GradientBoostedTreesRunner.scala | 3 +-- .../GradientBoostingClassificationExample.scala | 2 +- .../mllib/GradientBoostingRegressionExample.scala | 2 +- .../examples/mllib/IsotonicRegressionExample.scala | 2 +- .../apache/spark/examples/mllib/LBFGSExample.scala | 3 +-- .../org/apache/spark/examples/mllib/LDAExample.scala | 4 ++-- .../spark/examples/mllib/LinearRegression.scala | 2 +- .../examples/mllib/MultiLabelMetricsExample.scala | 2 +- .../examples/mllib/MulticlassMetricsExample.scala | 2 +- .../examples/mllib/MultivariateSummarizer.scala | 3 +-- .../spark/examples/mllib/NaiveBayesExample.scala | 2 +- .../mllib/PowerIterationClusteringExample.scala | 2 +- .../spark/examples/mllib/PrefixSpanExample.scala | 3 +-- .../mllib/RandomForestClassificationExample.scala | 2 +- .../mllib/RandomForestRegressionExample.scala | 2 +- .../spark/examples/mllib/RandomRDDGeneration.scala | 3 +-- .../spark/examples/mllib/RankingMetricsExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../examples/mllib/RegressionMetricsExample.scala | 4 ++-- .../apache/spark/examples/mllib/SampledRDDs.scala | 2 +- .../apache/spark/examples/mllib/SimpleFPGrowth.scala | 3 +-- .../examples/mllib/StreamingLinearRegression.scala | 2 +- .../examples/mllib/StreamingLogisticRegression.scala | 4 ++-- .../examples/pythonconverters/HBaseConverters.scala | 7 ++++--- .../spark/examples/sql/hive/HiveFromSpark.scala | 4 ++-- .../spark/examples/streaming/ActorWordCount.scala | 6 +++--- .../spark/examples/streaming/CustomReceiver.scala | 4 ++-- .../examples/streaming/FlumePollingEventCount.scala | 3 ++- .../spark/examples/streaming/KafkaWordCount.scala | 4 ++-- .../spark/examples/streaming/NetworkWordCount.scala | 2 +- .../streaming/RecoverableNetworkWordCount.scala | 2 +- .../examples/streaming/SqlNetworkWordCount.scala | 4 ++-- .../streaming/StatefulNetworkWordCount.scala | 2 +- .../spark/examples/streaming/StreamingExamples.scala | 4 ++-- .../examples/streaming/TwitterAlgebirdHLL.scala | 4 ++-- .../streaming/TwitterHashTagJoinSentiments.scala | 2 +- .../spark/examples/streaming/ZeroMQWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../streaming/clickstream/PageViewStream.scala | 3 ++- 106 files changed, 147 insertions(+), 154 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 5a80985a49458..973b005f91f63 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -22,15 +22,14 @@ import java.nio.ByteBuffer import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat import org.apache.cassandra.hadoop.cql3.CqlConfigHelper import org.apache.cassandra.hadoop.cql3.CqlOutputFormat +import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} - /* Need to create following keyspace and column family in cassandra before running this example Start CQL shell using ./bin/cqlsh and execute following commands diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ad39a012b4ae6..6a8f73ad000f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -23,9 +23,9 @@ import java.util.Arrays import java.util.SortedMap import org.apache.cassandra.db.IColumn +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.thrift._ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index d651fe4d6ee75..b26db0b2462e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.File import scala.io.Source._ -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 244742327a907..65d7489586062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, TableName} +import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.mapreduce.TableInputFormat import org.apache.spark._ - object HBaseTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HBaseTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 9c8aae53cf48d..a3901850f283e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import java.util.Random -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} /** * Logistic regression based classification. diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index e7b28d38bdfc6..407e3e08b968f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -23,7 +23,7 @@ import java.util.Random import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import breeze.linalg.{Vector, DenseVector, squaredDistance} +import breeze.linalg.{squaredDistance, DenseVector, Vector} import org.apache.spark.SparkContext._ diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 4f6b092a59ca5..58adbabe4454d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import java.util.Random -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} /** * Logistic regression based classification. diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 61ce9db914f9f..a797111dbad1a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.rdd.RDD /** * Usage: MultiBroadcastTest [slices] [numElem] diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 505ea5a4c7a85..6c90dbec3d531 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -22,7 +22,7 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index c56e1124ad415..1ea9121e2749e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import breeze.linalg.{Vector, DenseVector, squaredDistance} +import breeze.linalg.{squaredDistance, DenseVector, Vector} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index d265c227f4ed2..132800e6e4ca0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -22,7 +22,7 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.spark._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 0fd79660dd196..018bdf6d31033 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.SparkContext._ import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ /** * Computes the PageRank of URLs from an input file. Input file should diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 95072071ccddb..b92740f1fbcbd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import scala.util.Random import scala.collection.mutable +import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index cfbdae02212a5..e492582710e12 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -22,14 +22,13 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.storage.StorageLevel - /** * Logistic regression based classification. * This example uses Tachyon to persist rdds during computation. diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 8dd6c9706e7df..39cb83d9eeb7e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -19,11 +19,12 @@ package org.apache.spark.examples.graphx import scala.collection.mutable + import org.apache.spark._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.PartitionStrategy._ +import org.apache.spark.graphx.lib._ +import org.apache.spark.storage.StorageLevel /** * Driver program for running graph algorithms. diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 46e52aacd90bb..41ca5cbb9f083 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -18,11 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.graphx +import java.io.{FileOutputStream, PrintWriter} + +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy} -import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.graphx.util.GraphGenerators -import java.io.{PrintWriter, FileOutputStream} /** * The SynthBenchmark application can be used to run various GraphX algorithms on diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index f4b3613ccb94f..21f58ddf3cfb7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.regression.AFTSurvivalRegression import org.apache.spark.mllib.linalg.Vectors // $example off$ +import org.apache.spark.sql.SQLContext /** * An example for AFTSurvivalRegression. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala index e724aa587294b..2ed8101c133cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Binarizer // $example off$ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} object BinarizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 7c75e3d72b47b..6f6236a2b0588 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Bucketizer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object BucketizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala index a8d2bc4907e80..2be61537e613a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.ChiSqSelector import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object ChiSqSelectorExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala index ba916f66c4c07..7d07fc7dd113a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - object CountVectorizerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 14b358d46f6ab..bca301d412f4c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala index 314c2c28a2a10..dc26b55a768a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.DCT import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object DCTExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index db024b5cad935..224d8da5f0ec3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -18,15 +18,15 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object DecisionTreeClassificationExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index c4e98dfaca6c9..a37d12aa636cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -27,14 +27,13 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} -import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.evaluation.{MulticlassMetrics, RegressionMetrics} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{SQLContext, DataFrame} - +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example runner for decision trees. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ad01f55df72b5..ad32e5635a3ea 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -17,15 +17,17 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} + +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.DecisionTreeRegressor -import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.regression.DecisionTreeRegressor // $example off$ +import org.apache.spark.sql.SQLContext + object DecisionTreeRegressionExample { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala index 872de51dc75df..629d322c4357f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object ElementwiseProductExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 474af7db4b49b..cd62a803820cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object GradientBoostedTreeClassifierExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index da1cd9c2ce525..b8cf9629bbdab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} // $example off$ +import org.apache.spark.sql.SQLContext object GradientBoostedTreeRegressorExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala index 52537e5bb568d..4cea09ba12656 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.ml.feature.{StringIndexer, IndexToString} +import org.apache.spark.ml.feature.{IndexToString, StringIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object IndexToStringExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala index 419ce3d87a6ac..f9ddac77090ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -18,10 +18,10 @@ package org.apache.spark.examples.ml // scalastyle:off println -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.clustering.LDA +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.types.{StructField, StructType} // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 22c824cea84d3..c7352b3e7ab9c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.regression.LinearRegression // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object LinearRegressionWithElasticNetExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 4c420421b670e..04c60c0c1d067 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} // $example off$ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.functions.max -import org.apache.spark.{SparkConf, SparkContext} object LogisticRegressionSummaryExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala index 9ee995b52c90b..f632960f26ae5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.LogisticRegression // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object LogisticRegressionWithElasticNetExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala index fb7f28c9886bb..9a03f69f5af03 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.MinMaxScaler // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object MinMaxScalerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 9c98076bd24b1..d7d1e82f6f849 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // $example off$ +import org.apache.spark.sql.SQLContext /** * An example for Multilayer Perceptron Classification. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala index 8a85f71b56f3d..77b913aaa3fa0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.NGram // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object NGramExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala index 1990b55e8c5e8..6b33c16c74037 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Normalizer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object NormalizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala index 66602e2118506..cb9fe65a85e86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object OneHotEncoderExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index b46faea5713fb..ccee3b2aef980 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -22,10 +22,10 @@ import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} import scopt.OptionParser -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala index 4c806f71a32c3..535652ec6c793 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.PCA import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object PCAExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala index 39fb79af35766..3014008ea0ce4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.PolynomialExpansion import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object PolynomialExpansionExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index 8f29b7eaa6d26..e64e673a485ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.QuantileDiscretizer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object QuantileDiscretizerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala index 286866edea502..bec831d51c581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.RFormula // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object RFormulaExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index e79176ca6ca1c..6c9b52cf259e6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.classification.{RandomForestClassificationModel, Rand import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object RandomForestClassifierExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index acec1437a1af5..4d2db017f346f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} // $example off$ +import org.apache.spark.sql.SQLContext object RandomForestRegressorExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala index 014abd1fdbc63..202925acadff2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.SQLTransformer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - object SQLTransformerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala index e0a41e383a7ea..e3439677e78d6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StandardScaler // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object StandardScalerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala index 655ffce08d3ab..8199be12c155b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StopWordsRemover // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object StopWordsRemoverExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala index 9fa494cd2473b..3f0e870c8dc6b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StringIndexer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object StringIndexerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala index 40c33e4e7d44e..28115f939082e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object TfIdfExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala index 01e0d1388a2f4..c667728d6326d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object TokenizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala index cd1b0e9358beb..fbba17eba6a2f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -17,11 +17,11 @@ package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} /** * A simple example demonstrating model selection using TrainValidationSplit. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala index d527924419f81..768a8c0690477 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object VectorAssemblerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index 685891c164e70..3bef37ba360b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.VectorIndexer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object VectorIndexerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala index 04f19829eff87..01377d80e7e5c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -18,6 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.feature.VectorSlicer @@ -26,7 +27,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object VectorSlicerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala index 631ab4c8efa0d..e77aa59ba32b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Word2Vec // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object Word2VecExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ca22ddafc3c48..11e18c9f040bc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.AssociationRules import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object AssociationRulesExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 1a4016f76c2ad..2282bd2b7d680 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -24,8 +24,8 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.optimization.{L1Updater, SquaredL2Updater} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} /** * An example app for binary classification. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index 13a37827ab935..ade33fc5090f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -18,13 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkContext, SparkConf} object BinaryClassificationMetricsExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala index 3a596cccb87d3..53d0b8fc208ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -18,11 +18,11 @@ package org.apache.spark.examples.mllib // scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.clustering.BisectingKMeans import org.apache.spark.mllib.linalg.{Vector, Vectors} // $example off$ -import org.apache.spark.{SparkConf, SparkContext} /** * An example demonstrating a bisecting k-means clustering in spark.mllib. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index 026d4ecc6d10a..e003f35ed399f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -20,10 +20,9 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for summarizing multivariate data from a file. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index 69988cc1b9334..eda211b5a8dfd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -20,10 +20,10 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} -import org.apache.spark.{SparkConf, SparkContext} /** * Compute the similar columns of a matrix, using cosine similarity. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index d427bbadaa0c1..c6c7c6f5e2ed8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object DecisionTreeClassificationExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index fb05e7d9c5065..9c8baed3b8668 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object DecisionTreeRegressionExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index cc6bce3cb7c9c..c263f4f595a37 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} +import org.apache.spark.mllib.tree.{impurity, DecisionTree, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.util.MLUtils diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 14b930550d554..a7a3eade04a0c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -20,8 +20,8 @@ package org.apache.spark.examples.mllib import scopt.OptionParser -import org.apache.spark.mllib.fpm.FPGrowth import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.fpm.FPGrowth /** * Example for mining frequent itemsets using FP-growth. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index e16a6bf033574..b0144ef533133 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -23,10 +23,9 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy} import org.apache.spark.util.Utils - /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 139e1f909bdce..0ec2e11214e89 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index 3dc86da8e4d2b..b87ba0defe695 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 52ac9ae7dd2d0..3834ea807acbf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object IsotonicRegressionExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index 61d2e7715f53d..75a0419da5ec3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -18,6 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -26,8 +27,6 @@ import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Up import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object LBFGSExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 70010b05e4345..d28323555b990 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,16 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.log4j.{Level, Logger} import scopt.OptionParser -import org.apache.log4j.{Level, Logger} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 8878061a0970b..f87611f5d4613 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -22,9 +22,9 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.optimization.{L1Updater, SimpleUpdater, SquaredL2Updater} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1Updater} /** * An example app for linear regression. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala index 4503c15360adc..c0d447bf69dd7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.evaluation.MultilabelMetrics import org.apache.spark.rdd.RDD // $example off$ -import org.apache.spark.{SparkContext, SparkConf} object MultiLabelMetricsExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala index 0904449245989..4f925ede24d82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -18,13 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkContext, SparkConf} object MulticlassMetricsExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 5f839c75dd581..3c598172dadf0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -20,11 +20,10 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for summarizing multivariate data from a file. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index a7a47c2a3556a..8bae1b9d1832d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object NaiveBayesExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 0723223954610..9208d8e245881 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -21,9 +21,9 @@ package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.clustering.PowerIterationClustering import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} /** * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index d237232c430ca..ef86eab9e4ec5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.PrefixSpan // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object PrefixSpanExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index 5e55abd5121c4..7805153ba7b95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index a54fb3ab7e37a..655a277e28ae8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index bee85ba0f9969..7ccbb5a0640cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -18,11 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for randomly generated RDDs. Run with * {{{ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala index cffa03d5cc9f4..fdb01b86dd787 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.evaluation.{RankingMetrics, RegressionMetrics} import org.apache.spark.mllib.recommendation.{ALS, Rating} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} object RankingMetricsExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 64e4602465444..bc946951aebf9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.recommendation.ALS import org.apache.spark.mllib.recommendation.MatrixFactorizationModel diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala index 47d44532521ca..ace16ff1ea225 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -18,13 +18,13 @@ package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.util.MLUtils // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object RegressionMetricsExample { def main(args: Array[String]) : Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 6963f43e082c4..c4e5e965b8f40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.mllib.util.MLUtils import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.MLUtils /** * An example app for randomly generated and sampled RDDs. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b4e06afa7410f..ab15ac2c54d3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.FPGrowth import org.apache.spark.rdd.RDD // $example off$ -import org.apache.spark.{SparkContext, SparkConf} - object SimpleFPGrowth { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index b4a5dca031abd..e5592966f13fa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,9 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.SparkConf import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} -import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index b42f4cb5f9338..a8b144a197229 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.SparkConf +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD -import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 0a25ee7ae56f4..e252ca882e534 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -20,12 +20,13 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject -import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.CellUtil +import org.apache.hadoop.hbase.KeyValue.Type import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes -import org.apache.hadoop.hbase.KeyValue.Type -import org.apache.hadoop.hbase.CellUtil + +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts all diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index bf40bd1ef13df..4e427f54daa52 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.sql.hive -import com.google.common.io.{ByteStreams, Files} - import java.io.File +import com.google.common.io.{ByteStreams, Files} + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.apache.spark.sql.hive.HiveContext diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index e9c9907198769..8b8dae0be6119 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -22,13 +22,13 @@ import scala.collection.mutable.LinkedList import scala.reflect.ClassTag import scala.util.Random -import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} +import akka.actor.{actorRef2Scala, Actor, ActorRef, Props} -import org.apache.spark.{SparkConf, SecurityManager} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.util.AkkaUtils import org.apache.spark.streaming.receiver.ActorHelper +import org.apache.spark.util.AkkaUtils case class SubscribeReceiver(receiverActor: ActorRef) case class UnsubscribeReceiver(receiverActor: ActorRef) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 28e9bf520e568..ad13d437dd546 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import java.io.{InputStreamReader, BufferedReader, InputStream} +import java.io.{BufferedReader, InputStream, InputStreamReader} import java.net.Socket -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.receiver.Receiver diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 2bdbc37e2a289..fe3b79ed5d29d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -18,12 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import java.net.InetSocketAddress + import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.flume._ import org.apache.spark.util.IntParam -import java.net.InetSocketAddress /** * Produces a count of events received from Flume. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index b40d17e9c2fa3..e7f9bf36e35cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -20,11 +20,11 @@ package org.apache.spark.examples.streaming import java.util.HashMap -import org.apache.kafka.clients.producer.{ProducerConfig, KafkaProducer, ProducerRecord} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka._ -import org.apache.spark.SparkConf /** * Consumes messages from one or more topics in Kafka and does wordcount. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 9a57fe286d1ae..15b57fccb4076 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -19,8 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 38d4fd11f97d1..05f8e65d65a20 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.apache.spark.{Accumulator, SparkConf, SparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.streaming.{Seconds, StreamingContext, Time} import org.apache.spark.util.IntParam /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index ed617754cbf1c..9aa0f54312d28 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -21,10 +21,10 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, Seconds, StreamingContext} -import org.apache.spark.util.IntParam import org.apache.spark.sql.SQLContext import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, Time} +import org.apache.spark.util.IntParam /** * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 2dce1820d9734..c85d6843dc993 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import org.apache.spark.SparkConf import org.apache.spark.HashPartitioner +import org.apache.spark.SparkConf import org.apache.spark.streaming._ /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala index 8396e65d0d588..22a5654405dd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala @@ -17,10 +17,10 @@ package org.apache.spark.examples.streaming -import org.apache.spark.Logging - import org.apache.log4j.{Level, Logger} +import org.apache.spark.Logging + /** Utility functions for Spark Streaming examples. */ object StreamingExamples extends Logging { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 49826ede70418..0ec6214fdef1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -18,13 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import com.twitter.algebird.HyperLogLogMonoid import com.twitter.algebird.HyperLogLog._ +import com.twitter.algebird.HyperLogLogMonoid +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.twitter._ -import org.apache.spark.SparkConf // scalastyle:off /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala index 0328fa81ea897..edf0e0b7b2b46 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala @@ -19,8 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.streaming.twitter.TwitterUtils import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.twitter.TwitterUtils /** * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6ac9a72c37941..96448905760fb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -18,18 +18,18 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import scala.language.implicitConversions + import akka.actor.ActorSystem import akka.actor.actorRef2Scala +import akka.util.ByteString import akka.zeromq._ import akka.zeromq.Subscribe -import akka.util.ByteString +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.zeromq._ -import scala.language.implicitConversions -import org.apache.spark.SparkConf - /** * A simple publisher for demonstration purposes, repeatedly publishes random Messages * every one second. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 2fcccb22dddf7..ce1a62060ef6c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -18,9 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import java.net.ServerSocket import java.io.PrintWriter -import util.Random +import java.net.ServerSocket +import java.util.Random /** Represents a page view on a website with associated dimension data. */ class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 723616817f6a2..4b43550a065bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -18,8 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.examples.streaming.StreamingExamples +import org.apache.spark.streaming.{Seconds, StreamingContext} + // scalastyle:off /** Analyses a streaming dataset of web page views. This class demonstrates several types of * operators available in Spark streaming. From 53beddc5bf04a35ab73de99158919c2fdd5d4508 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 4 Jan 2016 23:23:41 -0800 Subject: [PATCH 1434/2089] [SPARK-12568][SQL] Add BINARY to Encoders Author: Michael Armbrust Closes #10516 from marmbrus/datasetCleanup. --- .../src/main/scala/org/apache/spark/sql/Encoder.scala | 6 ++++++ .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 9 +++++++++ .../sql/catalyst/encoders/ExpressionEncoderSuite.scala | 6 +++--- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index bb0fdc4c3d83b..22b7e1ea0c4cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -157,6 +157,12 @@ object Encoders { */ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + /** + * An encoder for arrays of bytes. + * @since 1.6.1 + */ + def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index ad4beda9c4916..6c058463b9cf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -198,6 +198,15 @@ case class ExpressionEncoder[T]( @transient private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) + /** + * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns + * is performed). + */ + def defaultBinding: ExpressionEncoder[T] = { + val attrs = schema.toAttributes + resolve(attrs, OuterScopes.outerScopes).bind(attrs) + } + /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 666699e18d4a5..3740dea8aad51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -77,6 +77,8 @@ class JavaSerializable(val value: Int) extends Serializable { } class ExpressionEncoderSuite extends SparkFunSuite { + OuterScopes.outerScopes.put(getClass.getName, this) + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() // test flat encoders @@ -278,8 +280,6 @@ class ExpressionEncoderSuite extends SparkFunSuite { } } - private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() - outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { @@ -287,7 +287,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema, outers).bind(schema) + val boundEncoder = encoder.defaultBinding val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( From 8eb2dc7133b4d2143adffc2bdbb61d96bd41a0ac Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 5 Jan 2016 00:39:50 -0800 Subject: [PATCH 1435/2089] [SPARK-12641] Remove unused code related to Hadoop 0.23 Currently we don't support Hadoop 0.23 but there is a few code related to it so let's clean it up. Author: Kousuke Saruta Closes #10590 from sarutak/SPARK-12641. --- .../main/scala/org/apache/spark/util/Utils.scala | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0c1f9c1ae2496..9bdcc4d817a4a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -662,9 +662,7 @@ private[spark] object Utils extends Logging { private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { // These environment variables are set by YARN. - // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs()) - // For Hadoop 2.X, we check for CONTAINER_ID. - conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null + conf.getenv("CONTAINER_ID") != null } /** @@ -740,17 +738,12 @@ private[spark] object Utils extends Logging { logError(s"Failed to create local root dir in $root. Ignoring this directory.") None } - }.toArray + } } /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(conf.getenv("LOCAL_DIRS")) - .getOrElse("")) + val localDirs = Option(conf.getenv("LOCAL_DIRS")).getOrElse("") if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") From 1cdc42d2b99edfec01066699a7620cca02b61f0e Mon Sep 17 00:00:00 2001 From: Imran Younus Date: Tue, 5 Jan 2016 11:48:45 +0000 Subject: [PATCH 1436/2089] [SPARK-12331][ML] R^2 for regression through the origin. Modified the definition of R^2 for regression through origin. Added modified test for regression metrics. Author: Imran Younus Author: Imran Younus Closes #10384 from iyounus/SPARK_12331_R2_for_regression_through_origin. --- .../ml/regression/LinearRegression.scala | 3 +- .../mllib/evaluation/RegressionMetrics.scala | 24 ++- .../evaluation/RegressionMetricsSuite.scala | 156 ++++++++++-------- 3 files changed, 112 insertions(+), 71 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index dee26337dcdf6..c54e08b2ad9a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -534,7 +534,8 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions .select(predictionCol, labelCol) - .map { case Row(pred: Double, label: Double) => (pred, label) } ) + .map { case Row(pred: Double, label: Double) => (pred, label) }, + !model.getFitIntercept) /** * Returns the explained variance regression score. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 34883f2f390d1..18c90b204a26a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,11 +27,18 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param throughOrigin True if the regression is through the origin. For example, in linear + * regression, it will be true without fitting intercept. */ @Since("1.2.0") -class RegressionMetrics @Since("1.2.0") ( - predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("2.0.0") ( + predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + extends Logging { + + @Since("1.2.0") + def this(predictionAndObservations: RDD[(Double, Double)]) = + this(predictionAndObservations, false) /** * An auxiliary constructor taking a DataFrame. @@ -53,6 +60,8 @@ class RegressionMetrics @Since("1.2.0") ( ) summary } + + private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) private lazy val SStot = summary.variance(0) * (summary.count - 1) private lazy val SSreg = { @@ -102,9 +111,16 @@ class RegressionMetrics @Since("1.2.0") ( /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * In case of regression through the origin, the definition of R^2^ is to be modified. + * @see J. G. Eisenhauer, Regression through the Origin. Teaching Statistics 25, 76-80 (2003) + * [[https://online.stat.psu.edu/~ajw13/stat501/SpecialTopics/Reg_thru_origin.pdf]] */ @Since("1.2.0") def r2: Double = { - 1 - SSerr / SStot + if (throughOrigin) { + 1 - SSerr / SSy + } else { + 1 - SSerr / SStot + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 4b7f1be58f99b..f1d517383643d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -22,91 +22,115 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + val obs = List[Double](77, 85, 62, 55, 63, 88, 57, 81, 51) + val eps = 1E-5 test("regression metrics for unbiased (includes intercept term) predictor") { /* Verify results in R: - preds = c(2.25, -0.25, 1.75, 7.75) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.796875 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.3125 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.559017 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9571734 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.08 91.88 65.48 52.28 62.18 81.98 58.88 78.68 55.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + [1] 157.3 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 3.7355556 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 17.539511 + rmse = sqrt(meanSquaredError) + rmse + [1] 4.18802 + r2 = summary(model)$r.squared + r2 + [1] 0.89968225 */ - val predictionAndObservations = sc.parallelize( - Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val preds = List(72.08, 91.88, 65.48, 52.28, 62.18, 81.98, 58.88, 78.68, 55.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + assert(metrics.explainedVariance ~== 157.3 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 3.7355556 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 17.539511 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 4.18802 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.89968225 absTol eps, "r2 score mismatch") } test("regression metrics for biased (no intercept term) predictor") { /* Verify results in R: - preds = c(2.5, 0.0, 2.0, 8.0) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.859375 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.375 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.6123724 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9486081 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ 0 + x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.12 99.17 63.11 45.08 58.60 85.65 54.09 81.14 49.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 294.88167 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 4.5888889 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 39.958711 + rmse = sqrt(meanSquaredError) + rmse + [1] 6.3212903 + r2 = summary(model)$r.squared + r2 + [1] 0.99185395 */ - val predictionAndObservations = sc.parallelize( - Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) - val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, + val preds = List(72.12, 99.17, 63.11, 45.08, 58.6, 85.65, 54.09, 81.14, 49.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) + val metrics = new RegressionMetrics(predictionAndObservations, true) + assert(metrics.explainedVariance ~== 294.88167 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 4.5888889 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 39.958711 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 6.3212903 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.99185395 absTol eps, "r2 score mismatch") } test("regression metrics with complete fitting") { - val predictionAndObservations = sc.parallelize( - Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) + /* Verify results in R: + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + preds = y + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 174.8395 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 0 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 0 + rmse = sqrt(meanSquaredError) + rmse + [1] 0 + r2 = 1 - sum((preds - y)^2)/sum((y - mean(y))^2) + r2 + [1] 1 + */ + val preds = obs + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, + assert(metrics.explainedVariance ~== 174.83951 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 0.0 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } } From b3c48e39f4a0a42a0b6b433511b2cce0d1e3f03d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jan 2016 10:19:56 -0800 Subject: [PATCH 1437/2089] [SPARK-12438][SQL] Add SQLUserDefinedType support for encoder JIRA: https://issues.apache.org/jira/browse/SPARK-12438 ScalaReflection lacks the support of SQLUserDefinedType. We should add it. Author: Liang-Chi Hsieh Closes #10390 from viirya/encoder-udt. --- .../spark/sql/catalyst/ScalaReflection.scala | 22 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 14 ++++++++++++ .../encoders/ExpressionEncoderSuite.scala | 2 ++ 3 files changed, 38 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9784c969665de..c6aa60b0b4d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } + val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -360,6 +361,16 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } + + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -409,6 +420,7 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { + val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -559,6 +571,16 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b18f49f3203fb..d82d3edae4e38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -81,6 +82,9 @@ object Cast { toField.nullable) } + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass => + true + case _ => false } @@ -431,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) @@ -473,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + (c, evPrim, evNull) => s"$evPrim = $c;" + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 3740dea8aad51..6453f1c191ba0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -244,6 +244,8 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + productTest(("UDT", new ExamplePoint(0.1, 0.2))) + test("nullable of encoder schema") { def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) From 9a6ba7e2c538124f539b50512a7f95059f81cc16 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jan 2016 10:21:47 -0800 Subject: [PATCH 1438/2089] [SPARK-12643][BUILD] Set lib directory for antlr JIRA: https://issues.apache.org/jira/browse/SPARK-12643 Without setting lib directory for antlr, the updates of imported grammar files can not be detected. So SparkSqlParser.g will not be rebuilt automatically. Since it is a minor update, no JIRA ticket is opened. Let me know if it is needed. Thanks. Author: Liang-Chi Hsieh Closes #10571 from viirya/antlr-build. --- project/SparkBuild.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 588e97f64e054..af1d36c6ea57b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -443,6 +443,10 @@ object Hive { val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) antlr.addGrammarFile(relGFilePath) + // We will set library directory multiple times here. However, only the + // last one has effect. Because the grammar files are located under the same directory, + // We assume there is only one library directory. + antlr.setLibDirectory(gFilePath.getParent) } // Generate the parser. From 76768337beec6842660db7522ad15c25ee66d346 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Jan 2016 10:23:36 -0800 Subject: [PATCH 1439/2089] [SPARK-12480][FOLLOW-UP] use a single column vararg for hash address comments in #10435 This makes the API easier to use if user programmatically generate the call to hash, and they will get analysis exception if the arguments of hash is empty. Author: Wenchen Fan Closes #10588 from cloud-fan/hash. --- python/pyspark/sql/functions.py | 12 ++++++++++++ .../apache/spark/sql/catalyst/expressions/misc.scala | 2 +- .../analysis/ExpressionTypeCheckingSuite.scala | 1 + .../main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7c15e38458690..b0390cb9942e6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1018,6 +1018,18 @@ def sha2(col, numBits): return Column(jc) +@since(2.0) +def hash(*cols): + """Calculates the hash code of given columns, and returns the result as a int column. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() + [Row(hash=1358996357)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------- String/Binary functions ------------------------------ _string_functions = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8834924687c0c..6697d463614d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -200,7 +200,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { - TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be empty") + TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 915c585ec91fb..f3df716a57824 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -163,6 +163,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e223e32fd702e..1c96f647b6345 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1820,8 +1820,8 @@ object functions extends LegacyFunctions { * @since 2.0 */ @scala.annotation.varargs - def hash(col: Column, cols: Column*): Column = withExpr { - new Murmur3Hash((col +: cols).map(_.expr)) + def hash(cols: Column*): Column = withExpr { + new Murmur3Hash(cols.map(_.expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// From 8ce645d4eeda203cf5e100c4bdba2d71edd44e6a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 5 Jan 2016 11:10:14 -0800 Subject: [PATCH 1440/2089] [SPARK-12615] Remove some deprecated APIs in RDD/SparkContext I looked at each case individually and it looks like they can all be removed. The only one that I had to think twice was toArray (I even thought about un-deprecating it, until I realized it was a problem in Java to have toArray returning java.util.List). Author: Reynold Xin Closes #10569 from rxin/SPARK-12615. --- .../scala/org/apache/spark/Aggregator.scala | 8 - .../scala/org/apache/spark/SparkContext.scala | 261 +----------------- .../scala/org/apache/spark/TaskContext.scala | 16 -- .../org/apache/spark/TaskContextImpl.scala | 10 - .../apache/spark/api/java/JavaDoubleRDD.scala | 1 - .../apache/spark/api/java/JavaRDDLike.scala | 10 - .../spark/api/java/JavaSparkContext.scala | 28 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 - .../main/scala/org/apache/spark/rdd/RDD.scala | 101 ------- .../org/apache/spark/rdd/SampledRDD.scala | 71 ----- .../spark/rdd/SequenceFileRDDFunctions.scala | 5 - .../org/apache/spark/scheduler/TaskInfo.scala | 3 - .../org/apache/spark/util/RpcUtils.scala | 11 - .../java/org/apache/spark/JavaAPISuite.java | 7 - .../scala/org/apache/spark/rdd/RDDSuite.scala | 60 ---- .../spark/scheduler/TaskContextSuite.scala | 8 - .../spark/util/ClosureCleanerSuite.scala | 21 -- ...avaBinaryClassificationMetricsExample.java | 12 +- .../apache/spark/examples/SparkHdfsLR.scala | 7 +- .../spark/examples/SparkTachyonHdfsLR.scala | 6 +- project/MimaExcludes.scala | 53 +++- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 22 files changed, 64 insertions(+), 643 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 7196e57d5d2e2..62629000cfc23 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,10 +34,6 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = - combineValuesByKey(iter, null) - def combineValuesByKey( iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { @@ -47,10 +43,6 @@ case class Aggregator[K, V, C] ( combiners.iterator } - @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = - combineCombinersByKey(iter, null) - def combineCombinersByKey( iter: Iterator[_ <: Product2[K, C]], context: TaskContext): Iterator[(K, C)] = { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 77e44ee0264af..87301202dea27 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -27,7 +27,7 @@ import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicIntege import java.util.UUID.randomUUID import scala.collection.JavaConverters._ -import scala.collection.{Map, Set} +import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -122,20 +122,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def this() = this(new SparkConf()) - /** - * :: DeveloperApi :: - * Alternative constructor for setting preferred locations where Spark will create executors. - * - * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters - * @param preferredNodeLocationData not used. Left for backward compatibility. - */ - @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") - @DeveloperApi - def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { - this(config) - logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") - } - /** * Alternative constructor that allows setting common Spark properties directly * @@ -155,21 +141,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. - * @param preferredNodeLocationData not used. Left for backward compatibility. */ - @deprecated("Passing in preferred locations has no effect at all, see SPARK-10921", "1.6.0") def this( master: String, appName: String, sparkHome: String = null, jars: Seq[String] = Nil, - environment: Map[String, String] = Map(), - preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = + environment: Map[String, String] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) - if (preferredNodeLocationData.nonEmpty) { - logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") - } } // NOTE: The below constructors could be consolidated using default arguments. Due to @@ -267,8 +247,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Generate the random name for a temp folder in external block store. // Add a timestamp as the suffix here to make it more safe val externalBlockStoreFolderName = "spark-" + randomUUID.toString() - @deprecated("Use externalBlockStoreFolderName instead.", "1.4.0") - val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) @@ -641,11 +619,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli localProperties.set(props) } - @deprecated("Properties no longer need to be explicitly initialized.", "1.0.0") - def initLocalProperties() { - localProperties.set(new Properties()) - } - /** * Set a local property that affects jobs submitted from this thread, such as the * Spark fair scheduler pool. @@ -1585,15 +1558,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli taskScheduler.schedulingMode } - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding files no longer creates local copies that need to be deleted", "1.0.0") - def clearFiles() { - addedFiles.clear() - } - /** * Gets the locality information associated with the partition in a particular rdd * @param rdd of interest @@ -1685,15 +1649,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli postEnvironmentUpdate() } - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding jars no longer creates local copies that need to be deleted", "1.0.0") - def clearJars() { - addedJars.clear() - } - // Shut down the SparkContext. def stop() { if (AsynchronousListenerBus.withinListenerThread.value) { @@ -1864,63 +1819,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) } - - /** - * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. - * - * The allowLocal flag is deprecated as of Spark 1.5.0+. - */ - @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") - def runJob[T, U: ClassTag]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit): Unit = { - if (allowLocal) { - logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") - } - runJob(rdd, func, partitions, resultHandler) - } - - /** - * Run a function on a given set of partitions in an RDD and return the results as an array. - * - * The allowLocal flag is deprecated as of Spark 1.5.0+. - */ - @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") - def runJob[T, U: ClassTag]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - if (allowLocal) { - logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") - } - runJob(rdd, func, partitions) - } - - /** - * Run a job on a given set of partitions of an RDD, but take a function of type - * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. - * - * The allowLocal argument is deprecated as of Spark 1.5.0+. - */ - @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") - def runJob[T, U: ClassTag]( - rdd: RDD[T], - func: Iterator[T] => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - if (allowLocal) { - logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") - } - runJob(rdd, func, partitions) - } - /** * Run a job on all partitions in an RDD and return the results in an array. */ @@ -2094,10 +1992,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli taskScheduler.defaultParallelism } - /** Default min number of partitions for Hadoop RDDs when not given by user */ - @deprecated("use defaultMinPartitions", "1.0.0") - def defaultMinSplits: Int = defaultMinPartitions - /** * Default min number of partitions for Hadoop RDDs when not given by user * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. @@ -2364,113 +2258,6 @@ object SparkContext extends Logging { */ private[spark] val LEGACY_DRIVER_IDENTIFIER = "" - // The following deprecated objects have already been copied to `object AccumulatorParam` to - // make the compiler find them automatically. They are duplicate codes only for backward - // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the - // following ones. - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double): Double = 0.0 - } - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object IntAccumulatorParam extends AccumulatorParam[Int] { - def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int): Int = 0 - } - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long): Long = t1 + t2 - def zero(initialValue: Long): Long = 0L - } - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float): Float = t1 + t2 - def zero(initialValue: Float): Float = 0f - } - - // The following deprecated functions have already been moved to `object RDD` to - // make the compiler find them automatically. They are still kept here for backward compatibility - // and just call the corresponding functions in `object RDD`. - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = - RDD.rddToPairRDDFunctions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = - RDD.rddToAsyncRDDActions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { - val kf = implicitly[K => Writable] - val vf = implicitly[V => Writable] - // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it - implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf) - implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf) - RDD.rddToSequenceFileRDDFunctions(rdd) - } - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = - RDD.rddToOrderedRDDFunctions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = - RDD.doubleRDDToDoubleRDDFunctions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = - RDD.numericRDDToDoubleRDDFunctions(rdd) - - // The following deprecated functions have already been moved to `object WritableFactory` to - // make the compiler find them automatically. They are still kept here for backward compatibility. - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def stringToText(s: String): Text = new Text(s) - private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) : ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u @@ -2479,50 +2266,6 @@ object SparkContext extends Logging { arr.map(x => anyToWritable(x)).toArray) } - // The following deprecated functions have already been moved to `object WritableConverter` to - // make the compiler find them automatically. They are still kept here for backward compatibility - // and just call the corresponding functions in `object WritableConverter`. - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def intWritableConverter(): WritableConverter[Int] = - WritableConverter.intWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def longWritableConverter(): WritableConverter[Long] = - WritableConverter.longWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def doubleWritableConverter(): WritableConverter[Double] = - WritableConverter.doubleWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def floatWritableConverter(): WritableConverter[Float] = - WritableConverter.floatWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def booleanWritableConverter(): WritableConverter[Boolean] = - WritableConverter.booleanWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def bytesWritableConverter(): WritableConverter[Array[Byte]] = - WritableConverter.bytesWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def stringWritableConverter(): WritableConverter[String] = - WritableConverter.stringWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def writableWritableConverter[T <: Writable](): WritableConverter[T] = - WritableConverter.writableWritableConverter() - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to SparkContext. diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index af558d6e5b474..e25ed0fdd7fd2 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -95,9 +95,6 @@ abstract class TaskContext extends Serializable { */ def isInterrupted(): Boolean - @deprecated("use isRunningLocally", "1.2.0") - def runningLocally(): Boolean - /** * Returns true if the task is running locally in the driver program. * @return @@ -118,16 +115,6 @@ abstract class TaskContext extends Serializable { */ def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext - /** - * Adds a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * - * @param f Callback function. - */ - @deprecated("use addTaskCompletionListener", "1.2.0") - def addOnCompleteCallback(f: () => Unit) - /** * The ID of the stage that this task belong to. */ @@ -144,9 +131,6 @@ abstract class TaskContext extends Serializable { */ def attemptNumber(): Int - @deprecated("use attemptNumber", "1.3.0") - def attemptId(): Long - /** * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index f0ae83a9341bd..6c493630997eb 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -38,9 +38,6 @@ private[spark] class TaskContextImpl( extends TaskContext with Logging { - // For backwards-compatibility; this method is now deprecated as of 1.3.0. - override def attemptId(): Long = taskAttemptId - // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -62,13 +59,6 @@ private[spark] class TaskContextImpl( this } - @deprecated("use addTaskCompletionListener", "1.1.0") - override def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f() - } - } - /** Marks the task as completed and triggers the listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index c32aefac465bc..37ae007f880c2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -23,7 +23,6 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.Partitioner -import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0e4d7dce0f2f5..9cf68672beca2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -57,9 +57,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] - @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = rdd.partitions.toSeq.asJava - /** Set of partitions in this RDD. */ def partitions: JList[Partition] = rdd.partitions.toSeq.asJava @@ -346,13 +343,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def toLocalIterator(): JIterator[T] = asJavaIteratorConverter(rdd.toLocalIterator).asJava - /** - * Return an array that contains all of the elements in this RDD. - * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead - */ - @deprecated("use collect()", "1.0.0") - def toArray(): JList[T] = collect() - /** * Return an array that contains all of the elements in a specific partition of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4f54cd69e2175..9f5b89bb4ba45 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -102,7 +102,7 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala)) private[spark] val env = sc.env @@ -126,14 +126,6 @@ class JavaSparkContext(val sc: SparkContext) /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: java.lang.Integer = sc.defaultParallelism - /** - * Default min number of partitions for Hadoop RDDs when not given by user. - * @deprecated As of Spark 1.0.0, defaultMinSplits is deprecated, use - * {@link #defaultMinPartitions()} instead - */ - @deprecated("use defaultMinPartitions", "1.0.0") - def defaultMinSplits: java.lang.Integer = sc.defaultMinSplits - /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions @@ -671,24 +663,6 @@ class JavaSparkContext(val sc: SparkContext) sc.addJar(path) } - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding jars no longer creates local copies that need to be deleted", "1.0.0") - def clearJars() { - sc.clearJars() - } - - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding files no longer creates local copies that need to be deleted", "1.0.0") - def clearFiles() { - sc.clearFiles() - } - /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index b87230142532b..76b31165aa74c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -359,12 +359,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } - /** Alias for reduceByKeyLocally */ - @deprecated("Use reduceByKeyLocally", "1.0.0") - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = self.withScope { - reduceByKeyLocally(func) - } - /** * Count the number of elements for each key, collecting the results to a local Map. * diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9fe9d83a705b2..394f79dc7734e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -746,99 +746,6 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } - /** - * :: DeveloperApi :: - * Return a new RDD by applying a function to each partition of this RDD. This is a variant of - * mapPartitions that also passes the TaskContext into the closure. - * - * `preservesPartitioning` indicates whether the input function preserves the partitioner, which - * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. - */ - @DeveloperApi - @deprecated("use TaskContext.get", "1.2.0") - def mapPartitionsWithContext[U: ClassTag]( - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { - val cleanF = sc.clean(f) - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { - mapPartitionsWithIndex(f, preservesPartitioning) - } - - /** - * Maps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex", "1.0.0") - def mapWith[A, U: ClassTag] - (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => U): RDD[U] = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.map(t => cleanF(t, a)) - }, preservesPartitioning) - } - - /** - * FlatMaps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0") - def flatMapWith[A, U: ClassTag] - (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => Seq[U]): RDD[U] = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.flatMap(t => cleanF(t, a)) - }, preservesPartitioning) - } - - /** - * Applies f to each element of this RDD, where f takes an additional parameter of type A. - * This additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") - def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex { (index, iter) => - val a = cleanA(index) - iter.map(t => {cleanF(t, a); t}) - } - } - - /** - * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") - def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope { - val cleanP = sc.clean(p) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.filter(t => cleanP(t, a)) - }, preservesPartitioning = true) - } - /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * second element in each RDD, etc. Assumes that the two RDDs have the *same number of @@ -944,14 +851,6 @@ abstract class RDD[T: ClassTag]( (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } - /** - * Return an array that contains all of the elements in this RDD. - */ - @deprecated("use collect", "1.0.0") - def toArray(): Array[T] = withScope { - collect() - } - /** * Return an RDD that contains all matching values by applying `f`. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala deleted file mode 100644 index 9e8cee5331cf8..0000000000000 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.util.Random - -import scala.reflect.ClassTag - -import org.apache.commons.math3.distribution.PoissonDistribution - -import org.apache.spark.{Partition, TaskContext} - -@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0") -private[spark] -class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { - override val index: Int = prev.index -} - -@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0") -private[spark] class SampledRDD[T: ClassTag]( - prev: RDD[T], - withReplacement: Boolean, - frac: Double, - seed: Int) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = { - val rg = new Random(seed) - firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = - firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) - - override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { - val split = splitIn.asInstanceOf[SampledRDDPartition] - if (withReplacement) { - // For large datasets, the expected number of occurrences of each element in a sample with - // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new PoissonDistribution(frac) - poisson.reseedRandomGenerator(split.seed) - - firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.sample() - if (count == 0) { - Iterator.empty // Avoid object allocation when we return 0 items, which is quite often - } else { - Iterator.fill(count)(element) - } - } - } else { // Sampling without replacement - val rand = new Random(split.seed) - firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 4b5f15dd06b85..c4bc85a5ea2d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -38,11 +38,6 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag extends Logging with Serializable { - @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") - def this(self: RDD[(K, V)]) { - this(self, null, null) - } - private val keyWritableClass = if (_keyWritableClass == null) { // pre 1.3.0, we need to use Reflection to get the Writable class diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index f113c2b1b8433..a42990addb9c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -95,9 +95,6 @@ class TaskInfo( } } - @deprecated("Use attemptNumber", "1.6.0") - def attempt: Int = attemptNumber - def id: String = s"$index.$attemptNumber" def duration: Long = { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index a51f30b9c2921..b68936f5c9f0a 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.util -import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps import org.apache.spark.SparkConf @@ -50,18 +49,8 @@ private[spark] object RpcUtils { RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") } - @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") - def askTimeout(conf: SparkConf): FiniteDuration = { - askRpcTimeout(conf).duration - } - /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") } - - @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") - def lookupTimeout(conf: SparkConf): FiniteDuration = { - lookupRpcTimeout(conf).duration - } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index d91948e44694b..502f86f178fd2 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -687,13 +687,6 @@ public Boolean call(Integer i) { }).isEmpty()); } - @Test - public void toArray() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3)); - List list = rdd.toArray(); - Assert.assertEquals(Arrays.asList(1, 2, 3), list); - } - @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 007a71f87cf10..18d1466bb7c30 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -441,66 +441,6 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(prunedData(0) === 10) } - test("mapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val randoms = ones.mapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextDouble * t}.collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(2) === prn42_3) - assert(randoms(5) === prn43_3) - } - - test("flatMapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val randoms = ones.flatMapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => - val random = prng.nextDouble() - Seq(random * t, random * t * 10)}. - collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(5) === prn42_3 * 10) - assert(randoms(11) === prn43_3 * 10) - } - - test("filterWith") { - import java.util.Random - val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val sample = ints.filterWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextInt(3) == 0}. - collect() - val checkSample = { - val prng42 = new Random(42) - val prng43 = new Random(43) - Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) else 0 == prng43.nextInt(3) - } - } - assert(sample.size === checkSample.size) - for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) - } - test("collect large number of empty partitions") { // Regression test for SPARK-4019 assert(sc.makeRDD(0 until 10, 1000).repartition(2001).collect().toSet === (0 until 10).toSet) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index d83d0aee42254..40ebfdde928fe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -99,14 +99,6 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }.collect() assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } - - test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") { - sc = new SparkContext("local", "test") - val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter => - Seq(TaskContext.get().attemptId).iterator - }.collect() - assert(attemptIds.toSet === Set(0, 1, 2, 3)) - } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 480722a5ac182..5e745e0a95769 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.util import java.io.NotSerializableException -import java.util.Random import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} @@ -91,11 +90,6 @@ class ClosureCleanerSuite extends SparkFunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) } @@ -269,21 +263,6 @@ private object TestUserClosuresActuallyCleaned { def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = { rdd.mapPartitionsWithIndex { (_, it) => return; it }.count() } - def testFlatMapWith(rdd: RDD[Int]): Unit = { - rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count() - } - def testMapWith(rdd: RDD[Int]): Unit = { - rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count() - } - def testFilterWith(rdd: RDD[Int]): Unit = { - rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count() - } - def testForEachWith(rdd: RDD[Int]): Unit = { - rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return } - } - def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = { - rdd.mapPartitionsWithContext { (_, it) => return; it }.count() - } def testZipPartitions2(rdd: RDD[Int]): Unit = { rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count() } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java index 980a9108af53f..779fac01c4be0 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -68,22 +68,22 @@ public Tuple2 call(LabeledPoint p) { // Precision by threshold JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); + System.out.println("Precision by threshold: " + precision.collect()); // Recall by threshold JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); + System.out.println("Recall by threshold: " + recall.collect()); // F Score by threshold JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); + System.out.println("F1 Score by threshold: " + f1Score.collect()); JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); + System.out.println("F2 Score by threshold: " + f2Score.collect()); // Precision-recall curve JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); + System.out.println("Precision-recall curve: " + prc.collect()); // Thresholds JavaRDD thresholds = precision.map( @@ -96,7 +96,7 @@ public Double call(Tuple2 t) { // ROC Curve JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); + System.out.println("ROC curve: " + roc.collect()); // AUPRC System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 6c90dbec3d531..04dec57b71e16 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -26,8 +26,6 @@ import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.scheduler.InputFormatInfo - /** * Logistic regression based classification. @@ -74,10 +72,7 @@ object SparkHdfsLR { val sparkConf = new SparkConf().setAppName("SparkHdfsLR") val inputPath = args(0) val conf = new Configuration() - val sc = new SparkContext(sparkConf, - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) - )) + val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).cache() val ITERATIONS = args(1).toInt diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index e492582710e12..ddc99d3f90690 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -26,7 +26,6 @@ import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.storage.StorageLevel /** @@ -70,10 +69,7 @@ object SparkTachyonHdfsLR { val inputPath = args(0) val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") val conf = new Configuration() - val sc = new SparkContext(sparkConf, - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) - )) + val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) val ITERATIONS = args(1).toInt diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cf11504b99451..8c3a40d2412a7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -59,7 +59,58 @@ object MimaExcludes { ) ++ Seq( // SPARK-12481 Remove Hadoop 1.x - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil") + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), + // SPARK-12615 Remove deprecated APIs in core + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles") ) case v if v.startsWith("1.6") => Seq( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index bff6811bf4164..2edc8f932c4a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -87,7 +87,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-8489: MissingRequirementError during reflection") { + ignore("SPARK-8489: MissingRequirementError during reflection") { // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates // a HiveContext and uses it to create a data frame from an RDD using reflection. // Before the fix in SPARK-8470, this results in a MissingRequirementError because From d202ad2fc24b54de38ad7bfb646bf7703069e9f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jan 2016 12:33:21 -0800 Subject: [PATCH 1441/2089] [SPARK-12439][SQL] Fix toCatalystArray and MapObjects JIRA: https://issues.apache.org/jira/browse/SPARK-12439 In toCatalystArray, we should look at the data type returned by dataTypeFor instead of silentSchemaFor, to determine if the element is native type. An obvious problem is when the element is Option[Int] class, catalsilentSchemaFor will return Int, then we will wrongly recognize the element is native type. There is another problem when using Option as array element. When we encode data like Seq(Some(1), Some(2), None) with encoder, we will use MapObjects to construct an array for it later. But in MapObjects, we don't check if the return value of lambdaFunction is null or not. That causes a bug that the decoded data for Seq(Some(1), Some(2), None) would be Seq(1, 2, -1), instead of Seq(1, 2, null). Author: Liang-Chi Hsieh Closes #10391 from viirya/fix-catalystarray. --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/encoders/RowEncoder.scala | 11 ++++++++--- .../spark/sql/catalyst/expressions/objects.scala | 4 ++-- .../catalyst/encoders/ExpressionEncoderSuite.scala | 3 +++ 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c6aa60b0b4d72..b0efdf3ef4024 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { val externalDataType = dataTypeFor(elementType) val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(catalystType)) { + if (isNativeType(externalDataType)) { NewInstance( classOf[GenericArrayData], input :: Nil, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 6f3d5ba84c9ae..3903086a4c45b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -35,7 +35,8 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpressions = extractorsFor(inputObject, schema) + // We use an If expression to wrap extractorsFor result of StructType + val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue val constructExpression = constructorFor(schema) new ExpressionEncoder[Row]( schema, @@ -129,7 +130,9 @@ object RowEncoder { Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), f.dataType)) } - CreateStruct(convertedFields) + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) } private def externalDataTypeFor(dt: DataType): DataType = dt match { @@ -220,6 +223,8 @@ object RowEncoder { Literal.create(null, externalDataTypeFor(f.dataType)), constructorFor(GetStructField(input, i))) } - CreateExternalRow(convertedFields) + If(IsNull(input), + Literal.create(null, externalDataTypeFor(input.dataType)), + CreateExternalRow(convertedFields)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index fb404c12d5a04..c0c3e6e891669 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -456,10 +456,10 @@ case class MapObjects( ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - if (${loopVar.isNull}) { + ${genFunction.code} + if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - ${genFunction.code} $convertedArray[$loopIndex] = ${genFunction.value}; } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 6453f1c191ba0..98f29e53df9f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { productTest(OptionalData(None, None, None, None, None, None, None, None)) + encodeDecodeTest(Seq(Some(1), None), "Option in array") + encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map") + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) productTest(BoxedData(null, null, null, null, null, null, null)) From 047a31bb1042867b20132b347b1e08feab4562eb Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 5 Jan 2016 13:10:46 -0800 Subject: [PATCH 1442/2089] [SPARK-12617] [PYSPARK] Clean up the leak sockets of Py4J This patch added Py4jCallbackConnectionCleaner to clean the leak sockets of Py4J every 30 seconds. This is a workaround before Py4J fixes the leak issue https://github.com/bartdag/py4j/issues/187 Author: Shixiong Zhu Closes #10579 from zsxwing/SPARK-12617. --- python/pyspark/context.py | 61 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 529d16b480399..5e4aeac330c5a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -54,6 +54,64 @@ } +class Py4jCallbackConnectionCleaner(object): + + """ + A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617. + It will scan all callback connections every 30 seconds and close the dead connections. + """ + + def __init__(self, gateway): + self._gateway = gateway + self._stopped = False + self._timer = None + self._lock = RLock() + + def start(self): + if self._stopped: + return + + def clean_closed_connections(): + from py4j.java_gateway import quiet_close, quiet_shutdown + + callback_server = self._gateway._callback_server + with callback_server.lock: + try: + closed_connections = [] + for connection in callback_server.connections: + if not connection.isAlive(): + quiet_close(connection.input) + quiet_shutdown(connection.socket) + quiet_close(connection.socket) + closed_connections.append(connection) + + for closed_connection in closed_connections: + callback_server.connections.remove(closed_connection) + except Exception: + import traceback + traceback.print_exc() + + self._start_timer(clean_closed_connections) + + self._start_timer(clean_closed_connections) + + def _start_timer(self, f): + from threading import Timer + + with self._lock: + if not self._stopped: + self._timer = Timer(30.0, f) + self._timer.daemon = True + self._timer.start() + + def stop(self): + with self._lock: + self._stopped = True + if self._timer: + self._timer.cancel() + self._timer = None + + class SparkContext(object): """ @@ -68,6 +126,7 @@ class SparkContext(object): _active_spark_context = None _lock = RLock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + _py4j_cleaner = None PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') @@ -244,6 +303,8 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm + _py4j_cleaner = Py4jCallbackConnectionCleaner(SparkContext._gateway) + _py4j_cleaner.start() if instance: if (SparkContext._active_spark_context and From 13a3b636d9425c5713cd1381203ee1b60f71b8c8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 5 Jan 2016 13:31:59 -0800 Subject: [PATCH 1443/2089] [SPARK-6724][MLLIB] Support model save/load for FPGrowthModel Support model save/load for FPGrowthModel Author: Yanbo Liang Closes #9267 from yanboliang/spark-6724. --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 100 +++++++++++++++++- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 40 +++++++ .../spark/mllib/fpm/FPGrowthSuite.scala | 68 ++++++++++++ 3 files changed, 205 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 70ef1ed30c71a..5273ed4d76650 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,19 +17,29 @@ package org.apache.spark.mllib.fpm -import java.{util => ju} import java.lang.{Iterable => JavaIterable} +import java.{util => ju} -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) + extends Saveable with Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) } + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[FPGrowthModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + FPGrowthModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object FPGrowthModel extends Loader[FPGrowthModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + FPGrowthModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel" + + def save(model: FPGrowthModel[_], path: String): Unit = { + val sc = model.freqItemsets.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqItemsets.first().items(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("items", ArrayType(itemType)), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqItemsets.map { x => + Row(x.items, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqItemsets.select("items").head().get(0) + loadImpl(freqItemsets, sample) + } + + def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => + val items = x.getAs[Seq[Item]](0).toArray + val freq = x.getLong(1) + new FreqItemset(items, freq) + } + new FPGrowthModel(freqItemsetsRDD) + } + } } /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 154f75d75e4a6..eeeabfe359e68 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.util.Utils; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -69,4 +71,42 @@ public void runFPGrowth() { long freq = itemset.freq(); } } + + @Test + public void runFPGrowthSaveLoad() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath); + List> freqItemsets = newModel.freqItemsets().toJavaRDD() + .collect(); + assertEquals(18, freqItemsets.size()); + + for (FPGrowth.FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 4a9bfdb348d9f..b9e997c207bc7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { */ assert(model1.freqItemsets.count() === 65) } + + test("model save/load with String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load with Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } } From c26d174265f6b4682210fcc406e6603b4f7dc784 Mon Sep 17 00:00:00 2001 From: Nong Date: Tue, 5 Jan 2016 13:47:24 -0800 Subject: [PATCH 1444/2089] [SPARK-12636] [SQL] Update UnsafeRowParquetRecordReader to support reading files directly. As noted in the code, this change is to make this component easier to test in isolation. Author: Nong Closes #10581 from nongli/spark-12636. --- .../SpecificParquetRecordReaderBase.java | 72 +++++++++++++++++- .../parquet/UnsafeRowParquetRecordReader.java | 61 ++++++++------- .../datasources/parquet/ParquetIOSuite.scala | 74 +++++++++++++++++++ 3 files changed, 178 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 842dcb8c93dc2..f8e32d60a489a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.ByteArrayInputStream; +import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +37,7 @@ import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; @@ -56,6 +58,8 @@ import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.hadoop.util.ConfigurationUtil; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Types; +import org.apache.spark.sql.types.StructType; /** * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. @@ -69,7 +73,7 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader readSupport; + protected StructType sparkSchema; /** * The total number of rows this RecordReader will eventually read. The sum of the @@ -125,20 +129,80 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont + " in range " + split.getStart() + ", " + split.getEnd()); } } - MessageType fileSchema = footer.getFileMetaData().getSchema(); + this.fileSchema = footer.getFileMetaData().getSchema(); Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); - this.readSupport = getReadSupportInstance( + ReadSupport readSupport = getReadSupportInstance( (Class>) getReadSupportClass(configuration)); ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); this.requestedSchema = readContext.getRequestedSchema(); - this.fileSchema = fileSchema; + this.sparkSchema = new CatalystSchemaConverter(configuration).convert(requestedSchema); this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { this.totalRowCount += block.getRowCount(); } } + /** + * Returns the list of files at 'path' recursively. This skips files that are ignored normally + * by MapReduce. + */ + public static List listDirectory(File path) throws IOException { + List result = new ArrayList(); + if (path.isDirectory()) { + for (File f: path.listFiles()) { + result.addAll(listDirectory(f)); + } + } else { + char c = path.getName().charAt(0); + if (c != '.' && c != '_') { + result.add(path.getAbsolutePath()); + } + } + return result; + } + + /** + * Initializes the reader to read the file at `path` with `columns` projected. If columns is + * null, all the columns are projected. + * + * This is exposed for testing to be able to create this reader without the rest of the Hadoop + * split machinery. It is not intended for general use and those not support all the + * configurations. + */ + protected void initialize(String path, List columns) throws IOException { + Configuration config = new Configuration(); + config.set("spark.sql.parquet.binaryAsString", "false"); + config.set("spark.sql.parquet.int96AsTimestamp", "false"); + config.set("spark.sql.parquet.writeLegacyFormat", "false"); + + this.file = new Path(path); + long length = FileSystem.get(config).getFileStatus(this.file).getLen(); + ParquetMetadata footer = readFooter(config, file, range(0, length)); + + List blocks = footer.getBlocks(); + this.fileSchema = footer.getFileMetaData().getSchema(); + + if (columns == null) { + this.requestedSchema = fileSchema; + } else { + Types.MessageTypeBuilder builder = Types.buildMessage(); + for (String s: columns) { + if (!fileSchema.containsField(s)) { + throw new IOException("Can only project existing columns. Unknown field: " + s + + " File schema:\n" + fileSchema); + } + builder.addFields(fileSchema.getType(s)); + } + this.requestedSchema = builder.named("spark_schema"); + } + this.sparkSchema = new CatalystSchemaConverter(config).convert(requestedSchema); + this.reader = new ParquetFileReader(config, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + @Override public Void getCurrentKey() throws IOException, InterruptedException { return null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 198bfb6d67aee..47818c0939f2a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.List; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; @@ -121,14 +122,42 @@ public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttem public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException { super.initialize(inputSplit, taskAttemptContext); + initializeInternal(); + } + + /** + * Utility API that will read all the data in path. This circumvents the need to create Hadoop + * objects to use this class. `columns` can contain the list of columns to project. + */ + @Override + public void initialize(String path, List columns) throws IOException { + super.initialize(path, columns); + initializeInternal(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + private void initializeInternal() throws IOException { /** * Check that the requested schema is supported. */ - if (requestedSchema.getFieldCount() == 0) { - // TODO: what does this mean? - throw new IOException("Empty request schema not supported."); - } int numVarLenFields = 0; originalTypes = new OriginalType[requestedSchema.getFieldCount()]; for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { @@ -182,25 +211,6 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont } } - @Override - public boolean nextKeyValue() throws IOException, InterruptedException { - if (batchIdx >= numBatched) { - if (!loadBatch()) return false; - } - ++batchIdx; - return true; - } - - @Override - public UnsafeRow getCurrentValue() throws IOException, InterruptedException { - return rows[batchIdx - 1]; - } - - @Override - public float getProgress() throws IOException, InterruptedException { - return (float) rowsReturned / totalRowCount; - } - /** * Decodes a batch of values into `rows`. This function is the hot path. */ @@ -253,10 +263,11 @@ private boolean loadBatch() throws IOException { case INT96: throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); } - numBatched = num; - batchIdx = 0; } + numBatched = num; + batchIdx = 0; + // Update the total row lengths if the schema contained variable length. We did not maintain // this as we populated the columns. if (containsVarLenFields) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b0581e8b35510..7f82cce0a122d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.column.{Encoding, ParquetProperties} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.Utils import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -642,6 +645,77 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } } + + test("UnsafeRowParquetRecordReader - direct path read") { + val data = (0 to 10).map(i => (i, ((i + 'a').toChar.toString))) + withTempPath { dir => + sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, null) + val result = mutable.ArrayBuffer.empty[(Int, String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + val v = (row.getInt(0), row.getString(1)) + result += v + } + assert(data == result) + } finally { + reader.close() + } + } + + // Project just one column + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, ("_2" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + result += row.getString(0) + } + assert(data.map(_._2) == result) + } finally { + reader.close() + } + } + + // Project columns in opposite order + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String, Int)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + val v = (row.getString(0), row.getInt(1)) + result += v + } + assert(data.map { x => (x._2, x._1) } == result) + } finally { + reader.close() + } + } + + // Empty projection + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, List[String]().asJava) + var result = 0 + while (reader.nextKeyValue()) { + result += 1 + } + assert(result == data.length) + } finally { + reader.close() + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From 6cfe341ee89baa952929e91d33b9ecbca73a3ea0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 5 Jan 2016 13:48:47 -0800 Subject: [PATCH 1445/2089] [SPARK-12511] [PYSPARK] [STREAMING] Make sure PythonDStream.registerSerializer is called only once There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. (https://github.com/bartdag/py4j/pull/184) Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when calling "registerSerializer". If we call "registerSerializer" twice, the second PythonProxyHandler will override the first one, then the first one will be GCed and trigger "PythonProxyHandler.finalize". To avoid that, we should not call"registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't be GCed. Author: Shixiong Zhu Closes #10514 from zsxwing/SPARK-12511. --- python/pyspark/streaming/context.py | 33 ++++++++++++++----- python/pyspark/streaming/util.py | 3 +- .../spark/streaming/StreamingContext.scala | 12 +++++++ 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 3deed52be0be2..5cc4bbde39958 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -98,8 +98,28 @@ def _ensure_initialized(cls): # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing - cls._transformerSerializer = TransformFunctionSerializer( - SparkContext._active_spark_context, CloudPickleSerializer(), gw) + if cls._transformerSerializer is None: + transformer_serializer = TransformFunctionSerializer() + transformer_serializer.init( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM + # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. + # (https://github.com/bartdag/py4j/pull/184) + # + # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when + # calling "registerSerializer". If we call "registerSerializer" twice, the second + # PythonProxyHandler will override the first one, then the first one will be GCed and + # trigger "PythonProxyHandler.finalize". To avoid that, we should not call + # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't + # be GCed. + # + # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version. + transformer_serializer.gateway.jvm.PythonDStream.registerSerializer( + transformer_serializer) + cls._transformerSerializer = transformer_serializer + else: + cls._transformerSerializer.init( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) @classmethod def getOrCreate(cls, checkpointPath, setupFunc): @@ -116,16 +136,13 @@ def getOrCreate(cls, checkpointPath, setupFunc): gw = SparkContext._gateway # Check whether valid checkpoint information exists in the given path - if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): + ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath) + if ssc_option.isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - try: - jssc = gw.jvm.JavaStreamingContext(checkpointPath) - except Exception: - print("failed to load StreamingContext from checkpoint", file=sys.stderr) - raise + jssc = gw.jvm.JavaStreamingContext(ssc_option.get()) # If there is already an active instance of Python SparkContext use it, or create a new one if not SparkContext._active_spark_context: diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..e617fc9ce9eec 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -89,11 +89,10 @@ class TransformFunctionSerializer(object): it uses this class to invoke Python, which returns the serialized function as a byte array. """ - def __init__(self, ctx, serializer, gateway=None): + def init(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer self.gateway = gateway or self.ctx._gateway - self.gateway.jvm.PythonDStream.registerSerializer(self) self.failure = None def dumps(self, id): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index c4a10aa2dd3b9..a5ab66697589b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -902,3 +902,15 @@ object StreamingContext extends Logging { result } } + +private class StreamingContextPythonHelper { + + /** + * This is a private method only for Python to implement `getOrCreate`. + */ + def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = { + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false) + checkpointOption.map(new StreamingContext(null, _, null)) + } +} From 1c6cf1a5639bf5111324e44d93a8c6462958750a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 5 Jan 2016 14:24:32 -0800 Subject: [PATCH 1446/2089] [SPARK-12570][ML][DOC] DecisionTreeRegressor: provide variance of prediction: user guide update Update user guide doc for ```DecisionTreeRegressor``` providing variance of prediction. cc jkbradley Author: Yanbo Liang Closes #10594 from yanboliang/spark-12570. --- docs/ml-classification-regression.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d63438bf74c17..8ffc997b4bf5a 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -535,7 +535,9 @@ The main differences between this API and the [original MLlib Decision Tree API] * use of DataFrame metadata to distinguish continuous and categorical features -The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). +The Pipelines API for Decision Trees offers a bit more functionality than the original API. +In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities); +for regression, users can get the biased sample variance of prediction. Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). @@ -605,6 +607,13 @@ All output columns are optional; to exclude an output column, set its correspond Vector of length # classes equal to rawPrediction normalized to a multinomial distribution Classification only + + varianceCol + Double + + The biased sample variance of prediction + Regression only + From 78015a8b7cc316343e302eeed6fe30af9f2961e8 Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Tue, 5 Jan 2016 15:05:04 -0800 Subject: [PATCH 1447/2089] [SPARK-12450][MLLIB] Un-persist broadcasted variables in KMeans SPARK-12450 . Un-persist broadcasted variables in KMeans. Author: RJ Nowling Closes #10415 from rnowling/spark-12450. --- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 2895db7c9061b..e47c4db62955d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -301,6 +301,8 @@ class KMeans private ( contribs.iterator }.reduceByKey(mergeContribs).collectAsMap() + bcActiveCenters.unpersist(blocking = false) + // Update the cluster centers and costs for each active run for ((run, i) <- activeRuns.zipWithIndex) { var changed = false @@ -419,7 +421,10 @@ class KMeans private ( s0 } ) + + bcNewCenters.unpersist(blocking = false) preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => @@ -448,6 +453,9 @@ class KMeans private ( ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() + + bcCenters.unpersist(blocking = false) + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray From ff89975543b153d0d235c0cac615d45b34aa8fe7 Mon Sep 17 00:00:00 2001 From: BrianLondon Date: Tue, 5 Jan 2016 23:15:07 +0000 Subject: [PATCH 1448/2089] [SPARK-12453][STREAMING] Remove explicit dependency on aws-java-sdk Successfully ran kinesis demo on a live, aws hosted kinesis stream against master and 1.6 branches. For reasons I don't entirely understand it required a manual merge to 1.5 which I did as shown here: https://github.com/BrianLondon/spark/commit/075c22e89bc99d5e99be21f40e0d72154a1e23a2 The demo ran successfully on the 1.5 branch as well. According to `mvn dependency:tree` it is still pulling a fairly old version of the aws-java-sdk (1.9.37), but this appears to have fixed the kinesis regression in 1.5.2. Author: BrianLondon Closes #10492 from BrianLondon/remove-only. --- extras/kinesis-asl/pom.xml | 5 ----- .../org/apache/spark/streaming/kinesis/KinesisReceiver.scala | 1 + pom.xml | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 3c5722502e5c1..20e2c5e0ffbee 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -59,11 +59,6 @@ amazon-kinesis-client ${aws.kinesis.client.version} - - com.amazonaws - aws-java-sdk - ${aws.java.sdk.version} - com.amazonaws amazon-kinesis-producer diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 80edda59e1719..abb9b6cd32f1c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -185,6 +185,7 @@ private[kinesis] class KinesisReceiver[T]( workerThread.setName(s"Kinesis Receiver ${streamId}") workerThread.setDaemon(true) workerThread.start() + logInfo(s"Started receiver with workerId $workerId") } diff --git a/pom.xml b/pom.xml index 398fcc92db994..d0ac1eb39aabe 100644 --- a/pom.xml +++ b/pom.xml @@ -152,7 +152,6 @@ 1.7.7 hadoop2 0.7.1 - 1.9.40 1.4.0 0.10.1 From 1537e55604cafafa49a8b7f3ce915f9745392bc0 Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Tue, 5 Jan 2016 15:33:27 -0800 Subject: [PATCH 1449/2089] [SPARK-12041][ML][PYSPARK] Add columnSimilarities to IndexedRowMatrix Add `columnSimilarities` to IndexedRowMatrix for PySpark spark.mllib.linalg. Author: Kai Jiang Closes #10158 from vectorijk/spark-12041. --- python/pyspark/mllib/linalg/distributed.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 0e76050788630..e1f022187d500 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -297,6 +297,20 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") + def columnSimilarities(self): + """ + Compute all cosine similarities between columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows) + >>> cs = mat.columnSimilarities() + >>> print(cs.numCols()) + 3 + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("columnSimilarities") + return CoordinateMatrix(java_coordinate_matrix) + def toRowMatrix(self): """ Convert this matrix to a RowMatrix. From df8bd97520fc67dad95141c5a8cf2e0d5332e693 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 5 Jan 2016 16:48:59 -0800 Subject: [PATCH 1450/2089] [SPARK-3873][SQL] Import ordering fixes. Author: Marcelo Vanzin Closes #10573 from vanzin/SPARK-3873-sql. --- .../main/scala/org/apache/spark/sql/Encoder.scala | 6 +++--- .../spark/sql/catalyst/JavaTypeInference.scala | 11 +++++------ .../spark/sql/catalyst/ScalaReflection.scala | 4 ++-- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/analysis/Catalog.scala | 2 +- .../spark/sql/catalyst/analysis/unresolved.scala | 4 ++-- .../apache/spark/sql/catalyst/dsl/package.scala | 4 ++-- .../sql/catalyst/encoders/ExpressionEncoder.scala | 10 +++++----- .../spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- .../catalyst/expressions/ExpectsInputTypes.scala | 2 +- .../sql/catalyst/expressions/InputFileName.scala | 2 +- .../sql/catalyst/expressions/JoinedRow.scala | 3 +-- .../expressions/MonotonicallyIncreasingID.scala | 4 ++-- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 3 +-- .../sql/catalyst/expressions/SortOrder.scala | 2 +- .../catalyst/expressions/SparkPartitionID.scala | 5 ++--- .../expressions/aggregate/interfaces.scala | 4 ++-- .../expressions/codegen/CodeGenerator.scala | 2 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateSafeProjection.scala | 2 +- .../codegen/GenerateUnsafeRowJoiner.scala | 3 +-- .../expressions/collectionOperations.scala | 2 +- .../catalyst/expressions/complexTypeCreator.scala | 2 +- .../expressions/complexTypeExtractors.scala | 4 ++-- .../expressions/datetimeExpressions.scala | 8 ++++---- .../sql/catalyst/expressions/generators.scala | 2 +- .../spark/sql/catalyst/expressions/literals.scala | 3 ++- .../catalyst/expressions/mathExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 2 +- .../catalyst/expressions/windowExpressions.scala | 2 +- .../catalyst/plans/logical/LocalRelation.scala | 2 +- .../plans/logical/ScriptTransformation.scala | 2 +- .../catalyst/plans/logical/basicOperators.scala | 3 ++- .../sql/catalyst/plans/logical/commands.scala | 2 +- .../catalyst/plans/physical/partitioning.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 15 ++++++++------- .../spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- .../apache/spark/sql/types/AbstractDataType.scala | 2 +- .../org/apache/spark/sql/types/ArrayType.scala | 7 +++---- .../org/apache/spark/sql/types/ByteType.scala | 3 +-- .../org/apache/spark/sql/types/DataType.scala | 3 +-- .../org/apache/spark/sql/types/Decimal.scala | 2 +- .../org/apache/spark/sql/types/DoubleType.scala | 2 +- .../org/apache/spark/sql/types/FloatType.scala | 2 +- .../org/apache/spark/sql/types/IntegerType.scala | 2 +- .../org/apache/spark/sql/types/LongType.scala | 2 +- .../org/apache/spark/sql/types/ShortType.scala | 2 +- .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- .../org/apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameStatFunctions.scala | 2 +- .../org/apache/spark/sql/DataFrameWriter.scala | 5 ++--- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/GroupedData.scala | 5 ++--- .../org/apache/spark/sql/GroupedDataset.scala | 4 ++-- .../scala/org/apache/spark/sql/SQLContext.scala | 6 +++--- .../scala/org/apache/spark/sql/SQLImplicits.scala | 5 ++--- .../org/apache/spark/sql/api/r/SQLUtils.scala | 8 ++++---- .../spark/sql/execution/CoGroupedIterator.scala | 2 +- .../spark/sql/execution/ExchangeCoordinator.scala | 4 ++-- .../apache/spark/sql/execution/ExistingRDD.scala | 5 ++--- .../spark/sql/execution/GroupedIterator.scala | 4 ++-- .../apache/spark/sql/execution/Queryable.scala | 1 + .../apache/spark/sql/execution/SQLExecution.scala | 4 ++-- .../spark/sql/execution/SortPrefixUtils.scala | 3 +-- .../spark/sql/execution/SparkSqlSerializer.scala | 5 ++--- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../spark/sql/execution/UnsafeRowSerializer.scala | 2 +- .../org/apache/spark/sql/execution/Window.scala | 2 +- .../execution/aggregate/SortBasedAggregate.scala | 2 +- .../execution/aggregate/TungstenAggregate.scala | 2 +- .../aggregate/TungstenAggregationIterator.scala | 4 ++-- .../aggregate/TypedAggregateExpression.scala | 6 +++--- .../spark/sql/execution/aggregate/udaf.scala | 6 +++--- .../spark/sql/execution/basicOperators.scala | 3 +-- .../sql/execution/columnar/ColumnStats.scala | 2 +- .../columnar/GenerateColumnAccessor.scala | 2 +- .../columnar/InMemoryColumnarTableScan.scala | 2 +- .../columnar/NullableColumnAccessor.scala | 2 +- .../columnar/compression/CompressionScheme.scala | 1 + .../org/apache/spark/sql/execution/commands.scala | 4 ++-- .../sql/execution/datasources/DDLParser.scala | 2 +- .../datasources/DataSourceStrategy.scala | 6 +++--- .../datasources/InsertIntoHadoopFsRelation.scala | 1 + .../datasources/ResolvedDataSource.scala | 4 ++-- .../execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- .../execution/datasources/WriterContainer.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../datasources/jdbc/DefaultSource.scala | 2 +- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 4 ++-- .../execution/datasources/jdbc/JDBCRelation.scala | 2 +- .../execution/datasources/jdbc/JdbcUtils.scala | 4 ++-- .../execution/datasources/json/JSONOptions.scala | 2 +- .../execution/datasources/json/JSONRelation.scala | 4 ++-- .../datasources/json/JacksonGenerator.scala | 5 ++--- .../datasources/json/JacksonParser.scala | 1 + .../datasources/parquet/CatalystReadSupport.scala | 4 ++-- .../parquet/CatalystRowConverter.scala | 6 +++--- .../parquet/CatalystSchemaConverter.scala | 6 +++--- .../parquet/CatalystWriteSupport.scala | 2 +- .../parquet/DirectParquetOutputCommitter.scala | 4 ++-- .../datasources/parquet/ParquetFilters.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 7 +++---- .../spark/sql/execution/datasources/rules.scala | 2 +- .../datasources/text/DefaultSource.scala | 14 +++++++------- .../spark/sql/execution/debug/package.scala | 2 +- .../sql/execution/joins/BroadcastHashJoin.scala | 4 ++-- .../execution/joins/BroadcastHashOuterJoin.scala | 6 +++--- .../execution/joins/BroadcastNestedLoopJoin.scala | 2 +- .../sql/execution/joins/CartesianProduct.scala | 4 ++-- .../sql/execution/joins/HashedRelation.scala | 7 +++---- .../sql/execution/joins/LeftSemiJoinHash.scala | 2 +- .../sql/execution/joins/SortMergeOuterJoin.scala | 4 ++-- .../sql/execution/local/BinaryHashJoinNode.scala | 2 +- .../spark/sql/execution/local/LocalNode.scala | 2 +- .../sql/execution/local/NestedLoopJoinNode.scala | 2 +- .../spark/sql/execution/local/ProjectNode.scala | 2 +- .../spark/sql/execution/metric/SQLMetrics.scala | 2 +- .../org/apache/spark/sql/execution/python.scala | 6 +++--- .../spark/sql/execution/stat/FrequentItems.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 4 ++-- .../spark/sql/execution/ui/SQLListener.scala | 7 +++---- .../apache/spark/sql/expressions/Aggregator.scala | 2 +- .../apache/spark/sql/expressions/WindowSpec.scala | 3 +-- .../org/apache/spark/sql/expressions/udaf.scala | 6 +++--- .../scala/org/apache/spark/sql/functions.scala | 6 +++--- .../org/apache/spark/sql/jdbc/DB2Dialect.scala | 3 +-- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../org/apache/spark/sql/jdbc/MySQLDialect.scala | 3 +-- .../org/apache/spark/sql/sources/interfaces.scala | 8 ++++---- .../apache/spark/sql/test/ExamplePointUDT.scala | 2 +- .../spark/sql/util/QueryExecutionListener.scala | 2 +- .../server/HiveServerServerOptionsProcessor.scala | 2 +- .../sql/hive/thriftserver/HiveThriftServer2.scala | 5 ++--- .../SparkExecuteStatementOperation.scala | 5 ++--- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 6 ++---- .../hive/thriftserver/SparkSQLCLIService.scala | 2 +- .../sql/hive/thriftserver/SparkSQLDriver.scala | 6 +++--- .../spark/sql/hive/thriftserver/SparkSQLEnv.scala | 2 +- .../server/SparkSQLOperationManager.scala | 4 +++- .../hive/thriftserver/ui/ThriftServerPage.scala | 5 +++-- .../thriftserver/ui/ThriftServerSessionPage.scala | 3 ++- .../hive/thriftserver/ui/ThriftServerTab.scala | 2 +- .../spark/sql/hive/ExtendedHiveQlParser.scala | 2 +- .../org/apache/spark/sql/hive/HiveContext.scala | 13 ++++++------- .../apache/spark/sql/hive/HiveInspectors.scala | 9 +++++---- .../spark/sql/hive/HiveMetastoreCatalog.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 7 ++++--- .../org/apache/spark/sql/hive/HiveShim.scala | 4 +--- .../apache/spark/sql/hive/HiveStrategies.scala | 3 +-- .../org/apache/spark/sql/hive/TableReader.scala | 4 ++-- .../spark/sql/hive/client/ClientWrapper.scala | 7 ++++--- .../apache/spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/execution/CreateTableAsSelect.scala | 4 ++-- .../sql/hive/execution/CreateViewAsSelect.scala | 4 ++-- .../hive/execution/DescribeHiveTableCommand.scala | 2 +- .../sql/hive/execution/HiveNativeCommand.scala | 2 +- .../sql/hive/execution/InsertIntoHiveTable.scala | 10 +++++----- .../sql/hive/execution/ScriptTransformation.scala | 6 +++--- .../org/apache/spark/sql/hive/hiveUDFs.scala | 7 +++---- .../spark/sql/hive/hiveWriterContainers.scala | 4 ++-- .../apache/spark/sql/hive/orc/OrcFilters.scala | 2 +- .../apache/spark/sql/hive/orc/OrcRelation.scala | 4 ++-- .../org/apache/spark/sql/hive/test/TestHive.scala | 4 ++-- 164 files changed, 301 insertions(+), 318 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 22b7e1ea0c4cb..b19538a23f19f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql import java.lang.reflect.Modifier import scala.annotation.implicitNotFound -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index ed153d1f88945..b5de60cdb7b76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,21 +17,20 @@ package org.apache.spark.sql.catalyst -import java.beans.{PropertyDescriptor, Introspector} +import java.beans.{Introspector, PropertyDescriptor} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap, List => JList} +import java.util.{Iterator => JIterator, List => JList, Map => JMap} import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - /** * Type-inference utilities for POJOs and Java collections. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b0efdf3ef4024..79f723cf9b8a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 06efcd42aa62e..e362b55d80cd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 3b775c3ca87b8..e8b2fcf819bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} +import org.apache.spark.sql.catalyst.{CatalystConf, EmptyConf, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 64cad6ee787d0..fc0e87aa68ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{errors, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.{TableIdentifier, errors} import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8102c93c6f107..5ac1984043d87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,11 +21,11 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6c058463b9cf2..05f746e72b498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -22,15 +22,15 @@ import java.util.concurrent.ConcurrentMap import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.util.Utils import org.apache.spark.sql.{AnalysisException, Encoder} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} -import org.apache.spark.sql.types.{StructField, ObjectType, StructType} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.util.Utils /** * A factory for constructing encoders that convert objects and primitives to and from the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3903086a4c45b..89d40b3b2c141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 2dcbd4eb15031..04650d85dec03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 50ec1d0cccfba..f33833c3918df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 935c3aa28c999..ed894f6d6e10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 6b5aebc428a23..d0b78e15d99d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types.{LongType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types.{DataType, LongType} /** * Returns monotonically increasing 64-bit integers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 64d397bf848a9..3a6c909fffce7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 290c128d65b30..3add722da7816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 63ec8c64c14e2..aa3951480c503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types.{IntegerType, DataType} - +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Expression that returns the current partition id of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b616d6953baa8..b47f32d1768b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 440c7d2fc1156..6daa8ee2f42bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,7 @@ import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 80c5e41baa927..3353580148799 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} +import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic} /** * A trait that can be used to provide a fallback mode for expression code generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 13634b69457a2..364dbb770f5e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 037ae83d485d0..88b3c5e47f6ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform - abstract class UnsafeRowJoiner { def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 741ad1f3efd8a..7aac2e5e6c1b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 72cc89c8be915..d71bbd63c8e89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 91c275b1aa1c7..9c73239f67ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 311540e33576e..3d65946a1bc65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -20,15 +20,15 @@ package org.apache.spark.sql.catalyst.expressions import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import scala.util.Try + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, + GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import scala.util.Try - /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 894a0730d1c2a..e7ef21aa85891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e3573b4947379..672cc9c45e0af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.json4s.JsonAST._ import java.sql.{Date, Timestamp} +import org.json4s.JsonAST._ + import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 9c1a3294def24..002f5929cc26b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.NumberConverter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 814b3c22f8806..387d979484f2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index f1a333b8e56a7..3934e33628bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.aggregate.{NoOp, DeclarativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 572d7d2f0b537..d3b5879777a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index ccf5291219add..578027da776e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} /** * Transforms the input by forking and running the specified script. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 986062e3971c0..79759b5a37b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ -import scala.collection.mutable.ArrayBuffer case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index e6621e0f50a9e..47b34d1fa2e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types.StringType /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f6fb31a2af594..1bfe0ecb1e20b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, Unevaluable} import org.apache.spark.sql.types.{DataType, IntegerType} /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index c97dc2d8be7e6..d4be545a35ab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,25 +18,26 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID + import scala.collection.Map import scala.collection.mutable.Stack + import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.util.Utils -import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.{EmptyRDD, RDD} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.{ScalaReflectionLock, TableIdentifier} import org.apache.spark.sql.catalyst.ScalaReflection._ -import org.apache.spark.sql.catalyst.{TableIdentifier, ScalaReflectionLock} +import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 2b93882919487..f18c052b68e37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index a5ae8bb0e5eb6..90af10f7a6b1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.types import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{runtimeMirror, TypeTag} import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index a001eadcc61d0..6533622492d41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.types -import org.apache.spark.sql.catalyst.util.ArrayData +import scala.math.Ordering + import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi - -import scala.math.Ordering - +import org.apache.spark.sql.catalyst.util.ArrayData object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 2ca427975a1cf..d37130e27ba5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock - /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 301b3a70f68f3..136a97e066df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.types +import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The base type of all Spark SQL data types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c7a1a2e7469ee..38ce1604b1ede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.math.{RoundingMode, MathContext} +import java.math.{MathContext, RoundingMode} import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 2a1bf0938e5a8..e553f65f3c99d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Fractional, Numeric} +import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 08e22252aef82..ae9aa9eefaf2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.types +import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral -import scala.math.{Ordering, Fractional, Numeric} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index a2c6e19b05b3c..38a7b8ee52651 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 2b3adf6ade83b..88aff0c87755c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index a13119e659064..486cf585284df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 71fa970907f6f..e8c61d6e01dc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.SqlParser._ import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c42192c83de89..7cf2818590a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,15 +30,15 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.datasources.json.JacksonGenerator +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d4df913e472a1..d948e4894253c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -30,10 +30,10 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.SqlParser +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 69c984717526d..e66aa5f947181 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.{util => ju, lang => jl} +import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9afa6856907a7..e2d72a549e6b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -24,12 +24,11 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} -import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.sources.HadoopFsRelation - /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a763a951440cc..42f01e9359c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -23,9 +23,9 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 2aa82f1496ae5..c74ef2c03541e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,13 +21,12 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Aggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} import org.apache.spark.sql.types.NumericType - /** * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 4bf0b256fcb4f..a819ddceb1b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,8 +21,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.Aggregator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3a875c4f9a284..e827427c19e25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,11 +26,14 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.SQLConf.SQLConfEntry +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException @@ -38,16 +41,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6735d02954b8d..ab414799f1a42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -21,11 +21,10 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 67da7b808bf59..d912aeb70d517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.util.matching.Regex + import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SaveMode, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, GenericRowWithSchema, NamedExpression} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} - -import scala.util.matching.Regex private[r] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala index 663bc904f39c8..33475bea9af43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala index 827fdd278460a..07015e5a5aaef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution -import java.util.{Map => JMap, HashMap => JHashMap} +import java.util.{HashMap => JHashMap, Map => JMap} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SimpleFutureAction, ShuffleDependency, MapOutputStatistics} +import org.apache.spark.{Logging, MapOutputStatistics, ShuffleDependency, SimpleFutureAction} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index fc508bfafa1c0..569a21feaa8a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, AttributeSet, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{Row, SQLContext} - object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala index 6a8850129f1ac..ef84992e6979c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering} -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, GenerateUnsafeProjection} object GroupedIterator { def apply( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 3f391fd9a9ddb..38263af0f7e30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils + import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 34971986261c2..0a11b16d0ed35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, - SparkListenerSQLExecutionEnd} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, + SparkListenerSQLExecutionStart} import org.apache.spark.util.Utils private[sql] object SQLExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index e17b50edc62dd..909f124d2c9cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -21,8 +21,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} - +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparator, PrefixComparators} object SortPrefixUtils { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 45a8e03248267..c590f7c6c3e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -22,16 +22,15 @@ import java.util.{HashMap => JavaHashMap} import scala.reflect.ClassTag -import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Input, Output} import com.twitter.chill.ResourcePool +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair -import org.apache.spark.{SparkConf, SparkEnv} - private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 183d9b65023b9..6cf75bc17039c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -24,10 +25,9 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.{Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 4730647c4be9c..a23ebec95333b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import com.google.common.io.ByteStreams -import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index b79d93d7ca4c9..89b17c82459f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType -import org.apache.spark.rdd.RDD /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 01d076678f041..1d56592c40b96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 999ebb768af50..a9cf04388d2e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType case class TungstenAggregate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 582fdbe547061..41799c596b6d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index a9719128a626e..1df38f7ff59cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -21,11 +21,11 @@ import scala.language.existentials import org.apache.spark.Logging import org.apache.spark.sql.Encoder -import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index c0d00104e8bfd..5a19920add717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedMutableProjection, MutableRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, ImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index af7237ef25886..95bef683238a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.{HashPartitioner, SparkEnv} import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow @@ -28,8 +29,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler -import org.apache.spark.{HashPartitioner, SparkEnv} - case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index c52ee9ffd6d2a..5d4476989a369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index b208425ffc3c3..55e2c0ed70029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator, UnsafeRowWriter} import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index d80912309bab9..9084b74d1a741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Accumulable, Accumulator, Accumulators} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -30,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Accumulable, Accumulator, Accumulators} private[sql] object InMemoryRelation { def apply( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 8d99546924de1..2465633162c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.MutableRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 920381f9c63d0..b90d00b15b180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 6ec4cadeeb072..2e2ce88211a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,13 +21,13 @@ import java.util.NoSuchElementException import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index f22508b21090c..48eff62b297f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -22,7 +22,7 @@ import scala.util.matching.Regex import org.apache.spark.Logging import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.DataTypeParser diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3741a9cb32fd4..1d6290e027f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,22 +19,22 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 758bcd706a8c3..38152d0cf1a48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -22,6 +22,7 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat + import org.apache.spark._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index e02ee6cd6b907..0ca0a38f712ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -21,14 +21,14 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} -import scala.util.{Success, Failure, Try} +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 12f8783f846d6..d45d2db62f3a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -27,6 +27,8 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} + +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod @@ -34,8 +36,6 @@ import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} -import org.apache.spark.{Partition => SparkPartition, _} - private[spark] class SqlNewHadoopPartition( rddId: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 8b0b647744559..9f23d531072aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{StructType, StringType} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e759c011e75d2..aed5d0dcf2d8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index 5ae6cff9b5584..4dcd261f5cbe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} class DefaultSource extends RelationProvider with DataSourceRegister { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index cb8d9504af997..d867e144e517f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -24,15 +24,15 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} /** * Data corresponding to one partition of a JDBCRDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 375266f1be42c..1d40d23edc11a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,9 +23,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Instructions on how to partition the table among workers. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 10f650693f288..69ba84646f088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -25,9 +25,9 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} -import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} +import org.apache.spark.sql.types._ /** * Util functions for JDBC tables. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index f805c0092585c..aee9cf2bdbcaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.json -import com.fasterxml.jackson.core.{JsonParser, JsonFactory} +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} /** * Options for the JSON data source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 54a8552134c82..8bf538178b5d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -24,19 +24,19 @@ import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 3f34520afe6b6..078e1cbec577e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution.datasources.json -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils} - import scala.collection.Map import com.fasterxml.jackson.core._ import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 55a1c24e9e000..2e3fe3da15389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream + import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.core._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index e5d8e6088b395..c3b7483e80ab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -22,11 +22,11 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.io.api.RecordMaterializer -import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema._ +import org.apache.parquet.schema.Type.Repetition import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 8851bc23cd050..42d89f4bf81d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -25,14 +25,14 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{DOUBLE, INT32, INT64, BINARY, FIXED_LEN_BYTE_ARRAY} import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 5f9f9083098a7..fb97a03df60f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema._ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ -import org.apache.parquet.schema._ -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} -import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.types._ /** * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 6862dea5e6c3b..e78afa5ae6d0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, minBytesForPrecision} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index e54f51e3830f3..ecadb9e7c6ac6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.Log -import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.hadoop.util.ContextUtil /** * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ac9b65b66d986..e9b734b0abf50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable -import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.OriginalType import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 8e1fe8090cc1e..45f1dff96db08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Logger => JLogger} import java.util.{List => JList} +import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -32,14 +32,15 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl +import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType -import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ @@ -49,8 +50,6 @@ import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} - private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1a8e7ab202dc2..50ecbd35760d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} /** * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index fe69c72d28cb0..bd2d17c0189ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.datasources.text import com.google.common.base.Objects -import org.apache.hadoop.fs.{Path, FileStatus} -import org.apache.hadoop.io.{NullWritable, Text, LongWritable} -import org.apache.hadoop.mapred.{TextInputFormat, JobConf} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 74892e4e13fa4..dbb6b654b1a38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import scala.collection.mutable.HashSet +import org.apache.spark.{Accumulator, AccumulatorParam, Logging} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.{Accumulator, AccumulatorParam, Logging} /** * Contains methods for debugging query execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 1d381e2eaef38..0a818cc2c2a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils -import org.apache.spark.{InternalAccumulator, TaskContext} /** * Performs an inner hash join of two child relations. When the output RDD of this operator is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab81bd7b3fc04..6c7fa2eee5bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.{InternalAccumulator, TaskContext} /** * Performs a outer hash join for two child relations. When the output RDD of this operator is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 54275c2cc1134..e55f8694781a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index d9fa4c6b83798..93d32e1fb93ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c6f56cfaed22c..ee7a1bdc343c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,8 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -30,10 +31,8 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} +import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} import org.apache.spark.util.collection.CompactBuffer -import org.apache.spark.{SparkConf, SparkEnv} - /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index bf3b05be981fb..25b3b5ca2377f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index c3a2bfc59c7a4..ed41ad2a005eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -22,10 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index 3dcef94095647..59345046da495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} /** * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index e46217050bad5..a0dfe996ccd55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, Row} +import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index b7fa0c0202221..b93bde58a55e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.util.collection.{BitSet, CompactBuffer} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index 11529d6dd9b83..bd73b08263f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, UnsafeProjection} case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 6c0f6f8a52dc5..52735c9d7f8c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.metric -import org.apache.spark.util.Utils import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} +import org.apache.spark.util.Utils /** * Create a layer for specialized metric. We cannot add `@specialized` to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index efb4b09c16348..41e35fd724cde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ -import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} -import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext} +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index db463029aedf7..a191759813de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging +import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, Column, DataFrame} private[sql] object FrequentItems extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 725d6821bf11c..7d701949afcf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Row, Column, DataFrame} -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 622e01c46e1ad..cd56136927088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable +import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} -import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue} import org.apache.spark.ui.SparkUI @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 65117d5824755..6eea92451734e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.expressions +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} /** * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 9397fb84105a9..3921147857a07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.{catalyst, Column} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ - /** * :: Experimental :: * A window specification that defines the partitioning, ordering, and frame boundaries. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 11dbf391cff98..8b355befc34a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ -import org.apache.spark.annotation.Experimental /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1c96f647b6345..592d79df3109a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index b1cb0e55026be..f12b6ca9d6ad2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.sql.types.{BooleanType, StringType, DataType} - +import org.apache.spark.sql.types.{BooleanType, DataType, StringType} private object DB2Dialect extends JdbcDialect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 13db141f27db6..ca2d909e2cccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types._ /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index da413ed1f08b5..e1717049f383d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types -import org.apache.spark.sql.types.{BooleanType, LongType, DataType, MetadataBuilder} - +import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder} private case object MySQLDialect extends JdbcDialect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index d6c5d1435702d..f4c7f0a269323 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,21 +21,21 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} +import org.apache.spark.sql.execution.datasources.{Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 8d4854b698ed7..20a17ba82be9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index ac432e2baa3c0..e6f8779929d55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.util import java.util.concurrent.locks.ReentrantReadWriteLock + import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal @@ -25,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.execution.QueryExecution - /** * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala index 2228f651e2387..60bb4dc5e77b6 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -16,7 +16,7 @@ */ package org.apache.hive.service.server -import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} +import org.apache.hive.service.server.HiveServer2.{ServerOptionsProcessor, StartOptionExecutor} /** * Class to upgrade a package-private class to public, and diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3e3f0382f6a3b..66eaa3ebcd737 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -27,8 +27,9 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} +import org.apache.hive.service.server.{HiveServer2, HiveServerServerOptionsProcessor} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} import org.apache.spark.sql.SQLConf @@ -36,8 +37,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{Logging, SparkContext} - /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e022ee86a763a..cd2167c4ecb18 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} +import java.util.{Arrays, Map => JMap, UUID} import java.util.concurrent.RejectedExecutionException -import java.util.{Arrays, UUID, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} @@ -33,11 +33,10 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} - private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 8e7aa75bc3b2c..03bc830df2034 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -20,13 +20,10 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ import java.util.{ArrayList => JArrayList, Locale} -import org.apache.spark.sql.AnalysisException - import scala.collection.JavaConverters._ import jline.console.ConsoleReader import jline.console.history.FileHistory - import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -35,11 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, CommandProcessor, CommandProcessorFactory, SetProcessor} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 5ad8c54f296d5..6fe57554cf580 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,11 +27,11 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ import org.apache.hive.service.server.HiveServer2 -import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index f1ec7238520ac..4278aa30fbbd7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{Arrays, ArrayList => JArrayList, List => JList} -import org.apache.log4j.LogManager -import org.apache.spark.sql.AnalysisException +import java.util.{ArrayList => JArrayList, Arrays, List => JList} import scala.collection.JavaConverters._ @@ -27,8 +25,10 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse +import org.apache.log4j.LogManager import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index bacf6cc458fd5..ca25d23c3e37c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -21,9 +21,9 @@ import java.io.PrintStream import scala.collection.JavaConverters._ +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 476651a559d2c..9954d3436d37c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.thriftserver.server import java.util.{Map => JMap} + import scala.collection.mutable.Map import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession + import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, ReflectionUtils} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index e990bd06011ff..3719da4925ccb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.Logging -import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{SessionInfo, ExecutionState, ExecutionInfo} -import org.apache.spark.ui.UIUtils._ +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState, SessionInfo} import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ /** Page for Spark Web UI that shows statistics of a thrift server */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index af16cb31df187..27d1c8bab4d9f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.Logging import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} -import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ /** Page for Spark Web UI that shows statistics of a streaming job */ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 4eabeaa6735e6..1dc7d79436d72 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.thriftserver.ui +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.sql.hive.thriftserver.{HiveThriftServer2, SparkSQLEnv} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab._ import org.apache.spark.ui.{SparkUI, SparkUITab} -import org.apache.spark.{SparkContext, Logging, SparkException} /** * Spark Web UI tab that shows statistics of a streaming job. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 7f8449cdc282d..395c8bff53f47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} +import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 86769f1a0d412..cbaf00603e189 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -37,25 +37,24 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.hadoop.util.VersionInfo +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} -import org.apache.spark.sql.execution.datasources.{ResolveDataSource, DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} -import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck, ResolveDataSource} +import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkContext} - /** * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 95b57d6ad124a..7a260e72eb459 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,18 +19,19 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ +import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f099e146d1e37..1616c4595221d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -31,21 +31,21 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.execution.{datasources, FileRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} -import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} private[hive] case class HiveSerDe( inputFormat: Option[String] = None, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index cbfe09b31d380..31d82eb20f6e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -25,21 +25,23 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse.SemanticException import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.DescribeCommand @@ -48,7 +50,6 @@ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} import org.apache.spark.sql.parser._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index f0697613cff3b..b8cced0b80969 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID -import org.apache.avro.Schema - import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} - +import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index e8376083c0d39..0b4f5a0fd6ea6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -22,11 +22,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ - private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 70ee02823eeba..fd465e80a87e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -23,11 +23,11 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable, Hive, HiveUtils, HiveStorageHandler} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveStorageHandler, HiveUtils, Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index d3da22aa0ae5c..ce7a305d437a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -25,15 +25,16 @@ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} +import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} +import org.apache.hadoop.hive.ql.{metadata, Driver} import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.{SparkConf, SparkException, Logging} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 346840079b853..ca636b0265d41 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.serde.serdeConstants import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegralType} +import org.apache.spark.sql.types.{IntegralType, StringType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index e72a60b42e653..4c0aae6c04bd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** * Create table and insert the query result into it. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 2c81115ee4fed..6e288afbb4d2d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 441b6b6033e1f..dfa5a982b1584 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,10 +21,10 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation -import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 41b645b2c9c93..381fb61160ac4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.{Row, SQLContext} private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 44dc68e6ba47f..b02ace786c66c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -23,22 +23,22 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection} import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types.DataType import org.apache.spark.util.SerializableJobConf -import org.apache.spark.{SparkException, TaskContext} private[hive] case class InsertIntoHiveTable( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 6ccd4178190cd..5e6641693798f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -31,16 +31,16 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} /** * Transforms the input by forking and running the specified script. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a1787fc92d6d2..b1a6d0ab7df3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -21,15 +21,14 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import scala.util.Try -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 777e7857d2db2..22182ba00986f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -23,17 +23,17 @@ import java.util.Date import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 165210f9ff301..99a232f74fac2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 84ef12a68e1ba..3538d642d5231 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -28,19 +28,19 @@ import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspect import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 66d5f20d88421..d26cb48479066 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -29,7 +29,8 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.{SQLContext, SQLConf} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -39,7 +40,6 @@ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{SparkConf, SparkContext} // SPARK-3729: Test key required to check for initialization errors with config. object TestHive From 0d42292f6a2dbe626e8f6a50e6c61dd79533f235 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 5 Jan 2016 17:48:05 -0800 Subject: [PATCH 1451/2089] [SPARK-12504][SQL] Masking credentials in the sql plan explain output for JDBC data sources. This fix masks JDBC credentials in the explain output. URL patterns to specify credential seems to be vary between different databases. Added a new method to dialect to mask the credentials according to the database specific URL pattern. While adding tests I noticed explain output includes array variable for partitions ([Lorg.apache.spark.Partition;3ff74546,). Modified the code to include the first, and last partition information. Author: sureshthalamati Closes #10452 from sureshthalamati/mask_jdbc_credentials_spark-12504. --- .../datasources/jdbc/JDBCRelation.scala | 5 +++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 1d40d23edc11a..572be823ca87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -108,4 +108,9 @@ private[sql] case class JDBCRelation( .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) } + + override def toString: String = { + // credentials should not be included in the plan output, table information is sufficient. + s"JDBCRelation(${table})" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index dae72e8acb5a7..73e548e00f588 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -27,6 +27,8 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -551,4 +553,24 @@ class JDBCSuite extends SparkFunSuite assert(rows(0).getAs[java.sql.Timestamp](2) === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) } + + test("test credentials in the properties are not in plan output") { + val df = sql("SELECT * FROM parts") + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + r => assert(!List("testPass", "testUser").exists(r.toString.contains)) + } + // test the JdbcRelation toString output + df.queryExecution.analyzed.collect { + case r: LogicalRelation => assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE)") + } + } + + test("test credentials in the connection url are not in the plan output") { + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + r => assert(!List("testPass", "testUser").exists(r.toString.contains)) + } + } } From 70fe6ce52f26904aa53bd20409db69b52bccf315 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 5 Jan 2016 18:46:52 -0800 Subject: [PATCH 1452/2089] [SPARK-12659] fix NPE in UnsafeExternalSorter (used by cartesian product) Cartesian product use UnsafeExternalSorter without comparator to do spilling, it will NPE if spilling happens. This bug also hitted by #10605 cc JoshRosen Author: Davies Liu Closes #10606 from davies/fix_spilling. --- .../unsafe/sort/UnsafeExternalSorter.java | 7 +++-- .../unsafe/sort/UnsafeInMemorySorter.java | 17 +++++----- .../sort/UnsafeExternalSorterSuite.java | 31 +++++++++++++++++++ project/MimaExcludes.scala | 1 + 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 79d74b23ceaef..77d0b70bb892e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -400,6 +400,7 @@ public void merge(UnsafeExternalSorter other) throws IOException { * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(recordComparator != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); @@ -531,18 +532,20 @@ public long getKeyPrefix() { * * It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. + * + * TODO: support forced spilling */ public UnsafeSorterIterator getIterator() throws IOException { if (spillWriters.isEmpty()) { assert(inMemSorter != null); - return inMemSorter.getIterator(); + return inMemSorter.getSortedIterator(); } else { LinkedList queue = new LinkedList<>(); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { queue.add(spillWriter.getReader(blockManager)); } if (inMemSorter != null) { - queue.add(inMemSorter.getIterator()); + queue.add(inMemSorter.getSortedIterator()); } return new ChainedIterator(queue); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c16cbce9a0f6c..b7ab45675ee1e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -99,7 +99,11 @@ public UnsafeInMemorySorter( this.consumer = consumer; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); - this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + if (recordComparator != null) { + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } else { + this.sortComparator = null; + } this.array = array; } @@ -223,14 +227,9 @@ public void loadNext() { * {@code next()} will return the same mutable object. */ public SortedIterator getSortedIterator() { - sorter.sort(array, 0, pos / 2, sortComparator); - return new SortedIterator(pos / 2); - } - - /** - * Returns an iterator over record pointers in original order (inserted). - */ - public SortedIterator getIterator() { + if (sortComparator != null) { + sorter.sort(array, 0, pos / 2, sortComparator); + } return new SortedIterator(pos / 2); } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index e0ee281e98b71..32f5a1a7e6c5a 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -369,6 +369,37 @@ public void forcedSpillingWithNotReadIterator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void forcedSpillingWithoutComparator() throws Exception { + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + taskContext, + null, + null, + /* initialSize */ 1024, + pageSizeBytes); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + int batch = n / 4; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + if (i % batch == batch - 1) { + sorter.spill(); + } + } + UnsafeSorterIterator iter = sorter.getIterator(); + for (int i = 0; i < n; i++) { + iter.hasNext(); + iter.loadNext(); + assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8c3a40d2412a7..940fedfa2ab60 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,6 +40,7 @@ object MimaExcludes { excludePackage("org.apache.spark.rpc"), excludePackage("org.spark-project.jetty"), excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.util.collection.unsafe"), excludePackage("org.apache.spark.sql.catalyst"), excludePackage("org.apache.spark.sql.execution"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), From 7a375bb87a8df56d9dde0c484e725e5c497a9876 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 5 Jan 2016 19:02:25 -0800 Subject: [PATCH 1453/2089] [SPARK-3873][CORE] Import ordering fixes. Author: Marcelo Vanzin Closes #10578 from vanzin/SPARK-3873-core. --- .../org/apache/spark/ContextCleaner.scala | 2 +- .../spark/ExecutorAllocationManager.scala | 4 ++-- .../org/apache/spark/HeartbeatReceiver.scala | 4 ++-- .../scala/org/apache/spark/HttpServer.scala | 7 +++---- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 4 ++-- .../scala/org/apache/spark/SparkConf.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 18 ++++++++---------- .../main/scala/org/apache/spark/SparkEnv.scala | 6 +++--- .../org/apache/spark/SparkHadoopWriter.scala | 2 +- .../apache/spark/api/java/JavaPairRDD.scala | 2 +- .../apache/spark/api/java/JavaRDDLike.scala | 2 +- .../spark/api/java/JavaSparkContext.scala | 2 +- .../api/java/JavaSparkStatusTracker.scala | 2 +- .../org/apache/spark/api/java/JavaUtils.scala | 6 +++--- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../apache/spark/api/python/PythonUtils.scala | 2 +- .../spark/api/python/PythonWorkerFactory.scala | 2 +- .../apache/spark/api/python/SerDeUtil.scala | 5 ++--- .../WriteInputFormatTestDataGenerator.scala | 3 +-- .../apache/spark/api/r/RBackendHandler.scala | 2 +- .../scala/org/apache/spark/api/r/RRDD.scala | 3 +-- .../scala/org/apache/spark/api/r/SerDe.scala | 2 +- .../org/apache/spark/broadcast/Broadcast.scala | 6 +++--- .../spark/broadcast/BroadcastManager.scala | 3 +-- .../scala/org/apache/spark/deploy/Client.scala | 4 ++-- .../apache/spark/deploy/ClientArguments.scala | 1 + .../spark/deploy/FaultToleranceTest.scala | 2 +- .../spark/deploy/LocalSparkCluster.scala | 4 ++-- .../org/apache/spark/deploy/PythonRunner.scala | 2 +- .../apache/spark/deploy/RPackageUtils.scala | 2 +- .../org/apache/spark/deploy/RRunner.scala | 2 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 4 ++-- .../org/apache/spark/deploy/SparkSubmit.scala | 4 ++-- .../apache/spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/client/TestClient.scala | 4 ++-- .../deploy/history/FsHistoryProvider.scala | 2 +- .../spark/deploy/history/HistoryPage.scala | 2 +- .../spark/deploy/history/HistoryServer.scala | 1 + .../apache/spark/deploy/master/Master.scala | 4 ++-- .../deploy/master/PersistenceEngine.scala | 4 ++-- .../master/ZooKeeperLeaderElectionAgent.scala | 5 +++-- .../deploy/master/ui/ApplicationPage.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 4 ++-- .../spark/deploy/master/ui/MasterWebUI.scala | 2 +- .../deploy/mesos/MesosClusterDispatcher.scala | 2 +- .../spark/deploy/mesos/ui/DriverPage.scala | 3 +-- .../deploy/mesos/ui/MesosClusterPage.scala | 1 + .../spark/deploy/mesos/ui/MesosClusterUI.scala | 4 ++-- .../deploy/rest/RestSubmissionClient.scala | 4 ++-- .../deploy/rest/RestSubmissionServer.scala | 5 +++-- .../deploy/rest/StandaloneRestServer.scala | 4 ++-- .../deploy/rest/mesos/MesosRestServer.scala | 3 +-- .../spark/deploy/worker/DriverRunner.scala | 4 ++-- .../spark/deploy/worker/ExecutorRunner.scala | 5 +++-- .../apache/spark/deploy/worker/Worker.scala | 4 ++-- .../spark/deploy/worker/ui/LogPage.scala | 4 ++-- .../spark/deploy/worker/ui/WorkerPage.scala | 7 ++++--- .../CoarseGrainedExecutorBackend.scala | 3 ++- .../spark/executor/MesosExecutorBackend.scala | 2 +- .../input/FixedLengthBinaryInputFormat.scala | 2 +- .../input/FixedLengthBinaryRecordReader.scala | 2 +- .../spark/input/PortableDataStream.scala | 2 +- .../input/WholeTextFileRecordReader.scala | 5 ++--- .../spark/mapred/SparkHadoopMapRedUtil.scala | 2 +- .../apache/spark/memory/MemoryManager.scala | 2 +- .../spark/memory/StorageMemoryPool.scala | 4 ++-- .../spark/memory/UnifiedMemoryManager.scala | 2 +- .../apache/spark/metrics/MetricsConfig.scala | 2 +- .../apache/spark/metrics/MetricsSystem.scala | 3 +-- .../spark/metrics/sink/GraphiteSink.scala | 2 +- .../apache/spark/metrics/sink/JmxSink.scala | 1 + .../spark/metrics/sink/MetricsServlet.scala | 3 +-- .../apache/spark/metrics/sink/Slf4jSink.scala | 2 +- .../spark/network/BlockTransferService.scala | 8 ++++---- .../netty/NettyBlockTransferService.scala | 4 ++-- .../network/netty/SparkTransportConf.scala | 2 +- .../apache/spark/partial/StudentTCacher.scala | 2 +- .../apache/spark/partial/SumEvaluator.scala | 2 +- .../org/apache/spark/rdd/AsyncRDDActions.scala | 4 ++-- .../org/apache/spark/rdd/BinaryFileRDD.scala | 2 +- .../scala/org/apache/spark/rdd/BlockRDD.scala | 1 - .../org/apache/spark/rdd/CoGroupedRDD.scala | 5 ++--- .../apache/spark/rdd/DoubleRDDFunctions.scala | 2 +- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 10 +++++----- .../scala/org/apache/spark/rdd/JdbcRDD.scala | 6 +++--- .../org/apache/spark/rdd/NewHadoopRDD.scala | 6 +++--- .../apache/spark/rdd/PairRDDFunctions.scala | 4 ++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- .../spark/rdd/SequenceFileRDDFunctions.scala | 2 +- .../org/apache/spark/rpc/RpcEndpointRef.scala | 2 +- .../org/apache/spark/rpc/RpcTimeout.scala | 3 +-- .../apache/spark/rpc/netty/Dispatcher.scala | 4 ++-- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../spark/scheduler/EventLoggingListener.scala | 2 +- .../scheduler/OutputCommitCoordinator.scala | 2 +- .../apache/spark/scheduler/ResultTask.scala | 3 +-- .../apache/spark/scheduler/SparkListener.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 5 ++--- .../apache/spark/scheduler/TaskScheduler.scala | 2 +- .../spark/scheduler/TaskSchedulerImpl.scala | 6 +++--- .../spark/scheduler/TaskSetManager.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 4 ++-- .../spark/scheduler/cluster/ExecutorData.scala | 2 +- .../cluster/SimrSchedulerBackend.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../mesos/CoarseMesosSchedulerBackend.scala | 6 +++--- .../cluster/mesos/MesosClusterScheduler.scala | 8 ++++---- .../cluster/mesos/MesosSchedulerBackend.scala | 1 + .../cluster/mesos/MesosSchedulerUtils.scala | 6 +++--- .../serializer/GenericAvroSerializer.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 6 +++--- .../apache/spark/serializer/Serializer.scala | 2 +- .../spark/shuffle/BaseShuffleHandle.scala | 2 +- .../spark/shuffle/FetchFailedException.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../spark/shuffle/ShuffleBlockResolver.scala | 1 + .../apache/spark/shuffle/ShuffleManager.scala | 2 +- .../status/api/v1/AllStagesResource.scala | 2 +- .../status/api/v1/OneApplicationResource.scala | 2 +- .../spark/status/api/v1/OneJobResource.scala | 2 +- .../spark/status/api/v1/OneRDDResource.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../spark/storage/BlockManagerMaster.scala | 4 ++-- .../storage/BlockManagerMasterEndpoint.scala | 2 +- .../spark/storage/DiskBlockManager.scala | 4 ++-- .../spark/storage/DiskBlockObjectWriter.scala | 4 ++-- .../org/apache/spark/storage/DiskStore.scala | 2 +- .../org/apache/spark/storage/RDDInfo.scala | 2 +- .../spark/storage/TachyonBlockManager.scala | 2 -- .../scala/org/apache/spark/ui/SparkUI.scala | 10 +++++----- .../main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../spark/ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../apache/spark/ui/exec/ExecutorsTab.scala | 2 +- .../org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +- .../apache/spark/ui/jobs/AllStagesPage.scala | 2 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 2 +- .../org/apache/spark/ui/jobs/JobPage.scala | 5 ++--- .../org/apache/spark/ui/jobs/PoolPage.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 2 +- .../org/apache/spark/ui/jobs/UIData.scala | 6 +++--- .../spark/ui/scope/RDDOperationGraph.scala | 2 +- .../apache/spark/ui/storage/StorageTab.scala | 2 +- .../spark/util/AsynchronousListenerBus.scala | 1 + .../org/apache/spark/util/EventLoop.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../spark/util/MutableURLClassLoader.scala | 2 +- .../spark/util/ShutdownHookManager.scala | 1 + .../org/apache/spark/util/SizeEstimator.scala | 4 ++-- .../apache/spark/util/TimeStampedHashMap.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../spark/util/collection/ExternalSorter.scala | 4 ++-- .../spark/util/collection/OpenHashSet.scala | 1 + .../spark/util/collection/Spillable.scala | 2 +- .../util/logging/RollingFileAppender.scala | 6 ++++-- .../spark/util/random/RandomSampler.scala | 2 +- 158 files changed, 246 insertions(+), 250 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bc732535fed87..4628093b91cb8 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{TimeUnit, ScheduledExecutorService} +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 4926cafaed1b0..3431fc13dcb4e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -24,9 +24,9 @@ import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.scheduler._ import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 1f1f0b75de5f1..e03977828b86d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.concurrent.Future import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index faa3ef3d7561d..3c808420c8b29 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,18 +19,17 @@ package org.apache.spark import java.io.File -import org.eclipse.jetty.server.ssl.SslSocketConnector -import org.eclipse.jetty.util.security.{Constraint, Password} -import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} +import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector +import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} +import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils - /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 72355cdfa68b3..8670f705cdb7e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ef9a2dab1c106..a7c2790c8360b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -21,13 +21,13 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import scala.util.hashing.byteswap32 import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} -import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ff2c4c34c0ca7..340e1f7824d1e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet -import org.apache.avro.{SchemaNormalization, Schema} +import org.apache.avro.{Schema, SchemaNormalization} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 87301202dea27..4a99c0b081d6a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -17,20 +17,19 @@ package org.apache.spark -import scala.language.implicitConversions - import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} import java.util.UUID.randomUUID import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.HashMap -import scala.reflect.{ClassTag, classTag} +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal import org.apache.commons.lang.SerializationUtils @@ -42,27 +41,26 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} - import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, - FixedLengthBinaryInputFormat} +import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, + WholeTextFileInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, - SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SimrSchedulerBackend, + SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump -import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} +import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b98cc964eda87..ec43be0e2f3a4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -29,12 +29,12 @@ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager -import org.apache.spark.metrics.MetricsSystem import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} -import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index dd400b8ae8a16..58647860623ee 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -22,9 +22,9 @@ import java.text.NumberFormat import java.text.SimpleDateFormat import java.util.Date -import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.mapred.SparkHadoopMapRedUtil diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 87deaf20e2b25..91dc18697c352 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} +import java.util.{Comparator, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.implicitConversions diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 9cf68672beca2..6d3485d88a163 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.java import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.util.{Comparator, Iterator => JIterator, List => JList} import scala.collection.JavaConverters._ import scala.reflect.ClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 9f5b89bb4ba45..9990b22e14a25 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration -import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -35,6 +34,7 @@ import org.apache.spark._ import org.apache.spark.AccumulatorParam._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast +import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 3300cad9efbab..99ca3c77cced0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext} +import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} /** * Low-level status reporting APIs for monitoring job and stage progress. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 8f9647eea9e25..b2a4d053fa650 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,13 +17,13 @@ package org.apache.spark.api.java +import java.{util => ju} import java.util.Map.Entry -import com.google.common.base.Optional - -import java.{util => ju} import scala.collection.mutable +import com.google.common.base.Optional + private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8464b578ed09e..f12e2dfafa19d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 292ac4cfc35b9..2d97cd9a9a208 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} private[spark] object PythonUtils { /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 7039b734d2e40..a2a2f89f1e875 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.util.Arrays diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index fd27276e70bfe..b0d858486bfb4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -20,16 +20,15 @@ package org.apache.spark.api.python import java.nio.ByteOrder import java.util.{ArrayList => JArrayList} -import org.apache.spark.api.java.JavaRDD - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure import scala.util.Try -import net.razorvine.pickle.{Unpickler, Pickler} +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.{Logging, SparkException} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index ee1fb056f0d96..9549784aeabf5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -17,13 +17,12 @@ package org.apache.spark.api.python -import java.io.{DataOutput, DataInput} import java.{util => ju} +import java.io.{DataInput, DataOutput} import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 - import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0095548c463cc..9bddd7248c7ef 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable.HashMap import scala.language.existentials -import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import io.netty.channel.ChannelHandler.Sharable import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 7509b3d3f44bb..401f362fee827 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -19,8 +19,7 @@ package org.apache.spark.api.r import java.io._ import java.net.{InetAddress, ServerSocket} -import java.util.Arrays -import java.util.{Map => JMap} +import java.util.{Arrays, Map => JMap} import scala.collection.JavaConverters._ import scala.io.Source diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index da126bac7ad1f..af815f885e8ae 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Timestamp, Date, Time} +import java.sql.{Date, Time, Timestamp} import scala.collection.JavaConverters._ import scala.collection.mutable.WrappedArray diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 12d79f6ed311b..0d68872dcb6e4 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -19,12 +19,12 @@ package org.apache.spark.broadcast import java.io.Serializable -import org.apache.spark.SparkException +import scala.reflect.ClassTag + import org.apache.spark.Logging +import org.apache.spark.SparkException import org.apache.spark.util.Utils -import scala.reflect.ClassTag - /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable * cached on each machine rather than shipping a copy of it with tasks. They can be used, for diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 61343607a13bc..be416c4f74cb3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,8 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.spark.{Logging, SparkConf, SecurityManager} - +import org.apache.spark.{Logging, SecurityManager, SparkConf} private[spark] class BroadcastManager( val isDriver: Boolean, diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 328a1bb84f5fb..63a20ab41a0f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -24,11 +24,11 @@ import scala.util.{Failure, Success} import org.apache.log4j.{Level, Logger} -import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.util.{SparkExitCode, ThreadUtils, Utils} /** * Proxy that relays messages to the driver. diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 72cc330a398da..255420182b493 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -22,6 +22,7 @@ import java.net.{URI, URISyntaxException} import scala.collection.mutable.ListBuffer import org.apache.log4j.Level + import org.apache.spark.util.{IntParam, MemoryParam, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index b4edb6109e839..c0ede4b7c8734 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -22,7 +22,7 @@ import java.net.URL import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer -import scala.concurrent.{Await, future, promise} +import scala.concurrent.{future, promise, Await} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.language.postfixOps diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 5bb62d37d6374..2dfb813d5fb41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index d85327603f64d..c0a9e3f280ba1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy -import java.net.URI import java.io.File +import java.net.URI import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index d46dc87a92c97..4911c3be3a024 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.{ByteStreams, Files} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.api.r.RUtils import org.apache.spark.util.{RedirectThread, Utils} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 661f7317c674b..d0466830b2177 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.{SparkException, SparkUserAppException} +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 4bd94f13e57e6..8ba3f5e241899 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -29,15 +29,15 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkException} /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 669b6b614e38c..a1e8da1ec8f5d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -37,9 +37,9 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository -import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SparkException, SparkUserAppException, SPARK_VERSION} +import org.apache.spark.{SPARK_VERSION, SparkException, SparkUserAppException} import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index f7c33214c2406..a7a0a78f14562 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.client import java.util.concurrent._ -import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.util.control.NonFatal diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index adb3f02258029..f8d3da24b956f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,9 @@ package org.apache.spark.deploy.client -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, Command} +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils private[spark] object TestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index c93bc8c127f58..22e4155cc5452 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 642d71b18c9e2..04bad79dccab8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 6143a33b69344..96007a06e3c54 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -25,6 +25,7 @@ import scala.util.control.NonFatal import com.google.common.cache._ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} + import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index bd3d981ce08b4..0deab8ddd5270 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -24,14 +24,13 @@ import java.util.Date import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.language.postfixOps import scala.util.Random import org.apache.hadoop.fs.Path -import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -42,6 +41,7 @@ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc._ import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index 58a00bceee6af..dddf2be57ee42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -17,11 +17,11 @@ package org.apache.spark.deploy.master +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rpc.RpcEnv -import scala.reflect.ClassTag - /** * Allows Master to persist any state that is necessary in order to recover from a failure. * The following semantics are required: diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index d317206a614fb..336cb24c19b1f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.master -import org.apache.spark.{Logging, SparkConf} import org.apache.curator.framework.CuratorFramework -import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.curator.framework.recipes.leader.{LeaderLatch, LeaderLatchListener} + +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f405aa2bdc8b3..1b18cf0ded69d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -21,8 +21,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.master.ExecutorDesc import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index ee539dd1f5113..f9b0279c3d1e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -23,10 +23,10 @@ import scala.xml.Node import org.json4s.JValue +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, MasterStateResponse, RequestKillDriver, RequestMasterState} import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index e41554a5a6d26..750ef0a962550 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.Logging import org.apache.spark.deploy.master.Master -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo, +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 89f1a8671fdb6..66e1e645007a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,11 +19,11 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.scheduler.cluster.mesos._ import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{Logging, SecurityManager, SparkConf} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index bc67fd460d9a9..807835105ec3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -23,10 +23,9 @@ import scala.xml.Node import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.scheduler.cluster.mesos.{MesosClusterSubmissionState, MesosClusterRetryState} +import org.apache.spark.scheduler.cluster.mesos.{MesosClusterRetryState, MesosClusterSubmissionState} import org.apache.spark.ui.{UIUtils, WebUIPage} - private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") { override def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 7419fa9699648..166f666fbcfdc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.mesos.Protos.TaskStatus + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 3f693545a0349..da9740bb41f59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy.mesos.ui -import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.ui.{SparkUI, WebUI} +import org.apache.spark.ui.JettyUtils._ /** * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 0744c64d5e944..4ec6bfe2f9eb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -23,15 +23,15 @@ import java.util.concurrent.TimeoutException import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import scala.concurrent.duration._ import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} +import org.apache.spark.util.Utils /** * A client that submits applications to a [[RestSubmissionServer]]. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 2e78d03e5c0cc..8e0862df4c29a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -21,14 +21,15 @@ import java.net.InetSocketAddress import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source + import com.fasterxml.jackson.core.JsonProcessingException import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index d5b9bcab1423f..c19296c7b3e00 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,11 +20,11 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** * A server that responds to requests submitted by the [[RestSubmissionClient]]. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 868cc35d06ef3..a8b2f788893d9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -23,13 +23,12 @@ import java.util.Date import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest._ import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.util.Utils -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} - /** * A server that responds to requests submitted by the [[RestSubmissionClient]]. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 89159ff5e2b3c..6049db6d989ae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -25,13 +25,13 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.{Utils, Clock, SystemClock} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Manages the execution of one driver, including automatically restarting the driver on failure. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 9c4b8cdc646b0..c6687a4c63a6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -23,12 +23,13 @@ import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged + +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender -import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Manages the execution of one executor process. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 37b94e02cc9d1..98e17da489741 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{UUID, Date} +import java.util.{Date, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -37,7 +37,7 @@ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, ThreadUtils, Utils} private[deploy] class Worker( override val rpcEnv: RpcEnv, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 5a1d06eb87db9..49803a27a5b00 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -23,9 +23,9 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.{WebUIPage, UIUtils} -import org.apache.spark.util.Utils import org.apache.spark.Logging +import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.util.Utils import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index fd905feb97e92..8ebcbcb6a1738 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,16 +17,17 @@ package org.apache.spark.deploy.worker.ui +import javax.servlet.http.HttpServletRequest + import scala.xml.Node -import javax.servlet.http.HttpServletRequest import org.json4s.JValue -import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} +import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index edbd7225ca06a..58bd9ca3d12c3 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,11 +22,12 @@ import java.nio.ByteBuffer import scala.collection.mutable import scala.util.{Failure, Success} -import org.apache.spark.rpc._ + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher +import org.apache.spark.rpc._ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index d85465eb25683..cfd9bcd65c566 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -21,9 +21,9 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} +import org.apache.mesos.protobuf.ByteString import org.apache.spark.{Logging, SparkConf, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index 30431a9b986bb..bc98273add3a6 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -19,8 +19,8 @@ package org.apache.spark.input import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.Logging diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 25596a15d93c0..549395314ba61 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -20,8 +20,8 @@ package org.apache.spark.input import java.io.IOException import org.apache.hadoop.fs.FSDataInputStream -import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileSplit diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index cb76e3c344fca..8009491a1b0e0 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.JavaConverters._ -import com.google.common.io.{Closeables, ByteStreams} +import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 998c898a3fc25..6b7f086678e93 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -17,15 +17,14 @@ package org.apache.spark.input -import org.apache.hadoop.conf.{Configuration, Configurable => HConfigurable} import com.google.common.io.{ByteStreams, Closeables} - +import org.apache.hadoop.conf.{Configurable => HConfigurable, Configuration} import org.apache.hadoop.io.Text import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.input.{CombineFileRecordReader, CombineFileSplit} /** * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface. diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 249bdf5994f8f..6841485f4b930 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -22,8 +22,8 @@ import java.io.IOException import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} -import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.executor.CommitDeniedException object SparkHadoopMapRedUtil extends Logging { /** diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index e707e27d96b50..33f8b9f16c11b 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -21,7 +21,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 70af83b5ee092..4036484aada23 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -22,8 +22,8 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{TaskContext, Logging} -import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} /** * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 829f054dba0e9..57a24ac140287 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import scala.collection.mutable import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockStatus, BlockId} +import org.apache.spark.storage.{BlockId, BlockStatus} /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index dd2d325d87034..8540984bfe827 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils private[spark] class MetricsConfig(conf: SparkConf) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index fdf76d312db3b..e34cfc698dcef 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,8 +20,6 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit -import org.apache.spark.util.Utils - import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -30,6 +28,7 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source +import org.apache.spark.util.Utils /** * Spark Metrics System, created by specific "instance", combined by source, diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 2d25ebd66159f..22454e50b14b4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter} +import com.codahale.metrics.graphite.{Graphite, GraphiteReporter, GraphiteUDP} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 2588fe2c9edb8..1992b42ac7f6b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics.sink import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} + import org.apache.spark.SecurityManager private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry, diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 4193e1d21d3c1..68b58b8490641 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -19,7 +19,6 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit - import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry @@ -27,7 +26,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{SparkConf, SecurityManager} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 11dfcfe2f04e1..773e074336cb0 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -20,7 +20,7 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit -import com.codahale.metrics.{Slf4jReporter, MetricRegistry} +import com.codahale.metrics.{MetricRegistry, Slf4jReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index dcbda5a8515dd..15d3540f3427b 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,13 +20,13 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Promise, Await, Future} +import scala.concurrent.{Await, Future, Promise} import scala.concurrent.duration.Duration import org.apache.spark.Logging -import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} -import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.storage.{BlockId, BlockManagerId, StorageLevel} private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 40604a4da18d5..f588a28eed28d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -25,10 +25,10 @@ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 84833f59d7afe..86874e2067dd4 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -18,7 +18,7 @@ package org.apache.spark.network.netty import org.apache.spark.SparkConf -import org.apache.spark.network.util.{TransportConf, ConfigProvider} +import org.apache.spark.network.util.{ConfigProvider, TransportConf} /** * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor, diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala index 828bf96c2c0bd..55acb9ca64d3f 100644 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} /** * A utility class for caching Student's T distribution values for a given confidence level diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 1753c2561b678..44295e5a1affe 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} import org.apache.spark.util.StatCounter diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 14f541f937b4c..ec48925823a02 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -20,10 +20,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} import scala.reflect.ClassTag -import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter, Logging} import org.apache.spark.util.ThreadUtils /** diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 2bf2337d49fef..be0cb175f5340 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -22,8 +22,8 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.JobContextImpl +import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat -import org.apache.spark.{ Partition, SparkContext } private[spark] class BinaryFileRDD[T]( sc: SparkContext, diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fc1710fbad0a3..8358244987a6d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.storage.{BlockId, BlockManager} -import scala.Some private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 3a0ca1d813297..3587e7eb1afaf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,18 +17,17 @@ package org.apache.spark.rdd -import scala.language.existentials - import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils -import org.apache.spark.serializer.Serializer /** The references to rdd and splitIndex are transient because redundant information is stored * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 7fbaadcea3a3b..c07f346bbafd5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -17,8 +17,8 @@ package org.apache.spark.rdd +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.annotation.Experimental -import org.apache.spark.{TaskContext, Logging} import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.MeanEvaluator import org.apache.spark.partial.PartialResult diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 920d3bf219ff5..a7a6e0b8a94f6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,22 +17,22 @@ package org.apache.spark.rdd +import java.io.EOFException import java.text.SimpleDateFormat import java.util.Date -import java.io.EOFException import scala.collection.immutable.Map -import scala.reflect.ClassTag import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter -import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID import org.apache.hadoop.mapred.lib.CombineFileSplit @@ -45,9 +45,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} -import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} +import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager, Utils} /** * A Spark split class that wraps around a Hadoop InputSplit. diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0c28f045e46e9..469962db6763c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,15 +17,15 @@ package org.apache.spark.rdd -import java.sql.{PreparedStatement, Connection, ResultSet} +import java.sql.{Connection, PreparedStatement, ResultSet} import scala.reflect.ClassTag +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.util.NextIterator -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index: Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 8b330a34c3d3a..146609ae3911a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -28,13 +28,13 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} private[spark] class NewHadoopPartition( rddId: Int, diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 76b31165aa74c..16a856f594e97 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} -import scala.collection.{Map, mutable} +import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -33,7 +33,7 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskType, TaskAttemptID} +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 394f79dc7734e..d6eac7888d5fd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -40,7 +40,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, +import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} /** diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index c4bc85a5ea2d5..92d9e3581ee59 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.rdd -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import org.apache.hadoop.io.Writable import org.apache.hadoop.io.compress.CompressionCodec diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 623da3e9c11b8..154398b57280a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -20,8 +20,8 @@ package org.apache.spark.rpc import scala.concurrent.Future import scala.reflect.ClassTag +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.util.RpcUtils -import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 285786ebf9f1b..8b4ebf34ba83c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -19,13 +19,12 @@ package org.apache.spark.rpc import java.util.concurrent.TimeoutException -import scala.concurrent.{Awaitable, Await} +import scala.concurrent.{Await, Awaitable} import scala.concurrent.duration._ import org.apache.spark.SparkConf import org.apache.spark.util.Utils - /** * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 533c9847661b6..19259e0e800c3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,14 +17,14 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.concurrent.Promise import scala.util.control.NonFatal -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b128ed50cad52..92438ba892cc0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -40,8 +40,8 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.util._ /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 68792c58c9b4e..aa607c5a2df93 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.fs.permission.FsPermission import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION} +import org.apache.spark.{Logging, SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 4d146678174f6..3e3ab15d8a24b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable import org.apache.spark._ -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} private sealed trait OutputCommitCoordinationMessage extends Serializable diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index fb693721a9cb6..6590cf6ffd24f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer - import java.io._ +import java.nio.ByteBuffer import org.apache.spark._ import org.apache.spark.broadcast.Broadcast diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 075a7f13172de..3130a65240a99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -29,8 +29,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} -import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9f27eed626be3..0379ca2af6ab3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,14 +22,13 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} - /** * A unit of execution. We have two kinds of Task's in Spark: * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cb9a3008107d7..7c0b007db708e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,8 +17,8 @@ package org.apache.spark.scheduler -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bdf19f9f277d9..6e3ef0e54f0fd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{TimerTask, Timer} +import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -30,11 +30,11 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.{ThreadUtils, Utils} -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 380301f1c9aec..aa39b59d8cce4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -25,7 +25,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.{min, max} +import scala.math.{max, min} import scala.util.control.NonFatal import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f222007a38c9b..b808993aa6cd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -22,12 +22,12 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} +import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME -import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, SerializableBuffer, ThreadUtils, Utils} /** * A scheduler backend that waits for coarse-grained executors to connect. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 626a2b7d69abe..b25a4bfb501fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 781ecfff7e5e7..0a6f2c01c18df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 1209cce6d1a61..16f33163789ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,11 +19,11 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress} import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index a4ed85cd2a4a3..58c30e7d97886 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,20 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 16815d51d4c67..05fda0fded7f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -24,16 +24,16 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} -import org.apache.mesos.{Scheduler, SchedulerDriver} + +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils -import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} - /** * Tracks the current state of a Mesos Task that runs a Spark driver. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 281965a5981bb..eaf0cb06d6c73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString + import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 721861fbbc517..010caff3e39b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -25,12 +25,12 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} -import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 8d6af9cae8927..3d5b7105f0ca8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -29,7 +29,7 @@ import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.io._ import org.apache.commons.io.IOUtils -import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec /** diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1b4538e6afb85..bc9fd50c2cd2b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{DataInput, DataOutput, EOFException, IOException, InputStream, OutputStream} +import java.io.{DataInput, DataOutput, EOFException, InputStream, IOException, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable @@ -25,9 +25,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} -import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap @@ -37,8 +37,8 @@ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index bd2704dc81871..90c0728557b99 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.annotation.{DeveloperApi, Private} -import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} +import org.apache.spark.util.{ByteBufferInputStream, NextIterator, Utils} /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index b36c457d6d514..0a65bbf8ddab4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} +import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency} import org.apache.spark.serializer.Serializer /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index be184464e0ae9..b2d050b218f53 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,8 +17,8 @@ package org.apache.spark.shuffle -import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index cc5f933393adf..7abcb29672cf5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} -import org.apache.spark.{Logging, SparkConf, SparkEnv} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index fadb8fe7ed0ab..68aba52fd7c6b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,12 +21,12 @@ import java.io._ import com.google.common.io.ByteStreams +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils -import org.apache.spark.{SparkEnv, Logging, SparkConf} /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 4342b0d598b16..81aea33ee41b4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer + import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a3444bf4daa3b..76fd249fbd2dc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{TaskContext, ShuffleDependency} +import org.apache.spark.{ShuffleDependency, TaskContext} /** * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 31b4dd7c0f427..341ae782362a0 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,8 +17,8 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.core.MediaType import javax.ws.rs.{GET, Produces, QueryParam} +import javax.ws.rs.core.MediaType import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index b5ef72649e295..d7e6a8b589953 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -16,8 +16,8 @@ */ package org.apache.spark.status.api.v1 +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType -import javax.ws.rs.{Produces, PathParam, GET} @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class OneApplicationResource(uiRoot: UIRoot) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala index 6d8a60d480aed..a0f6360bc5c72 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.JobExecutionStatus diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index dfdc09c6caf3b..237aeac185877 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b5b7804d54ce2..8caf9e55359e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,8 +21,8 @@ import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import scala.util.Random import scala.util.control.NonFatal diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 440c4c18aadd0..da1de11d605c9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -21,10 +21,10 @@ import scala.collection.Iterable import scala.collection.generic.CanBuildFrom import scala.concurrent.{Await, Future} -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ThreadUtils, RpcUtils} +import org.apache.spark.util.{RpcUtils, ThreadUtils} private[spark] class BlockManagerMaster( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 41892b4ffce5b..4db400a3442ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f7e84a2c2e14c..4daf22f71415e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -17,10 +17,10 @@ package org.apache.spark.storage +import java.io.{File, IOException} import java.util.UUID -import java.io.{IOException, File} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e2dd80f243930..e36a367323b20 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} +import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{SerializationStream, SerializerInstance} import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 6c4477184d5b4..1f3f193f2ffa2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{IOException, File, FileOutputStream, RandomAccessFile} +import java.io.{File, FileOutputStream, IOException, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 94e8559bd2e91..673f7ad79def0 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDDOperationScope, RDD} +import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.{CallSite, Utils} @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 7f88f2fe6d503..6aa7e13901779 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -25,7 +25,6 @@ import java.util.{Date, Random} import scala.util.control.NonFatal import com.google.common.io.ByteStreams - import tachyon.{Constants, TachyonURI} import tachyon.client.ClientContext import tachyon.client.file.{TachyonFile, TachyonFileSystem} @@ -38,7 +37,6 @@ import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.util.Utils - /** * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By * default, one block is mapped to one file with a name given by its BlockId. diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 8da6884a38535..e319937702f23 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -21,18 +21,18 @@ import java.util.{Date, ServiceLoader} import scala.collection.JavaConverters._ -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, - UIRoot} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, + UIRoot} import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} -import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} -import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.util.Utils /** * Top level user interface for a Spark application. diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 81a121fd441bd..3925235984723 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -26,9 +26,9 @@ import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * The top level component of the UI hierarchy that contains the server. diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 58575d154ce5c..1a6f0fdd50df7 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -21,7 +21,7 @@ import java.net.URLDecoder import javax.servlet.http.HttpServletRequest import scala.util.Try -import scala.xml.{Text, Node} +import scala.xml.{Node, Text} import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index a88fc4c37d3c9..2d955a66601ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext} +import org.apache.spark.{ExceptionFailure, Resubmitted, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d467dd9e1f29d..db9912bc817e8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.{HashMap, ListBuffer} import scala.xml._ import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData} import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} +import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData} /** Page showing list of all ongoing and recently finished jobs */ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 5e52942b64f3f..e75f1c57a69d0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, NodeSeq} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 1268f44596f8a..1304efd8f2ec7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.jobs import scala.collection.mutable -import scala.xml.{Unparsed, Node} +import scala.xml.{Node, Unparsed} import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 2cad0a796913e..654d988807f9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -18,11 +18,10 @@ package org.apache.spark.ui.jobs import java.util.Date +import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, HashMap, ListBuffer} -import scala.xml.{NodeSeq, Node, Unparsed, Utility} - -import javax.servlet.http.HttpServletRequest +import scala.xml.{Node, NodeSeq, Unparsed, Utility} import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.StageInfo diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index f3e0b38523f32..fa30f2bda4272 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.scheduler.StageInfo -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing specific pool details */ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 08e7576b0c08e..2cc6c75a9ac12 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -31,7 +31,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ -import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index f008d40180611..78165d7b743e2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -17,14 +17,14 @@ package org.apache.spark.ui.jobs +import scala.collection.mutable +import scala.collection.mutable.HashMap + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.collection.OpenHashSet -import scala.collection.mutable -import scala.collection.mutable.HashMap - private[spark] object UIData { class ExecutorSummary { diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index e9c8a8e299cd7..06da74f1b6b5f 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.{StringBuilder, ListBuffer} +import scala.collection.mutable.{ListBuffer, StringBuilder} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 22e2993b3b5bd..2d9b885c684b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -20,9 +20,9 @@ package org.apache.spark.ui.storage import scala.collection.mutable import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ui._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ +import org.apache.spark.ui._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 6c1fca71f2281..f6b7ea2f37869 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean + import scala.util.DynamicVariable import org.apache.spark.SparkContext diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index e9b2b8d24b476..542c5fccf458a 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -17,8 +17,8 @@ package org.apache.spark.util -import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} +import java.util.concurrent.atomic.AtomicBoolean import scala.util.control.NonFatal diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index cb0f1bf79f3d5..a62fd2f339285 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -25,8 +25,8 @@ import scala.collection.Map import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats -import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 945217203be72..0a3180da87987 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.net.{URLClassLoader, URL} +import java.net.{URL, URLClassLoader} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index acc24ca0fb814..38523be791cec 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -23,6 +23,7 @@ import java.util.PriorityQueue import scala.util.Try import org.apache.hadoop.fs.FileSystem + import org.apache.spark.Logging /** diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 09864e3f8392d..52587d2188943 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -17,8 +17,6 @@ package org.apache.spark.util -import com.google.common.collect.MapMaker - import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} @@ -27,6 +25,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer import scala.runtime.ScalaRunTime +import com.google.common.collect.MapMaker + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index d7e5143c30953..1733025041067 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,8 +17,8 @@ package org.apache.spark.util -import java.util.Set import java.util.Map.Entry +import java.util.Set import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9bdcc4d817a4a..9ecbffbf715c5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,8 +22,8 @@ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels -import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} +import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection import scala.collection.JavaConverters._ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f6d81ee5bf05e..4a44481cf4e14 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -28,12 +28,12 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator -import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 44b1d90667e65..63ba954a7fa7e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -20,15 +20,15 @@ package org.apache.spark.util.collection import java.io._ import java.util.Comparator -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60bf4dd7469f1..0f6a425e3db9a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.collection import scala.reflect._ + import com.google.common.hash.Hashing import org.apache.spark.annotation.Private diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 3a48af82b1dae..e1592184ca6d3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,8 +17,8 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} /** * Spills contents of an in-memory collection to disk when the memory threshold diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 1e8476c4a047e..050ece12f1728 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -20,8 +20,8 @@ package org.apache.spark.util.logging import java.io.{File, FileFilter, InputStream} import com.google.common.io.Files + import org.apache.spark.SparkConf -import RollingFileAppender._ /** * Continuously appends data from input stream into the given file, and rolls @@ -39,9 +39,11 @@ private[spark] class RollingFileAppender( activeFile: File, val rollingPolicy: RollingPolicy, conf: SparkConf, - bufferSize: Int = DEFAULT_BUFFER_SIZE + bufferSize: Int = RollingFileAppender.DEFAULT_BUFFER_SIZE ) extends FileAppender(inputStream, activeFile, bufferSize) { + import RollingFileAppender._ + private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) /** Stop the appender */ diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index c156b03cdb7c4..1314217023d15 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -19,8 +19,8 @@ package org.apache.spark.util.random import java.util.Random -import scala.reflect.ClassTag import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import org.apache.commons.math3.distribution.PoissonDistribution From b3ba1be3b77e42120145252b2730a56f1d55fd21 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 5 Jan 2016 19:07:39 -0800 Subject: [PATCH 1454/2089] [SPARK-3873][TESTS] Import ordering fixes. Author: Marcelo Vanzin Closes #10582 from vanzin/SPARK-3873-tests. --- .../org/apache/spark/ContextCleanerSuite.scala | 10 +++------- .../spark/ExecutorAllocationManagerSuite.scala | 1 + .../scala/org/apache/spark/FailureSuite.scala | 4 ++-- .../org/apache/spark/FileServerSuite.scala | 4 ++-- .../scala/org/apache/spark/FileSuite.scala | 9 ++++----- .../apache/spark/HeartbeatReceiverSuite.scala | 6 +++--- .../org/apache/spark/LocalSparkContext.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 ++-- .../org/apache/spark/SSLOptionsSuite.scala | 3 ++- .../scala/org/apache/spark/ShuffleSuite.scala | 6 +++--- .../org/apache/spark/SortShuffleSuite.scala | 2 +- .../org/apache/spark/SparkConfSuite.scala | 9 +++++---- .../apache/spark/SparkContextInfoSuite.scala | 1 + .../SparkContextSchedulerCreationSuite.scala | 2 +- .../org/apache/spark/SparkContextSuite.scala | 10 +++++----- .../org/apache/spark/ThreadingSuite.scala | 5 ++--- .../api/python/PythonBroadcastSuite.scala | 4 ++-- .../apache/spark/deploy/DeployTestUtils.scala | 2 +- .../org/apache/spark/deploy/IvyTestUtils.scala | 4 +--- .../spark/deploy/JsonProtocolSuite.scala | 2 +- .../spark/deploy/LogUrlsStandaloneSuite.scala | 4 ++-- .../spark/deploy/RPackageUtilsSuite.scala | 4 ++-- .../spark/deploy/SparkSubmitUtilsSuite.scala | 4 ++-- .../history/FsHistoryProviderSuite.scala | 2 +- .../deploy/master/PersistenceEngineSuite.scala | 2 +- .../deploy/master/ui/MasterWebUISuite.scala | 4 ++-- .../rest/StandaloneRestSubmitSuite.scala | 8 ++++---- .../deploy/worker/CommandUtilsSuite.scala | 3 ++- .../spark/deploy/worker/DriverRunnerTest.scala | 2 +- .../deploy/worker/ExecutorRunnerTest.scala | 2 +- .../spark/deploy/worker/WorkerSuite.scala | 4 ++-- .../deploy/worker/WorkerWatcherSuite.scala | 5 ++--- .../input/WholeTextFileRecordReaderSuite.scala | 5 ++--- .../spark/memory/MemoryManagerSuite.scala | 2 +- .../spark/memory/MemoryTestingUtils.scala | 2 +- .../spark/memory/TestMemoryManager.scala | 2 +- .../metrics/InputOutputMetricsSuite.scala | 10 +++++----- .../spark/metrics/MetricsConfigSuite.scala | 3 +-- .../spark/metrics/MetricsSystemSuite.scala | 7 +++---- .../NettyBlockTransferSecuritySuite.scala | 13 +++++++------ .../netty/NettyBlockTransferServiceSuite.scala | 5 +++-- .../spark/rdd/LocalCheckpointSuite.scala | 4 ++-- .../spark/rdd/PairRDDFunctionsSuite.scala | 16 ++++++++-------- .../org/apache/spark/rdd/PipedRDDSuite.scala | 8 ++++---- .../scala/org/apache/spark/rdd/RDDSuite.scala | 8 ++++---- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 ++-- .../apache/spark/rpc/netty/InboxSuite.scala | 2 +- .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 4 ++-- .../CoarseGrainedSchedulerBackendSuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/MapStatusSuite.scala | 8 ++++---- .../scheduler/NotSerializableFakeTask.scala | 2 +- ...tputCommitCoordinatorIntegrationSuite.scala | 4 ++-- .../OutputCommitCoordinatorSuite.scala | 13 ++++++------- .../spark/scheduler/SparkListenerSuite.scala | 3 +-- .../spark/scheduler/TaskContextSuite.scala | 6 ++---- .../CoarseMesosSchedulerBackendSuite.scala | 8 ++++---- .../mesos/MesosSchedulerBackendSuite.scala | 8 ++++---- .../mesos/MesosClusterSchedulerSuite.scala | 3 +-- .../GenericAvroSerializerSuite.scala | 6 +++--- .../KryoSerializerDistributedSuite.scala | 3 +-- .../KryoSerializerResizableOutputSuite.scala | 3 +-- .../spark/serializer/KryoSerializerSuite.scala | 5 ++--- .../spark/serializer/TestSerializer.scala | 3 +-- .../BypassMergeSortShuffleWriterSuite.scala | 6 +++--- .../storage/BlockManagerReplicationSuite.scala | 4 ++-- .../spark/storage/BlockManagerSuite.scala | 7 +++---- .../apache/spark/storage/LocalDirsSuite.scala | 3 +-- .../org/apache/spark/ui/UISeleniumSuite.scala | 6 +++--- .../scala/org/apache/spark/ui/UISuite.scala | 2 +- .../scope/RDDOperationGraphListenerSuite.scala | 3 --- .../spark/ui/storage/StorageTabSuite.scala | 1 + .../spark/util/ClosureCleanerSuite.scala | 2 +- .../apache/spark/util/FileAppenderSuite.scala | 5 ++--- .../apache/spark/util/JsonProtocolSuite.scala | 5 ++--- .../apache/spark/util/SizeEstimatorSuite.scala | 2 +- .../apache/spark/util/ThreadUtilsSuite.scala | 2 +- .../org/apache/spark/util/UtilsSuite.scala | 3 ++- .../util/collection/ExternalSorterSuite.scala | 4 +--- .../unsafe/sort/PrefixComparatorsSuite.scala | 1 + .../spark/util/random/RandomSamplerSuite.scala | 3 ++- .../util/random/XORShiftRandomSuite.scala | 5 ++--- .../sql/jdbc/DockerJDBCIntegrationSuite.scala | 2 +- .../sql/jdbc/PostgresIntegrationSuite.scala | 2 +- .../org/apache/spark/util/DockerUtils.scala | 2 +- .../streaming/flume/sink/SparkSinkSuite.scala | 2 +- .../spark/streaming/TestOutputStream.scala | 6 +++--- .../flume/FlumePollingStreamSuite.scala | 4 ++-- .../kafka/DirectKafkaStreamSuite.scala | 5 ++--- .../spark/streaming/kafka/KafkaRDDSuite.scala | 2 +- .../spark/streaming/mqtt/MQTTTestUtils.scala | 2 +- .../streaming/twitter/TwitterStreamSuite.scala | 5 ++--- .../spark/graphx/impl/EdgePartitionSuite.scala | 3 +-- .../graphx/impl/VertexPartitionSuite.scala | 3 +-- .../LogisticRegressionSuite.scala | 2 +- .../ml/classification/OneVsRestSuite.scala | 4 ++-- .../spark/ml/feature/InteractionSuite.scala | 2 +- .../org/apache/spark/ml/feature/PCASuite.scala | 2 +- .../ml/feature/PolynomialExpansionSuite.scala | 2 +- .../ml/feature/QuantileDiscretizerSuite.scala | 2 +- .../spark/ml/feature/StandardScalerSuite.scala | 1 - .../spark/ml/feature/StringIndexerSuite.scala | 2 +- .../ml/feature/VectorAssemblerSuite.scala | 2 +- .../spark/ml/feature/VectorSlicerSuite.scala | 2 +- .../spark/ml/feature/Word2VecSuite.scala | 2 +- .../org/apache/spark/ml/impl/TreeTests.scala | 3 +-- .../spark/ml/recommendation/ALSSuite.scala | 3 +-- .../DecisionTreeRegressorSuite.scala | 3 +-- .../ml/regression/LinearRegressionSuite.scala | 2 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 10 +++++----- .../spark/ml/util/DefaultReadWriteTest.scala | 2 +- .../mllib/api/python/PythonMLLibAPISuite.scala | 4 ++-- .../StreamingLogisticRegressionSuite.scala | 2 +- .../clustering/GaussianMixtureSuite.scala | 2 +- .../spark/mllib/clustering/LDASuite.scala | 2 +- .../mllib/evaluation/RankingMetricsSuite.scala | 2 +- .../apache/spark/mllib/feature/IDFSuite.scala | 2 +- .../mllib/feature/StandardScalerSuite.scala | 2 +- .../spark/mllib/feature/Word2VecSuite.scala | 1 - .../apache/spark/mllib/linalg/BLASSuite.scala | 2 +- .../linalg/BreezeMatrixConversionSuite.scala | 2 +- .../spark/mllib/linalg/VectorsSuite.scala | 2 +- .../linalg/distributed/BlockMatrixSuite.scala | 2 +- .../distributed/CoordinateMatrixSuite.scala | 2 +- .../distributed/IndexedRowMatrixSuite.scala | 2 +- .../linalg/distributed/RowMatrixSuite.scala | 4 ++-- .../optimization/GradientDescentSuite.scala | 2 +- .../spark/mllib/random/RandomRDDsSuite.scala | 4 ++-- .../mllib/rdd/MLPairRDDFunctionsSuite.scala | 2 +- .../spark/mllib/rdd/RDDFunctionsSuite.scala | 2 +- .../spark/mllib/regression/LassoSuite.scala | 2 +- .../regression/LinearRegressionSuite.scala | 2 +- .../regression/RidgeRegressionSuite.scala | 2 +- .../spark/mllib/stat/StreamingTestSuite.scala | 4 ++-- .../MultivariateGaussianSuite.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 2 +- .../spark/mllib/tree/EnsembleTestHelper.scala | 4 ++-- .../mllib/tree/GradientBoostedTreesSuite.scala | 5 ++--- .../mllib/util/LocalClusterSparkContext.scala | 2 +- .../apache/spark/mllib/util/TestingUtils.scala | 3 ++- .../spark/mllib/util/TestingUtilsSuite.scala | 3 ++- .../spark/repl/ExecutorClassLoaderSuite.scala | 8 ++++---- .../scala/org/apache/spark/sql/RowTest.scala | 3 ++- .../spark/sql/catalyst/DistributionSuite.scala | 3 +-- .../spark/sql/catalyst/SqlParserSuite.scala | 4 ++-- .../catalyst/analysis/AnalysisErrorSuite.scala | 12 ++++++------ .../sql/catalyst/analysis/AnalysisTest.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 4 ++-- .../analysis/ExpressionTypeCheckingSuite.scala | 2 +- .../analysis/HiveTypeCoercionSuite.scala | 3 +-- .../encoders/ExpressionEncoderSuite.scala | 7 ++++--- .../catalyst/encoders/RowEncoderSuite.scala | 2 +- .../sql/catalyst/expressions/CastSuite.scala | 4 ++-- .../ConditionalExpressionSuite.scala | 3 +-- .../expressions/DecimalExpressionSuite.scala | 3 +-- .../expressions/MiscFunctionsSuite.scala | 2 +- .../catalyst/expressions/OrderingSuite.scala | 4 ++-- .../aggregate/HyperLogLogPlusPlusSuite.scala | 9 +++++---- .../optimizer/AggregateOptimizeSuite.scala | 2 +- .../optimizer/BooleanSimplificationSuite.scala | 6 +++--- .../optimizer/ColumnPruningSuite.scala | 4 ++-- .../optimizer/CombiningLimitsSuite.scala | 4 ++-- .../optimizer/ConstantFoldingSuite.scala | 11 +++++------ .../optimizer/FilterPushdownSuite.scala | 6 +++--- .../optimizer/LikeSimplificationSuite.scala | 7 +++---- .../catalyst/optimizer/OptimizeInSuite.scala | 8 ++++---- .../optimizer/ProjectCollapsingSuite.scala | 3 +-- .../optimizer/SetOperationPushDownSuite.scala | 4 ++-- ...implifyCaseConversionExpressionsSuite.scala | 7 +++---- .../spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../sql/catalyst/plans/SameResultSuite.scala | 4 ++-- .../sql/catalyst/trees/TreeNodeSuite.scala | 2 +- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 2 +- .../sql/catalyst/util/MetadataSuite.scala | 2 +- .../apache/spark/sql/types/DecimalSuite.scala | 5 +++-- .../apache/spark/sql/CachedTableSuite.scala | 10 ++++------ .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../spark/sql/DataFrameWindowSuite.scala | 2 +- .../spark/sql/DatasetAggregatorSuite.scala | 5 ++--- .../org/apache/spark/sql/DatasetSuite.scala | 3 +-- .../apache/spark/sql/DateFunctionsSuite.scala | 2 +- .../spark/sql/ExtraStrategiesSuite.scala | 2 +- .../org/apache/spark/sql/ListTablesSuite.scala | 2 +- .../spark/sql/MultiSQLContextsSuite.scala | 3 ++- .../scala/org/apache/spark/sql/QueryTest.scala | 10 +++++----- .../scala/org/apache/spark/sql/RowSuite.scala | 2 +- .../org/apache/spark/sql/SQLConfSuite.scala | 3 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/UserDefinedTypeSuite.scala | 4 +--- .../execution/ExchangeCoordinatorSuite.scala | 4 ++-- .../sql/execution/GroupedIteratorSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../apache/spark/sql/execution/SortSuite.scala | 3 +-- .../UnsafeFixedWidthAggregationMapSuite.scala | 8 ++++---- .../UnsafeKVExternalSorterSuite.scala | 2 +- .../execution/UnsafeRowSerializerSuite.scala | 11 +++++------ .../execution/columnar/ColumnTypeSuite.scala | 7 +++---- .../execution/columnar/ColumnarTestUtils.scala | 2 +- .../columnar/NullableColumnAccessorSuite.scala | 2 +- .../columnar/NullableColumnBuilderSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 2 +- .../execution/datasources/json/JsonSuite.scala | 1 + .../parquet/ParquetCompatibilityTest.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 7 ++----- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 2 +- .../datasources/parquet/ParquetTest.scala | 7 +++---- .../execution/datasources/text/TextSuite.scala | 3 +-- .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 4 ++-- .../sql/execution/joins/SemiJoinSuite.scala | 4 ++-- .../execution/local/HashJoinNodeSuite.scala | 4 ++-- .../sql/execution/local/LocalNodeTest.scala | 3 +-- .../sql/execution/ui/SQLListenerSuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../sql/sources/PartitionedWriteSuite.scala | 2 +- .../spark/sql/sources/SaveLoadSuite.scala | 2 +- .../spark/sql/test/ProcessTestUtils.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../spark/sql/hive/thriftserver/CliSuite.scala | 6 +++--- .../thriftserver/HiveThriftServer2Suites.scala | 4 ++-- .../spark/sql/hive/CachedTableSuite.scala | 2 +- .../spark/sql/hive/ErrorPositionSuite.scala | 3 +-- .../sql/hive/HiveDataFrameAnalyticsSuite.scala | 3 ++- .../sql/hive/HiveDataFrameJoinSuite.scala | 2 +- .../spark/sql/hive/HiveInspectorSuite.scala | 2 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 2 +- .../spark/sql/hive/HiveParquetSuite.scala | 2 +- .../apache/spark/sql/hive/HiveQlSuite.scala | 5 ++--- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../spark/sql/hive/ListTablesSuite.scala | 4 ++-- .../sql/hive/MetastoreDataSourcesSuite.scala | 6 +++--- .../spark/sql/hive/MultiDatabaseSuite.scala | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 2 +- .../spark/sql/hive/QueryPartitionSuite.scala | 4 ++-- .../spark/sql/hive/StatisticsSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 ++-- .../hive/execution/ConcurrentHiveSuite.scala | 3 ++- .../hive/execution/HiveComparisonTest.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 2 +- .../execution/HiveOperatorQueryableSuite.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 6 +++--- .../hive/execution/HiveTableScanSuite.scala | 1 - .../sql/hive/execution/HiveUDFSuite.scala | 12 ++++++------ .../sql/hive/execution/SQLQuerySuite.scala | 8 ++++---- .../execution/ScriptTransformationSuite.scala | 2 +- .../spark/sql/hive/orc/OrcFilterSuite.scala | 2 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 2 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 2 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 4 ++-- .../CommitFailureTestRelationSuite.scala | 2 +- .../SimpleTextHadoopFsRelationSuite.scala | 5 ++--- .../spark/sql/sources/SimpleTextRelation.scala | 6 +++--- .../spark/streaming/DStreamScopeSuite.scala | 2 +- .../apache/spark/streaming/FailureSuite.scala | 2 +- .../spark/streaming/InputStreamsSuite.scala | 16 ++++++++-------- .../spark/streaming/MapWithStateSuite.scala | 4 ++-- .../spark/streaming/MasterFailureTest.scala | 18 ++++++++---------- .../streaming/ReceivedBlockHandlerSuite.scala | 5 +++-- .../streaming/ReceiverInputDStreamSuite.scala | 2 +- .../apache/spark/streaming/ReceiverSuite.scala | 2 +- .../streaming/StreamingListenerSuite.scala | 12 ++++++------ .../apache/spark/streaming/TestSuiteBase.scala | 6 +++--- .../streaming/WindowOperationsSuite.scala | 2 +- .../streaming/rdd/MapWithStateRDDSuite.scala | 2 +- .../rdd/WriteAheadLogBackedBlockRDDSuite.scala | 2 +- .../receiver/BlockGeneratorSuite.scala | 4 ++-- .../scheduler/InputInfoTrackerSuite.scala | 2 +- .../streaming/util/WriteAheadLogSuite.scala | 11 +++++------ .../util/WriteAheadLogUtilsSuite.scala | 2 +- .../types/UTF8StringPropertyCheckSuite.scala | 1 - .../ClientDistributedCacheManagerSuite.scala | 14 ++++++-------- .../apache/spark/deploy/yarn/ClientSuite.scala | 2 +- .../spark/deploy/yarn/YarnAllocatorSuite.scala | 9 +++------ .../deploy/yarn/YarnSparkHadoopUtilSuite.scala | 6 ++---- .../network/shuffle/ShuffleTestAccessor.scala | 2 +- 281 files changed, 517 insertions(+), 575 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 0c14bef7befd8..7b0238091730d 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -24,18 +24,14 @@ import scala.language.existentials import scala.util.Random import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.time.SpanSugar._ -import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD} -import org.apache.spark.storage._ +import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BroadcastBlockId -import org.apache.spark.storage.RDDBlockId -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.storage.ShuffleIndexBlockId +import org.apache.spark.storage._ /** * An abstract base class for context cleaner tests, which sets up a context with a config diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index fedfbd547b91b..4e678fbac6a39 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable import org.scalatest.{BeforeAndAfter, PrivateMethodTester} + import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 203dab934ca1f..3def8b0b1850e 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark -import org.apache.spark.util.NonSerializable - import java.io.{IOException, NotSerializableException, ObjectInputStream} +import org.apache.spark.util.NonSerializable + // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully // be copied for each task, so there's no other way for them to share state. diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 2c32b69715484..bc7059b77fec5 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -27,10 +27,10 @@ import org.apache.commons.lang3.RandomUtils import org.apache.spark.util.Utils -import SSLSampleConfigs._ - class FileServerSuite extends SparkFunSuite with LocalSparkContext { + import SSLSampleConfigs._ + @transient var tmpDir: File = _ @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 2e47801aafd75..993834f8d7d42 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -21,17 +21,16 @@ import java.io.{File, FileWriter} import scala.io.Source -import org.apache.spark.input.PortableDataStream -import org.apache.spark.storage.StorageLevel - import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec -import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} +import org.apache.hadoop.mapred.{FileAlreadyExistsException, FileSplit, JobConf, TextInputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class FileSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 9b43341576a8a..18e53508406dc 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -25,13 +25,13 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} -import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ +import org.mockito.Mockito.{mock, spy, verify, when} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 214681970acbf..e1a0bf7c933b9 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -17,7 +17,7 @@ package org.apache.spark -import _root_.io.netty.util.internal.logging.{Slf4JLoggerFactory, InternalLoggerFactory} +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 5b29d69cd9428..3819c0a8f31dc 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} +import org.mockito.Mockito._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 25b79bce6ab98..fa35819f55ac2 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -21,9 +21,10 @@ import java.io.File import javax.net.ssl.SSLContext import com.google.common.io.Files -import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.util.Utils + class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 0de10ae485378..c45d81459e8e2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark -import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier} +import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter -import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} +import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 5354731465a4a..7a897c2b4698f 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -26,8 +26,8 @@ import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util.Utils class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index ff9a92cc0a421..2fe99e3f81948 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -17,17 +17,18 @@ package org.apache.spark -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.{Try, Random} +import scala.util.{Random, Try} + +import com.esotericsoftware.kryo.Kryo import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} -import org.apache.spark.util.{RpcUtils, ResetSystemProperties} -import com.esotericsoftware.kryo.Kryo +import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("Test byteString conversion") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 2bdbd70c638a5..3706455c3facc 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.scalatest.Assertions + import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index d18e0782c0392..52919c1ec0b1e 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark import org.scalatest.PrivateMethodTester -import org.apache.spark.util.Utils import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend +import org.apache.spark.util.Utils class SparkContextSchedulerCreationSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 172ef050cc275..556afd08bbfe5 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,18 +20,18 @@ package org.apache.spark import java.io.File import java.util.concurrent.TimeUnit +import scala.concurrent.Await +import scala.concurrent.duration.Duration + import com.google.common.base.Charsets._ import com.google.common.io.Files - import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} -import org.apache.spark.util.Utils - -import scala.concurrent.Await -import scala.concurrent.duration.Duration import org.scalatest.Matchers._ +import org.apache.spark.util.Utils + class SparkContextSuite extends SparkFunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 54c131cdae367..fc31b784c7ae1 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark -import java.util.concurrent.{TimeUnit, Semaphore} -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import org.apache.spark.scheduler._ diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index 135c56bf5bc9d..b38a3667abee1 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.api.python -import scala.io.Source +import java.io.{File, PrintWriter} -import java.io.{PrintWriter, File} +import scala.io.Source import org.scalatest.Matchers diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 3164760b08a71..86455a13d0fe7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -20,9 +20,9 @@ package org.apache.spark.deploy import java.io.File import java.util.Date +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{SecurityManager, SparkConf} private[deploy] object DeployTestUtils { def createAppDesc(): ApplicationDescription = { diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index d93febcfd23fd..9ecf49b59898b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -24,10 +24,8 @@ import java.util.jar.Manifest import scala.collection.mutable.ArrayBuffer -import com.google.common.io.{Files, ByteStreams} - +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.FileUtils - import org.apache.ivy.core.settings.IvySettings import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 0a9f128a3a6b6..2d48e75cfbd96 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -23,10 +23,10 @@ import com.fasterxml.jackson.core.JsonParseException import org.json4s._ import org.json4s.jackson.JsonMethods +import org.apache.spark.{JsonTestUtils, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} import org.apache.spark.deploy.worker.ExecutorRunner -import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 8dd31b4b6fdda..f416ace5c2b71 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -22,9 +22,9 @@ import java.net.URL import scala.collection.mutable import scala.io.Source -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.SparkConfWithEnv class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index cc30ba223e1c3..13cba94578a6a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy -import java.io.{PrintStream, OutputStream, File} +import java.io.{File, OutputStream, PrintStream} import java.net.URI -import java.util.jar.Attributes.Name import java.util.jar.{JarFile, Manifest} +import java.util.jar.Attributes.Name import java.util.zip.ZipFile import scala.collection.JavaConverters._ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 4b5039b668a46..4877710c1237d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.deploy -import java.io.{File, PrintStream, OutputStream} +import java.io.{File, OutputStream, PrintStream} import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.{AbstractResolver, FileSystemResolver, IBiblioResolver} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 5cab17f8a38f5..6cbf911395a84 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -23,8 +23,8 @@ import java.net.URI import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} -import scala.io.Source import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import com.google.common.base.Charsets diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 7a44728675680..b4deed7f877e8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -25,7 +25,7 @@ import org.apache.curator.test.TestingServer import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} -import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.util.Utils class PersistenceEngineSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index fba835f054f8a..0c9382a92bcaf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -23,11 +23,11 @@ import scala.io.Source import scala.language.postfixOps import org.json4s.jackson.JsonMethods._ -import org.json4s.JsonAST.{JNothing, JString, JInt} +import org.json4s.JsonAST.{JInt, JNothing, JString} import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.MasterStateResponse import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index fa39aa2cb1311..ee889bf144546 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -24,16 +24,16 @@ import javax.servlet.http.HttpServletResponse import scala.collection.mutable import com.google.common.base.Charsets -import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ +import org.scalatest.BeforeAndAfterEach import org.apache.spark._ -import org.apache.spark.rpc._ -import org.apache.spark.util.Utils -import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils /** * Tests for the REST application submission protocol used in standalone cluster mode. diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 7101cb9978df3..607c0a4fac46b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy.worker +import org.scalatest.{Matchers, PrivateMethodTester} + import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.{Matchers, PrivateMethodTester} class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 6258c18d177fd..bd8b0655f4bbf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File -import org.mockito.Mockito._ import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 98664dc1101e6..0240bf8aed4cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File -import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 082d5e86eb512..101a44edd8ee2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.deploy.worker import org.scalatest.Matchers +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.{Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState -import org.apache.spark.deploy.{Command, ExecutorState} import org.apache.spark.rpc.{RpcAddress, RpcEnv} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} class WorkerSuite extends SparkFunSuite with Matchers { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 0ffd91d8ffc06..31bea3293ae77 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcEndpointAddress, RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 24184b02cb4c1..d852255a4fd29 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,13 +23,12 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.io.Text +import org.apache.hadoop.io.compress.{CompressionCodecFactory, DefaultCodec, GzipCodec} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} /** * Tests the correctness of diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 555b640cb4244..f2924a6a5c052 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration import org.mockito.Matchers.{any, anyLong} import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 4b4c3b0311328..0e60cc8e77873 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.memory -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} /** * Helper methods for mocking out memory-management-related classes in tests. diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 0706a6e45de8f..4a1e49b45df40 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import scala.collection.mutable import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockStatus, BlockId} +import org.apache.spark.storage.{BlockId, BlockStatus} class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 44eb5a0469122..aaf62e0f91067 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -25,17 +25,17 @@ import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, + JobConf, LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, + Reporter, TextInputFormat => OldTextInputFormat} import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} -import org.apache.hadoop.mapred.{JobConf, Reporter, FileSplit => OldFileSplit, - InputSplit => OldInputSplit, LineRecordReader => OldLineRecordReader, - RecordReader => OldRecordReader, TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, - RecordReader => NewRecordReader} import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 41f2ff725a17b..b24f5d732f292 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.metrics -import org.apache.spark.SparkConf - import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkConf import org.apache.spark.SparkFunSuite class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 9c389c76bf3bd..5d8554229dbe1 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.metrics +import scala.collection.mutable.ArrayBuffer + +import com.codahale.metrics.MetricRegistry import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.metrics.source.Source -import com.codahale.metrics.MetricRegistry - -import scala.collection.mutable.ArrayBuffer - class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 98da94139f7f8..47dbcb8fc0eaa 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -22,20 +22,21 @@ import java.nio._ import java.nio.charset.Charset import java.util.concurrent.TimeUnit -import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} import com.google.common.io.CharStreams -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.network.{BlockDataManager, BlockTransferService} -import org.apache.spark.storage.{BlockId, ShuffleBlockId} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.ShouldMatchers +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.storage.{BlockId, ShuffleBlockId} + class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 92daf4e6a2169..cc1a9e0287082 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.network.netty -import org.apache.spark.network.BlockDataManager -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.BlockDataManager + class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 3a22a9850a096..e694f5e5e7ad2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.rdd -import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite} - import org.mockito.Mockito.spy + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} /** diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 7d2cfcca9436a..16e2d2e636c11 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.rdd -import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution} -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util.Progressable - import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, -OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, -TaskAttemptContext => NewTaskAttempContext} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, + OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, + RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} +import org.apache.hadoop.util.Progressable + import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 5f73ec8675966..1eebc924a534d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.rdd import java.io.File -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} - import scala.collection.Map import scala.language.postfixOps import scala.sys.process._ import scala.util.Try +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} + import org.apache.spark._ import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 18d1466bb7c30..24acbed4d7258 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.rdd -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import com.esotericsoftware.kryo.KryoException - -import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.ClassTag +import com.esotericsoftware.kryo.KryoException + import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 924fce7f61c26..64e486d791cde 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.rpc import java.io.{File, NotSerializableException} -import java.util.UUID import java.nio.charset.StandardCharsets.UTF_8 -import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} +import java.util.UUID +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import scala.collection.mutable import scala.concurrent.Await diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 2136795b18813..12113be75c238 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, TestRpcEndpoint} class InboxSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index d4aebe9fd915e..0c156fef0ae0f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -21,11 +21,11 @@ import java.net.InetSocketAddress import java.nio.ByteBuffer import io.netty.channel.Channel -import org.mockito.Mockito._ import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.client.{TransportClient, TransportResponseHandler} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index eef6aafa624ee..70f40fb26c2f6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.util.{SerializableBuffer, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, SerializableBuffer} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2869f0fde4c53..370a284d2950f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 15c8de61b8240..56e0f01b3b41b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.scheduler -import org.apache.spark.storage.BlockManagerId +import scala.util.Random -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer import org.roaringbitmap.RoaringBitmap -import scala.util.Random +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index f33324792495b..1dca4bd89fd9e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import org.apache.spark.TaskContext diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 1ae5b030f0832..9f41aca8a1e14 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} import org.scalatest.concurrent.Timeouts -import org.scalatest.time.{Span, Seconds} +import org.scalatest.time.{Seconds, Span} -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} import org.apache.spark.util.Utils /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 7345508bfe995..c461da65bdc43 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -20,22 +20,21 @@ package org.apache.spark.scheduler import java.io.File import java.util.concurrent.TimeoutException +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter -import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} - import org.apache.spark._ -import org.apache.spark.rdd.{RDD, FakeOutputCommitter} +import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.Utils -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps - /** * Unit tests for the output commit coordination functionality. * diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index f20d5be7c0ee0..dc15f5932d6f8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,10 +24,9 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers -import org.apache.spark.SparkException +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 40ebfdde928fe..e5ec44a9f3b6b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.scheduler -import org.mockito.Mockito._ import org.mockito.Matchers.any - +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import org.apache.spark.metrics.source.JvmSource - class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 525ee0d3bdc5a..a4110d2d462de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -20,17 +20,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.util import java.util.Collections -import org.apache.mesos.Protos.Value.Scalar -import org.apache.mesos.Protos._ import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.Matchers import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.Matchers import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter +import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index c4dc560031207..504e5780f3d8a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -26,19 +26,19 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.SchedulerDriver +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.{ArgumentCaptor, Matchers} import org.scalatest.mock.MockitoSugar +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.cluster.ExecutorInfo class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala index f5cef1caaf1ac..98fdc58786ec8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -21,11 +21,10 @@ import java.util.Date import org.scalatest.mock.MockitoSugar +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} - class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index 87f25e7245e1f..3734f1cb408fe 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer -import com.esotericsoftware.kryo.io.{Output, Input} -import org.apache.avro.{SchemaBuilder, Schema} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.generic.GenericData.Record -import org.apache.spark.{SparkFunSuite, SharedSparkContext} +import org.apache.spark.{SharedSparkContext, SparkFunSuite} class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 935a091f14f9b..a0483f6483889 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.serializer -import org.apache.spark.util.Utils - import com.esotericsoftware.kryo.Kryo import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ +import org.apache.spark.util.Utils class KryoSerializerDistributedSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index a9b209ccfc76e..21251f0b93760 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkContext import org.apache.spark.SparkException - class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 9fcc22b608c65..8f9b453a6eeec 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileOutputStream, FileInputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -25,14 +25,13 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} - import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ -import org.apache.spark.util.Utils import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index c1e0a29a34bb1..17037870f7a15 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -17,12 +17,11 @@ package org.apache.spark.serializer -import java.io.{EOFException, OutputStream, InputStream} +import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer import scala.reflect.ClassTag - /** * A serializer implementation that always returns two elements in a deserialization stream. */ diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index bb331bb385df3..e33408b94e2cf 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -23,8 +23,8 @@ import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -32,9 +32,9 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach import org.apache.spark._ -import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 6e3f500e15dc0..3fd6fb4560465 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,11 +26,11 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 2224a444c7b54..21db3b1c9ffbd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -25,24 +25,23 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import org.mockito.Mockito.{mock, when} import org.mockito.{Matchers => mc} +import org.mockito.Mockito.{mock, when} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ - class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index cc50289c7b3ea..c7074078d8fd2 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.storage import java.io.File -import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.SparkConfWithEnv +import org.apache.spark.util.{SparkConfWithEnv, Utils} /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 0e36d7fda430d..aa22f3ba2b4d1 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source import scala.xml.Node @@ -26,16 +26,16 @@ import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler import org.json4s._ import org.json4s.jackson.JsonMethods -import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.openqa.selenium.{By, WebDriver} +import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ import org.w3c.css.sac.CSSParseException -import org.apache.spark.LocalSparkContext._ import org.apache.spark._ +import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 8f9502b5673d1..2d28b67ef23fa 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -26,8 +26,8 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ class UISuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index 86b078851851f..3fb78da0c7476 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.ui.scope import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SparkListenerStageSubmitted -import org.apache.spark.scheduler.SparkListenerStageCompleted -import org.apache.spark.scheduler.SparkListenerJobStart /** * Tests that this listener populates and cleans up its data structures properly. diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 37e2670de9685..4b838a8ab1335 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.storage import org.scalatest.BeforeAndAfter + import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 5e745e0a95769..932704c1a3659 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.util import java.io.NotSerializableException -import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} +import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 2b76ae1f8a24b..98d1b28d5a167 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -22,13 +22,12 @@ import java.io._ import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.BeforeAndAfter - import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.scalatest.BeforeAndAfter import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} +import org.apache.spark.util.logging.{FileAppender, RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy} class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 1939ce5c743b0..6566400e63799 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.util import java.util.Properties -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.shuffle.MetadataFetchFailedException - import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -30,6 +27,8 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage._ class JsonProtocolSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index fbe7b956682d5..49088aa0a53b5 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.SparkFunSuite diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 92ae038967528..6652a41b6990b 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.duration._ import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ import scala.util.Random import org.scalatest.concurrent.Eventually._ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 7de995af512db..bc926c280c7cd 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -33,8 +33,9 @@ import com.google.common.io.Files import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.network.util.ByteUnit + import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.ByteUnit class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index d7b2d07a40052..a62adf1c2c543 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.MemoryTestingUtils - import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} - class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { import TestUtils.{assertNotSpilled, assertSpilled} diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 0326ed70b5edb..c12f78447197c 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection.unsafe.sort import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks + import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index d6af0aebde733..791491daf0817 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.util.random import java.util.Random + import scala.collection.mutable.ArrayBuffer -import org.apache.commons.math3.distribution.PoissonDistribution +import org.apache.commons.math3.distribution.PoissonDistribution import org.scalatest.Matchers import org.apache.spark.SparkFunSuite diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index a5b50fce5c0a9..853503bbc2bba 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.util.random -import org.scalatest.Matchers +import scala.language.reflectiveCalls import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times -import scala.language.reflectiveCalls - class XORShiftRandomSuite extends SparkFunSuite with Matchers { private def fixture = new { diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index c503c4a13b482..f73231fc80a08 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -30,8 +30,8 @@ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.util.DockerUtils import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.DockerUtils abstract class DatabaseOnDocker { /** diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 6eb6b3391a4a4..559dc1fed163a 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import java.util.Properties import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Literal, If} +import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.tags.DockerTest @DockerTest diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala index 87271776d8564..fda377e032350 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.net.{Inet4Address, NetworkInterface, InetAddress} +import java.net.{Inet4Address, InetAddress, NetworkInterface} import scala.collection.JavaConverters._ import scala.sys.process._ diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 941fde45cd7b7..7f6cecf9cd18d 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume.sink import java.net.InetSocketAddress +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala index 79077e4a49e1a..57374ef515431 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming import java.io.{IOException, ObjectInputStream} +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} import org.apache.spark.util.Utils -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - /** * This is a output stream just for the testsuites. All the output is collected into a * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index bb951a6ef100d..60db846ffb7a2 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import scala.collection.JavaConverters._ -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps @@ -30,8 +30,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestOutputStream} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 02225d5aa7cc5..655b161734d94 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.streaming.kafka import java.io.File import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.rate.RateEstimator - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -38,7 +35,9 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils class DirectKafkaStreamSuite diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index f52a738afd65b..5e539c1d790cc 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming.kafka import scala.util.Random -import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder import org.scalatest.BeforeAndAfterAll import org.apache.spark._ diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala index 1618e2c088b70..26c6dc45d5115 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -27,8 +27,8 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils /** * Share codes for Scala and Python unit tests diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index d9acb568879fe..7e5fc0cbb9b30 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.streaming.twitter - import org.scalatest.BeforeAndAfter import twitter4j.Status -import twitter4j.auth.{NullAuthorization, Authorization} +import twitter4j.auth.{Authorization, NullAuthorization} import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 7435647c6d9ee..a73dfd219ea43 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -21,11 +21,10 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.graphx._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.graphx._ - class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index 1203f8959f506..0fb8451fdcab1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.graphx.impl import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.graphx._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.graphx._ - class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 1087afb0cdf79..ff0d0ff771042 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 5ea71c5317b7a..d7983f92a3483 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,9 +21,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 932d331b472b9..0d4e00668ddb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 9f6618b929296..f372ec58269e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 70892dc57170a..dfdc5792c6dbc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 3a4f6d235aa6c..722f1abde4359 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{SparkContext, SparkFunSuite} class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 1eae125a524ef..28631cef79431 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature - import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 749bfac747826..5d199ca9b51b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -25,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 9c1c00f41ab1d..f7de7c1e93fb2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 8acc3369c489c..94191e5df383b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.StructType class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index d561bbbb25529..a73b565125668 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 4e2d0e93bd412..a808177cb9bf0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -25,8 +25,7 @@ import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericA import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame} - +import org.apache.spark.sql.{DataFrame, SQLContext} private[ml] object TreeTests extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2c3fb84160dcb..ff0d8f5568279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,7 +25,6 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} - +import org.apache.spark.util.Utils class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 0b39af5543e93..13165f67014c5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -26,8 +26,7 @@ import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} - +import org.apache.spark.sql.{DataFrame, Row} class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2f3e703f4c252..273c882c2a47f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index d281084f913c0..56545de14bd30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.ml.{Pipeline, Estimator, Model} -import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} +import org.apache.spark.ml.{Estimator, Model, Pipeline} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.{ParamPair, ParamMap} +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 84d06b43d6224..0aa774b66078e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 59944416d96a6..0eb839f20c003 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.mllib.api.python import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors} import org.apache.spark.mllib.recommendation.Rating +import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index d7b291d5a6330..bf98bf2f5fde5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.dstream.DStream class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index a72723eb00daf..fb3bd3f412f81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 37fb69d68f6be..faef60e084cc1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.{ArrayList => JArrayList} -import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} +import breeze.linalg.{argmax, argtopk, max, DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index c0924a213a844..77ec49d005398 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 21163633051e5..5c938a61ed990 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 6ab2fa6770123..b4e26b2aeb3cf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 37d01e2876695..e74ecc16ee9f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext - import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 96e5ffef7a131..80da03cc2efeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.linalg import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.util.TestingUtils._ class BLASSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index dc04258e41d27..de2c3c13bd923 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.linalg -import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} +import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM} import org.apache.spark.SparkFunSuite diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index f895e2a8e4afb..832ccc0aacf8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} +import breeze.linalg.{squaredDistance => breezeSquaredDistance, DenseMatrix => BDM} import org.json4s.jackson.JsonMethods.{parse => parseJson} import org.apache.spark.{Logging, SparkException, SparkFunSuite} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index b8eb10305801c..d91ba8a6fdb72 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -22,7 +22,7 @@ import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f3728cd036a3f..37d75103d18d2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 6de6cf2fa8634..5b7ccb90158b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrices, Vectors} class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 0ff901ddc4979..2dff52c601d81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -21,11 +21,11 @@ import java.util.Arrays import scala.util.Random +import breeze.linalg.{norm => brzNorm, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.abs -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 36ac7d267243d..1c9b7c78e5b8d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext, MLUtils} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 413db2000d6d7..0b4c7eb302d40 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} +import org.apache.spark.mllib.rdd.{RandomRDD, RandomRDDPartition} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 10f5a2be48f7c..56231429859ee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ +import org.apache.spark.mllib.util.MLlibTestSparkContext class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index ac93733bab5f5..0e931fca6cf07 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.util.MLlibTestSparkContext class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 39537e7bb4c72..d96103d01e4ab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index f88a1c33c9f7c..0694079b9df9e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 7a781fee634c8..8fb8886645cde 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -23,7 +23,7 @@ import org.jblas.DoubleMatrix import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index 3c657c8cfe743..1142102bb040e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, - WelchTTest, BinarySample} +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest, StreamingTestResult, + StudentTTest, WelchTTest} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 6e7a003475458..669d44223d713 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat.distribution import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{ Vectors, Matrices } +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bf8fe1acac2fe..a9c935bd42445 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 3d3f80063f904..1cc8f342021a0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree +import scala.collection.mutable + import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.util.StatCounter -import scala.collection.mutable - object EnsembleTestHelper { /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 6fc9e8df621df..acb3b953b53be 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.tree import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils - /** * Test suite for [[GradientBoostedTrees]]. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 4f73b0809dca4..9b2d023bbf738 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import org.scalatest.{Suite, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.spark.{SparkConf, SparkContext} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 352193a67860c..6de9aaf94f1b2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.util -import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.scalatest.exceptions.TestFailedException +import org.apache.spark.mllib.linalg.{Matrix, Vector} + object TestingUtils { val ABS_TOL_MSG = " using absolute tolerance" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 8f475f30249d6..44c39704e5b92 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.util +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -import org.scalatest.exceptions.TestFailedException class TestingUtilsSuite extends SparkFunSuite { diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 05bf7a3aaefbf..ce3f51bd72dd8 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -30,14 +30,14 @@ import scala.language.implicitConversions import scala.language.postfixOps import com.google.common.io.Files +import org.mockito.Matchers.anyString +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.mockito.Matchers.anyString -import org.mockito.Mockito._ import org.apache.spark._ import org.apache.spark.rpc.RpcEnv diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 72624e7cbc112..1e7118144f2ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import org.scalatest.{FunSpec, Matchers} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types._ -import org.scalatest.{Matchers, FunSpec} class RowTest extends FunSpec with Matchers { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 827f7ce692712..b47b8adfe5d55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.physical._ - /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ class DistributionSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 9ff893b84775b..b0884f528742f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project} import org.apache.spark.unsafe.types.CalendarInterval private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 12079992b5b84..fc35959f20547 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.catalyst.analysis +import scala.beans.{BeanInfo, BeanProperty} + import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Count, Sum, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Sum} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import scala.beans.{BeanProperty, BeanInfo} - @BeanInfo private[sql] case class GroupableData(@BeanProperty data: Int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 23861ed15da61..af214b7af0624 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} trait AnalysisTest extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fed591fd90a9a..39c8f56c1bca6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union} import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index f3df716a57824..0521ed848c793 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{LongType, TypeCollection, StringType} +import org.apache.spark.sql.types.{LongType, StringType, TypeCollection} class ExpressionTypeCheckingSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 142915056f451..58d808c55860d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.PlanTest - import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 98f29e53df9f4..88c558d80a79a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.encoders -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.util.Arrays import java.util.concurrent.ConcurrentMap + import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag @@ -27,10 +28,10 @@ import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.{StructType, ArrayType} +import org.apache.spark.sql.types.{ArrayType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index b17f8d5ec70af..932511134c63a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index c99a4ac9645ac..43af3592070fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} -import java.util.{TimeZone, Calendar} +import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 0df673bb9fa02..4029da5925580 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ - class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("if") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 511f0307901df..a8f758d625a02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} - +import org.apache.spark.sql.types.{Decimal, DecimalType, LongType} class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 9175568f43a4e..64161bebdcbe8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index 7ad8657bde128..b190d3a00dfb8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.math._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index 0d329497758c6..83838294a9918 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.util.Random -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, BoundReference} -import org.apache.spark.sql.types.{DataType, IntegerType} - import scala.collection.mutable + import org.scalatest.Assertions._ +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{BoundReference, MutableRow, SpecificMutableRow} +import org.apache.spark.sql.types.{DataType, IntegerType} + class HyperLogLogPlusPlusSuite extends SparkFunSuite { /** Create a HLL++ instance and an input and output buffer. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 2d080b95b1292..37148a226f293 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index a0c71d83d7e39..000a3b7ecb7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 9bf61ae091786..81f3928035619 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types.StringType class ColumnPruningSuite extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 06c592f4905a3..9fe2b2d1f48ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class CombiningLimitsSuite extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 8aaefa84937c2..48f9ac77b74c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue} +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ -// For implicit conversions -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fba4c5ca77d64..b998636909a7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.IntegerType class FilterPushdownSuite extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index b3df487c84dc8..741bc113cfcda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - class LikeSimplificationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 48cab01ac1004..3e384e473e5f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -18,17 +18,17 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet + import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ -// For implicit conversions -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala index 1aa89991cc698..85b6530481b03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Rand import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor - class ProjectCollapsingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 1595ad9327423..a498b463a69e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class SetOperationPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index 6b1e53cd42b24..41455221cfdc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - class SimplifyCaseConversionExpressionsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2efee1fc54706..f9874088b5884 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.util._ /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 62d5f6ac74885..fb4f34d059b56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 965bdb1515e55..6a188e7e55126 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{IntegerType, StringType, NullType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType} case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 0ce5a2fb69505..d5f1c4d74efcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -22,8 +22,8 @@ import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index 4030a1b1df358..a0c1d97bfc3a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{MetadataBuilder, Metadata} +import org.apache.spark.sql.types.{Metadata, MetadataBuilder} class MetadataSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 50683947da224..e1675c95907af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.types -import org.apache.spark.SparkFunSuite +import scala.language.postfixOps + import org.scalatest.PrivateMethodTester -import scala.language.postfixOps +import org.apache.spark.SparkFunSuite class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 6b735bcf16104..89b9a687682d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql - -import org.apache.spark.sql.execution.Exchange -import org.apache.spark.sql.execution.PhysicalRDD - import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} -import org.apache.spark.storage.{StorageLevel, RDDBlockId} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} private case class BigData(s: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 53a9788024ba4..076db0c08dee0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e8fa663363731..ade1391ecd74a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} +import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.types._ class DataFrameSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 3917b9762ba63..09a56f6f3ae28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DataType, LongType, StructType} class DataFrameWindowSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index c6d2bf07b2803..3258f3782d8cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql - import scala.language.postfixOps -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c19b5a4d98a85..53b5f45c2d4a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import scala.language.postfixOps @@ -26,7 +26,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} - class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index a61c3aa48a73f..f7aa3b747ae5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 359a1e7f8424a..2c4b4f80ff9ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 5688f46e5e3d4..3d7c576965fc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -import org.apache.spark.sql.catalyst.TableIdentifier class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 162c0b56c6e11..6a375a33bfcf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql -import org.apache.spark._ import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ + class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalActiveSQLContext: Option[SQLContext] = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 815372f19233b..0e60573dc6b2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -21,15 +21,15 @@ import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.{LogicalRDD, Queryable} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.{LogicalRDD, Queryable} abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 3ba14d7602a62..4552eb6ce00a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 3d2bd236ceead..43300cd635c05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} - +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 72845711adddd..5de0979606b88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 2a1117318ad14..6800a8ddf6e3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} - import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 2715179e8500c..35ff1c40fe6ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql._ -import org.apache.spark.{SparkFunSuite, SparkContext, SparkConf, MapOutputStatistics} class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index e7a08481cfa80..6f10e4b80577a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} class GroupedIteratorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2fb439f50117a..858e289c2716e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.joins.{SortMergeJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index af971dfc6faec..6259453da26a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{RandomDataGenerator, Row} - /** * Test sorting. Many of the test cases generate random data and compares the sorted result with one diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 5a8406789ab81..9c258cb31f460 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution -import scala.util.control.NonFatal import scala.collection.mutable -import scala.util.{Try, Random} +import scala.util.{Random, Try} +import scala.util.control.NonFatal import org.scalatest.Matchers -import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 29027a664b4b4..95c9550aebb0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 09e258299de5a..9f09eb4429c12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.execution -import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ -import org.apache.spark._ - +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils /** * used to test close InputStream in UnsafeRowSerializer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 706ff1f998501..9ca8c4d2ed2b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ -import org.apache.spark.{Logging, SparkFunSuite} - class ColumnTypeSuite extends SparkFunSuite with Logging { private val DEFAULT_BUFFER_SIZE = 512 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 9cae65ef6f5dc..97cba1e349e8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index 35dc9a276cef7..dc22d3e8e4d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 93be3e16a5ed9..cdd4551d64b50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index ccbddef0fad3a..f67e9c7dae278 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index baa258ad26152..b3b6b7df0c1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 0835bd123049b..4217c81ff3e24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapA import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index f42f173b2a863..587aa5fd30d2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet +import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 7f82cce0a122d..ab48e971b507a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.parquet.column.{Encoding, ParquetProperties} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.Utils - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -29,8 +25,9 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 71e9034d97792..0feb945fbb79a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f777e973052d3..0bc64404f1648 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index fdd7697c91f5e..449fcc860fac9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File -import org.apache.parquet.schema.MessageType - import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -28,12 +26,13 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} +import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.{DataFrame, SaveMode, SQLConf} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 02c416af50cd7..f95272530d585 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.text +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.util.Utils - class TextSuite extends QueryTest with SharedSQLContext { test("reading text file") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 5b2998c3c76d3..58581d71e1bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,8 +22,8 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} /** * Test various broadcast join operators. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 2ec17146476fe..42fadaa8e2215 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 9c80714a9af4a..3d3e9a7b90928 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 3afd762942bcf..9c86084f9b8a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index c30327185e169..eb70747926fe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -22,8 +22,8 @@ import org.mockito.Mockito.{mock, when} import org.apache.spark.broadcast.TorrentBroadcast import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression} -import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedMutableProjection, UnsafeProjection} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} class HashJoinNodeSuite extends LocalNodeTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 615c417093612..1a485f967dd38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.types.{IntegerType, StringType} - class LocalNodeTest extends SparkFunSuite { protected val conf: SQLConf = new SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 11a6ce91116f4..eef3c1f3e34d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.ui import java.util.Properties -import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 73e548e00f588..1fa22e2933318 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources._ import org.apache.spark.util.Utils class JDBCSuite extends SparkFunSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 3eaa817f9c0b0..27b02d6e1ab36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 10d261368993d..e055da9e8a39a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLConf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala index 152c9c8459de9..df530d8587ef7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import scala.sys.process.BasicIO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index e87da1527c4d2..7df344edb4edd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.test import java.io.File import java.util.UUID -import scala.util.Try import scala.language.implicitConversions +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index fcf039916913a..ab31d45a79a2e 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -22,15 +22,15 @@ import java.sql.Timestamp import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import scala.concurrent.duration._ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfterAll -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.util.Utils /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ebb2575416b72..e598284ab22f8 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -23,8 +23,8 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{future, Await, ExecutionContext, Promise} import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Promise, future} import scala.io.Source import scala.util.{Random, Try} @@ -40,10 +40,10 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.{ThreadUtils, Utils} -import org.apache.spark.{Logging, SparkFunSuite} object TestData { def getTestDataFilePath(name: String): URL = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 99478e82d419f..9b37dd1103764 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index cf737836939f9..400f7f3708cf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -21,10 +21,9 @@ import scala.util.Try import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{AnalysisException, QueryTest} - class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { import hiveContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf765265..35e433964da91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index f621367eb553b..63cf5030ab8b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 8bb9058cd74ef..3b867bbfa1817 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index d63f3d3996523..14a83d53904a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row, SaveMode, SQLConf} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{SQLConf, QueryTest, Row, SaveMode} class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { import hiveContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 5596ec6882ea2..7841ffe5e03d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index a330362b4e1d1..f4a1a17422483 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.logical.Generate import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.plans.logical.Generate import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} - class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { private def extractTableDesc(sql: String): (HiveTable, Boolean) = { HiveQl.createPlan(sql).collect { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 2edc8f932c4a1..8932ce9503a3d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.{QueryTest, SQLContext} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 81ee9ba71beb6..da7303c791064 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 183aca29cf98d..a94f7053c39ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { import hiveContext._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f74eb1500b989..e22dac3bc9e87 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.hive -import java.io.{IOException, File} +import java.io.{File, IOException} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.util.Utils /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index f16c257ab5ab4..c2c896e5f61bb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { private lazy val df = sqlContext.range(10).coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 49aab85cf1aaf..4a73153a80ee8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -21,8 +21,8 @@ import java.sql.Timestamp import org.apache.hadoop.hive.conf.HiveConf -import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.hive.test.TestHiveSingleton class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index f542a5a02508c..f49ee690ac041 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.util.Utils -import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f775f1e955876..78f74cdc19ddb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag -import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 502b240f3650f..ff10a251f3b45 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.hadoop.util.VersionInfo -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, NamedExpression} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.IntegerType import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index e38d1eb5779fe..f5cd73d45ed75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.BeforeAndAfterAll class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4455430aa727a..d7e8ebc8d312f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.test.TestHive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index a7b7ad0093915..b7ef5d1db7297 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index 0d4c7f86b315a..9bdc24162b73d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index acd1130f2762c..98e22c2e2c1b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -26,14 +26,14 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter +import org.apache.spark.{SparkException, SparkFiles} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} -import org.apache.spark.{SparkException, SparkFiles} +import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 2209fc2f30a3c..b0c0dcbe5c25c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ - import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 9deb1a6db15ad..c5ff8825abd7f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -17,23 +17,23 @@ package org.apache.spark.sql.hive.execution -import java.io.{PrintWriter, File, DataInput, DataOutput} +import java.io.{DataInput, DataOutput, File, PrintWriter} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.UDAFPercentile -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDFOPAnd, GenericUDTFExplode, GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF, GenericUDFOPAnd, GenericUDTFExplode} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.Writable -import org.apache.spark.sql.test.SQLTestUtils + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils - case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da02..bf65325d54fe2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -22,13 +22,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{TableIdentifier, DefaultParserDialect} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} +import org.apache.spark.sql.catalyst.{DefaultParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 7cfdb886b585d..8f163f27c94cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryNode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 7b61b635bdb48..5afc7e77ab775 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.orc import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, PredicateLeaf} +import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.spark.sql.{Column, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.dsl.expressions._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 52e09f9496f05..6161412a49775 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -22,8 +22,8 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 47e73b4006fa5..27ea3e8041651 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -21,9 +21,9 @@ import java.io.File import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{QueryTest, Row} case class OrcData(intField: Int, stringField: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 88a0ed511749f..637c10611afc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 905eb7a3925b2..2ceb836681901 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index dc0531a6d4bc5..64c61a5092540 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils - class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index b554d135e4b5c..058c101eebb04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -21,14 +21,13 @@ import java.io.File import org.apache.hadoop.fs.Path -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{execution, Column, DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, PredicateHelper} import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, Row, execution} import org.apache.spark.util.Utils class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index e10d21d5e3682..9fc437bf8815a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -22,14 +22,14 @@ import java.text.NumberFormat import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{sources, Row, SQLContext} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLContext, sources} /** * A simple example [[HadoopFsRelationProvider]]. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 4c12ecc399e41..94f1bcebc3a39 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -21,11 +21,11 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.ManualClock -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} /** * Tests whether scope information is passed from DStream operations to RDDs correctly. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index e82c2fa4e72ad..6a0b0a1d47bc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFunSuite, Logging} +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 3a3176b91b1ee..2e231601c3953 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -17,30 +17,30 @@ package org.apache.spark.streaming -import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{Socket, SocketException, ServerSocket} +import java.io.{BufferedWriter, File, OutputStreamWriter} +import java.net.{ServerSocket, Socket, SocketException} import java.nio.charset.Charset -import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{ArrayBlockingQueue, CountDownLatch, Executors, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer, SynchronizedQueue} import scala.language.postfixOps import com.google.common.io.Files -import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} -import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerBatchCompleted} +import org.apache.spark.util.{ManualClock, Utils} class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 62d75a9e0e7aa..2984fd2b298dc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -22,12 +22,12 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag -import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.PrivateMethodTester._ +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class MapWithStateSuite extends SparkFunSuite with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 0e64b57e0ffd8..4e56dfbd424b0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -17,23 +17,21 @@ package org.apache.spark.streaming -import org.apache.spark.Logging -import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.util.Utils - -import scala.util.Random -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - import java.io.{File, IOException} import java.nio.charset.Charset import java.util.UUID -import com.google.common.io.Files +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import scala.util.Random -import org.apache.hadoop.fs.Path +import com.google.common.io.Files import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils private[streaming] object MasterFailureTest extends Logging { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c17fb7238151b..dd16fc3ecaf5d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -39,8 +39,6 @@ import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ import org.apache.spark.util.{ManualClock, Utils} -import WriteAheadLogBasedBlockHandler._ -import WriteAheadLogSuite._ class ReceivedBlockHandlerSuite extends SparkFunSuite @@ -48,6 +46,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { + import WriteAheadLogBasedBlockHandler._ + import WriteAheadLogSuite._ + val conf = new SparkConf() .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") .set("spark.app.id", "streaming-test") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index e6d8fbd4d7c57..a4871b460eb4d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -28,7 +29,6 @@ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} -import org.apache.spark.{SparkConf, SparkEnv} class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 01279b34f73dc..917232c9cdd63 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,8 +24,8 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 04cd5bdc26be2..628a5082074db 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -18,20 +18,20 @@ package org.apache.spark.streaming import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} -import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.apache.spark.Logging import org.apache.spark.SparkException import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ -import org.scalatest.Matchers -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ -import org.apache.spark.Logging - class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index be0f4636a6cb8..54eff2b214290 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import java.io.{ObjectInputStream, IOException} +import java.io.{IOException, ObjectInputStream} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.SynchronizedBuffer @@ -25,13 +25,13 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.scalatest.BeforeAndAfter -import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration +import org.scalatest.time.{Seconds => ScalaTestSeconds, Span} import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream, InputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index c39ad05f41520..c7d085ec0799b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.DStream class WindowOperationsSuite extends TestSuiteBase { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index 1640b9e6b7a6c..5b13fd6ad611a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.streaming.{State, Time} +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.util.Utils class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 43833c4361473..79ac833c1846b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -23,10 +23,10 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 92ad9fe52b777..f5ec0ff60aa27 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -22,13 +22,13 @@ import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ -import org.scalatest.concurrent.Timeouts._ import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.util.ManualClock -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index f5248acf712b9..a7e365649d3e8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index a670c7d638192..b5d6a24ce8dd6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import java.util.concurrent.{CountDownLatch, RejectedExecutionException, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{RejectedExecutionException, TimeUnit, CountDownLatch, ThreadPoolExecutor} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -31,17 +31,16 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.mockito.ArgumentCaptor -import org.mockito.Matchers.{eq => meq} -import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar -import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{CompletionIterator, ThreadUtils, ManualClock, Utils} import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{CompletionIterator, ManualClock, ThreadUtils, Utils} /** Common tests for WriteAheadLogs that we would like to test with different configurations. */ abstract class CommonWriteAheadLogTests( diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala index bfc5b0cf60fb1..2a41177a5e638 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils class WriteAheadLogUtilsSuite extends SparkFunSuite { diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 12a002befa0ac..b3bbd68827b0e 100644 --- a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.types import org.apache.commons.lang3.StringUtils - import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.GeneratorDrivenPropertyChecks // scalastyle:off diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 804dfecde7867..4cffbb2e9b96d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.mock.MockitoSugar -import org.mockito.Mockito.when +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus @@ -28,16 +28,14 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.yarn.api.records.LocalResource -import org.apache.hadoop.yarn.api.records.LocalResourceVisibility import org.apache.hadoop.yarn.api.records.LocalResourceType -import org.apache.hadoop.yarn.util.{Records, ConverterUtils} - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Map +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.mockito.Mockito.when +import org.scalatest.mock.MockitoSugar import org.apache.spark.SparkFunSuite - class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 7709c2f6e4f5f..998bd1377d562 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -41,7 +41,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.{Utils, ResetSystemProperties} +import org.apache.spark.util.{ResetSystemProperties, Utils} class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll with ResetSystemProperties { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 57edbd67253d4..1dd2f93bb708b 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,15 +25,12 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.scalatest.{BeforeAndAfterEach, Matchers} - -import org.scalatest.{BeforeAndAfterEach, Matchers} import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.apache.spark.{SecurityManager, SparkFunSuite} -import org.apache.spark.SparkConf -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index c2861c9d7fbc7..d3acaf229cc85 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -27,15 +27,13 @@ import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers -import org.apache.hadoop.yarn.api.records.ApplicationAccessType - import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.{Utils, ResetSystemProperties} - +import org.apache.spark.util.{ResetSystemProperties, Utils} class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging with ResetSystemProperties { diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala index 94bf579dc8247..d6902c7bb0739 100644 --- a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.network.shuffle -import java.io.{IOException, File} +import java.io.{File, IOException} import java.util.concurrent.ConcurrentMap import org.apache.hadoop.yarn.api.records.ApplicationId From d1fea41363c175a67b97cb7b3fe89f9043708739 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 6 Jan 2016 12:05:41 +0530 Subject: [PATCH 1455/2089] [SPARK-12393][SPARKR] Add read.text and write.text for SparkR Add ```read.text``` and ```write.text``` for SparkR. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10348 from yanboliang/spark-12393. --- R/pkg/NAMESPACE | 4 +++- R/pkg/R/DataFrame.R | 28 +++++++++++++++++++++++ R/pkg/R/SQLContext.R | 26 +++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 21 +++++++++++++++++ 5 files changed, 82 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ccc01fe169601..beacc39500aaa 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -94,7 +94,8 @@ exportMethods("arrange", "withColumnRenamed", "write.df", "write.json", - "write.parquet") + "write.parquet", + "write.text") exportClasses("Column") @@ -274,6 +275,7 @@ export("as.DataFrame", "parquetFile", "read.df", "read.parquet", + "read.text", "sql", "table", "tableNames", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c126f9efb475a..3bf5bc924f7dc 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -664,6 +664,34 @@ setMethod("saveAsParquetFile", write.parquet(x, path) }) +#' write.text +#' +#' Saves the content of the DataFrame in a text file at the specified path. +#' The DataFrame must have only one column of string type with the name "value". +#' Each row becomes a new line in the output file. +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.text +#' @name write.text +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.txt" +#' df <- read.text(sqlContext, path) +#' write.text(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.text", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "text", path)) + }) + #' Distinct #' #' Return a new DataFrame containing the distinct rows in this DataFrame. diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ccc683d86a3e5..99679b4a774d3 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -295,6 +295,32 @@ parquetFile <- function(sqlContext, ...) { read.parquet(sqlContext, unlist(list(...))) } +#' Create a DataFrame from a text file. +#' +#' Loads a text file and returns a DataFrame with a single string column named "value". +#' Each line in the text file is a new row in the resulting DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @rdname read.text +#' @name read.text +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.txt" +#' df <- read.text(sqlContext, path) +#' } +read.text <- function(sqlContext, path) { + # Allow the user to have a more flexible definiton of the text file path + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "text", paths) + dataFrame(sdf) +} + #' SQL Query #' #' Executes a SQL query using Spark, returning the result as a DataFrame. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 62be2ddc8f522..ba6861709754d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -549,6 +549,10 @@ setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +#' @rdname write.text +#' @export +setGeneric("write.text", function(x, path) { standardGeneric("write.text") }) + #' @rdname schema #' @export setGeneric("schema", function(x) { standardGeneric("schema") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ebe8faa34cf7d..eaf60beda3473 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1497,6 +1497,27 @@ test_that("read/write Parquet files", { unlink(parquetPath4) }) +test_that("read/write text files", { + # Test write.df and read.df + df <- read.df(sqlContext, jsonPath, "text") + expect_is(df, "DataFrame") + expect_equal(colnames(df), c("value")) + expect_equal(count(df), 3) + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + write.df(df, textPath, "text", mode="overwrite") + + # Test write.text and read.text + textPath2 <- tempfile(pattern = "textPath2", fileext = ".txt") + write.text(df, textPath2) + df2 <- read.text(sqlContext, c(textPath, textPath2)) + expect_is(df2, "DataFrame") + expect_equal(colnames(df2), c("value")) + expect_equal(count(df2), count(df) * 2) + + unlink(textPath) + unlink(textPath2) +}) + test_that("describe() and summarize() on a DataFrame", { df <- read.json(sqlContext, jsonPath) stats <- describe(df, "age") From b2467b381096804b862990d9ecda554f67e07ee1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Jan 2016 00:40:14 -0800 Subject: [PATCH 1456/2089] [SPARK-12578][SQL] Distinct should not be silently ignored when used in an aggregate function with OVER clause JIRA: https://issues.apache.org/jira/browse/SPARK-12578 Slightly update to Hive parser. We should keep the distinct keyword when used in an aggregate function with OVER clause. So the CheckAnalysis will detect it and throw exception later. Author: Liang-Chi Hsieh Closes #10557 from viirya/keep-distinct-hivesql. --- .../spark/sql/parser/IdentifiersParser.g | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g index 5c3d7ef866240..9f1e168374f01 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g +++ b/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g @@ -195,7 +195,7 @@ function RPAREN (KW_OVER ws=window_specification)? -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) - -> ^(TOK_FUNCTIONDI functionName (selectExpression+)?) + -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?) ; functionName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bf65325d54fe2..593fac2c32817 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -915,6 +915,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: distinct should not be silently ignored") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + val e = intercept[AnalysisException] { + sql( + """ + |select month, area, product, sum(distinct product + 1) over (partition by 1 order by 2) + |from windowData + """.stripMargin) + } + assert(e.getMessage.contains("Distinct window functions are not supported")) + } + test("window function: expressions in arguments of a window functions") { val data = Seq( WindowData(1, "a", 5), From 5d871ea43efdde59e05896a50d57021388412d30 Mon Sep 17 00:00:00 2001 From: QiangCai Date: Wed, 6 Jan 2016 18:13:07 +0900 Subject: [PATCH 1457/2089] [SPARK-12340][SQL] fix Int overflow in the SparkPlan.executeTake, RDD.take and AsyncRDDActions.takeAsync I have closed pull request https://github.com/apache/spark/pull/10487. And I create this pull request to resolve the problem. spark jira https://issues.apache.org/jira/browse/SPARK-12340 Author: QiangCai Closes #10562 from QiangCai/bugfix. --- .../scala/org/apache/spark/rdd/AsyncRDDActions.scala | 12 ++++++------ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++++---- .../org/apache/spark/sql/execution/SparkPlan.scala | 8 ++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++++ 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ec48925823a02..94719a4572ef6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val localProperties = self.context.getLocalProperties // Cached thread pool to handle aggregation of subtasks. implicit val executionContext = AsyncRDDActions.futureExecutionContext - val results = new ArrayBuffer[T](num) + val results = new ArrayBuffer[T] val totalParts = self.partitions.length /* @@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -111,11 +111,11 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi Unit) job.flatMap {_ => buf.foreach(results ++= _.take(num - results.size)) - continue(partsScanned + numPartsToTry) + continue(partsScanned + p.size) } } - new ComplexFutureAction[Seq[T]](continue(0)(_)) + new ComplexFutureAction[Seq[T]](continue(0L)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d6eac7888d5fd..e25657cc109be 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1190,11 +1190,11 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1209,11 +1209,11 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index f20f32aaced2e..21a6fba9078df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -165,11 +165,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -183,13 +183,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val sc = sqlContext.sparkContext val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5de0979606b88..bd987ae1bb03a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2067,4 +2067,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") { + val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) + rdd.toDF("key").registerTempTable("spark12340") + checkAnswer( + sql("select key from spark12340 limit 2147483638"), + Row(1) :: Row(2) :: Row(3) :: Nil + ) + assert(rdd.take(2147483638).size === 3) + assert(rdd.takeAsync(2147483638).get.size === 3) + } + } From 94c202c7d2c613d00a0a56f870a0cddbff599be3 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 6 Jan 2016 10:19:41 -0800 Subject: [PATCH 1458/2089] [SPARK-12665][CORE][GRAPHX] Remove Vector, VectorSuite and GraphKryoRegistrator which are deprecated and no longer used Whole code of Vector.scala, VectorSuite.scala and GraphKryoRegistrator.scala are no longer used so it's time to remove them in Spark 2.0. Author: Kousuke Saruta Closes #10613 from sarutak/SPARK-12665. --- .../scala/org/apache/spark/util/Vector.scala | 159 ------------------ .../org/apache/spark/util/VectorSuite.scala | 45 ----- .../spark/graphx/GraphKryoRegistrator.scala | 49 ------ project/MimaExcludes.scala | 7 + 4 files changed, 7 insertions(+), 253 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/Vector.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/VectorSuite.scala delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala deleted file mode 100644 index 6b3fa8491904a..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.language.implicitConversions -import scala.util.Random - -import org.apache.spark.util.random.XORShiftRandom - -@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") -class Vector(val elements: Array[Double]) extends Serializable { - def length: Int = elements.length - - def apply(index: Int): Double = elements(index) - - def + (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - Vector(length, i => this(i) + other(i)) - } - - def add(other: Vector): Vector = this + other - - def - (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - Vector(length, i => this(i) - other(i)) - } - - def subtract(other: Vector): Vector = this - other - - def dot(other: Vector): Double = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var ans = 0.0 - var i = 0 - while (i < length) { - ans += this(i) * other(i) - i += 1 - } - ans - } - - /** - * return (this + plus) dot other, but without creating any intermediate storage - * @param plus - * @param other - * @return - */ - def plusDot(plus: Vector, other: Vector): Double = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - if (length != plus.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) + plus(i)) * other(i) - i += 1 - } - ans - } - - def += (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var i = 0 - while (i < length) { - elements(i) += other(i) - i += 1 - } - this - } - - def addInPlace(other: Vector): Vector = this +=other - - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - - def multiply (d: Double): Vector = this * d - - def / (d: Double): Vector = this * (1 / d) - - def divide (d: Double): Vector = this / d - - def unary_- : Vector = this * -1 - - def sum: Double = elements.reduceLeft(_ + _) - - def squaredDist(other: Vector): Double = { - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) - other(i)) * (this(i) - other(i)) - i += 1 - } - ans - } - - def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - - override def toString: String = elements.mkString("(", ", ", ")") -} - -@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") -object Vector { - def apply(elements: Array[Double]): Vector = new Vector(elements) - - def apply(elements: Double*): Vector = new Vector(elements.toArray) - - def apply(length: Int, initializer: Int => Double): Vector = { - val elements: Array[Double] = Array.tabulate(length)(initializer) - new Vector(elements) - } - - def zeros(length: Int): Vector = new Vector(new Array[Double](length)) - - def ones(length: Int): Vector = Vector(length, _ => 1) - - /** - * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers - * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. - */ - def random(length: Int, random: Random = new XORShiftRandom()): Vector = - Vector(length, _ => random.nextDouble()) - - class Multiplier(num: Double) { - def * (vec: Vector): Vector = vec * num - } - - implicit def doubleToMultiplier(num: Double): Multiplier = new Multiplier(num) - - implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector): Vector = t1 + t2 - - def zero(initialValue: Vector): Vector = Vector.zeros(initialValue.length) - } - -} diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala deleted file mode 100644 index 11194cd22a419..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.util.Random - -import org.apache.spark.SparkFunSuite - -/** - * Tests org.apache.spark.util.Vector functionality - */ -@deprecated("suppress compile time deprecation warning", "1.0.0") -class VectorSuite extends SparkFunSuite { - - def verifyVector(vector: Vector, expectedLength: Int): Unit = { - assert(vector.length == expectedLength) - assert(vector.elements.min > 0.0) - assert(vector.elements.max < 1.0) - } - - test("random with default random number generator") { - val vector100 = Vector.random(100) - verifyVector(vector100, 100) - } - - test("random with given random number generator") { - val vector100 = Vector.random(100, new Random(100)) - verifyVector(vector100, 100) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala deleted file mode 100644 index eaa71dad17a04..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import com.esotericsoftware.kryo.Kryo - -import org.apache.spark.graphx.impl._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.util.BoundedPriorityQueue -import org.apache.spark.util.collection.BitSet -import org.apache.spark.util.collection.OpenHashSet - -/** - * Registers GraphX classes with Kryo for improved performance. - */ -@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0") -class GraphKryoRegistrator extends KryoRegistrator { - - def registerClasses(kryo: Kryo) { - kryo.register(classOf[Edge[Object]]) - kryo.register(classOf[(VertexId, Object)]) - kryo.register(classOf[EdgePartition[Object, Object]]) - kryo.register(classOf[BitSet]) - kryo.register(classOf[VertexIdToIndexMap]) - kryo.register(classOf[VertexAttributeBlock[Object]]) - kryo.register(classOf[PartitionStrategy]) - kryo.register(classOf[BoundedPriorityQueue[Object]]) - kryo.register(classOf[EdgeDirection]) - kryo.register(classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]]) - kryo.register(classOf[OpenHashSet[Int]]) - kryo.register(classOf[OpenHashSet[Long]]) - } -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 940fedfa2ab60..43ca4690dc2bb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -112,6 +112,13 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles") + ) ++ + // SPARK-12665 Remove deprecated and unused classes + Seq( + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") ) case v if v.startsWith("1.6") => Seq( From 9061e777fdcd5767718808e325e8953d484aa761 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Wed, 6 Jan 2016 10:37:53 -0800 Subject: [PATCH 1459/2089] [SPARK-11878][SQL] Eliminate distribute by in case group by is present with exactly the same grouping expressi For queries like : select <> from table group by a distribute by a we can eliminate distribute by ; since group by will anyways do a hash partitioning Also applicable when user uses Dataframe API Author: Yash Datta Closes #9858 from saucam/eliminatedistribute. --- .../apache/spark/sql/execution/Exchange.scala | 6 +++ .../spark/sql/execution/PlannerSuite.scala | 39 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7b4161930b7d2..6b100577077c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -467,6 +467,12 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator @ Exchange(partitioning, child, _) => + child.children match { + case Exchange(childPartitioning, baseChild, _)::Nil => + if (childPartitioning.guarantees(partitioning)) child else operator + case _ => operator + } case operator: SparkPlan => ensureDistributionAndOrdering(operator) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 858e289c2716e..03a1b8e11d455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -417,6 +417,45 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = Exchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.size == 2) { + fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") + } + } + + test("EnsureRequirements does not eliminate Exchange with different partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = Exchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.size == 1) { + fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") + } + } + // --------------------------------------------------------------------------------------------- } From 3b29004d2439c03a7d9bfdf7c2edd757d3d8c240 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 6 Jan 2016 10:43:03 -0800 Subject: [PATCH 1460/2089] [SPARK-7675][ML][PYSPARK] sparkml params type conversion From JIRA: Currently, PySpark wrappers for spark.ml Scala classes are brittle when accepting Param types. E.g., Normalizer's "p" param cannot be set to "2" (an integer); it must be set to "2.0" (a float). Fixing this is not trivial since there does not appear to be a natural place to insert the conversion before Python wrappers call Java's Params setter method. A possible fix will be to include a method "_checkType" to PySpark's Param class which checks the type, prints an error if needed, and converts types when relevant (e.g., int to float, or scipy matrix to array). The Java wrapper method which copies params to Scala can call this method when available. This fix instead checks the types at set time since I think failing sooner is better, but I can switch it around to check at copy time if that would be better. So far this only converts int to float and other conversions (like scipymatrix to array) are left for the future. Author: Holden Karau Closes #9581 from holdenk/SPARK-7675-PySpark-sparkml-Params-type-conversion. --- python/pyspark/ml/param/__init__.py | 22 ++- .../ml/param/_shared_params_code_gen.py | 63 ++++---- python/pyspark/ml/param/shared.py | 144 +++++++++--------- python/pyspark/ml/tests.py | 22 +++ 4 files changed, 148 insertions(+), 103 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 35c9b776a3d5e..92ce96aa3c4df 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -32,12 +32,13 @@ class Param(object): .. versionadded:: 1.3.0 """ - def __init__(self, parent, name, doc): + def __init__(self, parent, name, doc, expectedType=None): if not isinstance(parent, Identifiable): raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) self.parent = parent.uid self.name = str(name) self.doc = str(doc) + self.expectedType = expectedType def __str__(self): return str(self.parent) + "__" + self.name @@ -247,7 +248,24 @@ def _set(self, **kwargs): Sets user-supplied params. """ for param, value in kwargs.items(): - self._paramMap[getattr(self, param)] = value + p = getattr(self, param) + if p.expectedType is None or type(value) == p.expectedType or value is None: + self._paramMap[getattr(self, param)] = value + else: + try: + # Try and do "safe" conversions that don't lose information + if p.expectedType == float: + self._paramMap[getattr(self, param)] = float(value) + # Python 3 unified long & int + elif p.expectedType == int and type(value).__name__ == 'long': + self._paramMap[getattr(self, param)] = value + else: + raise Exception( + "Provided type {0} incompatible with type {1} for param {2}" + .format(type(value), p.expectedType, p)) + except ValueError: + raise Exception(("Failed to convert {0} to type {1} for param {2}" + .format(type(value), p.expectedType, p))) return self def _setDefault(self, **kwargs): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 0528dc1e3a6b9..82855bc4c75ba 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -38,7 +38,7 @@ # python _shared_params_code_gen.py > shared.py -def _gen_param_header(name, doc, defaultValueStr): +def _gen_param_header(name, doc, defaultValueStr, expectedType): """ Generates the header part for shared variables @@ -51,22 +51,26 @@ def _gen_param_header(name, doc, defaultValueStr): """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc") + $name = Param(Params._dummy(), "$name", "$doc", $expectedType) def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc")''' + self.$name = Param(self, "$name", "$doc", $expectedType)''' if defaultValueStr is not None: template += ''' self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] + expectedTypeName = str(expectedType) + if expectedType is not None: + expectedTypeName = expectedType.__name__ return template \ .replace("$name", name) \ .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValueStr", str(defaultValueStr)) + .replace("$defaultValueStr", str(defaultValueStr)) \ + .replace("$expectedType", expectedTypeName) def _gen_param_code(name, doc, defaultValueStr): @@ -84,7 +88,7 @@ def set$Name(self, value): """ Sets the value of :py:attr:`$name`. """ - self._paramMap[self.$name] = value + self._set($name=value) return self def get$Name(self): @@ -105,44 +109,45 @@ def get$Name(self): print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") print("from pyspark.ml.param import Param, Params\n\n") shared = [ - ("maxIter", "max number of iterations (>= 0).", None), - ("regParam", "regularization parameter (>= 0).", None), - ("featuresCol", "features column name.", "'features'"), - ("labelCol", "label column name.", "'label'"), - ("predictionCol", "prediction column name.", "'prediction'"), + ("maxIter", "max number of iterations (>= 0).", None, int), + ("regParam", "regularization parameter (>= 0).", None, float), + ("featuresCol", "features column name.", "'features'", str), + ("labelCol", "label column name.", "'label'", str), + ("predictionCol", "prediction column name.", "'prediction'", str), ("probabilityCol", "Column name for predicted class conditional probabilities. " + "Note: Not all models output well-calibrated probability estimates! These probabilities " + - "should be treated as confidences, not precise probabilities.", "'probability'"), - ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'"), - ("inputCol", "input column name.", None), - ("inputCols", "input column names.", None), - ("outputCol", "output column name.", "self.uid + '__output'"), - ("numFeatures", "number of features.", None), + "should be treated as confidences, not precise probabilities.", "'probability'", str), + ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'", + str), + ("inputCol", "input column name.", None, str), + ("inputCols", "input column names.", None, None), + ("outputCol", "output column name.", "self.uid + '__output'", str), + ("numFeatures", "number of features.", None, int), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), - ("seed", "random seed.", "hash(type(self).__name__)"), - ("tol", "the convergence tolerance for iterative algorithms.", None), - ("stepSize", "Step size to be used for each iteration of optimization.", None), + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, int), + ("seed", "random seed.", "hash(type(self).__name__)", int), + ("tol", "the convergence tolerance for iterative algorithms.", None, float), + ("stepSize", "Step size to be used for each iteration of optimization.", None, float), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None), + "added later.", None, str), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), - ("fitIntercept", "whether to fit an intercept term.", "True"), + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", float), + ("fitIntercept", "whether to fit an intercept term.", "True", bool), ("standardization", "whether to standardize the training features before fitting the " + - "model.", "True"), + "model.", "True", bool), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None), + "probability of that class and t is the class' threshold.", None, None), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None), + "all instance weights as 1.0.", None, str), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + - "default value is 'auto'.", "'auto'")] + "default value is 'auto'.", "'auto'", str)] code = [] - for name, doc, defaultValueStr in shared: - param_code = _gen_param_header(name, doc, defaultValueStr) + for name, doc, defaultValueStr, expectedType in shared: + param_code = _gen_param_header(name, doc, defaultValueStr, expectedType) code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr)) decisionTreeParams = [ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4d960801502c2..23f94314844f6 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -26,18 +26,18 @@ class HasMaxIter(Params): """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).") + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", int) def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations (>= 0). - self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).") + self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).", int) def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self._paramMap[self.maxIter] = value + self._set(maxIter=value) return self def getMaxIter(self): @@ -53,18 +53,18 @@ class HasRegParam(Params): """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).") + regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", float) def __init__(self): super(HasRegParam, self).__init__() #: param for regularization parameter (>= 0). - self.regParam = Param(self, "regParam", "regularization parameter (>= 0).") + self.regParam = Param(self, "regParam", "regularization parameter (>= 0).", float) def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self._paramMap[self.regParam] = value + self._set(regParam=value) return self def getRegParam(self): @@ -80,19 +80,19 @@ class HasFeaturesCol(Params): """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name.") + featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", str) def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name. - self.featuresCol = Param(self, "featuresCol", "features column name.") + self.featuresCol = Param(self, "featuresCol", "features column name.", str) self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self._paramMap[self.featuresCol] = value + self._set(featuresCol=value) return self def getFeaturesCol(self): @@ -108,19 +108,19 @@ class HasLabelCol(Params): """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name.") + labelCol = Param(Params._dummy(), "labelCol", "label column name.", str) def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name. - self.labelCol = Param(self, "labelCol", "label column name.") + self.labelCol = Param(self, "labelCol", "label column name.", str) self._setDefault(labelCol='label') def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self._paramMap[self.labelCol] = value + self._set(labelCol=value) return self def getLabelCol(self): @@ -136,19 +136,19 @@ class HasPredictionCol(Params): """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.") + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", str) def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name. - self.predictionCol = Param(self, "predictionCol", "prediction column name.") + self.predictionCol = Param(self, "predictionCol", "prediction column name.", str) self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self._paramMap[self.predictionCol] = value + self._set(predictionCol=value) return self def getPredictionCol(self): @@ -164,19 +164,19 @@ class HasProbabilityCol(Params): """ # a placeholder to make it appear in the generated doc - probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) def __init__(self): super(HasProbabilityCol, self).__init__() #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. - self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) self._setDefault(probabilityCol='probability') def setProbabilityCol(self, value): """ Sets the value of :py:attr:`probabilityCol`. """ - self._paramMap[self.probabilityCol] = value + self._set(probabilityCol=value) return self def getProbabilityCol(self): @@ -192,19 +192,19 @@ class HasRawPredictionCol(Params): """ # a placeholder to make it appear in the generated doc - rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) def __init__(self): super(HasRawPredictionCol, self).__init__() #: param for raw prediction (a.k.a. confidence) column name. - self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") + self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): """ Sets the value of :py:attr:`rawPredictionCol`. """ - self._paramMap[self.rawPredictionCol] = value + self._set(rawPredictionCol=value) return self def getRawPredictionCol(self): @@ -220,18 +220,18 @@ class HasInputCol(Params): """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name.") + inputCol = Param(Params._dummy(), "inputCol", "input column name.", str) def __init__(self): super(HasInputCol, self).__init__() #: param for input column name. - self.inputCol = Param(self, "inputCol", "input column name.") + self.inputCol = Param(self, "inputCol", "input column name.", str) def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self._paramMap[self.inputCol] = value + self._set(inputCol=value) return self def getInputCol(self): @@ -247,18 +247,18 @@ class HasInputCols(Params): """ # a placeholder to make it appear in the generated doc - inputCols = Param(Params._dummy(), "inputCols", "input column names.") + inputCols = Param(Params._dummy(), "inputCols", "input column names.", None) def __init__(self): super(HasInputCols, self).__init__() #: param for input column names. - self.inputCols = Param(self, "inputCols", "input column names.") + self.inputCols = Param(self, "inputCols", "input column names.", None) def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ - self._paramMap[self.inputCols] = value + self._set(inputCols=value) return self def getInputCols(self): @@ -274,19 +274,19 @@ class HasOutputCol(Params): """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name.") + outputCol = Param(Params._dummy(), "outputCol", "output column name.", str) def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name. - self.outputCol = Param(self, "outputCol", "output column name.") + self.outputCol = Param(self, "outputCol", "output column name.", str) self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self._paramMap[self.outputCol] = value + self._set(outputCol=value) return self def getOutputCol(self): @@ -302,18 +302,18 @@ class HasNumFeatures(Params): """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features.") + numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", int) def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features. - self.numFeatures = Param(self, "numFeatures", "number of features.") + self.numFeatures = Param(self, "numFeatures", "number of features.", int) def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self._paramMap[self.numFeatures] = value + self._set(numFeatures=value) return self def getNumFeatures(self): @@ -329,18 +329,18 @@ class HasCheckpointInterval(Params): """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def __init__(self): super(HasCheckpointInterval, self).__init__() #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. - self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def setCheckpointInterval(self, value): """ Sets the value of :py:attr:`checkpointInterval`. """ - self._paramMap[self.checkpointInterval] = value + self._set(checkpointInterval=value) return self def getCheckpointInterval(self): @@ -356,19 +356,19 @@ class HasSeed(Params): """ # a placeholder to make it appear in the generated doc - seed = Param(Params._dummy(), "seed", "random seed.") + seed = Param(Params._dummy(), "seed", "random seed.", int) def __init__(self): super(HasSeed, self).__init__() #: param for random seed. - self.seed = Param(self, "seed", "random seed.") + self.seed = Param(self, "seed", "random seed.", int) self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): """ Sets the value of :py:attr:`seed`. """ - self._paramMap[self.seed] = value + self._set(seed=value) return self def getSeed(self): @@ -384,18 +384,18 @@ class HasTol(Params): """ # a placeholder to make it appear in the generated doc - tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.") + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", float) def __init__(self): super(HasTol, self).__init__() #: param for the convergence tolerance for iterative algorithms. - self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.") + self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.", float) def setTol(self, value): """ Sets the value of :py:attr:`tol`. """ - self._paramMap[self.tol] = value + self._set(tol=value) return self def getTol(self): @@ -411,18 +411,18 @@ class HasStepSize(Params): """ # a placeholder to make it appear in the generated doc - stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.") + stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", float) def __init__(self): super(HasStepSize, self).__init__() #: param for Step size to be used for each iteration of optimization. - self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.") + self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.", float) def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self._paramMap[self.stepSize] = value + self._set(stepSize=value) return self def getStepSize(self): @@ -438,18 +438,18 @@ class HasHandleInvalid(Params): """ # a placeholder to make it appear in the generated doc - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def __init__(self): super(HasHandleInvalid, self).__init__() #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. - self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ - self._paramMap[self.handleInvalid] = value + self._set(handleInvalid=value) return self def getHandleInvalid(self): @@ -465,19 +465,19 @@ class HasElasticNetParam(Params): """ # a placeholder to make it appear in the generated doc - elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) def __init__(self): super(HasElasticNetParam, self).__init__() #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) self._setDefault(elasticNetParam=0.0) def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self._paramMap[self.elasticNetParam] = value + self._set(elasticNetParam=value) return self def getElasticNetParam(self): @@ -493,19 +493,19 @@ class HasFitIntercept(Params): """ # a placeholder to make it appear in the generated doc - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", bool) def __init__(self): super(HasFitIntercept, self).__init__() #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.", bool) self._setDefault(fitIntercept=True) def setFitIntercept(self, value): """ Sets the value of :py:attr:`fitIntercept`. """ - self._paramMap[self.fitIntercept] = value + self._set(fitIntercept=value) return self def getFitIntercept(self): @@ -521,19 +521,19 @@ class HasStandardization(Params): """ # a placeholder to make it appear in the generated doc - standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", bool) def __init__(self): super(HasStandardization, self).__init__() #: param for whether to standardize the training features before fitting the model. - self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.", bool) self._setDefault(standardization=True) def setStandardization(self, value): """ Sets the value of :py:attr:`standardization`. """ - self._paramMap[self.standardization] = value + self._set(standardization=value) return self def getStandardization(self): @@ -549,18 +549,18 @@ class HasThresholds(Params): """ # a placeholder to make it appear in the generated doc - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def __init__(self): super(HasThresholds, self).__init__() #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. - self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. """ - self._paramMap[self.thresholds] = value + self._set(thresholds=value) return self def getThresholds(self): @@ -576,18 +576,18 @@ class HasWeightCol(Params): """ # a placeholder to make it appear in the generated doc - weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def __init__(self): super(HasWeightCol, self).__init__() #: param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. - self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def setWeightCol(self, value): """ Sets the value of :py:attr:`weightCol`. """ - self._paramMap[self.weightCol] = value + self._set(weightCol=value) return self def getWeightCol(self): @@ -603,19 +603,19 @@ class HasSolver(Params): """ # a placeholder to make it appear in the generated doc - solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) def __init__(self): super(HasSolver, self).__init__() #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. - self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) self._setDefault(solver='auto') def setSolver(self, value): """ Sets the value of :py:attr:`solver`. """ - self._paramMap[self.solver] = value + self._set(solver=value) return self def getSolver(self): @@ -658,7 +658,7 @@ def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ - self._paramMap[self.maxDepth] = value + self._set(maxDepth=value) return self def getMaxDepth(self): @@ -671,7 +671,7 @@ def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ - self._paramMap[self.maxBins] = value + self._set(maxBins=value) return self def getMaxBins(self): @@ -684,7 +684,7 @@ def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ - self._paramMap[self.minInstancesPerNode] = value + self._set(minInstancesPerNode=value) return self def getMinInstancesPerNode(self): @@ -697,7 +697,7 @@ def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ - self._paramMap[self.minInfoGain] = value + self._set(minInfoGain=value) return self def getMinInfoGain(self): @@ -710,7 +710,7 @@ def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ - self._paramMap[self.maxMemoryInMB] = value + self._set(maxMemoryInMB=value) return self def getMaxMemoryInMB(self): @@ -723,7 +723,7 @@ def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. """ - self._paramMap[self.cacheNodeIds] = value + self._set(cacheNodeIds=value) return self def getCacheNodeIds(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7a16cf52cccb2..4eb17bfdcca90 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -37,6 +37,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand +from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -92,6 +93,27 @@ class MockModel(MockTransformer, Model, HasFake): pass +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int_to_float(self): + from pyspark.mllib.linalg import Vectors + df = self.sc.parallelize([ + Row(label=1.0, weight=2.0, features=Vectors.dense(1.0))]).toDF() + lr = LogisticRegression(elasticNetParam=0) + lr.fit(df) + lr.setElasticNetParam(0) + lr.fit(df) + + def test_invalid_to_float(self): + from pyspark.mllib.linalg import Vectors + self.assertRaises(Exception, lambda: LogisticRegression(elasticNetParam="happy")) + lr = LogisticRegression(elasticNetParam=0) + self.assertRaises(Exception, lambda: lr.setElasticNetParam("panda")) + + class PipelineTests(PySparkTestCase): def test_pipeline(self): From 007da1a9dc3bb912da841cc0f5832a4fa28e6d9d Mon Sep 17 00:00:00 2001 From: Joshi Date: Wed, 6 Jan 2016 10:48:14 -0800 Subject: [PATCH 1461/2089] [SPARK-11531][ML] SparseVector error Msg PySpark SparseVector should have "Found duplicate indices" error message Author: Joshi Author: Rekha Joshi Closes #9525 from rekhajoshm/SPARK-11531. --- python/pyspark/mllib/linalg/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index ae9ce58450905..131b855bf99c3 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -528,7 +528,9 @@ def __init__(self, size, *args): assert len(self.indices) == len(self.values), "index and value arrays not same length" for i in xrange(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: - raise TypeError("indices array must be sorted") + raise TypeError( + "Indices %s and %s are not strictly increasing" + % (self.indices[i], self.indices[i + 1])) def numNonzeros(self): """ From 95eb65163391b9e910277a948b72efccf6136e0c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 6 Jan 2016 10:50:02 -0800 Subject: [PATCH 1462/2089] [SPARK-11945][ML][PYSPARK] Add computeCost to KMeansModel for PySpark spark.ml Add ```computeCost``` to ```KMeansModel``` as evaluator for PySpark spark.ml. Author: Yanbo Liang Closes #9931 from yanboliang/SPARK-11945. --- python/pyspark/ml/clustering.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7bb8ab94e17df..9189c02220228 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -36,6 +36,14 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + """ + return self._call_java("computeCost", dataset) + @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): @@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction From 3aa3488225af12a77da3ba807906bc6a461ef11c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 6 Jan 2016 10:52:25 -0800 Subject: [PATCH 1463/2089] [SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier & DecisionTreeRegressor should support setSeed PySpark ```DecisionTreeClassifier``` & ```DecisionTreeRegressor``` should support ```setSeed``` like what we do at Scala side. Author: Yanbo Liang Closes #9807 from yanboliang/spark-11815. --- python/pyspark/ml/classification.py | 13 ++++++++----- python/pyspark/ml/regression.py | 14 +++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5599b8f3ecd88..265c6a14f1ca4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -273,7 +273,7 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - TreeClassifierParams, HasCheckpointInterval): + TreeClassifierParams, HasCheckpointInterval, HasSeed): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -313,12 +313,14 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + seed=None) """ super(DecisionTreeClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -335,12 +337,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="gini"): + impurity="gini", seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + seed=None) Sets params for the DecisionTreeClassifier. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a0bb8ceed8861..401bac0223ebd 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval): + DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, + HasSeed): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", + seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None) """ super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -435,11 +438,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance"): + impurity="variance", seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None) Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs From ea489f14f11b2fdfb44c86634d2e2c2167b6ea18 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 6 Jan 2016 11:16:53 -0800 Subject: [PATCH 1464/2089] [SPARK-12573][SPARK-12574][SQL] Move SQL Parser from Hive to Catalyst This PR moves a major part of the new SQL parser to Catalyst. This is a prelude to start using this parser for all of our SQL parsing. The following key changes have been made: The ANTLR Parser & Supporting classes have been moved to the Catalyst project. They are now part of the ```org.apache.spark.sql.catalyst.parser``` package. These classes contained quite a bit of code that was originally from the Hive project, I have added aknowledgements whenever this applied. All Hive dependencies have been factored out. I have also taken this chance to clean-up the ```ASTNode``` class, and to improve the error handling. The HiveQl object that provides the functionality to convert an AST into a LogicalPlan has been refactored into three different classes, one for every SQL sub-project: - ```CatalystQl```: This implements Query and Expression parsing functionality. - ```SparkQl```: This is a subclass of CatalystQL and provides SQL/Core only functionality such as Explain and Describe. - ```HiveQl```: This is a subclass of ```SparkQl``` and this adds Hive-only functionality to the parser such as Analyze, Drop, Views, CTAS & Transforms. This class still depends on Hive. cc rxin Author: Herman van Hovell Closes #10583 from hvanhovell/SPARK-12575. --- dev/deps/spark-deps-hadoop-2.2 | 4 +- dev/deps/spark-deps-hadoop-2.3 | 4 +- dev/deps/spark-deps-hadoop-2.4 | 4 +- dev/deps/spark-deps-hadoop-2.6 | 4 +- pom.xml | 6 + project/SparkBuild.scala | 104 +- sql/catalyst/pom.xml | 22 + .../sql/catalyst}/parser/FromClauseParser.g | 4 +- .../sql/catalyst}/parser/IdentifiersParser.g | 4 +- .../sql/catalyst}/parser/SelectClauseParser.g | 4 +- .../sql/catalyst}/parser/SparkSqlLexer.g | 27 +- .../sql/catalyst}/parser/SparkSqlParser.g | 31 +- .../spark/sql/catalyst/parser/ParseUtils.java | 162 ++ .../spark/sql/catalyst/CatalystQl.scala | 961 ++++++++ .../spark/sql/catalyst/parser/ASTNode.scala | 93 + .../sql/catalyst/parser/ParseDriver.scala | 156 ++ .../sql/catalyst/parser/ParserConf.scala | 26 + .../scala/org/apache/spark/sql/SQLConf.scala | 20 +- .../apache/spark/sql/execution/SparkQl.scala | 84 + sql/hive/pom.xml | 20 - .../apache/spark/sql/parser/ASTErrorNode.java | 49 - .../org/apache/spark/sql/parser/ASTNode.java | 245 -- .../apache/spark/sql/parser/ParseDriver.java | 213 -- .../apache/spark/sql/parser/ParseError.java | 54 - .../spark/sql/parser/ParseException.java | 51 - .../apache/spark/sql/parser/ParseUtils.java | 96 - .../spark/sql/parser/SemanticAnalyzer.java | 406 ---- .../org/apache/spark/sql/hive/HiveQl.scala | 2050 ++++------------- .../spark/sql/hive/ErrorPositionSuite.scala | 9 +- 29 files changed, 2107 insertions(+), 2806 deletions(-) rename sql/{hive/src/main/antlr3/org/apache/spark/sql => catalyst/src/main/antlr3/org/apache/spark/sql/catalyst}/parser/FromClauseParser.g (98%) rename sql/{hive/src/main/antlr3/org/apache/spark/sql => catalyst/src/main/antlr3/org/apache/spark/sql/catalyst}/parser/IdentifiersParser.g (99%) rename sql/{hive/src/main/antlr3/org/apache/spark/sql => catalyst/src/main/antlr3/org/apache/spark/sql/catalyst}/parser/SelectClauseParser.g (97%) rename sql/{hive/src/main/antlr3/org/apache/spark/sql => catalyst/src/main/antlr3/org/apache/spark/sql/catalyst}/parser/SparkSqlLexer.g (93%) rename sql/{hive/src/main/antlr3/org/apache/spark/sql => catalyst/src/main/antlr3/org/apache/spark/sql/catalyst}/parser/SparkSqlParser.g (99%) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 44727f9876deb..e4373f79f7922 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -5,8 +5,7 @@ activation-1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar @@ -179,7 +178,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 6014d50c6b6fd..7478181406d07 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -5,8 +5,7 @@ activation-1.1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar @@ -170,7 +169,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index f56e6f4393e78..faffb8bf398a5 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -5,8 +5,7 @@ activation-1.1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar @@ -171,7 +170,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e37484473db2e..e703c7acd3876 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -5,8 +5,7 @@ activation-1.1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar @@ -177,7 +176,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/pom.xml b/pom.xml index d0ac1eb39aabe..e414a8bfe6ce5 100644 --- a/pom.xml +++ b/pom.xml @@ -183,6 +183,7 @@ 3.5.2 1.3.9 0.9.2 + 3.5.2 ${java.home} @@ -1843,6 +1844,11 @@ + + org.antlr + antlr-runtime + ${antlr.version} + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index af1d36c6ea57b..5d4f19ab14a29 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -247,6 +247,9 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst ANTLR generation settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -357,6 +360,58 @@ object OldDeps { ) } +object Catalyst { + lazy val settings = Seq( + // ANTLR code-generation step. + // + // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of + // build errors in the current plugin. + // Create Parser from ANTLR grammar files. + sourceGenerators in Compile += Def.task { + val log = streams.value.log + + val grammarFileNames = Seq( + "SparkSqlLexer.g", + "SparkSqlParser.g") + val sourceDir = (sourceDirectory in Compile).value / "antlr3" + val targetDir = (sourceManaged in Compile).value + + // Create default ANTLR Tool. + val antlr = new org.antlr.Tool + + // Setup input and output directories. + antlr.setInputDirectory(sourceDir.getPath) + antlr.setOutputDirectory(targetDir.getPath) + antlr.setForceRelativeOutput(true) + antlr.setMake(true) + + // Add grammar files. + grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => + val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath + log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) + antlr.addGrammarFile(relGFilePath) + // We will set library directory multiple times here. However, only the + // last one has effect. Because the grammar files are located under the same directory, + // We assume there is only one library directory. + antlr.setLibDirectory(gFilePath.getParent) + } + + // Generate the parser. + antlr.process + if (antlr.getNumErrors > 0) { + log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) + } + + // Return all generated java files. + (targetDir ** "*.java").get.toSeq + }.taskValue, + // Include ANTLR tokens files. + resourceGenerators in Compile += Def.task { + ((sourceManaged in Compile).value ** "*.tokens").get.toSeq + }.taskValue + ) +} + object SQL { lazy val settings = Seq( initialCommands in console := @@ -414,54 +469,7 @@ object Hive { // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new // new query tests. - fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") }, - // ANTLR code-generation step. - // - // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of - // build errors in the current plugin. - // Create Parser from ANTLR grammar files. - sourceGenerators in Compile += Def.task { - val log = streams.value.log - - val grammarFileNames = Seq( - "SparkSqlLexer.g", - "SparkSqlParser.g") - val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value - - // Create default ANTLR Tool. - val antlr = new org.antlr.Tool - - // Setup input and output directories. - antlr.setInputDirectory(sourceDir.getPath) - antlr.setOutputDirectory(targetDir.getPath) - antlr.setForceRelativeOutput(true) - antlr.setMake(true) - - // Add grammar files. - grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => - val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath - log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) - antlr.addGrammarFile(relGFilePath) - // We will set library directory multiple times here. However, only the - // last one has effect. Because the grammar files are located under the same directory, - // We assume there is only one library directory. - antlr.setLibDirectory(gFilePath.getParent) - } - - // Generate the parser. - antlr.process - if (antlr.getNumErrors > 0) { - log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) - } - - // Return all generated java files. - (targetDir ** "*.java").get.toSeq - }.taskValue, - // Include ANTLR tokens files. - resourceGenerators in Compile += Def.task { - ((sourceManaged in Compile).value ** "*.tokens").get.toSeq - }.taskValue + fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } ) } diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index cfa520b7b9db2..76ca3f3bb1bfa 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -71,6 +71,10 @@ org.codehaus.janino janino + + org.antlr + antlr-runtime + target/scala-${scala.binary.version}/classes @@ -103,6 +107,24 @@ + + org.antlr + antlr3-maven-plugin + + + + antlr + + + + + ../catalyst/src/main/antlr3 + + **/SparkSqlLexer.g + **/SparkSqlParser.g + + + diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g similarity index 98% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index e4a80f0ce8ebf..ba6cfc60f045f 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar. */ parser grammar FromClauseParser; @@ -33,7 +35,7 @@ k=3; @Override public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); + gParent.displayRecognitionError(tokenNames, e); } protected boolean useSQL11ReservedKeywordsForIdentifier() { return gParent.useSQL11ReservedKeywordsForIdentifier(); diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g similarity index 99% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g index 9f1e168374f01..86c6bd610f912 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. */ parser grammar IdentifiersParser; @@ -33,7 +35,7 @@ k=3; @Override public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); + gParent.displayRecognitionError(tokenNames, e); } protected boolean useSQL11ReservedKeywordsForIdentifier() { return gParent.useSQL11ReservedKeywordsForIdentifier(); diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g similarity index 97% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g index 48bc8b0a300af..2d2bafb1ee34f 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar. */ parser grammar SelectClauseParser; @@ -33,7 +35,7 @@ k=3; @Override public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); + gParent.displayRecognitionError(tokenNames, e); } protected boolean useSQL11ReservedKeywordsForIdentifier() { return gParent.useSQL11ReservedKeywordsForIdentifier(); diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g similarity index 93% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index ee1b8989b5aff..e01e7101d0b7e 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -13,26 +13,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar. */ lexer grammar SparkSqlLexer; @lexer::header { -package org.apache.spark.sql.parser; +package org.apache.spark.sql.catalyst.parser; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; } @lexer::members { - private Configuration hiveConf; + private ParserConf parserConf; + private ParseErrorReporter reporter; - public void setHiveConf(Configuration hiveConf) { - this.hiveConf = hiveConf; + public void configure(ParserConf parserConf, ParseErrorReporter reporter) { + this.parserConf = parserConf; + this.reporter = reporter; } protected boolean allowQuotedId() { - String supportedQIds = HiveConf.getVar(hiveConf, HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT); - return !"none".equals(supportedQIds); + if (parserConf == null) { + return true; + } + return parserConf.supportQuotedId(); + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + if (reporter != null) { + reporter.report(this, e, tokenNames); + } } } diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g similarity index 99% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 69574d713d0be..98b46794a630c 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar. */ parser grammar SparkSqlParser; @@ -369,18 +371,15 @@ TOK_SET_AUTOCOMMIT; // Package headers @header { -package org.apache.spark.sql.parser; +package org.apache.spark.sql.catalyst.parser; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; } @members { - ArrayList errors = new ArrayList(); Stack msgs = new Stack(); private static HashMap xlateMap; @@ -563,9 +562,10 @@ import org.apache.hadoop.hive.conf.HiveConf; } @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - errors.add(new ParseError(this, e, tokenNames)); + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + if (reporter != null) { + reporter.report(this, e, tokenNames); + } } @Override @@ -654,15 +654,20 @@ import org.apache.hadoop.hive.conf.HiveConf; private CommonTree throwColumnNameException() throws RecognitionException { throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); } - private Configuration hiveConf; - public void setHiveConf(Configuration hiveConf) { - this.hiveConf = hiveConf; + + private ParserConf parserConf; + private ParseErrorReporter reporter; + + public void configure(ParserConf parserConf, ParseErrorReporter reporter) { + this.parserConf = parserConf; + this.reporter = reporter; } + protected boolean useSQL11ReservedKeywordsForIdentifier() { - if(hiveConf==null){ - return false; + if (parserConf == null) { + return true; } - return !HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS); + return !parserConf.supportSQL11ReservedKeywords(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java new file mode 100644 index 0000000000000..5bc87b680f9ad --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser; + +import java.io.UnsupportedEncodingException; + +/** + * A couple of utility methods that help with parsing ASTs. + * + * Both methods in this class were take from the SemanticAnalyzer in Hive: + * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java + */ +public final class ParseUtils { + private ParseUtils() { + super(); + } + + public static String charSetString(String charSetName, String charSetString) + throws UnsupportedEncodingException { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + return new String(bArray, charSetName); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala new file mode 100644 index 0000000000000..42bdf25b61ea5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -0,0 +1,961 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst + +import java.sql.Date + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. + */ +private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { + object Token { + def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { + CurrentOrigin.setPosition(node.line, node.positionInLine) + node.pattern + } + } + + + /** + * Returns the AST for the given SQL string. + */ + protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf) + + /** Creates LogicalPlan for a given HiveQL string. */ + def createPlan(sql: String): LogicalPlan = { + try { + createPlan(sql, ParseDriver.parse(sql, conf)) + } catch { + case e: MatchError => throw e + case e: AnalysisException => throw e + case e: Exception => + throw new AnalysisException(e.getMessage) + case e: NotImplementedError => + throw new AnalysisException( + s""" + |Unsupported language features in query: $sql + |${getAst(sql).treeString} + |$e + |${e.getStackTrace.head} + """.stripMargin) + } + } + + protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree) + + def parseDdl(ddl: String): Seq[Attribute] = { + val tree = getAst(ddl) + assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.") + val tableOps = tree.children + val colList = tableOps + .find(_.text == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")) + + colList.children.map(nodeToAttribute) + } + + protected def getClauses( + clauseNames: Seq[String], + nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { + var remainingNodes = nodeList + val clauses = clauseNames.map { clauseName => + val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName) + remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) + matches.headOption + } + + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } + clauses + } + + protected def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = + getClauseOption(clauseName, nodeList).getOrElse(sys.error( + s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}")) + + protected def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = { + nodeList.filter { case ast: ASTNode => ast.text == clauseName } match { + case Seq(oneMatch) => Some(oneMatch) + case Seq() => None + case _ => sys.error(s"Found multiple instances of clause $clauseName") + } + } + + protected def nodeToAttribute(node: ASTNode): Attribute = node match { + case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => + AttributeReference(colName, nodeToDataType(dataType), nullable = true)() + case _ => + noParseRule("Attribute", node) + } + + protected def nodeToDataType(node: ASTNode): DataType = node match { + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.text.toInt, scale.text.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.text.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT + case Token("TOK_BIGINT", Nil) => LongType + case Token("TOK_INT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => ByteType + case Token("TOK_SMALLINT", Nil) => ShortType + case Token("TOK_BOOLEAN", Nil) => BooleanType + case Token("TOK_STRING", Nil) => StringType + case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_FLOAT", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType + case Token("TOK_TIMESTAMP", Nil) => TimestampType + case Token("TOK_BINARY", Nil) => BinaryType + case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) + case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => + StructType(fields.map(nodeToStructField)) + case Token("TOK_MAP", keyType :: valueType :: Nil) => + MapType(nodeToDataType(keyType), nodeToDataType(valueType)) + case _ => + noParseRule("DataType", node) + } + + protected def nodeToStructField(node: ASTNode): StructField = node match { + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case _ => + noParseRule("StructField", node) + } + + protected def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = { + tableNameParts.children.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { + case Seq(tableOnly) => TableIdentifier(tableOnly) + case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) + case other => sys.error("Hive only supports tables names like 'tableName' " + + s"or 'databaseName.tableName', found '$other'") + } + } + + /** + * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) + * is equivalent to + * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 + * Check the following link for details. + * +https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup + * + * The bitmask denotes the grouping expressions validity for a grouping set, + * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) + * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of + * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. + */ + protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { + val (keyASTs, setASTs) = children.partition { + case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets + case _ => true // grouping keys + } + + val keys = keyASTs.map(nodeToExpr) + val keyMap = keyASTs.zipWithIndex.toMap + + val bitmasks: Seq[Int] = setASTs.map { + case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 + case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => + columns.foldLeft(0)((bitmap, col) => { + val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) + bitmap | 1 << keyIndex.getOrElse( + throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) + }) + case _ => sys.error("Expect GROUPING SETS clause") + } + + (keys, bitmasks) + } + + protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { + case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => + val cteRelations = ctes.map { node => + val relation = nodeToRelation(node).asInstanceOf[Subquery] + relation.alias -> relation + } + (Some(from.head), inserts, Some(cteRelations.toMap)) + case Token("TOK_FROM", from) :: inserts => + (Some(from.head), inserts, None) + case Token("TOK_INSERT", _) :: Nil => + (None, queryArgs, None) + } + + // Return one query for each insert clause. + val queries = insertClauses.map { + case Token("TOK_INSERT", singleInsert) => + val ( + intoClause :: + destClause :: + selectClause :: + selectDistinctClause :: + whereClause :: + groupByClause :: + rollupGroupByClause :: + cubeGroupByClause :: + groupingSetsClause :: + orderByClause :: + havingClause :: + sortByClause :: + clusterByClause :: + distributeByClause :: + limitClause :: + lateralViewClause :: + windowClause :: Nil) = { + getClauses( + Seq( + "TOK_INSERT_INTO", + "TOK_DESTINATION", + "TOK_SELECT", + "TOK_SELECTDI", + "TOK_WHERE", + "TOK_GROUPBY", + "TOK_ROLLUP_GROUPBY", + "TOK_CUBE_GROUPBY", + "TOK_GROUPING_SETS", + "TOK_ORDERBY", + "TOK_HAVING", + "TOK_SORTBY", + "TOK_CLUSTERBY", + "TOK_DISTRIBUTEBY", + "TOK_LIMIT", + "TOK_LATERAL_VIEW", + "WINDOW"), + singleInsert) + } + + val relations = fromClause match { + case Some(f) => nodeToRelation(f) + case None => OneRowRelation + } + + val withWhere = whereClause.map { whereNode => + val Seq(whereExpr) = whereNode.children + Filter(nodeToExpr(whereExpr), relations) + }.getOrElse(relations) + + val select = (selectClause orElse selectDistinctClause) + .getOrElse(sys.error("No select clause.")) + + val transformation = nodeToTransformation(select.children.head, withWhere) + + val withLateralView = lateralViewClause.map { lv => + nodeToGenerate(lv.children.head, outer = false, withWhere) + }.getOrElse(withWhere) + + // The projection of the query can either be a normal projection, an aggregation + // (if there is a group by) or a script transformation. + val withProject: LogicalPlan = transformation.getOrElse { + val selectExpressions = + select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) + Seq( + groupByClause.map(e => e match { + case Token("TOK_GROUPBY", children) => + // Not a transformation so must be either project or aggregation. + Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) + case _ => sys.error("Expect GROUP BY") + }), + groupingSetsClause.map(e => e match { + case Token("TOK_GROUPING_SETS", children) => + val(groupByExprs, masks) = extractGroupingSet(children) + GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) + case _ => sys.error("Expect GROUPING SETS") + }), + rollupGroupByClause.map(e => e match { + case Token("TOK_ROLLUP_GROUPBY", children) => + Aggregate( + Seq(Rollup(children.map(nodeToExpr))), + selectExpressions, + withLateralView) + case _ => sys.error("Expect WITH ROLLUP") + }), + cubeGroupByClause.map(e => e match { + case Token("TOK_CUBE_GROUPBY", children) => + Aggregate( + Seq(Cube(children.map(nodeToExpr))), + selectExpressions, + withLateralView) + case _ => sys.error("Expect WITH CUBE") + }), + Some(Project(selectExpressions, withLateralView))).flatten.head + } + + // Handle HAVING clause. + val withHaving = havingClause.map { h => + val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) } + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(havingExpr, BooleanType), withProject) + }.getOrElse(withProject) + + // Handle SELECT DISTINCT + val withDistinct = + if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withSort = + (orderByClause, sortByClause, distributeByClause, clusterByClause) match { + case (Some(totalOrdering), None, None, None) => + Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct) + case (None, Some(perPartitionOrdering), None, None) => + Sort( + perPartitionOrdering.children.map(nodeToSortOrder), + global = false, withDistinct) + case (None, None, Some(partitionExprs), None) => + RepartitionByExpression( + partitionExprs.children.map(nodeToExpr), withDistinct) + case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => + Sort( + perPartitionOrdering.children.map(nodeToSortOrder), global = false, + RepartitionByExpression( + partitionExprs.children.map(nodeToExpr), + withDistinct)) + case (None, None, None, Some(clusterExprs)) => + Sort( + clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression( + clusterExprs.children.map(nodeToExpr), + withDistinct)) + case (None, None, None, None) => withDistinct + case _ => sys.error("Unsupported set of ordering / distribution clauses.") + } + + val withLimit = + limitClause.map(l => nodeToExpr(l.children.head)) + .map(Limit(_, withSort)) + .getOrElse(withSort) + + // Collect all window specifications defined in the WINDOW clause. + val windowDefinitions = windowClause.map(_.children.collect { + case Token("TOK_WINDOWDEF", + Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => + windowName -> nodesToWindowSpecification(spec) + }.toMap) + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val resolvedCrossReference = windowDefinitions.map { + windowDefMap => windowDefMap.map { + case (windowName, WindowSpecReference(other)) => + (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) + case o => o.asInstanceOf[(String, WindowSpecDefinition)] + } + } + + val withWindowDefinitions = + resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) + + // TOK_INSERT_INTO means to add files to the table. + // TOK_DESTINATION means to overwrite the table. + val resultDestination = + (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) + val overwrite = intoClause.isEmpty + nodeToDest( + resultDestination, + withWindowDefinitions, + overwrite) + } + + // If there are multiple INSERTS just UNION them together into on query. + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) + + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => + Union(nodeToPlan(left), nodeToPlan(right)) + + case _ => + noParseRule("Plan", node) + } + + val allJoinTokens = "(TOK_.*JOIN)".r + val laterViewToken = "TOK_LATERAL_VIEW(.*)".r + protected def nodeToRelation(node: ASTNode): LogicalPlan = { + node match { + case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => + Subquery(cleanIdentifier(alias), nodeToPlan(query)) + + case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => + nodeToGenerate( + selectClause, + outer = isOuter.nonEmpty, + nodeToRelation(relationClause)) + + /* All relations, possibly with aliases or sampling clauses. */ + case Token("TOK_TABREF", clauses) => + // If the last clause is not a token then it's the alias of the table. + val (nonAliasClauses, aliasClause) = + if (clauses.last.text.startsWith("TOK")) { + (clauses, None) + } else { + (clauses.dropRight(1), Some(clauses.last)) + } + + val (Some(tableNameParts) :: + splitSampleClause :: + bucketSampleClause :: Nil) = { + getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), + nonAliasClauses) + } + + val tableIdent = extractTableIdent(tableNameParts) + val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } + val relation = UnresolvedRelation(tableIdent, alias) + + // Apply sampling if requested. + (bucketSampleClause orElse splitSampleClause).map { + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => + Limit(Literal(count.toInt), relation) + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + require( + fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) + && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 100]") + Sample(0.0, fraction.toDouble / 100, withReplacement = false, + (math.random * 1000).toInt, + relation) + case Token("TOK_TABLEBUCKETSAMPLE", + Token(numerator, Nil) :: + Token(denominator, Nil) :: Nil) => + val fraction = numerator.toDouble / denominator.toDouble + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) + case a => + noParseRule("Sampling", a) + }.getOrElse(relation) + + case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + + val joinType = joinToken match { + case "TOK_JOIN" => Inner + case "TOK_CROSSJOIN" => Inner + case "TOK_RIGHTOUTERJOIN" => RightOuter + case "TOK_LEFTOUTERJOIN" => LeftOuter + case "TOK_FULLOUTERJOIN" => FullOuter + case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) + case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + } + Join(nodeToRelation(relation1), + nodeToRelation(relation2), + joinType, + other.headOption.map(nodeToExpr)) + + case _ => + noParseRule("Relation", node) + } + } + + protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { + case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Ascending) + case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Descending) + case _ => + noParseRule("SortOrder", node) + } + + val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r + protected def nodeToDest( + node: ASTNode, + query: LogicalPlan, + overwrite: Boolean): LogicalPlan = node match { + case Token(destinationToken(), + Token("TOK_DIR", + Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => + query + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.children.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.children.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true) + + case _ => + noParseRule("Destination", node) + } + + protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match { + case Token("TOK_SELEXPR", e :: Nil) => + Some(nodeToExpr(e)) + + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => + Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + + case Token("TOK_SELEXPR", e :: aliasChildren) => + val aliasNames = aliasChildren.collect { + case Token(name, Nil) => cleanIdentifier(name) + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + + /* Hints are ignored */ + case Token("TOK_HINTLIST", _) => None + + case _ => + noParseRule("Select", node) + } + + protected val escapedIdentifier = "`([^`]+)`".r + protected val doubleQuotedString = "\"([^\"]+)\"".r + protected val singleQuotedString = "'([^']+)'".r + + protected def unquoteString(str: String) = str match { + case singleQuotedString(s) => s + case doubleQuotedString(s) => s + case other => other + } + + /** Strips backticks from ident if present */ + protected def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + + val numericAstTypes = Seq( + SparkSqlParser.Number, + SparkSqlParser.TinyintLiteral, + SparkSqlParser.SmallintLiteral, + SparkSqlParser.BigintLiteral, + SparkSqlParser.DecimalLiteral) + + /* Case insensitive matches */ + val COUNT = "(?i)COUNT".r + val SUM = "(?i)SUM".r + val AND = "(?i)AND".r + val OR = "(?i)OR".r + val NOT = "(?i)NOT".r + val TRUE = "(?i)TRUE".r + val FALSE = "(?i)FALSE".r + val LIKE = "(?i)LIKE".r + val RLIKE = "(?i)RLIKE".r + val REGEXP = "(?i)REGEXP".r + val IN = "(?i)IN".r + val DIV = "(?i)DIV".r + val BETWEEN = "(?i)BETWEEN".r + val WHEN = "(?i)WHEN".r + val CASE = "(?i)CASE".r + + protected def nodeToExpr(node: ASTNode): Expression = node match { + /* Attribute References */ + case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => + UnresolvedAttribute.quoted(cleanIdentifier(name)) + case Token(".", qualifier :: Token(attr, Nil) :: Nil) => + nodeToExpr(qualifier) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) + case other => UnresolvedExtractValue(other, Literal(attr)) + } + + /* Stars (*) */ + case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) + // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only + // has a single child which is tableName. + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => + UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) + + /* Aggregate Functions */ + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() + + /* Casts */ + case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), IntegerType) + case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), LongType) + case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), FloatType) + case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DoubleType) + case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ShortType) + case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ByteType) + case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BinaryType) + case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) + case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), TimestampType) + case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DateType) + + /* Arithmetic */ + case Token("+", child :: Nil) => nodeToExpr(child) + case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) + case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) + case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) + case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) + case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) + case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token(DIV(), left :: right:: Nil) => + Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) + case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) + case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) + case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) + + /* Comparisons */ + case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) + case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) + case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) + case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => + IsNotNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => + IsNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => + In(nodeToExpr(value), list.map(nodeToExpr)) + case Token("TOK_FUNCTION", + Token(BETWEEN(), Nil) :: + kw :: + target :: + minValue :: + maxValue :: Nil) => + + val targetExpression = nodeToExpr(target) + val betweenExpr = + And( + GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), + LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + kw match { + case Token("KW_FALSE", Nil) => betweenExpr + case Token("KW_TRUE", Nil) => Not(betweenExpr) + } + + /* Boolean Logic */ + case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) + case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) + case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + case Token("!", child :: Nil) => Not(nodeToExpr(child)) + + /* Case statements */ + case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => + CaseWhen(branches.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => + val keyExpr = nodeToExpr(branches.head) + CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) + + /* Complex datatype manipulation */ + case Token("[", child :: ordinal :: Nil) => + UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) + + /* Window Functions */ + case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(node.copy(children = node.children.init)) + nodesToWindowSpecification(spec) match { + case reference: WindowSpecReference => + UnresolvedWindowExpression(function, reference) + case definition: WindowSpecDefinition => + WindowExpression(function, definition) + } + + /* UDFs - Must be last otherwise will preempt built in functions */ + case Token("TOK_FUNCTION", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) + + /* Literals */ + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) + case Token("TOK_STRINGLITERALSEQUENCE", strings) => + Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) + + // This code is adapted from + // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 + case ast: ASTNode if numericAstTypes contains ast.tokenType => + var v: Literal = null + try { + if (ast.text.endsWith("L")) { + // Literal bigint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) + } else if (ast.text.endsWith("S")) { + // Literal smallint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) + } else if (ast.text.endsWith("Y")) { + // Literal tinyint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) + } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) { + // Literal decimal + val strVal = ast.text.stripSuffix("D").stripSuffix("B") + v = Literal(Decimal(strVal)) + } else { + v = Literal.create(ast.text.toDouble, DoubleType) + v = Literal.create(ast.text.toLong, LongType) + v = Literal.create(ast.text.toInt, IntegerType) + } + } catch { + case nfe: NumberFormatException => // Do nothing + } + + if (v == null) { + sys.error(s"Failed to parse number '${ast.text}'.") + } else { + v + } + + case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral => + Literal(ParseUtils.unescapeSQLString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL => + Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.text)) + + case _ => + noParseRule("Expression", node) + } + + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r + protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { + case Token(windowName, Nil) :: Nil => + // Refer to a window spec defined in the window clause. + WindowSpecReference(windowName) + case Nil => + // OVER() + WindowSpecDefinition( + partitionSpec = Nil, + orderSpec = Nil, + frameSpecification = UnspecifiedFrame) + case spec => + val (partitionClause :: rowFrame :: rangeFrame :: Nil) = + getClauses( + Seq( + "TOK_PARTITIONINGSPEC", + "TOK_WINDOWRANGE", + "TOK_WINDOWVALUES"), + spec) + + // Handle Partition By and Order By. + val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => + val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = + getClauses( + Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), + partitionAndOrdering.children) + + (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { + case (Some(partitionByExpr), Some(orderByExpr), None) => + (partitionByExpr.children.map(nodeToExpr), + orderByExpr.children.map(nodeToSortOrder)) + case (Some(partitionByExpr), None, None) => + (partitionByExpr.children.map(nodeToExpr), Nil) + case (None, Some(orderByExpr), None) => + (Nil, orderByExpr.children.map(nodeToSortOrder)) + case (None, None, Some(clusterByExpr)) => + val expressions = clusterByExpr.children.map(nodeToExpr) + (expressions, expressions.map(SortOrder(_, Ascending))) + case _ => + noParseRule("Partition & Ordering", partitionAndOrdering) + } + }.getOrElse { + (Nil, Nil) + } + + // Handle Window Frame + val windowFrame = + if (rowFrame.isEmpty && rangeFrame.isEmpty) { + UnspecifiedFrame + } else { + val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) + def nodeToBoundary(node: ASTNode): FrameBoundary = node match { + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow + case _ => + noParseRule("Window Frame Boundary", node) + } + + rowFrame.orElse(rangeFrame).map { frame => + frame.children match { + case precedingNode :: followingNode :: Nil => + SpecifiedWindowFrame( + frameType, + nodeToBoundary(precedingNode), + nodeToBoundary(followingNode)) + case precedingNode :: Nil => + SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) + case _ => + noParseRule("Window Frame", frame) + } + }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) + } + + WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + } + + protected def nodeToTransformation( + node: ASTNode, + child: LogicalPlan): Option[ScriptTransformation] = None + + val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r + protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { + val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node + + val alias = getClause("TOK_TABALIAS", clauses).children.head.text + + val generator = clauses.head match { + case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => + Explode(nodeToExpr(childNode)) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + JsonTuple(children.map(nodeToExpr)) + case other => + nodeToGenerator(other) + } + + val attributes = clauses.collect { + case Token(a, Nil) => UnresolvedAttribute(a.toLowerCase) + } + + Generate(generator, join = true, outer = outer, Some(alias.toLowerCase), attributes, child) + } + + protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node) + + protected def noParseRule(msg: String, node: ASTNode): Nothing = throw new NotImplementedError( + s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala new file mode 100644 index 0000000000000..ec5e71042d4be --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.runtime.{Token, TokenRewriteStream} + +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} + +case class ASTNode( + token: Token, + startIndex: Int, + stopIndex: Int, + children: List[ASTNode], + stream: TokenRewriteStream) extends TreeNode[ASTNode] { + /** Cache the number of children. */ + val numChildren = children.size + + /** tuple used in pattern matching. */ + val pattern = Some((token.getText, children)) + + /** Line in which the ASTNode starts. */ + lazy val line: Int = { + val line = token.getLine + if (line == 0) { + if (children.nonEmpty) children.head.line + else 0 + } else { + line + } + } + + /** Position of the Character at which ASTNode starts. */ + lazy val positionInLine: Int = { + val line = token.getCharPositionInLine + if (line == -1) { + if (children.nonEmpty) children.head.positionInLine + else 0 + } else { + line + } + } + + /** Origin of the ASTNode. */ + override val origin = Origin(Some(line), Some(positionInLine)) + + /** Source text. */ + lazy val source = stream.toString(startIndex, stopIndex) + + def text: String = token.getText + + def tokenType: Int = token.getType + + /** + * Checks if this node is equal to another node. + * + * Right now this function only checks the name, type, text and children of the node + * for equality. + */ + def treeEquals(other: ASTNode): Boolean = { + def check(f: ASTNode => Any): Boolean = { + val l = f(this) + val r = f(other) + (l == null && r == null) || l.equals(r) + } + if (other == null) { + false + } else if (!check(_.token.getType) + || !check(_.token.getText) + || !check(_.numChildren)) { + false + } else { + children.zip(other.children).forall { + case (l, r) => l treeEquals r + } + } + } + + override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine " +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala new file mode 100644 index 0000000000000..0e93af8b92cd2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.runtime._ +import org.antlr.runtime.tree.CommonTree + +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException + +/** + * The ParseDriver takes a SQL command and turns this into an AST. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + */ +object ParseDriver extends Logging { + def parse(command: String, conf: ParserConf): ASTNode = { + logInfo(s"Parsing command: $command") + + // Setup error collection. + val reporter = new ParseErrorReporter() + + // Create lexer. + val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) + val tokens = new TokenRewriteStream(lexer) + lexer.configure(conf, reporter) + + // Create the parser. + val parser = new SparkSqlParser(tokens) + parser.configure(conf, reporter) + + try { + val result = parser.statement() + + // Check errors. + reporter.checkForErrors() + + // Return the AST node from the result. + logInfo(s"Parse completed.") + + // Find the non null token tree in the result. + def nonNullToken(tree: CommonTree): CommonTree = { + if (tree.token != null || tree.getChildCount == 0) tree + else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + } + val tree = nonNullToken(result.getTree) + + // Make sure all boundaries are set. + tree.setUnknownTokenBoundaries() + + // Construct the immutable AST. + def createASTNode(tree: CommonTree): ASTNode = { + val children = (0 until tree.getChildCount).map { i => + createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) + }.toList + ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + } + createASTNode(tree) + } + catch { + case e: RecognitionException => + logInfo(s"Parse failed.") + reporter.throwError(e) + } + } +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == CharStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Utility used by the Parser and the Lexer for error collection and reporting. + */ +private[parser] class ParseErrorReporter { + val errors = scala.collection.mutable.Buffer.empty[ParseError] + + def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { + errors += ParseError(br, re, tokenNames) + } + + def checkForErrors(): Unit = { + if (errors.nonEmpty) { + val first = errors.head + val e = first.re + throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) + } + } + + def throwError(e: RecognitionException): Nothing = { + throwError(e.line, e.charPositionInLine, e.toString, errors) + } + + private def throwError( + line: Int, + startPosition: Int, + msg: String, + errors: Seq[ParseError]): Nothing = { + val b = new StringBuilder + b.append(msg).append("\n") + errors.foreach(error => error.buildMessage(b).append("\n")) + throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + } +} + +/** + * Error collected during the parsing process. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + */ +private[parser] case class ParseError( + br: BaseRecognizer, + re: RecognitionException, + tokenNames: Array[String]) { + def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { + s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala new file mode 100644 index 0000000000000..ce449b11431a5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +trait ParserConf { + def supportQuotedId: Boolean + def supportSQL11ReservedKeywords: Boolean +} + +case class SimpleParserConf( + supportQuotedId: Boolean = true, + supportSQL11ReservedKeywords: Boolean = false) extends ParserConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b58a3739912bc..26c00dc250b4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.parser.ParserConf //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -451,6 +452,19 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) + val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers", + defaultValue = Some(true), + isPublic = false, + doc = "Whether to use quoted identifier.\n false: default(past) behavior. Implies only" + + "alphaNumeric and underscore are valid characters in identifiers.\n" + + " true: implies column names can contain any character.") + + val PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS = booleanConf( + "spark.sql.parser.supportSQL11ReservedKeywords", + defaultValue = Some(false), + isPublic = false, + doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -471,7 +485,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf { +private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -569,6 +583,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) + + def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala new file mode 100644 index 0000000000000..a322688a259e2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} + +private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { + /** Check if a command should not be explained. */ + protected def isNoExplainCommand(command: String): Boolean = "TOK_DESCTABLE" == command + + protected override def nodeToPlan(node: ASTNode): LogicalPlan = { + node match { + // Just fake explain for any of the native commands. + case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => + ExplainCommand(OneRowRelation) + + case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.text => + val Some(crtTbl) :: _ :: extended :: Nil = + getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) + ExplainCommand(nodeToPlan(crtTbl), extended = extended.isDefined) + + case Token("TOK_EXPLAIN", explainArgs) => + // Ignore FORMATTED if present. + val Some(query) :: _ :: extended :: Nil = + getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) + ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + + case Token("TOK_DESCTABLE", describeArgs) => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val Some(tableType) :: formatted :: extended :: pretty :: Nil = + getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) + if (formatted.isDefined || pretty.isDefined) { + // FORMATTED and PRETTY are not supported and this statement will be treated as + // a Hive native command. + nodeToDescribeFallback(node) + } else { + tableType match { + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => + nameParts match { + case Token(".", dbName :: tableName :: Nil) => + // It is describing a table with the format like "describe db.table". + // TODO: Actually, a user may mean tableName.columnName. Need to resolve this + // issue. + val tableIdent = extractTableIdent(nameParts) + datasources.DescribeCommand( + UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) + case Token(".", dbName :: tableName :: colName :: Nil) => + // It is describing a column with the format like "describe db.table column". + nodeToDescribeFallback(node) + case tableName => + // It is describing a table with the format like "describe table". + datasources.DescribeCommand( + UnresolvedRelation(TableIdentifier(tableName.text), None), + isExtended = extended.isDefined) + } + // All other cases. + case _ => nodeToDescribeFallback(node) + } + } + + case _ => + super.nodeToPlan(node) + } + } + + protected def nodeToDescribeFallback(node: ASTNode): LogicalPlan = noParseRule("Describe", node) +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ffabb92179a18..cd0c2aeb93a9f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -262,26 +262,6 @@ - - - org.antlr - antlr3-maven-plugin - - - - antlr - - - - - ${basedir}/src/main/antlr3 - - **/SparkSqlLexer.g - **/SparkSqlParser.g - - - - diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java deleted file mode 100644 index 35ecdc5ad10a9..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.antlr.runtime.RecognitionException; -import org.antlr.runtime.Token; -import org.antlr.runtime.TokenStream; -import org.antlr.runtime.tree.CommonErrorNode; - -public class ASTErrorNode extends ASTNode { - - /** - * - */ - private static final long serialVersionUID = 1L; - CommonErrorNode delegate; - - public ASTErrorNode(TokenStream input, Token start, Token stop, - RecognitionException e){ - delegate = new CommonErrorNode(input,start,stop,e); - } - - @Override - public boolean isNil() { return delegate.isNil(); } - - @Override - public int getType() { return delegate.getType(); } - - @Override - public String getText() { return delegate.getText(); } - @Override - public String toString() { return delegate.toString(); } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java deleted file mode 100644 index 33d9322b628ec..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java +++ /dev/null @@ -1,245 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -import org.antlr.runtime.Token; -import org.antlr.runtime.tree.CommonTree; -import org.antlr.runtime.tree.Tree; -import org.apache.hadoop.hive.ql.lib.Node; - -public class ASTNode extends CommonTree implements Node, Serializable { - private static final long serialVersionUID = 1L; - private transient StringBuffer astStr; - private transient int startIndx = -1; - private transient int endIndx = -1; - private transient ASTNode rootNode; - private transient boolean isValidASTStr; - - public ASTNode() { - } - - /** - * Constructor. - * - * @param t - * Token for the CommonTree Node - */ - public ASTNode(Token t) { - super(t); - } - - public ASTNode(ASTNode node) { - super(node); - } - - @Override - public Tree dupNode() { - return new ASTNode(this); - } - - /* - * (non-Javadoc) - * - * @see org.apache.hadoop.hive.ql.lib.Node#getChildren() - */ - @Override - public ArrayList getChildren() { - if (super.getChildCount() == 0) { - return null; - } - - ArrayList ret_vec = new ArrayList(); - for (int i = 0; i < super.getChildCount(); ++i) { - ret_vec.add((Node) super.getChild(i)); - } - - return ret_vec; - } - - /* - * (non-Javadoc) - * - * @see org.apache.hadoop.hive.ql.lib.Node#getName() - */ - @Override - public String getName() { - return (Integer.valueOf(super.getToken().getType())).toString(); - } - - public String dump() { - StringBuilder sb = new StringBuilder("\n"); - dump(sb, ""); - return sb.toString(); - } - - private StringBuilder dump(StringBuilder sb, String ws) { - sb.append(ws); - sb.append(toString()); - sb.append("\n"); - - ArrayList children = getChildren(); - if (children != null) { - for (Node node : getChildren()) { - if (node instanceof ASTNode) { - ((ASTNode) node).dump(sb, ws + " "); - } else { - sb.append(ws); - sb.append(" NON-ASTNODE!!"); - sb.append("\n"); - } - } - } - return sb; - } - - private ASTNode getRootNodeWithValidASTStr(boolean useMemoizedRoot) { - if (useMemoizedRoot && rootNode != null && rootNode.parent == null && - rootNode.hasValidMemoizedString()) { - return rootNode; - } - ASTNode retNode = this; - while (retNode.parent != null) { - retNode = (ASTNode) retNode.parent; - } - rootNode=retNode; - if (!rootNode.isValidASTStr) { - rootNode.astStr = new StringBuffer(); - rootNode.toStringTree(rootNode); - rootNode.isValidASTStr = true; - } - return retNode; - } - - private boolean hasValidMemoizedString() { - return isValidASTStr && astStr != null; - } - - private void resetRootInformation() { - // Reset the previously stored rootNode string - if (rootNode != null) { - rootNode.astStr = null; - rootNode.isValidASTStr = false; - } - } - - private int getMemoizedStringLen() { - return astStr == null ? 0 : astStr.length(); - } - - private String getMemoizedSubString(int start, int end) { - return (astStr == null || start < 0 || end > astStr.length() || start >= end) ? null : - astStr.subSequence(start, end).toString(); - } - - private void addtoMemoizedString(String string) { - if (astStr == null) { - astStr = new StringBuffer(); - } - astStr.append(string); - } - - @Override - public void setParent(Tree t) { - super.setParent(t); - resetRootInformation(); - } - - @Override - public void addChild(Tree t) { - super.addChild(t); - resetRootInformation(); - } - - @Override - public void addChildren(List kids) { - super.addChildren(kids); - resetRootInformation(); - } - - @Override - public void setChild(int i, Tree t) { - super.setChild(i, t); - resetRootInformation(); - } - - @Override - public void insertChild(int i, Object t) { - super.insertChild(i, t); - resetRootInformation(); - } - - @Override - public Object deleteChild(int i) { - Object ret = super.deleteChild(i); - resetRootInformation(); - return ret; - } - - @Override - public void replaceChildren(int startChildIndex, int stopChildIndex, Object t) { - super.replaceChildren(startChildIndex, stopChildIndex, t); - resetRootInformation(); - } - - @Override - public String toStringTree() { - - // The root might have changed because of tree modifications. - // Compute the new root for this tree and set the astStr. - getRootNodeWithValidASTStr(true); - - // If rootNotModified is false, then startIndx and endIndx will be stale. - if (startIndx >= 0 && endIndx <= rootNode.getMemoizedStringLen()) { - return rootNode.getMemoizedSubString(startIndx, endIndx); - } - return toStringTree(rootNode); - } - - private String toStringTree(ASTNode rootNode) { - this.rootNode = rootNode; - startIndx = rootNode.getMemoizedStringLen(); - // Leaf node - if ( children==null || children.size()==0 ) { - rootNode.addtoMemoizedString(this.toString()); - endIndx = rootNode.getMemoizedStringLen(); - return this.toString(); - } - if ( !isNil() ) { - rootNode.addtoMemoizedString("("); - rootNode.addtoMemoizedString(this.toString()); - rootNode.addtoMemoizedString(" "); - } - for (int i = 0; children!=null && i < children.size(); i++) { - ASTNode t = (ASTNode)children.get(i); - if ( i>0 ) { - rootNode.addtoMemoizedString(" "); - } - t.toStringTree(rootNode); - } - if ( !isNil() ) { - rootNode.addtoMemoizedString(")"); - } - endIndx = rootNode.getMemoizedStringLen(); - return rootNode.getMemoizedSubString(startIndx, endIndx); - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java deleted file mode 100644 index c77198b087cbd..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.util.ArrayList; -import org.antlr.runtime.ANTLRStringStream; -import org.antlr.runtime.CharStream; -import org.antlr.runtime.NoViableAltException; -import org.antlr.runtime.RecognitionException; -import org.antlr.runtime.Token; -import org.antlr.runtime.TokenRewriteStream; -import org.antlr.runtime.TokenStream; -import org.antlr.runtime.tree.CommonTree; -import org.antlr.runtime.tree.CommonTreeAdaptor; -import org.antlr.runtime.tree.TreeAdaptor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.apache.hadoop.hive.ql.Context; - -/** - * ParseDriver. - * - */ -public class ParseDriver { - - private static final Logger LOG = LoggerFactory.getLogger("hive.ql.parse.ParseDriver"); - - /** - * ANTLRNoCaseStringStream. - * - */ - //This class provides and implementation for a case insensitive token checker - //for the lexical analysis part of antlr. By converting the token stream into - //upper case at the time when lexical rules are checked, this class ensures that the - //lexical rules need to just match the token with upper case letters as opposed to - //combination of upper case and lower case characters. This is purely used for matching lexical - //rules. The actual token text is stored in the same way as the user input without - //actually converting it into an upper case. The token values are generated by the consume() - //function of the super class ANTLRStringStream. The LA() function is the lookahead function - //and is purely used for matching lexical rules. This also means that the grammar will only - //accept capitalized tokens in case it is run from other tools like antlrworks which - //do not have the ANTLRNoCaseStringStream implementation. - public class ANTLRNoCaseStringStream extends ANTLRStringStream { - - public ANTLRNoCaseStringStream(String input) { - super(input); - } - - @Override - public int LA(int i) { - - int returnChar = super.LA(i); - if (returnChar == CharStream.EOF) { - return returnChar; - } else if (returnChar == 0) { - return returnChar; - } - - return Character.toUpperCase((char) returnChar); - } - } - - /** - * HiveLexerX. - * - */ - public class HiveLexerX extends SparkSqlLexer { - - private final ArrayList errors; - - public HiveLexerX(CharStream input) { - super(input); - errors = new ArrayList(); - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - errors.add(new ParseError(this, e, tokenNames)); - } - - @Override - public String getErrorMessage(RecognitionException e, String[] tokenNames) { - String msg = null; - - if (e instanceof NoViableAltException) { - // @SuppressWarnings("unused") - // NoViableAltException nvae = (NoViableAltException) e; - // for development, can add - // "decision=<<"+nvae.grammarDecisionDescription+">>" - // and "(decision="+nvae.decisionNumber+") and - // "state "+nvae.stateNumber - msg = "character " + getCharErrorDisplay(e.c) + " not supported here"; - } else { - msg = super.getErrorMessage(e, tokenNames); - } - - return msg; - } - - public ArrayList getErrors() { - return errors; - } - - } - - /** - * Tree adaptor for making antlr return ASTNodes instead of CommonTree nodes - * so that the graph walking algorithms and the rules framework defined in - * ql.lib can be used with the AST Nodes. - */ - public static final TreeAdaptor adaptor = new CommonTreeAdaptor() { - /** - * Creates an ASTNode for the given token. The ASTNode is a wrapper around - * antlr's CommonTree class that implements the Node interface. - * - * @param payload - * The token. - * @return Object (which is actually an ASTNode) for the token. - */ - @Override - public Object create(Token payload) { - return new ASTNode(payload); - } - - @Override - public Object dupNode(Object t) { - - return create(((CommonTree)t).token); - }; - - @Override - public Object errorNode(TokenStream input, Token start, Token stop, RecognitionException e) { - return new ASTErrorNode(input, start, stop, e); - }; - }; - - public ASTNode parse(String command) throws ParseException { - return parse(command, null); - } - - public ASTNode parse(String command, Context ctx) - throws ParseException { - return parse(command, ctx, true); - } - - /** - * Parses a command, optionally assigning the parser's token stream to the - * given context. - * - * @param command - * command to parse - * - * @param ctx - * context with which to associate this parser's token stream, or - * null if either no context is available or the context already has - * an existing stream - * - * @return parsed AST - */ - public ASTNode parse(String command, Context ctx, boolean setTokenRewriteStream) - throws ParseException { - LOG.info("Parsing command: " + command); - - HiveLexerX lexer = new HiveLexerX(new ANTLRNoCaseStringStream(command)); - TokenRewriteStream tokens = new TokenRewriteStream(lexer); - if (ctx != null) { - if ( setTokenRewriteStream) { - ctx.setTokenRewriteStream(tokens); - } - lexer.setHiveConf(ctx.getConf()); - } - SparkSqlParser parser = new SparkSqlParser(tokens); - if (ctx != null) { - parser.setHiveConf(ctx.getConf()); - } - parser.setTreeAdaptor(adaptor); - SparkSqlParser.statement_return r = null; - try { - r = parser.statement(); - } catch (RecognitionException e) { - e.printStackTrace(); - throw new ParseException(parser.errors); - } - - if (lexer.getErrors().size() == 0 && parser.errors.size() == 0) { - LOG.info("Parse Completed"); - } else if (lexer.getErrors().size() != 0) { - throw new ParseException(lexer.getErrors()); - } else { - throw new ParseException(parser.errors); - } - - ASTNode tree = (ASTNode) r.getTree(); - tree.setUnknownTokenBoundaries(); - return tree; - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java deleted file mode 100644 index b47bcfb2914df..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.antlr.runtime.BaseRecognizer; -import org.antlr.runtime.RecognitionException; - -/** - * - */ -public class ParseError { - private final BaseRecognizer br; - private final RecognitionException re; - private final String[] tokenNames; - - ParseError(BaseRecognizer br, RecognitionException re, String[] tokenNames) { - this.br = br; - this.re = re; - this.tokenNames = tokenNames; - } - - BaseRecognizer getBaseRecognizer() { - return br; - } - - RecognitionException getRecognitionException() { - return re; - } - - String[] getTokenNames() { - return tokenNames; - } - - String getMessage() { - return br.getErrorHeader(re) + " " + br.getErrorMessage(re, tokenNames); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java deleted file mode 100644 index fff891ced5550..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.util.ArrayList; - -/** - * ParseException. - * - */ -public class ParseException extends Exception { - - private static final long serialVersionUID = 1L; - ArrayList errors; - - public ParseException(ArrayList errors) { - super(); - this.errors = errors; - } - - @Override - public String getMessage() { - - StringBuilder sb = new StringBuilder(); - for (ParseError err : errors) { - if (sb.length() > 0) { - sb.append('\n'); - } - sb.append(err.getMessage()); - } - - return sb.toString(); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java deleted file mode 100644 index a5c2998f86cc1..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.apache.hadoop.hive.common.type.HiveDecimal; -import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; - - -/** - * Library of utility functions used in the parse code. - * - */ -public final class ParseUtils { - /** - * Performs a descent of the leftmost branch of a tree, stopping when either a - * node with a non-null token is found or the leaf level is encountered. - * - * @param tree - * candidate node from which to start searching - * - * @return node at which descent stopped - */ - public static ASTNode findRootNonNullToken(ASTNode tree) { - while ((tree.getToken() == null) && (tree.getChildCount() > 0)) { - tree = (org.apache.spark.sql.parser.ASTNode) tree.getChild(0); - } - return tree; - } - - private ParseUtils() { - // prevent instantiation - } - - public static VarcharTypeInfo getVarcharTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() != 1) { - throw new SemanticException("Bad params for type varchar"); - } - - String lengthStr = node.getChild(0).getText(); - return TypeInfoFactory.getVarcharTypeInfo(Integer.valueOf(lengthStr)); - } - - public static CharTypeInfo getCharTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() != 1) { - throw new SemanticException("Bad params for type char"); - } - - String lengthStr = node.getChild(0).getText(); - return TypeInfoFactory.getCharTypeInfo(Integer.valueOf(lengthStr)); - } - - public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() > 2) { - throw new SemanticException("Bad params for type decimal"); - } - - int precision = HiveDecimal.USER_DEFAULT_PRECISION; - int scale = HiveDecimal.USER_DEFAULT_SCALE; - - if (node.getChildCount() >= 1) { - String precStr = node.getChild(0).getText(); - precision = Integer.valueOf(precStr); - } - - if (node.getChildCount() == 2) { - String scaleStr = node.getChild(1).getText(); - scale = Integer.valueOf(scaleStr); - } - - return TypeInfoFactory.getDecimalTypeInfo(precision, scale); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java deleted file mode 100644 index 4b2015e0df84e..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java +++ /dev/null @@ -1,406 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.io.UnsupportedEncodingException; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.antlr.runtime.tree.Tree; -import org.apache.commons.lang.StringUtils; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.metastore.api.FieldSchema; -import org.apache.hadoop.hive.ql.ErrorMsg; -import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.serde.serdeConstants; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; - -/** - * SemanticAnalyzer. - * - */ -public abstract class SemanticAnalyzer { - public static String charSetString(String charSetName, String charSetString) - throws SemanticException { - try { - // The character set name starts with a _, so strip that - charSetName = charSetName.substring(1); - if (charSetString.charAt(0) == '\'') { - return new String(unescapeSQLString(charSetString).getBytes(), - charSetName); - } else // hex input is also supported - { - assert charSetString.charAt(0) == '0'; - assert charSetString.charAt(1) == 'x'; - charSetString = charSetString.substring(2); - - byte[] bArray = new byte[charSetString.length() / 2]; - int j = 0; - for (int i = 0; i < charSetString.length(); i += 2) { - int val = Character.digit(charSetString.charAt(i), 16) * 16 - + Character.digit(charSetString.charAt(i + 1), 16); - if (val > 127) { - val = val - 256; - } - bArray[j++] = (byte)val; - } - - String res = new String(bArray, charSetName); - return res; - } - } catch (UnsupportedEncodingException e) { - throw new SemanticException(e); - } - } - - /** - * Remove the encapsulating "`" pair from the identifier. We allow users to - * use "`" to escape identifier for table names, column names and aliases, in - * case that coincide with Hive language keywords. - */ - public static String unescapeIdentifier(String val) { - if (val == null) { - return null; - } - if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { - val = val.substring(1, val.length() - 1); - } - return val; - } - - /** - * Converts parsed key/value properties pairs into a map. - * - * @param prop ASTNode parent of the key/value pairs - * - * @param mapProp property map which receives the mappings - */ - public static void readProps( - ASTNode prop, Map mapProp) { - - for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { - String key = unescapeSQLString(prop.getChild(propChild).getChild(0) - .getText()); - String value = null; - if (prop.getChild(propChild).getChild(1) != null) { - value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); - } - mapProp.put(key, value); - } - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } - - /** - * Get the list of FieldSchema out of the ASTNode. - */ - public static List getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { - List colList = new ArrayList(); - int numCh = ast.getChildCount(); - for (int i = 0; i < numCh; i++) { - FieldSchema col = new FieldSchema(); - ASTNode child = (ASTNode) ast.getChild(i); - Tree grandChild = child.getChild(0); - if(grandChild != null) { - String name = grandChild.getText(); - if(lowerCase) { - name = name.toLowerCase(); - } - // child 0 is the name of the column - col.setName(unescapeIdentifier(name)); - // child 1 is the type of the column - ASTNode typeChild = (ASTNode) (child.getChild(1)); - col.setType(getTypeStringFromAST(typeChild)); - - // child 2 is the optional comment of the column - if (child.getChildCount() == 3) { - col.setComment(unescapeSQLString(child.getChild(2).getText())); - } - } - colList.add(col); - } - return colList; - } - - protected static String getTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - switch (typeNode.getType()) { - case SparkSqlParser.TOK_LIST: - return serdeConstants.LIST_TYPE_NAME + "<" - + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; - case SparkSqlParser.TOK_MAP: - return serdeConstants.MAP_TYPE_NAME + "<" - + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," - + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; - case SparkSqlParser.TOK_STRUCT: - return getStructTypeStringFromAST(typeNode); - case SparkSqlParser.TOK_UNIONTYPE: - return getUnionTypeStringFromAST(typeNode); - default: - return getTypeName(typeNode); - } - } - - private static String getStructTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; - typeNode = (ASTNode) typeNode.getChild(0); - int children = typeNode.getChildCount(); - if (children <= 0) { - throw new SemanticException("empty struct not allowed."); - } - StringBuilder buffer = new StringBuilder(typeStr); - for (int i = 0; i < children; i++) { - ASTNode child = (ASTNode) typeNode.getChild(i); - buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); - buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); - if (i < children - 1) { - buffer.append(","); - } - } - - buffer.append(">"); - return buffer.toString(); - } - - private static String getUnionTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; - typeNode = (ASTNode) typeNode.getChild(0); - int children = typeNode.getChildCount(); - if (children <= 0) { - throw new SemanticException("empty union not allowed."); - } - StringBuilder buffer = new StringBuilder(typeStr); - for (int i = 0; i < children; i++) { - buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); - if (i < children - 1) { - buffer.append(","); - } - } - buffer.append(">"); - typeStr = buffer.toString(); - return typeStr; - } - - public static String getAstNodeText(ASTNode tree) { - return tree.getChildCount() == 0?tree.getText() : - getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); - } - - public static String generateErrorMessage(ASTNode ast, String message) { - StringBuilder sb = new StringBuilder(); - if (ast == null) { - sb.append(message).append(". Cannot tell the position of null AST."); - return sb.toString(); - } - sb.append(ast.getLine()); - sb.append(":"); - sb.append(ast.getCharPositionInLine()); - sb.append(" "); - sb.append(message); - sb.append(". Error encountered near token '"); - sb.append(getAstNodeText(ast)); - sb.append("'"); - return sb.toString(); - } - - private static final Map TokenToTypeName = new HashMap(); - - static { - TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); - } - - public static String getTypeName(ASTNode node) throws SemanticException { - int token = node.getType(); - String typeName; - - // datetime type isn't currently supported - if (token == SparkSqlParser.TOK_DATETIME) { - throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); - } - - switch (token) { - case SparkSqlParser.TOK_CHAR: - CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); - typeName = charTypeInfo.getQualifiedName(); - break; - case SparkSqlParser.TOK_VARCHAR: - VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); - typeName = varcharTypeInfo.getQualifiedName(); - break; - case SparkSqlParser.TOK_DECIMAL: - DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); - typeName = decTypeInfo.getQualifiedName(); - break; - default: - typeName = TokenToTypeName.get(token); - } - return typeName; - } - - public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { - boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); - if (testMode) { - URI uri = new Path(location).toUri(); - String scheme = uri.getScheme(); - String authority = uri.getAuthority(); - String path = uri.getPath(); - if (!path.startsWith("/")) { - path = (new Path(System.getProperty("test.tmp.dir"), - path)).toUri().getPath(); - } - if (StringUtils.isEmpty(scheme)) { - scheme = "pfile"; - } - try { - uri = new URI(scheme, authority, path, null, null); - } catch (URISyntaxException e) { - throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); - } - return uri.toString(); - } else { - //no-op for non-test mode for now - return location; - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 31d82eb20f6e4..bf3fe12d5c5d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -17,41 +17,30 @@ package org.apache.spark.sql.hive -import java.sql.Date import java.util.Locale import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.{Context, ErrorMsg} -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse.SemanticException -import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.ParseUtils._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.execution.SparkQl +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} -import org.apache.spark.sql.parser._ +import org.apache.spark.sql.hive.execution.{HiveNativeCommand, AnalyzeTable, DropTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler +import org.apache.spark.sql.AnalysisException /** * Used when we need to start parsing the AST before deciding that we are going to pass the command @@ -71,7 +60,7 @@ private[hive] case class CreateTableAsSelect( override def output: Seq[Attribute] = Seq.empty[Attribute] override lazy val resolved: Boolean = tableDesc.specifiedDatabase.isDefined && - tableDesc.schema.size > 0 && + tableDesc.schema.nonEmpty && tableDesc.serde.isDefined && tableDesc.inputFormat.isDefined && tableDesc.outputFormat.isDefined && @@ -89,7 +78,7 @@ private[hive] case class CreateViewAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl extends Logging { +private[hive] object HiveQl extends SparkQl with Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -180,103 +169,6 @@ private[hive] object HiveQl extends Logging { protected val hqlParser = new ExtendedHiveQlParser - /** - * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations - * similar to [[catalyst.trees.TreeNode]]. - * - * Note that this should be considered very experimental and is not indented as a replacement - * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to - * have clean copy semantics. Therefore, users of this class should take care when - * copying/modifying trees that might be used elsewhere. - */ - implicit class TransformableNode(n: ASTNode) { - /** - * Returns a copy of this node where `rule` has been recursively applied to it and all of its - * children. When `rule` does not apply to a given node it is left unchanged. - * @param rule the function use to transform this nodes children - */ - def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { - try { - val afterRule = rule.applyOrElse(n, identity[ASTNode]) - afterRule.withChildren( - nilIfEmpty(afterRule.getChildren) - .asInstanceOf[Seq[ASTNode]] - .map(ast => Option(ast).map(_.transform(rule)).orNull)) - } catch { - case e: Exception => - logError(dumpTree(n).toString) - throw e - } - } - - /** - * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. - */ - private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.asScala).getOrElse(Nil) - - /** - * Returns this ASTNode with the text changed to `newText`. - */ - def withText(newText: String): ASTNode = { - n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) - n - } - - /** - * Returns this ASTNode with the children changed to `newChildren`. - */ - def withChildren(newChildren: Seq[ASTNode]): ASTNode = { - (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - newChildren.foreach(n.addChild(_)) - n - } - - /** - * Throws an error if this is not equal to other. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def checkEquals(other: ASTNode): Unit = { - def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { - sys.error(s"$field does not match for trees. " + - s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") - } - check("name", _.getName) - check("type", _.getType) - check("text", _.getText) - check("numChildren", n => nilIfEmpty(n.getChildren).size) - - val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] - val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] - leftChildren zip rightChildren foreach { - case (l, r) => l checkEquals r - } - } - } - - /** - * Returns the AST for the given SQL string. - */ - def getAst(sql: String): ASTNode = { - /* - * Context has to be passed in hive0.13.1. - * Otherwise, there will be Null pointer exception, - * when retrieving properties form HiveConf. - */ - val hContext = createContext() - val node = getAst(sql, hContext) - hContext.clear() - node - } - - private def createContext(): Context = new Context(hiveConf) - - private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(sql, context)) - /** * Returns the HiveConf */ @@ -296,226 +188,16 @@ private[hive] object HiveQl extends Logging { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) - val errorRegEx = "line (\\d+):(\\d+) (.*)".r - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String): LogicalPlan = { - try { - val context = createContext() - val tree = getAst(sql, context) - val plan = if (nativeCommands contains tree.getText) { - HiveNativeCommand(sql) - } else { - nodeToPlan(tree, context) match { - case NativePlaceholder => HiveNativeCommand(sql) - case other => other - } - } - context.clear() - plan - } catch { - case pe: ParseException => - pe.getMessage match { - case errorRegEx(line, start, message) => - throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) - case otherMessage => - throw new AnalysisException(otherMessage) - } - case e: MatchError => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) - } - } - - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = - try { - ParseUtils.findRootNonNullToken( - (new ParseDriver) - .parse(ddl, null /* no context required for parsing alone */)) - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) - } - assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.getChildren - val colList = - tableOps.asScala - .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")).getChildren - - colList.asScala.map(nodeToAttribute) - } - - /** Extractor for matching Hive's AST Tokens. */ - private[hive] case class Token(name: String, children: Seq[ASTNode]) extends Node { - def getName(): String = name - def getChildren(): java.util.List[Node] = { - val col = new java.util.ArrayList[Node](children.size) - children.foreach(col.add(_)) - col - } - } - object Token { - /** @return matches of the form (tokenName, children). */ - def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { - case t: ASTNode => - CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) - Some((t.getText, - Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) - case t: Token => Some((t.name, t.children)) - case _ => None - } - } - - protected def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } - - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses - } - - def getClause(clauseName: String, nodeList: Seq[Node]): Node = - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) - - def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { - nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } - - protected def nodeToAttribute(node: Node): Attribute = node match { - case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => - AttributeReference(colName, nodeToDataType(dataType), true)() - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", - Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", - keyType :: - valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") - } - - protected def nodeToStructField(node: Node): StructField = node match { - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: - _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") - } - - protected def extractTableIdent(tableNameParts: Node): TableIdentifier = { - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } - } - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition( n => n match { - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets - case _ => true // grouping keys - }) - - val keys = keyASTs.map(nodeToExpr).toSeq - val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap - - val bitmasks: Seq[Int] = setASTs.map(set => set match { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => - children.foldLeft(0)((bitmap, col) => { - val colString = col.asInstanceOf[ASTNode].toStringTree() - require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") - bitmap | 1 << keyMap(colString) - }) - case _ => sys.error("Expect GROUPING SETS clause") - }) - - (keys, bitmasks) - } - - protected def getProperties(node: Node): Seq[(String, String)] = node match { + protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { case Token("TOK_TABLEPROPLIST", list) => list.map { case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => - (unquoteString(key) -> unquoteString(value)) + unquoteString(key) -> unquoteString(value) } } private def createView( view: ASTNode, - context: Context, viewNameParts: ASTNode, query: ASTNode, schema: Seq[HiveColumn], @@ -524,8 +206,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C replace: Boolean): CreateViewAsSelect = { val TableIdentifier(viewName, dbName) = extractTableIdent(viewNameParts) - val originalText = context.getTokenRewriteStream - .toString(query.getTokenStartIndex, query.getTokenStopIndex) + val originalText = query.source val tableDesc = HiveTable( specifiedDatabase = dbName, @@ -544,104 +225,67 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // We need to keep the original SQL string so that if `spark.sql.nativeView` is // false, we can fall back to use hive native command later. // We can remove this when parser is configurable(can access SQLConf) in the future. - val sql = context.getTokenRewriteStream - .toString(view.getTokenStartIndex, view.getTokenStopIndex) - CreateViewAsSelect(tableDesc, nodeToPlan(query, context), allowExist, replace, sql) + val sql = view.source + CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) } - protected def nodeToPlan(node: ASTNode, context: Context): LogicalPlan = node match { - // Special drop table that also uncaches. - case Token("TOK_DROPTABLE", - Token("TOK_TABNAME", tableNameParts) :: - ifExists) => - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - DropTable(tableName, ifExists.nonEmpty) - // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" - case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: - isNoscan) => - // Reference: - // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables - if (partitionSpec.nonEmpty) { - // Analyze partitions will be treated as a Hive native command. - NativePlaceholder - } else if (isNoscan.isEmpty) { - // If users do not specify "noscan", it will be treated as a Hive native command. - NativePlaceholder - } else { - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - AnalyzeTable(tableName) + protected override def createPlan( + sql: String, + node: ASTNode): LogicalPlan = { + if (nativeCommands.contains(node.text)) { + HiveNativeCommand(sql) + } else { + nodeToPlan(node) match { + case NativePlaceholder => HiveNativeCommand(sql) + case plan => plan } - // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) - if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(OneRowRelation) - case Token("TOK_EXPLAIN", explainArgs) - if "TOK_CREATETABLE" == explainArgs.head.getText => - val Some(crtTbl) :: _ :: extended :: Nil = - getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(crtTbl, context), - extended = extended.isDefined) - case Token("TOK_EXPLAIN", explainArgs) => - // Ignore FORMATTED if present. - val Some(query) :: _ :: extended :: Nil = - getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(query, context), - extended = extended.isDefined) - - case Token("TOK_DESCTABLE", describeArgs) => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val Some(tableType) :: formatted :: extended :: pretty :: Nil = - getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) - if (formatted.isDefined || pretty.isDefined) { - // FORMATTED and PRETTY are not supported and this statement will be treated as - // a Hive native command. - NativePlaceholder - } else { - tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => { - nameParts match { - case Token(".", dbName :: tableName :: Nil) => - // It is describing a table with the format like "describe db.table". - // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts) - DescribeCommand( - UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => - // It is describing a column with the format like "describe db.table column". - NativePlaceholder - case tableName => - // It is describing a table with the format like "describe table". - DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.getText), None), - isExtended = extended.isDefined) - } - } - // All other cases. - case _ => NativePlaceholder + } + } + + protected override def isNoExplainCommand(command: String): Boolean = + noExplainCommands.contains(command) + + protected override def nodeToPlan(node: ASTNode): LogicalPlan = { + node match { + // Special drop table that also uncaches. + case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + DropTable(tableName, ifExists.nonEmpty) + + // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + case Token("TOK_ANALYZE", + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) => + // Reference: + // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables + if (partitionSpec.nonEmpty) { + // Analyze partitions will be treated as a Hive native command. + NativePlaceholder + } else if (isNoscan.isEmpty) { + // If users do not specify "noscan", it will be treated as a Hive native command. + NativePlaceholder + } else { + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + AnalyzeTable(tableName) } - } - case view @ Token("TOK_ALTERVIEW", children) => - val Some(viewNameParts) :: maybeQuery :: ignores = - getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME"), children) + case view @ Token("TOK_ALTERVIEW", children) => + val Some(nameParts) :: maybeQuery :: _ = + getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME"), children) - // if ALTER VIEW doesn't have query part, let hive to handle it. - maybeQuery.map { query => - createView(view, context, viewNameParts, query, Nil, Map(), false, true) - }.getOrElse(NativePlaceholder) + // if ALTER VIEW doesn't have query part, let hive to handle it. + maybeQuery.map { query => + createView(view, nameParts, query, Nil, Map(), allowExist = false, replace = true) + }.getOrElse(NativePlaceholder) - case view @ Token("TOK_CREATEVIEW", children) + case view @ Token("TOK_CREATEVIEW", children) if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - val Seq( + val Seq( Some(viewNameParts), Some(query), maybeComment, @@ -650,1224 +294,466 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeProperties, maybeColumns, maybePartCols - ) = getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_TABLECOMMENT", - "TOK_ORREPLACE", - "TOK_IFNOTEXISTS", - "TOK_TABLEPROPERTIES", - "TOK_TABCOLNAME", - "TOK_VIEWPARTCOLS"), children) - - // If the view is partitioned, we let hive handle it. - if (maybePartCols.isDefined) { - NativePlaceholder - } else { - val schema = maybeColumns.map { cols => - SemanticAnalyzer.getColumns(cols, true).asScala.map { field => + ) = getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_TABLECOMMENT", + "TOK_ORREPLACE", + "TOK_IFNOTEXISTS", + "TOK_TABLEPROPERTIES", + "TOK_TABCOLNAME", + "TOK_VIEWPARTCOLS"), children) + + // If the view is partitioned, we let hive handle it. + if (maybePartCols.isDefined) { + NativePlaceholder + } else { + val schema = maybeColumns.map { cols => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. - HiveColumn(field.getName, null, field.getComment) - } - }.getOrElse(Seq.empty[HiveColumn]) - - val properties = scala.collection.mutable.Map.empty[String, String] - - maybeProperties.foreach { - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - properties ++= getProperties(list) - } - - maybeComment.foreach { - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = SemanticAnalyzer.unescapeSQLString(child.getText) - if (comment ne null) { - properties += ("comment" -> comment) - } - } - - createView(view, context, viewNameParts, query, schema, properties.toMap, - allowExisting.isDefined, replace.isDefined) - } - - case Token("TOK_CREATETABLE", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val ( - Some(tableNameParts) :: - _ /* likeTable */ :: - externalTable :: - Some(query) :: - allowExisting +: - ignores) = - getClauses( - Seq( - "TOK_TABNAME", - "TOK_LIKETABLE", - "EXTERNAL", - "TOK_QUERY", - "TOK_IFNOTEXISTS", - "TOK_TABLECOMMENT", - "TOK_TABCOLLIST", - "TOK_TABLEPARTCOLS", // Partitioned by - "TOK_TABLEBUCKETS", // Clustered by - "TOK_TABLESKEWED", // Skewed by - "TOK_TABLEROWFORMAT", - "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", - "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat - "TOK_STORAGEHANDLER", // Storage handler - "TOK_TABLELOCATION", - "TOK_TABLEPROPERTIES"), - children) - val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) - - // TODO add bucket support - var tableDesc: HiveTable = HiveTable( - specifiedDatabase = dbName, - name = tblName, - schema = Seq.empty[HiveColumn], - partitionColumns = Seq.empty[HiveColumn], - properties = Map[String, String](), - serdeProperties = Map[String, String](), - tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, - viewText = None) - - // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) - val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbreviation - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } + nodeToColumns(cols, lowerCase = true).map(_.copy(hiveType = null)) + }.getOrElse(Seq.empty[HiveColumn]) - hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) - hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) - hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) + val properties = scala.collection.mutable.Map.empty[String, String] - children.collect { - case list @ Token("TOK_TABCOLLIST", _) => - val cols = SemanticAnalyzer.getColumns(list, true) - if (cols != null) { - tableDesc = tableDesc.copy( - schema = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) + maybeProperties.foreach { + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + properties ++= getProperties(list) } - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = SemanticAnalyzer.unescapeSQLString(child.getText) - // TODO support the sql text - tableDesc = tableDesc.copy(viewText = Option(comment)) - case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = SemanticAnalyzer.getColumns(list(0), false) - if (cols != null) { - tableDesc = tableDesc.copy( - partitionColumns = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) - } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => - val serdeParams = new java.util.HashMap[String, String]() - child match { - case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = SemanticAnalyzer.unescapeSQLString (rowChild1.getText()) - serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) - serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) - if (rowChild2.length > 1) { - val fieldEscape = SemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) - serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) - } - case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) - case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) - case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - if (!(lineDelim == "\n") && !(lineDelim == "10")) { - throw new AnalysisException( - SemanticAnalyzer.generateErrorMessage( - rowChild, - ErrorMsg.LINES_TERMINATED_BY_NON_NEWLINE.getMsg)) - } - serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) - case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - // TODO support the nullFormat - case _ => assert(false) - } - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - case Token("TOK_TABLELOCATION", child :: Nil) => - var location = SemanticAnalyzer.unescapeSQLString(child.getText) - location = SemanticAnalyzer.relativeToAbsolutePath(hiveConf, location) - tableDesc = tableDesc.copy(location = Option(location)) - case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.copy( - serde = Option(SemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) - if (child.getChildCount == 2) { - val serdeParams = new java.util.HashMap[String, String]() - SemanticAnalyzer.readProps( - (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - } - case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - child.getText().toLowerCase(Locale.ENGLISH) match { - case "orc" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - case "parquet" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - case "rcfile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } - case "textfile" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - - case "sequencefile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - case "avro" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + maybeComment.foreach { + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = unescapeSQLString(child.text) + if (comment ne null) { + properties += ("comment" -> comment) } - - case _ => - throw new SemanticException( - s"Unrecognized file format in STORED AS clause: ${child.getText}") - } - - case Token("TOK_TABLESERIALIZER", - Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - - otherProps match { - case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) - case Nil => } - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", children) => - tableDesc = tableDesc.copy( - inputFormat = - Option(SemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), - outputFormat = - Option(SemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) - case Token("TOK_STORAGEHANDLER", _) => - throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) - case _ => // Unsupport features - } - - CreateTableAsSelect(tableDesc, nodeToPlan(query, context), allowExisting != None) - - // If its not a "CTAS" like above then take it as a native command - case Token("TOK_CREATETABLE", _) => NativePlaceholder - - // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" - case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder - - case Token("TOK_QUERY", queryArgs) - if Seq("TOK_CTE", "TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => - val cteRelations = ctes.map { node => - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - relation.alias -> relation - } - (Some(from.head), inserts, Some(cteRelations.toMap)) - case Token("TOK_FROM", from) :: inserts => - (Some(from.head), inserts, None) - case Token("TOK_INSERT", _) :: Nil => - (None, queryArgs, None) + createView(view, viewNameParts, query, schema, properties.toMap, + allowExisting.isDefined, replace.isDefined) } - // Return one query for each insert clause. - val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => + case Token("TOK_CREATETABLE", children) + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { + Some(tableNameParts) :: + _ /* likeTable */ :: + externalTable :: + Some(query) :: + allowExisting +: + _) = getClauses( Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) + "TOK_TABNAME", + "TOK_LIKETABLE", + "EXTERNAL", + "TOK_QUERY", + "TOK_IFNOTEXISTS", + "TOK_TABLECOMMENT", + "TOK_TABCOLLIST", + "TOK_TABLEPARTCOLS", // Partitioned by + "TOK_TABLEBUCKETS", // Clustered by + "TOK_TABLESKEWED", // Skewed by + "TOK_TABLEROWFORMAT", + "TOK_TABLESERIALIZER", + "TOK_FILEFORMAT_GENERIC", + "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat + "TOK_STORAGEHANDLER", // Storage handler + "TOK_TABLELOCATION", + "TOK_TABLEPROPERTIES"), + children) + val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) + + // TODO add bucket support + var tableDesc: HiveTable = HiveTable( + specifiedDatabase = dbName, + name = tblName, + schema = Seq.empty[HiveColumn], + partitionColumns = Seq.empty[HiveColumn], + properties = Map[String, String](), + serdeProperties = Map[String, String](), + tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = None) + + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) + val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) } - val relations = fromClause match { - case Some(f) => nodeToRelation(f, context) - case None => OneRowRelation - } + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.asScala - Filter(nodeToExpr(whereExpr), relations) - }.getOrElse(relations) - - val select = - (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause.")) - - // Script transformations are expressed as a select clause with a single expression of type - // TOK_TRANSFORM - val transformation = select.getChildren.iterator().next() match { - case Token("TOK_SELEXPR", - Token("TOK_TRANSFORM", - Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", inputSerdeClause) :: - Token("TOK_RECORDWRITER", writerClause) :: - // TODO: Need to support other types of (in/out)put - Token(script, Nil) :: - Token("TOK_SERDE", outputSerdeClause) :: - Token("TOK_RECORDREADER", readerClause) :: - outputClause) :: Nil) => - - val (output, schemaLess) = outputClause match { - case Token("TOK_ALIASLIST", aliases) :: Nil => - (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, - false) - case Token("TOK_TABCOLLIST", attributes) :: Nil => - (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(name, nodeToDataType(dataType))() }, false) - case Nil => - (List(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) + children.collect { + case list @ Token("TOK_TABCOLLIST", _) => + val cols = nodeToColumns(list, lowerCase = true) + if (cols != null) { + tableDesc = tableDesc.copy(schema = cols) } - - type SerDeInfo = ( - Seq[(String, String)], // Input row format information - Option[String], // Optional input SerDe class - Seq[(String, String)], // Input SerDe properties - Boolean // Whether to use default record reader/writer - ) - - def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { - case Token("TOK_SERDEPROPS", propsClause) :: Nil => - val rowFormat = propsClause.map { - case Token(name, Token(value, Nil) :: Nil) => (name, value) + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = unescapeSQLString(child.text) + // TODO support the sql text + tableDesc = tableDesc.copy(viewText = Option(comment)) + case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => + val cols = nodeToColumns(list.head, lowerCase = false) + if (cols != null) { + tableDesc = tableDesc.copy(partitionColumns = cols) + } + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => + val serdeParams = new java.util.HashMap[String, String]() + child match { + case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => + val fieldDelim = unescapeSQLString (rowChild1.text) + serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) + serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) + if (rowChild2.length > 1) { + val fieldEscape = unescapeSQLString (rowChild2.head.text) + serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } - (rowFormat, None, Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(SemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: - Token("TOK_TABLEPROPERTIES", - Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => - val serdeProps = propsClause.map { - case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (SemanticAnalyzer.unescapeSQLString(name), - SemanticAnalyzer.unescapeSQLString(value)) + case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => + val collItemDelim = unescapeSQLString(rowChild.text) + serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) + case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => + val mapKeyDelim = unescapeSQLString(rowChild.text) + serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) + case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => + val lineDelim = unescapeSQLString(rowChild.text) + if (!(lineDelim == "\n") && !(lineDelim == "10")) { + throw new AnalysisException( + s"LINES TERMINATED BY only supports newline '\\n' right now: $rowChild") } - - // SPARK-10310: Special cases LazySimpleSerDe - // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = SemanticAnalyzer.unescapeSQLString(serdeClass) - val useDefaultRecordReaderWriter = - unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName - (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) - - case Nil => - // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here - val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") - (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) - } - - val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = - matchSerDe(inputSerdeClause) - - val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = - matchSerDe(outputSerdeClause) - - val unescapedScript = SemanticAnalyzer.unescapeSQLString(script) - - // TODO Adds support for user-defined record reader/writer classes - val recordReaderClass = if (useDefaultRecordReader) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) - } else { - None + serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) + case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => + val nullFormat = unescapeSQLString(rowChild.text) + // TODO support the nullFormat + case _ => assert(false) } - - val recordWriterClass = if (useDefaultRecordWriter) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) - } else { - None + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) + case Token("TOK_TABLELOCATION", child :: Nil) => + val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text)) + tableDesc = tableDesc.copy(location = Option(location)) + case Token("TOK_TABLESERIALIZER", child :: Nil) => + tableDesc = tableDesc.copy( + serde = Option(unescapeSQLString(child.children.head.text))) + if (child.numChildren == 2) { + // This is based on the readProps(..) method in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: + val serdeParams = child.children(1).children.head.children.map { + case Token(_, Token(prop, Nil) :: valueNode) => + val value = valueNode.headOption + .map(_.text) + .map(unescapeSQLString) + .orNull + (unescapeSQLString(prop), value) + }.toMap + tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) } + case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => + child.text.toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - val schema = HiveScriptIOSchema( - inRowFormat, outRowFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - recordReaderClass, recordWriterClass, - schemaLess) - - Some( - logical.ScriptTransformation( - inputExprs.map(nodeToExpr), - unescapedScript, - output, - withWhere, schema)) - case _ => None - } - - val withLateralView = lateralViewClause.map { lv => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = false, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - withWhere) - }.getOrElse(withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Aggregate(Seq(Rollup(children.map(nodeToExpr))), selectExpressions, withLateralView) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Aggregate(Seq(Cube(children.map(nodeToExpr))), selectExpressions, withLateralView) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withLateralView))).flatten.head - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), - false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), - false, - RepartitionByExpression( - clusterExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy(serde = + Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) + case _ => + throw new AnalysisException( + s"Unrecognized file format in STORED AS clause: ${child.text}") + } - // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left, context), nodeToPlan(right, context)) + case Token("TOK_TABLESERIALIZER", + Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => + tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") - } + otherProps match { + case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) + case _ => + } - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node, context: Context): LogicalPlan = node match { - case Token("TOK_SUBQUERY", - query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query, context)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - nodeToRelation(relationClause, context)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.getText.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) + case list @ Token("TOK_TABLEFILEFORMAT", _) => + tableDesc = tableDesc.copy( + inputFormat = + Option(unescapeSQLString(list.children.head.text)), + outputFormat = + Option(unescapeSQLString(list.children(1).text))) + case Token("TOK_STORAGEHANDLER", _) => + throw new AnalysisException( + "CREATE TABLE AS SELECT cannot be used for a non-native table") + case _ => // Unsupport features } - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: - Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: - Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - }.getOrElse(relation) - - case Token("TOK_UNIQUEJOIN", joinArgs) => - val tableOrdinals = - joinArgs.zipWithIndex.filter { - case (arg, i) => arg.getText == "TOK_TABREF" - }.map(_._2) - - val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") - val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i), context)) - val joinExpressions = - tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) - - val joinConditions = joinExpressions.sliding(2).map { - case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } - predicates.reduceLeft(And) - }.toBuffer - - val joinType = isPreserved.sliding(2).map { - case Seq(true, true) => FullOuter - case Seq(true, false) => LeftOuter - case Seq(false, true) => RightOuter - case Seq(false, false) => Inner - }.toBuffer - - val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) - - // Must be transform down. - val joinedResult = joinedTables transform { - case j: Join => - j.copy( - condition = Some(joinConditions.remove(joinConditions.length - 1)), - joinType = joinType.remove(joinType.length - 1)) - } - - val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) - - // Unique join is not really the same as an outer join so we must group together results where - // the joinExpressions are the same, taking the First of each value is only okay because the - // user of a unique join is implicitly promising that there is only one result. - // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression. - // instead we should figure out how important supporting this feature is and whether it is - // worth the number of hacks that will be required to implement it. Namely, we need to add - // some sort of mapped star expansion that would expand all child output row to be similarly - // named output expressions where some aggregate expression has been applied (i.e. First). - // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) - throw new UnsupportedOperationException - - case Token(allJoinTokens(joinToken), - relation1 :: - relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_ANTIJOIN" => throw new NotImplementedError("Anti join not supported") - } - Join(nodeToRelation(relation1, context), - nodeToRelation(relation2, context), - joinType, - other.headOption.map(nodeToExpr)) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } + CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) - def nodeToSortOrder(node: Node): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) + // If its not a "CTAS" like above then take it as a native command + case Token("TOK_CREATETABLE", _) => + NativePlaceholder - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } + // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" + case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION", table) :: Nil) => + NativePlaceholder - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: Node, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName}:" + - s"\n ${dumpTree(a).toString} ") + case _ => + super.nodeToPlan(node) + } } - protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - var aliasNames = ArrayBuffer[String]() - aliasChildren.foreach { _ match { - case Token(name, Nil) => aliasNames += cleanIdentifier(name) + protected override def nodeToDescribeFallback(node: ASTNode): LogicalPlan = NativePlaceholder + + protected override def nodeToTransformation( + node: ASTNode, + child: LogicalPlan): Option[ScriptTransformation] = node match { + case Token("TOK_SELEXPR", + Token("TOK_TRANSFORM", + Token("TOK_EXPLIST", inputExprs) :: + Token("TOK_SERDE", inputSerdeClause) :: + Token("TOK_RECORDWRITER", writerClause) :: + // TODO: Need to support other types of (in/out)put + Token(script, Nil) :: + Token("TOK_SERDE", outputSerdeClause) :: + Token("TOK_RECORDREADER", readerClause) :: + outputClause) :: Nil) => + + val (output, schemaLess) = outputClause match { + case Token("TOK_ALIASLIST", aliases) :: Nil => + (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, + false) + case Token("TOK_TABCOLLIST", attributes) :: Nil => + (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => + AttributeReference(name, nodeToDataType(dataType))() }, false) + case Nil => + (List(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) case _ => - } + noParseRule("Transform", node) } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName }:" + - s"\n ${dumpTree(a).toString } ") - } - - protected val escapedIdentifier = "`([^`]+)`".r - protected val doubleQuotedString = "\"([^\"]+)\"".r - protected val singleQuotedString = "'([^']+)'".r + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { + case Token("TOK_SERDEPROPS", propsClause) :: Nil => + val rowFormat = propsClause.map { + case Token(name, Token(value, Nil) :: Nil) => (name, value) + } + (rowFormat, None, Nil, false) - protected def unquoteString(str: String) = str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => + (Nil, Some(unescapeSQLString(serdeClass)), Nil, false) - /** Strips backticks from ident if present */ - protected def cleanIdentifier(ident: String): String = ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: + Token("TOK_TABLEPROPERTIES", + Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => + val serdeProps = propsClause.map { + case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => + (unescapeSQLString(name), unescapeSQLString(value)) + } - val numericAstTypes = Seq( - SparkSqlParser.Number, - SparkSqlParser.TinyintLiteral, - SparkSqlParser.SmallintLiteral, - SparkSqlParser.BigintLiteral, - SparkSqlParser.DecimalLiteral) - - /* Case insensitive matches */ - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - - protected def nodeToExpr(node: Node): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", - Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => - Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => - Count(Literal(1)).toAggregateExpression() - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) - case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token(name, args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = nodeToExpr(Token(name, args)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => SemanticAnalyzer.unescapeSQLString(s.getText)).mkString) - - // This code is adapted from - // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 - case ast: ASTNode if numericAstTypes contains ast.getType => - var v: Literal = null - try { - if (ast.getText.endsWith("L")) { - // Literal bigint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) - } else if (ast.getText.endsWith("S")) { - // Literal smallint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) - } else if (ast.getText.endsWith("Y")) { - // Literal tinyint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) - } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { - // Literal decimal - val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(Decimal(strVal)) - } else { - v = Literal.create(ast.getText.toDouble, DoubleType) - v = Literal.create(ast.getText.toLong, LongType) - v = Literal.create(ast.getText.toInt, IntegerType) - } - } catch { - case nfe: NumberFormatException => // Do nothing - } + val unescapedScript = unescapeSQLString(script) - if (v == null) { - sys.error(s"Failed to parse number '${ast.getText}'.") + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) } else { - v + None } - case ast: ASTNode if ast.getType == SparkSqlParser.StringLiteral => - Literal(SemanticAnalyzer.unescapeSQLString(ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_CHARSETLITERAL => - Literal(SemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), - orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - throw new NotImplementedError( - s"""No parse rules for Node ${partitionAndOrdering.getName} - """.stripMargin) - } - }.getOrElse { - (Nil, Nil) + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None } - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} - """.stripMargin) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.asScala.toList match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame based on Node ${frame.getName} - """.stripMargin) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + val schema = HiveScriptIOSchema( + inRowFormat, outRowFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) + + Some( + ScriptTransformation( + inputExprs.map(nodeToExpr), + unescapedScript, + output, + child, schema)) + case _ => None } - val explode = "(?i)explode".r - val jsonTuple = "(?i)json_tuple".r - def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { - val function = nodes.head - - val attributes = nodes.flatMap { - case Token(a, Nil) => a.toLowerCase :: Nil - case _ => Nil - } - - function match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - (Explode(nodeToExpr(child)), attributes) - - case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => - (JsonTuple(children.map(nodeToExpr)), attributes) - - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - - (HiveGenericUDTF( - new HiveFunctionWrapper(functionClassName), - children.map(nodeToExpr)), attributes) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: - |${dumpTree(a).toString} - """.stripMargin) - } + protected override def nodeToGenerator(node: ASTNode): Generator = node match { + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + case other => super.nodeToGenerator(node) } - def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) - : StringBuilder = { - node match { - case a: ASTNode => builder.append( - (" " * indent) + a.getText + " " + - a.getLine + ", " + - a.getTokenStartIndex + "," + - a.getTokenStopIndex + ", " + - a.getCharPositionInLine + "\n") - case other => sys.error(s"Non ASTNode encountered: $other") + // This is based the getColumns methods in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java + protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[HiveColumn] = { + node.children.map(_.children).collect { + case Token(rawColName, Nil) :: colTypeNode :: comment => + val colName = if (!lowerCase) rawColName + else rawColName.toLowerCase + HiveColumn( + cleanIdentifier(colName), + nodeToTypeString(colTypeNode), + comment.headOption.map(n => unescapeSQLString(n.text)).orNull) } + } - Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) - builder + // This is based on the following methods in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: + // getTypeStringFromAST + // getStructTypeStringFromAST + // getUnionTypeStringFromAST + protected def nodeToTypeString(node: ASTNode): String = node.tokenType match { + case SparkSqlParser.TOK_LIST => + val listType :: Nil = node.children + val listTypeString = nodeToTypeString(listType) + s"${serdeConstants.LIST_TYPE_NAME}<$listTypeString>" + + case SparkSqlParser.TOK_MAP => + val keyType :: valueType :: Nil = node.children + val keyTypeString = nodeToTypeString(keyType) + val valueTypeString = nodeToTypeString(valueType) + s"${serdeConstants.MAP_TYPE_NAME}<$keyTypeString,$valueTypeString>" + + case SparkSqlParser.TOK_STRUCT => + val typeNode = node.children.head + require(typeNode.children.nonEmpty, "Struct must have one or more columns.") + val structColStrings = typeNode.children.map { columnNode => + val Token(colName, Nil) :: colTypeNode :: Nil = columnNode.children + cleanIdentifier(colName) + ":" + nodeToTypeString(colTypeNode) + } + s"${serdeConstants.STRUCT_TYPE_NAME}<${structColStrings.mkString(",")}>" + + case SparkSqlParser.TOK_UNIONTYPE => + val typeNode = node.children.head + val unionTypesString = typeNode.children.map(nodeToTypeString).mkString(",") + s"${serdeConstants.UNION_TYPE_NAME}<$unionTypesString>" + + case SparkSqlParser.TOK_CHAR => + val Token(size, Nil) :: Nil = node.children + s"${serdeConstants.CHAR_TYPE_NAME}($size)" + + case SparkSqlParser.TOK_VARCHAR => + val Token(size, Nil) :: Nil = node.children + s"${serdeConstants.VARCHAR_TYPE_NAME}($size)" + + case SparkSqlParser.TOK_DECIMAL => + val precisionAndScale = node.children match { + case Token(precision, Nil) :: Token(scale, Nil) :: Nil => + precision + "," + scale + case Token(precision, Nil) :: Nil => + precision + "," + HiveDecimal.USER_DEFAULT_SCALE + case Nil => + HiveDecimal.USER_DEFAULT_PRECISION + "," + HiveDecimal.USER_DEFAULT_SCALE + case _ => + noParseRule("Decimal", node) + } + s"${serdeConstants.DECIMAL_TYPE_NAME}($precisionAndScale)" + + // Simple data types. + case SparkSqlParser.TOK_BOOLEAN => serdeConstants.BOOLEAN_TYPE_NAME + case SparkSqlParser.TOK_TINYINT => serdeConstants.TINYINT_TYPE_NAME + case SparkSqlParser.TOK_SMALLINT => serdeConstants.SMALLINT_TYPE_NAME + case SparkSqlParser.TOK_INT => serdeConstants.INT_TYPE_NAME + case SparkSqlParser.TOK_BIGINT => serdeConstants.BIGINT_TYPE_NAME + case SparkSqlParser.TOK_FLOAT => serdeConstants.FLOAT_TYPE_NAME + case SparkSqlParser.TOK_DOUBLE => serdeConstants.DOUBLE_TYPE_NAME + case SparkSqlParser.TOK_STRING => serdeConstants.STRING_TYPE_NAME + case SparkSqlParser.TOK_BINARY => serdeConstants.BINARY_TYPE_NAME + case SparkSqlParser.TOK_DATE => serdeConstants.DATE_TYPE_NAME + case SparkSqlParser.TOK_TIMESTAMP => serdeConstants.TIMESTAMP_TYPE_NAME + case SparkSqlParser.TOK_INTERVAL_YEAR_MONTH => serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME + case SparkSqlParser.TOK_INTERVAL_DAY_TIME => serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME + case SparkSqlParser.TOK_DATETIME => serdeConstants.DATETIME_TYPE_NAME + case _ => null } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 400f7f3708cf4..a2d283622ca52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -21,6 +21,7 @@ import scala.util.Try import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -116,8 +117,9 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { + def ast = ParseDriver.parse(query, hiveContext.conf) def parseTree = - Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + Try(quietly(ast.treeString)).getOrElse("") test(name) { val error = intercept[AnalysisException] { @@ -139,10 +141,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd val expectedStart = line.indexOf(token) val actualStart = error.startPosition.getOrElse { - fail( - s"start not returned for error on token $token\n" + - HiveQl.dumpTree(HiveQl.getAst(query)) - ) + fail(s"start not returned for error on token $token\n${ast.treeString}") } assert(expectedStart === actualStart, s"""Incorrect start position. From fcd013cf70e7890aa25a8fe3cb6c8b36bf0e1f04 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 6 Jan 2016 11:58:33 -0800 Subject: [PATCH 1465/2089] [SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None If initial model passed to GMM is not empty it causes `net.razorvine.pickle.PickleException`. It can be fixed by converting `initialModel.weights` to `list`. Author: zero323 Closes #9986 from zero323/SPARK-12006. --- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c9e6f1dec6bf8..48daa87e82d13 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -346,7 +346,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = initialModel.weights + initialModelWeights = list(initialModel.weights) initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6ed03e35828ed..97fed7662ea90 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -475,6 +475,18 @@ def test_gmm_deterministic(self): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ From f82ebb15224ec5375f25f67d598ec3ef1cb65210 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Wed, 6 Jan 2016 12:01:05 -0800 Subject: [PATCH 1466/2089] [SPARK-12368][ML][DOC] Better doc for the binary classification evaluator' metricName For the BinaryClassificationEvaluator, the scaladoc doesn't mention that "areaUnderPR" is supported, only that the default is "areadUnderROC". Also, in the documentation, it is said that: "The default metric used to choose the best ParamMap can be overriden by the setMetric method in each of these evaluators." However, the method is called setMetricName. This PR aims to fix both issues. Author: BenFradet Closes #10328 from BenFradet/SPARK-12368. --- docs/ml-guide.md | 4 ++-- .../spark/ml/evaluation/BinaryClassificationEvaluator.scala | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 44a316a07dfef..1343753bce246 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -628,7 +628,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName` method in each of these evaluators. The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. @@ -951,4 +951,4 @@ model.transform(test) {% endhighlight %} - \ No newline at end of file + diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index bfb70963b151d..f71726f110e84 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -39,8 +39,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va def this() = this(Identifiable.randomUID("binEval")) /** - * param for metric name in evaluation - * Default: areaUnderROC + * param for metric name in evaluation (supports `"areaUnderROC"` (default), `"areaUnderPR"`) * @group param */ @Since("1.2.0") From 1e6648d62fb82b708ea54c51cd23bfe4f542856e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 6 Jan 2016 12:03:01 -0800 Subject: [PATCH 1467/2089] [SPARK-12617][PYSPARK] Move Py4jCallbackConnectionCleaner to Streaming Move Py4jCallbackConnectionCleaner to Streaming because the callback server starts only in StreamingContext. Author: Shixiong Zhu Closes #10621 from zsxwing/SPARK-12617-2. --- python/pyspark/context.py | 61 ---------------------------- python/pyspark/streaming/context.py | 63 +++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 61 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5e4aeac330c5a..529d16b480399 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -54,64 +54,6 @@ } -class Py4jCallbackConnectionCleaner(object): - - """ - A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617. - It will scan all callback connections every 30 seconds and close the dead connections. - """ - - def __init__(self, gateway): - self._gateway = gateway - self._stopped = False - self._timer = None - self._lock = RLock() - - def start(self): - if self._stopped: - return - - def clean_closed_connections(): - from py4j.java_gateway import quiet_close, quiet_shutdown - - callback_server = self._gateway._callback_server - with callback_server.lock: - try: - closed_connections = [] - for connection in callback_server.connections: - if not connection.isAlive(): - quiet_close(connection.input) - quiet_shutdown(connection.socket) - quiet_close(connection.socket) - closed_connections.append(connection) - - for closed_connection in closed_connections: - callback_server.connections.remove(closed_connection) - except Exception: - import traceback - traceback.print_exc() - - self._start_timer(clean_closed_connections) - - self._start_timer(clean_closed_connections) - - def _start_timer(self, f): - from threading import Timer - - with self._lock: - if not self._stopped: - self._timer = Timer(30.0, f) - self._timer.daemon = True - self._timer.start() - - def stop(self): - with self._lock: - self._stopped = True - if self._timer: - self._timer.cancel() - self._timer = None - - class SparkContext(object): """ @@ -126,7 +68,6 @@ class SparkContext(object): _active_spark_context = None _lock = RLock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _py4j_cleaner = None PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') @@ -303,8 +244,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - _py4j_cleaner = Py4jCallbackConnectionCleaner(SparkContext._gateway) - _py4j_cleaner.start() if instance: if (SparkContext._active_spark_context and diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 5cc4bbde39958..0f1f005ce3edf 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -19,6 +19,7 @@ import os import sys +from threading import RLock, Timer from py4j.java_gateway import java_import, JavaObject @@ -32,6 +33,63 @@ __all__ = ["StreamingContext"] +class Py4jCallbackConnectionCleaner(object): + + """ + A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617. + It will scan all callback connections every 30 seconds and close the dead connections. + """ + + def __init__(self, gateway): + self._gateway = gateway + self._stopped = False + self._timer = None + self._lock = RLock() + + def start(self): + if self._stopped: + return + + def clean_closed_connections(): + from py4j.java_gateway import quiet_close, quiet_shutdown + + callback_server = self._gateway._callback_server + if callback_server: + with callback_server.lock: + try: + closed_connections = [] + for connection in callback_server.connections: + if not connection.isAlive(): + quiet_close(connection.input) + quiet_shutdown(connection.socket) + quiet_close(connection.socket) + closed_connections.append(connection) + + for closed_connection in closed_connections: + callback_server.connections.remove(closed_connection) + except Exception: + import traceback + traceback.print_exc() + + self._start_timer(clean_closed_connections) + + self._start_timer(clean_closed_connections) + + def _start_timer(self, f): + with self._lock: + if not self._stopped: + self._timer = Timer(30.0, f) + self._timer.daemon = True + self._timer.start() + + def stop(self): + with self._lock: + self._stopped = True + if self._timer: + self._timer.cancel() + self._timer = None + + class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext @@ -47,6 +105,9 @@ class StreamingContext(object): # Reference to a currently active StreamingContext _activeContext = None + # A cleaner to clean leak sockets of callback server every 30 seconds + _py4j_cleaner = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -95,6 +156,8 @@ def _ensure_initialized(cls): jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + _py4j_cleaner = Py4jCallbackConnectionCleaner(gw) + _py4j_cleaner.start() # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing From 19e4e9febf9bb4fd69f6d7bc13a54844e4e096f1 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Wed, 6 Jan 2016 12:48:57 -0800 Subject: [PATCH 1468/2089] [SPARK-12672][STREAMING][UI] Use the uiRoot function instead of default root path to gain the streaming batch url. Author: huangzhaowei Closes #10617 from SaintBacchus/SPARK-12672. --- .../org/apache/spark/streaming/scheduler/JobScheduler.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1ed6fb0aa9d52..2c57706636fa5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -26,7 +26,8 @@ import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} +import org.apache.spark.ui.{UIUtils => SparkUIUtils} +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -203,7 +204,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { try { val formattedTime = UIUtils.formatBatchTime( job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) - val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" + val batchUrl = s"${SparkUIUtils.uiRoot}/streaming/batch/?id=${job.time.milliseconds}" val batchLinkText = s"[output operation ${job.outputOpId}, batch time ${formattedTime}]" ssc.sc.setJobDescription( From cbaea9591f089171f3af654d1f9a52916e9f28b9 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 6 Jan 2016 13:51:50 -0800 Subject: [PATCH 1469/2089] Revert "[SPARK-12672][STREAMING][UI] Use the uiRoot function instead of default root path to gain the streaming batch url." This reverts commit 19e4e9febf9bb4fd69f6d7bc13a54844e4e096f1. Will merge #10618 instead. --- .../org/apache/spark/streaming/scheduler/JobScheduler.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 2c57706636fa5..1ed6fb0aa9d52 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -26,8 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.ui.{UIUtils => SparkUIUtils} -import org.apache.spark.util.{EventLoop, ThreadUtils} +import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} private[scheduler] sealed trait JobSchedulerEvent @@ -204,7 +203,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { try { val formattedTime = UIUtils.formatBatchTime( job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) - val batchUrl = s"${SparkUIUtils.uiRoot}/streaming/batch/?id=${job.time.milliseconds}" + val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" val batchLinkText = s"[output operation ${job.outputOpId}, batch time ${formattedTime}]" ssc.sc.setJobDescription( From 6f7ba6409a39fd2e34865e3e7a84a3dd0b00d6a4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Jan 2016 15:54:00 -0800 Subject: [PATCH 1470/2089] [SPARK-12681] [SQL] split IdentifiersParser.g into two files To avoid to have a huge Java source (over 64K loc), that can't be compiled. cc hvanhovell Author: Davies Liu Closes #10624 from davies/split_ident. --- .../sql/catalyst/parser/ExpressionParser.g | 565 ++++++++++++++++++ .../sql/catalyst/parser/IdentifiersParser.g | 515 ---------------- .../sql/catalyst/parser/SparkSqlParser.g | 2 +- 3 files changed, 566 insertions(+), 516 deletions(-) create mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g new file mode 100644 index 0000000000000..cad770122d150 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -0,0 +1,565 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. +*/ + +parser grammar ExpressionParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +// fun(par1, par2, par3) +function +@init { gParent.pushMsg("function specification", state); } +@after { gParent.popMsg(state); } + : + functionName + LPAREN + ( + (STAR) => (star=STAR) + | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? + ) + RPAREN (KW_OVER ws=window_specification)? + -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) + -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) + -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?) + ; + +functionName +@init { gParent.pushMsg("function name", state); } +@after { gParent.popMsg(state); } + : // Keyword IF is also a function name + (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) + | + (functionIdentifier) => functionIdentifier + | + {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] + ; + +castExpression +@init { gParent.pushMsg("cast expression", state); } +@after { gParent.popMsg(state); } + : + KW_CAST + LPAREN + expression + KW_AS + primitiveType + RPAREN -> ^(TOK_FUNCTION primitiveType expression) + ; + +caseExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE expression + (KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_CASE expression*) + ; + +whenExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE + ( KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) + ; + +constant +@init { gParent.pushMsg("constant", state); } +@after { gParent.popMsg(state); } + : + Number + | dateLiteral + | timestampLiteral + | intervalLiteral + | StringLiteral + | stringLiteralSequence + | BigintLiteral + | SmallintLiteral + | TinyintLiteral + | DecimalLiteral + | charSetStringLiteral + | booleanValue + ; + +stringLiteralSequence + : + StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) + ; + +charSetStringLiteral +@init { gParent.pushMsg("character string literal", state); } +@after { gParent.popMsg(state); } + : + csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) + ; + +dateLiteral + : + KW_DATE StringLiteral -> + { + // Create DateLiteral token, but with the text of the string value + // This makes the dateLiteral more consistent with the other type literals. + adaptor.create(TOK_DATELITERAL, $StringLiteral.text) + } + | + KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) + ; + +timestampLiteral + : + KW_TIMESTAMP StringLiteral -> + { + adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) + } + | + KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) + ; + +intervalLiteral + : + KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> + { + adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + } + ; + +intervalQualifiers + : + KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL + | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL + | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL + | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL + | KW_DAY -> TOK_INTERVAL_DAY_LITERAL + | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL + | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL + | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + ; + +expression +@init { gParent.pushMsg("expression specification", state); } +@after { gParent.popMsg(state); } + : + precedenceOrExpression + ; + +atomExpression + : + (KW_NULL) => KW_NULL -> TOK_NULL + | (constant) => constant + | castExpression + | caseExpression + | whenExpression + | (functionName LPAREN) => function + | tableOrColumn + | LPAREN! expression RPAREN! + ; + + +precedenceFieldExpression + : + atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* + ; + +precedenceUnaryOperator + : + PLUS | MINUS | TILDE + ; + +nullCondition + : + KW_NULL -> ^(TOK_ISNULL) + | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) + ; + +precedenceUnaryPrefixExpression + : + (precedenceUnaryOperator^)* precedenceFieldExpression + ; + +precedenceUnarySuffixExpression + : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) + -> precedenceUnaryPrefixExpression + ; + + +precedenceBitwiseXorOperator + : + BITWISEXOR + ; + +precedenceBitwiseXorExpression + : + precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* + ; + + +precedenceStarOperator + : + STAR | DIVIDE | MOD | DIV + ; + +precedenceStarExpression + : + precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* + ; + + +precedencePlusOperator + : + PLUS | MINUS + ; + +precedencePlusExpression + : + precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* + ; + + +precedenceAmpersandOperator + : + AMPERSAND + ; + +precedenceAmpersandExpression + : + precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* + ; + + +precedenceBitwiseOrOperator + : + BITWISEOR + ; + +precedenceBitwiseOrExpression + : + precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* + ; + + +// Equal operators supporting NOT prefix +precedenceEqualNegatableOperator + : + KW_LIKE | KW_RLIKE | KW_REGEXP + ; + +precedenceEqualOperator + : + precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +subQueryExpression + : + LPAREN! selectStatement[true] RPAREN! + ; + +precedenceEqualExpression + : + (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple + | + precedenceEqualExpressionSingle + ; + +precedenceEqualExpressionSingle + : + (left=precedenceBitwiseOrExpression -> $left) + ( + (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) + -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) + | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) + -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) + | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) + -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) + | (KW_NOT KW_IN expressions) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) + | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) + | (KW_IN expressions) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) + | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) + | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) + )* + | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) + ; + +expressions + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) +precedenceEqualExpressionMutiple + : + (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) + ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) + | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) + ; + +expressionsToStruct + : + LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) + ; + +precedenceNotOperator + : + KW_NOT + ; + +precedenceNotExpression + : + (precedenceNotOperator^)* precedenceEqualExpression + ; + + +precedenceAndOperator + : + KW_AND + ; + +precedenceAndExpression + : + precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* + ; + + +precedenceOrOperator + : + KW_OR + ; + +precedenceOrExpression + : + precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* + ; + + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : db=identifier DOT fn=identifier + -> Identifier[$db.text + "." + $fn.text] + | + identifier + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g index 86c6bd610f912..916eb6a7ac26b 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g @@ -182,518 +182,3 @@ sortByClause columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) ) ; - -// fun(par1, par2, par3) -function -@init { gParent.pushMsg("function specification", state); } -@after { gParent.popMsg(state); } - : - functionName - LPAREN - ( - (STAR) => (star=STAR) - | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? - ) - RPAREN (KW_OVER ws=window_specification)? - -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) - -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) - -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?) - ; - -functionName -@init { gParent.pushMsg("function name", state); } -@after { gParent.popMsg(state); } - : // Keyword IF is also a function name - (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) - | - (functionIdentifier) => functionIdentifier - | - {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] - ; - -castExpression -@init { gParent.pushMsg("cast expression", state); } -@after { gParent.popMsg(state); } - : - KW_CAST - LPAREN - expression - KW_AS - primitiveType - RPAREN -> ^(TOK_FUNCTION primitiveType expression) - ; - -caseExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE expression - (KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_CASE expression*) - ; - -whenExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE - ( KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) - ; - -constant -@init { gParent.pushMsg("constant", state); } -@after { gParent.popMsg(state); } - : - Number - | dateLiteral - | timestampLiteral - | intervalLiteral - | StringLiteral - | stringLiteralSequence - | BigintLiteral - | SmallintLiteral - | TinyintLiteral - | DecimalLiteral - | charSetStringLiteral - | booleanValue - ; - -stringLiteralSequence - : - StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) - ; - -charSetStringLiteral -@init { gParent.pushMsg("character string literal", state); } -@after { gParent.popMsg(state); } - : - csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) - ; - -dateLiteral - : - KW_DATE StringLiteral -> - { - // Create DateLiteral token, but with the text of the string value - // This makes the dateLiteral more consistent with the other type literals. - adaptor.create(TOK_DATELITERAL, $StringLiteral.text) - } - | - KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) - ; - -timestampLiteral - : - KW_TIMESTAMP StringLiteral -> - { - adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) - } - | - KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) - ; - -intervalLiteral - : - KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> - { - adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) - } - ; - -intervalQualifiers - : - KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL - | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL - | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL - | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL - | KW_DAY -> TOK_INTERVAL_DAY_LITERAL - | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL - | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL - | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL - ; - -expression -@init { gParent.pushMsg("expression specification", state); } -@after { gParent.popMsg(state); } - : - precedenceOrExpression - ; - -atomExpression - : - (KW_NULL) => KW_NULL -> TOK_NULL - | (constant) => constant - | castExpression - | caseExpression - | whenExpression - | (functionName LPAREN) => function - | tableOrColumn - | LPAREN! expression RPAREN! - ; - - -precedenceFieldExpression - : - atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* - ; - -precedenceUnaryOperator - : - PLUS | MINUS | TILDE - ; - -nullCondition - : - KW_NULL -> ^(TOK_ISNULL) - | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) - ; - -precedenceUnaryPrefixExpression - : - (precedenceUnaryOperator^)* precedenceFieldExpression - ; - -precedenceUnarySuffixExpression - : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? - -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) - -> precedenceUnaryPrefixExpression - ; - - -precedenceBitwiseXorOperator - : - BITWISEXOR - ; - -precedenceBitwiseXorExpression - : - precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* - ; - - -precedenceStarOperator - : - STAR | DIVIDE | MOD | DIV - ; - -precedenceStarExpression - : - precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* - ; - - -precedencePlusOperator - : - PLUS | MINUS - ; - -precedencePlusExpression - : - precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* - ; - - -precedenceAmpersandOperator - : - AMPERSAND - ; - -precedenceAmpersandExpression - : - precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* - ; - - -precedenceBitwiseOrOperator - : - BITWISEOR - ; - -precedenceBitwiseOrExpression - : - precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* - ; - - -// Equal operators supporting NOT prefix -precedenceEqualNegatableOperator - : - KW_LIKE | KW_RLIKE | KW_REGEXP - ; - -precedenceEqualOperator - : - precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -subQueryExpression - : - LPAREN! selectStatement[true] RPAREN! - ; - -precedenceEqualExpression - : - (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple - | - precedenceEqualExpressionSingle - ; - -precedenceEqualExpressionSingle - : - (left=precedenceBitwiseOrExpression -> $left) - ( - (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) - -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) - | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) - -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) - | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) - -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) - | (KW_NOT KW_IN expressions) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) - | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) - -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) - | (KW_IN expressions) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) - | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) - | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) - )* - | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) - ; - -expressions - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) -precedenceEqualExpressionMutiple - : - (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) - ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) - | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) - ; - -expressionsToStruct - : - LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) - ; - -precedenceNotOperator - : - KW_NOT - ; - -precedenceNotExpression - : - (precedenceNotOperator^)* precedenceEqualExpression - ; - - -precedenceAndOperator - : - KW_AND - ; - -precedenceAndExpression - : - precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* - ; - - -precedenceOrOperator - : - KW_OR - ; - -precedenceOrExpression - : - precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* - ; - - -booleanValue - : - KW_TRUE^ | KW_FALSE^ - ; - -booleanValueTok - : - KW_TRUE -> TOK_TRUE - | KW_FALSE -> TOK_FALSE - ; - -tableOrPartition - : - tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) - ; - -partitionSpec - : - KW_PARTITION - LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) - ; - -partitionVal - : - identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) - ; - -dropPartitionSpec - : - KW_PARTITION - LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) - ; - -dropPartitionVal - : - identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) - ; - -dropPartitionOperator - : - EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -sysFuncNames - : - KW_AND - | KW_OR - | KW_NOT - | KW_LIKE - | KW_IF - | KW_CASE - | KW_WHEN - | KW_TINYINT - | KW_SMALLINT - | KW_INT - | KW_BIGINT - | KW_FLOAT - | KW_DOUBLE - | KW_BOOLEAN - | KW_STRING - | KW_BINARY - | KW_ARRAY - | KW_MAP - | KW_STRUCT - | KW_UNIONTYPE - | EQUAL - | EQUAL_NS - | NOTEQUAL - | LESSTHANOREQUALTO - | LESSTHAN - | GREATERTHANOREQUALTO - | GREATERTHAN - | DIVIDE - | PLUS - | MINUS - | STAR - | MOD - | DIV - | AMPERSAND - | TILDE - | BITWISEOR - | BITWISEXOR - | KW_RLIKE - | KW_REGEXP - | KW_IN - | KW_BETWEEN - ; - -descFuncNames - : - (sysFuncNames) => sysFuncNames - | StringLiteral - | functionIdentifier - ; - -identifier - : - Identifier - | nonReserved -> Identifier[$nonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -functionIdentifier -@init { gParent.pushMsg("function identifier", state); } -@after { gParent.popMsg(state); } - : db=identifier DOT fn=identifier - -> Identifier[$db.text + "." + $fn.text] - | - identifier - ; - -principalIdentifier -@init { gParent.pushMsg("identifier for principal spec", state); } -@after { gParent.popMsg(state); } - : identifier - | QuotedIdentifier - ; - -//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved -//Non reserved keywords are basically the keywords that can be used as identifiers. -//All the KW_* are automatically not only keywords, but also reserved keywords. -//That means, they can NOT be used as identifiers. -//If you would like to use them as identifiers, put them in the nonReserved list below. -//If you are not sure, please refer to the SQL2011 column in -//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html -nonReserved - : - KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS - | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS - | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY - | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY - | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE - | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT - | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE - | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR - | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG - | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE - | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY - | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER - | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE - | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED - | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED - | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED - | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET - | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR - | KW_WORK - | KW_TRANSACTION - | KW_WRITE - | KW_ISOLATION - | KW_LEVEL - | KW_SNAPSHOT - | KW_AUTOCOMMIT - | KW_ANTI -; - -//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. -sql11ReservedKeywordsUsedAsCastFunctionName - : - KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP - ; - -//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. -//We are planning to remove the following whole list after several releases. -//Thus, please do not change the following list unless you know what to do. -sql11ReservedKeywordsUsedAsIdentifier - : - KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN - | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE - | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT - | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL - | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION - | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT - | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE - | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH -//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. - | KW_REGEXP | KW_RLIKE - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 98b46794a630c..4afce3090f739 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -26,7 +26,7 @@ ASTLabelType=CommonTree; backtrack=false; k=3; } -import SelectClauseParser, FromClauseParser, IdentifiersParser; +import SelectClauseParser, FromClauseParser, IdentifiersParser, ExpressionParser; tokens { TOK_INSERT; From 917d3fc069fb9ea1c1487119c9c12b373f4f9b77 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Jan 2016 16:58:10 -0800 Subject: [PATCH 1471/2089] [SPARK-12539][SQL] support writing bucketed table This PR adds bucket write support to Spark SQL. User can specify bucketing columns, numBuckets and sorting columns with or without partition columns. For example: ``` df.write.partitionBy("year").bucketBy(8, "country").sortBy("amount").saveAsTable("sales") ``` When bucketing is used, we will calculate bucket id for each record, and group the records by bucket id. For each group, we will create a file with bucket id in its name, and write data into it. For each bucket file, if sorting columns are specified, the data will be sorted before write. Note that there may be multiply files for one bucket, as the data is distributed. Currently we store the bucket metadata at hive metastore in a non-hive-compatible way. We use different bucketing hash function compared to hive, so we can't be compatible anyway. Limitations: * Can't write bucketed data without hive metastore. * Can't insert bucketed data into existing hive tables. Author: Wenchen Fan Closes #10498 from cloud-fan/bucket-write. --- .../apache/spark/sql/DataFrameWriter.scala | 89 ++++++- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../sql/execution/datasources/DDLParser.scala | 1 + .../InsertIntoHadoopFsRelation.scala | 2 +- .../datasources/ResolvedDataSource.scala | 2 + .../datasources/WriterContainer.scala | 219 +++++++++++++----- .../sql/execution/datasources/bucket.scala | 57 +++++ .../spark/sql/execution/datasources/ddl.scala | 10 +- .../datasources/json/JSONRelation.scala | 35 ++- .../datasources/parquet/ParquetRelation.scala | 28 ++- .../sql/execution/datasources/rules.scala | 24 +- .../apache/spark/sql/sources/interfaces.scala | 34 ++- .../spark/sql/hive/HiveMetastoreCatalog.scala | 23 +- .../spark/sql/hive/HiveStrategies.scala | 7 +- .../spark/sql/hive/execution/commands.scala | 15 +- .../spark/sql/hive/orc/OrcRelation.scala | 20 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 1 + .../sql/sources/BucketedWriteSuite.scala | 169 ++++++++++++++ 18 files changed, 626 insertions(+), 117 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e2d72a549e6b0..00f9817b53976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,9 +23,9 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.sources.HadoopFsRelation @@ -128,6 +128,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + this.numBuckets = Option(numBuckets) + this.bucketColumnNames = Option(colName +: colNames) + this + } + + /** + * Sorts the output in each bucket by the given columns. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def sortBy(colName: String, colNames: String*): DataFrameWriter = { + this.sortColumnNames = Option(colName +: colNames) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -144,10 +172,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { + assertNotBucketed() ResolvedDataSource( df.sqlContext, source, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df) @@ -166,6 +196,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { + assertNotBucketed() val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -188,13 +219,47 @@ final class DataFrameWriter private[sql](df: DataFrame) { ifNotExists = false)).toRdd } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => - parCols.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => + cols.map(normalize(_, "Bucketing")) + } + + private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => + cols.map(normalize(_, "Sorting")) + } + + private def getBucketSpec: Option[BucketSpec] = { + if (sortColumnNames.isDefined) { + require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + } + + for { + n <- numBuckets + } yield { + require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) + } + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNotBucketed(): Unit = { + if (numBuckets.isDefined || sortColumnNames.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing bucketed data to this data source.") } } @@ -244,6 +309,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df.logicalPlan) @@ -372,4 +438,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var partitioningColumns: Option[Seq[String]] = None + private var bucketColumnNames: Option[Seq[String]] = None + + private var numBuckets: Option[Int] = None + + private var sortColumnNames: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6cf75bc17039c..482130a18d939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -382,13 +382,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) - if partitionsCols.nonEmpty => + case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => + case c: CreateTableUsingAsSelect if c.temporary => val cmd = CreateTempTableUsingAsSelect( - tableIdent, provider, Array.empty[String], mode, opts, query) + c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 48eff62b297f2..d8d21b06b8b35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -109,6 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) provider, temp.isDefined, Array.empty[String], + bucketSpec = None, mode, options, queryPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 38152d0cf1a48..7a8691e7cb9c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 0ca0a38f712ce..ece9b8a9a9174 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -210,6 +210,7 @@ object ResolvedDataSource extends Logging { sqlContext: SQLContext, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { @@ -244,6 +245,7 @@ object ResolvedDataSource extends Logging { Array(outputPath.toString), Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + bucketSpec, caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9f23d531072aa..4f8524f4b967c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType, StringType} import org.apache.spark.util.SerializableConfiguration @@ -121,9 +121,9 @@ private[sql] abstract class BaseWriterContainer( } } - protected def newOutputWriter(path: String): OutputWriter = { + protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { try { - outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -312,19 +312,23 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] - executorSideSetup(taskContext) + private val bucketSpec = relation.bucketSpec - var outputWritersCleared = false + private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) + } - // Returns the partition key given an input row - val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) + } + + private def bucketIdExpression: Option[Expression] = for { + BucketSpec(numBuckets, _, _) <- bucketSpec + } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) - // Expressions that given a partition key build a string like: col1=val/col2=val/... - val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + // Expressions that given a partition key build a string like: col1=val/col2=val/... + private def partitionStringExpression: Seq[Expression] = { + partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( PartitioningUtils.escapePathName _, @@ -335,6 +339,121 @@ private[sql] class DynamicPartitionWriterContainer( val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } + } + + private def getBucketIdFromKey(key: InternalRow): Option[Int] = { + if (bucketSpec.isDefined) { + Some(key.getInt(partitionColumns.length)) + } else { + None + } + } + + private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { + val bucketIdIndex = partitionColumns.length + if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { + false + } else { + var i = partitionColumns.length - 1 + while (i >= 0) { + val dt = partitionColumns(i).dataType + if (key1.get(i, dt) != key2.get(i, dt)) return false + i -= 1 + } + true + } + } + + private def sortBasedWrite( + sorter: UnsafeKVExternalSorter, + iterator: Iterator[InternalRow], + getSortingKey: UnsafeProjection, + getOutputRow: UnsafeProjection, + getPartitionString: UnsafeProjection, + outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { + (key1, key2) => key1 != key2 + } else { + (key1, key2) => key1 == null || !sameBucket(key1, key2) + } + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (needNewWriter(currentKey, sortedIterator.getKey)) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + /** + * Open and returns a new OutputWriter given a partition key and optional bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + */ + private def newOutputWriter( + key: InternalRow, + getPartitionString: UnsafeProjection): OutputWriter = { + val configuration = taskAttemptContext.getConfiguration + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + configuration.set("spark.sql.sources.output.path", outputPath) + getWorkPath + } + val bucketId = getBucketIdFromKey(key) + val newWriter = super.newOutputWriter(path, bucketId) + newWriter.initConverter(dataSchema) + newWriter + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + var outputWritersCleared = false + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val getSortingKey = + UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) + + val sortingKeySchema = if (bucketSpec.isEmpty) { + StructType.fromAttributes(partitionColumns) + } else { // If it's bucketed, we should also consider bucket id as part of the key. + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) + StructType(fields) + } + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) // Returns the partition path given a partition key. val getPartitionString = @@ -342,22 +461,34 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null + // If there is no sorting columns, we set sorter to null and try the hash-based writing first, + // and fill the sorter if there are too many writers and we need to fall back on sorting. + // If there are sorting columns, then we have to sort the data anyway, and no need to try the + // hash-based writing first. + var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { + new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + } else { + null + } while (iterator.hasNext && sorter == null) { val inputRow = iterator.next() - val currentKey = getPartitionKey(inputRow) + // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. + val currentKey = getSortingKey(inputRow) var currentWriter = outputWriters.get(currentKey) if (currentWriter == null) { if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) + currentWriter = newOutputWriter(currentKey, getPartitionString) outputWriters.put(currentKey.copy(), currentWriter) currentWriter.writeInternal(getOutputRow(inputRow)) } else { logInfo(s"Maximum partitions reached, falling back on sorting.") sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(partitionColumns), + sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, TaskContext.get().taskMemoryManager().pageSizeBytes) @@ -369,39 +500,15 @@ private[sql] class DynamicPartitionWriterContainer( } // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort. + // using external sort, or there are sorting columns and we need to sort the whole data set. if (sorter != null) { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } + sortBasedWrite( + sorter, + iterator, + getSortingKey, + getOutputRow, + getPartitionString, + outputWriters) } commitTask() @@ -412,18 +519,6 @@ private[sql] class DynamicPartitionWriterContainer( throw new SparkException("Task failed while writing rows.", cause) } - /** Open and returns a new OutputWriter given a partition key. */ - def newOutputWriter(key: InternalRow): OutputWriter = { - val partitionPath = getPartitionString(key).getString(0) - val path = new Path(getWorkPath, partitionPath) - val configuration = taskAttemptContext.getConfiguration - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = super.newOutputWriter(path.toString) - newWriter.initConverter(dataSchema) - newWriter - } - def clearOutputWriters(): Unit = { if (!outputWritersCleared) { outputWriters.asScala.values.foreach(_.close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala new file mode 100644 index 0000000000000..82287c8967134 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} +import org.apache.spark.sql.types.StructType + +/** + * A container for bucketing information. + * Bucketing is a technology for decomposing data sets into more manageable parts, and the number + * of buckets is fixed so it does not fluctuate with data. + * + * @param numBuckets number of buckets. + * @param bucketColumnNames the names of the columns that used to generate the bucket id. + * @param sortColumnNames the names of the columns that used to sort data in each bucket. + */ +private[sql] case class BucketSpec( + numBuckets: Int, + bucketColumnNames: Seq[String], + sortColumnNames: Seq[String]) + +private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProvider { + final override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = + // TODO: throw exception here as we won't call this method during execution, after bucketed read + // support is finished. + createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters) +} + +private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { + final override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + throw new UnsupportedOperationException("use bucket version") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index aed5d0dcf2d8a..0897fcadbc011 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,6 +76,7 @@ case class CreateTableUsingAsSelect( provider: String, temporary: Boolean, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { @@ -109,7 +110,14 @@ case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + bucketSpec = None, + mode, + options, + df) sqlContext.catalog.registerTable( tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 8bf538178b5d9..b92edf65bfb6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -34,13 +34,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "json" @@ -49,6 +49,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { new JSONRelation( @@ -56,6 +57,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, + bucketSpec = bucketSpec, paths = paths, parameters = parameters)(sqlContext) } @@ -66,11 +68,29 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { + def this( + inputRDD: Option[RDD[String]], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + userDefinedPartitionColumns: Option[StructType], + paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String])(sqlContext: SQLContext) = { + this( + inputRDD, + maybeDataSchema, + maybePartitionSpec, + userDefinedPartitionColumns, + None, + paths, + parameters)(sqlContext) + } + val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) /** Constraints to be imposed on schema to be stored. */ @@ -158,13 +178,14 @@ private[sql] class JSONRelation( partitionColumns) } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + new BucketedOutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) + new JsonOutputWriter(path, bucketId, dataSchema, context) } } } @@ -172,6 +193,7 @@ private[sql] class JSONRelation( private[json] class JsonOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with Logging { @@ -188,7 +210,8 @@ private[json] class JsonOutputWriter( val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 45f1dff96db08..4b375de05e9e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -45,13 +45,13 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser -import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "parquet" @@ -60,13 +60,17 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) +private[sql] class ParquetOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { @@ -86,7 +90,8 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } } } @@ -107,6 +112,7 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -123,6 +129,7 @@ private[sql] class ParquetRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + None, parameters)(sqlContext) } @@ -216,7 +223,7 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible @@ -276,10 +283,13 @@ private[sql] class ParquetRelation( sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, bucketId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 50ecbd35760d8..d484403d1c641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -165,22 +165,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => + case c: CreateTableUsingAsSelect => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { + if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { + EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _) => // Get all input data source relations of the query. - val srcRelations = query.collect { + val srcRelations = c.child.collect { case LogicalRelation(src: BaseRelation, _) => src } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableIdent that is also being read from.") + s"Cannot overwrite table ${c.tableIdent} that is also being read from.") } else { // OK } @@ -192,7 +192,17 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) + + for { + spec <- c.bucketSpec + sortColumnName <- spec.sortColumnNames + sortColumn <- c.child.schema.find(_.name == sortColumnName) + } { + if (!RowOrdering.isOrderable(sortColumn.dataType)) { + failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") + } + } case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f4c7f0a269323..c35f33132f602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -161,6 +161,20 @@ trait HadoopFsRelationProvider { dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation + + // TODO: expose bucket API to users. + private[sql] def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], + parameters: Map[String, String]): HadoopFsRelation = { + if (bucketSpec.isDefined) { + throw new AnalysisException("Currently we don't support bucketing for this data source.") + } + createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters) + } } /** @@ -351,7 +365,18 @@ abstract class OutputWriterFactory extends Serializable { * * @since 1.4.0 */ - def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter + def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter + + // TODO: expose bucket API to users. + private[sql] def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + newInstance(path, dataSchema, context) } /** @@ -435,6 +460,9 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ + // TODO: expose bucket API to users. + private[sql] def bucketSpec: Option[BucketSpec] = None + private class FileStatusCache { var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 1616c4595221d..43d84d507b20e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.execution.{datasources, FileRelation} -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand @@ -211,6 +211,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { @@ -240,6 +241,25 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get + + tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) + tableProperties.put("spark.sql.sources.schema.numBucketCols", + bucketColumnNames.length.toString) + bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) => + tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) + } + + if (sortColumnNames.nonEmpty) { + tableProperties.put("spark.sql.sources.schema.numSortCols", + sortColumnNames.length.toString) + sortColumnNames.zipWithIndex.foreach { case (sortCol, index) => + tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + } + } + } + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover @@ -596,6 +616,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive conf.defaultDataSourceName, temporary = false, Array.empty[String], + bucketSpec = None, mode, options = Map.empty[String, String], child diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0b4f5a0fd6ea6..3687dd6f5a7ab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -88,10 +88,9 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect( - tableIdent, provider, false, partitionCols, mode, opts, query) => - val cmd = - CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) + case c: CreateTableUsingAsSelect => + val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, + c.bucketSpec, c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 94210a5394f9b..612f01cda88ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -151,6 +151,7 @@ case class CreateMetastoreDataSource( tableIdent, userSpecifiedSchema, Array.empty[String], + bucketSpec = None, provider, optionsWithPath, isExternal) @@ -164,6 +165,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -254,8 +256,14 @@ case class CreateMetastoreDataSourceAsSelect( } // Create the relation based on the data of df. - val resolved = - ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + bucketSpec, + mode, + optionsWithPath, + df) if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of @@ -265,6 +273,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent, Some(resolved.relation.schema), partitionColumns, + bucketSpec, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 3538d642d5231..14fa152c2331d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -37,13 +37,13 @@ import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "orc" @@ -52,17 +52,19 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") - new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + new OrcRelation(paths, dataSchema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } private[orc] class OrcOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with HiveInspectors { @@ -101,7 +103,8 @@ private[orc] class OrcOutputWriter( val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId - val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), @@ -153,6 +156,7 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -169,6 +173,7 @@ private[sql] class OrcRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + None, parameters)(sqlContext) } @@ -205,7 +210,7 @@ private[sql] class OrcRelation( OrcTableScan(output, this, filters, inputPaths).execute() } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { job.getConfiguration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) @@ -216,12 +221,13 @@ private[sql] class OrcRelation( classOf[MapRedOutputFormat[_, _]]) } - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, dataSchema, context) + new OrcOutputWriter(path, bucketId, dataSchema, context) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e22dac3bc9e87..202851ae1366e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -707,6 +707,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], + bucketSpec = None, provider = "json", options = Map("path" -> "just a dummy path"), isExternal = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala new file mode 100644 index 0000000000000..579da0291f291 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest} + +class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("bucketed by non-existing column") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) + } + + test("numBuckets not greater than 0 or less than 100000") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt")) + intercept[IllegalArgumentException](df.write.bucketBy(100000, "i").saveAsTable("tt")) + } + + test("specify sorting columns without bucketing columns") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + } + + test("sorting by non-orderable column") { + val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) + } + + test("write bucketed data to unsupported data source") { + val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") + intercept[AnalysisException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) + } + + test("write bucketed data to non-hive-table or existing hive table") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) + } + + private val testFileName = """.*-(\d+)$""".r + private val otherFileName = """.*-(\d+)\..*""".r + private def getBucketId(fileName: String): Int = { + fileName match { + case testFileName(bucketId) => bucketId.toInt + case otherFileName(bucketId) => bucketId.toInt + } + } + + private def testBucketing( + dataDir: File, + source: String, + bucketCols: Seq[String], + sortCols: Seq[String] = Nil): Unit = { + val allBucketFiles = dataDir.listFiles().filterNot(f => + f.getName.startsWith(".") || f.getName.startsWith("_") + ) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) + .select((bucketCols ++ sortCols).map(col): _*) + + if (sortCols.nonEmpty) { + checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + } + + val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) + } + } + } + } + + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + + test("write bucketed data") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) + } + } + } + } + + test("write bucketed data with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) + } + } + } + } + + test("write bucketed data without partitionBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j")) + } + } + } + + test("write bucketed data without partitionBy with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) + } + } + } +} From ac56cf605b61803c26e0004b43c703cca7e02d61 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 6 Jan 2016 17:17:32 -0800 Subject: [PATCH 1472/2089] [SPARK-12604][CORE] Java count(AprroxDistinct)ByKey methods return Scala Long not Java Change Java countByKey, countApproxDistinctByKey return types to use Java Long, not Scala; update similar methods for consistency on java.long.Long.valueOf with no API change Author: Sean Owen Closes #10554 from srowen/SPARK-12604. --- .../apache/spark/api/java/JavaPairRDD.scala | 32 +++++++++++-------- .../apache/spark/api/java/JavaRDDLike.scala | 16 +++++----- .../java/org/apache/spark/JavaAPISuite.java | 18 +++++------ .../streaming/api/java/JavaDStreamLike.scala | 18 +++++------ .../streaming/api/java/JavaPairDStream.scala | 7 ++-- 5 files changed, 49 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 91dc18697c352..76752e1fde663 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.java +import java.{lang => jl} import java.lang.{Iterable => JIterable} -import java.util.{Comparator, List => JList, Map => JMap} +import java.util.{Comparator, List => JList} import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -139,7 +140,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * math.ceil(numItems * samplingRate) over all key values. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double], + fractions: java.util.Map[K, Double], seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) @@ -154,7 +155,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Use Utils.random.nextLong as the default seed for the random number generator. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double]): JavaPairRDD[K, V] = + fractions: java.util.Map[K, Double]): JavaPairRDD[K, V] = sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** @@ -168,7 +169,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * two additional passes. */ def sampleByKeyExact(withReplacement: Boolean, - fractions: JMap[K, Double], + fractions: java.util.Map[K, Double], seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) @@ -184,7 +185,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * Use Utils.random.nextLong as the default seed for the random number generator. */ - def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + def sampleByKeyExact( + withReplacement: Boolean, + fractions: java.util.Map[K, Double]): JavaPairRDD[K, V] = sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) /** @@ -292,7 +295,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, jl.Long] = + mapAsSerializableJavaMap(rdd.countByKey().mapValues(jl.Long.valueOf)) /** * Approximate version of countByKey that can return a partial result if it does @@ -934,9 +938,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * It must be greater than 0.000017. * @param partitioner partitioner of the resulting RDD. */ - def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): JavaPairRDD[K, Long] = - { - fromRDD(rdd.countApproxDistinctByKey(relativeSD, partitioner)) + def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner) + : JavaPairRDD[K, jl.Long] = { + fromRDD(rdd.countApproxDistinctByKey(relativeSD, partitioner)). + asInstanceOf[JavaPairRDD[K, jl.Long]] } /** @@ -950,8 +955,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * It must be greater than 0.000017. * @param numPartitions number of partitions of the resulting RDD. */ - def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaPairRDD[K, Long] = { - fromRDD(rdd.countApproxDistinctByKey(relativeSD, numPartitions)) + def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaPairRDD[K, jl.Long] = { + fromRDD(rdd.countApproxDistinctByKey(relativeSD, numPartitions)). + asInstanceOf[JavaPairRDD[K, jl.Long]] } /** @@ -964,8 +970,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. */ - def countApproxDistinctByKey(relativeSD: Double): JavaPairRDD[K, Long] = { - fromRDD(rdd.countApproxDistinctByKey(relativeSD)) + def countApproxDistinctByKey(relativeSD: Double): JavaPairRDD[K, jl.Long] = { + fromRDD(rdd.countApproxDistinctByKey(relativeSD)).asInstanceOf[JavaPairRDD[K, jl.Long]] } /** Assign a name to this RDD */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 6d3485d88a163..1b1a9dce397fd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.java import java.{lang => jl} -import java.lang.{Iterable => JIterable, Long => JLong} +import java.lang.{Iterable => JIterable} import java.util.{Comparator, Iterator => JIterator, List => JList} import scala.collection.JavaConverters._ @@ -305,8 +305,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. */ - def zipWithUniqueId(): JavaPairRDD[T, JLong] = { - JavaPairRDD.fromRDD(rdd.zipWithUniqueId()).asInstanceOf[JavaPairRDD[T, JLong]] + def zipWithUniqueId(): JavaPairRDD[T, jl.Long] = { + JavaPairRDD.fromRDD(rdd.zipWithUniqueId()).asInstanceOf[JavaPairRDD[T, jl.Long]] } /** @@ -316,8 +316,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. */ - def zipWithIndex(): JavaPairRDD[T, JLong] = { - JavaPairRDD.fromRDD(rdd.zipWithIndex()).asInstanceOf[JavaPairRDD[T, JLong]] + def zipWithIndex(): JavaPairRDD[T, jl.Long] = { + JavaPairRDD.fromRDD(rdd.zipWithIndex()).asInstanceOf[JavaPairRDD[T, jl.Long]] } // Actions (launch a job to return a value to the user program) @@ -448,7 +448,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, jl.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue().mapValues(jl.Long.valueOf)) /** * (Experimental) Approximate version of countByValue(). @@ -631,8 +631,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * The asynchronous version of `count`, which returns a * future for counting the number of elements in this RDD. */ - def countAsync(): JavaFutureAction[JLong] = { - new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf) + def countAsync(): JavaFutureAction[jl.Long] = { + new JavaFutureActionWrapper[Long, jl.Long](rdd.countAsync(), jl.Long.valueOf) } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 502f86f178fd2..47382e4231563 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1580,11 +1580,11 @@ public void countApproxDistinctByKey() { } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); - for (Tuple2 resItem : res) { - double count = (double)resItem._1(); - Long resCount = (Long)resItem._2(); - Double error = Math.abs((resCount - count) / count); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); + for (Tuple2 resItem : res) { + double count = resItem._1(); + long resCount = resItem._2(); + double error = Math.abs((resCount - count) / count); Assert.assertTrue(error < 0.1); } @@ -1633,12 +1633,12 @@ public Tuple2 call(Integer i) { fractions.put(0, 0.5); fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); - Map wrCounts = (Map) (Object) wr.countByKey(); + Map wrCounts = wr.countByKey(); Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); - Map worCounts = (Map) (Object) wor.countByKey(); + Map worCounts = wor.countByKey(); Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); @@ -1659,12 +1659,12 @@ public Tuple2 call(Integer i) { fractions.put(0, 0.5); fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); - Map wrExactCounts = (Map) (Object) wrExact.countByKey(); + Map wrExactCounts = wrExact.countByKey(); Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); - Map worExactCounts = (Map) (Object) worExact.countByKey(); + Map worExactCounts = worExact.countByKey(); Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 84acec7d8e330..733147f63ea2e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong} +import java.{lang => jl} import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -50,8 +50,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def wrapRDD(in: RDD[T]): R - implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { - in.map(new JLong(_)) + implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[jl.Long] = { + in.map(jl.Long.valueOf) } /** @@ -74,14 +74,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): JavaDStream[JLong] = dstream.count() + def count(): JavaDStream[jl.Long] = dstream.count() /** * Return a new DStream in which each RDD contains the counts of each distinct value in * each RDD of this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. */ - def countByValue(): JavaPairDStream[T, JLong] = { + def countByValue(): JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong(dstream.countByValue()) } @@ -91,7 +91,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * partitions. * @param numPartitions number of partitions of each RDD in the new DStream. */ - def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { + def countByValue(numPartitions: Int): JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) } @@ -101,7 +101,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of elements in a window over this DStream. windowDuration and slideDuration are as defined in * the window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = { + def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[jl.Long] = { dstream.countByWindow(windowDuration, slideDuration) } @@ -116,7 +116,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * DStream's batching interval */ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[T, JLong] = { + : JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong( dstream.countByValueAndWindow(windowDuration, slideDuration)) } @@ -133,7 +133,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param numPartitions number of partitions of each RDD in the new DStream. */ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[T, JLong] = { + : JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong( dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 2bf3ccec6bc55..af0d84b33224f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.api.java -import java.lang.{Iterable => JIterable, Long => JLong} +import java.{lang => jl} +import java.lang.{Iterable => JIterable} import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -847,7 +848,7 @@ object JavaPairDStream { } def scalaToJavaLong[K: ClassTag](dstream: JavaPairDStream[K, Long]) - : JavaPairDStream[K, JLong] = { - DStream.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) + : JavaPairDStream[K, jl.Long] = { + DStream.toPairDStreamFunctions(dstream.dstream).mapValues(jl.Long.valueOf) } } From a74d743cc7c52a78fa023fdd0d06847b7d48bf78 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 6 Jan 2016 19:20:43 -0800 Subject: [PATCH 1473/2089] [SPARK-12640][SQL] Add simple benchmarking utility class and add Parquet scan benchmarks. [SPARK-12640][SQL] Add simple benchmarking utility class and add Parquet scan benchmarks. We've run benchmarks ad hoc to measure the scanner performance. We will continue to invest in this and it makes sense to get these benchmarks into code. This adds a simple benchmarking utility to do this. Author: Nong Li Author: Nong Closes #10589 from nongli/spark-12640. --- .../org/apache/spark/util/Benchmark.scala | 120 +++++++++++++ .../parquet/ParquetReadBenchmark.scala | 158 ++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/util/Benchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala new file mode 100644 index 0000000000000..457a1a05a1bf5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable + +import org.apache.commons.lang3.SystemUtils + +/** + * Utility class to benchmark components. An example of how to use this is: + * val benchmark = new Benchmark("My Benchmark", valuesPerIteration) + * benchmark.addCase("V1")() + * benchmark.addCase("V2")() + * benchmark.run + * This will output the average time to run each function and the rate of each function. + * + * The benchmark function takes one argument that is the iteration that's being run. + * + * If outputPerIteration is true, the timing for each run will be printed to stdout. + */ +private[spark] class Benchmark( + name: String, valuesPerIteration: Long, + iters: Int = 5, + outputPerIteration: Boolean = false) { + val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + + def addCase(name: String)(f: Int => Unit): Unit = { + benchmarks += Benchmark.Case(name, f) + } + + /** + * Runs the benchmark and outputs the results to stdout. This should be copied and added as + * a comment with the benchmark. Although the results vary from machine to machine, it should + * provide some baseline. + */ + def run(): Unit = { + require(benchmarks.nonEmpty) + // scalastyle:off + println("Running benchmark: " + name) + + val results = benchmarks.map { c => + println(" Running case: " + c.name) + Benchmark.measure(valuesPerIteration, iters, outputPerIteration)(c.fn) + } + println + + val firstRate = results.head.avgRate + // The results are going to be processor specific so it is useful to include that. + println(Benchmark.getProcessorName()) + printf("%-24s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") + println("-------------------------------------------------------------------------") + results.zip(benchmarks).foreach { r => + printf("%-24s %16s %16s %14s\n", + r._2.name, + "%10.2f" format r._1.avgMs, + "%10.2f" format r._1.avgRate, + "%6.2f X" format (r._1.avgRate / firstRate)) + } + println + // scalastyle:on + } +} + +private[spark] object Benchmark { + case class Case(name: String, fn: Int => Unit) + case class Result(avgMs: Double, avgRate: Double) + + /** + * This should return a user helpful processor information. Getting at this depends on the OS. + * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz" + */ + def getProcessorName(): String = { + if (SystemUtils.IS_OS_MAC_OSX) { + Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) + } else if (SystemUtils.IS_OS_LINUX) { + Utils.executeAndGetOutput(Seq("/usr/bin/grep", "-m", "1", "\"model name\"", "/proc/cpuinfo")) + } else { + System.getenv("PROCESSOR_IDENTIFIER") + } + } + + /** + * Runs a single function `f` for iters, returning the average time the function took and + * the rate of the function. + */ + def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { + var totalTime = 0L + for (i <- 0 until iters + 1) { + val start = System.nanoTime() + + f(i) + + val end = System.nanoTime() + if (i != 0) totalTime += end - start + + if (outputPerIteration) { + // scalastyle:off + println(s"Iteration $i took ${(end - start) / 1000} microseconds") + // scalastyle:on + } + } + Result(totalTime.toDouble / 1000000 / iters, num * iters / (totalTime.toDouble / 1000)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala new file mode 100644 index 0000000000000..cab6abde6da23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Benchmark to measure parquet read performance. + * To run this: + * spark-submit --class --jars + */ +object ParquetReadBenchmark { + val conf = new SparkConf() + conf.set("spark.sql.parquet.compression.codec", "snappy") + val sc = new SparkContext("local[1]", "test-sql-context", conf) + val sqlContext = new SQLContext(sc) + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) + } + } + } + + def intScanBenchmark(values: Int): Unit = { + withTempPath { dir => + sqlContext.range(values).write.parquet(dir.getCanonicalPath) + withTempTable("tempTable") { + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + val benchmark = new Benchmark("Single Int Column Scan", values) + + benchmark.addCase("SQL Parquet Reader") { iter => + sqlContext.sql("select sum(id) from tempTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(id) from tempTable").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + benchmark.addCase("ParquetReader") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + reader.initialize(p, ("id" :: Nil).asJava) + + while (reader.nextKeyValue()) { + val record = reader.getCurrentValue + if (!record.isNullAt(0)) sum += record.getInt(0) + } + reader.close() + }} + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + SQL Parquet Reader 1910.0 13.72 1.00 X + SQL Parquet MR 2330.0 11.25 0.82 X + ParquetReader 1252.6 20.93 1.52 X + */ + benchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select id as c1, cast(id as STRING) as c2 from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("Int and String Scan", values) + + benchmark.addCase("SQL Parquet Reader") { iter => + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + + benchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + benchmark.addCase("ParquetReader") { num => + var sum1 = 0L + var sum2 = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + reader.initialize(p, null) + while (reader.nextKeyValue()) { + val record = reader.getCurrentValue + if (!record.isNullAt(0)) sum1 += record.getInt(0) + if (!record.isNullAt(1)) sum2 += record.getUTF8String(1).numBytes() + } + reader.close() + } + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + SQL Parquet Reader 2245.6 7.00 1.00 X + SQL Parquet MR 2914.2 5.40 0.77 X + ParquetReader 1544.6 10.18 1.45 X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + intScanBenchmark(1024 * 1024 * 15) + intStringScanBenchmark(1024 * 1024 * 10) + } +} From 6b6d02be0d4e2ce562dddfb391b3302f79de8276 Mon Sep 17 00:00:00 2001 From: Robert Dodier Date: Wed, 6 Jan 2016 19:49:10 -0800 Subject: [PATCH 1474/2089] [SPARK-12663][MLLIB] More informative error message in MLUtils.loadLibSVMFile This PR contains 1 commit which resolves [SPARK-12663](https://issues.apache.org/jira/browse/SPARK-12663). For the record, I got a positive response from 2 people when I floated this idea on devspark.apache.org on 2015-10-23. [Link to archived discussion.](http://apache-spark-developers-list.1001551.n3.nabble.com/slightly-more-informative-error-message-in-MLUtils-loadLibSVMFile-td14764.html) Author: Robert Dodier Closes #10611 from robert-dodier/loadlibsvmfile-error-msg-branch. --- mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 89186de96988f..74e9271e40329 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -86,7 +86,8 @@ object MLUtils { val indicesLength = indices.length while (i < indicesLength) { val current = indices(i) - require(current > previous, "indices should be one-based and in ascending order" ) + require(current > previous, s"indices should be one-based and in ascending order;" + + " found current=$current, previous=$previous; line=\"$line\"") previous = current i += 1 } From 8e19c7663a067d55b32af68d62da42c7cd5d6009 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 Jan 2016 20:50:31 -0800 Subject: [PATCH 1475/2089] [SPARK-7689] Remove TTL-based metadata cleaning in Spark 2.0 This PR removes `spark.cleaner.ttl` and the associated TTL-based metadata cleaning code. Now that we have the `ContextCleaner` and a timer to trigger periodic GCs, I don't think that `spark.cleaner.ttl` is necessary anymore. The TTL-based cleaning isn't enabled by default, isn't included in our end-to-end tests, and has been a source of user confusion when it is misconfigured. If the TTL is set too low, data which is still being used may be evicted / deleted, leading to hard to diagnose bugs. For all of these reasons, I think that we should remove this functionality in Spark 2.0. Additional benefits of doing this include marginally reduced memory usage, since we no longer need to store timetsamps in hashmaps, and a handful fewer threads. Author: Josh Rosen Closes #10534 from JoshRosen/remove-ttl-based-cleaning. --- .../org/apache/spark/MapOutputTracker.scala | 25 +-- .../scala/org/apache/spark/SparkContext.scala | 21 +-- .../shuffle/FileShuffleBlockResolver.scala | 28 ++- .../apache/spark/storage/BlockManager.scala | 63 ++----- .../apache/spark/util/MetadataCleaner.scala | 110 ----------- .../spark/util/TimeStampedHashSet.scala | 86 --------- .../util/TimeStampedWeakValueHashMap.scala | 171 ------------------ .../spark/util/TimeStampedHashMapSuite.scala | 86 --------- docs/configuration.md | 11 -- .../apache/spark/streaming/Checkpoint.scala | 3 +- .../spark/streaming/dstream/DStream.scala | 14 +- .../streaming/StreamingContextSuite.scala | 19 +- 12 files changed, 48 insertions(+), 589 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 8670f705cdb7e..1b59beb8d6efd 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,7 +18,6 @@ package org.apache.spark import java.io._ -import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} @@ -267,8 +266,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map - * output information, which allows old output information based on a TTL. + * MapOutputTracker for the driver. */ private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { @@ -291,17 +289,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - /** - * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver, - * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). - * Other than these two scenarios, nothing should be dropped from this HashMap. - */ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() - - // For cleaning up TimeStampedHashMaps - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) + // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // Statuses are dropped only by explicit de-registering. + protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { @@ -462,14 +453,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) sendTracker(StopMapOutputTracker) mapStatuses.clear() trackerEndpoint = null - metadataCleaner.cancel() cachedSerializedStatuses.clear() } - - private def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) - } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4a99c0b081d6a..98075cef112db 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} +import java.util.concurrent.ConcurrentMap import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} import java.util.UUID.randomUUID @@ -32,6 +33,7 @@ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal +import com.google.common.collect.MapMaker import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -199,7 +201,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _eventLogDir: Option[URI] = None private var _eventLogCodec: Option[String] = None private var _env: SparkEnv = _ - private var _metadataCleaner: MetadataCleaner = _ private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ private var _progressBar: Option[ConsoleProgressBar] = None @@ -271,8 +272,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] val persistentRdds = { + val map: ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]() + map.asScala + } private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener def statusTracker: SparkStatusTracker = _statusTracker @@ -439,8 +442,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.repl.class.uri", replUri) } - _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - _statusTracker = new SparkStatusTracker(this) _progressBar = @@ -1674,11 +1675,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli env.metricsSystem.report() } } - if (metadataCleaner != null) { - Utils.tryLogNonFatalError { - metadataCleaner.cancel() - } - } Utils.tryLogNonFatalError { _cleaner.foreach(_.stop()) } @@ -2085,11 +2081,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ - private[spark] def cleanup(cleanupTime: Long) { - persistentRdds.clearOldValues(cleanupTime) - } - // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having finished construction. // NOTE: this must be placed at the end of the SparkContext constructor. diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 7abcb29672cf5..294e16cde1931 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.collection.JavaConverters._ @@ -27,7 +27,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.Utils /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -63,10 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val completedMapTasks = new ConcurrentLinkedQueue[Int]() } - private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] - - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState] /** * Get a ShuffleWriterGroup for the given map task, which will register it as complete @@ -75,9 +72,12 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) - private val shuffleState = shuffleStates(shuffleId) - + private val shuffleState: ShuffleState = { + // Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use + // .getOrElseUpdate() because that's actually NOT atomic. + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) + shuffleStates.get(shuffleId) + } val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() val writers: Array[DiskBlockObjectWriter] = { @@ -114,7 +114,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { - shuffleStates.get(shuffleId) match { + Option(shuffleStates.get(shuffleId)) match { case Some(state) => for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) @@ -131,11 +131,5 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) - } - - override def stop() { - metadataCleaner.cancel() - } + override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8caf9e55359e0..5c80ac17b8d90 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,7 +19,9 @@ package org.apache.spark.storage import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ @@ -75,7 +77,7 @@ private[spark] class BlockManager( val diskBlockManager = new DiskBlockManager(this, conf) - private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val blockInfo = new ConcurrentHashMap[BlockId, BlockInfo] private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) @@ -147,11 +149,6 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) - private val broadcastCleaner = new MetadataCleaner( - MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) - // Field related to peer block managers that are necessary for block replication @volatile private var cachedPeers: Seq[BlockManagerId] = _ private val peerFetchLock = new Object @@ -232,7 +229,7 @@ private[spark] class BlockManager( */ private def reportAllBlocks(): Unit = { logInfo(s"Reporting ${blockInfo.size} blocks to the master.") - for ((blockId, info) <- blockInfo) { + for ((blockId, info) <- blockInfo.asScala) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { logError(s"Failed to report $blockId to master; giving up.") @@ -313,7 +310,7 @@ private[spark] class BlockManager( * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { - blockInfo.get(blockId).map { info => + blockInfo.asScala.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L // Assume that block is not in external block store @@ -327,7 +324,7 @@ private[spark] class BlockManager( * may not know of). */ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { - (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + (blockInfo.asScala.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq } /** @@ -439,7 +436,7 @@ private[spark] class BlockManager( } private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { - val info = blockInfo.get(blockId).orNull + val info = blockInfo.get(blockId) if (info != null) { info.synchronized { // Double check to make sure the block is still there. There is a small chance that the @@ -447,7 +444,7 @@ private[spark] class BlockManager( // Note that this only checks metadata tracking. If user intentionally deleted the block // on disk or from off heap storage without using removeBlock, this conditional check will // still pass but eventually we will get an exception because we can't find the block. - if (blockInfo.get(blockId).isEmpty) { + if (blockInfo.asScala.get(blockId).isEmpty) { logWarning(s"Block $blockId had been removed") return None } @@ -731,7 +728,7 @@ private[spark] class BlockManager( val putBlockInfo = { val tinfo = new BlockInfo(level, tellMaster) // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + val oldBlockOpt = Option(blockInfo.putIfAbsent(blockId, tinfo)) if (oldBlockOpt.isDefined) { if (oldBlockOpt.get.waitForReady()) { logWarning(s"Block $blockId already exists on this machine; not re-adding it") @@ -1032,7 +1029,7 @@ private[spark] class BlockManager( data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") - val info = blockInfo.get(blockId).orNull + val info = blockInfo.get(blockId) // If the block has not already been dropped if (info != null) { @@ -1043,7 +1040,7 @@ private[spark] class BlockManager( // If we get here, the block write failed. logWarning(s"Block $blockId was marked as failure. Nothing to drop") return None - } else if (blockInfo.get(blockId).isEmpty) { + } else if (blockInfo.asScala.get(blockId).isEmpty) { logWarning(s"Block $blockId was already dropped.") return None } @@ -1095,7 +1092,7 @@ private[spark] class BlockManager( def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(s"Removing RDD $rddId") - val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocksToRemove = blockInfo.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } @@ -1105,7 +1102,7 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfo.keys.collect { + val blocksToRemove = blockInfo.asScala.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } @@ -1117,7 +1114,7 @@ private[spark] class BlockManager( */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { logDebug(s"Removing block $blockId") - val info = blockInfo.get(blockId).orNull + val info = blockInfo.get(blockId) if (info != null) { info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. @@ -1141,36 +1138,6 @@ private[spark] class BlockManager( } } - private def dropOldNonBroadcastBlocks(cleanupTime: Long): Unit = { - logInfo(s"Dropping non broadcast blocks older than $cleanupTime") - dropOldBlocks(cleanupTime, !_.isBroadcast) - } - - private def dropOldBroadcastBlocks(cleanupTime: Long): Unit = { - logInfo(s"Dropping broadcast blocks older than $cleanupTime") - dropOldBlocks(cleanupTime, _.isBroadcast) - } - - private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)): Unit = { - val iterator = blockInfo.getEntrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) - if (time < cleanupTime && shouldDrop(id)) { - info.synchronized { - val level = info.level - if (level.useMemory) { memoryStore.remove(id) } - if (level.useDisk) { diskStore.remove(id) } - if (level.useOffHeap) { externalBlockStore.remove(id) } - iterator.remove() - logInfo(s"Dropped block $id") - } - val status = getCurrentBlockStatus(id, info) - reportBlockStatus(id, info, status) - } - } - } - private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle @@ -1248,8 +1215,6 @@ private[spark] class BlockManager( if (externalBlockStoreInitialized) { externalBlockStore.clear() } - metadataCleaner.cancel() - broadcastCleaner.cancel() futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala deleted file mode 100644 index a8bbad086849e..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.{Timer, TimerTask} - -import org.apache.spark.{Logging, SparkConf} - -/** - * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) - */ -private[spark] class MetadataCleaner( - cleanerType: MetadataCleanerType.MetadataCleanerType, - cleanupFunc: (Long) => Unit, - conf: SparkConf) - extends Logging -{ - val name = cleanerType.toString - - private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType) - private val periodSeconds = math.max(10, delaySeconds / 10) - private val timer = new Timer(name + " cleanup timer", true) - - - private val task = new TimerTask { - override def run() { - try { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - - if (delaySeconds > 0) { - logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + - "and period of " + periodSeconds + " secs") - timer.schedule(task, delaySeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} - -private[spark] object MetadataCleanerType extends Enumeration { - - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, - SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value - - type MetadataCleanerType = Value - - def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { - "spark.cleaner.ttl." + which.toString - } -} - -// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the -// initialization of StreamingContext. It's okay for users trying to configure stuff themselves. -private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf): Int = { - conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt - } - - def getDelaySeconds( - conf: SparkConf, - cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt - } - - def setDelaySeconds( - conf: SparkConf, - cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) { - conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) - } - - /** - * Set the default delay time (in seconds). - * @param conf SparkConf instance - * @param delay default delay time to set - * @param resetAll whether to reset all to default - */ - def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { - conf.set("spark.cleaner.ttl", delay.toString) - if (resetAll) { - for (cleanerType <- MetadataCleanerType.values) { - System.clearProperty(MetadataCleanerType.systemProperty(cleanerType)) - } - } - } -} - diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala deleted file mode 100644 index 65efeb1f4c19c..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ -import scala.collection.mutable.Set - -private[spark] class TimeStampedHashSet[A] extends Set[A] { - val internalMap = new ConcurrentHashMap[A, Long]() - - def contains(key: A): Boolean = { - internalMap.contains(key) - } - - def iterator: Iterator[A] = { - val jIterator = internalMap.entrySet().iterator() - jIterator.asScala.map(_.getKey) - } - - override def + (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet += elem - newSet - } - - override def - (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet -= elem - newSet - } - - override def += (key: A): this.type = { - internalMap.put(key, currentTime) - this - } - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def empty: Set[A] = new TimeStampedHashSet[A]() - - override def size(): Int = internalMap.size() - - override def foreach[U](f: (A) => U): Unit = { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - f(iterator.next.getKey) - } - } - - /** - * Removes old values that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue < threshTime) { - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala deleted file mode 100644 index 310c0c109416c..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.lang.ref.WeakReference -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable -import scala.language.implicitConversions - -import org.apache.spark.Logging - -/** - * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. - * - * If the value is garbage collected and the weak reference is null, get() will return a - * non-existent value. These entries are removed from the map periodically (every N inserts), as - * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are - * older than a particular threshold can be removed using the clearOldValues method. - * - * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it - * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, - * so all operations on this HashMap are thread-safe. - * - * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. - */ -private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() with Logging { - - import TimeStampedWeakValueHashMap._ - - private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) - private val insertCount = new AtomicInteger(0) - - /** Return a map consisting only of entries whose values are still strongly reachable. */ - private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } - - def get(key: A): Option[B] = internalMap.get(key) - - def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator - - override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { - val newMap = new TimeStampedWeakValueHashMap[A, B1] - val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] - newMap.internalMap.putAll(oldMap.toMap) - newMap.internalMap += kv - newMap - } - - override def - (key: A): mutable.Map[A, B] = { - val newMap = new TimeStampedWeakValueHashMap[A, B] - newMap.internalMap.putAll(nonNullReferenceMap.toMap) - newMap.internalMap -= key - newMap - } - - override def += (kv: (A, B)): this.type = { - internalMap += kv - if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { - clearNullValues() - } - this - } - - override def -= (key: A): this.type = { - internalMap -= key - this - } - - override def update(key: A, value: B): Unit = this += ((key, value)) - - override def apply(key: A): B = internalMap.apply(key) - - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) - - override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) - - def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) - - def toMap: Map[A, B] = iterator.toMap - - /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) - - /** Remove entries with values that are no longer strongly reachable. */ - def clearNullValues() { - val it = internalMap.getEntrySet.iterator - while (it.hasNext) { - val entry = it.next() - if (entry.getValue.value.get == null) { - logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") - it.remove() - } - } - } - - // For testing - - def getTimestamp(key: A): Option[Long] = { - internalMap.getTimeStampedValue(key).map(_.timestamp) - } - - def getReference(key: A): Option[WeakReference[B]] = { - internalMap.getTimeStampedValue(key).map(_.value) - } -} - -/** - * Helper methods for converting to and from WeakReferences. - */ -private object TimeStampedWeakValueHashMap { - - // Number of inserts after which entries with null references are removed - val CLEAR_NULL_VALUES_INTERVAL = 100 - - /* Implicit conversion methods to WeakReferences. */ - - implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) - - implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { - kv match { case (k, v) => (k, toWeakReference(v)) } - } - - implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { - (kv: (K, WeakReference[V])) => p(kv) - } - - /* Implicit conversion methods from WeakReferences. */ - - implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get - - implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { - v match { - case Some(ref) => Option(fromWeakReference(ref)) - case None => None - } - } - - implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { - kv match { case (k, v) => (k, fromWeakReference(v)) } - } - - implicit def fromWeakReferenceIterator[K, V]( - it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { - it.map(fromWeakReferenceTuple) - } - - implicit def fromWeakReferenceMap[K, V]( - map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { - mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 9b3169026cda3..25fc15dd54d04 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.util -import java.lang.ref.WeakReference - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -34,10 +32,6 @@ class TimeStampedHashMapSuite extends SparkFunSuite { testMap(new TimeStampedHashMap[String, String]()) testMapThreadSafety(new TimeStampedHashMap[String, String]()) - // Test TimeStampedWeakValueHashMap basic functionality - testMap(new TimeStampedWeakValueHashMap[String, String]()) - testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) - test("TimeStampedHashMap - clearing by timestamp") { // clearing by insertion time val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) @@ -68,86 +62,6 @@ class TimeStampedHashMapSuite extends SparkFunSuite { assert(map1.get("k2").isDefined) } - test("TimeStampedWeakValueHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis - assert(map.getTimestamp("k1").isDefined) - assert(map.getTimestamp("k1").get < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.getTimestamp("k1").isDefined) - assert(map1.getTimestamp("k1").get < threshTime1) - assert(map1.getTimestamp("k2").isDefined) - assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) // should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - test("TimeStampedWeakValueHashMap - clearing weak references") { - var strongRef = new Object - val weakRef = new WeakReference(strongRef) - val map = new TimeStampedWeakValueHashMap[String, Object] - map("k1") = strongRef - map("k2") = "v2" - map("k3") = "v3" - val isEquals = map("k1") == strongRef - assert(isEquals) - - // clear strong reference to "k1" - strongRef = null - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - System.runFinalization() - Thread.sleep(100) - } - assert(map.getReference("k1").isDefined) - val ref = map.getReference("k1").get - assert(ref.get === null) - assert(map.get("k1") === None) - - // operations should only display non-null entries - assert(map.iterator.forall { case (k, v) => k != "k1" }) - assert(map.filter { case (k, v) => k != "k2" }.size === 1) - assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") - assert(map.toMap.size === 2) - assert(map.toMap.forall { case (k, v) => k != "k1" }) - val buffer = new ArrayBuffer[String] - map.foreach { case (k, v) => buffer += v.toString } - assert(buffer.size === 2) - assert(buffer.forall(_ != "k1")) - val plusMap = map + (("k4", "v4")) - assert(plusMap.size === 3) - assert(plusMap.forall { case (k, v) => k != "k1" }) - val minusMap = map - "k2" - assert(minusMap.size === 1) - assert(minusMap.head._1 == "k3") - - // clear null values - should only clear k1 - map.clearNullValues() - assert(map.getReference("k1") === None) - assert(map.get("k1") === None) - assert(map.get("k2").isDefined) - assert(map.get("k2").get === "v2") - } - /** Test basic operations of a Scala mutable Map. */ def testMap(hashMapConstructor: => mutable.Map[String, String]) { def newMap() = hashMapConstructor diff --git a/docs/configuration.md b/docs/configuration.md index 7d743d572b582..3ffc77dcc62e0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -823,17 +823,6 @@ Apart from these, the following properties are also available, and may be useful too small, BlockManager might take a performance hit. - - spark.cleaner.ttl - (infinite) - - Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks - generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be - forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in - case of Spark Streaming applications). Note that any RDD that persists in memory for more than - this duration will be cleared as well. - - spark.executor.cores 1 in YARN mode, all the available cores on the worker in standalone mode. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 61b230ab6f98a..b186d297610e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec +import org.apache.spark.util.Utils import org.apache.spark.streaming.scheduler.JobGenerator -import org.apache.spark.util.{MetadataCleaner, Utils} private[streaming] class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) @@ -40,7 +40,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration val pendingTimes = ssc.scheduler.getPendingTimes().toArray - val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 91a43e14a8b1b..c59348a89d34f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} +import org.apache.spark.util.{CallSite, Utils} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -271,18 +271,6 @@ abstract class DStream[T: ClassTag] ( checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." ) - val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) - logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - require( - metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, - "It seems you are doing some DStream window operation or setting a checkpoint interval " + - "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + - "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" + - "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " + - "set the Java cleaner delay to more than " + - math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." - ) - dependencies.foreach(_.validateAtStart()) logInfo("Slide time = " + slideDuration) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 860fac29c0ee0..0ae4c45988032 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -81,9 +81,9 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("from existing SparkContext") { @@ -93,26 +93,27 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert( Utils.timeStringAsSeconds(cp.sparkConfPairs - .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) + .toMap.getOrElse("spark.dummyTimeConfig", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert( + newCp.createSparkConf().getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("checkPoint from conf") { @@ -288,7 +289,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600s") + conf.set("spark.dummyTimeConfig", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") From 174e72ceca41a6ac17ad05d50832ee9c561918c0 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 6 Jan 2016 21:28:29 -0800 Subject: [PATCH 1476/2089] [SPARK-12673][UI] Add missing uri prepending for job description Otherwise the url will be failed to proxy to the right one if in YARN mode. Here is the screenshot: ![screen shot 2016-01-06 at 5 28 26 pm](https://cloud.githubusercontent.com/assets/850797/12139632/bbe78ecc-b49c-11e5-8932-94e8b3622a09.png) Author: jerryshao Closes #10618 from jerryshao/SPARK-12673. --- .../main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index db9912bc817e8..451cd83b51ae7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -224,10 +224,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(lastStageDescription, parent.basePath) + val basePathUri = UIUtils.prependBaseUri(parent.basePath) + val jobDescription = UIUtils.makeDescription(lastStageDescription, basePathUri) - val detailUrl = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) + val detailUrl = "%s/jobs/job?id=%s".format(basePathUri, job.jobId) {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} From b6738520374637347ab5ae6c801730cdb6b35daa Mon Sep 17 00:00:00 2001 From: Guillaume Poulin Date: Wed, 6 Jan 2016 21:34:46 -0800 Subject: [PATCH 1477/2089] [SPARK-12678][CORE] MapPartitionsRDD clearDependencies MapPartitionsRDD was keeping a reference to `prev` after a call to `clearDependencies` which could lead to memory leak. Author: Guillaume Poulin Closes #10623 from gpoulin/map_partition_deps. --- .../main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4312d3a417759..e4587c96eae1c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -25,7 +25,7 @@ import org.apache.spark.{Partition, TaskContext} * An RDD that applies the provided function to every partition of the parent RDD. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], + var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { @@ -36,4 +36,9 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) + + override def clearDependencies() { + super.clearDependencies() + prev = null + } } From e5cde7ab11a43334fa01b1bb8904da5c0774bc62 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 6 Jan 2016 22:03:31 -0800 Subject: [PATCH 1478/2089] Revert "[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None" This reverts commit fcd013cf70e7890aa25a8fe3cb6c8b36bf0e1f04. Author: Yin Huai Closes #10632 from yhuai/pythonStyle. --- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/tests.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 48daa87e82d13..c9e6f1dec6bf8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -346,7 +346,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = list(initialModel.weights) + initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 97fed7662ea90..6ed03e35828ed 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -475,18 +475,6 @@ def test_gmm_deterministic(self): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ From 84e77a15df18ba3f1cc871a3c52c783b46e52369 Mon Sep 17 00:00:00 2001 From: zzcclp Date: Wed, 6 Jan 2016 23:06:21 -0800 Subject: [PATCH 1479/2089] [DOC] fix 'spark.memory.offHeap.enabled' default value to false modify 'spark.memory.offHeap.enabled' default value to false Author: zzcclp Closes #10633 from zzcclp/fix_spark.memory.offHeap.enabled_default_value. --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 3ffc77dcc62e0..6bd0658b3e056 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -750,7 +750,7 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled - true + false If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. From 6a1c864ab6ee3e869a16ffdbaf6fead21c7aac6d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Jan 2016 23:21:52 -0800 Subject: [PATCH 1480/2089] [SPARK-12295] [SQL] external spilling for window functions This PR manage the memory used by window functions (buffered rows), also enable external spilling. After this PR, we can run window functions on a partition with hundreds of millions of rows with only 1G. Author: Davies Liu Closes #10605 from davies/unsafe_window. --- .../unsafe/sort/UnsafeExternalSorter.java | 21 +- .../unsafe/sort/UnsafeInMemorySorter.java | 18 +- .../unsafe/sort/UnsafeSorterIterator.java | 2 + .../unsafe/sort/UnsafeSorterSpillMerger.java | 7 + .../unsafe/sort/UnsafeSorterSpillReader.java | 8 +- .../apache/spark/sql/execution/Window.scala | 314 +++++++++++++----- 6 files changed, 276 insertions(+), 94 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 77d0b70bb892e..68dc0c6d415f6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -45,7 +45,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + @Nullable private final PrefixComparator prefixComparator; + @Nullable private final RecordComparator recordComparator; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; @@ -431,7 +433,11 @@ class SpillableIterator extends UnsafeSorterIterator { public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) { this.upstream = inMemIterator; - this.numRecords = inMemIterator.numRecordsLeft(); + this.numRecords = inMemIterator.getNumRecords(); + } + + public int getNumRecords() { + return numRecords; } public long spill() throws IOException { @@ -558,13 +564,23 @@ class ChainedIterator extends UnsafeSorterIterator { private final Queue iterators; private UnsafeSorterIterator current; + private int numRecords; public ChainedIterator(Queue iterators) { assert iterators.size() > 0; + this.numRecords = 0; + for (UnsafeSorterIterator iter: iterators) { + this.numRecords += iter.getNumRecords(); + } this.iterators = iterators; this.current = iterators.remove(); } + @Override + public int getNumRecords() { + return numRecords; + } + @Override public boolean hasNext() { while (!current.hasNext() && !iterators.isEmpty()) { @@ -575,6 +591,9 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } current.loadNext(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b7ab45675ee1e..f71b8d154cc24 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,6 +19,8 @@ import java.util.Comparator; +import org.apache.avro.reflect.Nullable; + import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -66,7 +68,9 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; + @Nullable private final Sorter sorter; + @Nullable private final Comparator sortComparator; /** @@ -98,10 +102,11 @@ public UnsafeInMemorySorter( LongArray array) { this.consumer = consumer; this.memoryManager = memoryManager; - this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); if (recordComparator != null) { + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); } else { + this.sorter = null; this.sortComparator = null; } this.array = array; @@ -190,12 +195,13 @@ public SortedIterator clone() { } @Override - public boolean hasNext() { - return position / 2 < numRecords; + public int getNumRecords() { + return numRecords; } - public int numRecordsLeft() { - return numRecords - position / 2; + @Override + public boolean hasNext() { + return position / 2 < numRecords; } @Override @@ -227,7 +233,7 @@ public void loadNext() { * {@code next()} will return the same mutable object. */ public SortedIterator getSortedIterator() { - if (sortComparator != null) { + if (sorter != null) { sorter.sort(array, 0, pos / 2, sortComparator); } return new SortedIterator(pos / 2); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java index 16ac2e8d821ba..1b3167fcc250c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator { public abstract int getRecordLength(); public abstract long getKeyPrefix(); + + public abstract int getNumRecords(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 3874a9f9cbdb6..ceb59352af64b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -23,6 +23,7 @@ final class UnsafeSorterSpillMerger { + private int numRecords = 0; private final PriorityQueue priorityQueue; public UnsafeSorterSpillMerger( @@ -59,6 +60,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator. spillReader.loadNext(); priorityQueue.add(spillReader); + numRecords += spillReader.getNumRecords(); } } @@ -67,6 +69,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { private UnsafeSorterIterator spillReader; + @Override + public int getNumRecords() { + return numRecords; + } + @Override public boolean hasNext() { return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index dcb13e6581e54..20ee1c8eb0c77 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -38,6 +38,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen // Variables that change with every record read: private int recordLength; private long keyPrefix; + private int numRecords; private int numRecordsRemaining; private byte[] arr = new byte[1024 * 1024]; @@ -53,13 +54,18 @@ public UnsafeSorterSpillReader( try { this.in = blockManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { Closeables.close(bs, /* swallowIOException = */ true); throw e; } } + @Override + public int getNumRecords() { + return numRecords; + } + @Override public boolean hasNext() { return (numRecordsRemaining > 0); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 89b17c82459f3..be885397a7d40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -26,6 +28,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} +import org.apache.spark.{SparkEnv, TaskContext} /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -283,23 +287,26 @@ case class Window( val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. - var nextRow: InternalRow = EmptyRow - var nextGroup: InternalRow = EmptyRow + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null var nextRowAvailable: Boolean = false private[this] def fetchNextRow() { nextRowAvailable = stream.hasNext if (nextRowAvailable) { - nextRow = stream.next() + nextRow = stream.next().asInstanceOf[UnsafeRow] nextGroup = grouping(nextRow) } else { - nextRow = EmptyRow - nextGroup = EmptyRow + nextRow = null + nextGroup = null } } fetchNextRow() // Manage the current partition. - val rows = ArrayBuffer.empty[InternalRow] + val rows = ArrayBuffer.empty[UnsafeRow] + val inputFields = child.output.length + var sorter: UnsafeExternalSorter = null + var rowBuffer: RowBuffer = null val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length @@ -307,27 +314,63 @@ case class Window( // Collect all the rows in the current partition. // Before we start to fetch new input rows, make a copy of nextGroup. val currentGroup = nextGroup.copy() - rows.clear() + + // clear last partition + if (sorter != null) { + // the last sorter of this task will be cleaned up via task completion listener + sorter.cleanupResources() + sorter = null + } else { + rows.clear() + } + while (nextRowAvailable && nextGroup == currentGroup) { - rows += nextRow.copy() + if (sorter == null) { + rows += nextRow.copy() + + if (rows.length >= 4096) { + // We will not sort the rows, so prefixComparator and recordComparator are null. + sorter = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + rows.foreach { r => + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) + } + rows.clear() + } + } else { + sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, + nextRow.getSizeInBytes, 0) + } fetchNextRow() } + if (sorter != null) { + rowBuffer = new ExternalRowBuffer(sorter, inputFields) + } else { + rowBuffer = new ArrayRowBuffer(rows) + } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rows) + frames(i).prepare(rowBuffer.copy()) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rows.size + rowsSize = rowBuffer.size() } // Iteration var rowIndex = 0 - var rowsSize = 0 + var rowsSize = 0L + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow @@ -340,13 +383,14 @@ case class Window( if (rowIndex < rowsSize) { // Get the results for the window frames. var i = 0 + val current = rowBuffer.next() while (i < numFrames) { - frames(i).write() + frames(i).write(rowIndex, current) i += 1 } // 'Merge' the input row with the window function result - join(rows(rowIndex), windowFunctionResult) + join(current, windowFunctionResult) rowIndex += 1 // Return the projection. @@ -362,14 +406,18 @@ case class Window( * Function for comparing boundary values. */ private[execution] abstract class BoundOrdering { - def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int + def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int } /** * Compare the input index to the bound of the output index. */ private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { - override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = inputIndex - (outputIndex + offset) } @@ -380,8 +428,100 @@ private[execution] final case class RangeBoundOrdering( ordering: Ordering[InternalRow], current: Projection, bound: Projection) extends BoundOrdering { - override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = - ordering.compare(current(input(inputIndex)), bound(input(outputIndex))) + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + ordering.compare(current(inputRow), bound(outputRow)) +} + +/** + * The interface of row buffer for a partition + */ +private[execution] abstract class RowBuffer { + + /** Number of rows. */ + def size(): Int + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow + + /** Skip the next `n` rows. */ + def skip(n: Int): Unit + + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer +} + +/** + * A row buffer based on ArrayBuffer (the number of rows is limited) + */ +private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { + + private[this] var cursor: Int = -1 + + /** Number of rows. */ + def size(): Int = buffer.length + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow = { + cursor += 1 + if (cursor < buffer.length) { + buffer(cursor) + } else { + null + } + } + + /** Skip the next `n` rows. */ + def skip(n: Int): Unit = { + cursor += n + } + + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer = { + new ArrayRowBuffer(buffer) + } +} + +/** + * An external buffer of rows based on UnsafeExternalSorter + */ +private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) + extends RowBuffer { + + private[this] val iter: UnsafeSorterIterator = sorter.getIterator + + private[this] val currentRow = new UnsafeRow(numFields) + + /** Number of rows. */ + def size(): Int = iter.getNumRecords() + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow = { + if (iter.hasNext) { + iter.loadNext() + currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + currentRow + } else { + null + } + } + + /** Skip the next `n` rows. */ + def skip(n: Int): Unit = { + var i = 0 + while (i < n && iter.hasNext) { + iter.loadNext() + i += 1 + } + } + + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer = { + new ExternalRowBuffer(sorter, numFields) + } } /** @@ -395,12 +535,12 @@ private[execution] abstract class WindowFunctionFrame { * * @param rows to calculate the frame results for. */ - def prepare(rows: ArrayBuffer[InternalRow]): Unit + def prepare(rows: RowBuffer): Unit /** * Write the current results to the target row. */ - def write(): Unit + def write(index: Int, current: InternalRow): Unit } /** @@ -421,14 +561,11 @@ private[execution] final class OffsetWindowFunctionFrame( offset: Int) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: ArrayBuffer[InternalRow] = null + private[this] var input: RowBuffer = null /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 - /** Index of the current output row. */ - private[this] var outputIndex = 0 - /** Row used when there is no valid input. */ private[this] val emptyRow = new GenericInternalRow(inputSchema.size) @@ -463,22 +600,26 @@ private[execution] final class OffsetWindowFunctionFrame( newMutableProjection(boundExpressions, Nil)().target(target) } - override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows + // drain the first few rows if offset is larger than zero + inputIndex = 0 + while (inputIndex < offset) { + input.next() + inputIndex += 1 + } inputIndex = offset - outputIndex = 0 } - override def write(): Unit = { - val size = input.size - if (inputIndex >= 0 && inputIndex < size) { - join(input(inputIndex), input(outputIndex)) + override def write(index: Int, current: InternalRow): Unit = { + if (inputIndex >= 0 && inputIndex < input.size) { + val r = input.next() + join(r, current) } else { - join(emptyRow, input(outputIndex)) + join(emptyRow, current) } projection(join) inputIndex += 1 - outputIndex += 1 } } @@ -498,7 +639,13 @@ private[execution] final class SlidingWindowFunctionFrame( ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: ArrayBuffer[InternalRow] = null + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null + + /** The rows within current sliding window. */ + private[this] val buffer = new util.ArrayDeque[InternalRow]() /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -508,33 +655,32 @@ private[execution] final class SlidingWindowFunctionFrame( * current output row. */ private[this] var inputLowIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows + nextRow = rows.next() inputHighIndex = 0 inputLowIndex = 0 - outputIndex = 0 + buffer.clear() } /** Write the frame columns for the current row to the given target row. */ - override def write(): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. - while (inputHighIndex < input.size && - ubound.compare(input, inputHighIndex, outputIndex) <= 0) { + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + buffer.add(nextRow.copy()) + nextRow = input.next() inputHighIndex += 1 bufferUpdated = true } // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (inputLowIndex < inputHighIndex && - lbound.compare(input, inputLowIndex, outputIndex) < 0) { + while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + buffer.remove() inputLowIndex += 1 bufferUpdated = true } @@ -542,12 +688,12 @@ private[execution] final class SlidingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { processor.initialize(input.size) - processor.update(input, inputLowIndex, inputHighIndex) + val iter = buffer.iterator() + while (iter.hasNext) { + processor.update(iter.next()) + } processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } } @@ -567,13 +713,18 @@ private[execution] final class UnboundedWindowFunctionFrame( processor: AggregateProcessor) extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { - processor.initialize(rows.size) - processor.update(rows, 0, rows.size) + override def prepare(rows: RowBuffer): Unit = { + val size = rows.size() + processor.initialize(size) + var i = 0 + while (i < size) { + processor.update(rows.next()) + i += 1 + } } /** Write the frame columns for the current row to the given target row. */ - override def write(): Unit = { + override def write(index: Int, current: InternalRow): Unit = { // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate // for each row. processor.evaluate(target) @@ -600,31 +751,32 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: ArrayBuffer[InternalRow] = null + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null /** Index of the first input row with a value greater than the upper bound of the current * output row. */ private[this] var inputIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows + nextRow = rows.next() inputIndex = 0 - outputIndex = 0 processor.initialize(input.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. - while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - processor.update(input(inputIndex)) + while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { + processor.update(nextRow) + nextRow = input.next() inputIndex += 1 bufferUpdated = true } @@ -633,9 +785,6 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( if (bufferUpdated) { processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } } @@ -661,29 +810,31 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: ArrayBuffer[InternalRow] = null + private[this] var input: RowBuffer = null /** Index of the first input row with a value equal to or greater than the lower bound of the * current output row. */ private[this] var inputIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows inputIndex = 0 - outputIndex = 0 } /** Write the frame columns for the current row to the given target row. */ - override def write(): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Duplicate the input to have a new iterator + val tmp = input.copy() // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) { + tmp.skip(inputIndex) + var nextRow = tmp.next() + while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { + nextRow = tmp.next() inputIndex += 1 bufferUpdated = true } @@ -691,12 +842,12 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { processor.initialize(input.size) - processor.update(input, inputIndex, input.size) + while (nextRow != null) { + processor.update(nextRow) + nextRow = tmp.next() + } processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } } @@ -825,15 +976,6 @@ private[execution] final class AggregateProcessor( } } - /** Bulk update the given buffer. */ - def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = { - var i = begin - while (i < end) { - update(input(i)) - i += 1 - } - } - /** Evaluate buffer. */ def evaluate(target: MutableRow): Unit = evaluateProjection.target(target)(buffer) From fd1dcfaf2608c2cc3a439ed3ca044ae655982306 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Jan 2016 23:46:12 -0800 Subject: [PATCH 1481/2089] [SPARK-12542][SQL] support except/intersect in HiveQl Parse the SQL query with except/intersect in FROM clause for HivQL. Author: Davies Liu Closes #10622 from davies/intersect. --- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 1 + .../sql/catalyst/parser/SparkSqlParser.g | 12 ++++--- .../spark/sql/catalyst/CatalystQl.scala | 7 +++- .../spark/sql/catalyst/CatalystQlSuite.scala | 32 +++++++++++++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 18 +++++++++++ 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e01e7101d0b7e..44a63fbef258c 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -103,6 +103,7 @@ KW_CLUSTER: 'CLUSTER'; KW_DISTRIBUTE: 'DISTRIBUTE'; KW_SORT: 'SORT'; KW_UNION: 'UNION'; +KW_EXCEPT: 'EXCEPT'; KW_LOAD: 'LOAD'; KW_EXPORT: 'EXPORT'; KW_IMPORT: 'IMPORT'; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 4afce3090f739..cf8a56566d32d 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -88,6 +88,8 @@ TOK_DISTRIBUTEBY; TOK_SORTBY; TOK_UNIONALL; TOK_UNIONDISTINCT; +TOK_EXCEPT; +TOK_INTERSECT; TOK_JOIN; TOK_LEFTOUTERJOIN; TOK_RIGHTOUTERJOIN; @@ -2122,6 +2124,8 @@ setOperator @after { popMsg(state); } : KW_UNION KW_ALL -> ^(TOK_UNIONALL) | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) + | KW_EXCEPT -> ^(TOK_EXCEPT) + | KW_INTERSECT -> ^(TOK_INTERSECT) ; queryStatementExpression[boolean topLevel] @@ -2242,7 +2246,7 @@ setOpSelectStatement[CommonTree t, boolean topLevel] ^(TOK_QUERY ^(TOK_FROM ^(TOK_SUBQUERY - ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + ^($u {$setOpSelectStatement.tree} $b) {adaptor.create(Identifier, generateUnionAlias())} ) ) @@ -2252,12 +2256,12 @@ setOpSelectStatement[CommonTree t, boolean topLevel] ) ) -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? - ^(TOK_UNIONALL {$setOpSelectStatement.tree} $b) + ^($u {$setOpSelectStatement.tree} $b) -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? ^(TOK_QUERY ^(TOK_FROM ^(TOK_SUBQUERY - ^(TOK_UNIONALL {$t} $b) + ^($u {$t} $b) {adaptor.create(Identifier, generateUnionAlias())} ) ) @@ -2266,7 +2270,7 @@ setOpSelectStatement[CommonTree t, boolean topLevel] ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) ) ) - -> ^(TOK_UNIONALL {$t} $b) + -> ^($u {$t} $b) )+ o=orderByClause? c=clusterByClause? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 42bdf25b61ea5..1eda4a9a97644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -399,9 +399,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) - // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT case Token("TOK_UNIONALL", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + case Token("TOK_UNIONDISTINCT", left :: right :: Nil) => + Distinct(Union(nodeToPlan(left), nodeToPlan(right))) + case Token("TOK_EXCEPT", left :: right :: Nil) => + Except(nodeToPlan(left), nodeToPlan(right)) + case Token("TOK_INTERSECT", left :: right :: Nil) => + Intersect(nodeToPlan(left), nodeToPlan(right)) case _ => noParseRule("Plan", node) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala new file mode 100644 index 0000000000000..0fee97fb0718c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.plans.PlanTest + +class CatalystQlSuite extends PlanTest { + + test("parse union/except/intersect") { + val paresr = new CatalystQl() + paresr.createPlan("select * from t1 union all select * from t2") + paresr.createPlan("select * from t1 union distinct select * from t2") + paresr.createPlan("select * from t1 union select * from t2") + paresr.createPlan("select * from t1 except select * from t2") + paresr.createPlan("select * from t1 intersect select * from t2") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 98e22c2e2c1b0..fa99289b41971 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -787,6 +787,24 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("union/except/intersect") { + assertResult(Array(Row(1), Row(1))) { + sql("select 1 as a union all select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a union distinct select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a union select 1 as a").collect() + } + assertResult(Array()) { + sql("select 1 as a except select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a intersect select 1 as a").collect() + } + } + test("SPARK-5383 alias for udfs with multi output columns") { assert( sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") From 8113dbda0bd51fdbe20dbfad466b8d25304a01f4 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 7 Jan 2016 00:27:13 -0800 Subject: [PATCH 1482/2089] [STREAMING][DOCS][EXAMPLES] Minor fixes Author: Jacek Laskowski Closes #10603 from jaceklaskowski/streaming-actor-custom-receiver. --- docs/streaming-custom-receivers.md | 8 ++++---- .../spark/examples/streaming/ActorWordCount.scala | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a75587a92adc7..97db865daa371 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -257,9 +257,9 @@ The following table summarizes the characteristics of both types of receivers ## Implementing and Using a Custom Actor-based Receiver -Custom [Akka Actors](http://doc.akka.io/docs/akka/2.2.4/scala/actors.html) can also be used to +Custom [Akka Actors](http://doc.akka.io/docs/akka/2.3.11/scala/actors.html) can also be used to receive data. The [`ActorHelper`](api/scala/index.html#org.apache.spark.streaming.receiver.ActorHelper) -trait can be applied on any Akka actor, which allows received data to be stored in Spark using +trait can be mixed in to any Akka actor, which allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of this actor can be configured to handle failures, etc. {% highlight scala %} @@ -273,8 +273,8 @@ class CustomActor extends Actor with ActorHelper { And a new input stream can be created with this custom actor as {% highlight scala %} -// Assuming ssc is the StreamingContext -val lines = ssc.actorStream[String](Props(new CustomActor()), "CustomReceiver") +val ssc: StreamingContext = ... +val lines = ssc.actorStream[String](Props[CustomActor], "CustomReceiver") {% endhighlight %} See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 8b8dae0be6119..a47fb7b7d7906 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -62,15 +62,13 @@ class FeederActor extends Actor { }.start() def receive: Receive = { - case SubscribeReceiver(receiverActor: ActorRef) => println("received subscribe from %s".format(receiverActor.toString)) - receivers = LinkedList(receiverActor) ++ receivers + receivers = LinkedList(receiverActor) ++ receivers case UnsubscribeReceiver(receiverActor: ActorRef) => println("received unsubscribe from %s".format(receiverActor.toString)) - receivers = receivers.dropWhile(x => x eq receiverActor) - + receivers = receivers.dropWhile(x => x eq receiverActor) } } @@ -129,9 +127,9 @@ object FeederActor { * and describe the AkkaSystem that Spark Sample feeder is running on. * * To run this example locally, you may run Feeder Actor as - * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` + * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.0.1 9999` * and then run the example - * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount 127.0.1.1 9999` + * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount 127.0.0.1 9999` */ object ActorWordCount { def main(args: Array[String]) { From 592f64985d0d58b4f6a0366bf975e04ca496bdbe Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 7 Jan 2016 10:32:56 -0800 Subject: [PATCH 1483/2089] [SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None If initial model passed to GMM is not empty it causes net.razorvine.pickle.PickleException. It can be fixed by converting initialModel.weights to list. Author: zero323 Closes #10644 from zero323/SPARK-12006. --- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c9e6f1dec6bf8..48daa87e82d13 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -346,7 +346,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = initialModel.weights + initialModelWeights = list(initialModel.weights) initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6ed03e35828ed..3436a28b2974f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -475,6 +475,18 @@ def test_gmm_deterministic(self): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ From f194d9911a93fc3a78be820096d4836f22d09976 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 7 Jan 2016 10:37:15 -0800 Subject: [PATCH 1484/2089] [SPARK-12662][SQL] Fix DataFrame.randomSplit to avoid creating overlapping splits https://issues.apache.org/jira/browse/SPARK-12662 cc yhuai Author: Sameer Agarwal Closes #10626 from sameeragarwal/randomsplit. --- .../org/apache/spark/sql/DataFrame.scala | 7 +++++- .../apache/spark/sql/DataFrameStatSuite.scala | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7cf2818590a78..60d2f05b8605b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1062,10 +1062,15 @@ class DataFrame private[sql]( * @since 1.4.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. + val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index b15af42caa3ab..63ad6c439a870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } } + test("randomSplit on reordered partitions") { + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val data = + sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + + // Verify that the splits don't overalap + assert(splits(0).intersect(splits(1)).collect().isEmpty) + + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") From 07b314a57a638a232ee0b5cd14169e57d742f0f9 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 7 Jan 2016 10:39:46 -0800 Subject: [PATCH 1485/2089] [MINOR] Fix for BUILD FAILURE for Scala 2.11 It was introduced in 917d3fc069fb9ea1c1487119c9c12b373f4f9b77 /cc cloud-fan rxin Author: Jacek Laskowski Closes #10636 from jaceklaskowski/fix-for-build-failure-2.11. --- .../datasources/json/JSONRelation.scala | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index b92edf65bfb6b..8a6fa4aeebc09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -68,29 +68,12 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val bucketSpec: Option[BucketSpec], + override val bucketSpec: Option[BucketSpec] = None, override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - def this( - inputRDD: Option[RDD[String]], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - userDefinedPartitionColumns: Option[StructType], - paths: Array[String] = Array.empty[String], - parameters: Map[String, String] = Map.empty[String, String])(sqlContext: SQLContext) = { - this( - inputRDD, - maybeDataSchema, - maybePartitionSpec, - userDefinedPartitionColumns, - None, - paths, - parameters)(sqlContext) - } - val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) /** Constraints to be imposed on schema to be stored. */ From 1b2c2162af4d5d2d950af94571e69273b49bf913 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 7 Jan 2016 21:12:57 +0000 Subject: [PATCH 1486/2089] =?UTF-8?q?[STREAMING][MINOR]=20More=20contextua?= =?UTF-8?q?l=20information=20in=20logs=20+=20minor=20code=20i=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …mprovements Please review and merge at your convenience. Thanks! Author: Jacek Laskowski Closes #10595 from jaceklaskowski/streaming-minor-fixes. --- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 4 +- .../spark/network/client/StreamCallback.java | 4 +- .../spark/network/client/TransportClient.java | 2 +- .../spark/network/server/RpcHandler.java | 2 +- .../spark/streaming/StreamingContext.scala | 12 +-- .../spark/streaming/dstream/DStream.scala | 86 +++++++++---------- .../streaming/dstream/InputDStream.scala | 3 +- .../dstream/ReceiverInputDStream.scala | 4 +- .../receiver/ReceivedBlockHandler.scala | 2 +- .../spark/streaming/receiver/Receiver.scala | 4 +- .../receiver/ReceiverSupervisor.scala | 8 +- .../spark/streaming/scheduler/JobSet.scala | 8 +- 14 files changed, 69 insertions(+), 74 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 92438ba892cc0..6b01a10fc136e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -747,7 +747,7 @@ class DAGScheduler( } /** - * Check for waiting or failed stages which are now eligible for resubmission. + * Check for waiting stages which are now eligible for resubmission. * Ordinarily run on every iteration of the event loop. */ private def submitWaitingStages() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5c80ac17b8d90..4479e6875a731 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -59,7 +59,7 @@ private[spark] class BlockResult( * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). * - * Note that #initialize() must be called before the BlockManager is usable. + * Note that [[initialize()]] must be called before the BlockManager is usable. */ private[spark] class BlockManager( executorId: String, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d0448feb5b06..037bec1d9c33b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks * in a pipelined fashion as they are received. * - * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid * using too much memory. * * @param context [[TaskContext]], used for metrics update @@ -329,7 +329,7 @@ final class ShuffleBlockFetcherIterator( } /** - * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() */ private class BufferReleasingInputStream( private val delegate: InputStream, diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 51d34cac6e636..29e6a30dc1f67 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,8 +21,8 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link onData(String, ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link onComplete(String)} will be + * Callback for streaming data. Stream data will be offered to the {@link #onData(String, ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link #onComplete(String)} will be * called. *

    * The network library guarantees that a single thread will call these methods at a time, but diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index c49ca4d5ee925..e15f096d36913 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -288,7 +288,7 @@ public void send(ByteBuffer message) { /** * Removes any state associated with the given RPC. * - * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}. */ public void removeRpcRequest(long requestId) { handler.removeRpcRequest(requestId); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index c6ed0f459ad71..a99c3015b0e05 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -57,7 +57,7 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link receive(TransportClient, byte[], RpcResponseCallback)}" and log a warning if + * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index a5ab66697589b..ca0a21fbb79ff 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -226,7 +226,7 @@ class StreamingContext private[streaming] ( * Set the context to periodically checkpoint the DStream operations for driver * fault-tolerance. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored. - * Note that this must be a fault-tolerant file system like HDFS for + * Note that this must be a fault-tolerant file system like HDFS. */ def checkpoint(directory: String) { if (directory != null) { @@ -274,7 +274,7 @@ class StreamingContext private[streaming] ( * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver * - * @deprecated As of 1.0.0", replaced by `receiverStream`. + * @deprecated As of 1.0.0 replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -285,7 +285,7 @@ class StreamingContext private[streaming] ( /** * Create an input stream with any arbitrary user implemented receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Find more details at http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver */ def receiverStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -549,7 +549,7 @@ class StreamingContext private[streaming] ( // Verify whether the DStream checkpoint is serializable if (isCheckpointingEnabled) { - val checkpoint = new Checkpoint(this, Time.apply(0)) + val checkpoint = new Checkpoint(this, Time(0)) try { Checkpoint.serialize(checkpoint, conf) } catch { @@ -575,9 +575,9 @@ class StreamingContext private[streaming] ( * * Return the current state of the context. The context can be in three possible states - * - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * - StreamingContextState.INITIALIZED - The context has been created, but not started yet. * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * - StreamingContextState.ACTIVE - The context has been started, and not stopped. * Input DStreams, transformations and output operations cannot be created on the context. * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index c59348a89d34f..1dfb4e7abc0ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -103,7 +103,7 @@ abstract class DStream[T: ClassTag] ( // Reference to whole DStream graph private[streaming] var graph: DStreamGraph = null - private[streaming] def isInitialized = (zeroTime != null) + private[streaming] def isInitialized = zeroTime != null // Duration for which the DStream requires its parent DStream to remember each RDD created private[streaming] def parentRememberDuration = rememberDuration @@ -189,15 +189,15 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def initialize(time: Time) { if (zeroTime != null && zeroTime != time) { - throw new SparkException("ZeroTime is already initialized to " + zeroTime - + ", cannot initialize it again to " + time) + throw new SparkException(s"ZeroTime is already initialized to $zeroTime" + + s", cannot initialize it again to $time") } zeroTime = time // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger if (mustCheckpoint && checkpointDuration == null) { checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt - logInfo("Checkpoint interval automatically set to " + checkpointDuration) + logInfo(s"Checkpoint interval automatically set to $checkpointDuration") } // Set the minimum value of the rememberDuration if not already set @@ -234,7 +234,7 @@ abstract class DStream[T: ClassTag] ( require( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + + s"The checkpoint interval for ${this.getClass.getSimpleName} has not been set." + " Please use DStream.checkpoint() to set the interval." ) @@ -245,53 +245,53 @@ abstract class DStream[T: ClassTag] ( require( checkpointDuration == null || checkpointDuration >= slideDuration, - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + - "Please set it to at least " + slideDuration + "." + s"The checkpoint interval for ${this.getClass.getSimpleName} has been set to " + + s"$checkpointDuration which is lower than its slide time ($slideDuration). " + + s"Please set it to at least $slideDuration." ) require( checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple of " + slideDuration + "." + s"The checkpoint interval for ${this.getClass.getSimpleName} has been set to " + + s" $checkpointDuration which not a multiple of its slide time ($slideDuration). " + + s"Please set it to a multiple of $slideDuration." ) require( checkpointDuration == null || storageLevel != StorageLevel.NONE, - "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + s"${this.getClass.getSimpleName} has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) require( checkpointDuration == null || rememberDuration > checkpointDuration, - "The remember duration for " + this.getClass.getSimpleName + " has been set to " + - rememberDuration + " which is not more than the checkpoint interval (" + - checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." + s"The remember duration for ${this.getClass.getSimpleName} has been set to " + + s" $rememberDuration which is not more than the checkpoint interval" + + s" ($checkpointDuration). Please set it to higher than $checkpointDuration." ) dependencies.foreach(_.validateAtStart()) - logInfo("Slide time = " + slideDuration) - logInfo("Storage level = " + storageLevel) - logInfo("Checkpoint interval = " + checkpointDuration) - logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized and validated " + this) + logInfo(s"Slide time = $slideDuration") + logInfo(s"Storage level = ${storageLevel.description}") + logInfo(s"Checkpoint interval = $checkpointDuration") + logInfo(s"Remember duration = $rememberDuration") + logInfo(s"Initialized and validated $this") } private[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { - throw new SparkException("Context is already set in " + this + ", cannot set it again") + throw new SparkException(s"Context must not be set again for $this") } ssc = s - logInfo("Set context for " + this) + logInfo(s"Set context for $this") dependencies.foreach(_.setContext(ssc)) } private[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { - throw new SparkException("Graph is already set in " + this + ", cannot set it again") + throw new SparkException(s"Graph must not be set again for $this") } graph = g dependencies.foreach(_.setGraph(graph)) @@ -300,7 +300,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def remember(duration: Duration) { if (duration != null && (rememberDuration == null || duration > rememberDuration)) { rememberDuration = duration - logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) + logInfo(s"Duration for remembering RDDs set to $rememberDuration for $this") } dependencies.foreach(_.remember(parentRememberDuration)) } @@ -310,11 +310,11 @@ abstract class DStream[T: ClassTag] ( if (!isInitialized) { throw new SparkException (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { - logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + - " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) + logInfo(s"Time $time is invalid as zeroTime is $zeroTime" + + s" , slideDuration is $slideDuration and difference is ${time - zeroTime}") false } else { - logDebug("Time " + time + " is valid") + logDebug(s"Time $time is valid") true } } @@ -452,20 +452,20 @@ abstract class DStream[T: ClassTag] ( oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys if (unpersistData) { - logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) + logDebug(s"Unpersisting old RDDs: ${oldRDDs.values.map(_.id).mkString(", ")}") oldRDDs.values.foreach { rdd => rdd.unpersist(false) // Explicitly remove blocks of BlockRDD rdd match { case b: BlockRDD[_] => - logInfo("Removing blocks of RDD " + b + " of time " + time) + logInfo(s"Removing blocks of RDD $b of time $time") b.removeBlocks() case _ => } } } - logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) + logDebug(s"Cleared ${oldRDDs.size} RDDs that were older than " + + s"${time - rememberDuration}: ${oldRDDs.keys.mkString(", ")}") dependencies.foreach(_.clearMetadata(time)) } @@ -477,10 +477,10 @@ abstract class DStream[T: ClassTag] ( * this method to save custom checkpoint data. */ private[streaming] def updateCheckpointData(currentTime: Time) { - logDebug("Updating checkpoint data for time " + currentTime) + logDebug(s"Updating checkpoint data for time $currentTime") checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) - logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) + logDebug(s"Updated checkpoint data for time $currentTime: $checkpointData") } private[streaming] def clearCheckpointData(time: Time) { @@ -509,13 +509,13 @@ abstract class DStream[T: ClassTag] ( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { - logDebug(this.getClass().getSimpleName + ".writeObject used") + logDebug(s"${this.getClass().getSimpleName}.writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { oos.defaultWriteObject() } else { - val msg = "Object of " + this.getClass.getName + " is being serialized " + + val msg = s"Object of ${this.getClass.getName} is being serialized " + " possibly as a part of closure of an RDD operation. This is because " + " the DStream object is being referred to from within the closure. " + " Please rewrite the RDD operation inside this DStream to avoid this. " + @@ -532,7 +532,7 @@ abstract class DStream[T: ClassTag] ( @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { - logDebug(this.getClass().getSimpleName + ".readObject used") + logDebug(s"${this.getClass().getSimpleName}.readObject used") ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[T]] () } @@ -756,7 +756,7 @@ abstract class DStream[T: ClassTag] ( val firstNum = rdd.take(num + 1) // scalastyle:off println println("-------------------------------------------") - println("Time: " + time) + println(s"Time: $time") println("-------------------------------------------") firstNum.take(num).foreach(println) if (firstNum.length > num) println("...") @@ -903,21 +903,19 @@ abstract class DStream[T: ClassTag] ( val alignedToTime = if ((toTime - zeroTime).isMultipleOf(slideDuration)) { toTime } else { - logWarning("toTime (" + toTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") - toTime.floor(slideDuration, zeroTime) + logWarning(s"toTime ($toTime) is not a multiple of slideDuration ($slideDuration)") + toTime.floor(slideDuration, zeroTime) } val alignedFromTime = if ((fromTime - zeroTime).isMultipleOf(slideDuration)) { fromTime } else { - logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") + logWarning(s"fromTime ($fromTime) is not a multiple of slideDuration ($slideDuration)") fromTime.floor(slideDuration, zeroTime) } - logInfo("Slicing from " + fromTime + " to " + toTime + - " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") + logInfo(s"Slicing from $fromTime to $toTime" + + s" (aligned to $alignedFromTime and $alignedToTime)") alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { if (time >= zeroTime) getOrCompute(time) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 95994c983c0cc..d60f418e5c4de 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -28,7 +28,8 @@ import org.apache.spark.util.Utils /** * This is the abstract base class for all input streams. This class provides methods - * start() and stop() which is called by Spark Streaming system to start and stop receiving data. + * start() and stop() which are called by Spark Streaming system to start and stop + * receiving data, respectively. * Input streams that can generate RDDs from new data by running a service/thread only on * the driver node (that is, without running a receiver on worker nodes), can be * implemented by directly inheriting this InputDStream. For example, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a18551fac719a..565b137228d00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] * that has to start a receiver on worker nodes to receive external data. * Specific implementations of ReceiverInputDStream must - * define `the getReceiver()` function that gets the receiver object of type + * define [[getReceiver]] function that gets the receiver object of type * [[org.apache.spark.streaming.receiver.Receiver]] that will be sent * to the workers to receive data. * @param ssc_ Streaming context that will execute this input stream @@ -121,7 +121,7 @@ abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) } if (validBlockIds.size != blockIds.size) { logWarning("Some blocks could not be recovered as they were not found in memory. " + - "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "To prevent such data loss, enable Write Ahead Log (see programming guide " + "for more details.") } new BlockRDD[T](ssc.sc, validBlockIds) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 43c605af73716..faa5aca1d8f7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -69,7 +69,7 @@ private[streaming] class BlockManagerBasedBlockHandler( def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { - var numRecords = None: Option[Long] + var numRecords: Option[Long] = None val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index b08152485ab5b..639f4259e2e73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -103,7 +103,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable /** * This method is called by the system when the receiver is stopped. All resources - * (threads, buffers, etc.) setup in `onStart()` must be cleaned up in this method. + * (threads, buffers, etc.) set up in `onStart()` must be cleaned up in this method. */ def onStop() @@ -273,7 +273,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable /** Get the attached supervisor. */ private[streaming] def supervisor: ReceiverSupervisor = { assert(_supervisor != null, - "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "A ReceiverSupervisor has not been attached to the receiver yet. Maybe you are starting " + "some computation in the receiver before the Receiver.onStart() has been called.") _supervisor } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index c42a9ac233f87..d0195fb14f0a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -143,10 +143,10 @@ private[streaming] abstract class ReceiverSupervisor( def startReceiver(): Unit = synchronized { try { if (onReceiverStart()) { - logInfo("Starting receiver") + logInfo(s"Starting receiver $streamId") receiverState = Started receiver.onStart() - logInfo("Called receiver onStart") + logInfo(s"Called receiver $streamId onStart") } else { // The driver refused us stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) @@ -218,11 +218,9 @@ private[streaming] abstract class ReceiverSupervisor( stopLatch.await() if (stoppingError != null) { logError("Stopped receiver with error: " + stoppingError) + throw stoppingError } else { logInfo("Stopped receiver without error") } - if (stoppingError != null) { - throw stoppingError - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index f76300351e3c0..6e7232a2a0886 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -59,17 +59,15 @@ case class JobSet( // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay: Long = { - processingEndTime - time.milliseconds - } + def totalDelay: Long = processingEndTime - time.milliseconds def toBatchInfo: BatchInfo = { BatchInfo( time, streamIdToInputInfo, submissionTime, - if (processingStartTime >= 0) Some(processingStartTime) else None, - if (processingEndTime >= 0) Some(processingEndTime) else None, + if (hasStarted) Some(processingStartTime) else None, + if (hasCompleted) Some(processingEndTime) else None, jobs.map { job => (job.outputOpId, job.toOutputOperationInfo) }.toMap ) } From 8346518357f4a3565ae41e9a5ccd7e2c3ed6c468 Mon Sep 17 00:00:00 2001 From: Darek Blasiak Date: Thu, 7 Jan 2016 21:15:40 +0000 Subject: [PATCH 1487/2089] [SPARK-12598][CORE] bug in setMinPartitions There is a bug in the calculation of ```maxSplitSize```. The ```totalLen``` should be divided by ```minPartitions``` and not by ```files.size```. Author: Darek Blasiak Closes #10546 from datafarmer/setminpartitionsbug. --- .../scala/org/apache/spark/input/PortableDataStream.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 8009491a1b0e0..18cb7631b3d4c 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -41,9 +41,8 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong + val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum + val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong super.setMaxSplitSize(maxSplitSize) } From 34dbc8af21da63702bc0694d471fbfee4cd08dda Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Jan 2016 13:56:34 -0800 Subject: [PATCH 1488/2089] [SPARK-12580][SQL] Remove string concatenations from usage and extended in @ExpressionDescription Use multi-line string literals for ExpressionDescription with ``// scalastyle:off line.size.limit`` and ``// scalastyle:on line.size.limit`` The policy is here, as describe at https://github.com/apache/spark/pull/10488 Let's use multi-line string literals. If we have to have a line with more than 100 characters, let's use ``// scalastyle:off line.size.limit`` and ``// scalastyle:on line.size.limit`` to just bypass the line number requirement. Author: Kazuaki Ishizaki Closes #10524 from kiszk/SPARK-12580. --- .../spark/sql/catalyst/expressions/misc.scala | 12 +++--- .../expressions/windowExpressions.scala | 38 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6697d463614d5..fd95b124b2455 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -58,13 +58,13 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or * the hash length is not one of the permitted values, the return value is NULL. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = - """_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the input. - SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.""" - , - extended = "> SELECT _FUNC_('Spark', 0);\n " + - "'529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b'") + usage = """_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the input. + SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.""", + extended = """> SELECT _FUNC_('Spark', 0); + '529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b'""") +// scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 3934e33628bd8..afe122f6a0e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -366,8 +366,8 @@ abstract class OffsetWindowFunction * @param default to use when the input value is null or when the offset is larger than the window. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows after the - current row in the window""") + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows + after the current row in the window""") case class Lead(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -393,8 +393,8 @@ case class Lead(input: Expression, offset: Expression, default: Expression) * @param default to use when the input value is null or when the offset is smaller than the window. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows before the - current row in the window""") + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows + before the current row in the window""") case class Lag(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -446,9 +446,9 @@ object SizeBasedWindowFunction { * This documentation has been based upon similar documentation for the Hive and Presto projects. */ @ExpressionDescription(usage = - """_FUNC_() - The ROW_NUMBER() function assigns a unique, sequential - number to each row, starting with one, according to the ordering of rows within the window - partition.""") + """_FUNC_() - The ROW_NUMBER() function assigns a unique, sequential number to + each row, starting with one, according to the ordering of rows within + the window partition.""") case class RowNumber() extends RowNumberLike { override val evaluateExpression = rowNumber } @@ -462,8 +462,8 @@ case class RowNumber() extends RowNumberLike { * This documentation has been based upon similar documentation for the Hive and Presto projects. */ @ExpressionDescription(usage = - """_FUNC_() - The CUME_DIST() function computes the position of a value relative to a all values - in the partition.""") + """_FUNC_() - The CUME_DIST() function computes the position of a value relative to + a all values in the partition.""") case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { override def dataType: DataType = DoubleType // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must @@ -494,8 +494,8 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { * @param buckets number of buckets to divide the rows in. Default value is 1. */ @ExpressionDescription(usage = - """_FUNC_(x) - The NTILE(n) function divides the rows for each window partition into 'n' buckets - ranging from 1 to at most 'n'.""") + """_FUNC_(x) - The NTILE(n) function divides the rows for each window partition + into 'n' buckets ranging from 1 to at most 'n'.""") case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(1)) @@ -602,9 +602,9 @@ abstract class RankLike extends AggregateWindowFunction { * Analyser. */ @ExpressionDescription(usage = - """_FUNC_() - RANK() computes the rank of a value in a group of values. The result is one plus - the number of rows preceding or equal to the current row in the ordering of the partition. Tie - values will produce gaps in the sequence.""") + """_FUNC_() - RANK() computes the rank of a value in a group of values. The result + is one plus the number of rows preceding or equal to the current row in the + ordering of the partition. Tie values will produce gaps in the sequence.""") case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) @@ -622,9 +622,9 @@ case class Rank(children: Seq[Expression]) extends RankLike { * Analyser. */ @ExpressionDescription(usage = - """_FUNC_() - The DENSE_RANK() function computes the rank of a value in a group of values. The - result is one plus the previously assigned rank value. Unlike Rank, DenseRank will not produce - gaps in the ranking sequence.""") + """_FUNC_() - The DENSE_RANK() function computes the rank of a value in a group of + values. The result is one plus the previously assigned rank value. Unlike Rank, + DenseRank will not produce gaps in the ranking sequence.""") case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) @@ -649,8 +649,8 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { * Analyser. */ @ExpressionDescription(usage = - """_FUNC_() - PERCENT_RANK() The PercentRank function computes the percentage ranking of a value - in a group of values.""") + """_FUNC_() - PERCENT_RANK() The PercentRank function computes the percentage + ranking of a value in a group of values.""") case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) From c0c397509bc909b9bf2d5186182f461155b021ab Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 7 Jan 2016 15:26:55 -0800 Subject: [PATCH 1489/2089] [SPARK-12510][STREAMING] Refactor ActorReceiver to support Java This PR includes the following changes: 1. Rename `ActorReceiver` to `ActorReceiverSupervisor` 2. Remove `ActorHelper` 3. Add a new `ActorReceiver` for Scala and `JavaActorReceiver` for Java 4. Add `JavaActorWordCount` example Author: Shixiong Zhu Closes #10457 from zsxwing/java-actor-stream. --- .../streaming/JavaActorWordCount.java | 137 ++++++++++++++++++ .../examples/streaming/ActorWordCount.scala | 9 +- .../streaming/zeromq/ZeroMQReceiver.scala | 5 +- project/MimaExcludes.scala | 3 + .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/receiver/ActorReceiver.scala | 64 ++++++-- 6 files changed, 202 insertions(+), 20 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java new file mode 100644 index 0000000000000..2377207779fec --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; + +import scala.Tuple2; + +import akka.actor.ActorSelection; +import akka.actor.Props; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.receiver.JavaActorReceiver; + +/** + * A sample actor as receiver, is also simplest. This receiver actor + * goes and subscribe to a typical publisher/feeder actor and receives + * data. + * + * @see [[org.apache.spark.examples.streaming.FeederActor]] + */ +class JavaSampleActorReceiver extends JavaActorReceiver { + + private final String urlOfPublisher; + + public JavaSampleActorReceiver(String urlOfPublisher) { + this.urlOfPublisher = urlOfPublisher; + } + + private ActorSelection remotePublisher; + + @Override + public void preStart() { + remotePublisher = getContext().actorSelection(urlOfPublisher); + remotePublisher.tell(new SubscribeReceiver(getSelf()), getSelf()); + } + + public void onReceive(Object msg) throws Exception { + store((T) msg); + } + + @Override + public void postStop() { + remotePublisher.tell(new UnsubscribeReceiver(getSelf()), getSelf()); + } +} + +/** + * A sample word count program demonstrating the use of plugging in + * Actor as Receiver + * Usage: JavaActorWordCount + * and describe the AkkaSystem that Spark Sample feeder is running on. + * + * To run this example locally, you may run Feeder Actor as + *

    + *     $ bin/run-example org.apache.spark.examples.streaming.FeederActor localhost 9999
    + * 
    + * and then run the example + *
    + *     $ bin/run-example org.apache.spark.examples.streaming.JavaActorWordCount localhost 9999
    + * 
    + */ +public class JavaActorWordCount { + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaActorWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + final String host = args[0]; + final String port = args[1]; + SparkConf sparkConf = new SparkConf().setAppName("JavaActorWordCount"); + // Create the context and set the batch size + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); + + String feederActorURI = "akka.tcp://test@" + host + ":" + port + "/user/FeederActor"; + + /* + * Following is the use of actorStream to plug in custom actor as receiver + * + * An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e type of data received and InputDstream + * should be same. + * + * For example: Both actorStream and JavaSampleActorReceiver are parameterized + * to same type to ensure type safety. + */ + JavaDStream lines = jssc.actorStream( + Props.create(JavaSampleActorReceiver.class, feederActorURI), "SampleReceiver"); + + // compute wordcount + lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String s) { + return Arrays.asList(s.split("\\s+")); + } + }).mapToPair(new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }).print(); + + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index a47fb7b7d7906..88cdc6bc144e5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -26,8 +26,7 @@ import akka.actor.{actorRef2Scala, Actor, ActorRef, Props} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.streaming.receiver.ActorHelper +import org.apache.spark.streaming.receiver.ActorReceiver import org.apache.spark.util.AkkaUtils case class SubscribeReceiver(receiverActor: ActorRef) @@ -80,7 +79,7 @@ class FeederActor extends Actor { * @see [[org.apache.spark.examples.streaming.FeederActor]] */ class SampleActorReceiver[T: ClassTag](urlOfPublisher: String) -extends Actor with ActorHelper { +extends ActorReceiver { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) @@ -127,9 +126,9 @@ object FeederActor { * and describe the AkkaSystem that Spark Sample feeder is running on. * * To run this example locally, you may run Feeder Actor as - * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.0.1 9999` + * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor localhost 9999` * and then run the example - * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount 127.0.0.1 9999` + * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount localhost 9999` */ object ActorWordCount { def main(args: Array[String]) { diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 588e6bac7b14a..506ba8782d3d5 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -19,12 +19,11 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import akka.actor.Actor import akka.util.ByteString import akka.zeromq._ import org.apache.spark.Logging -import org.apache.spark.streaming.receiver.ActorHelper +import org.apache.spark.streaming.receiver.ActorReceiver /** * A receiver to subscribe to ZeroMQ stream. @@ -33,7 +32,7 @@ private[streaming] class ZeroMQReceiver[T: ClassTag]( publisherUrl: String, subscribe: Subscribe, bytesToObjects: Seq[ByteString] => Iterator[T]) - extends Actor with ActorHelper with Logging { + extends ActorReceiver with Logging { override def preStart(): Unit = { ZeroMQExtension(context.system) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 43ca4690dc2bb..69e5bc881b593 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -119,6 +119,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") + ) ++ Seq( + // SPARK-12510 Refactor ActorReceiver to support Java + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") ) case v if v.startsWith("1.6") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ca0a21fbb79ff..ba509a1030af7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -41,7 +41,7 @@ import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} +import org.apache.spark.streaming.receiver.{ActorReceiverSupervisor, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} @@ -312,7 +312,7 @@ class StreamingContext private[streaming] ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy ): ReceiverInputDStream[T] = withNamedScope("actor stream") { - receiverStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) + receiverStream(new ActorReceiverSupervisor[T](props, name, storageLevel, supervisorStrategy)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index 7ec74016a1c2c..0eabf3d260b26 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -47,13 +47,12 @@ object ActorSupervisorStrategy { /** * :: DeveloperApi :: - * A receiver trait to be mixed in with your Actor to gain access to - * the API for pushing received data into Spark Streaming for being processed. + * A base Actor that provides APIs for pushing received data into Spark Streaming for processing. * * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * * @example {{{ - * class MyActor extends Actor with ActorHelper{ + * class MyActor extends ActorReceiver { * def receive { * case anything: String => store(anything) * } @@ -69,13 +68,60 @@ object ActorSupervisorStrategy { * should be same. */ @DeveloperApi -trait ActorHelper extends Logging{ +abstract class ActorReceiver extends Actor { - self: Actor => // to ensure that this can be added to Actor classes only + /** Store an iterator of received data as a data block into Spark's memory. */ + def store[T](iter: Iterator[T]) { + context.parent ! IteratorData(iter) + } + + /** + * Store the bytes of received data as a data block into Spark's memory. Note + * that the data in the ByteBuffer must be serialized using the same serializer + * that Spark is configured to use. + */ + def store(bytes: ByteBuffer) { + context.parent ! ByteBufferData(bytes) + } + + /** + * Store a single item of received data to Spark's memory. + * These single items will be aggregated together into data blocks before + * being pushed into Spark's memory. + */ + def store[T](item: T) { + context.parent ! SingleItemData(item) + } +} + +/** + * :: DeveloperApi :: + * A Java UntypedActor that provides APIs for pushing received data into Spark Streaming for + * processing. + * + * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * + * @example {{{ + * class MyActor extends JavaActorReceiver { + * def receive { + * case anything: String => store(anything) + * } + * } + * + * // Can be used with an actorStream as follows + * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") + * + * }}} + * + * @note Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of push block and InputDStream + * should be same. + */ +@DeveloperApi +abstract class JavaActorReceiver extends UntypedActor { /** Store an iterator of received data as a data block into Spark's memory. */ def store[T](iter: Iterator[T]) { - logDebug("Storing iterator") context.parent ! IteratorData(iter) } @@ -85,7 +131,6 @@ trait ActorHelper extends Logging{ * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - logDebug("Storing Bytes") context.parent ! ByteBufferData(bytes) } @@ -95,7 +140,6 @@ trait ActorHelper extends Logging{ * being pushed into Spark's memory. */ def store[T](item: T) { - logDebug("Storing item") context.parent ! SingleItemData(item) } } @@ -104,7 +148,7 @@ trait ActorHelper extends Logging{ * :: DeveloperApi :: * Statistics for querying the supervisor about state of workers. Used in * conjunction with `StreamingContext.actorStream` and - * [[org.apache.spark.streaming.receiver.ActorHelper]]. + * [[org.apache.spark.streaming.receiver.ActorReceiver]]. */ @DeveloperApi case class Statistics(numberOfMsgs: Int, @@ -137,7 +181,7 @@ private[streaming] case class ByteBufferData(bytes: ByteBuffer) extends ActorRec * context.parent ! Props(new Worker, "Worker") * }}} */ -private[streaming] class ActorReceiver[T: ClassTag]( +private[streaming] class ActorReceiverSupervisor[T: ClassTag]( props: Props, name: String, storageLevel: StorageLevel, From 5a4021998ab0f1c8bbb610eceecdf879d149a7b8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 7 Jan 2016 17:21:03 -0800 Subject: [PATCH 1490/2089] [SPARK-12604][CORE] Addendum - use casting vs mapValues for countBy{Key,Value} Per rxin, let's use the casting for countByKey and countByValue as well. Let's see if this passes. Author: Sean Owen Closes #10641 from srowen/SPARK-12604.2. --- core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala | 2 +- core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 76752e1fde663..59af1052ebd05 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -296,7 +296,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): java.util.Map[K, jl.Long] = - mapAsSerializableJavaMap(rdd.countByKey().mapValues(jl.Long.valueOf)) + mapAsSerializableJavaMap(rdd.countByKey()).asInstanceOf[java.util.Map[K, jl.Long]] /** * Approximate version of countByKey that can return a partial result if it does diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1b1a9dce397fd..242438237f987 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -448,7 +448,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, jl.Long] = - mapAsSerializableJavaMap(rdd.countByValue().mapValues(jl.Long.valueOf)) + mapAsSerializableJavaMap(rdd.countByValue()).asInstanceOf[java.util.Map[T, jl.Long]] /** * (Experimental) Approximate version of countByValue(). From c94199e977279d9b4658297e8108b46bdf30157b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 7 Jan 2016 17:37:46 -0800 Subject: [PATCH 1491/2089] [SPARK-12507][STREAMING][DOCUMENT] Expose closeFileAfterWrite and allowBatching configurations for Streaming /cc tdas brkyvz Author: Shixiong Zhu Closes #10453 from zsxwing/streaming-conf. --- docs/configuration.md | 18 ++++++++++++++++++ docs/streaming-programming-guide.md | 12 +++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 6bd0658b3e056..08392c39187b9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1574,6 +1574,24 @@ Apart from these, the following properties are also available, and may be useful How many batches the Spark Streaming UI and status APIs remember before garbage collecting. + + spark.streaming.driver.writeAheadLog.closeFileAfterWrite + false + + Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + when you want to use S3 (or any file system that does not support flushing) for the metadata WAL + on the driver. + + + + spark.streaming.receiver.writeAheadLog.closeFileAfterWrite + false + + Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + when you want to use S3 (or any file system that does not support flushing) for the data WAL + on the receivers. + + #### SparkR diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 3b071c7da5596..1edc0fe34706b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1985,7 +1985,11 @@ To run a Spark Streaming applications, you need to have the following. to increase aggregate throughput. Additionally, it is recommended that the replication of the received data within Spark be disabled when the write ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the - input stream to `StorageLevel.MEMORY_AND_DISK_SER`. + input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that + does not support flushing) for _write ahead logs_, please remember to enable + `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and + `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See + [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming application to process data as fast as it is being received, the receivers can be rate limited @@ -2023,12 +2027,6 @@ contains serialized Scala/Java/Python objects and trying to deserialize objects modified classes may lead to errors. In this case, either start the upgraded app with a different checkpoint directory, or delete the previous checkpoint directory. -### Other Considerations -{:.no_toc} -If the data is being received by the receivers faster than what can be processed, -you can limit the rate by setting the [configuration parameter](configuration.html#spark-streaming) -`spark.streaming.receiver.maxRate`. - *** ## Monitoring Applications From 28e0e500a2062baeda8c887e17dc8ab2b7d7d4b4 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 7 Jan 2016 17:46:24 -0800 Subject: [PATCH 1492/2089] [SPARK-12591][STREAMING] Register OpenHashMapBasedStateMap for Kryo The default serializer in Kryo is FieldSerializer and it ignores transient fields and never calls `writeObject` or `readObject`. So we should register OpenHashMapBasedStateMap using `DefaultSerializer` to make it work with Kryo. Author: Shixiong Zhu Closes #10609 from zsxwing/SPARK-12591. --- .../spark/serializer/KryoSerializer.scala | 24 ++++- .../serializer/KryoSerializerSuite.scala | 20 +++- project/MimaExcludes.scala | 4 + .../spark/streaming/util/StateMap.scala | 71 ++++++++++---- .../spark/streaming/StateMapSuite.scala | 96 ++++++++++++++++--- 5 files changed, 174 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index bc9fd50c2cd2b..150ddc12e0694 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{DataInput, DataOutput, EOFException, InputStream, IOException, OutputStream} +import java.io._ import java.nio.ByteBuffer import javax.annotation.Nullable @@ -378,18 +378,24 @@ private[serializer] object KryoSerializer { private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() { override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = { - bitmap.serialize(new KryoOutputDataOutputBridge(output)) + bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output)) } override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = { val ret = new RoaringBitmap - ret.deserialize(new KryoInputDataInputBridge(input)) + ret.deserialize(new KryoInputObjectInputBridge(kryo, input)) ret } } ) } -private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput { +/** + * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all + * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects + * an InputStream or ObjectInput but you want to use Kryo. + */ +private[spark] class KryoInputObjectInputBridge( + kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput { override def readLong(): Long = input.readLong() override def readChar(): Char = input.readChar() override def readFloat(): Float = input.readFloat() @@ -408,9 +414,16 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readBoolean(): Boolean = input.readBoolean() override def readUnsignedByte(): Int = input.readByteUnsigned() override def readDouble(): Double = input.readDouble() + override def readObject(): AnyRef = kryo.readClassAndObject(input) } -private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput { +/** + * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all + * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects + * an OutputStream or ObjectOutput but you want to use Kryo. + */ +private[spark] class KryoOutputObjectOutputBridge( + kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput { override def writeFloat(v: Float): Unit = output.writeFloat(v) // There is no "readChars" counterpart, except maybe "readLine", which is not supported override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars") @@ -426,6 +439,7 @@ private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends override def writeChar(v: Int): Unit = output.writeChar(v.toChar) override def writeLong(v: Long): Unit = output.writeLong(v) override def writeByte(v: Int): Unit = output.writeByte(v) + override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj) } /** diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 8f9b453a6eeec..f869bcd708619 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -362,19 +362,35 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { bitmap.add(1) bitmap.add(3) bitmap.add(5) - bitmap.serialize(new KryoOutputDataOutputBridge(output)) + // Ignore Kryo because it doesn't use writeObject + bitmap.serialize(new KryoOutputObjectOutputBridge(null, output)) output.flush() output.close() val inStream = new FileInputStream(tmpfile) val input = new KryoInput(inStream) val ret = new RoaringBitmap - ret.deserialize(new KryoInputDataInputBridge(input)) + // Ignore Kryo because it doesn't use readObject + ret.deserialize(new KryoInputObjectInputBridge(null, input)) input.close() assert(ret == bitmap) Utils.deleteRecursively(dir) } + test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") { + val kryo = new KryoSerializer(conf).newKryo() + + val bytesOutput = new ByteArrayOutputStream() + val objectOutput = new KryoOutputObjectOutputBridge(kryo, new KryoOutput(bytesOutput)) + objectOutput.writeObject("test") + objectOutput.close() + + val bytesInput = new ByteArrayInputStream(bytesOutput.toByteArray) + val objectInput = new KryoInputObjectInputBridge(kryo, new KryoInput(bytesInput)) + assert(objectInput.readObject() === "test") + objectInput.close() + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 69e5bc881b593..40559a0910ce8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -119,6 +119,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") + ) ++ Seq( + // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") ) ++ Seq( // SPARK-12510 Refactor ActorReceiver to support Java ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3f139ad138c88..4e5baebaae04b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -17,16 +17,20 @@ package org.apache.spark.streaming.util -import java.io.{ObjectInputStream, ObjectOutputStream} +import java.io._ import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ import org.apache.spark.util.collection.OpenHashMap /** Internal interface for defining the map that keeps track of sessions. */ -private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { +private[streaming] abstract class StateMap[K, S] extends Serializable { /** Get the state for a key if it exists */ def get(key: K): Option[S] @@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser /** Companion object for [[StateMap]], with utility methods */ private[streaming] object StateMap { - def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S] def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", @@ -64,7 +68,7 @@ private[streaming] object StateMap { } /** Implementation of StateMap interface representing an empty map */ -private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { +private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { throw new NotImplementedError("put() should not be called on an EmptyStateMap") } @@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa } /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ -private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( +private[streaming] class OpenHashMapBasedStateMap[K, S]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, - deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD - ) extends StateMap[K, S] { self => + private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, + private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S]) + extends StateMap[K, S] with KryoSerializable { self => - def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + def this(initialCapacity: Int, deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( new EmptyStateMap[K, S], initialCapacity = initialCapacity, deltaChainThreshold = deltaChainThreshold) - def this(deltaChainThreshold: Int) = this( + def this(deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) - def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = { + this(DELTA_CHAIN_LENGTH_THRESHOLD) + } require(initialCapacity >= 1, "Invalid initial capacity") require(deltaChainThreshold >= 1, "Invalid delta chain threshold") @@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( * Serialize the map data. Besides serialization, this method actually compact the deltas * (if needed) in a single pass over all the data in the map. */ - - private def writeObject(outputStream: ObjectOutputStream): Unit = { - // Write all the non-transient fields, especially class tags, etc. - outputStream.defaultWriteObject() - + private def writeObjectInternal(outputStream: ObjectOutput): Unit = { // Write the data in the delta of this state map outputStream.writeInt(deltaMap.size) val deltaMapIterator = deltaMap.iterator @@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } /** Deserialize the map data. */ - private def readObject(inputStream: ObjectInputStream): Unit = { - - // Read the non-transient fields, especially class tags, etc. - inputStream.defaultReadObject() - + private def readObjectInternal(inputStream: ObjectInput): Unit = { // Read the data of the delta val deltaMapSize = inputStream.readInt() deltaMap = if (deltaMapSize != 0) { @@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } parentStateMap = newParentSessionStore } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + writeObjectInternal(outputStream) + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + readObjectInternal(inputStream) + } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(initialCapacity) + output.writeInt(deltaChainThreshold) + kryo.writeClassAndObject(output, keyClassTag) + kryo.writeClassAndObject(output, stateClassTag) + writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output)) + } + + override def read(kryo: Kryo, input: Input): Unit = { + initialCapacity = input.readInt() + deltaChainThreshold = input.readInt() + keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]] + stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]] + readObjectInternal(new KryoInputObjectInputBridge(kryo, input)) + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index c4a01eaea739e..ea32bbf95ce59 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,15 +17,23 @@ package org.apache.spark.streaming +import org.apache.spark.streaming.rdd.MapWithStateRDDRecord + import scala.collection.{immutable, mutable, Map} +import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.SparkFunSuite +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Output, Input} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer._ import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} -import org.apache.spark.util.Utils class StateMapSuite extends SparkFunSuite { + private val conf = new SparkConf() + test("EmptyStateMap") { val map = new EmptyStateMap[Int, Int] intercept[scala.NotImplementedError] { @@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite { map1.put(2, 200, 2) testSerialization(map1, "error deserializing and serialized map with data + no delta") - val map2 = map1.copy() + val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] // Do not test compaction - assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + assert(map2.shouldCompact === false) testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") map2.put(3, 300, 3) map2.put(4, 400, 4) testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") - val map3 = map2.copy() - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + assert(map3.shouldCompact === false) testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) @@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } - private def testSerialization[MapType <: StateMap[Int, Int]]( - map: MapType, msg: String): MapType = { - val deserMap = Utils.deserialize[MapType]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + private def testSerialization[T: ClassTag]( + map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = { + testSerialization(new JavaSerializer(conf), map, msg) + testSerialization(new KryoSerializer(conf), map, msg) + } + + private def testSerialization[T : ClassTag]( + serializer: Serializer, + map: OpenHashMapBasedStateMap[T, T], + msg: String): OpenHashMapBasedStateMap[T, T] = { + val deserMap = serializeAndDeserialize(serializer, map) assertMap(deserMap, map, 1, msg) deserMap } // Assert whether all the data and operations on a state map matches that of a reference state map - private def assertMap( - mapToTest: StateMap[Int, Int], - refMapToTestWith: StateMap[Int, Int], + private def assertMap[T]( + mapToTest: StateMap[T, T], + refMapToTestWith: StateMap[T, T], time: Long, msg: String): Unit = { withClue(msg) { @@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite { } } } + + test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + testSerialization( + new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states") + } + + test("EmptyStateMap - serializing and deserializing") { + val map = StateMap.empty[KryoState, KryoState] + // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer. + assert(serializeAndDeserialize(new JavaSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + assert(serializeAndDeserialize(new KryoSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + } + + test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + + val record = + MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c"))) + val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record) + assert(!(record eq deserRecord)) + assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq) + assert(record.mappedData === deserRecord.mappedData) + } + + private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = { + val serializerInstance = serializer.newInstance() + serializerInstance.deserialize[T]( + serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader) + } +} + +/** A class that only supports Kryo serialization. */ +private[streaming] final class KryoState(var state: String) extends KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + kryo.writeClassAndObject(output, state) + } + + override def read(kryo: Kryo, input: Input): Unit = { + state = kryo.readClassAndObject(input).asInstanceOf[String] + } + + override def equals(other: Any): Boolean = other match { + case that: KryoState => state == that.state + case _ => false + } + + override def hashCode(): Int = { + if (state == null) 0 else state.hashCode() + } } From 5028a001d51a9e9a13e3c39f6a080618f3425d87 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 7 Jan 2016 21:13:17 -0800 Subject: [PATCH 1493/2089] [SPARK-12317][SQL] Support units (m,k,g) in SQLConf This PR is continue from previous closed PR 10314. In this PR, SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE will be taken memory string conventions as input. For example, the user can now specify 10g for SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE in SQLConf file. marmbrus srowen : Can you help review this code changes ? Thanks. Author: Kevin Yu Closes #10629 from kevinyu98/spark-12317. --- .../scala/org/apache/spark/sql/SQLConf.scala | 22 ++++++++++- .../org/apache/spark/sql/SQLConfSuite.scala | 39 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 26c00dc250b4b..7976795ff5919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -26,6 +26,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.parser.ParserConf +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -115,6 +116,25 @@ private[spark] object SQLConf { } }, _.toString, doc, isPublic) + def longMemConf( + key: String, + defaultValue: Option[Long] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Long] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toLong + } catch { + case _: NumberFormatException => + try { + Utils.byteStringAsBytes(v) + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be long, but was $v") + } + } + }, _.toString, doc, isPublic) + def doubleConf( key: String, defaultValue: Option[Double] = None, @@ -235,7 +255,7 @@ private[spark] object SQLConf { doc = "The default number of partitions to use when shuffling data for joins or aggregations.") val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = - longConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", + longMemConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", defaultValue = Some(64 * 1024 * 1024), doc = "The target post-shuffle input size in bytes of a task.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 43300cd635c05..a2eddc8fe173e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -92,4 +92,43 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } + + test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { + sqlContext.conf.clear() + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(sqlContext.conf.targetPostShuffleInputSize === 100) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(sqlContext.conf.targetPostShuffleInputSize === -1) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MaxValue + // Utils.byteStringAsBytes("90000000000g") + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + } + + intercept[IllegalArgumentException] { + // This value less than Int.MinValue + // Utils.byteStringAsBytes("-90000000000g") + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + } + // Test invalid input + intercept[IllegalArgumentException] { + // This value exceeds Long.MaxValue + // Utils.byteStringAsBytes("-1g") + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1g") + } + sqlContext.conf.clear() + } } From 726bd3c4ece33667096f04be4d3e1ea13048a1af Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Jan 2016 21:15:43 -0800 Subject: [PATCH 1494/2089] Fix indentation for the previous patch. --- .../org/apache/spark/sql/SQLConfSuite.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index a2eddc8fe173e..cf0701eca29ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -113,22 +113,20 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { // Test overflow exception intercept[IllegalArgumentException] { - // This value exceeds Long.MaxValue - // Utils.byteStringAsBytes("90000000000g") - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + // This value exceeds Long.MaxValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") } intercept[IllegalArgumentException] { - // This value less than Int.MinValue - // Utils.byteStringAsBytes("-90000000000g") + // This value less than Int.MinValue sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") - } + } + // Test invalid input intercept[IllegalArgumentException] { - // This value exceeds Long.MaxValue - // Utils.byteStringAsBytes("-1g") - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1g") - } + // This value exceeds Long.MaxValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1g") + } sqlContext.conf.clear() } } From 794ea553bd0fcfece15b610b47ee86d6644134c9 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 8 Jan 2016 00:53:15 -0800 Subject: [PATCH 1495/2089] [SPARK-12692][BUILD] Scala style: check no white space before comma and colon We should not put a white space before `,` and `:` so let's check it. Because there are lots of style violations, first, I'd like to add a checker, enable and let the level `warning`. Then, I'd like to fix the style step by step. Author: Kousuke Saruta Closes #10643 from sarutak/SPARK-12692. --- scalastyle-config.xml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index ee855ca0e09cb..9714c46fe99a0 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -218,6 +218,12 @@ This file is divided into 3 sections: + + + + COLON, COMMA + + From b9c835337880f57fe8b953962913bcc524162348 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 8 Jan 2016 17:47:44 +0000 Subject: [PATCH 1496/2089] [SPARK-12618][CORE][STREAMING][SQL] Clean up build warnings: 2.0.0 edition Fix most build warnings: mostly deprecated API usages. I'll annotate some of the changes below. CC rxin who is leading the charge to remove the deprecated APIs. Author: Sean Owen Closes #10570 from srowen/SPARK-12618. --- .../test/scala/org/apache/spark/Smuggle.scala | 1 + ...avaBinaryClassificationMetricsExample.java | 5 +- .../mllib/JavaRankingMetricsExample.java | 21 ++++-- .../JavaRecoverableNetworkWordCount.java | 8 +-- .../streaming/JavaSqlNetworkWordCount.java | 8 +-- .../JavaTwitterHashTagJoinSentiments.java | 36 +++++------ .../apache/spark/examples/SparkHdfsLR.scala | 2 +- .../spark/examples/SparkTachyonHdfsLR.scala | 2 +- .../kafka/JavaDirectKafkaStreamSuite.java | 7 +- .../streaming/kafka/JavaKafkaStreamSuite.java | 8 +-- .../kinesis/KinesisStreamSuite.scala | 8 +-- .../spark/mllib/clustering/KMeans.scala | 8 +-- .../mllib/recommendation/JavaALSSuite.java | 4 +- .../JavaIsotonicRegressionSuite.java | 18 +++--- python/pyspark/mllib/clustering.py | 2 +- .../expressions/ExpressionEvalHelper.scala | 8 +-- .../catalyst/util/DateTimeUtilsSuite.scala | 3 - .../SpecificParquetRecordReaderBase.java | 19 +++--- .../spark/sql/ColumnExpressionSuite.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 5 +- .../columnar/ColumnarTestUtils.scala | 1 + .../apache/spark/streaming/JavaAPISuite.java | 4 +- .../streaming/JavaMapWithStateSuite.java | 64 +++++++------------ .../spark/streaming/JavaReceiverAPISuite.java | 14 ++-- 24 files changed, 123 insertions(+), 137 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala index 01694a6e6f741..9f0a1b4c25dd1 100644 --- a/core/src/test/scala/org/apache/spark/Smuggle.scala +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -21,6 +21,7 @@ import java.util.UUID import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable +import scala.language.implicitConversions /** * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java index 779fac01c4be0..3d8babba04a53 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -56,6 +56,7 @@ public static void main(String[] args) { // Compute raw scores on the test set. JavaRDD> predictionAndLabels = test.map( new Function>() { + @Override public Tuple2 call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2(prediction, p.label()); @@ -88,6 +89,7 @@ public Tuple2 call(LabeledPoint p) { // Thresholds JavaRDD thresholds = precision.map( new Function, Double>() { + @Override public Double call(Tuple2 t) { return new Double(t._1().toString()); } @@ -106,8 +108,7 @@ public Double call(Tuple2 t) { // Save and load model model.save(sc, "target/tmp/LogisticRegressionModel"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, - "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel"); // $example off$ } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index 47ab3fc358246..4ad2104763330 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -41,6 +41,7 @@ public static void main(String[] args) { JavaRDD data = sc.textFile(path); JavaRDD ratings = data.map( new Function() { + @Override public Rating call(String line) { String[] parts = line.split("::"); return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double @@ -57,13 +58,14 @@ public Rating call(String line) { JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); JavaRDD> userRecsScaled = userRecs.map( new Function, Tuple2>() { + @Override public Tuple2 call(Tuple2 t) { Rating[] scaledRatings = new Rating[t._2().length]; for (int i = 0; i < scaledRatings.length; i++) { double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); } - return new Tuple2(t._1(), scaledRatings); + return new Tuple2<>(t._1(), scaledRatings); } } ); @@ -72,6 +74,7 @@ public Tuple2 call(Tuple2 t) { // Map ratings to 1 or 0, 1 indicating a movie that should be recommended JavaRDD binarizedRatings = ratings.map( new Function() { + @Override public Rating call(Rating r) { double binaryRating; if (r.rating() > 0.0) { @@ -87,6 +90,7 @@ public Rating call(Rating r) { // Group ratings by common user JavaPairRDD> userMovies = binarizedRatings.groupBy( new Function() { + @Override public Object call(Rating r) { return r.user(); } @@ -96,8 +100,9 @@ public Object call(Rating r) { // Get true relevant documents from all user ratings JavaPairRDD> userMoviesList = userMovies.mapValues( new Function, List>() { + @Override public List call(Iterable docs) { - List products = new ArrayList(); + List products = new ArrayList<>(); for (Rating r : docs) { if (r.rating() > 0.0) { products.add(r.product()); @@ -111,8 +116,9 @@ public List call(Iterable docs) { // Extract the product id from each recommendation JavaPairRDD> userRecommendedList = userRecommended.mapValues( new Function>() { + @Override public List call(Rating[] docs) { - List products = new ArrayList(); + List products = new ArrayList<>(); for (Rating r : docs) { products.add(r.product()); } @@ -124,7 +130,7 @@ public List call(Rating[] docs) { userRecommendedList).values(); // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); + RankingMetrics metrics = RankingMetrics.of(relevantDocs); // Precision and NDCG at k Integer[] kVector = {1, 3, 5}; @@ -139,6 +145,7 @@ public List call(Rating[] docs) { // Evaluate the model using numerical ratings and regression metrics JavaRDD> userProducts = ratings.map( new Function>() { + @Override public Tuple2 call(Rating r) { return new Tuple2(r.user(), r.product()); } @@ -147,18 +154,20 @@ public Tuple2 call(Rating r) { JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( new Function, Object>>() { + @Override public Tuple2, Object> call(Rating r) { return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); + new Tuple2<>(r.user(), r.product()), r.rating()); } } )); JavaRDD> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map( new Function, Object>>() { + @Override public Tuple2, Object> call(Rating r) { return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); + new Tuple2<>(r.user(), r.product()), r.rating()); } } )).join(predictions).values(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index 90d473703ec5a..bc963a02be608 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -36,6 +36,7 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.function.VoidFunction2; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; @@ -154,9 +155,9 @@ public Integer call(Integer i1, Integer i2) { } }); - wordCounts.foreachRDD(new Function2, Time, Void>() { + wordCounts.foreachRDD(new VoidFunction2, Time>() { @Override - public Void call(JavaPairRDD rdd, Time time) throws IOException { + public void call(JavaPairRDD rdd, Time time) throws IOException { // Get or register the blacklist Broadcast final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); // Get or register the droppedWordsCounter Accumulator @@ -164,7 +165,7 @@ public Void call(JavaPairRDD rdd, Time time) throws IOException // Use blacklist to drop words and use droppedWordsCounter to count them String counts = rdd.filter(new Function, Boolean>() { @Override - public Boolean call(Tuple2 wordCount) throws Exception { + public Boolean call(Tuple2 wordCount) { if (blacklist.value().contains(wordCount._1())) { droppedWordsCounter.add(wordCount._2()); return false; @@ -178,7 +179,6 @@ public Boolean call(Tuple2 wordCount) throws Exception { System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); System.out.println("Appending to " + outputFile.getAbsolutePath()); Files.append(output + "\n", outputFile, Charset.defaultCharset()); - return null; } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 3515d7be45d37..084f68a8be437 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.VoidFunction2; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.DataFrame; import org.apache.spark.api.java.StorageLevels; @@ -78,13 +78,14 @@ public Iterable call(String x) { }); // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD(new Function2, Time, Void>() { + words.foreachRDD(new VoidFunction2, Time>() { @Override - public Void call(JavaRDD rdd, Time time) { + public void call(JavaRDD rdd, Time time) { SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { + @Override public JavaRecord call(String word) { JavaRecord record = new JavaRecord(); record.setWord(word); @@ -101,7 +102,6 @@ public JavaRecord call(String word) { sqlContext.sql("select word, count(*) as total from words group by word"); System.out.println("========= " + time + "========="); wordCountsDataFrame.show(); - return null; } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java index 030ee30b93381..d869768026ae3 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java @@ -17,13 +17,13 @@ package org.apache.spark.examples.streaming; -import org.apache.commons.io.IOUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -33,8 +33,6 @@ import scala.Tuple2; import twitter4j.Status; -import java.io.IOException; -import java.net.URI; import java.util.Arrays; import java.util.List; @@ -44,7 +42,7 @@ */ public class JavaTwitterHashTagJoinSentiments { - public static void main(String[] args) throws IOException { + public static void main(String[] args) { if (args.length < 4) { System.err.println("Usage: JavaTwitterHashTagJoinSentiments " + " []"); @@ -79,7 +77,7 @@ public Iterable call(Status s) { JavaDStream hashTags = words.filter(new Function() { @Override - public Boolean call(String word) throws Exception { + public Boolean call(String word) { return word.startsWith("#"); } }); @@ -91,8 +89,7 @@ public Boolean call(String word) throws Exception { @Override public Tuple2 call(String line) { String[] columns = line.split("\t"); - return new Tuple2(columns[0], - Double.parseDouble(columns[1])); + return new Tuple2<>(columns[0], Double.parseDouble(columns[1])); } }); @@ -101,7 +98,7 @@ public Tuple2 call(String line) { @Override public Tuple2 call(String s) { // leave out the # character - return new Tuple2(s.substring(1), 1); + return new Tuple2<>(s.substring(1), 1); } }); @@ -120,9 +117,8 @@ public Integer call(Integer a, Integer b) { hashTagTotals.transformToPair(new Function, JavaPairRDD>>() { @Override - public JavaPairRDD> call(JavaPairRDD topicCount) - throws Exception { + public JavaPairRDD> call( + JavaPairRDD topicCount) { return wordSentiments.join(topicCount); } }); @@ -131,9 +127,9 @@ public JavaPairRDD> call(JavaPairRDD>, String, Double>() { @Override public Tuple2 call(Tuple2> topicAndTuplePair) throws Exception { + Tuple2> topicAndTuplePair) { Tuple2 happinessAndCount = topicAndTuplePair._2(); - return new Tuple2(topicAndTuplePair._1(), + return new Tuple2<>(topicAndTuplePair._1(), happinessAndCount._1() * happinessAndCount._2()); } }); @@ -141,9 +137,8 @@ public Tuple2 call(Tuple2 happinessTopicPairs = topicHappiness.mapToPair( new PairFunction, Double, String>() { @Override - public Tuple2 call(Tuple2 topicHappiness) - throws Exception { - return new Tuple2(topicHappiness._2(), + public Tuple2 call(Tuple2 topicHappiness) { + return new Tuple2<>(topicHappiness._2(), topicHappiness._1()); } }); @@ -151,17 +146,17 @@ public Tuple2 call(Tuple2 topicHappiness) JavaPairDStream happiest10 = happinessTopicPairs.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD happinessAndTopics) throws Exception { + public JavaPairRDD call( + JavaPairRDD happinessAndTopics) { return happinessAndTopics.sortByKey(false); } } ); // Print hash tags with the most positive sentiment values - happiest10.foreachRDD(new Function, Void>() { + happiest10.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaPairRDD happinessTopicPairs) throws Exception { + public void call(JavaPairRDD happinessTopicPairs) { List> topList = happinessTopicPairs.take(10); System.out.println( String.format("\nHappiest topics in last 10 seconds (%s total):", @@ -170,7 +165,6 @@ public Void call(JavaPairRDD happinessTopicPairs) throws Excepti System.out.println( String.format("%s (%s happiness)", pair._2(), pair._1())); } - return null; } }); diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 04dec57b71e16..e4486b949fb3e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -74,7 +74,7 @@ object SparkHdfsLR { val conf = new Configuration() val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).cache() + val points = lines.map(parsePoint).cache() val ITERATIONS = args(1).toInt // Initialize w to a random value diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index ddc99d3f90690..8b739c9d7c1db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -71,7 +71,7 @@ object SparkTachyonHdfsLR { val conf = new Configuration() val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) + val points = lines.map(parsePoint).persist(StorageLevel.OFF_HEAP) val ITERATIONS = args(1).toInt // Initialize w to a random value diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index fbdfbf7e509b3..4891e4f4a17bc 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -130,17 +131,15 @@ public String call(MessageAndMetadata msgAndMd) { JavaDStream unifiedStream = stream1.union(stream2); final Set result = Collections.synchronizedSet(new HashSet()); - unifiedStream.foreachRDD( - new Function, Void>() { + unifiedStream.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) { + public void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() ); } - return null; } } ); diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 1e69de46cd35d..617c92a008fc5 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; @@ -103,10 +104,9 @@ public String call(Tuple2 tuple2) { } ); - words.countByValue().foreachRDD( - new Function, Void>() { + words.countByValue().foreachRDD(new VoidFunction>() { @Override - public Void call(JavaPairRDD rdd) { + public void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -115,8 +115,6 @@ public Void call(JavaPairRDD rdd) { result.put(r._1(), r._2()); } } - - return null; } } ); diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 6fe24fe81165b..78263f9dca65c 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -137,8 +137,8 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Verify that the generated KinesisBackedBlockRDD has the all the right information val blockInfos = Seq(blockInfo1, blockInfo2) val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) - nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] - val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[_]] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[_]] assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) @@ -203,7 +203,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun Seconds(10), StorageLevel.MEMORY_ONLY, addFive, awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) - stream shouldBe a [ReceiverInputDStream[Int]] + stream shouldBe a [ReceiverInputDStream[_]] val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.foreachRDD { rdd => @@ -272,7 +272,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun times.foreach { time => val (arrayOfSeqNumRanges, data) = collectedData(time) val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] - rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + rdd shouldBe a [KinesisBackedBlockRDD[_]] // Verify the recovered sequence ranges val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index e47c4db62955d..ca11ede4ccd47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -107,7 +107,7 @@ class KMeans private ( * Number of runs of the algorithm to execute in parallel. */ @Since("1.4.0") - @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def getRuns: Int = runs /** @@ -117,7 +117,7 @@ class KMeans private ( * return the best clustering found over any run. Default: 1. */ @Since("0.8.0") - @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def setRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") @@ -431,7 +431,7 @@ class KMeans private ( val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) } - if (rs.length > 0) Some(p, rs) else None + if (rs.length > 0) Some((p, rs)) else None } }.collect() mergeNewCenters() diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index 271dda4662e0d..a6631ed7ebd6f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -56,10 +56,10 @@ void validatePrediction( double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - List> localUsersProducts = new ArrayList(users * products); + List> localUsersProducts = new ArrayList<>(users * products); for (int u=0; u < users; ++u) { for (int p=0; p < products; ++p) { - localUsersProducts.add(new Tuple2(u, p)); + localUsersProducts.add(new Tuple2<>(u, p)); } } JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 32c2f4f3395b7..3db9b39e740e7 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -36,11 +36,11 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; - private List> generateIsotonicInput(double[] labels) { - ArrayList> input = new ArrayList(labels.length); + private static List> generateIsotonicInput(double[] labels) { + List> input = new ArrayList<>(labels.length); for (int i = 1; i <= labels.length; i++) { - input.add(new Tuple3(labels[i-1], (double) i, 1d)); + input.add(new Tuple3<>(labels[i-1], (double) i, 1.0)); } return input; @@ -70,7 +70,7 @@ public void testIsotonicRegressionJavaRDD() { runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); Assert.assertArrayEquals( - new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); + new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); } @Test @@ -81,10 +81,10 @@ public void testIsotonicRegressionPredictionsJavaRDD() { JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); - Assert.assertTrue(predictions.get(0) == 1d); - Assert.assertTrue(predictions.get(1) == 1d); - Assert.assertTrue(predictions.get(2) == 10d); - Assert.assertTrue(predictions.get(3) == 12d); - Assert.assertTrue(predictions.get(4) == 12d); + Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14); + Assert.assertEquals(1.0, predictions.get(1).doubleValue(), 1.0e-14); + Assert.assertEquals(10.0, predictions.get(2).doubleValue(), 1.0e-14); + Assert.assertEquals(12.0, predictions.get(3).doubleValue(), 1.0e-14); + Assert.assertEquals(12.0, predictions.get(4).doubleValue(), 1.0e-14); } } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 48daa87e82d13..d22a7f4c3b167 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -173,7 +173,7 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" """Train a k-means clustering model.""" if runs != 1: warnings.warn( - "Support for runs is deprecated in 1.6.0. This param will have no effect in 1.7.0.") + "Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.") clusterInitialModel = [] if initialModel is not None: if not isinstance(initialModel, KMeansModel): diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f869a96edb1ce..e028d22a54ba0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -57,8 +57,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double]) => - expected.isWithin(result) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) case _ => result == expected } } @@ -275,8 +275,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double]) => - expected.isWithin(result) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) case (result: Double, expected: Double) if result.isNaN && expected.isNaN => true case (result: Float, expected: Float) if result.isNaN && expected.isNaN => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index d5f1c4d74efcf..6745b4b6c3c67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -384,9 +384,6 @@ class DateTimeUtilsSuite extends SparkFunSuite { Timestamp.valueOf("1700-02-28 12:14:50.123456")).foreach { t => val us = fromJavaTimestamp(t) assert(toJavaTimestamp(us) === t) - assert(getHours(us) === t.getHours) - assert(getMinutes(us) === t.getMinutes) - assert(getSeconds(us) === t.getSeconds) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index f8e32d60a489a..6bcd155ccdc49 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -21,6 +21,7 @@ import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -62,7 +63,7 @@ import org.apache.spark.sql.types.StructType; /** - * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * Base class for custom RecordReaders for Parquet that directly materialize to `T`. * This class handles computing row groups, filtering on them, setting up the column readers, * etc. * This is heavily based on parquet-mr's RecordReader. @@ -83,6 +84,7 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); - ReadSupport readSupport = getReadSupportInstance( - (Class>) getReadSupportClass(configuration)); + ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration)); ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); this.requestedSchema = readContext.getRequestedSchema(); @@ -282,8 +283,9 @@ private static Map> toSetMultiMap(Map map) { return Collections.unmodifiableMap(setMultiMap); } - private static Class getReadSupportClass(Configuration configuration) { - return ConfigurationUtil.getClassFromConfig(configuration, + @SuppressWarnings("unchecked") + private Class> getReadSupportClass(Configuration configuration) { + return (Class>) ConfigurationUtil.getClassFromConfig(configuration, ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); } @@ -294,10 +296,9 @@ private static Class getReadSupportClass(Configuration configuration) { private static ReadSupport getReadSupportInstance( Class> readSupportClass){ try { - return readSupportClass.newInstance(); - } catch (InstantiationException e) { - throw new BadConfigurationException("could not instantiate read support class", e); - } catch (IllegalAccessException e) { + return readSupportClass.getConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | + NoSuchMethodException | InvocationTargetException e) { throw new BadConfigurationException("could not instantiate read support class", e); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 076db0c08dee0..eb4efcd1d4e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -580,7 +580,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("sparkPartitionId") { + test("spark_partition_id") { // Make sure we have 2 partitions, each with 2 records. val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) @@ -591,7 +591,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("InputFileName") { + test("input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 0e60573dc6b2c..fac26bd0c0269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate @@ -206,7 +207,7 @@ abstract class QueryTest extends PlanTest { val jsonString = try { logicalPlan.toJSON } catch { - case e => + case NonFatal(e) => fail( s""" |Failed to parse logical plan to JSON: @@ -231,7 +232,7 @@ abstract class QueryTest extends PlanTest { val jsonBackPlan = try { TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) } catch { - case e => + case NonFatal(e) => fail( s""" |Failed to rebuild the logical plan from JSON: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 97cba1e349e8f..1529313dfbd51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -60,6 +60,7 @@ object ColumnarTestUtils { case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) + case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 9722c60bba1c3..ddc56fc869ae1 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -772,8 +772,8 @@ public Iterable call(String x) { @SuppressWarnings("unchecked") @Test public void testForeachRDD() { - final Accumulator accumRdd = ssc.sc().accumulator(0); - final Accumulator accumEle = ssc.sc().accumulator(0); + final Accumulator accumRdd = ssc.sparkContext().accumulator(0); + final Accumulator accumEle = ssc.sparkContext().accumulator(0); List> inputData = Arrays.asList( Arrays.asList(1,1,1), Arrays.asList(1,1,1)); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index bc4bc2eb42231..20e2a1c3d5c31 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.streaming; import java.io.Serializable; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -26,10 +27,10 @@ import scala.Tuple2; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.util.ManualClock; import org.junit.Assert; @@ -51,10 +52,8 @@ public void testAPI() { JavaPairRDD initialRDD = null; JavaPairDStream wordsDstream = null; - final Function4, State, Optional> - mappingFunc = + Function4, State, Optional> mappingFunc = new Function4, State, Optional>() { - @Override public Optional call( Time time, String word, Optional one, State state) { @@ -76,11 +75,10 @@ public Optional call( .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream stateSnapshots = stateDstream.stateSnapshots(); + stateDstream.stateSnapshots(); - final Function3, State, Double> mappingFunc2 = + Function3, State, Double> mappingFunc2 = new Function3, State, Double>() { - @Override public Double call(String key, Optional one, State state) { // Use all State's methods here @@ -95,13 +93,13 @@ public Double call(String key, Optional one, State state) { JavaMapWithStateDStream stateDstream2 = wordsDstream.mapWithState( - StateSpec.function(mappingFunc2) + StateSpec.function(mappingFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream stateSnapshots2 = stateDstream2.stateSnapshots(); + stateDstream2.stateSnapshots(); } @Test @@ -126,33 +124,21 @@ public void testBasicFunction() { Collections.emptySet() ); + @SuppressWarnings("unchecked") List>> stateData = Arrays.asList( Collections.>emptySet(), - Sets.newHashSet(new Tuple2("a", 1)), - Sets.newHashSet(new Tuple2("a", 2), new Tuple2("b", 1)), - Sets.newHashSet( - new Tuple2("a", 3), - new Tuple2("b", 2), - new Tuple2("c", 1)), - Sets.newHashSet( - new Tuple2("a", 4), - new Tuple2("b", 3), - new Tuple2("c", 1)), - Sets.newHashSet( - new Tuple2("a", 5), - new Tuple2("b", 3), - new Tuple2("c", 1)), - Sets.newHashSet( - new Tuple2("a", 5), - new Tuple2("b", 3), - new Tuple2("c", 1)) + Sets.newHashSet(new Tuple2<>("a", 1)), + Sets.newHashSet(new Tuple2<>("a", 2), new Tuple2<>("b", 1)), + Sets.newHashSet(new Tuple2<>("a", 3), new Tuple2<>("b", 2), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 4), new Tuple2<>("b", 3), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)) ); Function3, State, Integer> mappingFunc = new Function3, State, Integer>() { - @Override - public Integer call(String key, Optional value, State state) throws Exception { + public Integer call(String key, Optional value, State state) { int sum = value.or(0) + (state.exists() ? state.get() : 0); state.update(sum); return sum; @@ -160,7 +146,7 @@ public Integer call(String key, Optional value, State state) t }; testOperation( inputData, - StateSpec.function(mappingFunc), + StateSpec.function(mappingFunc), outputData, stateData); } @@ -175,27 +161,25 @@ private void testOperation( JavaMapWithStateDStream mapWithStateDStream = JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { @Override - public Tuple2 call(K x) throws Exception { - return new Tuple2(x, 1); + public Tuple2 call(K x) { + return new Tuple2<>(x, 1); } })).mapWithState(mapWithStateSpec); final List> collectedOutputs = - Collections.synchronizedList(Lists.>newArrayList()); - mapWithStateDStream.foreachRDD(new Function, Void>() { + Collections.synchronizedList(new ArrayList>()); + mapWithStateDStream.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public void call(JavaRDD rdd) { collectedOutputs.add(Sets.newHashSet(rdd.collect())); - return null; } }); final List>> collectedStateSnapshots = - Collections.synchronizedList(Lists.>>newArrayList()); - mapWithStateDStream.stateSnapshots().foreachRDD(new Function, Void>() { + Collections.synchronizedList(new ArrayList>>()); + mapWithStateDStream.stateSnapshots().foreachRDD(new VoidFunction>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public void call(JavaPairRDD rdd) { collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); - return null; } }); BatchCounter batchCounter = new BatchCounter(ssc.ssc()); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 7a8ef9d14784c..d09258e0e4a85 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -18,13 +18,14 @@ package org.apache.spark.streaming; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import static org.junit.Assert.*; import com.google.common.io.Closeables; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -68,12 +69,11 @@ public String call(String v1) { return v1 + "."; } }); - mapped.foreachRDD(new Function, Void>() { + mapped.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) { + public void call(JavaRDD rdd) { long count = rdd.count(); dataCounter.addAndGet(count); - return null; } }); @@ -90,7 +90,7 @@ public Void call(JavaRDD rdd) { Thread.sleep(100); } ssc.stop(); - assertTrue(dataCounter.get() > 0); + Assert.assertTrue(dataCounter.get() > 0); } finally { server.stop(); } @@ -98,8 +98,8 @@ public Void call(JavaRDD rdd) { private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + private String host = null; + private int port = -1; JavaSocketReceiver(String host_ , int port_) { super(StorageLevel.MEMORY_AND_DISK()); From cfe1ba56e4ab281a9e8eaf419fb7429f93c7a0ce Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 Jan 2016 09:50:41 -0800 Subject: [PATCH 1497/2089] [SPARK-12687] [SQL] Support from clause surrounded by `()`. JIRA: https://issues.apache.org/jira/browse/SPARK-12687 Some queries such as `(select 1 as a) union (select 2 as a)` can't work. This patch fixes it. Author: Liang-Chi Hsieh Closes #10660 from viirya/fix-union. --- .../sql/catalyst/parser/FromClauseParser.g | 2 +- .../sql/catalyst/parser/SparkSqlParser.g | 21 ++++++++++++++++++- .../spark/sql/catalyst/CatalystQlSuite.scala | 4 ++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index ba6cfc60f045f..972c52e3ffcec 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -151,8 +151,8 @@ fromSource @after { gParent.popMsg(state); } : (LPAREN KW_VALUES) => fromSource0 - | (LPAREN) => LPAREN joinSource RPAREN -> joinSource | fromSource0 + | (LPAREN joinSource) => LPAREN joinSource RPAREN -> joinSource ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index cf8a56566d32d..b04bb677774c5 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -2216,6 +2216,8 @@ regularBody[boolean topLevel] selectStatement[boolean topLevel] : ( + ( + LPAREN s=selectClause f=fromClause? w=whereClause? @@ -2227,6 +2229,20 @@ selectStatement[boolean topLevel] sort=sortByClause? win=window_clause? l=limitClause? + RPAREN + | + s=selectClause + f=fromClause? + w=whereClause? + g=groupByClause? + h=havingClause? + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + ) -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) $s $w? $g? $h? $o? $c? $d? $sort? $win? $l?)) @@ -2241,7 +2257,10 @@ selectStatement[boolean topLevel] setOpSelectStatement[CommonTree t, boolean topLevel] : - (u=setOperator b=simpleSelectStatement + (( + u=setOperator LPAREN b=simpleSelectStatement RPAREN + | + u=setOperator b=simpleSelectStatement) -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? ^(TOK_QUERY ^(TOK_FROM diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 0fee97fb0718c..30978d9b49e2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -28,5 +28,9 @@ class CatalystQlSuite extends PlanTest { paresr.createPlan("select * from t1 union select * from t2") paresr.createPlan("select * from t1 except select * from t2") paresr.createPlan("select * from t1 intersect select * from t2") + paresr.createPlan("(select * from t1) union all (select * from t2)") + paresr.createPlan("(select * from t1) union distinct (select * from t2)") + paresr.createPlan("(select * from t1) union (select * from t2)") + paresr.createPlan("select * from ((select * from t1) union (select * from t2)) t") } } From ea104b8f1ce8aa109d1b16b696a61a47df6283b2 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 8 Jan 2016 11:08:45 -0800 Subject: [PATCH 1498/2089] [SPARK-12701][CORE] FileAppender should use join to ensure writing thread completion Changed Logging FileAppender to use join in `awaitTermination` to ensure that thread is properly finished before returning. Author: Bryan Cutler Closes #10654 from BryanCutler/fileAppender-join-thread-SPARK-12701. --- .../org/apache/spark/util/logging/FileAppender.scala | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 14b6ba4af489a..58c8560a3d049 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -29,7 +29,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi extends Logging { @volatile private var outputStream: FileOutputStream = null @volatile private var markedForStop = false // has the appender been asked to stopped - @volatile private var stopped = false // has the appender stopped // Thread that reads the input stream and writes to file private val writingThread = new Thread("File appending thread for " + file) { @@ -47,11 +46,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi * or because of any error in appending */ def awaitTermination() { - synchronized { - if (!stopped) { - wait() - } - } + writingThread.join() } /** Stop the appender */ @@ -77,10 +72,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi logError(s"Error writing stream to file $file", e) } finally { closeFile() - synchronized { - stopped = true - notifyAll() - } } } From 00d9261724feb48d358679efbae6889833e893e0 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 8 Jan 2016 11:38:46 -0800 Subject: [PATCH 1499/2089] [DOCUMENTATION] doc fix of job scheduling spark.shuffle.service.enabled is spark application related configuration, it is not necessary to set it in yarn-site.xml Author: Jeff Zhang Closes #10657 from zjffdu/doc-fix. --- docs/job-scheduling.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 36327c6efeaf3..6c587b3f0d8db 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -91,7 +91,7 @@ pre-packaged distribution. 2. Add this jar to the classpath of all `NodeManager`s in your cluster. 3. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to -`org.apache.spark.network.yarn.YarnShuffleService` and `spark.shuffle.service.enabled` to true. +`org.apache.spark.network.yarn.YarnShuffleService`. 4. Restart all `NodeManager`s in your cluster. All other relevant configurations are optional and under the `spark.dynamicAllocation.*` and From 8c70cb4c62a353bea99f37965dfc829c4accc391 Mon Sep 17 00:00:00 2001 From: Udo Klein Date: Fri, 8 Jan 2016 20:32:37 +0000 Subject: [PATCH 1500/2089] fixed numVertices in transitive closure example Author: Udo Klein Closes #10642 from udoklein/patch-2. --- examples/src/main/python/transitive_closure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 7bf5fb6ddfe29..3d61250d8b230 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -30,8 +30,8 @@ def generateGraph(): edges = set() while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) + src = rand.randrange(0, numVertices) + dst = rand.randrange(0, numVertices) if src != dst: edges.add((src, dst)) return edges From 553fd7b912a32476b481fd3f80c1d0664b6c6484 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 8 Jan 2016 14:38:19 -0600 Subject: [PATCH 1501/2089] =?UTF-8?q?[SPARK-12654]=20sc.wholeTextFiles=20w?= =?UTF-8?q?ith=20spark.hadoop.cloneConf=3Dtrue=20fail=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …s on secure Hadoop https://issues.apache.org/jira/browse/SPARK-12654 So the bug here is that WholeTextFileRDD.getPartitions has: val conf = getConf in getConf if the cloneConf=true it creates a new Hadoop Configuration. Then it uses that to create a new newJobContext. The newJobContext will copy credentials around, but credentials are only present in a JobConf not in a Hadoop Configuration. So basically when it is cloning the hadoop configuration its changing it from a JobConf to Configuration and dropping the credentials that were there. NewHadoopRDD just uses the conf passed in for the getPartitions (not getConf) which is why it works. Author: Thomas Graves Closes #10651 from tgravescs/SPARK-12654. --- .../main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 146609ae3911a..7a1197830443f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} @@ -93,7 +94,13 @@ class NewHadoopRDD[K, V]( // issues, this cloning is disabled by default. NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { logDebug("Cloning Hadoop Configuration") - new Configuration(conf) + // The Configuration passed in is actually a JobConf and possibly contains credentials. + // To keep those credentials properly we have to create a new JobConf not a Configuration. + if (conf.isInstanceOf[JobConf]) { + new JobConf(conf) + } else { + new Configuration(conf) + } } } else { conf From 659fd9d04b988d48960eac4f352ca37066f43f5c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 8 Jan 2016 13:02:30 -0800 Subject: [PATCH 1502/2089] [SPARK-4819] Remove Guava's "Optional" from public API Replace Guava `Optional` with (an API clone of) Java 8 `java.util.Optional` (edit: and a clone of Guava `Optional`) See also https://github.com/apache/spark/pull/10512 Author: Sean Owen Closes #10513 from srowen/SPARK-4819. --- .../org/apache/spark/api/java/Optional.java | 187 ++++++++++++++++++ .../apache/spark/api/java/JavaPairRDD.scala | 2 - .../apache/spark/api/java/JavaRDDLike.scala | 4 - .../spark/api/java/JavaSparkContext.scala | 1 - .../org/apache/spark/api/java/JavaUtils.scala | 9 +- .../java/org/apache/spark/JavaAPISuite.java | 46 ++--- .../apache/spark/api/java/OptionalSuite.java | 94 +++++++++ docs/streaming-programming-guide.md | 1 - .../JavaStatefulNetworkWordCount.java | 20 +- .../java/org/apache/spark/Java8APISuite.java | 2 +- .../apache/spark/streaming/Java8APISuite.java | 1 - network/common/pom.xml | 6 - pom.xml | 11 -- project/MimaExcludes.scala | 11 +- .../apache/spark/streaming/StateSpec.scala | 12 +- .../streaming/api/java/JavaPairDStream.scala | 3 +- .../apache/spark/streaming/JavaAPISuite.java | 2 +- .../streaming/JavaMapWithStateSuite.java | 4 +- .../tools/JavaAPICompletenessChecker.scala | 2 +- 19 files changed, 333 insertions(+), 85 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/Optional.java create mode 100644 core/src/test/java/org/apache/spark/api/java/OptionalSuite.java diff --git a/core/src/main/java/org/apache/spark/api/java/Optional.java b/core/src/main/java/org/apache/spark/api/java/Optional.java new file mode 100644 index 0000000000000..ca7babc3f01c7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/Optional.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import java.io.Serializable; + +import com.google.common.base.Preconditions; + +/** + *

    Like {@code java.util.Optional} in Java 8, {@code scala.Option} in Scala, and + * {@code com.google.common.base.Optional} in Google Guava, this class represents a + * value of a given type that may or may not exist. It is used in methods that wish + * to optionally return a value, in preference to returning {@code null}.

    + * + *

    In fact, the class here is a reimplementation of the essential API of both + * {@code java.util.Optional} and {@code com.google.common.base.Optional}. From + * {@code java.util.Optional}, it implements:

    + * + *
      + *
    • {@link #empty()}
    • + *
    • {@link #of(Object)}
    • + *
    • {@link #ofNullable(Object)}
    • + *
    • {@link #get()}
    • + *
    • {@link #orElse(Object)}
    • + *
    • {@link #isPresent()}
    • + *
    + * + *

    From {@code com.google.common.base.Optional} it implements:

    + * + *
      + *
    • {@link #absent()}
    • + *
    • {@link #of(Object)}
    • + *
    • {@link #fromNullable(Object)}
    • + *
    • {@link #get()}
    • + *
    • {@link #or(Object)}
    • + *
    • {@link #orNull()}
    • + *
    • {@link #isPresent()}
    • + *
    + * + *

    {@code java.util.Optional} itself is not used at this time because the + * project does not require Java 8. Using {@code com.google.common.base.Optional} + * has in the past caused serious library version conflicts with Guava that can't + * be resolved by shading. Hence this work-alike clone.

    + * + * @param type of value held inside + */ +public final class Optional implements Serializable { + + private static final Optional EMPTY = new Optional<>(); + + private final T value; + + private Optional() { + this.value = null; + } + + private Optional(T value) { + Preconditions.checkNotNull(value); + this.value = value; + } + + // java.util.Optional API (subset) + + /** + * @return an empty {@code Optional} + */ + public static Optional empty() { + @SuppressWarnings("unchecked") + Optional t = (Optional) EMPTY; + return t; + } + + /** + * @param value non-null value to wrap + * @return {@code Optional} wrapping this value + * @throws NullPointerException if value is null + */ + public static Optional of(T value) { + return new Optional<>(value); + } + + /** + * @param value value to wrap, which may be null + * @return {@code Optional} wrapping this value, which may be empty + */ + public static Optional ofNullable(T value) { + if (value == null) { + return empty(); + } else { + return of(value); + } + } + + /** + * @return the value wrapped by this {@code Optional} + * @throws NullPointerException if this is empty (contains no value) + */ + public T get() { + Preconditions.checkNotNull(value); + return value; + } + + /** + * @param other value to return if this is empty + * @return this {@code Optional}'s value if present, or else the given value + */ + public T orElse(T other) { + return value != null ? value : other; + } + + /** + * @return true iff this {@code Optional} contains a value (non-empty) + */ + public boolean isPresent() { + return value != null; + } + + // Guava API (subset) + // of(), get() and isPresent() are identically present in the Guava API + + /** + * @return an empty {@code Optional} + */ + public static Optional absent() { + return empty(); + } + + /** + * @param value value to wrap, which may be null + * @return {@code Optional} wrapping this value, which may be empty + */ + public static Optional fromNullable(T value) { + return ofNullable(value); + } + + /** + * @param other value to return if this is empty + * @return this {@code Optional}'s value if present, or else the given value + */ + public T or(T other) { + return value != null ? value : other; + } + + /** + * @return this {@code Optional}'s value if present, or else null + */ + public T orNull() { + return value; + } + + // Common methods + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Optional)) { + return false; + } + Optional other = (Optional) obj; + return value == null ? other.value == null : value.equals(other.value); + } + + @Override + public int hashCode() { + return value == null ? 0 : value.hashCode(); + } + + @Override + public String toString() { + return value == null ? "Optional.empty" : String.format("Optional[%s]", value); + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 59af1052ebd05..fb04472ee73fd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{JobConf, OutputFormat} @@ -655,7 +654,6 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * keys; this also retains the original RDD's partitioning. */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { - import scala.collection.JavaConverters._ def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 242438237f987..0f8d13cf5cc2f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -24,7 +24,6 @@ import java.util.{Comparator, Iterator => JIterator, List => JList} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ @@ -122,7 +121,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { - import scala.collection.JavaConverters._ def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -132,7 +130,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { - import scala.collection.JavaConverters._ def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } @@ -142,7 +139,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - import scala.collection.JavaConverters._ def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 9990b22e14a25..01433ca2efc14 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index b2a4d053fa650..f820401da2fc3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -22,13 +22,12 @@ import java.util.Map.Entry import scala.collection.mutable -import com.google.common.base.Optional - private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = - option match { - case Some(value) => Optional.of(value) - case None => Optional.absent() + if (option.isDefined) { + Optional.of(option.get) + } else { + Optional.empty[T] } // Workaround for SPARK-3926 / SI-8911 diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 47382e4231563..44d5cac7c2de5 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,7 +21,17 @@ import java.nio.channels.FileChannel; import java.nio.ByteBuffer; import java.net.URI; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.*; import scala.Tuple2; @@ -35,7 +45,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.base.Throwables; -import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; @@ -49,7 +58,12 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.*; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaFutureAction; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; @@ -1785,32 +1799,6 @@ public void testAsyncActionErrorWrapping() throws Exception { Assert.assertTrue(future.isDone()); } - - /** - * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, - * since that's the only artifact where Guava classes have been relocated. - */ - @Test - public void testGuavaOptional() { - // Stop the context created in setUp() and start a local-cluster one, to force usage of the - // assembly. - sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); - try { - JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); - JavaRDD> rdd2 = rdd1.map( - new Function>() { - @Override - public Optional call(Integer i) { - return Optional.fromNullable(i); - } - }); - rdd2.collect(); - } finally { - localCluster.stop(); - } - } - static class Class1 {} static class Class2 {} diff --git a/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java b/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java new file mode 100644 index 0000000000000..4b97c18198c1a --- /dev/null +++ b/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests {@link Optional}. + */ +public class OptionalSuite { + + @Test + public void testEmpty() { + Assert.assertFalse(Optional.empty().isPresent()); + Assert.assertNull(Optional.empty().orNull()); + Assert.assertEquals("foo", Optional.empty().or("foo")); + Assert.assertEquals("foo", Optional.empty().orElse("foo")); + } + + @Test(expected = NullPointerException.class) + public void testEmptyGet() { + Optional.empty().get(); + } + + @Test + public void testAbsent() { + Assert.assertFalse(Optional.absent().isPresent()); + Assert.assertNull(Optional.absent().orNull()); + Assert.assertEquals("foo", Optional.absent().or("foo")); + Assert.assertEquals("foo", Optional.absent().orElse("foo")); + } + + @Test(expected = NullPointerException.class) + public void testAbsentGet() { + Optional.absent().get(); + } + + @Test + public void testOf() { + Assert.assertTrue(Optional.of(1).isPresent()); + Assert.assertNotNull(Optional.of(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).orElse(2)); + } + + @Test(expected = NullPointerException.class) + public void testOfWithNull() { + Optional.of(null); + } + + @Test + public void testOfNullable() { + Assert.assertTrue(Optional.ofNullable(1).isPresent()); + Assert.assertNotNull(Optional.ofNullable(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).orElse(2)); + Assert.assertFalse(Optional.ofNullable(null).isPresent()); + Assert.assertNull(Optional.ofNullable(null).orNull()); + Assert.assertEquals(Integer.valueOf(2), Optional.ofNullable(null).or(2)); + Assert.assertEquals(Integer.valueOf(2), Optional.ofNullable(null).orElse(2)); + } + + @Test + public void testFromNullable() { + Assert.assertTrue(Optional.fromNullable(1).isPresent()); + Assert.assertNotNull(Optional.fromNullable(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).orElse(2)); + Assert.assertFalse(Optional.fromNullable(null).isPresent()); + Assert.assertNull(Optional.fromNullable(null).orNull()); + Assert.assertEquals(Integer.valueOf(2), Optional.fromNullable(null).or(2)); + Assert.assertEquals(Integer.valueOf(2), Optional.fromNullable(null).orElse(2)); + } + +} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 1edc0fe34706b..8fd075d02b78e 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -881,7 +881,6 @@ Scala code, take a look at the example
    {% highlight java %} -import com.google.common.base.Optional; Function2, Optional, Optional> updateFunction = new Function2, Optional, Optional>() { @Override public Optional call(List values, Optional state) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 14997c64d505e..f52cc7c20576b 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -23,17 +23,14 @@ import scala.Tuple2; -import com.google.common.base.Optional; -import com.google.common.collect.Lists; - import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.State; import org.apache.spark.streaming.StateSpec; -import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.*; /** @@ -67,8 +64,8 @@ public static void main(String[] args) { // Initial state RDD input to mapWithState @SuppressWarnings("unchecked") - List> tuples = Arrays.asList(new Tuple2("hello", 1), - new Tuple2("world", 1)); + List> tuples = + Arrays.asList(new Tuple2<>("hello", 1), new Tuple2<>("world", 1)); JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples); JavaReceiverInputDStream lines = ssc.socketTextStream( @@ -77,7 +74,7 @@ public static void main(String[] args) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + return Arrays.asList(SPACE.split(x)); } }); @@ -85,18 +82,17 @@ public Iterable call(String x) { new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }); // Update the cumulative count function - final Function3, State, Tuple2> mappingFunc = + Function3, State, Tuple2> mappingFunc = new Function3, State, Tuple2>() { - @Override public Tuple2 call(String word, Optional one, State state) { - int sum = one.or(0) + (state.exists() ? state.get() : 0); - Tuple2 output = new Tuple2(word, sum); + int sum = one.orElse(0) + (state.exists() ? state.get() : 0); + Tuple2 output = new Tuple2<>(word, sum); state.update(sum); return output; } diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 14975265ab2ce..27d494ce355f7 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -24,7 +24,6 @@ import scala.Tuple2; import com.google.common.collect.Iterables; -import com.google.common.base.Optional; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -38,6 +37,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.util.Utils; diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index e8a0dfc0f0a5f..604d818ef1947 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -22,7 +22,6 @@ import scala.Tuple2; -import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.junit.Assert; diff --git a/network/common/pom.xml b/network/common/pom.xml index 32c34c63a45c5..92ca0046d4f53 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -52,15 +52,9 @@ com.google.code.findbugs jsr305 - com.google.guava guava - compile diff --git a/pom.xml b/pom.xml index e414a8bfe6ce5..9c975a45f8d23 100644 --- a/pom.xml +++ b/pom.xml @@ -2251,17 +2251,6 @@ com.google.common org.spark-project.guava - - - com/google/common/base/Absent* - com/google/common/base/Function - com/google/common/base/Optional* - com/google/common/base/Present* - com/google/common/base/Supplier - diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 40559a0910ce8..0d5f938d9ef5c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -57,7 +57,16 @@ object MimaExcludes { ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") - ) ++ + ) ++ + Seq( + // SPARK-4819 replace Guava Optional + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") + ) ++ Seq( // SPARK-12481 Remove Hadoop 1.x ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 0b094558dfd59..f1114c1e5ac6a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming -import com.google.common.base.Optional - import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils, Optional} import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner @@ -200,7 +198,11 @@ object StateSpec { StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s) - Option(t.orNull) + if (t.isPresent) { + Some(t.get) + } else { + None + } } StateSpec.function(wrappedFunc) } @@ -220,7 +222,7 @@ object StateSpec { mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { - mappingFunction.call(k, Optional.fromNullable(v.get), s) + mappingFunction.call(k, Optional.ofNullable(v.get), s) } StateSpec.function(wrappedFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index af0d84b33224f..d718f1d6fc43e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -25,14 +25,13 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.Partitioner import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils} +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils, Optional} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index ddc56fc869ae1..4dbcef293487c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -33,7 +33,6 @@ import org.junit.Assert; import org.junit.Test; -import com.google.common.base.Optional; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -43,6 +42,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index 20e2a1c3d5c31..9b7701003d8d0 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -26,7 +26,6 @@ import scala.Tuple2; -import com.google.common.base.Optional; import com.google.common.collect.Sets; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; @@ -38,6 +37,7 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.Function3; import org.apache.spark.api.java.function.Function4; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -139,7 +139,7 @@ public void testBasicFunction() { new Function3, State, Integer>() { @Override public Integer call(String key, Optional value, State state) { - int sum = value.or(0) + (state.exists() ? state.get() : 0); + int sum = value.orElse(0) + (state.exists() ? state.get() : 0); state.update(sum); return sum; } diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 6fb7184e877ee..ccd8fd3969f61 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -161,7 +161,7 @@ object JavaAPICompletenessChecker { } case "scala.Option" => { if (isReturnType) { - ParameterizedType("com.google.common.base.Optional", parameters.map(applySubs)) + ParameterizedType("org.apache.spark.api.java.Optional", parameters.map(applySubs)) } else { applySubs(parameters(0)) } From d9447cac747823e71b676c08c75f4aab34de12a2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 8 Jan 2016 14:08:13 -0800 Subject: [PATCH 1503/2089] [SPARK-12593][SQL] Converts resolved logical plan back to SQL This PR tries to enable Spark SQL to convert resolved logical plans back to SQL query strings. For now, the major use case is to canonicalize Spark SQL native view support. The major entry point is `SQLBuilder.toSQL`, which returns an `Option[String]` if the logical plan is recognized. The current version is still in WIP status, and is quite limited. Known limitations include: 1. The logical plan must be analyzed but not optimized The optimizer erases `Subquery` operators, which contain necessary scope information for SQL generation. Future versions should be able to recover erased scope information by inserting subqueries when necessary. 1. The logical plan must be created using HiveQL query string Query plans generated by composing arbitrary DataFrame API combinations are not supported yet. Operators within these query plans need to be rearranged into a canonical form that is more suitable for direct SQL generation. For example, the following query plan ``` Filter (a#1 < 10) +- MetastoreRelation default, src, None ``` need to be canonicalized into the following form before SQL generation: ``` Project [a#1, b#2, c#3] +- Filter (a#1 < 10) +- MetastoreRelation default, src, None ``` Otherwise, the SQL generation process will have to handle a large number of special cases. 1. Only a fraction of expressions and basic logical plan operators are supported in this PR Currently, 95.7% (1720 out of 1798) query plans in `HiveCompatibilitySuite` can be successfully converted to SQL query strings. Known unsupported components are: - Expressions - Part of math expressions - Part of string expressions (buggy?) - Null expressions - Calendar interval literal - Part of date time expressions - Complex type creators - Special `NOT` expressions, e.g. `NOT LIKE` and `NOT IN` - Logical plan operators/patterns - Cube, rollup, and grouping set - Script transformation - Generator - Distinct aggregation patterns that fit `DistinctAggregationRewriter` analysis rule - Window functions Support for window functions, generators, and cubes etc. will be added in follow-up PRs. This PR leverages `HiveCompatibilitySuite` for testing SQL generation in a "round-trip" manner: * For all select queries, we try to convert it back to SQL * If the query plan is convertible, we parse the generated SQL into a new logical plan * Run the new logical plan instead of the original one If the query plan is inconvertible, the test case simply falls back to the original logic. TODO - [x] Fix failed test cases - [x] Support for more basic expressions and logical plan operators (e.g. distinct aggregation etc.) - [x] Comments and documentation Author: Cheng Lian Closes #10541 from liancheng/sql-generation. --- .../sql/catalyst/parser/SparkSqlParser.g | 48 ++-- .../sql/catalyst/analysis/Analyzer.scala | 20 +- .../spark/sql/catalyst/analysis/Catalog.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 8 + .../sql/catalyst/expressions/Expression.scala | 23 +- .../catalyst/expressions/InputFileName.scala | 1 + .../MonotonicallyIncreasingID.scala | 4 + .../sql/catalyst/expressions/SortOrder.scala | 14 +- .../expressions/aggregate/interfaces.scala | 14 +- .../sql/catalyst/expressions/arithmetic.scala | 8 + .../expressions/complexTypeExtractors.scala | 2 + .../expressions/conditionalExpressions.scala | 41 ++- .../expressions/datetimeExpressions.scala | 22 ++ .../expressions/decimalExpressions.scala | 3 + .../sql/catalyst/expressions/literals.scala | 37 ++- .../expressions/mathExpressions.scala | 2 + .../spark/sql/catalyst/expressions/misc.scala | 4 + .../expressions/namedExpressions.scala | 12 + .../expressions/nullExpressions.scala | 6 + .../sql/catalyst/expressions/predicates.scala | 19 ++ .../expressions/randomExpressions.scala | 3 + .../expressions/regexpExpressions.scala | 2 + .../expressions/stringExpressions.scala | 28 +- .../sql/catalyst/optimizer/Optimizer.scala | 52 ++++ .../spark/sql/catalyst/plans/joinTypes.scala | 24 +- .../plans/logical/basicOperators.scala | 1 + .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../spark/sql/catalyst/util/package.scala | 14 + .../apache/spark/sql/types/ArrayType.scala | 2 + .../org/apache/spark/sql/types/DataType.scala | 2 + .../org/apache/spark/sql/types/MapType.scala | 2 + .../apache/spark/sql/types/StructType.scala | 5 + .../spark/sql/types/UserDefinedType.scala | 2 + .../sql/catalyst/analysis/AnalysisSuite.scala | 38 --- .../optimizer/ComputeCurrentTimeSuite.scala | 68 +++++ .../optimizer/FilterPushdownSuite.scala | 6 +- .../datasources/parquet/ParquetRelation.scala | 16 +- .../execution/HiveCompatibilitySuite.scala | 12 +- .../HiveWindowFunctionQuerySuite.scala | 1 + .../org/apache/spark/sql/hive/HiveQl.scala | 3 +- .../apache/spark/sql/hive/SQLBuilder.scala | 244 ++++++++++++++++++ .../org/apache/spark/sql/hive/hiveUDFs.scala | 48 ++-- .../sql/hive/ExpressionSQLBuilderSuite.scala | 75 ++++++ .../sql/hive/LogicalPlanToSQLSuite.scala | 146 +++++++++++ .../spark/sql/hive/SQLBuilderTest.scala | 74 ++++++ .../hive/execution/HiveComparisonTest.scala | 70 ++++- .../sql/hive/execution/HiveQuerySuite.scala | 1 + 47 files changed, 1087 insertions(+), 146 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index b04bb677774c5..2c13d3056f468 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -1,9 +1,9 @@ /** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with + (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -582,7 +582,7 @@ import java.util.HashMap; return header; } - + @Override public String getErrorMessage(RecognitionException e, String[] tokenNames) { String msg = null; @@ -619,7 +619,7 @@ import java.util.HashMap; } return msg; } - + public void pushMsg(String msg, RecognizerSharedState state) { // ANTLR generated code does not wrap the @init code wit this backtracking check, // even if the matching @after has it. If we have parser rules with that are doing @@ -639,7 +639,7 @@ import java.util.HashMap; // counter to generate unique union aliases private int aliasCounter; private String generateUnionAlias() { - return "_u" + (++aliasCounter); + return "u_" + (++aliasCounter); } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { @@ -1235,7 +1235,7 @@ alterTblPartitionStatementSuffixSkewedLocation : KW_SET KW_SKEWED KW_LOCATION skewedLocations -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) ; - + skewedLocations @init { pushMsg("skewed locations", state); } @after { popMsg(state); } @@ -1264,7 +1264,7 @@ alterStatementSuffixLocation -> ^(TOK_ALTERTABLE_LOCATION $newLoc) ; - + alterStatementSuffixSkewedby @init {pushMsg("alter skewed by statement", state);} @after{popMsg(state);} @@ -1336,10 +1336,10 @@ tabTypeExpr (identifier (DOT^ ( (KW_ELEM_TYPE) => KW_ELEM_TYPE - | + | (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier ))* )? @@ -1376,7 +1376,7 @@ descStatement analyzeStatement @init { pushMsg("analyze statement", state); } @after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) ; @@ -1389,7 +1389,7 @@ showStatement | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? -> ^(TOK_SHOWCOLUMNS tableName $db_name?) | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) | KW_SHOW KW_CREATE ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) | @@ -1398,7 +1398,7 @@ showStatement | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS + | KW_SHOW KW_LOCKS ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) | @@ -1511,7 +1511,7 @@ showCurrentRole setRole @init {pushMsg("set role", state);} @after {popMsg(state);} - : KW_SET KW_ROLE + : KW_SET KW_ROLE ( (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) | @@ -1966,7 +1966,7 @@ columnNameOrderList skewedValueElement @init { pushMsg("skewed value element", state); } @after { popMsg(state); } - : + : skewedColumnValues | skewedColumnValuePairList ; @@ -1980,8 +1980,8 @@ skewedColumnValuePairList skewedColumnValuePair @init { pushMsg("column value pair", state); } @after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN + : + LPAREN colValues=skewedColumnValues RPAREN -> ^(TOK_TABCOLVALUES $colValues) ; @@ -2001,11 +2001,11 @@ skewedColumnValue skewedValueLocationElement @init { pushMsg("skewed value location element", state); } @after { popMsg(state); } - : + : skewedColumnValue | skewedColumnValuePair ; - + columnNameOrder @init { pushMsg("column name order", state); } @after { popMsg(state); } @@ -2118,7 +2118,7 @@ unionType @after { popMsg(state); } : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) ; - + setOperator @init { pushMsg("set operator", state); } @after { popMsg(state); } @@ -2172,7 +2172,7 @@ fromStatement[boolean topLevel] {adaptor.create(Identifier, generateUnionAlias())} ) ) - ^(TOK_INSERT + ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) ) @@ -2414,8 +2414,8 @@ setColumnsClause KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) ; -/* - UPDATE +/* + UPDATE
    SET col1 = val1, col2 = val2... WHERE ... */ updateStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e362b55d80cd1..8a33af8207350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,8 +86,7 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic, - ComputeCurrentTime), + PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - /** * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e8b2fcf819bf6..a8f89ce6de457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias + .map(a => Subquery(a, tableWithQualifiers)) + .getOrElse(tableWithQualifiers) } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d82d3edae4e38..6f199cfc5d8cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -931,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { $evPrim = $result.copy(); """ } + + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this + // type of casting can only be introduced by the analyzer, and can be omitted when converting + // back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a9c12127d367..d6219514b752b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] { protected def toCommentSafeString: String = this.toString .replace("*/", "\\*\\/") .replace("\\u", "\\\\u") + + /** + * Returns SQL representation of this expression. For expressions that don't have a SQL + * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. + */ + @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") + def sql: String = throw new UnsupportedOperationException( + s"Cannot map expression $this to its SQL representation" + ) } @@ -356,6 +366,8 @@ abstract class UnaryExpression extends Expression { """ } } + + override def sql: String = s"($prettyName(${child.sql}))" } @@ -456,6 +468,8 @@ abstract class BinaryExpression extends Expression { """ } } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -492,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { TypeCheckResult.TypeCheckSuccess } } + + override def sql: String = s"(${left.sql} $symbol ${right.sql})" } @@ -593,4 +609,9 @@ abstract class TernaryExpression extends Expression { """ } } + + override def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index f33833c3918df..827dce8af100e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic { "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } + override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index d0b78e15d99d1..94f8801dec369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with $countTerm++; """ } + + override def prettyName: String = "monotonically_increasing_id" + + override def sql: String = s"$prettyName()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3add722da7816..1cb1b9da3049b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -24,9 +24,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator -abstract sealed class SortDirection -case object Ascending extends SortDirection -case object Descending extends SortDirection +abstract sealed class SortDirection { + def sql: String +} + +case object Ascending extends SortDirection { + override def sql: String = "ASC" +} + +case object Descending extends SortDirection { + override def sql: String = "DESC" +} /** * An expression that can be used to sort a tuple. This class extends expression primarily so that diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b47f32d1768b9..ddd99c51ab0c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -93,11 +94,13 @@ private[sql] case class AggregateExpression( override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + + override def sql: String = aggregateFunction.sql(isDistinct) } /** - * AggregateFunction2 is the superclass of two aggregation function interfaces: + * AggregateFunction is the superclass of two aggregation function interfaces: * * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. @@ -163,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } + + def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0fe..7bd851c059d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp numeric.negate(input) } } + + override def sql: String = s"(-${child.sql})" } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input + + override def sql: String = s"(+${child.sql})" } /** @@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { @@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val r = a % n if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c73239f67ff2..5bd97cc7467ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } + + override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f79c8676fb58c..19da849d2bec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} import org.apache.spark.sql.types._ @@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } override def toString: String = s"if ($predicate) $trueValue else $falseValue" + + override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" } trait CaseWhenLike extends Expression { @@ -110,7 +112,7 @@ trait CaseWhenLike extends Expression { override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) } } @@ -206,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } // scalastyle:off @@ -310,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val keySQL = key.sql + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE $keySQL " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 3d65946a1bc65..17f1df06f2fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { DateTimeUtils.millisToDays(System.currentTimeMillis()) } + + override def prettyName: String = "current_date" } /** @@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { System.currentTimeMillis() * 1000L } + + override def prettyName: String = "current_timestamp" } /** @@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression) s"""${ev.value} = $sd + $d;""" }) } + + override def prettyName: String = "date_add" } /** @@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression) s"""${ev.value} = $sd - $d;""" }) } + + override def prettyName: String = "date_sub" } case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } + + override def prettyName: String = "to_unix_timestamp" } /** @@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi def this() = { this(CurrentTimestamp()) } + + override def prettyName: String = "unix_timestamp" } abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { @@ -437,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { """ } } + + override def prettyName: String = "unix_time" } /** @@ -451,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression) override def left: Expression = sec override def right: Expression = format + override def prettyName: String = "from_unixtime" + def this(unix: Expression) = { this(unix, Literal("yyyy-MM-dd HH:mm:ss")) } @@ -733,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression) s"""$dtu.dateAddMonths($sd, $m)""" }) } + + override def prettyName: String = "add_months" } /** @@ -758,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression) s"""$dtu.monthsBetween($l, $r)""" }) } + + override def prettyName: String = "months_between" } /** @@ -823,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, d => d) } + + override def prettyName: String = "to_date" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c54bcdd774021..5f8b544edb511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def prettyName: String = "promote_precision" + override def sql: String = child.sql } /** @@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 672cc9c45e0af..0eb915fdc1691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,9 +21,9 @@ import java.sql.{Date, Timestamp} import org.json4s.JsonAST._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ @@ -214,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType) } } } + + override def sql: String = (value, dataType) match { + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => + "NULL" + + case _ if value == null => + s"CAST(NULL AS ${dataType.sql})" + + case (v: UTF8String, StringType) => + // Escapes all backslashes and double quotes. + "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + + case (v: Byte, ByteType) => + s"CAST($v AS ${ByteType.simpleString.toUpperCase})" + + case (v: Short, ShortType) => + s"CAST($v AS ${ShortType.simpleString.toUpperCase})" + + case (v: Long, LongType) => + s"CAST($v AS ${LongType.simpleString.toUpperCase})" + + case (v: Float, FloatType) => + s"CAST($v AS ${FloatType.simpleString.toUpperCase})" + + case (v: Decimal, DecimalType.Fixed(precision, scale)) => + s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" + + case (v: Int, DateType) => + s"DATE '${DateTimeUtils.toJavaDate(v)}'" + + case (v: Long, TimestampType) => + s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" + + case _ => value.toString + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 002f5929cc26b..66d8631a846ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } + + override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index fd95b124b2455..cc406a39f0408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); """ } + + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eefd9c7482553..eee708cb02f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)( explicitMetadata == a.explicitMetadata case _ => false } + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"${child.sql} AS $qualifiersString`$name`" + } } /** @@ -271,6 +277,12 @@ case class AttributeReference( // Since the expression id is not in the first constructor it is missing from the default // tree string. override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"$qualifiersString`$name`" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index df4747d4e6f7a..89aec2b20fd0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.value = eval.isNull eval.code } + + override def sql: String = s"(${child.sql} IS NULL)" } @@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { ev.value = s"(!(${eval.isNull}))" eval.code } + + override def sql: String = s"(${child.sql} IS NOT NULL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba4..bca12a8d21023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -101,6 +101,8 @@ case class Not(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } + + override def sql: String = s"(NOT ${child.sql})" } @@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } """ } + + override def sql: String = { + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") + s"($valueSQL IN ($listSQL))" + } } /** @@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } """ } + + override def sql: String = { + val valueSQL = child.sql + val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + s"($valueSQL IN ($listSQL))" + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { @@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } + + override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } + + override def sql: String = s"(${left.sql} OR ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bde8cb9fe876..8de47e9ddc28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = DoubleType + + // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. + override def sql: String = s"$prettyName($seed)" } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index adef6050c3565..db266639b8560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { matches(regex, input1.asInstanceOf[UTF8String].toString) } } + + override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 50c8b9d59847e..931f752b4dc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression]) """ } } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" + s"$termDict == null" } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); + $termLastMatching = $matching.clone(); + $termLastReplace = $replace.clone(); + $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatching, $termLastReplace); } - ${ev.value} = ${src}.translate(${termDict}); + ${ev.value} = $src.translate($termDict); """ }) } @@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def dataType: DataType = IntegerType + + override def prettyName: String = "find_in_set" } /** @@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } - } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8b..f8121a733a8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Compute Current Time", Once, + ComputeCurrentTime) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -333,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] { ) Project(cleanedProjection, child) } + + // TODO Eliminate duplicate code + // This clause is identical to the one above except that the inner operator is an `Aggregate` + // rather than a `Project`. + case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = AttributeMap(projectList2.collect { + case a: Alias => (a.toAttribute, a) + }) + + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.exists(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }.exists(!_.deterministic)) + + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + agg.copy(aggregateExpressions = cleanedProjection) + } } } @@ -976,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b5..a5f6764aef7ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -37,14 +37,26 @@ object JoinType { } } -sealed abstract class JoinType +sealed abstract class JoinType { + def sql: String +} -case object Inner extends JoinType +case object Inner extends JoinType { + override def sql: String = "INNER" +} -case object LeftOuter extends JoinType +case object LeftOuter extends JoinType { + override def sql: String = "LEFT OUTER" +} -case object RightOuter extends JoinType +case object RightOuter extends JoinType { + override def sql: String = "RIGHT OUTER" +} -case object FullOuter extends JoinType +case object FullOuter extends JoinType { + override def sql: String = "FULL OUTER" +} -case object LeftSemi extends JoinType +case object LeftSemi extends JoinType { + override def sql: String = "LEFT SEMI" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 79759b5a37b34..64957db6b4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -423,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 62ea731ab5f38..9ebacb4680dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -37,7 +37,7 @@ object RuleExecutor { val maxSize = map.keys.map(_.toString.length).max map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n") + }.mkString("\n", "\n", "") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 71293475ca0f9..7a0d0de6328a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -130,6 +130,20 @@ package object util { ret } + /** + * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. + */ + def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { + case xs if xs.isEmpty => + Option(Seq.empty[T]) + + case xs => + for { + head <- xs.head + tail <- sequenceOption(xs.tail) + } yield head +: tail + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 6533622492d41..520e344361625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 136a97e066df7..92cf8d4c46bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString + def sql: String = simpleString.toUpperCase + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 00461e529ca0a..5474954af70e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,6 +62,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 34382bf124eb0..9b5c86a8984be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -279,6 +279,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + override def sql: String = { + val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + s"STRUCT<${fieldTypes.mkString(", ")}>" + } + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd9..d7a2c23be8a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def sql: String = sqlType.sql } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fa823e3021835..cf84855885a37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } - test("analyzer should replace current_timestamp with literals") { - val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), - LocalRelation()) - - val min = System.currentTimeMillis() * 1000 - val plan = in.analyze.asInstanceOf[Project] - val max = (System.currentTimeMillis() + 1) * 1000 - - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - - test("analyzer should replace current_date with literals") { - val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) - - val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val plan = in.analyze.asInstanceOf[Project] - val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) - - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala new file mode 100644 index 0000000000000..10ed4e46ddd1c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class ComputeCurrentTimeSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) + } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b998636909a7d..f9f3bd55aa578 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a) - .select('a).analyze + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c) - .select('c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4b375de05e9e3..ca8d010090401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.{List => JList} import java.util.logging.{Logger => JLogger} +import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -32,24 +32,24 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl -import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { @@ -147,6 +147,12 @@ private[sql] class ParquetRelation( .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) + // If this relation is converted from a Hive metastore table, this method returns the name of the + // original Hive metastore table. + private[sql] def metastoreTableName: Option[TableIdentifier] = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier) + } + private lazy val metadataCache: MetadataCache = { val meta = new MetadataCache meta.refresh() diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd1a52e5f3303..afd2f611580fc 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -68,10 +71,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) + super.afterAll() } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList = Seq( + override def blackList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -106,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - //"describe_table", + // "describe_table", "describe_comment_nonascii", "create_merge_compressed", @@ -323,7 +327,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { * The set of tests that are believed to be working in catalyst. Tests not on whiteList or * blacklist are implicitly marked as ignored. */ - override def whiteList = Seq( + override def whiteList: Seq[String] = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 98bbdf0653c2a..bad3ca6da231f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.reset() + super.afterAll() } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bf3fe12d5c5d2..5b13dbe47370e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -668,7 +668,8 @@ private[hive] object HiveQl extends SparkQl with Logging { Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + HiveGenericUDTF( + functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) case other => super.nodeToGenerator(node) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala new file mode 100644 index 0000000000000..1c910051faccf --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.{DataFrame, SQLContext} + +/** + * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * all resolved logical plan are convertible. They either don't have corresponding SQL + * representations (e.g. logical plans that operate on local Scala collections), or are simply not + * supported by this builder (yet). + */ +class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + + def toSQL: Option[String] = { + val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val maybeSQL = try { + toSQL(canonicalizedPlan) + } catch { case cause: UnsupportedOperationException => + logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") + None + } + + if (maybeSQL.isDefined) { + logDebug( + s"""Built SQL query string successfully from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + |# Built SQL query string: + |${maybeSQL.get} + """.stripMargin) + } else { + logDebug( + s"""Failed to build SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + } + + maybeSQL + } + + private def projectToSQL( + projectList: Seq[NamedExpression], + child: LogicalPlan, + isDistinct: Boolean): Option[String] = { + for { + childSQL <- toSQL(child) + listSQL = projectList.map(_.sql).mkString(", ") + maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + distinct = if (isDistinct) " DISTINCT " else " " + } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" + } + + private def aggregateToSQL( + groupingExprs: Seq[Expression], + aggExprs: Seq[Expression], + child: LogicalPlan): Option[String] = { + val aggSQL = aggExprs.map(_.sql).mkString(", ") + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + val maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + + toSQL(child).map { childSQL => + s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" + } + } + + private def toSQL(node: LogicalPlan): Option[String] = node match { + case Distinct(Project(list, child)) => + projectToSQL(list, child, isDistinct = true) + + case Project(list, child) => + projectToSQL(list, child, isDistinct = false) + + case Aggregate(groupingExprs, aggExprs, child) => + aggregateToSQL(groupingExprs, aggExprs, child) + + case Limit(limit, child) => + for { + childSQL <- toSQL(child) + limitSQL = limit.sql + } yield s"$childSQL LIMIT $limitSQL" + + case Filter(condition, child) => + for { + childSQL <- toSQL(child) + whereOrHaving = child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + conditionSQL = condition.sql + } yield s"$childSQL $whereOrHaving $conditionSQL" + + case Union(left, right) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + } yield s"$leftSQL UNION ALL $rightSQL" + + // ParquetRelation converted from Hive metastore table + case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => + // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is + // that, the metastore database name and table name are not always propagated to converted + // `ParquetRelation` instances via data source options. Here we use subquery alias as a + // workaround. + Some(s"`$alias`") + + case Subquery(alias, child) => + toSQL(child).map(childSQL => s"($childSQL) AS $alias") + + case Join(left, right, joinType, condition) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + joinTypeSQL = joinType.sql + conditionSQL = condition.map(" ON " + _.sql).getOrElse("") + } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" + + case MetastoreRelation(database, table, alias) => + val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") + Some(s"`$database`.`$table`$aliasSQL") + + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) + if orders.map(_.child) == partitionExprs => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL CLUSTER BY $partitionExprsSQL" + + case Sort(orders, global, child) => + for { + childSQL <- toSQL(child) + ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + orderOrSort = if (global) "ORDER" else "SORT" + } yield s"$childSQL $orderOrSort BY $ordersSQL" + + case RepartitionByExpression(partitionExprs, child, _) => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" + + case OneRowRelation => + Some("") + + case _ => None + } + + object Canonicalizer extends RuleExecutor[LogicalPlan] { + override protected def batches: Seq[Batch] = Seq( + Batch("Canonicalizer", FixedPoint(100), + // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over + // `Aggregate`s to perform type casting. This rule merges these `Project`s into + // `Aggregate`s. + ProjectCollapsing, + + // Used to handle other auxiliary `Project`s added by analyzer (e.g. + // `ResolveAggregateFunctions` rule) + RecoverScopingInfo + ) + ) + + object RecoverScopingInfo extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + // This branch handles aggregate functions within HAVING clauses. For example: + // + // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" + // + // This kind of query results in query plans of the following form because of analysis rule + // `ResolveAggregateFunctions`: + // + // Project ... + // +- Filter ... + // +- Aggregate ... + // +- MetastoreRelation default, src, None + case plan @ Project(_, Filter(_, _: Aggregate)) => + wrapChildWithSubquery(plan) + + case plan @ Project(_, + _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit + ) => plan + + case plan: Project => + wrapChildWithSubquery(plan) + } + + def wrapChildWithSubquery(project: Project): Project = project match { + case Project(projectList, child) => + val alias = SQLBuilder.newSubqueryName + val childAttributes = child.outputSet + val aliasedProjectList = projectList.map(_.transform { + case a: Attribute if childAttributes.contains(a) => + a.withQualifiers(alias :: Nil) + }.asInstanceOf[NamedExpression]) + + Project(aliasedProjectList, Subquery(alias, child)) + } + } + } +} + +object SQLBuilder { + private val nextSubqueryId = new AtomicLong(0) + + private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index b1a6d0ab7df3c..e76c18fa528f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,30 +17,26 @@ package org.apache.spark.sql.hive -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.catalyst.{InternalRow, analysis} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ @@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry( try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDF( - new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. udtf } else { @@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry( } } -private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic @@ -191,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -205,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def nullable: Boolean = true @@ -257,6 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -271,6 +273,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUDTF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors with CodegenFallback { @@ -336,6 +339,8 @@ private[hive] case class HiveGenericUDTF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -343,6 +348,7 @@ private[hive] case class HiveGenericUDTF( * performance a lot. */ private[hive] case class HiveUDAFFunction( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false, @@ -427,5 +433,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false override val dataType: DataType = inspectorToDataType(returnInspector) -} + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$name($distinct${children.map(_.sql).mkString(", ")})" + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala new file mode 100644 index 0000000000000..3a6eb57add4e3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} + +class ExpressionSQLBuilderSuite extends SQLBuilderTest { + test("literal") { + checkSQL(Literal("foo"), "\"foo\"") + checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") + checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(4: Int), "4") + checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(2.5D), "2.5") + checkSQL( + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), + "TIMESTAMP('2016-01-01 00:00:00.0')") + // TODO tests for decimals + } + + test("binary comparisons") { + checkSQL('a.int === 'b.int, "(`a` = `b`)") + checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") + checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))") + + checkSQL('a.int < 'b.int, "(`a` < `b`)") + checkSQL('a.int <= 'b.int, "(`a` <= `b`)") + checkSQL('a.int > 'b.int, "(`a` > `b`)") + checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + + checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") + checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + + checkSQL('a.int.isNull, "(`a` IS NULL)") + checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + } + + test("logical operators") { + checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") + checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") + checkSQL(!'a.boolean, "(NOT `a`)") + checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + } + + test("arithmetic expressions") { + checkSQL('a.int + 'b.int, "(`a` + `b`)") + checkSQL('a.int - 'b.int, "(`a` - `b`)") + checkSQL('a.int * 'b.int, "(`a` * `b`)") + checkSQL('a.int / 'b.int, "(`a` / `b`)") + checkSQL('a.int % 'b.int, "(`a` % `b`)") + + checkSQL(-'a.int, "(-`a`)") + checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala new file mode 100644 index 0000000000000..0e81acf532a03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.functions._ + +class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sqlContext.range(10).write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + } + + override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + } + + private def checkHiveQl(hiveQl: String): Unit = { + val df = sql(hiveQl) + val convertedSQL = new SQLBuilder(df).toSQL + + if (convertedSQL.isEmpty) { + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + val sqlString = convertedSQL.get + try { + checkAnswer(sql(sqlString), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$sqlString + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } + } + + test("in") { + checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") + } + + test("aggregate function in having clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") + } + + test("aggregate function in order by clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") + } + + // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule + // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into + // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query + // execution since these aliases have different expression ID. But this introduces name collision + // when converting resolved plans back to SQL query strings as expression IDs are stripped. + ignore("aggregate function in order by clause with multiple order keys") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") + } + + test("type widening in union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") + } + + test("case") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") + } + + test("case with else") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0") + } + + test("case with key") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0") + } + + test("case with key and else") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0") + } + + test("select distinct without aggregate functions") { + checkHiveQl("SELECT DISTINCT id FROM t0") + } + + test("cluster by") { + checkHiveQl("SELECT id FROM t0 CLUSTER BY id") + } + + test("distribute by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id") + } + + test("distribute by with sort by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id") + } + + test("distinct aggregation") { + checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0") + } + + // TODO Enable this + // Query plans transformed by DistinctAggregationRewriter are not recognized yet + ignore("distinct and non-distinct aggregation") { + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala new file mode 100644 index 0000000000000..cf4a3fdd88806 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{DataFrame, QueryTest} + +abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + + protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { + val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL + + if (maybeSQL.isEmpty) { + fail( + s"""Cannot convert the following logical query plan to SQL: + | + |${plan.treeString} + """.stripMargin) + } + + val actualSQL = maybeSQL.get + + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following logical query plan: + | + |${plan.treeString} + | + |$cause + """.stripMargin) + } + + checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + } + + protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { + checkSQL(df.queryExecution.analyzed, expectedSQL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index d7e8ebc8d312f..57358a07840e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,9 +27,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} /** * Allows the creations of tests that execute the same query against both hive @@ -130,6 +131,28 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } + /** Used for testing [[SQLBuilder]] */ + private var numConvertibleQueries: Int = 0 + private var numTotalQueries: Int = 0 + + override protected def afterAll(): Unit = { + logInfo({ + val percentage = if (numTotalQueries > 0) { + numConvertibleQueries.toDouble / numTotalQueries * 100 + } else { + 0D + } + + s"""SQLBuiler statistics: + |- Total query number: $numTotalQueries + |- Number of convertible queries: $numConvertibleQueries + |- Percentage of convertible queries: $percentage% + """.stripMargin + }) + + super.afterAll() + } + protected def prepareAnswer( hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { @@ -372,8 +395,49 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) - try { (query, prepareAnswer(query, query.stringResult())) } catch { + var query: TestHive.QueryExecution = null + try { + query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + numTotalQueries += 1 + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + numConvertibleQueries += 1 + logInfo( + s""" + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} + """.stripMargin.trim) + new TestHive.QueryExecution(sql) + }.getOrElse { + logInfo( + s""" + |### Cannot convert the following logical plan back to SQL {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + |}}} + """.stripMargin.trim) + originalQuery + } + } + } + + (query, prepareAnswer(query, query.stringResult())) + } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fa99289b41971..4659d745fe78b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION udtf_count2") + super.afterAll() } test("SPARK-4908: concurrent hive native commands") { From 1fdf9bbd67b884f23150b651f0fefdab6ccf008a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 Jan 2016 20:50:08 -0800 Subject: [PATCH 1504/2089] [SPARK-12730][TESTS] De-duplicate some test code in BlockManagerSuite This patch deduplicates some test code in BlockManagerSuite. I'm splitting this change off from a larger PR in order to make things easier to review. Author: Josh Rosen Closes #10667 from JoshRosen/block-mgr-tests-cleanup. --- .../spark/storage/BlockManagerSuite.scala | 88 ++++++------------- 1 file changed, 25 insertions(+), 63 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 21db3b1c9ffbd..67210e5d4c50e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -505,38 +505,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("in-memory LRU storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY) } test("in-memory LRU storage with serialization") { + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY_SER) + } + + private def testInMemoryLRUStorage(storageLevel: StorageLevel): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) assert(store.getSingle("a2").isDefined, "a2 was not in store") assert(store.getSingle("a3").isDefined, "a3 was not in store") assert(store.getSingle("a1") === None, "a1 was in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a1", a1, storageLevel) assert(store.getSingle("a1").isDefined, "a1 was not in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") assert(store.getSingle("a3") === None, "a3 was in store") @@ -618,62 +607,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("disk and memory storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingle) } test("disk and memory storage with getLocalBytes") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytes) } test("disk and memory storage with serialization") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingle) } test("disk and memory storage with serialization and getLocalBytes") { + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytes) + } + + def testDiskAndMemoryStorage( + storageLevel: StorageLevel, + accessMethod: BlockManager => BlockId => Option[_]): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) + assert(accessMethod(store)("a2").isDefined, "a2 was not in store") + assert(accessMethod(store)("a3").isDefined, "a3 was not in store") + assert(store.memoryStore.getValues("a1").isEmpty, "a1 was in memory store") + assert(accessMethod(store)("a1").isDefined, "a1 was not in store") assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") } From 090d691323063c436601943506baac3ec5255dd9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 Jan 2016 20:58:53 -0800 Subject: [PATCH 1505/2089] [SPARK-4628][BUILD] Remove all non-Maven-Central repositories from build This patch removes all non-Maven-central repositories from Spark's build, thereby avoiding any risk of future build-breaks due to us accidentally depending on an artifact which is not present in an immutable public Maven repository. I tested this by running ``` build/mvn \ -Phive \ -Phive-thriftserver \ -Pkinesis-asl \ -Pspark-ganglia-lgpl \ -Pyarn \ dependency:go-offline ``` inside of a fresh Ubuntu Docker container with no Ivy or Maven caches (I did a similar test for SBT). Author: Josh Rosen Closes #10659 from JoshRosen/SPARK-4628. --- external/mqtt/pom.xml | 2 +- pom.xml | 87 ---------------------------------------- project/SparkBuild.scala | 7 +++- project/plugins.sbt | 6 --- 4 files changed, 7 insertions(+), 95 deletions(-) diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index b3ba72a0087ad..d3a2bf5825b08 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -51,7 +51,7 @@ org.eclipse.paho org.eclipse.paho.client.mqttv3 - 1.0.1 + 1.0.2 org.scalacheck diff --git a/pom.xml b/pom.xml index 9c975a45f8d23..0eac212754320 100644 --- a/pom.xml +++ b/pom.xml @@ -226,93 +226,6 @@ false - - apache-repo - Apache Repository - https://repository.apache.org/content/repositories/releases - - true - - - false - - - - jboss-repo - JBoss Repository - https://repository.jboss.org/nexus/content/repositories/releases - - true - - - false - - - - mqtt-repo - MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases - - true - - - false - - - - cloudera-repo - Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos - - true - - - false - - - - spark-hive-staging - Staging Repo for Hive 1.2.1 (Spark Version) - https://oss.sonatype.org/content/repositories/orgspark-project-1113 - - true - - - - mapr-repo - MapR Repository - http://repository.mapr.com/maven/ - - true - - - false - - - - - spring-releases - Spring Release Repository - https://repo.spring.io/libs-release - - false - - - false - - - - - twttr-repo - Twttr Repository - http://maven.twttr.com - - true - - - false - - diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5d4f19ab14a29..4c34c888cfd5e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -141,7 +141,12 @@ object SparkBuild extends PomBuild { publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", - resolvers += Resolver.mavenLocal, + // Override SBT's default resolvers: + resolvers := Seq( + DefaultMavenRepository, + Resolver.mavenLocal + ), + externalResolvers := resolvers.value, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) diff --git a/project/plugins.sbt b/project/plugins.sbt index 15ba3a36d51ca..822a7c4a82d5e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,9 +1,3 @@ -resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) - -resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" - -resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" - addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") From 95cd5d95ce8aec8b2462204c791ba927326305ba Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 Jan 2016 21:48:06 -0800 Subject: [PATCH 1506/2089] [SPARK-12577] [SQL] Better support of parentheses in partition by and order by clause of window function's over clause JIRA: https://issues.apache.org/jira/browse/SPARK-12577 Author: Liang-Chi Hsieh Closes #10620 from viirya/fix-parentheses. --- .../sql/catalyst/parser/ExpressionParser.g | 7 +++- .../spark/sql/catalyst/CatalystQlSuite.scala | 36 +++++++++++++------ 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index cad770122d150..aabb5d49582c8 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -223,7 +223,12 @@ precedenceUnaryPrefixExpression ; precedenceUnarySuffixExpression - : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + : + ( + (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN + | + precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + ) -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) -> precedenceUnaryPrefixExpression ; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 30978d9b49e2b..d7204c3488313 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -20,17 +20,33 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.plans.PlanTest class CatalystQlSuite extends PlanTest { + val parser = new CatalystQl() test("parse union/except/intersect") { - val paresr = new CatalystQl() - paresr.createPlan("select * from t1 union all select * from t2") - paresr.createPlan("select * from t1 union distinct select * from t2") - paresr.createPlan("select * from t1 union select * from t2") - paresr.createPlan("select * from t1 except select * from t2") - paresr.createPlan("select * from t1 intersect select * from t2") - paresr.createPlan("(select * from t1) union all (select * from t2)") - paresr.createPlan("(select * from t1) union distinct (select * from t2)") - paresr.createPlan("(select * from t1) union (select * from t2)") - paresr.createPlan("select * from ((select * from t1) union (select * from t2)) t") + parser.createPlan("select * from t1 union all select * from t2") + parser.createPlan("select * from t1 union distinct select * from t2") + parser.createPlan("select * from t1 union select * from t2") + parser.createPlan("select * from t1 except select * from t2") + parser.createPlan("select * from t1 intersect select * from t2") + parser.createPlan("(select * from t1) union all (select * from t2)") + parser.createPlan("(select * from t1) union distinct (select * from t2)") + parser.createPlan("(select * from t1) union (select * from t2)") + parser.createPlan("select * from ((select * from t1) union (select * from t2)) t") + } + + test("window function: better support of parentheses") { + parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + + "order by 2) from windowData") + parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + + "order by 2) from windowData") + parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + + "order by 2) from windowData") + + parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + + "from windowData") + parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + + "from windowData") + parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + + "from windowData") } } From 3d77cffec093bed4d330969f1a996f3358b9a772 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 9 Jan 2016 12:29:51 +0530 Subject: [PATCH 1507/2089] [SPARK-12645][SPARKR] SparkR support hash function Add ```hash``` function for SparkR ```DataFrame```. Author: Yanbo Liang Closes #10597 from yanboliang/spark-12645. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 20 ++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index beacc39500aaa..34be7f0ebd752 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -130,6 +130,7 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "hash", "cume_dist", "date_add", "date_format", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index df36bc869acb4..9bb7876b384ce 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -340,6 +340,26 @@ setMethod("crc32", column(jc) }) +#' hash +#' +#' Calculates the hash code of given columns, and returns the result as a int column. +#' +#' @rdname hash +#' @name hash +#' @family misc_funcs +#' @export +#' @examples \dontrun{hash(df$c)} +setMethod("hash", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "hash", jcols) + column(jc) + }) + #' dayofmonth #' #' Extracts the day of the month as an integer from a given date/timestamp/string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ba6861709754d..5ba68e3a4f378 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -736,6 +736,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname hash +#' @export +setGeneric("hash", function(x, ...) { standardGeneric("hash") }) + #' @rdname cume_dist #' @export setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index eaf60beda3473..97625b94a0e23 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -922,7 +922,7 @@ test_that("column functions", { c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) - c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) From b23c4521f5df905e4fe4d79dd5b670286e2697f7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Jan 2016 11:21:58 -0800 Subject: [PATCH 1508/2089] [SPARK-12340] Fix overflow in various take functions. This is a follow-up for the original patch #10562. Author: Reynold Xin Closes #10670 from rxin/SPARK-12340. --- .../scala/org/apache/spark/rdd/AsyncRDDActions.scala | 8 ++++---- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 4 ++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 7 +++---- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ------------ 6 files changed, 19 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 94719a4572ef6..7de9df1e489fb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - job.flatMap {_ => + job.flatMap { _ => buf.foreach(results ++= _.take(num - results.size)) continue(partsScanned + p.size) } } - new ComplexFutureAction[Seq[T]](continue(0L)(_)) + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e25657cc109be..de7102f5b6245 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0L + var partsScanned = 0 while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. @@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 24acbed4d7258..ef2ed445005d3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(nums.take(501) === (1 to 501).toArray) assert(nums.take(999) === (1 to 999).toArray) assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.parallelize(1 to 2, 2) + assert(nums.take(2147483638).size === 2) + assert(nums.takeAsync(2147483638).get.size === 2) } test("top with predefined ordering") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21a6fba9078df..2355de3d05865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -165,7 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length - var partsScanned = 0L + var partsScanned = 0 while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. @@ -183,10 +183,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ade1391ecd74a..983dfbdedeefe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -308,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( mapData.toDF().limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) + + // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake + checkAnswer( + sqlContext.range(2).limit(2147483638), + Row(0) :: Row(1) :: Nil + ) } test("except") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bd987ae1bb03a..5de0979606b88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2067,16 +2067,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } - - test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") { - val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) - rdd.toDF("key").registerTempTable("spark12340") - checkAnswer( - sql("select key from spark12340 limit 2147483638"), - Row(1) :: Row(2) :: Row(3) :: Nil - ) - assert(rdd.take(2147483638).size === 3) - assert(rdd.takeAsync(2147483638).get.size === 3) - } - } From 3efd106e5cc1312bfba693a694ed33a3609a6741 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Jan 2016 20:25:28 -0800 Subject: [PATCH 1509/2089] Close #10665 From 5b0d544339ef02fc25c816b6d6841031ef3902c2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Jan 2016 20:28:20 -0800 Subject: [PATCH 1510/2089] [SPARK-12735] Consolidate & move spark-ec2 to AMPLab managed repository. Author: Reynold Xin Closes #10673 from rxin/SPARK-12735. --- .gitignore | 1 - dev/create-release/release-tag.sh | 3 - dev/create-release/releaseutils.py | 1 - dev/lint-python | 2 +- dev/sparktestsupport/modules.py | 9 - docs/_layouts/global.html | 2 - docs/cluster-overview.md | 2 - docs/ec2-scripts.md | 192 --- docs/index.md | 5 +- ec2/README | 4 - .../root/spark-ec2/ec2-variables.sh | 34 - ec2/spark-ec2 | 25 - ec2/spark_ec2.py | 1530 ----------------- make-distribution.sh | 1 - 14 files changed, 3 insertions(+), 1808 deletions(-) delete mode 100644 docs/ec2-scripts.md delete mode 100644 ec2/README delete mode 100644 ec2/deploy.generic/root/spark-ec2/ec2-variables.sh delete mode 100755 ec2/spark-ec2 delete mode 100755 ec2/spark_ec2.py diff --git a/.gitignore b/.gitignore index 07524bc429e92..8ecf536e79a5f 100644 --- a/.gitignore +++ b/.gitignore @@ -60,7 +60,6 @@ dev/create-release/*final spark-*-bin-*.tgz unit-tests.log /lib/ -ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b0a3374becc6a..d404939d1caee 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -64,9 +64,6 @@ git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG -# TODO: It would be nice to do some verifications here -# i.e. check whether ec2 scripts have the new version - # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 7f152b7f53559..5d0ac16b3b0a1 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -159,7 +159,6 @@ def get_commits(tag): "build": CORE_COMPONENT, "deploy": CORE_COMPONENT, "documentation": CORE_COMPONENT, - "ec2": "EC2", "examples": CORE_COMPONENT, "graphx": "GraphX", "input/output": CORE_COMPONENT, diff --git a/dev/lint-python b/dev/lint-python index 0b97213ae3dff..1765a07d2f22b 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 47cd600bd18a4..1fc6596164124 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -406,15 +406,6 @@ def contains_file(self, filename): should_run_build_tests=True ) -ec2 = Module( - name="ec2", - dependencies=[], - source_file_regexes=[ - "ec2/", - ] -) - - yarn = Module( name="yarn", dependencies=[], diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 62d75eff71057..d493f62f0e578 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -98,8 +98,6 @@
  • Spark Standalone
  • Mesos
  • YARN
  • -
  • -
  • Amazon EC2
  • diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index faaf154d243f5..2810112f5294e 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -53,8 +53,6 @@ The system currently supports three cluster managers: and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone -cluster on Amazon EC2. # Submitting Applications diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md deleted file mode 100644 index 7f60f82b966fe..0000000000000 --- a/docs/ec2-scripts.md +++ /dev/null @@ -1,192 +0,0 @@ ---- -layout: global -title: Running Spark on EC2 ---- - -The `spark-ec2` script, located in Spark's `ec2` directory, allows you -to launch, manage and shut down Spark clusters on Amazon EC2. It automatically -sets up Spark and HDFS on the cluster for you. This guide describes -how to use `spark-ec2` to launch clusters, how to run jobs on them, and how -to shut them down. It assumes you've already signed up for an EC2 account -on the [Amazon Web Services site](http://aws.amazon.com/). - -`spark-ec2` is designed to manage multiple named clusters. You can -launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named -`test` will contain a master node in a security group called -`test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. - - -# Before You Start - -- Create an Amazon EC2 key pair for yourself. This can be done by - logging into your Amazon Web Services account through the [AWS - console](http://aws.amazon.com/console/), clicking Key Pairs on the - left sidebar, and creating and downloading a key. Make sure that you - set the permissions for the private key file to `600` (i.e. only you - can read and write it) so that `ssh` will work. -- Whenever you want to use the `spark-ec2` script, set the environment - variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to your - Amazon EC2 access key ID and secret access key. These can be - obtained from the [AWS homepage](http://aws.amazon.com/) by clicking - Account \> Security Credentials \> Access Credentials. - -# Launching a Cluster - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run - `./spark-ec2 -k -i -s launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster - ``` - -- After everything launches, check that the cluster scheduler is up and sees - all the slaves by going to its web UI, which will be printed at the end of - the script (typically `http://:8080`). - -You can also run `./spark-ec2 --help` to see more usage options. The -following options are worth pointing out: - -- `--instance-type=` can be used to specify an EC2 -instance type to use. For now, the script only supports 64-bit instance -types, and the default type is `m1.large` (which has 2 cores and 7.5 GB -RAM). Refer to the Amazon pages about [EC2 instance -types](http://aws.amazon.com/ec2/instance-types) and [EC2 -pricing](http://aws.amazon.com/ec2/#pricing) for information about other -instance types. -- `--region=` specifies an EC2 region in which to launch -instances. The default region is `us-east-1`. -- `--zone=` can be used to specify an EC2 availability zone -to launch instances in. Sometimes, you will get an error because there -is not enough capacity in one zone, and you should try to launch in -another. -- `--ebs-vol-size=` will attach an EBS volume with a given amount - of space to each node so that you can have a persistent HDFS cluster - on your nodes across cluster restarts (see below). -- `--spot-price=` will launch the worker nodes as - [Spot Instances](http://aws.amazon.com/ec2/spot-instances/), - bidding for the given maximum price (in dollars). -- `--spark-version=` will pre-load the cluster with the - specified version of Spark. The `` can be a version number - (e.g. "0.7.3") or a specific git hash. By default, a recent - version will be used. -- `--spark-git-repo=` will let you run a custom version of - Spark that is built from the given git repository. By default, the - [Apache Github mirror](https://github.com/apache/spark) will be used. - When using a custom Spark version, `--spark-version` must be set to git - commit hash, such as 317e114, instead of a version number. -- If one of your launches fails due to e.g. not having the right -permissions on your private key file, you can run `launch` with the -`--resume` option to restart the setup process on an existing cluster. - -# Launching a Cluster in a VPC - -- Run - `./spark-ec2 -k -i -s --vpc-id= --subnet-id= launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), `` is the name of your VPC, `` is the - name of your subnet, and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --vpc-id=vpc-a28d24c7 --subnet-id=subnet-4eb27b39 --spark-version=1.1.0 launch my-spark-cluster - ``` - -# Running Applications - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 -k -i login ` to - SSH into the cluster, where `` and `` are as - above. (This is just for convenience; you could also use - the EC2 console.) -- To deploy code or data within your cluster, you can log in and use the - provided script `~/spark-ec2/copy-dir`, which, - given a directory path, RSYNCs it to the same location on all the slaves. -- If your application needs to access large datasets, the fastest way to do - that is to load them from Amazon S3 or an Amazon EBS device into an - instance of the Hadoop Distributed File System (HDFS) on your nodes. - The `spark-ec2` script already sets up a HDFS instance for you. It's - installed in `/root/ephemeral-hdfs`, and can be accessed using the - `bin/hadoop` script in that directory. Note that the data in this - HDFS goes away when you stop and restart a machine. -- There is also a *persistent HDFS* instance in - `/root/persistent-hdfs` that will keep data across cluster restarts. - Typically each node has relatively little space of persistent data - (about 3 GB), but you can use the `--ebs-vol-size` option to - `spark-ec2` to attach a persistent EBS volume to each node for - storing the persistent HDFS. -- Finally, if you get errors while running your application, look at the slave's logs - for that application inside of the scheduler work directory (/root/spark/work). You can - also view the status of the cluster using the web UI: `http://:8080`. - -# Configuration - -You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such -as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to -do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master, -then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers. - -The [configuration guide](configuration.html) describes the available configuration options. - -# Terminating a Cluster - -***Note that there is no way to recover data on EC2 nodes after shutting -them down! Make sure you have copied everything important off the nodes -before stopping them.*** - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 destroy `. - -# Pausing and Restarting Clusters - -The `spark-ec2` script also supports pausing a cluster. In this case, -the VMs are stopped but not terminated, so they -***lose all data on ephemeral disks*** but keep the data in their -root partitions and their `persistent-hdfs`. Stopped machines will not -cost you any EC2 cycles, but ***will*** continue to cost money for EBS -storage. - -- To stop one of your clusters, go into the `ec2` directory and run -`./spark-ec2 --region= stop `. -- To restart it later, run -`./spark-ec2 -i --region= start `. -- To ultimately destroy the cluster and stop consuming EBS space, run -`./spark-ec2 --region= destroy ` as described in the previous -section. - -# Limitations - -- Support for "cluster compute" nodes is limited -- there's no way to specify a - locality group. However, you can launch slave nodes in your - `-slaves` group manually and then use `spark-ec2 launch - --resume` to start a cluster with them. - -If you have a patch or suggestion for one of these limitations, feel free to -[contribute](contributing-to-spark.html) it! - -# Accessing Data in S3 - -Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. To provide AWS credentials for S3 access, launch the Spark cluster with the option `--copy-aws-credentials`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3). - -In addition to using a single input file, you can also use a directory of files as input by simply giving the path to the directory. diff --git a/docs/index.md b/docs/index.md index ae26f97c86c21..9dfc52a2bdc9b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,7 +64,7 @@ To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] Example applications are also provided in R. For example, - + ./bin/spark-submit examples/src/main/r/dataframe.R # Launching on a Cluster @@ -73,7 +73,6 @@ The Spark [cluster mode overview](cluster-overview.html) explains the key concep Spark can run both by itself, or over several existing cluster managers. It currently provides several options for deployment: -* [Amazon EC2](ec2-scripts.html): our EC2 scripts let you launch a cluster in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): simplest way to deploy Spark on a private cluster * [Apache Mesos](running-on-mesos.html) * [Hadoop YARN](running-on-yarn.html) @@ -103,7 +102,7 @@ options for deployment: * [Cluster Overview](cluster-overview.html): overview of concepts and components when running on a cluster * [Submitting Applications](submitting-applications.html): packaging and deploying applications * Deployment modes: - * [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes + * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using [Apache Mesos](http://mesos.apache.org) diff --git a/ec2/README b/ec2/README deleted file mode 100644 index 72434f24bf98d..0000000000000 --- a/ec2/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder contains a script, spark-ec2, for launching Spark clusters on -Amazon EC2. Usage instructions are available online at: - -http://spark.apache.org/docs/latest/ec2-scripts.html diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh deleted file mode 100644 index 4f3e8da809f7f..0000000000000 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# These variables are automatically filled in by the spark-ec2 script. -export MASTERS="{{master_list}}" -export SLAVES="{{slave_list}}" -export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" -export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" -export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" -export MODULES="{{modules}}" -export SPARK_VERSION="{{spark_version}}" -export TACHYON_VERSION="{{tachyon_version}}" -export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" -export SWAP_MB="{{swap}}" -export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" -export SPARK_MASTER_OPTS="{{spark_master_opts}}" -export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" -export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 deleted file mode 100755 index 26e7d22655694..0000000000000 --- a/ec2/spark-ec2 +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Preserve the user's CWD so that relative paths are passed correctly to -#+ the underlying Python script. -SPARK_EC2_DIR="$(dirname "$0")" - -python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py deleted file mode 100755 index 19d5980560fef..0000000000000 --- a/ec2/spark_ec2.py +++ /dev/null @@ -1,1530 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import division, print_function, with_statement - -import codecs -import hashlib -import itertools -import logging -import os -import os.path -import pipes -import random -import shutil -import string -from stat import S_IRUSR -import subprocess -import sys -import tarfile -import tempfile -import textwrap -import time -import warnings -from datetime import datetime -from optparse import OptionParser -from sys import stderr - -if sys.version < "3": - from urllib2 import urlopen, Request, HTTPError -else: - from urllib.request import urlopen, Request - from urllib.error import HTTPError - raw_input = input - xrange = range - -SPARK_EC2_VERSION = "1.6.0" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) - -VALID_SPARK_VERSIONS = set([ - "0.7.3", - "0.8.0", - "0.8.1", - "0.9.0", - "0.9.1", - "0.9.2", - "1.0.0", - "1.0.1", - "1.0.2", - "1.1.0", - "1.1.1", - "1.2.0", - "1.2.1", - "1.3.0", - "1.3.1", - "1.4.0", - "1.4.1", - "1.5.0", - "1.5.1", - "1.5.2", - "1.6.0", -]) - -SPARK_TACHYON_MAP = { - "1.0.0": "0.4.1", - "1.0.1": "0.4.1", - "1.0.2": "0.4.1", - "1.1.0": "0.5.0", - "1.1.1": "0.5.0", - "1.2.0": "0.5.0", - "1.2.1": "0.5.0", - "1.3.0": "0.5.0", - "1.3.1": "0.5.0", - "1.4.0": "0.6.4", - "1.4.1": "0.6.4", - "1.5.0": "0.7.1", - "1.5.1": "0.7.1", - "1.5.2": "0.7.1", - "1.6.0": "0.8.2", -} - -DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION -DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" - -# Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" - - -def setup_external_libs(libs): - """ - Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. - """ - PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" - SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") - - if not os.path.exists(SPARK_EC2_LIB_DIR): - print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( - path=SPARK_EC2_LIB_DIR - )) - print("This should be a one-time operation.") - os.mkdir(SPARK_EC2_LIB_DIR) - - for lib in libs: - versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) - lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) - - if not os.path.isdir(lib_dir): - tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") - print(" - Downloading {lib}...".format(lib=lib["name"])) - download_stream = urlopen( - "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( - prefix=PYPI_URL_PREFIX, - first_letter=lib["name"][:1], - lib_name=lib["name"], - lib_version=lib["version"] - ) - ) - with open(tgz_file_path, "wb") as tgz_file: - tgz_file.write(download_stream.read()) - with open(tgz_file_path, "rb") as tar: - if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: - print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) - sys.exit(1) - tar = tarfile.open(tgz_file_path) - tar.extractall(path=SPARK_EC2_LIB_DIR) - tar.close() - os.remove(tgz_file_path) - print(" - Finished downloading {lib}.".format(lib=lib["name"])) - sys.path.insert(1, lib_dir) - - -# Only PyPI libraries are supported. -external_libs = [ - { - "name": "boto", - "version": "2.34.0", - "md5": "5556223d2d0cc4d06dd4829e671dcecd" - } -] - -setup_external_libs(external_libs) - -import boto -from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType -from boto import ec2 - - -class UsageError(Exception): - pass - - -# Configure and parse our command-line arguments -def parse_args(): - parser = OptionParser( - prog="spark-ec2", - version="%prog {v}".format(v=SPARK_EC2_VERSION), - usage="%prog [options] \n\n" - + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") - - parser.add_option( - "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: %default)") - parser.add_option( - "-w", "--wait", type="int", - help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") - parser.add_option( - "-k", "--key-pair", - help="Key pair to use on instances") - parser.add_option( - "-i", "--identity-file", - help="SSH private key file to use for logging into instances") - parser.add_option( - "-p", "--profile", default=None, - help="If you have multiple profiles (AWS or boto config), you can configure " + - "additional, named profiles by using this option (default: %default)") - parser.add_option( - "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: %default). " + - "WARNING: must be 64-bit; small instances won't work") - parser.add_option( - "-m", "--master-instance-type", default="", - help="Master instance type (leave empty for same as instance-type)") - parser.add_option( - "-r", "--region", default="us-east-1", - help="EC2 region used to launch instances in, or to find them in (default: %default)") - parser.add_option( - "-z", "--zone", default="", - help="Availability zone to launch instances in, or 'all' to spread " + - "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies) (default: a single zone chosen at random)") - parser.add_option( - "-a", "--ami", - help="Amazon Machine Image ID to use") - parser.add_option( - "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, - help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") - parser.add_option( - "--spark-git-repo", - default=DEFAULT_SPARK_GITHUB_REPO, - help="Github repo from which to checkout supplied commit hash (default: %default)") - parser.add_option( - "--spark-ec2-git-repo", - default=DEFAULT_SPARK_EC2_GITHUB_REPO, - help="Github repo from which to checkout spark-ec2 (default: %default)") - parser.add_option( - "--spark-ec2-git-branch", - default=DEFAULT_SPARK_EC2_BRANCH, - help="Github repo branch of spark-ec2 to use (default: %default)") - parser.add_option( - "--deploy-root-dir", - default=None, - help="A directory to copy into / on the first master. " + - "Must be absolute. Note that a trailing slash is handled as per rsync: " + - "If you omit it, the last directory of the --deploy-root-dir path will be created " + - "in / before copying its contents. If you append the trailing slash, " + - "the directory is not created and its contents are copied directly into /. " + - "(default: %default).") - parser.add_option( - "--hadoop-major-version", default="1", - help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + - "(Hadoop 2.4.0) (default: %default)") - parser.add_option( - "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", - help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + - "the given local address (for use with login)") - parser.add_option( - "--resume", action="store_true", default=False, - help="Resume installation on a previously launched cluster " + - "(for debugging)") - parser.add_option( - "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Size (in GB) of each EBS volume.") - parser.add_option( - "--ebs-vol-type", default="standard", - help="EBS volume type (e.g. 'gp2', 'standard').") - parser.add_option( - "--ebs-vol-num", type="int", default=1, - help="Number of EBS volumes to attach to each node as /vol[x]. " + - "The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0. " + - "Only support up to 8 EBS volumes.") - parser.add_option( - "--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") - parser.add_option( - "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: %default)") - parser.add_option( - "--spot-price", metavar="PRICE", type="float", - help="If specified, launch slaves as spot instances with the given " + - "maximum price (in dollars)") - parser.add_option( - "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + - "the Ganglia page will be publicly accessible") - parser.add_option( - "--no-ganglia", action="store_false", dest="ganglia", - help="Disable Ganglia monitoring for the cluster") - parser.add_option( - "-u", "--user", default="root", - help="The SSH user you want to connect as (default: %default)") - parser.add_option( - "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") - parser.add_option( - "--use-existing-master", action="store_true", default=False, - help="Launch fresh slaves, but use an existing stopped master if possible") - parser.add_option( - "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + - "is used as Hadoop major version (default: %default)") - parser.add_option( - "--master-opts", type="string", default="", - help="Extra options to give to master through SPARK_MASTER_OPTS variable " + - "(e.g -Dspark.worker.timeout=180)") - parser.add_option( - "--user-data", type="string", default="", - help="Path to a user-data file (most AMIs interpret this as an initialization script)") - parser.add_option( - "--authorized-address", type="string", default="0.0.0.0/0", - help="Address to authorize on created security groups (default: %default)") - parser.add_option( - "--additional-security-group", type="string", default="", - help="Additional security group to place the machines in") - parser.add_option( - "--additional-tags", type="string", default="", - help="Additional tags to set on the machines; tags are comma-separated, while name and " + - "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") - parser.add_option( - "--copy-aws-credentials", action="store_true", default=False, - help="Add AWS credentials to hadoop configuration to allow Spark to access S3") - parser.add_option( - "--subnet-id", default=None, - help="VPC subnet to launch instances in") - parser.add_option( - "--vpc-id", default=None, - help="VPC to launch instances in") - parser.add_option( - "--private-ips", action="store_true", default=False, - help="Use private IPs for instances rather than public if VPC/subnet " + - "requires that.") - parser.add_option( - "--instance-initiated-shutdown-behavior", default="stop", - choices=["stop", "terminate"], - help="Whether instances should terminate when shut down or just stop") - parser.add_option( - "--instance-profile-name", default=None, - help="IAM profile name to launch instances under") - - (opts, args) = parser.parse_args() - if len(args) != 2: - parser.print_help() - sys.exit(1) - (action, cluster_name) = args - - # Boto config check - # http://boto.cloudhackers.com/en/latest/boto_config_tut.html - home_dir = os.getenv('HOME') - if home_dir is None or not os.path.isfile(home_dir + '/.boto'): - if not os.path.isfile('/etc/boto.cfg'): - # If there is no boto config, check aws credentials - if not os.path.isfile(home_dir + '/.aws/credentials'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) - return (opts, action, cluster_name) - - -# Get the EC2 security group of the given name, creating it if it doesn't exist -def get_or_make_group(conn, name, vpc_id): - groups = conn.get_all_security_groups() - group = [g for g in groups if g.name == name] - if len(group) > 0: - return group[0] - else: - print("Creating security group " + name) - return conn.create_security_group(name, "Spark EC2 group", vpc_id) - - -def get_validate_spark_version(version, repo): - if "." in version: - version = version.replace("v", "") - if version not in VALID_SPARK_VERSIONS: - print("Don't know about Spark version: {v}".format(v=version), file=stderr) - sys.exit(1) - return version - else: - github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = Request(github_commit_url) - request.get_method = lambda: 'HEAD' - try: - response = urlopen(request) - except HTTPError as e: - print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), - file=stderr) - print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) - sys.exit(1) - return version - - -# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-06-19 -# For easy maintainability, please keep this manually-inputted dictionary sorted by key. -EC2_INSTANCE_TYPES = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c4.large": "hvm", - "c4.xlarge": "hvm", - "c4.2xlarge": "hvm", - "c4.4xlarge": "hvm", - "c4.8xlarge": "hvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "d2.xlarge": "hvm", - "d2.2xlarge": "hvm", - "d2.4xlarge": "hvm", - "d2.8xlarge": "hvm", - "g2.2xlarge": "hvm", - "g2.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.xlarge": "hvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", - "m4.large": "hvm", - "m4.xlarge": "hvm", - "m4.2xlarge": "hvm", - "m4.4xlarge": "hvm", - "m4.10xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "t1.micro": "pvm", - "t2.micro": "hvm", - "t2.small": "hvm", - "t2.medium": "hvm", - "t2.large": "hvm", -} - - -def get_tachyon_version(spark_version): - return SPARK_TACHYON_MAP.get(spark_version, "") - - -# Attempt to resolve an appropriate AMI given the architecture and region of the request. -def get_spark_ami(opts): - if opts.instance_type in EC2_INSTANCE_TYPES: - instance_type = EC2_INSTANCE_TYPES[opts.instance_type] - else: - instance_type = "pvm" - print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) - - # URL prefix from which to fetch AMI information - ami_prefix = "{r}/{b}/ami-list".format( - r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), - b=opts.spark_ec2_git_branch) - - ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) - reader = codecs.getreader("ascii") - try: - ami = reader(urlopen(ami_path)).read().strip() - except: - print("Could not resolve AMI at: " + ami_path, file=stderr) - sys.exit(1) - - print("Spark AMI: " + ami) - return ami - - -# Launch a cluster of the given name, by setting up its security groups, -# and then starting new instances in them. -# Returns a tuple of EC2 reservation objects for the master and slaves -# Fails if there already instances running in the cluster's groups. -def launch_cluster(conn, opts, cluster_name): - if opts.identity_file is None: - print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) - sys.exit(1) - - if opts.key_pair is None: - print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) - sys.exit(1) - - user_data_content = None - if opts.user_data: - with open(opts.user_data) as user_data_file: - user_data_content = user_data_file.read() - - print("Setting up security groups...") - master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) - slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) - authorized_address = opts.authorized_address - if master_group.rules == []: # Group was just now created - if opts.vpc_id is None: - master_group.authorize(src_group=master_group) - master_group.authorize(src_group=slave_group) - else: - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize('tcp', 22, 22, authorized_address) - master_group.authorize('tcp', 8080, 8081, authorized_address) - master_group.authorize('tcp', 18080, 18080, authorized_address) - master_group.authorize('tcp', 19999, 19999, authorized_address) - master_group.authorize('tcp', 50030, 50030, authorized_address) - master_group.authorize('tcp', 50070, 50070, authorized_address) - master_group.authorize('tcp', 60070, 60070, authorized_address) - master_group.authorize('tcp', 4040, 4045, authorized_address) - # Rstudio (GUI for R) needs port 8787 for web access - master_group.authorize('tcp', 8787, 8787, authorized_address) - # HDFS NFS gateway requires 111,2049,4242 for tcp & udp - master_group.authorize('tcp', 111, 111, authorized_address) - master_group.authorize('udp', 111, 111, authorized_address) - master_group.authorize('tcp', 2049, 2049, authorized_address) - master_group.authorize('udp', 2049, 2049, authorized_address) - master_group.authorize('tcp', 4242, 4242, authorized_address) - master_group.authorize('udp', 4242, 4242, authorized_address) - # RM in YARN mode uses 8088 - master_group.authorize('tcp', 8088, 8088, authorized_address) - if opts.ganglia: - master_group.authorize('tcp', 5080, 5080, authorized_address) - if slave_group.rules == []: # Group was just now created - if opts.vpc_id is None: - slave_group.authorize(src_group=master_group) - slave_group.authorize(src_group=slave_group) - else: - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize('tcp', 22, 22, authorized_address) - slave_group.authorize('tcp', 8080, 8081, authorized_address) - slave_group.authorize('tcp', 50060, 50060, authorized_address) - slave_group.authorize('tcp', 50075, 50075, authorized_address) - slave_group.authorize('tcp', 60060, 60060, authorized_address) - slave_group.authorize('tcp', 60075, 60075, authorized_address) - - # Check if instances are already running in our groups - existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if existing_slaves or (existing_masters and not opts.use_existing_master): - print("ERROR: There are already instances running in group %s or %s" % - (master_group.name, slave_group.name), file=stderr) - sys.exit(1) - - # Figure out Spark AMI - if opts.ami is None: - opts.ami = get_spark_ami(opts) - - # we use group ids to work around https://github.com/boto/boto/issues/350 - additional_group_ids = [] - if opts.additional_security_group: - additional_group_ids = [sg.id - for sg in conn.get_all_security_groups() - if opts.additional_security_group in (sg.name, sg.id)] - print("Launching instances...") - - try: - image = conn.get_all_images(image_ids=[opts.ami])[0] - except: - print("Could not find AMI " + opts.ami, file=stderr) - sys.exit(1) - - # Create block device mapping so that we can add EBS volumes if asked to. - # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz - block_map = BlockDeviceMapping() - if opts.ebs_vol_size > 0: - for i in range(opts.ebs_vol_num): - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.volume_type = opts.ebs_vol_type - device.delete_on_termination = True - block_map["/dev/sd" + chr(ord('s') + i)] = device - - # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342). - if opts.instance_type.startswith('m3.'): - for i in range(get_num_disks(opts.instance_type)): - dev = BlockDeviceType() - dev.ephemeral_name = 'ephemeral%d' % i - # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.ascii_letters[i + 1] - block_map[name] = dev - - # Launch slaves - if opts.spot_price is not None: - # Launch spot instances with the requested price - print("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - my_req_ids = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - slave_reqs = conn.request_spot_instances( - price=opts.spot_price, - image_id=opts.ami, - launch_group="launch-group-%s" % cluster_name, - placement=zone, - count=num_slaves_this_zone, - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_profile_name=opts.instance_profile_name) - my_req_ids += [req.id for req in slave_reqs] - i += 1 - - print("Waiting for spot instances to be granted...") - try: - while True: - time.sleep(10) - reqs = conn.get_all_spot_instance_requests() - id_to_req = {} - for r in reqs: - id_to_req[r.id] = r - active_instance_ids = [] - for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - if len(active_instance_ids) == opts.slaves: - print("All %d slaves granted" % opts.slaves) - reservations = conn.get_all_reservations(active_instance_ids) - slave_nodes = [] - for r in reservations: - slave_nodes += r.instances - break - else: - print("%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves)) - except: - print("Canceling spot instance requests") - conn.cancel_spot_instance_requests(my_req_ids) - # Log a warning if any of these requests actually launched instances: - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - running = len(master_nodes) + len(slave_nodes) - if running: - print(("WARNING: %d instances are still running" % running), file=stderr) - sys.exit(0) - else: - # Launch non-spot instances - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - slave_nodes = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - if num_slaves_this_zone > 0: - slave_res = image.run( - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - slave_nodes += slave_res.instances - print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( - s=num_slaves_this_zone, - plural_s=('' if num_slaves_this_zone == 1 else 's'), - z=zone, - r=slave_res.id)) - i += 1 - - # Launch or resume masters - if existing_masters: - print("Starting master...") - for inst in existing_masters: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - master_nodes = existing_masters - else: - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run( - key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - - master_nodes = master_res.instances - print("Launched master in %s, regid = %s" % (zone, master_res.id)) - - # This wait time corresponds to SPARK-4983 - print("Waiting for AWS to propagate instance metadata...") - time.sleep(15) - - # Give the instances descriptive names and set additional tags - additional_tags = {} - if opts.additional_tags.strip(): - additional_tags = dict( - map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') - ) - - for master in master_nodes: - master.add_tags( - dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) - ) - - for slave in slave_nodes: - slave.add_tags( - dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) - ) - - # Return all the instances - return (master_nodes, slave_nodes) - - -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - """ - Get the EC2 instances in an existing cluster if available. - Returns a tuple of lists of EC2 instance objects for the masters and slaves. - """ - print("Searching for existing cluster {c} in region {r}...".format( - c=cluster_name, r=opts.region)) - - def get_instances(group_names): - """ - Get all non-terminated instances that belong to any of the provided security groups. - - EC2 reservation filters and instance states are documented here: - http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options - """ - reservations = conn.get_all_reservations( - filters={"instance.group-name": group_names}) - instances = itertools.chain.from_iterable(r.instances for r in reservations) - return [i for i in instances if i.state not in ["shutting-down", "terminated"]] - - master_instances = get_instances([cluster_name + "-master"]) - slave_instances = get_instances([cluster_name + "-slaves"]) - - if any((master_instances, slave_instances)): - print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( - m=len(master_instances), - plural_m=('' if len(master_instances) == 1 else 's'), - s=len(slave_instances), - plural_s=('' if len(slave_instances) == 1 else 's'))) - - if not master_instances and die_on_error: - print("ERROR: Could not find a master for cluster {c} in region {r}.".format( - c=cluster_name, r=opts.region), file=sys.stderr) - sys.exit(1) - - return (master_instances, slave_instances) - - -# Deploy configuration files and run setup scripts on a newly launched -# or started EC2 cluster. -def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = get_dns_name(master_nodes[0], opts.private_ips) - if deploy_ssh_key: - print("Generating cluster's SSH key on master...") - key_setup = """ - [ -f ~/.ssh/id_rsa ] || - (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) - """ - ssh(master, opts, key_setup) - dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print("Transferring cluster's SSH key to slaves...") - for slave in slave_nodes: - slave_address = get_dns_name(slave, opts.private_ips) - print(slave_address) - ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) - - modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] - - if opts.hadoop_major_version == "1": - modules = list(filter(lambda x: x != "mapreduce", modules)) - - if opts.ganglia: - modules.append('ganglia') - - # Clear SPARK_WORKER_INSTANCES if running on YARN - if opts.hadoop_major_version == "yarn": - opts.worker_instances = "" - - # NOTE: We should clone the repository before running deploy_files to - # prevent ec2-variables.sh from being overwritten - print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) - ssh( - host=master, - opts=opts, - command="rm -rf spark-ec2" - + " && " - + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, - b=opts.spark_ec2_git_branch) - ) - - print("Deploying files to master...") - deploy_files( - conn=conn, - root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", - opts=opts, - master_nodes=master_nodes, - slave_nodes=slave_nodes, - modules=modules - ) - - if opts.deploy_root_dir is not None: - print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) - deploy_user_files( - root_dir=opts.deploy_root_dir, - opts=opts, - master_nodes=master_nodes - ) - - print("Running setup on master...") - setup_spark_cluster(master, opts) - print("Done!") - - -def setup_spark_cluster(master, opts): - ssh(master, opts, "chmod u+x spark-ec2/setup.sh") - ssh(master, opts, "spark-ec2/setup.sh") - print("Spark standalone cluster started at http://%s:8080" % master) - - if opts.ganglia: - print("Ganglia started at http://%s:5080/ganglia" % master) - - -def is_ssh_available(host, opts, print_ssh_output=True): - """ - Check if SSH is available on a host. - """ - s = subprocess.Popen( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order - ) - cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout - - if s.returncode != 0 and print_ssh_output: - # extra leading newline is for spacing in wait_for_cluster_state() - print(textwrap.dedent("""\n - Warning: SSH connection error. (This could be temporary.) - Host: {h} - SSH return code: {r} - SSH output: {o} - """).format( - h=host, - r=s.returncode, - o=cmd_output.strip() - )) - - return s.returncode == 0 - - -def is_cluster_ssh_available(cluster_instances, opts): - """ - Check if SSH is available on all the instances in a cluster. - """ - for i in cluster_instances: - dns_name = get_dns_name(i, opts.private_ips) - if not is_ssh_available(host=dns_name, opts=opts): - return False - else: - return True - - -def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): - """ - Wait for all the instances in the cluster to reach a designated state. - - cluster_instances: a list of boto.ec2.instance.Instance - cluster_state: a string representing the desired state of all the instances in the cluster - value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as - 'running', 'terminated', etc. - (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) - """ - sys.stdout.write( - "Waiting for cluster to enter '{s}' state.".format(s=cluster_state) - ) - sys.stdout.flush() - - start_time = datetime.now() - num_attempts = 0 - - while True: - time.sleep(5 * num_attempts) # seconds - - for i in cluster_instances: - i.update() - - max_batch = 100 - statuses = [] - for j in xrange(0, len(cluster_instances), max_batch): - batch = [i.id for i in cluster_instances[j:j + max_batch]] - statuses.extend(conn.get_all_instance_status(instance_ids=batch)) - - if cluster_state == 'ssh-ready': - if all(i.state == 'running' for i in cluster_instances) and \ - all(s.system_status.status == 'ok' for s in statuses) and \ - all(s.instance_status.status == 'ok' for s in statuses) and \ - is_cluster_ssh_available(cluster_instances, opts): - break - else: - if all(i.state == cluster_state for i in cluster_instances): - break - - num_attempts += 1 - - sys.stdout.write(".") - sys.stdout.flush() - - sys.stdout.write("\n") - - end_time = datetime.now() - print("Cluster is now in '{s}' state. Waited {t} seconds.".format( - s=cluster_state, - t=(end_time - start_time).seconds - )) - - -# Get number of local disks available for a given EC2 instance type. -def get_num_disks(instance_type): - # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-06-19 - # For easy maintainability, please keep this manually-inputted dictionary sorted by key. - disks_by_instance = { - "c1.medium": 1, - "c1.xlarge": 4, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "c4.large": 0, - "c4.xlarge": 0, - "c4.2xlarge": 0, - "c4.4xlarge": 0, - "c4.8xlarge": 0, - "cc1.4xlarge": 2, - "cc2.8xlarge": 4, - "cg1.4xlarge": 2, - "cr1.8xlarge": 2, - "d2.xlarge": 3, - "d2.2xlarge": 6, - "d2.4xlarge": 12, - "d2.8xlarge": 24, - "g2.2xlarge": 1, - "g2.8xlarge": 2, - "hi1.4xlarge": 2, - "hs1.8xlarge": 24, - "i2.xlarge": 1, - "i2.2xlarge": 2, - "i2.4xlarge": 4, - "i2.8xlarge": 8, - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "m4.large": 0, - "m4.xlarge": 0, - "m4.2xlarge": 0, - "m4.4xlarge": 0, - "m4.10xlarge": 0, - "r3.large": 1, - "r3.xlarge": 1, - "r3.2xlarge": 1, - "r3.4xlarge": 1, - "r3.8xlarge": 2, - "t1.micro": 0, - "t2.micro": 0, - "t2.small": 0, - "t2.medium": 0, - "t2.large": 0, - } - if instance_type in disks_by_instance: - return disks_by_instance[instance_type] - else: - print("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type, file=stderr) - return 1 - - -# Deploy the configuration file templates in a given local directory to -# a cluster, filling in any template parameters with information about the -# cluster (e.g. lists of masters and slaves). Files are only deployed to -# the first master instance in the cluster, and we expect the setup -# script to be run on that instance to copy them to other nodes. -# -# root_dir should be an absolute path to the directory with the files we want to deploy. -def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - - num_disks = get_num_disks(opts.instance_type) - hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" - mapred_local_dirs = "/mnt/hadoop/mrlocal" - spark_local_dirs = "/mnt/spark" - if num_disks > 1: - for i in range(2, num_disks + 1): - hdfs_data_dirs += ",/mnt%d/ephemeral-hdfs/data" % i - mapred_local_dirs += ",/mnt%d/hadoop/mrlocal" % i - spark_local_dirs += ",/mnt%d/spark" % i - - cluster_url = "%s:7077" % active_master - - if "." in opts.spark_version: - # Pre-built Spark deploy - spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - tachyon_v = get_tachyon_version(spark_v) - else: - # Spark-only custom deploy - spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) - tachyon_v = "" - print("Deploying Spark via git hash; Tachyon won't be set up") - modules = filter(lambda x: x != "tachyon", modules) - - master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] - slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] - worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" - template_vars = { - "master_list": '\n'.join(master_addresses), - "active_master": active_master, - "slave_list": '\n'.join(slave_addresses), - "cluster_url": cluster_url, - "hdfs_data_dirs": hdfs_data_dirs, - "mapred_local_dirs": mapred_local_dirs, - "spark_local_dirs": spark_local_dirs, - "swap": str(opts.swap), - "modules": '\n'.join(modules), - "spark_version": spark_v, - "tachyon_version": tachyon_v, - "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": worker_instances_str, - "spark_master_opts": opts.master_opts - } - - if opts.copy_aws_credentials: - template_vars["aws_access_key_id"] = conn.aws_access_key_id - template_vars["aws_secret_access_key"] = conn.aws_secret_access_key - else: - template_vars["aws_access_key_id"] = "" - template_vars["aws_secret_access_key"] = "" - - # Create a temp directory in which we will place all the files to be - # deployed after we substitue template parameters in them - tmp_dir = tempfile.mkdtemp() - for path, dirs, files in os.walk(root_dir): - if path.find(".svn") == -1: - dest_dir = os.path.join('/', path[len(root_dir):]) - local_dir = tmp_dir + dest_dir - if not os.path.exists(local_dir): - os.makedirs(local_dir) - for filename in files: - if filename[0] not in '#.~' and filename[-1] != '~': - dest_file = os.path.join(dest_dir, filename) - local_file = tmp_dir + dest_file - with open(os.path.join(path, filename)) as src: - with open(local_file, "w") as dest: - text = src.read() - for key in template_vars: - text = text.replace("{{" + key + "}}", template_vars[key]) - dest.write(text) - dest.close() - # rsync the whole directory over to the master machine - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s/" % tmp_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - # Remove the temp directory we created above - shutil.rmtree(tmp_dir) - - -# Deploy a given local directory to a cluster, WITHOUT parameter substitution. -# Note that unlike deploy_files, this works for binary files. -# Also, it is up to the user to add (or not) the trailing slash in root_dir. -# Files are only deployed to the first master instance in the cluster. -# -# root_dir should be an absolute path. -def deploy_user_files(root_dir, opts, master_nodes): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s" % root_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - - -def stringify_command(parts): - if isinstance(parts, str): - return parts - else: - return ' '.join(map(pipes.quote, parts)) - - -def ssh_args(opts): - parts = ['-o', 'StrictHostKeyChecking=no'] - parts += ['-o', 'UserKnownHostsFile=/dev/null'] - if opts.identity_file is not None: - parts += ['-i', opts.identity_file] - return parts - - -def ssh_command(opts): - return ['ssh'] + ssh_args(opts) - - -# Run a command on a host through ssh, retrying up to five times -# and then throwing an exception if ssh continues to fail. -def ssh(host, opts, command): - tries = 0 - while True: - try: - return subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), - stringify_command(command)]) - except subprocess.CalledProcessError as e: - if tries > 5: - # If this was an ssh failure, provide the user with hints. - if e.returncode == 255: - raise UsageError( - "Failed to SSH to remote host {0}.\n" - "Please check that you have provided the correct --identity-file and " - "--key-pair parameters and try again.".format(host)) - else: - raise e - print("Error executing remote command, retrying after 30 seconds: {0}".format(e), - file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) -def _check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - -def ssh_read(host, opts, command): - return _check_output( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) - - -def ssh_write(host, opts, command, arguments): - tries = 0 - while True: - proc = subprocess.Popen( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], - stdin=subprocess.PIPE) - proc.stdin.write(arguments) - proc.stdin.close() - status = proc.wait() - if status == 0: - break - elif tries > 5: - raise RuntimeError("ssh_write failed with error %s" % proc.returncode) - else: - print("Error {0} while executing remote command, retrying after 30 seconds". - format(status), file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Gets a list of zones to launch instances in -def get_zones(conn, opts): - if opts.zone == 'all': - zones = [z.name for z in conn.get_all_zones()] - else: - zones = [opts.zone] - return zones - - -# Gets the number of items in a partition -def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total // num_partitions - if (total % num_partitions) - current_partitions > 0: - num_slaves_this_zone += 1 - return num_slaves_this_zone - - -# Gets the IP address, taking into account the --private-ips flag -def get_ip_address(instance, private_ips=False): - ip = instance.ip_address if not private_ips else \ - instance.private_ip_address - return ip - - -# Gets the DNS name, taking into account the --private-ips flag -def get_dns_name(instance, private_ips=False): - dns = instance.public_dns_name if not private_ips else \ - instance.private_ip_address - if not dns: - raise UsageError("Failed to determine hostname of {0}.\n" - "Please check that you provided --private-ips if " - "necessary".format(instance)) - return dns - - -def real_main(): - (opts, action, cluster_name) = parse_args() - - # Input parameter validation - get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - - if opts.wait is not None: - # NOTE: DeprecationWarnings are silent in 2.7+ by default. - # To show them, run Python with the -Wdefault switch. - # See: https://docs.python.org/3.5/whatsnew/2.7.html - warnings.warn( - "This option is deprecated and has no effect. " - "spark-ec2 automatically waits as long as necessary for clusters to start up.", - DeprecationWarning - ) - - if opts.identity_file is not None: - if not os.path.exists(opts.identity_file): - print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - file_mode = os.stat(opts.identity_file).st_mode - if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print("ERROR: The identity file must be accessible only by you.", file=stderr) - print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - if opts.instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type), file=stderr) - - if opts.master_instance_type != "": - if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type), file=stderr) - # Since we try instance types even if we can't resolve them, we check if they resolve first - # and, if they do, see if they resolve to the same virtualization type. - if opts.instance_type in EC2_INSTANCE_TYPES and \ - opts.master_instance_type in EC2_INSTANCE_TYPES: - if EC2_INSTANCE_TYPES[opts.instance_type] != \ - EC2_INSTANCE_TYPES[opts.master_instance_type]: - print("Error: spark-ec2 currently does not support having a master and slaves " - "with different AMI virtualization types.", file=stderr) - print("master instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) - print("slave instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) - sys.exit(1) - - if opts.ebs_vol_num > 8: - print("ebs-vol-num cannot be greater than 8", file=stderr) - sys.exit(1) - - # Prevent breaking ami_prefix (/, .git and startswith checks) - # Prevent forks with non spark-ec2 names for now. - if opts.spark_ec2_git_repo.endswith("/") or \ - opts.spark_ec2_git_repo.endswith(".git") or \ - not opts.spark_ec2_git_repo.startswith("https://github.com") or \ - not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " - "Furthermore, we currently only support forks named spark-ec2.", file=stderr) - sys.exit(1) - - if not (opts.deploy_root_dir is None or - (os.path.isabs(opts.deploy_root_dir) and - os.path.isdir(opts.deploy_root_dir) and - os.path.exists(opts.deploy_root_dir))): - print("--deploy-root-dir must be an absolute path to a directory that exists " - "on the local file system", file=stderr) - sys.exit(1) - - try: - if opts.profile is None: - conn = ec2.connect_to_region(opts.region) - else: - conn = ec2.connect_to_region(opts.region, profile_name=opts.profile) - except Exception as e: - print((e), file=stderr) - sys.exit(1) - - # Select an AZ at random if it was not specified. - if opts.zone == "": - opts.zone = random.choice(conn.get_all_zones()).name - - if action == "launch": - if opts.slaves <= 0: - print("ERROR: You have to start at least 1 slave", file=sys.stderr) - sys.exit(1) - if opts.resume: - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - else: - (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - setup_cluster(conn, master_nodes, slave_nodes, opts, True) - - elif action == "destroy": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - - if any(master_nodes + slave_nodes): - print("The following instances will be terminated:") - for inst in master_nodes + slave_nodes: - print("> %s" % get_dns_name(inst, opts.private_ips)) - print("ALL DATA ON ALL NODES WILL BE LOST!!") - - msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) - response = raw_input(msg) - if response == "y": - print("Terminating master...") - for inst in master_nodes: - inst.terminate() - print("Terminating slaves...") - for inst in slave_nodes: - inst.terminate() - - # Delete security groups as well - if opts.delete_groups: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='terminated' - ) - print("Deleting security groups (this will take some time)...") - attempt = 1 - while attempt <= 3: - print("Attempt %d" % attempt) - groups = [g for g in conn.get_all_security_groups() if g.name in group_names] - success = True - # Delete individual rules in all groups before deleting groups to - # remove dependencies between them - for group in groups: - print("Deleting rules in security group " + group.name) - for rule in group.rules: - for grant in rule.grants: - success &= group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - - # Sleep for AWS eventual-consistency to catch up, and for instances - # to terminate - time.sleep(30) # Yes, it does have to be this long :-( - for group in groups: - try: - # It is needed to use group_id to make it work with VPC - conn.delete_security_group(group_id=group.id) - print("Deleted security group %s" % group.name) - except boto.exception.EC2ResponseError: - success = False - print("Failed to delete security group %s" % group.name) - - # Unfortunately, group.revoke() returns True even if a rule was not - # deleted, so this needs to be rerun if something fails - if success: - break - - attempt += 1 - - if not success: - print("Failed to delete all security groups after 3 tries.") - print("Try re-running in a few minutes.") - - elif action == "login": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - master = get_dns_name(master_nodes[0], opts.private_ips) - print("Logging into master " + master + "...") - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) - - elif action == "reboot-slaves": - response = raw_input( - "Are you sure you want to reboot the cluster " + - cluster_name + " slaves?\n" + - "Reboot cluster slaves " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Rebooting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - print("Rebooting " + inst.id) - inst.reboot() - - elif action == "get-master": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - print(get_dns_name(master_nodes[0], opts.private_ips)) - - elif action == "stop": - response = raw_input( - "Are you sure you want to stop the cluster " + - cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + - "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + - "AMAZON EBS IF IT IS EBS-BACKED!!\n" + - "All data on spot-instance slaves will be lost.\n" + - "Stop cluster " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Stopping master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.stop() - print("Stopping slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - if inst.spot_instance_request_id: - inst.terminate() - else: - inst.stop() - - elif action == "start": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print("Starting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - print("Starting master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - - # Determine types of running instances - existing_master_type = master_nodes[0].instance_type - existing_slave_type = slave_nodes[0].instance_type - # Setting opts.master_instance_type to the empty string indicates we - # have the same instance type for the master and the slaves - if existing_master_type == existing_slave_type: - existing_master_type = "" - opts.master_instance_type = existing_master_type - opts.instance_type = existing_slave_type - - setup_cluster(conn, master_nodes, slave_nodes, opts, False) - - else: - print("Invalid action: %s" % action, file=stderr) - sys.exit(1) - - -def main(): - try: - real_main() - except UsageError as e: - print("\nError:\n", e, file=stderr) - sys.exit(1) - - -if __name__ == "__main__": - logging.basicConfig() - main() diff --git a/make-distribution.sh b/make-distribution.sh index a38fd8df17206..327659298e4d8 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -212,7 +212,6 @@ cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" -cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Copy SparkR if it exists if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib From b78e028e37193a4e27b012f0b3c8343d850c5674 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sun, 10 Jan 2016 10:36:01 +0000 Subject: [PATCH 1511/2089] =?UTF-8?q?[SPARK-12736][CORE][DEPLOY]=20Standal?= =?UTF-8?q?one=20Master=20cannot=20be=20started=20due=20t=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …o NoClassDefFoundError: org/spark-project/guava/collect/Maps /cc srowen rxin Author: Jacek Laskowski Closes #10674 from jaceklaskowski/SPARK-12736. --- network/common/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/network/common/pom.xml b/network/common/pom.xml index 92ca0046d4f53..eda2b7307088f 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -55,6 +55,7 @@ com.google.guava guava + compile From e5904bb5e7d83b3731b312c40f7904c0511019f5 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 10 Jan 2016 12:38:57 -0800 Subject: [PATCH 1512/2089] [SPARK-12692][BUILD][MLLIB] Scala style: Fix the style violation (Space before "," or ":") Fix the style violation (space before , and :). This PR is a followup for #10643. Author: Kousuke Saruta Closes #10684 from sarutak/SPARK-12692-followup-mllib. --- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/NaiveBayesExample.scala | 2 +- .../spark/examples/mllib/RegressionMetricsExample.scala | 2 +- .../org/apache/spark/ml/classification/OneVsRest.scala | 4 ++-- .../main/scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 2 +- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 4 ++-- .../spark/mllib/clustering/GaussianMixtureModel.scala | 2 +- .../main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 2 +- .../pmml/export/BinaryClassificationPMMLModelExport.scala | 6 +++--- .../spark/mllib/pmml/export/KMeansPMMLModelExport.scala | 4 ++-- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- .../main/scala/org/apache/spark/mllib/tree/model/Node.scala | 2 +- .../org/apache/spark/mllib/util/LinearDataGenerator.scala | 2 +- .../org/apache/spark/mllib/classification/SVMSuite.scala | 2 +- .../org/apache/spark/mllib/stat/StreamingTestSuite.scala | 2 +- 17 files changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 3834ea807acbf..c4336639d7c0b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegression object IsotonicRegressionExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("IsotonicRegressionExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index 8bae1b9d1832d..0187ad603a654 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint object NaiveBayesExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("NaiveBayesExample") val sc = new SparkContext(conf) // $example on$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala index ace16ff1ea225..add634c957b40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.SQLContext object RegressionMetricsExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("RegressionMetricsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 08a51109d6c62..c41a611f1cc60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -113,13 +113,13 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } - val transformedDataset = model.transform(df).select(columns : _*) + val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f9952434d2982..6cc9d025445c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -238,7 +238,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def transform(dataset: DataFrame): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) - dataset.select(columnsToKeep.map(dataset.col) : _*) + dataset.select(columnsToKeep.map(dataset.col): _*) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0b215659b3672..716bc63e00995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -102,7 +102,7 @@ class VectorAssembler(override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) + dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 6e87302c7779b..d3376a7dff938 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -474,7 +474,7 @@ private[ml] object RandomForest extends Logging { val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator @@ -825,7 +825,7 @@ private[ml] object RandomForest extends Logging { protected[tree] def findSplits( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata, - seed : Long): Array[Array[Split]] = { + seed: Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 5c9bc62cb09bb..16bc45bcb627f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -177,7 +177,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { } @Since("1.4.0") - override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + override def load(sc: SparkContext, path: String): GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats val k = (metadata \ "k").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 5273ed4d76650..ffae0e7ed0ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -134,7 +134,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { loadImpl(freqItemsets, sample) } - def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d7a74db0b1fd8..b08da4fb55034 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -279,7 +279,7 @@ class DenseMatrix @Since("1.3.0") ( } override def hashCode: Int = { - com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray) + com.google.common.base.Objects.hashCode(numRows: Integer, numCols: Integer, toArray) } private[mllib] def toBreeze: BM[Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 7abb1bf7ce967..a8c32f72bfdeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, - description : String, - normalizationMethod : RegressionNormalizationMethodType, + model: GeneralizedLinearModel, + description: String, + normalizationMethod: RegressionNormalizationMethodType, threshold: Double) extends PMMLModelExport { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index b5b824bb9c9b6..255c6140e5410 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -26,14 +26,14 @@ import org.apache.spark.mllib.clustering.KMeansModel /** * PMML Model Export for KMeansModel class */ -private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ +private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModelExport{ populateKMeansPMML(model) /** * Export the input KMeansModel model to PMML format. */ - private def populateKMeansPMML(model : KMeansModel): Unit = { + private def populateKMeansPMML(model: KMeansModel): Unit = { pmml.getHeader.setDescription("k-means clustering") if (model.clusterCenters.length > 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index af1f7e74c004d..c73774fcd8c46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -600,7 +600,7 @@ object DecisionTree extends Serializable with Logging { val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 66f0908c1250f..b373c2de3ea96 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node @Since("1.2.0") ( * @return predicted value */ @Since("1.1.0") - def predict(features: Vector) : Double = { + def predict(features: Vector): Double = { if (isLeaf) { predict.predict } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 094528e2ece06..240781bcd335b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -175,7 +175,7 @@ object LinearDataGenerator { nfeatures: Int, eps: Double, nparts: Int = 2, - intercept: Double = 0.0) : RDD[LabeledPoint] = { + intercept: Double = 0.0): RDD[LabeledPoint] = { val random = new Random(42) // Random values distributed uniformly in [-0.5, 0.5] val w = Array.fill(nfeatures)(random.nextDouble() - 0.5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index ee3c85d09a463..1a47344b68937 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -45,7 +45,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) + val weightsMat = new DoubleMatrix(1, weights.length, weights: _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index 1142102bb040e..50441816ece3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { - override def maxWaitTimeMillis : Int = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for null hypothesis using welch t-test") { // set parameters From 3119206b7188c23055621dfeaf6874f21c711a82 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 10 Jan 2016 15:41:22 -0800 Subject: [PATCH 1513/2089] [SPARK-12692][BUILD][GRAPHX] Scala style: Fix the style violation (Space before "," or ":") Fix the style violation (space before `,` and `:`). This PR is a followup for #10643. Author: Kousuke Saruta Closes #10683 from sarutak/SPARK-12692-followup-graphx. --- graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala | 5 ++--- .../org/apache/spark/graphx/impl/ReplicatedVertexView.scala | 4 ++-- .../apache/spark/graphx/impl/ShippableVertexPartition.scala | 4 ++-- .../apache/spark/graphx/impl/VertexPartitionBaseOps.scala | 2 +- .../main/scala/org/apache/spark/graphx/lib/PageRank.scala | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index fc36e12dd2aed..d048fb5d561f3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.SparkException -import org.apache.spark.SparkContext._ import org.apache.spark.graphx.lib._ import org.apache.spark.rdd.RDD @@ -379,7 +378,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]] */ def personalizedPageRank(src: VertexId, tol: Double, - resetProb: Double = 0.15) : Graph[Double, Double] = { + resetProb: Double = 0.15): Graph[Double, Double] = { PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src)) } @@ -392,7 +391,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]] */ def staticPersonalizedPageRank(src: VertexId, numIter: Int, - resetProb: Double = 0.15) : Graph[Double, Double] = { + resetProb: Double = 0.15): Graph[Double, Double] = { PageRank.runWithOptions(graph, numIter, resetProb, Some(src)) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index f79f9c7ec448f..b4bec7cba5207 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -41,8 +41,8 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { - new ReplicatedVertexView(edges_, hasSrcId, hasDstId) + _edges: EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + new ReplicatedVertexView(_edges, hasSrcId, hasDstId) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index 3f203c4eca485..96d807f9f9ceb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -102,8 +102,8 @@ class ShippableVertexPartition[VD: ClassTag]( extends VertexPartitionBase[VD] { /** Return a new ShippableVertexPartition with the specified routing table. */ - def withRoutingTable(routingTable_ : RoutingTablePartition): ShippableVertexPartition[VD] = { - new ShippableVertexPartition(index, values, mask, routingTable_) + def withRoutingTable(_routingTable: RoutingTablePartition): ShippableVertexPartition[VD] = { + new ShippableVertexPartition(index, values, mask, _routingTable) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index f508b483a2f1b..7c680dcb99cd2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.collection.BitSet * example, [[VertexPartition.VertexPartitionOpsConstructor]]). */ private[graphx] abstract class VertexPartitionBaseOps - [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor] + [VD: ClassTag, Self[X] <: VertexPartitionBase[X]: VertexPartitionBaseOpsConstructor] (self: Self[VD]) extends Serializable with Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 35b26c998e1d9..46faad2e68c50 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -138,7 +138,7 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId , id: VertexId) => resetProb * delta(src, id) + (src: VertexId, id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } From 3ab0138b0fe0f9208b4b476855294a7c729583b7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 Jan 2016 19:59:01 -0800 Subject: [PATCH 1514/2089] [SPARK-12734][BUILD] Fix Netty exclusion and use Maven Enforcer to prevent future bugs Netty classes are published under multiple artifacts with different names, so our build needs to exclude the `io.netty:netty` and `org.jboss.netty:netty` versions of the Netty artifact. However, our existing exclusions were incomplete, leading to situations where duplicate Netty classes would wind up on the classpath and cause compile errors (or worse). This patch fixes the exclusion issue by adding more exclusions and uses Maven Enforcer's [banned dependencies](https://maven.apache.org/enforcer/enforcer-rules/bannedDependencies.html) rule to prevent these classes from accidentally being reintroduced. I also updated `dev/test-dependencies.sh` to run `mvn validate` so that the enforcer rules can run as part of pull request builds. /cc rxin srowen pwendell. I'd like to backport at least the exclusion portion of this fix to `branch-1.5` in order to fix the documentation publishing job, which fails nondeterministically due to incompatible versions of Netty classes taking precedence on the compile-time classpath. Author: Josh Rosen Author: Josh Rosen Closes #10672 from JoshRosen/enforce-netty-exclusions. --- dev/deps/spark-deps-hadoop-2.2 | 1 - dev/deps/spark-deps-hadoop-2.3 | 1 - dev/deps/spark-deps-hadoop-2.4 | 1 - dev/deps/spark-deps-hadoop-2.6 | 1 - dev/test-dependencies.sh | 17 +++------- examples/pom.xml | 4 +++ pom.xml | 57 +++++++++++++++++++++++++++++++++- 7 files changed, 64 insertions(+), 18 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index e4373f79f7922..13d1b0e950480 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -142,7 +142,6 @@ metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar -netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 7478181406d07..d7deaa0a24541 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -133,7 +133,6 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index faffb8bf398a5..7ad2212ed5ae7 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -134,7 +134,6 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e703c7acd3876..7f8518927aec4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -140,7 +140,6 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 424ce6ad7663c..def87aa4087e3 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -70,19 +70,10 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de # Generate manifests for each Hadoop profile: for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do echo "Performing Maven install for $HADOOP_PROFILE" - $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar install:install -q \ - -pl '!assembly' \ - -pl '!examples' \ - -pl '!external/flume-assembly' \ - -pl '!external/kafka-assembly' \ - -pl '!external/twitter' \ - -pl '!external/flume' \ - -pl '!external/mqtt' \ - -pl '!external/mqtt-assembly' \ - -pl '!external/zeromq' \ - -pl '!external/kafka' \ - -pl '!tags' \ - -DskipTests + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install -q + + echo "Performing Maven validate for $HADOOP_PROFILE" + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE validate -q echo "Generating dependency manifest for $HADOOP_PROFILE" mkdir -p dev/pr-deps diff --git a/examples/pom.xml b/examples/pom.xml index 1a0d5e5854642..6013085b10e84 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -111,6 +111,10 @@ org.jruby jruby-complete + + io.netty + netty +
    diff --git a/pom.xml b/pom.xml index 0eac212754320..cbed36c1eac16 100644 --- a/pom.xml +++ b/pom.xml @@ -519,6 +519,12 @@ ${akka.group} akka-remote_${scala.binary.version} ${akka.version} + + + io.netty + netty + + ${akka.group} @@ -762,6 +768,10 @@ org.jboss.netty netty + + io.netty + netty + @@ -822,6 +832,10 @@ junit junit + + io.netty + netty + @@ -922,6 +936,10 @@ org.jboss.netty netty + + io.netty + netty + commons-logging commons-logging @@ -946,6 +964,10 @@ org.jboss.netty netty + + io.netty + netty + javax.servlet servlet-api @@ -975,6 +997,10 @@ org.jboss.netty netty + + io.netty + netty + javax.servlet servlet-api @@ -1003,6 +1029,10 @@ org.jboss.netty netty + + io.netty + netty + javax.servlet servlet-api @@ -1031,6 +1061,10 @@ org.jboss.netty netty + + io.netty + netty + javax.servlet servlet-api @@ -1046,6 +1080,16 @@ zookeeper ${zookeeper.version} ${hadoop.deps.scope} + + + org.jboss.netty + netty + + + io.netty + netty + + org.codehaus.jackson @@ -1771,7 +1815,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4 + 1.4.1 enforce-versions @@ -1786,6 +1830,17 @@ ${java.version} + + + io.netty:netty + org.jboss.netty + + + + io.netty:netty:3.4.0.Final:*:test + + true + From 6439a82503e900ae2e5c3cda5d10ac20dfd6e53f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sun, 10 Jan 2016 20:04:50 -0800 Subject: [PATCH 1515/2089] [SPARK-3873][BUILD] Enable import ordering error checking. Turn import ordering violations into build errors, plus a few adjustments to account for how the checker behaves. I'm a little on the fence about whether the existing code is right, but it's easier to appease the checker than to discuss what's the more correct order here. Plus a few fixes to imports that cropped in since my recent cleanups. Author: Marcelo Vanzin Closes #10612 from vanzin/SPARK-3873-enable. --- .../streaming/KinesisWordCountASL.scala | 2 +- .../kinesis/KinesisInputDStream.scala | 2 +- .../streaming/kinesis/KinesisReceiver.scala | 2 +- .../streaming/kinesis/KinesisUtils.scala | 2 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../kinesis/KinesisCheckpointerSuite.scala | 4 ++-- .../kinesis/KinesisReceiverSuite.scala | 2 +- .../kinesis/KinesisStreamSuite.scala | 4 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 2 +- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 7 +++---- .../spark/mllib/tree/DecisionTree.scala | 2 +- .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../spark/mllib/tree/RandomForest.scala | 2 +- scalastyle-config.xml | 21 +++++++++---------- .../spark/sql/catalyst/CatalystQl.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 3 +-- .../apache/spark/sql/execution/SparkQl.scala | 2 +- .../apache/spark/sql/execution/Window.scala | 2 +- .../datasources/WriterContainer.scala | 2 +- .../sql/execution/datasources/bucket.scala | 3 ++- .../datasources/parquet/ParquetRelation.scala | 10 ++++----- .../sql/execution/datasources/rules.scala | 2 +- .../parquet/ParquetReadBenchmark.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 5 +++-- .../apache/spark/sql/hive/SQLBuilder.scala | 2 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 8 +++---- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../spark/sql/hive/ErrorPositionSuite.scala | 2 +- .../sql/hive/LogicalPlanToSQLSuite.scala | 2 +- .../spark/sql/hive/SQLBuilderTest.scala | 2 +- .../hive/execution/HiveComparisonTest.scala | 4 ++-- .../sql/sources/BucketedWriteSuite.scala | 4 ++-- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../spark/streaming/util/StateMap.scala | 2 +- .../spark/streaming/StateMapSuite.scala | 5 ++--- 36 files changed, 62 insertions(+), 64 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index de749626ec09c..6a73bc0e30d05 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.util.Random -import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.auth.{BasicAWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 3321c7527edb4..5223c81a8e0e0 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -24,10 +24,10 @@ import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.streaming.{Duration, StreamingContext, Time} private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index abb9b6cd32f1c..48ee2a959786b 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2de6195716e5c..15ac588b82587 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -24,9 +24,9 @@ import com.amazonaws.services.kinesis.model.Record import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, StreamingContext} import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.{Duration, StreamingContext} object KinesisUtils { /** diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index d85b4cda8ce98..e6f504c4e54dd 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.kinesis import org.scalatest.BeforeAndAfterAll -import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) extends KinesisFunSuite with BeforeAndAfterAll { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala index 645e64a0bc3a0..e1499a8220991 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.kinesis -import java.util.concurrent.{TimeoutException, ExecutorService} +import java.util.concurrent.{ExecutorService, TimeoutException} import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ @@ -28,7 +28,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ import org.scalatest.mock.MockitoSugar diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index e5c70db554a27..fd15b6ccdc889 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -27,8 +27,8 @@ import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} import org.apache.spark.util.Utils diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 78263f9dca65c..ee6a5f0390d04 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -25,10 +25,11 @@ import scala.util.Random import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} @@ -38,7 +39,6 @@ import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext} abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 7443097492d82..7a651a37ac77e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} -import org.apache.spark.sql.types.{DoubleType, DataType, StructType} +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** * Parameters for Decision Tree-based algorithms. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index ffae0e7ed0ca4..1250bc1a07cb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.fpm -import java.lang.{Iterable => JavaIterable} import java.{util => ju} +import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,16 +29,15 @@ import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c73774fcd8c46..07ba0d8ccb2a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -25,10 +25,10 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl._ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 729a211574822..1b71256c585bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -22,8 +22,8 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index a684cdd18c2fc..570a76f960796 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -26,9 +26,9 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 9714c46fe99a0..2439a1f715aba 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -187,6 +187,16 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + @@ -207,17 +217,6 @@ This file is divided into 3 sections: - - - - java,scala,3rdParty,spark - javax?\..* - scala\..* - (?!org\.apache\.spark\.).* - org\.apache\.spark\..* - - - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 1eda4a9a97644..2e3cc0bfde7c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -22,10 +22,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0eb915fdc1691..17351ef0685a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,9 +21,9 @@ import java.sql.{Date, Timestamp} import org.json4s.JsonAST._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9b5c86a8984be..3bd733fa2d26c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -25,8 +25,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.{LegacyTypeStringParser, DataTypeParser} - +import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser} /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index a322688a259e2..f3e89ef4a71f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} -import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index be885397a7d40..168b5ab0316d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -22,6 +22,7 @@ import java.util import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -29,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} -import org.apache.spark.{SparkEnv, TaskContext} /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 4f8524f4b967c..40ecdb8e4403e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{IntegerType, StructType, StringType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 82287c8967134..9976829638d70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.mapreduce.TaskAttemptContext + import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} +import org.apache.spark.sql.sources.{HadoopFsRelation, HadoopFsRelationProvider, OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index ca8d010090401..7754edc803d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Logger => JLogger} import java.util.{List => JList} +import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -32,24 +32,24 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl +import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType -import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} +import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index d484403d1c641..1c773e69275db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cab6abde6da23..ae95b50e1ee76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -21,9 +21,9 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Try +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.util.{Benchmark, Utils} -import org.apache.spark.{SparkConf, SparkContext} /** * Benchmark to measure parquet read performance. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 5b13dbe47370e..d1b1c0d8d8bc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -24,11 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions._ @@ -38,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, AnalyzeTable, DropTable, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} import org.apache.spark.sql.types._ import org.apache.spark.sql.AnalysisException diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1c910051faccf..61e3f183bb42d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing @@ -26,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.{DataFrame, SQLContext} /** * A builder class used to convert a resolved logical plan into a SQL query string. Note that this diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e76c18fa528f3..56cab1aee89df 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -22,21 +22,21 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.hive.ql.exec._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} +import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{analysis, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.sequenceOption -import org.apache.spark.sql.catalyst.{InternalRow, analysis} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index d26cb48479066..033746d42f557 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -37,8 +37,8 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index a2d283622ca52..e72a18a716b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -21,8 +21,8 @@ import scala.util.Try import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 0e81acf532a03..9a8a9c51183da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index cf4a3fdd88806..a5e209ac9db3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{DataFrame, QueryTest} abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { protected def checkSQL(e: Expression, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 57358a07840e2..fd3339a66bec0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,10 +27,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} +import org.apache.spark.sql.hive.test.TestHive /** * Allows the creations of tests that execute the same query against both hive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 579da0291f291..7f1745705aaaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest} class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b186d297610e2..86f01d2168729 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.Utils import org.apache.spark.streaming.scheduler.JobGenerator +import org.apache.spark.util.Utils private[streaming] class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 4e5baebaae04b..4ccc905b275d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -25,7 +25,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} +import org.apache.spark.serializer.{KryoInputObjectInputBridge, KryoOutputObjectOutputBridge} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ import org.apache.spark.util.collection.OpenHashMap diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index ea32bbf95ce59..da0430e263b5f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.rdd.MapWithStateRDDRecord - import scala.collection.{immutable, mutable, Map} import scala.reflect.ClassTag import scala.util.Random import com.esotericsoftware.kryo.{Kryo, KryoSerializable} -import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer._ +import org.apache.spark.streaming.rdd.MapWithStateRDDRecord import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} class StateMapSuite extends SparkFunSuite { From 008a55828512056313da2626fd378e8aa1534790 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 10 Jan 2016 23:33:57 -0800 Subject: [PATCH 1516/2089] [SPARK-4628][BUILD] Add a resolver to MiMaBuild.scala for mqttv3(1.0.1). #10659 removed the repository `https://repo.eclipse.org/content/repositories/paho-releases` but it's needed by MiMa because `spark-streaming-mqtt(1.6.0)` depends on `mqttv3(1.0.1)` and it is provided by the removed repository and maven-central provide only `mqttv3(1.0.2)` for now. Otherwise, if `mqttv3(1.0.1)` is absent from the local repository, dev/mima should fail. JoshRosen Do you have any other better idea? Author: Kousuke Saruta Closes #10688 from sarutak/SPARK-4628-followup. --- project/MimaBuild.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 9ba9f8286f10c..41856443af49b 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,11 +91,16 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" + // The resolvers setting for MQTT Repository is needed for mqttv3(1.0.1) + // because spark-streaming-mqtt(1.6.0) depends on it. + // Remove the setting on updating previousSparkVersion. val previousSparkVersion = "1.6.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), - binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) + binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value), + sbt.Keys.resolvers += + "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases") } } From f13c7f8f7dc8766b0a42406b5c3639d6be55cf33 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 Jan 2016 00:31:29 -0800 Subject: [PATCH 1517/2089] [SPARK-12734][HOTFIX][TEST-MAVEN] Fix bug in Netty exclusions This is a hotfix for a build bug introduced by the Netty exclusion changes in #10672. We can't exclude `io.netty:netty` because Akka depends on it. There's not a direct conflict between `io.netty:netty` and `io.netty:netty-all`, because the former puts classes in the `org.jboss.netty` namespace while the latter uses the `io.netty` namespace. However, there still is a conflict between `org.jboss.netty:netty` and `io.netty:netty`, so we need to continue to exclude the JBoss version of that artifact. While the diff here looks somewhat large, note that this is only a revert of a some of the changes from #10672. You can see the net changes in pom.xml at https://github.com/apache/spark/compare/3119206b7188c23055621dfeaf6874f21c711a82...5211ab8#diff-600376dffeb79835ede4a0b285078036 Author: Josh Rosen Closes #10693 from JoshRosen/netty-hotfix. --- dev/deps/spark-deps-hadoop-2.2 | 1 + dev/deps/spark-deps-hadoop-2.3 | 1 + dev/deps/spark-deps-hadoop-2.4 | 1 + dev/deps/spark-deps-hadoop-2.6 | 1 + examples/pom.xml | 4 --- pom.xml | 50 +++++----------------------------- 6 files changed, 11 insertions(+), 47 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 13d1b0e950480..e4373f79f7922 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -142,6 +142,7 @@ metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar +netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index d7deaa0a24541..7478181406d07 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -133,6 +133,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar mx4j-3.0.2.jar +netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 7ad2212ed5ae7..faffb8bf398a5 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -134,6 +134,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar mx4j-3.0.2.jar +netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 7f8518927aec4..e703c7acd3876 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -140,6 +140,7 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.2.jar mx4j-3.0.2.jar +netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar objenesis-1.2.jar opencsv-2.3.jar diff --git a/examples/pom.xml b/examples/pom.xml index 6013085b10e84..1a0d5e5854642 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -111,10 +111,6 @@ org.jruby jruby-complete - - io.netty - netty - diff --git a/pom.xml b/pom.xml index cbed36c1eac16..06cccf1df0bb2 100644 --- a/pom.xml +++ b/pom.xml @@ -519,12 +519,6 @@ ${akka.group} akka-remote_${scala.binary.version} ${akka.version} - - - io.netty - netty - - ${akka.group} @@ -768,10 +762,6 @@ org.jboss.netty netty - - io.netty - netty - @@ -832,10 +822,6 @@ junit junit - - io.netty - netty - @@ -936,10 +922,6 @@ org.jboss.netty netty - - io.netty - netty - commons-logging commons-logging @@ -964,10 +946,6 @@ org.jboss.netty netty - - io.netty - netty - javax.servlet servlet-api @@ -997,10 +975,6 @@ org.jboss.netty netty - - io.netty - netty - javax.servlet servlet-api @@ -1029,10 +1003,6 @@ org.jboss.netty netty - - io.netty - netty - javax.servlet servlet-api @@ -1061,10 +1031,6 @@ org.jboss.netty netty - - io.netty - netty - javax.servlet servlet-api @@ -1085,10 +1051,6 @@ org.jboss.netty netty - - io.netty - netty - @@ -1832,13 +1794,15 @@ - io.netty:netty + org.jboss.netty - - - io.netty:netty:3.4.0.Final:*:test - true From f253feff62f3eb3cce22bbec0874f317a61b0092 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Jan 2016 00:44:33 -0800 Subject: [PATCH 1518/2089] [SPARK-12539][FOLLOW-UP] always sort in partitioning writer address comments in #10498 , especially https://github.com/apache/spark/pull/10498#discussion_r49021259 Author: Wenchen Fan This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10638 from cloud-fan/bucket-write. --- .../datasources/WriterContainer.scala | 192 +++++------------- .../apache/spark/sql/sources/interfaces.scala | 3 - 2 files changed, 48 insertions(+), 147 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 40ecdb8e4403e..fff72872c13b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration @@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer( } } - private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { - val bucketIdIndex = partitionColumns.length - if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { - false - } else { - var i = partitionColumns.length - 1 - while (i >= 0) { - val dt = partitionColumns(i).dataType - if (key1.get(i, dt) != key2.get(i, dt)) return false - i -= 1 - } - true - } - } - - private def sortBasedWrite( - sorter: UnsafeKVExternalSorter, - iterator: Iterator[InternalRow], - getSortingKey: UnsafeProjection, - getOutputRow: UnsafeProjection, - getPartitionString: UnsafeProjection, - outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { - (key1, key2) => key1 != key2 - } else { - (key1, key2) => key1 == null || !sameBucket(key1, key2) - } - - val sortedIterator = sorter.sortedIterator() - var currentKey: UnsafeRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (needNewWriter(currentKey, sortedIterator.getKey)) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey, getPartitionString) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } - } - /** * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer( } def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) - var outputWritersCleared = false - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val getSortingKey = - UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) - - val sortingKeySchema = if (bucketSpec.isEmpty) { - StructType.fromAttributes(partitionColumns) - } else { // If it's bucketed, we should also consider bucket id as part of the key. - val fields = StructType.fromAttributes(partitionColumns) - .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) - StructType(fields) - } + val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns + + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // If there is no sorting columns, we set sorter to null and try the hash-based writing first, - // and fill the sorter if there are too many writers and we need to fall back on sorting. - // If there are sorting columns, then we have to sort the data anyway, and no need to try the - // hash-based writing first. - var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { - new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity } else { - null + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) } - while (iterator.hasNext && sorter == null) { - val inputRow = iterator.next() - // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. - val currentKey = getSortingKey(inputRow) - var currentWriter = outputWriters.get(currentKey) - - if (currentWriter == null) { - if (outputWriters.size < maxOpenFiles) { + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey, getPartitionString) - outputWriters.put(currentKey.copy(), currentWriter) - currentWriter.writeInternal(getOutputRow(inputRow)) - } else { - logInfo(s"Maximum partitions reached, falling back on sorting.") - sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - sorter.insertKV(currentKey, getOutputRow(inputRow)) } - } else { - currentWriter.writeInternal(getOutputRow(inputRow)) - } - } - // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort, or there are sorting columns and we need to sort the whole data set. - if (sorter != null) { - sortBasedWrite( - sorter, - iterator, - getSortingKey, - getOutputRow, - getPartitionString, - outputWriters) + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } } commitTask() @@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer( abortTask() throw new SparkException("Task failed while writing rows.", cause) } - - def clearOutputWriters(): Unit = { - if (!outputWritersCleared) { - outputWriters.asScala.values.foreach(_.close()) - outputWriters.clear() - outputWritersCleared = true - } - } - - def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { - case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c35f33132f602..9f3607369c30f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -162,7 +162,6 @@ trait HadoopFsRelationProvider { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation - // TODO: expose bucket API to users. private[sql] def createRelation( sqlContext: SQLContext, paths: Array[String], @@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable { dataSchema: StructType, context: TaskAttemptContext): OutputWriter - // TODO: expose bucket API to users. private[sql] def newInstance( path: String, bucketId: Option[Int], @@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ - // TODO: expose bucket API to users. private[sql] def bucketSpec: Option[BucketSpec] = None private class FileStatusCache { From bd723bd53d9a28239b60939a248a4ea13340aad8 Mon Sep 17 00:00:00 2001 From: Udo Klein Date: Mon, 11 Jan 2016 09:30:08 +0000 Subject: [PATCH 1519/2089] removed lambda from sortByKey() According to the documentation the sortByKey method does not take a lambda as an argument, thus the example is flawed. Removed the argument completely as this will default to ascending sort. Author: Udo Klein Closes #10640 from udoklein/patch-1. --- examples/src/main/python/sort.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index f6b0ecb02c100..b6c2916254056 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -30,7 +30,7 @@ lines = sc.textFile(sys.argv[1], 1) sortedCount = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ - .sortByKey(lambda x: x) + .sortByKey() # This is just a demo on how to bring all the sorted data back to a single node. # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() From 8fe928b4fe380ba527164bd413402abfed13c0e1 Mon Sep 17 00:00:00 2001 From: BrianLondon Date: Mon, 11 Jan 2016 09:32:06 +0000 Subject: [PATCH 1520/2089] [SPARK-12269][STREAMING][KINESIS] Update aws-java-sdk version The current Spark Streaming kinesis connector references a quite old version 1.9.40 of the AWS Java SDK (1.10.40 is current). Numerous AWS features including Kinesis Firehose are unavailable in 1.9. Those two versions of the AWS SDK in turn require conflicting versions of Jackson (2.4.4 and 2.5.3 respectively) such that one cannot include the current AWS SDK in a project that also uses the Spark Streaming Kinesis ASL. Author: BrianLondon Closes #10256 from BrianLondon/master. --- dev/deps/spark-deps-hadoop-2.2 | 8 ++++---- dev/deps/spark-deps-hadoop-2.3 | 8 ++++---- dev/deps/spark-deps-hadoop-2.4 | 8 ++++---- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- pom.xml | 6 +++--- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index e4373f79f7922..cd3ff293502ae 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -84,13 +84,13 @@ hadoop-yarn-server-web-proxy-2.2.0.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 7478181406d07..0985089ccea61 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -79,13 +79,13 @@ hadoop-yarn-server-web-proxy-2.3.0.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index faffb8bf398a5..50f062601c02b 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -79,13 +79,13 @@ hadoop-yarn-server-web-proxy-2.4.0.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e703c7acd3876..2b6ca983ad65e 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -85,13 +85,13 @@ htrace-core-3.0.4.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/pom.xml b/pom.xml index 06cccf1df0bb2..fc5cf970e0601 100644 --- a/pom.xml +++ b/pom.xml @@ -152,9 +152,9 @@ 1.7.7 hadoop2 0.7.1 - 1.4.0 + 1.6.1 - 0.10.1 + 0.10.2 4.3.2 @@ -167,7 +167,7 @@ ${scala.version} org.scala-lang 1.9.13 - 2.4.4 + 2.5.3 1.1.2 1.1.2 1.2.0-incubating From 9559ac5f74434cf4bf611bdcde9a216d39799826 Mon Sep 17 00:00:00 2001 From: Anatoliy Plastinin Date: Mon, 11 Jan 2016 10:28:57 -0800 Subject: [PATCH 1521/2089] [SPARK-12744][SQL] Change parsing JSON integers to timestamps to treat integers as number of seconds JIRA: https://issues.apache.org/jira/browse/SPARK-12744 This PR makes parsing JSON integers to timestamps consistent with casting behavior. Author: Anatoliy Plastinin Closes #10687 from antlypls/fix-json-timestamp-parsing. --- .../datasources/json/JacksonParser.scala | 2 +- .../execution/datasources/json/JsonSuite.scala | 17 +++++++++++++++-- .../datasources/json/TestJsonData.scala | 4 ++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 2e3fe3da15389..b2f5c1e96421d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -90,7 +90,7 @@ object JacksonParser { DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => - parser.getLongValue * 1000L + parser.getLongValue * 1000000L case (_, StringType) => val writer = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b3b6b7df0c1d1..4ab148065a476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -83,9 +83,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber * 1000L)), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), @@ -1465,4 +1465,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("Casting long as timestamp") { + withTempTable("jsonTable") { + val schema = (new StructType).add("ts", TimestampType) + val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select ts from jsonTable"), + Row(java.sql.Timestamp.valueOf("2016-01-02 03:04:05")) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index cb61f7eeca0de..a0836058d3c74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -205,6 +205,10 @@ private[json] trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) + def timestampAsLong: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"ts":1451732645}""" :: Nil) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) From b313badaa049f847f33663c61cd70ee2f2cbebac Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 11 Jan 2016 11:29:15 -0800 Subject: [PATCH 1522/2089] [STREAMING][MINOR] Typo fixes Author: Jacek Laskowski Closes #10698 from jaceklaskowski/streaming-kafka-typo-fixes. --- .../scala/org/apache/spark/streaming/kafka/KafkaCluster.scala | 2 +- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index c4e18d92eefa9..d7885d7cc1ae1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -385,7 +385,7 @@ object KafkaCluster { val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => val hpa = hp.split(":") if (hpa.size == 1) { - throw new SparkException(s"Broker not the in correct format of : [$brokers]") + throw new SparkException(s"Broker not in the correct format of : [$brokers]") } (hpa(0), hpa(1).toInt) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 603be22818206..4eb155645867b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -156,7 +156,7 @@ class KafkaRDD[ var requestOffset = part.fromOffset var iter: Iterator[MessageAndOffset] = null - // The idea is to use the provided preferred host, except on task retry atttempts, + // The idea is to use the provided preferred host, except on task retry attempts, // to minimize number of kafka metadata requests private def connectLeader: SimpleConsumer = { if (context.attemptNumber > 0) { From a44991453a43615028083ba9546f5cd93112f6bd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 Jan 2016 12:56:43 -0800 Subject: [PATCH 1523/2089] [SPARK-12734][HOTFIX] Build changes must trigger all tests; clean after install in dep tests This patch fixes a build/test issue caused by the combination of #10672 and a latent issue in the original `dev/test-dependencies` script. First, changes which _only_ touched build files were not triggering full Jenkins runs, making it possible for a build change to be merged even though it could cause failures in other tests. The `root` build module now depends on `build`, so all tests will now be run whenever a build-related file is changed. I also added a `clean` step to the Maven install step in `dev/test-dependencies` in order to address an issue where the dummy JARs stuck around and caused "multiple assembly JARs found" errors in tests. /cc zsxwing Author: Josh Rosen Closes #10704 from JoshRosen/fix-build-test-problems. --- dev/sparktestsupport/modules.py | 2 +- dev/test-dependencies.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1fc6596164124..93a8c15e3ec30 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -426,7 +426,7 @@ def contains_file(self, filename): # No other modules should directly depend on this module. root = Module( name="root", - dependencies=[], + dependencies=[build], # Changes to build should trigger all tests. source_file_regexes=[], # In order to run all of the tests, enable every test profile: build_profile_flags=list(set( diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index def87aa4087e3..3cb5d2be2a91a 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -70,7 +70,7 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de # Generate manifests for each Hadoop profile: for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do echo "Performing Maven install for $HADOOP_PROFILE" - $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install -q + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install clean -q echo "Performing Maven validate for $HADOOP_PROFILE" $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE validate -q From a767ee8a0599f5482717493a3298413c65d8ff89 Mon Sep 17 00:00:00 2001 From: Brandon Bradley Date: Mon, 11 Jan 2016 14:21:50 -0800 Subject: [PATCH 1524/2089] [SPARK-12758][SQL] add note to Spark SQL Migration guide about TimestampType casting Warning users about casting changes. Author: Brandon Bradley Closes #10708 from blbradley/spark-12758. --- docs/sql-programming-guide.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b058833616433..bc89c781562bd 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2151,6 +2151,11 @@ options. ... {% endhighlight %} + - From Spark 1.6, LongType casts to TimestampType expect seconds instead of microseconds. This + change was made to match the behavior of Hive 1.2 for more consistent type casting to TimestampType + from numeric types. See [SPARK-11724](https://issues.apache.org/jira/browse/SPARK-11724) for + details. + ## Upgrading From Spark SQL 1.4 to 1.5 - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with From ee4ee02b86be8756a6d895a2e23e80862134a6d3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 11 Jan 2016 14:43:25 -0800 Subject: [PATCH 1525/2089] [SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do. Author: Yanbo Liang Closes #10552 from yanboliang/spark-12603. --- .../python/mllib/gaussian_mixture_model.py | 4 +++ .../examples/mllib/DenseGaussianMixture.scala | 6 ++++ .../python/GaussianMixtureModelWrapper.scala | 4 +++ .../clustering/GaussianMixtureModel.scala | 2 +- python/pyspark/mllib/clustering.py | 35 ++++++++++++------- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index 2cb8010cdc07f..69e836fc1d06a 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -62,5 +62,9 @@ def parseVector(line): for i in range(args.k): print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, "sigma = ", model.gaussians[i].sigma.toArray())) + print("\n") + print(("The membership value of each vector to all mixture components (first 100): ", + model.predictSoft(data).take(100))) + print("\n") print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index 1fce4ba7efd60..90b817b23e156 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -58,6 +58,12 @@ object DenseGaussianMixture { (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } + println("The membership value of each vector to all mixture components (first <= 100):") + val membership = clusters.predictSoft(data) + membership.take(100).foreach { x => + print(" " + x.mkString(",")) + } + println() println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 6a3b20c88d2d2..a689b09341450 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -40,5 +40,9 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava) } + def predictSoft(point: Vector): Vector = { + Vectors.dense(model.predictSoft(point)) + } + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 16bc45bcb627f..42fe27024f8fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -75,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") ( */ @Since("1.5.0") def predict(point: Vector): Int = { - val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + val r = predictSoft(point) r.indexOf(r.max) } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index d22a7f4c3b167..580cb512d8025 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -202,16 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, - ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) + ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] - True + False >>> labels[4]==labels[5] True + >>> model.predict([-0.1,-0.05]) + 0 + >>> softPredicted = model.predictSoft([-0.1,-0.05]) + >>> abs(softPredicted[0] - 1.0) < 0.001 + True + >>> abs(softPredicted[1] - 0.0) < 0.001 + True + >>> abs(softPredicted[2] - 0.0) < 0.001 + True >>> path = tempfile.mkdtemp() >>> model.save(sc, path) @@ -277,26 +286,27 @@ def k(self): @since('1.3.0') def predict(self, x): """ - Find the cluster to which the points in 'x' has maximum membership - in this model. + Find the cluster to which the point 'x' or each point in RDD 'x' + has maximum membership in this model. - :param x: RDD of data points. - :return: cluster_labels. RDD of cluster labels. + :param x: vector or RDD of vector represents data points. + :return: cluster label or RDD of cluster labels. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels else: - raise TypeError("x should be represented by an RDD, " - "but got %s." % type(x)) + z = self.predictSoft(x) + return z.argmax() @since('1.3.0') def predictSoft(self, x): """ - Find the membership of each point in 'x' to all mixture components. + Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: RDD of data points. - :return: membership_matrix. RDD of array of double values. + :param x: vector or RDD of vector represents data points. + :return: the membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -304,8 +314,7 @@ def predictSoft(self, x): _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: - raise TypeError("x should be represented by an RDD, " - "but got %s." % type(x)) + return self.call("predictSoft", _convert_to_vector(x)).toArray() @classmethod @since('1.5.0') From 4f8eefa36bb90812aac61ac7a762c9452de666bf Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 11 Jan 2016 14:48:35 -0800 Subject: [PATCH 1526/2089] [SPARK-12685][MLLIB] word2vec trainWordsCount gets overflow jira: https://issues.apache.org/jira/browse/SPARK-12685 the log of `word2vec` reports trainWordsCount = -785727483 during computation over a large dataset. Update the priority as it will affect the computation process. `alpha = learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))` Author: Yuhao Yang Closes #10627 from hhbyyh/w2voverflow. --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a7e1b76df6a7d..dc5d070890d5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -151,7 +151,7 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private var window = 5 - private var trainWordsCount = 0 + private var trainWordsCount = 0L private var vocabSize = 0 @transient private var vocab: Array[VocabWord] = null @transient private var vocabHash = mutable.HashMap.empty[String, Int] @@ -159,13 +159,13 @@ class Word2Vec extends Serializable with Logging { private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) + .filter(_._2 >= minCount) .map(x => VocabWord( x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) - .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) @@ -179,7 +179,7 @@ class Word2Vec extends Serializable with Logging { trainWordsCount += vocab(a).cn a += 1 } - logInfo("trainWordsCount = " + trainWordsCount) + logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") } private def createExpTable(): Array[Float] = { @@ -332,7 +332,7 @@ class Word2Vec extends Serializable with Logging { val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount From bbea88852ce6a3127d071ca40dbca2d042f9fbcf Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 11 Jan 2016 14:55:44 -0800 Subject: [PATCH 1527/2089] [SPARK-10809][MLLIB] Single-document topicDistributions method for LocalLDAModel jira: https://issues.apache.org/jira/browse/SPARK-10809 We could provide a single-document topicDistributions method for LocalLDAModel to allow for quick queries which avoid RDD operations. Currently, the user must use an RDD of documents. add some missing assert too. Author: Yuhao Yang Closes #9484 from hhbyyh/ldaTopicPre. --- .../spark/mllib/clustering/LDAModel.scala | 26 +++++++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 15 ++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 2fce3ff641101..b30ecb80209d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -387,6 +387,32 @@ class LocalLDAModel private[spark] ( } } + /** + * Predicts the topic mixture distribution for a document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * Note this means to allow quick query for single document. For batch documents, please refer + * to [[topicDistributions()]] to avoid overhead. + * + * @param document document to predict topic mixture distributions for + * @return topic mixture distribution for the document + */ + @Since("2.0.0") + def topicDistribution(document: Vector): Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + if (document.numNonzeros == 0) { + Vectors.zeros(this.k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + document, + expElogbeta, + this.docConcentration.toBreeze, + gammaShape, + this.k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + /** * Java-friendly version of [[topicDistributions]] */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index faef60e084cc1..ea23196d2c801 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -366,7 +366,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { (0, 0.99504), (1, 0.99504), (1, 0.99504), (1, 0.99504)) - val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + val actualPredictions = ldaModel.topicDistributions(docs).cache() + val topTopics = actualPredictions.map { case (id, topics) => // convert results to expectedPredictions format, which only has highest probability topic val topicsBz = topics.toBreeze.toDenseVector (id, (argmax(topicsBz), max(topicsBz))) @@ -374,9 +375,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { .values .collect() - expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => - expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + expectedPredictions.zip(topTopics).foreach { case (expected, actual) => + assert(expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)) } + + docs.collect() + .map(doc => ldaModel.topicDistribution(doc._2)) + .zip(actualPredictions.map(_._2).collect()) + .foreach { case (single, batch) => + assert(single ~== batch relTol 1E-3D) + } + actualPredictions.unpersist() } test("OnlineLDAOptimizer with asymmetric prior") { From fe9eb0b0ce397aeb40a32f8231d2ce8c17d7a609 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 11 Jan 2016 16:29:37 -0800 Subject: [PATCH 1528/2089] [SPARK-12576][SQL] Enable expression parsing in CatalystQl The PR allows us to use the new SQL parser to parse SQL expressions such as: ```1 + sin(x*x)``` We enable this functionality in this PR, but we will not start using this actively yet. This will be done as soon as we have reached grammar parity with the existing parser stack. cc rxin Author: Herman van Hovell Closes #10649 from hvanhovell/SPARK-12576. --- .../sql/catalyst/parser/SelectClauseParser.g | 7 + .../spark/sql/catalyst/CatalystQl.scala | 59 ++++--- .../sql/catalyst/parser/ParseDriver.scala | 24 ++- .../spark/sql/catalyst/CatalystQlSuite.scala | 151 ++++++++++++++++-- .../spark/sql/hive/ExtendedHiveQlParser.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 19 +-- .../spark/sql/hive/ErrorPositionSuite.scala | 5 +- .../apache/spark/sql/hive/HiveQlSuite.scala | 2 +- 9 files changed, 217 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g index 2d2bafb1ee34f..f18b6ec496f8f 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g @@ -131,6 +131,13 @@ selectItem : (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) | + namedExpression + ; + +namedExpression +@init { gParent.pushMsg("select named expression", state); } +@after { gParent.popMsg(state); } + : ( expression ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? ) -> ^(TOK_SELEXPR expression identifier*) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 2e3cc0bfde7c7..c87b6c8e95436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -30,6 +30,12 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler +private[sql] object CatalystQl { + val parser = new CatalystQl + def parseExpression(sql: String): Expression = parser.parseExpression(sql) + def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql) +} + /** * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. */ @@ -41,16 +47,13 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { } } - /** - * Returns the AST for the given SQL string. + * The safeParse method allows a user to focus on the parsing/AST transformation logic. This + * method will take care of possible errors during the parsing process. */ - protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf) - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String): LogicalPlan = { + protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = { try { - createPlan(sql, ParseDriver.parse(sql, conf)) + toResult(ast) } catch { case e: MatchError => throw e case e: AnalysisException => throw e @@ -58,26 +61,39 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { throw new AnalysisException(e.getMessage) case e: NotImplementedError => throw new AnalysisException( - s""" - |Unsupported language features in query: $sql - |${getAst(sql).treeString} + s"""Unsupported language features in query + |== SQL == + |$sql + |== AST == + |${ast.treeString} + |== Error == |$e + |== Stacktrace == |${e.getStackTrace.head} """.stripMargin) } } - protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree) - - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = getAst(ddl) - assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.children - val colList = tableOps - .find(_.text == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")) - - colList.children.map(nodeToAttribute) + /** Creates LogicalPlan for a given SQL string. */ + def parsePlan(sql: String): LogicalPlan = + safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan) + + /** Creates Expression for a given SQL string. */ + def parseExpression(sql: String): Expression = + safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get) + + /** Creates TableIdentifier for a given SQL string. */ + def parseTableIdentifier(sql: String): TableIdentifier = + safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent) + + def parseDdl(sql: String): Seq[Attribute] = { + safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast => + val Token("TOK_CREATETABLE", children) = ast + children + .find(_.text == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")) + .flatMap(_.children.map(nodeToAttribute)) + } } protected def getClauses( @@ -187,7 +203,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val keyMap = keyASTs.zipWithIndex.toMap val bitmasks: Seq[Int] = setASTs.map { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => columns.foldLeft(0)((bitmap, col) => { val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 0e93af8b92cd2..f8e4f21451192 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -28,7 +28,25 @@ import org.apache.spark.sql.AnalysisException * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver */ object ParseDriver extends Logging { - def parse(command: String, conf: ParserConf): ASTNode = { + /** Create an LogicalPlan ASTNode from a SQL command. */ + def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => + parser.statement().getTree + } + + /** Create an Expression ASTNode from a SQL command. */ + def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => + parser.namedExpression().getTree + } + + /** Create an TableIdentifier ASTNode from a SQL command. */ + def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => + parser.tableName().getTree + } + + private def parse( + command: String, + conf: ParserConf)( + toTree: SparkSqlParser => CommonTree): ASTNode = { logInfo(s"Parsing command: $command") // Setup error collection. @@ -44,7 +62,7 @@ object ParseDriver extends Logging { parser.configure(conf, reporter) try { - val result = parser.statement() + val result = toTree(parser) // Check errors. reporter.checkForErrors() @@ -57,7 +75,7 @@ object ParseDriver extends Logging { if (tree.token != null || tree.getChildCount == 0) tree else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) } - val tree = nonNullToken(result.getTree) + val tree = nonNullToken(result) // Make sure all boundaries are set. tree.setUnknownTokenBoundaries() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index d7204c3488313..ba9d2524a9551 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -17,36 +17,157 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { val parser = new CatalystQl() + test("test case insensitive") { + val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation) + assert(result === parser.parsePlan("seLect 1")) + assert(result === parser.parsePlan("select 1")) + assert(result === parser.parsePlan("SELECT 1")) + } + + test("test NOT operator with comparison operations") { + val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") + val expected = Project( + UnresolvedAlias( + Not( + GreaterThan(Literal(true), Literal(true))) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + test("support hive interval literal") { + def checkInterval(sql: String, result: CalendarInterval): Unit = { + val parsed = parser.parsePlan(sql) + val expected = Project( + UnresolvedAlias( + Literal(result) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + def checkYearMonth(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' YEAR TO MONTH", + CalendarInterval.fromYearMonthString(lit)) + } + + def checkDayTime(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' DAY TO SECOND", + CalendarInterval.fromDayTimeString(lit)) + } + + def checkSingleUnit(lit: String, unit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' $unit", + CalendarInterval.fromSingleUnitString(unit, lit)) + } + + checkYearMonth("123-10") + checkYearMonth("496-0") + checkYearMonth("-2-3") + checkYearMonth("-123-0") + + checkDayTime("99 11:22:33.123456789") + checkDayTime("-99 11:22:33.123456789") + checkDayTime("10 9:8:7.123456789") + checkDayTime("1 0:0:0") + checkDayTime("-1 0:0:0") + checkDayTime("1 0:0:1") + + for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { + checkSingleUnit("7", unit) + checkSingleUnit("-7", unit) + checkSingleUnit("0", unit) + } + + checkSingleUnit("13.123456789", "second") + checkSingleUnit("-13.123456789", "second") + } + + test("support scientific notation") { + def assertRight(input: String, output: Double): Unit = { + val parsed = parser.parsePlan("SELECT " + input) + val expected = Project( + UnresolvedAlias( + Literal(output) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + assertRight("9.0e1", 90) + assertRight("0.9e+2", 90) + assertRight("900e-1", 90) + assertRight("900.0E-1", 90) + assertRight("9.e+1", 90) + + intercept[AnalysisException](parser.parsePlan("SELECT .e3")) + } + + test("parse expressions") { + compareExpressions( + parser.parseExpression("prinln('hello', 'world')"), + UnresolvedFunction( + "prinln", Literal("hello") :: Literal("world") :: Nil, false)) + + compareExpressions( + parser.parseExpression("1 + r.r As q"), + Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")()) + + compareExpressions( + parser.parseExpression("1 - f('o', o(bar))"), + Subtract(Literal(1), + UnresolvedFunction("f", + Literal("o") :: + UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) :: + Nil, false))) + } + + test("table identifier") { + assert(TableIdentifier("q") === parser.parseTableIdentifier("q")) + assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q")) + intercept[AnalysisException](parser.parseTableIdentifier("")) + // TODO parser swallows third identifier. + // intercept[AnalysisException](parser.parseTableIdentifier("d.q.g")) + } + test("parse union/except/intersect") { - parser.createPlan("select * from t1 union all select * from t2") - parser.createPlan("select * from t1 union distinct select * from t2") - parser.createPlan("select * from t1 union select * from t2") - parser.createPlan("select * from t1 except select * from t2") - parser.createPlan("select * from t1 intersect select * from t2") - parser.createPlan("(select * from t1) union all (select * from t2)") - parser.createPlan("(select * from t1) union distinct (select * from t2)") - parser.createPlan("(select * from t1) union (select * from t2)") - parser.createPlan("select * from ((select * from t1) union (select * from t2)) t") + parser.parsePlan("select * from t1 union all select * from t2") + parser.parsePlan("select * from t1 union distinct select * from t2") + parser.parsePlan("select * from t1 union select * from t2") + parser.parsePlan("select * from t1 except select * from t2") + parser.parsePlan("select * from t1 intersect select * from t2") + parser.parsePlan("(select * from t1) union all (select * from t2)") + parser.parsePlan("(select * from t1) union distinct (select * from t2)") + parser.parsePlan("(select * from t1) union (select * from t2)") + parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t") } test("window function: better support of parentheses") { - parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + + parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + "order by 2) from windowData") - parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + + parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + "order by 2) from windowData") - parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + + parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + "order by 2) from windowData") - parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + + parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + "from windowData") - parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + + parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + "from windowData") - parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + + parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + "from windowData") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 395c8bff53f47..b22f424981325 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -38,7 +38,7 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim) + case statement => HiveQl.parsePlan(statement.trim) } protected lazy val dfs: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 43d84d507b20e..67228f3f3c9c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -414,8 +414,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(table.name, HiveQl.createPlan(viewText)) - case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText)) + case None => Subquery(table.name, HiveQl.parsePlan(viewText)) + case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText)) } } else { MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d1b1c0d8d8bc2..ca9ddf94c11a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -230,15 +230,16 @@ private[hive] object HiveQl extends SparkQl with Logging { CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) } - protected override def createPlan( - sql: String, - node: ASTNode): LogicalPlan = { - if (nativeCommands.contains(node.text)) { - HiveNativeCommand(sql) - } else { - nodeToPlan(node) match { - case NativePlaceholder => HiveNativeCommand(sql) - case plan => plan + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sql: String): LogicalPlan = { + safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast => + if (nativeCommands.contains(ast.text)) { + HiveNativeCommand(sql) + } else { + nodeToPlan(ast) match { + case NativePlaceholder => HiveNativeCommand(sql) + case plan => plan + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index e72a18a716b5c..14a466cfe9486 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -117,9 +117,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = ParseDriver.parse(query, hiveContext.conf) - def parseTree = - Try(quietly(ast.treeString)).getOrElse("") + def ast = ParseDriver.parsePlan(query, hiveContext.conf) + def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { val error = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f4a1a17422483..53d15c14cb3d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, M class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { private def extractTableDesc(sql: String): (HiveTable, Boolean) = { - HiveQl.createPlan(sql).collect { + HiveQl.parsePlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) }.head } From 473907adf6e37855ee31d0703b43d7170e26b4b9 Mon Sep 17 00:00:00 2001 From: wangfei Date: Mon, 11 Jan 2016 18:18:44 -0800 Subject: [PATCH 1529/2089] [SPARK-12742][SQL] org.apache.spark.sql.hive.LogicalPlanToSQLSuite failure due to Table already exists exception ``` [info] Exception encountered when attempting to run a suite with class name: org.apache.spark.sql.hive.LogicalPlanToSQLSuite *** ABORTED *** (325 milliseconds) [info] org.apache.spark.sql.AnalysisException: Table `t1` already exists.; [info] at org.apache.spark.sql.DataFrameWriter.saveAsTable(DataFrameWriter.scala:296) [info] at org.apache.spark.sql.DataFrameWriter.saveAsTable(DataFrameWriter.scala:285) [info] at org.apache.spark.sql.hive.LogicalPlanToSQLSuite.beforeAll(LogicalPlanToSQLSuite.scala:33) [info] at org.scalatest.BeforeAndAfterAll$class.beforeAll(BeforeAndAfterAll.scala:187) [info] at org.apache.spark.sql.hive.LogicalPlanToSQLSuite.beforeAll(LogicalPlanToSQLSuite.scala:23) [info] at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:253) [info] at org.apache.spark.sql.hive.LogicalPlanToSQLSuite.run(LogicalPlanToSQLSuite.scala:23) [info] at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:462) [info] at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:671) [info] at sbt.ForkMain$Run$2.call(ForkMain.java:296) [info] at sbt.ForkMain$Run$2.call(ForkMain.java:286) [info] at java.util.concurrent.FutureTask.run(FutureTask.java:266) [info] at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) [info] at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) [info] at java.lang.Thread.run(Thread.java:745) ``` /cc liancheng Author: wangfei Closes #10682 from scwf/fix-test. --- .../org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 9a8a9c51183da..2ee8150fb80d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -24,6 +24,9 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { import testImplicits._ protected override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") sqlContext.range(10).write.saveAsTable("t0") sqlContext From 36d493509d32d14b54af62f5f65e8fa750e7413d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 11 Jan 2016 18:42:26 -0800 Subject: [PATCH 1530/2089] [SPARK-12498][SQL][MINOR] BooleanSimplication simplification Scala syntax allows binary case classes to be used as infix operator in pattern matching. This PR makes use of this syntax sugar to make `BooleanSimplification` more readable. Author: Cheng Lian Closes #10445 from liancheng/boolean-simplification-simplification. --- .../sql/catalyst/expressions/literals.scala | 4 + .../sql/catalyst/optimizer/Optimizer.scala | 190 ++++++++---------- 2 files changed, 92 insertions(+), 102 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 17351ef0685a9..e0b020330278b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -28,6 +28,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ object Literal { + val TrueLiteral: Literal = Literal(true, BooleanType) + + val FalseLiteral: Literal = Literal(false, BooleanType) + def apply(v: Any): Literal = v match { case i: Int => Literal(i, IntegerType) case l: Long => Literal(l, LongType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f8121a733a8d2..b70bc184d0a5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} @@ -519,112 +520,97 @@ object OptimizeIn extends Rule[LogicalPlan] { object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case and @ And(left, right) => (left, right) match { - // true && r => r - case (Literal(true, BooleanType), r) => r - // l && true => l - case (l, Literal(true, BooleanType)) => l - // false && r => false - case (Literal(false, BooleanType), _) => Literal(false) - // l && false => false - case (_, Literal(false, BooleanType)) => Literal(false) - // a && a => a - case (l, r) if l fastEquals r => l - // a && (not(a) || b) => a && b - case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) - case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) - case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) - case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) - // (a || b) && (a || c) => a || (b && c) - case _ => - // 1. Split left and right to get the disjunctive predicates, - // i.e. lhs = (a, b), rhs = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) - val lhs = splitDisjunctivePredicates(left) - val rhs = splitDisjunctivePredicates(right) - val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) - if (common.isEmpty) { - // No common factors, return the original predicate - and + case TrueLiteral And e => e + case e And TrueLiteral => e + case FalseLiteral Or e => e + case e Or FalseLiteral => e + + case FalseLiteral And _ => FalseLiteral + case _ And FalseLiteral => FalseLiteral + case TrueLiteral Or _ => TrueLiteral + case _ Or TrueLiteral => TrueLiteral + + case a And b if a.semanticEquals(b) => a + case a Or b if a.semanticEquals(b) => a + + case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c) + case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b) + case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c) + case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c) + + case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c) + case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b) + case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c) + case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c) + + // Common factor elimination for conjunction + case and @ (left And right) => + // 1. Split left and right to get the disjunctive predicates, + // i.e. lhs = (a, b), rhs = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) + val lhs = splitDisjunctivePredicates(left) + val rhs = splitDisjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals)) + if (common.isEmpty) { + // No common factors, return the original predicate + and + } else { + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a || b || c || ...) && (a || b) => (a || b) + common.reduce(Or) } else { - val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) - val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a || b || c || ...) && (a || b) => (a || b) - common.reduce(Or) - } else { - // (a || b || c || ...) && (a || b || d || ...) => - // ((c || ...) && (d || ...)) || a || b - (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) - } + // (a || b || c || ...) && (a || b || d || ...) => + // ((c || ...) && (d || ...)) || a || b + (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) } - } // end of And(left, right) - - case or @ Or(left, right) => (left, right) match { - // true || r => true - case (Literal(true, BooleanType), _) => Literal(true) - // r || true => true - case (_, Literal(true, BooleanType)) => Literal(true) - // false || r => r - case (Literal(false, BooleanType), r) => r - // l || false => l - case (l, Literal(false, BooleanType)) => l - // a || a => a - case (l, r) if l fastEquals r => l - // (a && b) || (a && c) => a && (b || c) - case _ => - // 1. Split left and right to get the conjunctive predicates, - // i.e. lhs = (a, b), rhs = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) - val lhs = splitConjunctivePredicates(left) - val rhs = splitConjunctivePredicates(right) - val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) - if (common.isEmpty) { - // No common factors, return the original predicate - or + } + + // Common factor elimination for disjunction + case or @ (left Or right) => + // 1. Split left and right to get the conjunctive predicates, + // i.e. lhs = (a, b), rhs = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) + val lhs = splitConjunctivePredicates(left) + val rhs = splitConjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals)) + if (common.isEmpty) { + // No common factors, return the original predicate + or + } else { + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a && b) || (a && b && c && ...) => a && b + common.reduce(And) } else { - val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) - val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a && b) || (a && b && c && ...) => a && b - common.reduce(And) - } else { - // (a && b && c && ...) || (a && b && d && ...) => - // ((c && ...) || (d && ...)) && a && b - (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) - } + // (a && b && c && ...) || (a && b && d && ...) => + // ((c && ...) || (d && ...)) && a && b + (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) } - } // end of Or(left, right) - - case not @ Not(exp) => exp match { - // not(true) => false - case Literal(true, BooleanType) => Literal(false) - // not(false) => true - case Literal(false, BooleanType) => Literal(true) - // not(l > r) => l <= r - case GreaterThan(l, r) => LessThanOrEqual(l, r) - // not(l >= r) => l < r - case GreaterThanOrEqual(l, r) => LessThan(l, r) - // not(l < r) => l >= r - case LessThan(l, r) => GreaterThanOrEqual(l, r) - // not(l <= r) => l > r - case LessThanOrEqual(l, r) => GreaterThan(l, r) - // not(l || r) => not(l) && not(r) - case Or(l, r) => And(Not(l), Not(r)) - // not(l && r) => not(l) or not(r) - case And(l, r) => Or(Not(l), Not(r)) - // not(not(e)) => e - case Not(e) => e - case _ => not - } // end of Not(exp) - - // if (true) a else b => a - // if (false) a else b => b - case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue + } + + case Not(TrueLiteral) => FalseLiteral + case Not(FalseLiteral) => TrueLiteral + + case Not(a GreaterThan b) => LessThanOrEqual(a, b) + case Not(a GreaterThanOrEqual b) => LessThan(a, b) + + case Not(a LessThan b) => GreaterThanOrEqual(a, b) + case Not(a LessThanOrEqual b) => GreaterThan(a, b) + + case Not(a Or b) => And(Not(a), Not(b)) + case Not(a And b) => Or(Not(a), Not(b)) + + case Not(Not(e)) => e + + case If(TrueLiteral, trueValue, _) => trueValue + case If(FalseLiteral, _, falseValue) => falseValue } } } From aaa2c3b628319178ca1f3f68966ff253c2de49cb Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 11 Jan 2016 19:59:15 -0800 Subject: [PATCH 1531/2089] [SPARK-11823] Ignores HiveThriftBinaryServerSuite's test jdbc cancel https://issues.apache.org/jira/browse/SPARK-11823 This test often hangs and times out, leaving hanging processes. Let's ignore it for now and improve the test. Author: Yin Huai Closes #10715 from yhuai/SPARK-11823-ignore. --- .../spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index e598284ab22f8..ba3b26e1b7d49 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -347,7 +347,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { ) } - test("test jdbc cancel") { + // This test often hangs and then times out, leaving the hanging processes. + // Let's ignore it and improve the test. + ignore("test jdbc cancel") { withJdbcStatement { statement => val queries = Seq( "DROP TABLE IF EXISTS test_map", From 39ae04e6b714e085a1341aa84d8fc5fc827d5f35 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 11 Jan 2016 21:06:22 -0800 Subject: [PATCH 1532/2089] [SPARK-12692][BUILD][STREAMING] Scala style: Fix the style violation (Space before "," or ":") Fix the style violation (space before , and :). This PR is a followup for #10643. Author: Kousuke Saruta Closes #10685 from sarutak/SPARK-12692-followup-streaming. --- .../clickstream/PageViewGenerator.scala | 14 ++++---- .../spark/streaming/flume/sink/Logging.scala | 8 ++--- .../streaming/flume/FlumeInputDStream.scala | 18 +++++----- .../kafka/DirectKafkaInputDStream.scala | 4 +-- .../streaming/kafka/KafkaInputDStream.scala | 4 +-- .../kafka/ReliableKafkaStreamSuite.scala | 2 +- .../streaming/mqtt/MQTTInputDStream.scala | 4 +-- .../twitter/TwitterInputDStream.scala | 4 +-- project/MimaExcludes.scala | 12 +++++++ .../apache/spark/streaming/Checkpoint.scala | 12 +++---- .../spark/streaming/StreamingContext.scala | 36 +++++++++---------- .../streaming/api/java/JavaDStreamLike.scala | 2 +- .../dstream/ConstantInputDStream.scala | 4 +-- .../dstream/DStreamCheckpointData.scala | 2 +- .../streaming/dstream/FileInputDStream.scala | 18 +++++----- .../streaming/dstream/InputDStream.scala | 6 ++-- .../dstream/PluggableInputDStream.scala | 4 +-- .../streaming/dstream/RawInputDStream.scala | 4 +-- .../dstream/ReceiverInputDStream.scala | 6 ++-- .../dstream/SocketInputDStream.scala | 4 +-- .../streaming/dstream/StateDStream.scala | 6 ++-- .../spark/streaming/receiver/Receiver.scala | 8 ++--- .../streaming/BasicOperationsSuite.scala | 2 +- .../spark/streaming/CheckpointSuite.scala | 2 +- .../spark/streaming/MasterFailureTest.scala | 4 +-- .../spark/streaming/StateMapSuite.scala | 2 +- .../streaming/StreamingContextSuite.scala | 2 +- .../spark/streaming/TestSuiteBase.scala | 4 +-- .../scheduler/ReceiverTrackerSuite.scala | 4 +-- .../streaming/util/WriteAheadLogSuite.scala | 2 +- 30 files changed, 108 insertions(+), 96 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index ce1a62060ef6c..50216b9bd40f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -23,15 +23,15 @@ import java.net.ServerSocket import java.util.Random /** Represents a page view on a website with associated dimension data. */ -class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) +class PageView(val url: String, val status: Int, val zipCode: Int, val userID: Int) extends Serializable { - override def toString() : String = { + override def toString(): String = { "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) } } object PageView extends Serializable { - def fromString(in : String) : PageView = { + def fromString(in: String): PageView = { val parts = in.split("\t") new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) } @@ -58,9 +58,9 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01) : _*) + val userID = Map((1 to 100).map(_ -> .01): _*) - def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { + def pickFromDistribution[T](inputMap: Map[T, Double]): T = { val rand = new Random().nextDouble() var total = 0.0 for ((item, prob) <- inputMap) { @@ -72,7 +72,7 @@ object PageViewGenerator { inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 } - def getNextClickEvent() : String = { + def getNextClickEvent(): String = { val id = pickFromDistribution(userID) val page = pickFromDistribution(pages) val status = pickFromDistribution(httpStatus) @@ -80,7 +80,7 @@ object PageViewGenerator { new PageView(page, status, zipCode, id).toString() } - def main(args : Array[String]) { + def main(args: Array[String]) { if (args.length != 2) { System.err.println("Usage: PageViewGenerator ") System.exit(1) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index d87b86932dd41..aa530a7121bd0 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -26,20 +26,20 @@ import org.slf4j.{Logger, LoggerFactory} private[sink] trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient private var log_ : Logger = null + @transient private var _log: Logger = null // Method to get or create the logger for this object protected def log: Logger = { - if (log_ == null) { + if (_log == null) { initializeIfNecessary() var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects if (className.endsWith("$")) { className = className.substring(0, className.length - 1) } - log_ = LoggerFactory.getLogger(className) + _log = LoggerFactory.getLogger(className) } - log_ + _log } // Log methods that take only a String diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 1bfa35a8b3d1d..74bd0165c6209 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -41,12 +41,12 @@ import org.apache.spark.util.Utils private[streaming] class FlumeInputDStream[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, host: String, port: Int, storageLevel: StorageLevel, enableDecompression: Boolean -) extends ReceiverInputDStream[SparkFlumeEvent](ssc_) { +) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { override def getReceiver(): Receiver[SparkFlumeEvent] = { new FlumeReceiver(host, port, storageLevel, enableDecompression) @@ -60,7 +60,7 @@ class FlumeInputDStream[T: ClassTag]( * which are not serializable. */ class SparkFlumeEvent() extends Externalizable { - var event : AvroFlumeEvent = new AvroFlumeEvent() + var event: AvroFlumeEvent = new AvroFlumeEvent() /* De-serialize from bytes. */ def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -75,12 +75,12 @@ class SparkFlumeEvent() extends Externalizable { val keyLength = in.readInt() val keyBuff = new Array[Byte](keyLength) in.readFully(keyBuff) - val key : String = Utils.deserialize(keyBuff) + val key: String = Utils.deserialize(keyBuff) val valLength = in.readInt() val valBuff = new Array[Byte](valLength) in.readFully(valBuff) - val value : String = Utils.deserialize(valBuff) + val value: String = Utils.deserialize(valBuff) headers.put(key, value) } @@ -109,7 +109,7 @@ class SparkFlumeEvent() extends Externalizable { } private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + def fromAvroFlumeEvent(in: AvroFlumeEvent): SparkFlumeEvent = { val event = new SparkFlumeEvent event.event = in event @@ -118,13 +118,13 @@ private[streaming] object SparkFlumeEvent { /** A simple server that implements Flume's Avro protocol. */ private[streaming] -class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { - override def append(event : AvroFlumeEvent) : Status = { +class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol { + override def append(event: AvroFlumeEvent): Status = { receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event)) Status.OK } - override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + override def appendBatch(events: java.util.List[AvroFlumeEvent]): Status = { events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 8a087474d3169..54d8c8b03f206 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -58,11 +58,11 @@ class DirectKafkaInputDStream[ U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R - ) extends InputDStream[R](ssc_) with Logging { + ) extends InputDStream[R](_ssc) with Logging { val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 67f2360896b16..89d1811c99971 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -48,12 +48,12 @@ class KafkaInputDStream[ V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], useReliableReceiver: Boolean, storageLevel: StorageLevel - ) extends ReceiverInputDStream[(K, V)](ssc_) with Logging { + ) extends ReceiverInputDStream[(K, V)](_ssc) with Logging { def getReceiver(): Receiver[(K, V)] = { if (!useReliableReceiver) { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 80e2df62de3fe..7b9aee39ffb76 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -50,7 +50,7 @@ class ReliableKafkaStreamSuite extends SparkFunSuite private var ssc: StreamingContext = _ private var tempDirectory: File = null - override def beforeAll() : Unit = { + override def beforeAll(): Unit = { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 116c170489e96..079bd8a9a87ea 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -38,11 +38,11 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class MQTTInputDStream( - ssc_ : StreamingContext, + _ssc: StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel - ) extends ReceiverInputDStream[String](ssc_) { + ) extends ReceiverInputDStream[String](_ssc) { private[streaming] override def name: String = s"MQTT stream [$id]" diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index a48eec70b9f78..bdd57fdde3b89 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -39,11 +39,11 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class TwitterInputDStream( - ssc_ : StreamingContext, + _ssc: StreamingContext, twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel - ) extends ReceiverInputDStream[Status](ssc_) { + ) extends ReceiverInputDStream[Status](_ssc) { private def createOAuthAuthorization(): Authorization = { new OAuthAuthorization(new ConfigurationBuilder().build()) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0d5f938d9ef5c..4206d1fada421 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -135,6 +135,18 @@ object MimaExcludes { ) ++ Seq( // SPARK-12510 Refactor ActorReceiver to support Java ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") + ) ++ Seq( + // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") ) case v if v.startsWith("1.6") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 86f01d2168729..298cdc05acfa9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -183,7 +183,7 @@ class CheckpointWriter( val executor = Executors.newFixedThreadPool(1) val compressionCodec = CompressionCodec.createCodec(conf) private var stopped = false - private var fs_ : FileSystem = _ + private var _fs: FileSystem = _ @volatile private var latestCheckpointTime: Time = null @@ -298,12 +298,12 @@ class CheckpointWriter( } private def fs = synchronized { - if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf) - fs_ + if (_fs == null) _fs = new Path(checkpointDir).getFileSystem(hadoopConf) + _fs } private def reset() = synchronized { - fs_ = null + _fs = null } } @@ -370,8 +370,8 @@ object CheckpointReader extends Logging { } private[streaming] -class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) - extends ObjectInputStream(inputStream_) { +class ObjectInputStreamWithLoader(_inputStream: InputStream, loader: ClassLoader) + extends ObjectInputStream(_inputStream) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ba509a1030af7..157ee92fd71b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -58,9 +58,9 @@ import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookMan * of the context by `stop()` or by an exception. */ class StreamingContext private[streaming] ( - sc_ : SparkContext, - cp_ : Checkpoint, - batchDur_ : Duration + _sc: SparkContext, + _cp: Checkpoint, + _batchDur: Duration ) extends Logging { /** @@ -126,18 +126,18 @@ class StreamingContext private[streaming] ( } - if (sc_ == null && cp_ == null) { + if (_sc == null && _cp == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") } - private[streaming] val isCheckpointPresent = (cp_ != null) + private[streaming] val isCheckpointPresent = (_cp != null) private[streaming] val sc: SparkContext = { - if (sc_ != null) { - sc_ + if (_sc != null) { + _sc } else if (isCheckpointPresent) { - SparkContext.getOrCreate(cp_.createSparkConf()) + SparkContext.getOrCreate(_cp.createSparkConf()) } else { throw new SparkException("Cannot create StreamingContext without a SparkContext") } @@ -154,13 +154,13 @@ class StreamingContext private[streaming] ( private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) - cp_.graph.restoreCheckpointData() - cp_.graph + _cp.graph.setContext(this) + _cp.graph.restoreCheckpointData() + _cp.graph } else { - require(batchDur_ != null, "Batch duration for StreamingContext cannot be null") + require(_batchDur != null, "Batch duration for StreamingContext cannot be null") val newGraph = new DStreamGraph() - newGraph.setBatchDuration(batchDur_) + newGraph.setBatchDuration(_batchDur) newGraph } } @@ -169,15 +169,15 @@ class StreamingContext private[streaming] ( private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { - sc.setCheckpointDir(cp_.checkpointDir) - cp_.checkpointDir + sc.setCheckpointDir(_cp.checkpointDir) + _cp.checkpointDir } else { null } } private[streaming] val checkpointDuration: Duration = { - if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration + if (isCheckpointPresent) _cp.checkpointDuration else graph.batchDuration } private[streaming] val scheduler = new JobScheduler(this) @@ -246,7 +246,7 @@ class StreamingContext private[streaming] ( } private[streaming] def initialCheckpoint: Checkpoint = { - if (isCheckpointPresent) cp_ else null + if (isCheckpointPresent) _cp else null } private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() @@ -460,7 +460,7 @@ class StreamingContext private[streaming] ( def binaryRecordsStream( directory: String, recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { - val conf = sc_.hadoopConfiguration + val conf = _sc.hadoopConfiguration conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 733147f63ea2e..a791a474c673d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -101,7 +101,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of elements in a window over this DStream. windowDuration and slideDuration are as defined in * the window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[jl.Long] = { + def countByWindow(windowDuration: Duration, slideDuration: Duration): JavaDStream[jl.Long] = { dstream.countByWindow(windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index 695384deb32d7..b5f86fe7794fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -25,8 +25,8 @@ import org.apache.spark.streaming.{StreamingContext, Time} /** * An input stream that always returns the same RDD on each timestep. Useful for testing. */ -class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc_) { +class ConstantInputDStream[T: ClassTag](_ssc: StreamingContext, rdd: RDD[T]) + extends InputDStream[T](_ssc) { require(rdd != null, "parameter rdd null is illegal, which will lead to NPE in the following transformation") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 3eff174c2b66c..a9ce1131ce0c1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -39,7 +39,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) // in that batch's checkpoint data @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time] - @transient private var fileSystem : FileSystem = null + @transient private var fileSystem: FileSystem = null protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]] /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index cb5b1f252e90c..1c2325409b53e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -73,13 +73,13 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - ssc_ : StreamingContext, + _ssc: StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, conf: Option[Configuration] = None) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) - extends InputDStream[(K, V)](ssc_) { + extends InputDStream[(K, V)](_ssc) { private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) @@ -128,8 +128,8 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Timestamp of the last round of finding files @transient private var lastNewFileFindingTime = 0L - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null + @transient private var _path: Path = null + @transient private var _fs: FileSystem = null override def start() { } @@ -289,17 +289,17 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( } private def directoryPath: Path = { - if (path_ == null) path_ = new Path(directory) - path_ + if (_path == null) _path = new Path(directory) + _path } private def fs: FileSystem = { - if (fs_ == null) fs_ = directoryPath.getFileSystem(ssc.sparkContext.hadoopConfiguration) - fs_ + if (_fs == null) _fs = directoryPath.getFileSystem(ssc.sparkContext.hadoopConfiguration) + _fs } private def reset() { - fs_ = null + _fs = null } @throws(classOf[IOException]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d60f418e5c4de..76f6230f36226 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -38,10 +38,10 @@ import org.apache.spark.util.Utils * that requires running a receiver on the worker nodes, use * [[org.apache.spark.streaming.dstream.ReceiverInputDStream]] as the parent class. * - * @param ssc_ Streaming context that will execute this input stream + * @param _ssc Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) - extends DStream[T](ssc_) { +abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext) + extends DStream[T](_ssc) { private[streaming] var lastValidTime: Time = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 2442e4c01a0c0..e003ddb96c860 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -24,8 +24,8 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - ssc_ : StreamingContext, - receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { + _ssc: StreamingContext, + receiver: Receiver[T]) extends ReceiverInputDStream[T](_ssc) { def getReceiver(): Receiver[T] = { receiver diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index ac73dca05a674..409c565380f06 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -38,11 +38,11 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, host: String, port: Int, storageLevel: StorageLevel - ) extends ReceiverInputDStream[T](ssc_ ) with Logging { + ) extends ReceiverInputDStream[T](_ssc) with Logging { def getReceiver(): Receiver[T] = { new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[Receiver[T]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 565b137228d00..49d8f14f4c390 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -35,11 +35,11 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils * define [[getReceiver]] function that gets the receiver object of type * [[org.apache.spark.streaming.receiver.Receiver]] that will be sent * to the workers to receive data. - * @param ssc_ Streaming context that will execute this input stream + * @param _ssc Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) - extends InputDStream[T](ssc_) { +abstract class ReceiverInputDStream[T: ClassTag](_ssc: StreamingContext) + extends InputDStream[T](_ssc) { /** * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index e70fc87c39d95..441477479167a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -31,12 +31,12 @@ import org.apache.spark.util.NextIterator private[streaming] class SocketInputDStream[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], storageLevel: StorageLevel - ) extends ReceiverInputDStream[T](ssc_) { + ) extends ReceiverInputDStream[T](_ssc) { def getReceiver(): Receiver[T] = { new SocketReceiver(host, port, bytesToObjects, storageLevel) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index ebbe139a2cdf8..fedffb23952a4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -31,7 +31,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, preservePartitioning: Boolean, - initialRDD : Option[RDD[(K, S)]] + initialRDD: Option[RDD[(K, S)]] ) extends DStream[(K, S)](parent.ssc) { super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -43,7 +43,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( override val mustCheckpoint = true private [this] def computeUsingPreviousRDD ( - parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = { + parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = { // Define the function for the mapPartition operation on cogrouped RDD; // first map the cogrouped tuple to tuples of required type, // and then apply the update function @@ -98,7 +98,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => { + val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None))) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 639f4259e2e73..3376cd557d72f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -108,7 +108,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable def onStop() /** Override this to specify a preferred location (hostname). */ - def preferredLocation : Option[String] = None + def preferredLocation: Option[String] = None /** * Store a single item of received data to Spark's memory. @@ -257,11 +257,11 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - @transient private var _supervisor : ReceiverSupervisor = null + @transient private var _supervisor: ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ - private[streaming] def setReceiverId(id_ : Int) { - id = id_ + private[streaming] def setReceiverId(_id: Int) { + id = _id } /** Attach Network Receiver executor to this receiver. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 9d296c6d3ef8b..25e7ae8262a5f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -186,7 +186,7 @@ class BasicOperationsSuite extends TestSuiteBase { val output = Seq(1 to 8, 101 to 108, 201 to 208) testOperation( input, - (s: DStream[Int]) => s.union(s.map(_ + 4)) , + (s: DStream[Int]) => s.union(s.map(_ + 4)), output ) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 4d04138da01f7..4a6b91fbc745e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSy * A input stream that records the times of restore() invoked */ private[streaming] -class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) { +class CheckpointInputDStream(_ssc: StreamingContext) extends InputDStream[Int](_ssc) { protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData override def start(): Unit = { } override def stop(): Unit = { } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 4e56dfbd424b0..7bbbdebd9b19f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -200,12 +200,12 @@ object MasterFailureTest extends Logging { * the last expected output is generated. Finally, return */ private def runStreams[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, lastExpectedOutput: T, maxTimeToRun: Long ): Seq[T] = { - var ssc = ssc_ + var ssc = _ssc var totalTimeRan = 0L var isLastOutputGenerated = false var isTimedOut = false diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index da0430e263b5f..7a76cafc9a11c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -280,7 +280,7 @@ class StateMapSuite extends SparkFunSuite { testSerialization(new KryoSerializer(conf), map, msg) } - private def testSerialization[T : ClassTag]( + private def testSerialization[T: ClassTag]( serializer: Serializer, map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 0ae4c45988032..197b3d143995a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -896,7 +896,7 @@ object SlowTestReceiver { package object testPackage extends Assertions { def test() { val conf = new SparkConf().setMaster("local").setAppName("CreationSite test") - val ssc = new StreamingContext(conf , Milliseconds(100)) + val ssc = new StreamingContext(conf, Milliseconds(100)) try { val inputStream = ssc.receiverStream(new TestReceiver) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 54eff2b214290..239b10894ad2c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -58,8 +58,8 @@ private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputD * replayable, reliable message queue like Kafka. It requires a sequence as input, and * returns the i_th element at the i_th batch unde manual clock. */ -class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) - extends InputDStream[T](ssc_) { +class TestInputStream[T: ClassTag](_ssc: StreamingContext, input: Seq[Seq[T]], numPartitions: Int) + extends InputDStream[T](_ssc) { def start() {} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 3bd8d086abf7f..b67189fbd7f03 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -107,8 +107,8 @@ class ReceiverTrackerSuite extends TestSuiteBase { } /** An input DStream with for testing rate controlling */ -private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) - extends ReceiverInputDStream[Int](ssc_) { +private[streaming] class RateTestInputDStream(@transient _ssc: StreamingContext) + extends ReceiverInputDStream[Int](_ssc) { override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index b5d6a24ce8dd6..734dd93cda471 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -154,7 +154,7 @@ abstract class CommonWriteAheadLogTests( // Recover old files and generate a second set of log files val dataToWrite2 = generateRandomData() manualClock.advance(100000) - writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching , + writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching, manualClock) val logFiles2 = getLogFilesInDirectory(testDir) assert(logFiles2.size > logFiles1.size) From 112abf9100f05be436e449817468c50174712c78 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 11 Jan 2016 21:37:54 -0800 Subject: [PATCH 1533/2089] [SPARK-12692][BUILD][YARN] Scala style: Fix the style violation (Space before "," or ":") Fix the style violation (space before , and :). This PR is a followup for #10643. Author: Kousuke Saruta Closes #10686 from sarutak/SPARK-12692-followup-yarn. --- .../org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index e286aed9f9781..272f1299e0ea9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -357,7 +357,7 @@ object YarnSparkHadoopUtil { * * @return The correct OOM Error handler JVM option, platform dependent. */ - def getOutOfMemoryErrorArgument : String = { + def getOutOfMemoryErrorArgument: String = { if (Utils.isWindows) { escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p") } else { From 8cfa218f4f1b05f4d076ec15dd0a033ad3e4500d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 12 Jan 2016 00:51:00 -0800 Subject: [PATCH 1534/2089] [SPARK-12692][BUILD][SQL] Scala style: Fix the style violation (Space before "," or ":") Fix the style violation (space before , and :). This PR is a followup for #10643. Author: Kousuke Saruta Closes #10718 from sarutak/SPARK-12692-followup-sql. --- scalastyle-config.xml | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 6 ++-- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 6 ++-- .../catalyst/analysis/FunctionRegistry.scala | 4 +-- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 4 ++- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../spark/sql/catalyst/encoders/package.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/stringExpressions.scala | 6 ++-- .../plans/logical/basicOperators.scala | 6 ++-- .../sql/catalyst/util/NumberConverter.scala | 2 +- .../apache/spark/sql/types/ArrayType.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 2 ++ .../encoders/EncoderErrorMessageSuite.scala | 2 +- .../encoders/ExpressionEncoderSuite.scala | 6 ++-- .../BooleanSimplificationSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 4 ++- .../org/apache/spark/sql/DataFrame.scala | 36 +++++++++---------- .../apache/spark/sql/DataFrameHolder.scala | 2 +- .../spark/sql/DataFrameNaFunctions.scala | 8 ++--- .../apache/spark/sql/DataFrameReader.scala | 6 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 18 +++++----- .../org/apache/spark/sql/GroupedData.scala | 10 +++--- .../org/apache/spark/sql/GroupedDataset.scala | 8 ++--- .../org/apache/spark/sql/SQLContext.scala | 12 +++---- .../org/apache/spark/sql/SQLImplicits.scala | 10 +++--- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../spark/sql/execution/Queryable.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 2 +- .../datasources/SqlNewHadoopRDD.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 2 +- .../execution/joins/CartesianProduct.scala | 2 +- .../sql/execution/metric/SQLMetrics.scala | 2 +- .../sql/execution/stat/FrequentItems.scala | 4 +-- .../apache/spark/sql/expressions/Window.scala | 8 ++--- .../org/apache/spark/sql/functions.scala | 6 ++-- .../spark/sql/jdbc/AggregatedDialect.scala | 2 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 10 +++--- .../apache/spark/sql/jdbc/MySQLDialect.scala | 7 ++-- .../spark/sql/DatasetAggregatorSuite.scala | 4 +-- .../apache/spark/sql/DatasetCacheSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 26 +++++++------- .../datasources/json/JsonSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +-- .../hive/thriftserver/ReflectionUtils.scala | 2 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +-- .../spark/sql/hive/HiveInspectors.scala | 8 ++--- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- .../spark/sql/hive/HiveInspectorSuite.scala | 4 +-- .../sql/hive/InsertIntoHiveTableSuite.scala | 4 +-- 54 files changed, 150 insertions(+), 141 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 2439a1f715aba..b873b627219f2 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -218,7 +218,7 @@ This file is divided into 3 sections: - + COLON, COMMA diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 79f723cf9b8a0..23fea0e2832a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -49,7 +49,7 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + def dataTypeFor[T: TypeTag]: DataType = dataTypeFor(localTypeOf[T]) private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { tpe match { @@ -116,7 +116,7 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = { + def constructorFor[T: TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil @@ -386,7 +386,7 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def extractorsFor[T: TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2a132d8b82bef..6ec408a673c79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -203,7 +203,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(expression ~ direction.? , ",") ^^ { + ( rep1sep(expression ~ direction.?, ",") ^^ { case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8a33af8207350..d16880bc4a9c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -84,7 +84,7 @@ class Analyzer( ResolveAggregateFunctions :: DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*), + extendedResolutionRules: _*), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, @@ -110,7 +110,7 @@ class Analyzer( // Taking into account the reasonableness and the implementation complexity, // here use the CTE definition first, check table name only and ignore database name // see https://github.com/apache/spark/pull/4929#discussion_r27186638 for more info - case u : UnresolvedRelation => + case u: UnresolvedRelation => val substituted = cteRelations.get(u.tableIdentifier.table).map { relation => val withAlias = u.alias.map(Subquery(_, relation)) withAlias.getOrElse(relation) @@ -889,7 +889,7 @@ class Analyzer( _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). - case wf : WindowFunction => + case wf: WindowFunction => val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c2aa3c06b3e7..7c3d45b1e40c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -323,13 +323,13 @@ object FunctionRegistry { } else { // Otherwise, find an ctor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { + val f = Try(tag.runtimeClass.getDeclaredConstructor(params: _*)) match { case Success(e) => e case Failure(e) => throw new AnalysisException(s"Invalid number of arguments for function $name") } - Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { + Try(f.newInstance(expressions: _*).asInstanceOf[Expression]) match { case Success(e) => e case Failure(e) => throw new AnalysisException(e.getMessage) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index dbcbd6854b474..e326ea782700c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -529,7 +529,7 @@ object HiveTypeCoercion { if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) case EqualTo(left @ BooleanType(), right @ NumericType()) => - transform(left , right) + transform(left, right) case EqualTo(left @ NumericType(), right @ BooleanType()) => transform(right, left) case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5ac1984043d87..c4dbcb7b60628 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -61,9 +61,11 @@ package object dsl { trait ImplicitOperators { def expr: Expression + // scalastyle:off whitespacebeforetoken def unary_- : Expression = UnaryMinus(expr) def unary_! : Predicate = Not(expr) def unary_~ : Expression = BitwiseNot(expr) + // scalastyle:on whitespacebeforetoken def + (other: Expression): Expression = Add(expr, other) def - (other: Expression): Expression = Subtract(expr, other) @@ -141,7 +143,7 @@ package object dsl { // Note that if we make ExpressionConversions an object rather than a trait, we can // then make this a value class to avoid the small penalty of runtime instantiation. def $(args: Any*): analysis.UnresolvedAttribute = { - analysis.UnresolvedAttribute(sc.s(args : _*)) + analysis.UnresolvedAttribute(sc.s(args: _*)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 05f746e72b498..fa4c2d93eccec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](): ExpressionEncoder[T] = { + def apply[T: TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 9e283f5eb6342..08ada1f38ba96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -27,7 +27,7 @@ package object encoders { * references from a specific schema.) This requirement allows us to preserve whether a given * object type is being bound by name or by ordinal when doing resolution. */ - private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + private[sql] def encoderFor[A: Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e.assertUnresolved() e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d6219514b752b..4ffbfa57e726d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -164,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns the hash for this expression. Expressions that compute the same result, even if * they differ cosmetically should return the same hash. */ - def semanticHash() : Int = { + def semanticHash(): Int = { def computeHash(e: Seq[Any]): Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var hash: Int = 17 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 931f752b4dc1a..bf41f85f79096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -46,7 +46,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + UTF8String.concat(inputs: _*) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -99,7 +99,7 @@ case class ConcatWs(children: Seq[Expression]) case null => Iterator(null.asInstanceOf[UTF8String]) } } - UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) + UTF8String.concatWs(flatInputs.head, flatInputs.tail: _*) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -990,7 +990,7 @@ case class FormatNumber(x: Expression, d: Expression) def typeHelper(p: String): String = { x.dataType match { - case _ : DecimalType => s"""$p.toJavaBigDecimal()""" + case _: DecimalType => s"""$p.toJavaBigDecimal()""" case _ => s"$p" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64957db6b4013..5489051e9501b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -496,7 +496,7 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { - def apply[T, U : Encoder]( + def apply[T, U: Encoder]( func: T => U, tEncoder: ExpressionEncoder[T], child: LogicalPlan): AppendColumns[T, U] = { @@ -522,7 +522,7 @@ case class AppendColumns[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { - def apply[K, T, U : Encoder]( + def apply[K, T, U: Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], @@ -557,7 +557,7 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[Key, Left, Right, Result : Encoder]( + def apply[Key, Left, Right, Result: Encoder]( func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 9fefc5656aac0..e4417e0955143 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -122,7 +122,7 @@ object NumberConverter { * unsigned, otherwise it is signed. * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv */ - def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 520e344361625..be7573b95d841 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -90,7 +90,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { private[this] val elementOrdering: Ordering[Any] = elementType match { case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] - case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ede..864b47a2a08aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -310,6 +310,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def remainder(that: Decimal): Decimal = this % that + // scalastyle:off whitespacebeforetoken def unary_- : Decimal = { if (decimalVal.ne(null)) { Decimal(-decimalVal, precision, scale) @@ -317,6 +318,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { Decimal(-longVal, precision, scale) } } + // scalastyle:on whitespacebeforetoken def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index 8c766ef829923..a1c4a861c610f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -98,5 +98,5 @@ class EncoderErrorMessageSuite extends SparkFunSuite { s"""array element class: "${clsName[NonEncodable]}"""")) } - private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName + private def clsName[T: ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 88c558d80a79a..67f4dc98be231 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -80,7 +80,7 @@ class JavaSerializable(val value: Int) extends Serializable { class ExpressionEncoderSuite extends SparkFunSuite { OuterScopes.outerScopes.put(getClass.getName, this) - implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + implicit def encoder[T: TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() // test flat encoders encodeDecodeTest(false, "primitive boolean") @@ -145,7 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { encoderFor(Encoders.javaSerialization[JavaSerializable])) // test product encoders - private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + private def productTest[T <: Product: ExpressionEncoder](input: T): Unit = { encodeDecodeTest(input, input.getClass.getSimpleName) } @@ -286,7 +286,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { } } - private def encodeDecodeTest[T : ExpressionEncoder]( + private def encodeDecodeTest[T: ExpressionEncoder]( input: T, testName: String): Unit = { test(s"encode/decode for $testName: $input") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 000a3b7ecb7c6..6932f185b9d62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -80,7 +80,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) - checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) + checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5), 'a < 2) checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e8c61d6e01dc3..a434d03332459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -152,7 +152,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) + def as[U: Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. @@ -171,6 +171,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { UnresolvedExtractValue(expr, lit(extraction).expr) } + // scalastyle:off whitespacebeforetoken /** * Unary minus, i.e. negate the expression. * {{{ @@ -202,6 +203,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ def unary_! : Column = withExpr { Not(expr) } + // scalastyle:on whitespacebeforetoken /** * Equality test. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 60d2f05b8605b..fac8950aee12d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -204,7 +204,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) + def as[U: Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion @@ -227,7 +227,7 @@ class DataFrame private[sql]( val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } - select(newCols : _*) + select(newCols: _*) } /** @@ -579,7 +579,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { - sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) + sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*) } /** @@ -608,7 +608,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply) : _*) + sort((sortCol +: sortCols).map(apply): _*) } /** @@ -631,7 +631,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols: _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -640,7 +640,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs: _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -720,7 +720,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)): _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -948,7 +948,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs : _*) + groupBy().agg(aggExpr, aggExprs: _*) } /** @@ -986,7 +986,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -1118,7 +1118,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + def explode[A <: Product: TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val elementTypes = schema.toAttributes.map { @@ -1147,7 +1147,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) + def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil @@ -1186,7 +1186,7 @@ class DataFrame private[sql]( Column(field) } } - select(columns : _*) + select(columns: _*) } else { select(Column("*"), col.as(colName)) } @@ -1207,7 +1207,7 @@ class DataFrame private[sql]( Column(field) } } - select(columns : _*) + select(columns: _*) } else { select(Column("*"), col.as(colName, metadata)) } @@ -1231,7 +1231,7 @@ class DataFrame private[sql]( Column(col) } } - select(columns : _*) + select(columns: _*) } else { this } @@ -1244,7 +1244,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ def drop(colName: String): DataFrame = { - drop(Seq(colName) : _*) + drop(Seq(colName): _*) } /** @@ -1283,7 +1283,7 @@ class DataFrame private[sql]( val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) - select(colsAfterDrop : _*) + select(colsAfterDrop: _*) } /** @@ -1479,7 +1479,7 @@ class DataFrame private[sql]( * @group action * @since 1.6.0 */ - def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n): _*) /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. @@ -1505,7 +1505,7 @@ class DataFrame private[sql]( */ def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) + java.util.Arrays.asList(rdd.collect(): _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index 3b30337f1f877..4441a634be407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -33,5 +33,5 @@ case class DataFrameHolder private[sql](private val df: DataFrame) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = df - def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) + def toDF(colNames: String*): DataFrame = df.toDF(colNames: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f7be5f6b370ab..43500b09e0f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -164,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.col(f.name) } } - df.select(projections : _*) + df.select(projections: _*) } /** @@ -191,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.col(f.name) } } - df.select(projections : _*) + df.select(projections: _*) } /** @@ -364,7 +364,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.col(f.name) } } - df.select(projections : _*) + df.select(projections: _*) } private def fill0(values: Seq[(String, Any)]): DataFrame = { @@ -395,7 +395,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } }.getOrElse(df.col(f.name)) } - df.select(projections : _*) + df.select(projections: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d948e4894253c..1ed451d5a8bab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -203,7 +203,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { predicates: Array[String], connectionProperties: Properties): DataFrame = { val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) => - JDBCPartition(part, i) : Partition + JDBCPartition(part, i): Partition } jdbc(url, table, parts, connectionProperties) } @@ -262,7 +262,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def json(paths: String*): DataFrame = format("json").load(paths : _*) + def json(paths: String*): DataFrame = format("json").load(paths: _*) /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and @@ -355,7 +355,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.6.0 */ @scala.annotation.varargs - def text(paths: String*): DataFrame = format("text").load(paths : _*) + def text(paths: String*): DataFrame = format("text").load(paths: _*) /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 42f01e9359c64..9ffb5b94b2d18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -131,7 +131,7 @@ class Dataset[T] private[sql]( * along with `alias` or `as` to rearrange or rename as required. * @since 1.6.0 */ - def as[U : Encoder]: Dataset[U] = { + def as[U: Encoder]: Dataset[U] = { new Dataset(sqlContext, queryExecution, encoderFor[U]) } @@ -318,7 +318,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U: Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) /** * (Java-specific) @@ -333,7 +333,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ - def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapPartitions[T, U]( @@ -360,7 +360,7 @@ class Dataset[T] private[sql]( * and then flattening the results. * @since 1.6.0 */ - def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = + def flatMap[U: Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) /** @@ -432,7 +432,7 @@ class Dataset[T] private[sql]( * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ - def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { + def groupBy[K: Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) @@ -566,14 +566,14 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by sampling a fraction of records. * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withPlan(Sample(0.0, fraction, withReplacement, seed, _)) /** * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { + def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { sample(withReplacement, fraction, Utils.random.nextLong) } @@ -731,7 +731,7 @@ class Dataset[T] private[sql]( * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ - def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num): _*) /** * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). @@ -786,7 +786,7 @@ class Dataset[T] private[sql]( private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) - private[sql] def withPlan[R : Encoder]( + private[sql] def withPlan[R: Encoder]( other: Dataset[_])( f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c74ef2c03541e..f5cbf013bce9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -229,7 +229,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericColumns(colNames: _*)(Average) } /** @@ -241,7 +241,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Max) + aggregateNumericColumns(colNames: _*)(Max) } /** @@ -253,7 +253,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericColumns(colNames: _*)(Average) } /** @@ -265,7 +265,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Min) + aggregateNumericColumns(colNames: _*)(Min) } /** @@ -277,7 +277,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) + aggregateNumericColumns(colNames: _*)(Sum) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a819ddceb1b1b..12179367fa012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -73,7 +73,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def keyAs[L : Encoder]: GroupedDataset[L, V] = + def keyAs[L: Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], unresolvedVEncoder, @@ -110,7 +110,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( @@ -158,7 +158,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) flatMapGroups(func) } @@ -302,7 +302,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def cogroup[U, R : Encoder]( + def cogroup[U, R: Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { new Dataset[R]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e827427c19e25..61c74f83409e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -409,7 +409,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ @Experimental - def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { + def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes @@ -425,7 +425,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ @Experimental - def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = { SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes @@ -498,7 +498,7 @@ class SQLContext private[sql]( } - def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { + def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d).copy()) @@ -507,7 +507,7 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } - def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { + def createDataset[T: Encoder](data: RDD[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d)) @@ -516,7 +516,7 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } - def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -945,7 +945,7 @@ class SQLContext private[sql]( } } - // Register a succesfully instantiatd context to the singleton. This should be at the end of + // Register a successfully instantiated context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. sparkContext.addSparkListener(new SparkListener { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index ab414799f1a42..a7f7997df1a8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,7 +37,7 @@ abstract class SQLImplicits { protected def _sqlContext: SQLContext /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = ExpressionEncoder() /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() @@ -67,7 +67,7 @@ abstract class SQLImplicits { * Creates a [[Dataset]] from an RDD. * @since 1.6.0 */ - implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { + implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) } @@ -75,7 +75,7 @@ abstract class SQLImplicits { * Creates a [[Dataset]] from a local Seq. * @since 1.6.0 */ - implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { + implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(s)) } @@ -89,7 +89,7 @@ abstract class SQLImplicits { * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). * @since 1.3.0 */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + implicit def rddToDataFrameHolder[A <: Product: TypeTag](rdd: RDD[A]): DataFrameHolder = { DataFrameHolder(_sqlContext.createDataFrame(rdd)) } @@ -97,7 +97,7 @@ abstract class SQLImplicits { * Creates a DataFrame from a local Seq of Product. * @since 1.3.0 */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + implicit def localSeqToDataFrameHolder[A <: Product: TypeTag](data: Seq[A]): DataFrameHolder = { DataFrameHolder(_sqlContext.createDataFrame(data)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d912aeb70d517..a8e6a40169d81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -39,7 +39,7 @@ private[r] object SQLUtils { new JavaSparkContext(sqlCtx.sparkContext) } - def createStructType(fields : Seq[StructField]): StructType = { + def createStructType(fields: Seq[StructField]): StructType = { StructType(fields) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6b100577077c6..058d147c7d65d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -223,7 +223,7 @@ case class Exchange( new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } - protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { coordinator match { case Some(exchangeCoordinator) => val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 38263af0f7e30..bb551614779b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -71,7 +71,7 @@ private[sql] trait Queryable { private[sql] def formatString ( rows: Seq[Seq[String]], numRows: Int, - hasMoreData : Boolean, + hasMoreData: Boolean, truncate: Boolean = true): String = { val sb = new StringBuilder val numCols = schema.fieldNames.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 1df38f7ff59cd..b5ac530444b79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { - def apply[A, B : Encoder, C : Encoder]( + def apply[A, B: Encoder, C: Encoder]( aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index d45d2db62f3a9..d5e0d80076cbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -256,7 +256,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) } catch { - case e : Exception => + case e: Exception => logDebug("Failed to use InputSplit#getLocationInfo.", e) None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index fb97a03df60f4..c4b125e9d5f00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -557,7 +557,7 @@ private[parquet] object CatalystSchemaConverter { } } - private def computeMinBytesForPrecision(precision : Int) : Int = { + private def computeMinBytesForPrecision(precision: Int): Int = { var numBytes = 1 while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { numBytes += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 93d32e1fb93ae..a567457dba3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter * materialize the right RDD (in case of the right RDD is nondeterministic). */ private[spark] -class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) +class UnsafeCartesianRDD(left: RDD[UnsafeRow], right: RDD[UnsafeRow], numFieldsOfRight: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 52735c9d7f8c4..8c68d9ee0a1ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -64,7 +64,7 @@ private[sql] trait SQLMetricValue[T] extends Serializable { /** * A wrapper of Long to avoid boxing and unboxing when using Accumulator */ -private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { +private[sql] class LongSQLMetricValue(private var _value: Long) extends SQLMetricValue[Long] { def add(incr: Long): LongSQLMetricValue = { _value += incr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index a191759813de1..a4cb54e2bf2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -94,7 +94,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) }.toArray - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)): _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -115,7 +115,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toArray) - val resultRow = Row(justItems : _*) + val resultRow = Row(justItems: _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index e9b60841fc28c..05a9f377b9897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -44,7 +44,7 @@ object Window { */ @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { - spec.partitionBy(colName, colNames : _*) + spec.partitionBy(colName, colNames: _*) } /** @@ -53,7 +53,7 @@ object Window { */ @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { - spec.partitionBy(cols : _*) + spec.partitionBy(cols: _*) } /** @@ -62,7 +62,7 @@ object Window { */ @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { - spec.orderBy(colName, colNames : _*) + spec.orderBy(colName, colNames: _*) } /** @@ -71,7 +71,7 @@ object Window { */ @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { - spec.orderBy(cols : _*) + spec.orderBy(cols: _*) } private def spec: WindowSpec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 592d79df3109a..1ac62883a68ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -306,7 +306,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) + countDistinct(Column(columnName), columnNames.map(Column.apply): _*) /** * Aggregate function: returns the first value in a group. @@ -768,7 +768,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { - array((colName +: colNames).map(col) : _*) + array((colName +: colNames).map(col): _*) } /** @@ -977,7 +977,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { - struct((colName +: colNames).map(col) : _*) + struct((colName +: colNames).map(col): _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 467d8d62d1b7f..d2c31d6e04107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -30,7 +30,7 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect require(dialects.nonEmpty) - override def canHandle(url : String): Boolean = + override def canHandle(url: String): Boolean = dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index ca2d909e2cccc..8d58321d4887d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi -case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) +case class JdbcType(databaseTypeDefinition: String, jdbcNullType: Int) /** * :: DeveloperApi :: @@ -60,7 +60,7 @@ abstract class JdbcDialect extends Serializable { * @return True if the dialect can be applied on the given jdbc url. * @throws NullPointerException if the url is null. */ - def canHandle(url : String): Boolean + def canHandle(url: String): Boolean /** * Get the custom datatype mapping for the given jdbc meta information. @@ -130,7 +130,7 @@ object JdbcDialects { * * @param dialect The new dialect. */ - def registerDialect(dialect: JdbcDialect) : Unit = { + def registerDialect(dialect: JdbcDialect): Unit = { dialects = dialect :: dialects.filterNot(_ == dialect) } @@ -139,7 +139,7 @@ object JdbcDialects { * * @param dialect The jdbc dialect. */ - def unregisterDialect(dialect : JdbcDialect) : Unit = { + def unregisterDialect(dialect: JdbcDialect): Unit = { dialects = dialects.filterNot(_ == dialect) } @@ -169,5 +169,5 @@ object JdbcDialects { * NOOP dialect object, always returning the neutral element. */ private object NoopDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = true + override def canHandle(url: String): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index e1717049f383d..faae54e605c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -23,10 +23,13 @@ import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuil private case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3258f3782d8cc..f952fc07fd387 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext /** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { +class SumOf[I, N: Numeric](f: I => N) extends Aggregator[I, N, N] { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero @@ -113,7 +113,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + def sum[I, N: Numeric: Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn test("typed aggregation: TypedAggregator") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 3a283a4e1f610..848f1af65508b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -27,7 +27,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("persist and unpersist") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) val cached = ds.cache() // count triggers the caching action. It should not throw. cached.count() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 53b5f45c2d4a6..a3ed2e06165ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -30,7 +30,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("toDS") { - val data = Seq(("a", 1) , ("b", 2), ("c", 3)) + val data = Seq(("a", 1), ("b", 2), ("c", 3)) checkAnswer( data.toDS(), data: _*) @@ -87,7 +87,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("as case class / collect") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] checkAnswer( ds, ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) @@ -105,7 +105,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("map") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.map(v => (v._1, v._2 + 1)), ("a", 2), ("b", 3), ("c", 4)) @@ -124,23 +124,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select(expr("_2 + 1").as[Int]), 2, 3, 4) } test("select 2") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], - expr("_2").as[Int]) : Dataset[(String, Int)], + expr("_2").as[Int]): Dataset[(String, Int)], ("a", 1), ("b", 2), ("c", 3)) } test("select 2, primitive and tuple") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -149,7 +149,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -158,7 +158,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class, fields reordered") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDecoding( ds.select( expr("_1").as[String], @@ -167,28 +167,28 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("filter") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.filter(_._1 == "b"), ("b", 2)) } test("foreach") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreach(v => acc += v._2) assert(acc.value == 6) } test("foreachPartition") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreachPartition(_.foreach(v => acc += v._2)) assert(acc.value == 6) } test("reduce") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4ab148065a476..860e07c68cef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -206,7 +206,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructType( StructField("f1", IntegerType, true) :: StructField("f2", IntegerType, true) :: Nil), - StructType(StructField("f1", LongType, true) :: Nil) , + StructType(StructField("f1", LongType, true) :: Nil), StructType( StructField("f1", LongType, true) :: StructField("f2", IntegerType, true) :: Nil)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index ab48e971b507a..f2e0a868f4b1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -72,7 +72,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { /** * Writes `data` to a Parquet file, reads it back and check file contents. */ - protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { + protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 1fa22e2933318..984e3fbc05e48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -46,7 +46,7 @@ class JDBCSuite extends SparkFunSuite val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) @@ -489,7 +489,7 @@ class JDBCSuite extends SparkFunSuite test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala index 599294dfbb7d7..d1d8a68f6d196 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver private[hive] object ReflectionUtils { - def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + def setSuperField(obj: Object, fieldName: String, fieldValue: Object) { setAncestorField(obj, 1, fieldName, fieldValue) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 03bc830df2034..9f9efe33e12a3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -325,7 +325,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (ret != 0) { // For analysis exception, only the error is printed out to the console. rc.getException() match { - case e : AnalysisException => + case e: AnalysisException => err.println(s"""Error in query: ${e.getMessage}""") case _ => err.println(rc.getErrorMessage()) } @@ -369,7 +369,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (counter != 0) { responseMsg += s", Fetched $counter row(s)" } - console.printInfo(responseMsg , null) + console.printInfo(responseMsg, null) // Destroy the driver to release all the locks. driver.destroy() } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7a260e72eb459..c9df3c4a82c88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -657,8 +657,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) : _*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) + java.util.Arrays.asList(fields.map(f => f.name): _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)): _*)) } /** @@ -905,8 +905,8 @@ private[hive] trait HiveInspectors { getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) : _*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) + java.util.Arrays.asList(fields.map(_.name): _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo): _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 56cab1aee89df..912cd41173a2a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -181,7 +181,7 @@ private[hive] case class HiveSimpleUDF( val ret = FunctionRegistry.invoke( method, function, - conversionHelper.convertIfNecessary(inputs : _*): _*) + conversionHelper.convertIfNecessary(inputs: _*): _*) unwrap(ret, returnInspector) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3b867bbfa1817..ad28345a667d0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -118,8 +118,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) : _*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) + java.util.Arrays.asList(fields.map(f => f.name): _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)): _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index da7303c791064..40e9c9362cf5e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -154,8 +154,8 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } val expected = List( "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil, "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) From c48f2a3a5fd714ad2ff19b29337e55583988431e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 12 Jan 2016 11:50:33 +0000 Subject: [PATCH 1535/2089] [SPARK-7615][MLLIB] MLLIB Word2Vec wordVectors divided by Euclidean Norm equals to zero Cosine similarity with 0 vector should be 0 Related to https://github.com/apache/spark/pull/10152 Author: Sean Owen Closes #10696 from srowen/SPARK-7615. --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index dc5d070890d5d..dee898827f30f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -543,7 +543,12 @@ class Word2VecModel private[spark] ( val cosVec = cosineVec.map(_.toDouble) var ind = 0 while (ind < numWords) { - cosVec(ind) /= wordVecNorms(ind) + val norm = wordVecNorms(ind) + if (norm == 0.0) { + cosVec(ind) = 0.0 + } else { + cosVec(ind) /= norm + } ind += 1 } wordList.zip(cosVec) From 9c7f34af37ef328149c1d66b4689d80a1589e1cc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 12 Jan 2016 12:13:32 +0000 Subject: [PATCH 1536/2089] [SPARK-5273][MLLIB][DOCS] Improve documentation examples for LinearRegression Use a much smaller step size in LinearRegressionWithSGD MLlib examples to achieve a reasonable RMSE. Our training folks hit this exact same issue when concocting an example and had the same solution. Author: Sean Owen Closes #10675 from srowen/SPARK-5273. --- docs/mllib-linear-methods.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 20b35612cab95..aac8f7560a4f8 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -590,7 +590,8 @@ val parsedData = data.map { line => // Building the model val numIterations = 100 -val model = LinearRegressionWithSGD.train(parsedData, numIterations) +val stepSize = 0.00000001 +val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) // Evaluate model on training examples and compute training error val valuesAndPreds = parsedData.map { point => @@ -655,8 +656,9 @@ public class LinearRegression { // Building the model int numIterations = 100; + double stepSize = 0.00000001; final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); // Evaluate model on training examples and compute training error JavaRDD> valuesAndPreds = parsedData.map( @@ -706,7 +708,7 @@ data = sc.textFile("data/mllib/ridge-data/lpsa.data") parsedData = data.map(parsePoint) # Build the model -model = LinearRegressionWithSGD.train(parsedData) +model = LinearRegressionWithSGD.train(parsedData, iterations=100, step=0.00000001) # Evaluate the model on training data valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) From 9f0995bb0d0bbe5d9b15a1ca9fa18e246ff90d66 Mon Sep 17 00:00:00 2001 From: Tommy YU Date: Tue, 12 Jan 2016 13:20:04 +0000 Subject: [PATCH 1537/2089] [SPARK-12638][API DOC] Parameter explanation not very accurate for rdd function "aggregate" Currently, RDD function aggregate's parameter doesn't explain well, especially parameter "zeroValue". It's helpful to let junior scala user know that "zeroValue" attend both "seqOp" and "combOp" phase. Author: Tommy YU Closes #10587 from Wenpei/rdd_aggregate_doc. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index de7102f5b6245..53e01a0dbfc06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -970,6 +970,13 @@ abstract class RDD[T: ClassTag]( * apply the fold to each element sequentially in some defined ordering. For functions * that are not commutative, the result may differ from that of a fold applied to a * non-distributed collection. + * + * @param zeroValue the initial value for the accumulated result of each partition for the `op` + * operator, and also the initial value for the combine results from different + * partitions for the `op` operator - this will typically be the neutral + * element (e.g. `Nil` for list concatenation or `0` for summation) + * @param op an operator used to both accumulate results within a partition and combine results + * from different partitions */ def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks @@ -988,6 +995,13 @@ abstract class RDD[T: ClassTag]( * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. + * + * @param zeroValue the initial value for the accumulated result of each partition for the + * `seqOp` operator, and also the initial value for the combine results from + * different partitions for the `combOp` operator - this will typically be the + * neutral element (e.g. `Nil` for list concatenation or `0` for summation) + * @param seqOp an operator used to accumulate results within a partition + * @param combOp an associative operator used to combine results from different partitions */ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope { // Clone the zero value since we will also be serializing it as part of tasks From 7e15044d9d9f9839c8d422bae71f27e855d559b4 Mon Sep 17 00:00:00 2001 From: Yucai Yu Date: Tue, 12 Jan 2016 13:23:23 +0000 Subject: [PATCH 1538/2089] [SPARK-12582][TEST] IndexShuffleBlockResolverSuite fails in windows [SPARK-12582][Test] IndexShuffleBlockResolverSuite fails in windows * IndexShuffleBlockResolverSuite fails in windows due to file is not closed. * mv IndexShuffleBlockResolverSuite.scala from "test/java" to "test/scala". https://issues.apache.org/jira/browse/SPARK-12582 Author: Yucai Yu Closes #10526 from yucai/master. --- .../sort/IndexShuffleBlockResolverSuite.scala | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) rename core/src/test/{java => scala}/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala (87%) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala similarity index 87% rename from core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index f200ff36c7dd5..d21ce73f4021e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -19,18 +19,18 @@ package org.apache.spark.shuffle.sort import java.io.{File, FileInputStream, FileOutputStream} +import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfterEach +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkFunSuite} class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -64,12 +64,15 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } test("commit shuffle files multiple times") { - val lengths = Array[Long](10, 0, 20) val resolver = new IndexShuffleBlockResolver(conf, blockManager) + val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) val out = new FileOutputStream(dataTmp) - out.write(new Array[Byte](30)) - out.close() + Utils.tryWithSafeFinally { + out.write(new Array[Byte](30)) + } { + out.close() + } resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) val dataFile = resolver.getDataFile(1, 2) @@ -77,12 +80,15 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa assert(dataFile.length() === 30) assert(!dataTmp.exists()) + val lengths2 = new Array[Long](3) val dataTmp2 = File.createTempFile("shuffle", null, tempDir) val out2 = new FileOutputStream(dataTmp2) - val lengths2 = new Array[Long](3) - out2.write(Array[Byte](1)) - out2.write(new Array[Byte](29)) - out2.close() + Utils.tryWithSafeFinally { + out2.write(Array[Byte](1)) + out2.write(new Array[Byte](29)) + } { + out2.close() + } resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) assert(lengths2.toSeq === lengths.toSeq) assert(dataFile.exists()) @@ -90,20 +96,27 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa assert(!dataTmp2.exists()) // The dataFile should be the previous one - val in = new FileInputStream(dataFile) val firstByte = new Array[Byte](1) - in.read(firstByte) + val in = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + in.read(firstByte) + } { + in.close() + } assert(firstByte(0) === 0) // remove data file dataFile.delete() + val lengths3 = Array[Long](10, 10, 15) val dataTmp3 = File.createTempFile("shuffle", null, tempDir) val out3 = new FileOutputStream(dataTmp3) - val lengths3 = Array[Long](10, 10, 15) - out3.write(Array[Byte](2)) - out3.write(new Array[Byte](34)) - out3.close() + Utils.tryWithSafeFinally { + out3.write(Array[Byte](2)) + out3.write(new Array[Byte](34)) + } { + out3.close() + } resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) @@ -111,9 +124,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa assert(!dataTmp2.exists()) // The dataFile should be the previous one - val in2 = new FileInputStream(dataFile) val firstByte2 = new Array[Byte](1) - in2.read(firstByte2) + val in2 = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + in2.read(firstByte2) + } { + in2.close() + } assert(firstByte2(0) === 2) } } From 1d8887953018b2e12b6ee47a76e50e542c836b80 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Jan 2016 10:58:57 -0800 Subject: [PATCH 1539/2089] [SPARK-12762][SQL] Add unit test for SimplifyConditionals optimization rule This pull request does a few small things: 1. Separated if simplification from BooleanSimplification and created a new rule SimplifyConditionals. In the future we can also simplify other conditional expressions here. 2. Added unit test for SimplifyConditionals. 3. Renamed SimplifyCaseConversionExpressionsSuite to SimplifyStringCaseConversionSuite Author: Reynold Xin Closes #10716 from rxin/SPARK-12762. --- .../expressions/conditionalExpressions.scala | 10 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 10 ++++ .../optimizer/CombiningLimitsSuite.scala | 3 +- .../optimizer/SimplifyConditionalSuite.scala | 50 +++++++++++++++++++ ...> SimplifyStringCaseConversionSuite.scala} | 3 +- 5 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{SimplifyCaseConversionExpressionsSuite.scala => SimplifyStringCaseConversionSuite.scala} (96%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 19da849d2bec9..379e62a26eb47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -45,7 +45,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def dataType: DataType = trueValue.dataType override def eval(input: InternalRow): Any = { - if (true == predicate.eval(input)) { + if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) } else { falseValue.eval(input) @@ -141,8 +141,8 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { } } - /** Written in imperative fashion for performance considerations. */ override def eval(input: InternalRow): Any = { + // Written in imperative fashion for performance considerations val len = branchesArr.length var i = 0 // If all branches fail and an elseVal is not provided, the whole statement @@ -389,7 +389,7 @@ case class Least(children: Seq[Expression]) extends Expression { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = + def updateEval(eval: GeneratedExpressionCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || @@ -398,6 +398,7 @@ case class Least(children: Seq[Expression]) extends Expression { ${ev.value} = ${eval.value}; } """ + } s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; @@ -447,7 +448,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = + def updateEval(eval: GeneratedExpressionCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || @@ -456,6 +457,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ${ev.value} = ${eval.value}; } """ + } s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b70bc184d0a5e..487431f8925a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -63,6 +63,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ConstantFolding, LikeSimplification, BooleanSimplification, + SimplifyConditionals, RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, @@ -608,7 +609,16 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case Not(a And b) => Or(Not(a), Not(b)) case Not(Not(e)) => e + } + } +} +/** + * Simplifies conditional expressions (if / case). + */ +object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 9fe2b2d1f48ca..87ad81db11b64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -34,7 +34,8 @@ class CombiningLimitsSuite extends PlanTest { Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, - BooleanSimplification) :: Nil + BooleanSimplification, + SimplifyConditionals) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala new file mode 100644 index 0000000000000..8e5d7ef3c9d49 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class SimplifyConditionalSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + test("simplify if") { + assertEquivalent( + If(TrueLiteral, Literal(10), Literal(20)), + Literal(10)) + + assertEquivalent( + If(FalseLiteral, Literal(10), Literal(20)), + Literal(20)) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala similarity index 96% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala index 41455221cfdc6..24413e7a2a3f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -25,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -class SimplifyCaseConversionExpressionsSuite extends PlanTest { +class SimplifyStringCaseConversionSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = From 508592b1bae3b2c88350ddfc1d909892f236ce5f Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 12 Jan 2016 11:09:28 -0800 Subject: [PATCH 1540/2089] [SPARK-9843][SQL] Make catalyst optimizer pass pluggable at runtime Let me know whether you'd like to see it in other place Author: Robert Kruszewski Closes #10210 from robert3005/feature/pluggable-optimizer. --- .../spark/sql/ExperimentalMethods.scala | 5 ++++ .../org/apache/spark/sql/SQLContext.scala | 4 +-- .../spark/sql/execution/SparkOptimizer.scala | 27 +++++++++++++++++++ .../apache/spark/sql/SQLContextSuite.scala | 12 +++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index 717709e4f9312..deed45d273c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule /** * :: Experimental :: @@ -42,4 +44,7 @@ class ExperimentalMethods protected[sql](sqlContext: SQLContext) { @Experimental var extraStrategies: Seq[Strategy] = Nil + @Experimental + var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 61c74f83409e9..6721d9c40748b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} +import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ @@ -202,7 +202,7 @@ class SQLContext private[sql]( } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala new file mode 100644 index 0000000000000..edaf3b36aa52e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.optimizer._ + +class SparkOptimizer(val sqlContext: SQLContext) + extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 1994dacfc4dfa..14b9448d260f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -18,9 +18,15 @@ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ + object DummyRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + test("getOrCreate instantiates SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) assert(sqlContext != null, "SQLContext.getOrCreate returned null") @@ -65,4 +71,10 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ session2.sql("select myadd(1, 2)").explain() } } + + test("Catalyst optimization passes are modifiable at runtime") { + val sqlContext = SQLContext.getOrCreate(sc) + sqlContext.experimental.extraOptimizations = Seq(DummyRule) + assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + } } From 0ed430e315b9a409490a3604a619321b476cb520 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Jan 2016 11:13:08 -0800 Subject: [PATCH 1541/2089] [SPARK-12768][SQL] Remove CaseKeyWhen expression This patch removes CaseKeyWhen expression and replaces it with a factory method that generates the equivalent CaseWhen. This reduces the amount of code we'd need to maintain in the future for both code generation and optimizer. Note that we introduced CaseKeyWhen to avoid duplicate evaluations of the key. This is no longer a problem because we now have common subexpression elimination. Author: Reynold Xin Closes #10722 from rxin/SPARK-12768. --- .../catalyst/analysis/HiveTypeCoercion.scala | 20 +- .../expressions/conditionalExpressions.scala | 187 ++++-------------- .../analysis/HiveTypeCoercionSuite.scala | 2 +- 3 files changed, 38 insertions(+), 171 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e326ea782700c..75c36d93108df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -638,8 +638,7 @@ object HiveTypeCoercion { */ object CaseWhenCoercion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => - logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { @@ -649,22 +648,7 @@ object HiveTypeCoercion { Seq(Cast(elseVal, commonType)) case other => other }.reduce(_ ++ _) - c match { - case _: CaseWhen => CaseWhen(castedBranches) - case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) - } - }.getOrElse(c) - - case c: CaseKeyWhen if c.childrenResolved && !c.resolved => - val maybeCommonType = - findWiderCommonType((c.key +: c.whenList).map(_.dataType)) - maybeCommonType.map { commonType => - val castedBranches = c.branches.grouped(2).map { - case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => - Seq(Cast(whenExpr, commonType), thenExpr) - case other => other - }.reduce(_ ++ _) - CaseKeyWhen(Cast(c.key, commonType), castedBranches) + CaseWhen(castedBranches) }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 379e62a26eb47..5a1462433d583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -78,17 +78,23 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" } -trait CaseWhenLike extends Expression { +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * When a = true, returns b; when c = true, returns d; else returns e. + */ +case class CaseWhen(branches: Seq[Expression]) extends Expression { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - // element is the value for the default catch-all case (if provided). - // Hence, `branches` consists of at least two elements, and can have an odd or even length. - def branches: Seq[Expression] + override def children: Seq[Expression] = branches @transient lazy val whenList = branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq + @transient lazy val thenList = branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq + val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) // both then and else expressions should be considered. @@ -97,47 +103,26 @@ trait CaseWhenLike extends Expression { case Seq(dt1, dt2) => dt1.sameType(dt2) } - override def checkInputDataTypes(): TypeCheckResult = { - if (valueTypesEqual) { - checkTypesInternal() - } else { - TypeCheckResult.TypeCheckFailure( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } - } - - protected def checkTypesInternal(): TypeCheckResult - override def dataType: DataType = thenList.head.dataType override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) } -} - -// scalastyle:off -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - override def children: Seq[Expression] = branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if (whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.TypeCheckSuccess + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } } else { - val index = whenList.indexWhere(_.dataType != BooleanType) TypeCheckResult.TypeCheckFailure( - s"WHEN expressions in CaseWhen should all be boolean type, " + - s"but the ${index + 1}th when expression's type is ${whenList(index)}") + "THEN and ELSE expressions should all be same type or coercible to a common type") } } @@ -227,125 +212,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { } } -// scalastyle:off /** * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + * When a = b, returns c; when a = d, returns e; else returns f. */ -// scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = key +: branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if ((key +: whenList).map(_.dataType).distinct.size > 1) { - TypeCheckResult.TypeCheckFailure( - "key and WHEN expressions should all be same type or coercible to a common type") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - private def evalElse(input: InternalRow): Any = { - if (branchesArr.length % 2 == 0) { - null - } else { - branchesArr(branchesArr.length - 1).eval(input) - } - } - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: InternalRow): Any = { - val evaluatedKey = key.eval(input) - // If key is null, we can just return the else part or null if there is no else. - // If key is not null but doesn't match any when part, we need to return - // the else part or null if there is no else, according to Hive's semantics. - if (evaluatedKey != null) { - val len = branchesArr.length - var i = 0 - while (i < len - 1) { - if (evaluatedKey == branchesArr(i).eval(input)) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - } - evalElse(input) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val keyEval = key.gen(ctx) - val len = branchesArr.length - val got = ctx.freshName("got") - - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) - s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.value, cond.value)}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } - } - """ - }.mkString("\n") - - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } - """ - } else { - "" - } - - s""" - boolean $got = false; - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${keyEval.code} - if (!${keyEval.isNull}) { - $cases +object CaseKeyWhen { + def apply(key: Expression, branches: Seq[Expression]): CaseWhen = { + val newBranches = branches.zipWithIndex.map { case (expr, i) => + if (i % 2 == 0 && i != branches.size - 1) { + // If this expression is at even position, then it is either a branch condition, or + // the very last value that is the "else value". The "i != branches.size - 1" makes + // sure we are not adding an EqualTo to the "else value". + EqualTo(key, expr) + } else { + expr } - $other - """ - } - - override def toString: String = { - s"CASE $key" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } - - override def sql: String = { - val keySQL = key.sql - val branchesSQL = branches.map(_.sql) - val (cases, maybeElse) = if (branches.length % 2 == 0) { - (branchesSQL, None) - } else { - (branchesSQL.init, Some(branchesSQL.last)) } - - val head = s"CASE $keySQL " - val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" - val body = cases.grouped(2).map { - case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" - }.mkString(" ") - - head + body + tail + CaseWhen(newBranches) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 58d808c55860d..23b11af9ac087 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -299,7 +299,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - ruleTest(HiveTypeCoercion.CaseWhenCoercion, + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) From 0d543b98f3e3da5053f0476f4647a765460861f3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Jan 2016 12:56:52 -0800 Subject: [PATCH 1542/2089] Revert "[SPARK-12692][BUILD][SQL] Scala style: Fix the style violation (Space before "," or ":")" This reverts commit 8cfa218f4f1b05f4d076ec15dd0a033ad3e4500d. --- scalastyle-config.xml | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 6 ++-- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 6 ++-- .../catalyst/analysis/FunctionRegistry.scala | 4 +-- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 4 +-- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../spark/sql/catalyst/encoders/package.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/stringExpressions.scala | 6 ++-- .../plans/logical/basicOperators.scala | 6 ++-- .../sql/catalyst/util/NumberConverter.scala | 2 +- .../apache/spark/sql/types/ArrayType.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 2 -- .../encoders/EncoderErrorMessageSuite.scala | 2 +- .../encoders/ExpressionEncoderSuite.scala | 6 ++-- .../BooleanSimplificationSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 4 +-- .../org/apache/spark/sql/DataFrame.scala | 36 +++++++++---------- .../apache/spark/sql/DataFrameHolder.scala | 2 +- .../spark/sql/DataFrameNaFunctions.scala | 8 ++--- .../apache/spark/sql/DataFrameReader.scala | 6 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 18 +++++----- .../org/apache/spark/sql/GroupedData.scala | 10 +++--- .../org/apache/spark/sql/GroupedDataset.scala | 8 ++--- .../org/apache/spark/sql/SQLContext.scala | 12 +++---- .../org/apache/spark/sql/SQLImplicits.scala | 10 +++--- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../spark/sql/execution/Queryable.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 2 +- .../datasources/SqlNewHadoopRDD.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 2 +- .../execution/joins/CartesianProduct.scala | 2 +- .../sql/execution/metric/SQLMetrics.scala | 2 +- .../sql/execution/stat/FrequentItems.scala | 4 +-- .../apache/spark/sql/expressions/Window.scala | 8 ++--- .../org/apache/spark/sql/functions.scala | 6 ++-- .../spark/sql/jdbc/AggregatedDialect.scala | 2 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 10 +++--- .../apache/spark/sql/jdbc/MySQLDialect.scala | 7 ++-- .../spark/sql/DatasetAggregatorSuite.scala | 4 +-- .../apache/spark/sql/DatasetCacheSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 26 +++++++------- .../datasources/json/JsonSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +-- .../hive/thriftserver/ReflectionUtils.scala | 2 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +-- .../spark/sql/hive/HiveInspectors.scala | 8 ++--- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- .../spark/sql/hive/HiveInspectorSuite.scala | 4 +-- .../sql/hive/InsertIntoHiveTableSuite.scala | 4 +-- 54 files changed, 141 insertions(+), 150 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index b873b627219f2..2439a1f715aba 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -218,7 +218,7 @@ This file is divided into 3 sections: - + COLON, COMMA diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 23fea0e2832a1..79f723cf9b8a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -49,7 +49,7 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor[T: TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { tpe match { @@ -116,7 +116,7 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T: TypeTag]: Expression = { + def constructorFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil @@ -386,7 +386,7 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def extractorsFor[T: TypeTag](inputObject: Expression): CreateNamedStruct = { + def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 6ec408a673c79..2a132d8b82bef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -203,7 +203,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(expression ~ direction.?, ",") ^^ { + ( rep1sep(expression ~ direction.? , ",") ^^ { case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d16880bc4a9c9..8a33af8207350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -84,7 +84,7 @@ class Analyzer( ResolveAggregateFunctions :: DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules: _*), + extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, @@ -110,7 +110,7 @@ class Analyzer( // Taking into account the reasonableness and the implementation complexity, // here use the CTE definition first, check table name only and ignore database name // see https://github.com/apache/spark/pull/4929#discussion_r27186638 for more info - case u: UnresolvedRelation => + case u : UnresolvedRelation => val substituted = cteRelations.get(u.tableIdentifier.table).map { relation => val withAlias = u.alias.map(Subquery(_, relation)) withAlias.getOrElse(relation) @@ -889,7 +889,7 @@ class Analyzer( _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). - case wf: WindowFunction => + case wf : WindowFunction => val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7c3d45b1e40c0..5c2aa3c06b3e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -323,13 +323,13 @@ object FunctionRegistry { } else { // Otherwise, find an ctor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = Try(tag.runtimeClass.getDeclaredConstructor(params: _*)) match { + val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { case Success(e) => e case Failure(e) => throw new AnalysisException(s"Invalid number of arguments for function $name") } - Try(f.newInstance(expressions: _*).asInstanceOf[Expression]) match { + Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e case Failure(e) => throw new AnalysisException(e.getMessage) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 75c36d93108df..e9e20670817fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -529,7 +529,7 @@ object HiveTypeCoercion { if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) case EqualTo(left @ BooleanType(), right @ NumericType()) => - transform(left, right) + transform(left , right) case EqualTo(left @ NumericType(), right @ BooleanType()) => transform(right, left) case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index c4dbcb7b60628..5ac1984043d87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -61,11 +61,9 @@ package object dsl { trait ImplicitOperators { def expr: Expression - // scalastyle:off whitespacebeforetoken def unary_- : Expression = UnaryMinus(expr) def unary_! : Predicate = Not(expr) def unary_~ : Expression = BitwiseNot(expr) - // scalastyle:on whitespacebeforetoken def + (other: Expression): Expression = Add(expr, other) def - (other: Expression): Expression = Subtract(expr, other) @@ -143,7 +141,7 @@ package object dsl { // Note that if we make ExpressionConversions an object rather than a trait, we can // then make this a value class to avoid the small penalty of runtime instantiation. def $(args: Any*): analysis.UnresolvedAttribute = { - analysis.UnresolvedAttribute(sc.s(args: _*)) + analysis.UnresolvedAttribute(sc.s(args : _*)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index fa4c2d93eccec..05f746e72b498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { - def apply[T: TypeTag](): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 08ada1f38ba96..9e283f5eb6342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -27,7 +27,7 @@ package object encoders { * references from a specific schema.) This requirement allows us to preserve whether a given * object type is being bound by name or by ordinal when doing resolution. */ - private[sql] def encoderFor[A: Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e.assertUnresolved() e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4ffbfa57e726d..d6219514b752b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -164,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns the hash for this expression. Expressions that compute the same result, even if * they differ cosmetically should return the same hash. */ - def semanticHash(): Int = { + def semanticHash() : Int = { def computeHash(e: Seq[Any]): Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var hash: Int = 17 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bf41f85f79096..931f752b4dc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -46,7 +46,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs: _*) + UTF8String.concat(inputs : _*) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -99,7 +99,7 @@ case class ConcatWs(children: Seq[Expression]) case null => Iterator(null.asInstanceOf[UTF8String]) } } - UTF8String.concatWs(flatInputs.head, flatInputs.tail: _*) + UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -990,7 +990,7 @@ case class FormatNumber(x: Expression, d: Expression) def typeHelper(p: String): String = { x.dataType match { - case _: DecimalType => s"""$p.toJavaBigDecimal()""" + case _ : DecimalType => s"""$p.toJavaBigDecimal()""" case _ => s"$p" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5489051e9501b..64957db6b4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -496,7 +496,7 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { - def apply[T, U: Encoder]( + def apply[T, U : Encoder]( func: T => U, tEncoder: ExpressionEncoder[T], child: LogicalPlan): AppendColumns[T, U] = { @@ -522,7 +522,7 @@ case class AppendColumns[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { - def apply[K, T, U: Encoder]( + def apply[K, T, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], @@ -557,7 +557,7 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[Key, Left, Right, Result: Encoder]( + def apply[Key, Left, Right, Result : Encoder]( func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index e4417e0955143..9fefc5656aac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -122,7 +122,7 @@ object NumberConverter { * unsigned, otherwise it is signed. * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv */ - def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { + def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index be7573b95d841..520e344361625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -90,7 +90,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { private[this] val elementOrdering: Ordering[Any] = elementType match { case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] - case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 864b47a2a08aa..38ce1604b1ede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -310,7 +310,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { def remainder(that: Decimal): Decimal = this % that - // scalastyle:off whitespacebeforetoken def unary_- : Decimal = { if (decimalVal.ne(null)) { Decimal(-decimalVal, precision, scale) @@ -318,7 +317,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { Decimal(-longVal, precision, scale) } } - // scalastyle:on whitespacebeforetoken def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index a1c4a861c610f..8c766ef829923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -98,5 +98,5 @@ class EncoderErrorMessageSuite extends SparkFunSuite { s"""array element class: "${clsName[NonEncodable]}"""")) } - private def clsName[T: ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 67f4dc98be231..88c558d80a79a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -80,7 +80,7 @@ class JavaSerializable(val value: Int) extends Serializable { class ExpressionEncoderSuite extends SparkFunSuite { OuterScopes.outerScopes.put(getClass.getName, this) - implicit def encoder[T: TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() // test flat encoders encodeDecodeTest(false, "primitive boolean") @@ -145,7 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { encoderFor(Encoders.javaSerialization[JavaSerializable])) // test product encoders - private def productTest[T <: Product: ExpressionEncoder](input: T): Unit = { + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { encodeDecodeTest(input, input.getClass.getSimpleName) } @@ -286,7 +286,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { } } - private def encodeDecodeTest[T: ExpressionEncoder]( + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { test(s"encode/decode for $testName: $input") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 6932f185b9d62..000a3b7ecb7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -80,7 +80,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) - checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5), 'a < 2) + checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a434d03332459..e8c61d6e01dc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -152,7 +152,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[U: Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. @@ -171,7 +171,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { UnresolvedExtractValue(expr, lit(extraction).expr) } - // scalastyle:off whitespacebeforetoken /** * Unary minus, i.e. negate the expression. * {{{ @@ -203,7 +202,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ def unary_! : Column = withExpr { Not(expr) } - // scalastyle:on whitespacebeforetoken /** * Equality test. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fac8950aee12d..60d2f05b8605b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -204,7 +204,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @Experimental - def as[U: Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) + def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion @@ -227,7 +227,7 @@ class DataFrame private[sql]( val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } - select(newCols: _*) + select(newCols : _*) } /** @@ -579,7 +579,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { - sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*) + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) } /** @@ -608,7 +608,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply): _*) + sort((sortCol +: sortCols).map(apply) : _*) } /** @@ -631,7 +631,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols: _*) + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -640,7 +640,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs: _*) + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -720,7 +720,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)): _*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -948,7 +948,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs: _*) + groupBy().agg(aggExpr, aggExprs : _*) } /** @@ -986,7 +986,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -1118,7 +1118,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def explode[A <: Product: TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val elementTypes = schema.toAttributes.map { @@ -1147,7 +1147,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) + def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil @@ -1186,7 +1186,7 @@ class DataFrame private[sql]( Column(field) } } - select(columns: _*) + select(columns : _*) } else { select(Column("*"), col.as(colName)) } @@ -1207,7 +1207,7 @@ class DataFrame private[sql]( Column(field) } } - select(columns: _*) + select(columns : _*) } else { select(Column("*"), col.as(colName, metadata)) } @@ -1231,7 +1231,7 @@ class DataFrame private[sql]( Column(col) } } - select(columns: _*) + select(columns : _*) } else { this } @@ -1244,7 +1244,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ def drop(colName: String): DataFrame = { - drop(Seq(colName): _*) + drop(Seq(colName) : _*) } /** @@ -1283,7 +1283,7 @@ class DataFrame private[sql]( val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) - select(colsAfterDrop: _*) + select(colsAfterDrop : _*) } /** @@ -1479,7 +1479,7 @@ class DataFrame private[sql]( * @group action * @since 1.6.0 */ - def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n): _*) + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. @@ -1505,7 +1505,7 @@ class DataFrame private[sql]( */ def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => withNewExecutionId { - java.util.Arrays.asList(rdd.collect(): _*) + java.util.Arrays.asList(rdd.collect() : _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index 4441a634be407..3b30337f1f877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -33,5 +33,5 @@ case class DataFrameHolder private[sql](private val df: DataFrame) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = df - def toDF(colNames: String*): DataFrame = df.toDF(colNames: _*) + def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 43500b09e0f38..f7be5f6b370ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -164,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.col(f.name) } } - df.select(projections: _*) + df.select(projections : _*) } /** @@ -191,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.col(f.name) } } - df.select(projections: _*) + df.select(projections : _*) } /** @@ -364,7 +364,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.col(f.name) } } - df.select(projections: _*) + df.select(projections : _*) } private def fill0(values: Seq[(String, Any)]): DataFrame = { @@ -395,7 +395,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } }.getOrElse(df.col(f.name)) } - df.select(projections: _*) + df.select(projections : _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1ed451d5a8bab..d948e4894253c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -203,7 +203,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { predicates: Array[String], connectionProperties: Properties): DataFrame = { val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) => - JDBCPartition(part, i): Partition + JDBCPartition(part, i) : Partition } jdbc(url, table, parts, connectionProperties) } @@ -262,7 +262,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def json(paths: String*): DataFrame = format("json").load(paths: _*) + def json(paths: String*): DataFrame = format("json").load(paths : _*) /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and @@ -355,7 +355,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.6.0 */ @scala.annotation.varargs - def text(paths: String*): DataFrame = format("text").load(paths: _*) + def text(paths: String*): DataFrame = format("text").load(paths : _*) /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9ffb5b94b2d18..42f01e9359c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -131,7 +131,7 @@ class Dataset[T] private[sql]( * along with `alias` or `as` to rearrange or rename as required. * @since 1.6.0 */ - def as[U: Encoder]: Dataset[U] = { + def as[U : Encoder]: Dataset[U] = { new Dataset(sqlContext, queryExecution, encoderFor[U]) } @@ -318,7 +318,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def map[U: Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) /** * (Java-specific) @@ -333,7 +333,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ - def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapPartitions[T, U]( @@ -360,7 +360,7 @@ class Dataset[T] private[sql]( * and then flattening the results. * @since 1.6.0 */ - def flatMap[U: Encoder](func: T => TraversableOnce[U]): Dataset[U] = + def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) /** @@ -432,7 +432,7 @@ class Dataset[T] private[sql]( * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ - def groupBy[K: Encoder](func: T => K): GroupedDataset[K, T] = { + def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) @@ -566,14 +566,14 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by sampling a fraction of records. * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = + def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = withPlan(Sample(0.0, fraction, withReplacement, seed, _)) /** * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { + def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { sample(withReplacement, fraction, Utils.random.nextLong) } @@ -731,7 +731,7 @@ class Dataset[T] private[sql]( * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ - def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num): _*) + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) /** * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). @@ -786,7 +786,7 @@ class Dataset[T] private[sql]( private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) - private[sql] def withPlan[R: Encoder]( + private[sql] def withPlan[R : Encoder]( other: Dataset[_])( f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f5cbf013bce9d..c74ef2c03541e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -229,7 +229,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames: _*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -241,7 +241,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames: _*)(Max) + aggregateNumericColumns(colNames : _*)(Max) } /** @@ -253,7 +253,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames: _*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -265,7 +265,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames: _*)(Min) + aggregateNumericColumns(colNames : _*)(Min) } /** @@ -277,7 +277,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames: _*)(Sum) + aggregateNumericColumns(colNames : _*)(Sum) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 12179367fa012..a819ddceb1b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -73,7 +73,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def keyAs[L: Encoder]: GroupedDataset[L, V] = + def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], unresolvedVEncoder, @@ -110,7 +110,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( @@ -158,7 +158,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) flatMapGroups(func) } @@ -302,7 +302,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def cogroup[U, R: Encoder]( + def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { new Dataset[R]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6721d9c40748b..2dd82358fbfdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -409,7 +409,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ @Experimental - def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { + def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes @@ -425,7 +425,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ @Experimental - def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = { + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes @@ -498,7 +498,7 @@ class SQLContext private[sql]( } - def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = { + def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d).copy()) @@ -507,7 +507,7 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } - def createDataset[T: Encoder](data: RDD[T]): Dataset[T] = { + def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d)) @@ -516,7 +516,7 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } - def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = { + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -945,7 +945,7 @@ class SQLContext private[sql]( } } - // Register a successfully instantiated context to the singleton. This should be at the end of + // Register a succesfully instantiatd context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. sparkContext.addSparkListener(new SparkListener { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index a7f7997df1a8b..ab414799f1a42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,7 +37,7 @@ abstract class SQLImplicits { protected def _sqlContext: SQLContext /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = ExpressionEncoder() + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() @@ -67,7 +67,7 @@ abstract class SQLImplicits { * Creates a [[Dataset]] from an RDD. * @since 1.6.0 */ - implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] = { + implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) } @@ -75,7 +75,7 @@ abstract class SQLImplicits { * Creates a [[Dataset]] from a local Seq. * @since 1.6.0 */ - implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = { + implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(s)) } @@ -89,7 +89,7 @@ abstract class SQLImplicits { * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). * @since 1.3.0 */ - implicit def rddToDataFrameHolder[A <: Product: TypeTag](rdd: RDD[A]): DataFrameHolder = { + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { DataFrameHolder(_sqlContext.createDataFrame(rdd)) } @@ -97,7 +97,7 @@ abstract class SQLImplicits { * Creates a DataFrame from a local Seq of Product. * @since 1.3.0 */ - implicit def localSeqToDataFrameHolder[A <: Product: TypeTag](data: Seq[A]): DataFrameHolder = + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { DataFrameHolder(_sqlContext.createDataFrame(data)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a8e6a40169d81..d912aeb70d517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -39,7 +39,7 @@ private[r] object SQLUtils { new JavaSparkContext(sqlCtx.sparkContext) } - def createStructType(fields: Seq[StructField]): StructType = { + def createStructType(fields : Seq[StructField]): StructType = { StructType(fields) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 058d147c7d65d..6b100577077c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -223,7 +223,7 @@ case class Exchange( new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { coordinator match { case Some(exchangeCoordinator) => val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index bb551614779b5..38263af0f7e30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -71,7 +71,7 @@ private[sql] trait Queryable { private[sql] def formatString ( rows: Seq[Seq[String]], numRows: Int, - hasMoreData: Boolean, + hasMoreData : Boolean, truncate: Boolean = true): String = { val sb = new StringBuilder val numCols = schema.fieldNames.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index b5ac530444b79..1df38f7ff59cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { - def apply[A, B: Encoder, C: Encoder]( + def apply[A, B : Encoder, C : Encoder]( aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index d5e0d80076cbe..d45d2db62f3a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -256,7 +256,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) } catch { - case e: Exception => + case e : Exception => logDebug("Failed to use InputSplit#getLocationInfo.", e) None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index c4b125e9d5f00..fb97a03df60f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -557,7 +557,7 @@ private[parquet] object CatalystSchemaConverter { } } - private def computeMinBytesForPrecision(precision: Int): Int = { + private def computeMinBytesForPrecision(precision : Int) : Int = { var numBytes = 1 while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { numBytes += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index a567457dba3c5..93d32e1fb93ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter * materialize the right RDD (in case of the right RDD is nondeterministic). */ private[spark] -class UnsafeCartesianRDD(left: RDD[UnsafeRow], right: RDD[UnsafeRow], numFieldsOfRight: Int) +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 8c68d9ee0a1ef..52735c9d7f8c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -64,7 +64,7 @@ private[sql] trait SQLMetricValue[T] extends Serializable { /** * A wrapper of Long to avoid boxing and unboxing when using Accumulator */ -private[sql] class LongSQLMetricValue(private var _value: Long) extends SQLMetricValue[Long] { +private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { def add(incr: Long): LongSQLMetricValue = { _value += incr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index a4cb54e2bf2a2..a191759813de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -94,7 +94,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) }.toArray - val freqItems = df.select(cols.map(Column(_)): _*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -115,7 +115,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toArray) - val resultRow = Row(justItems: _*) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 05a9f377b9897..e9b60841fc28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -44,7 +44,7 @@ object Window { */ @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { - spec.partitionBy(colName, colNames: _*) + spec.partitionBy(colName, colNames : _*) } /** @@ -53,7 +53,7 @@ object Window { */ @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { - spec.partitionBy(cols: _*) + spec.partitionBy(cols : _*) } /** @@ -62,7 +62,7 @@ object Window { */ @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { - spec.orderBy(colName, colNames: _*) + spec.orderBy(colName, colNames : _*) } /** @@ -71,7 +71,7 @@ object Window { */ @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { - spec.orderBy(cols: _*) + spec.orderBy(cols : _*) } private def spec: WindowSpec = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1ac62883a68ee..592d79df3109a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -306,7 +306,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - countDistinct(Column(columnName), columnNames.map(Column.apply): _*) + countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) /** * Aggregate function: returns the first value in a group. @@ -768,7 +768,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { - array((colName +: colNames).map(col): _*) + array((colName +: colNames).map(col) : _*) } /** @@ -977,7 +977,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { - struct((colName +: colNames).map(col): _*) + struct((colName +: colNames).map(col) : _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index d2c31d6e04107..467d8d62d1b7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -30,7 +30,7 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect require(dialects.nonEmpty) - override def canHandle(url: String): Boolean = + override def canHandle(url : String): Boolean = dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8d58321d4887d..ca2d909e2cccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi -case class JdbcType(databaseTypeDefinition: String, jdbcNullType: Int) +case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) /** * :: DeveloperApi :: @@ -60,7 +60,7 @@ abstract class JdbcDialect extends Serializable { * @return True if the dialect can be applied on the given jdbc url. * @throws NullPointerException if the url is null. */ - def canHandle(url: String): Boolean + def canHandle(url : String): Boolean /** * Get the custom datatype mapping for the given jdbc meta information. @@ -130,7 +130,7 @@ object JdbcDialects { * * @param dialect The new dialect. */ - def registerDialect(dialect: JdbcDialect): Unit = { + def registerDialect(dialect: JdbcDialect) : Unit = { dialects = dialect :: dialects.filterNot(_ == dialect) } @@ -139,7 +139,7 @@ object JdbcDialects { * * @param dialect The jdbc dialect. */ - def unregisterDialect(dialect: JdbcDialect): Unit = { + def unregisterDialect(dialect : JdbcDialect) : Unit = { dialects = dialects.filterNot(_ == dialect) } @@ -169,5 +169,5 @@ object JdbcDialects { * NOOP dialect object, always returning the neutral element. */ private object NoopDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = true + override def canHandle(url : String): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index faae54e605c68..e1717049f383d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -23,13 +23,10 @@ import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuil private case object MySQLDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( - sqlType: Int, - typeName: String, - size: Int, - md: MetadataBuilder): Option[DataType] = { + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index f952fc07fd387..3258f3782d8cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext /** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N: Numeric](f: I => N) extends Aggregator[I, N, N] { +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero @@ -113,7 +113,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - def sum[I, N: Numeric: Encoder](f: I => N): TypedColumn[I, N] = + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn test("typed aggregation: TypedAggregator") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 848f1af65508b..3a283a4e1f610 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -27,7 +27,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("persist and unpersist") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) val cached = ds.cache() // count triggers the caching action. It should not throw. cached.count() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a3ed2e06165ea..53b5f45c2d4a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -30,7 +30,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("toDS") { - val data = Seq(("a", 1), ("b", 2), ("c", 3)) + val data = Seq(("a", 1) , ("b", 2), ("c", 3)) checkAnswer( data.toDS(), data: _*) @@ -87,7 +87,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("as case class / collect") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] checkAnswer( ds, ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) @@ -105,7 +105,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("map") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.map(v => (v._1, v._2 + 1)), ("a", 2), ("b", 3), ("c", 4)) @@ -124,23 +124,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select(expr("_2 + 1").as[Int]), 2, 3, 4) } test("select 2") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], - expr("_2").as[Int]): Dataset[(String, Int)], + expr("_2").as[Int]) : Dataset[(String, Int)], ("a", 1), ("b", 2), ("c", 3)) } test("select 2, primitive and tuple") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -149,7 +149,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -158,7 +158,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class, fields reordered") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkDecoding( ds.select( expr("_1").as[String], @@ -167,28 +167,28 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("filter") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.filter(_._1 == "b"), ("b", 2)) } test("foreach") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreach(v => acc += v._2) assert(acc.value == 6) } test("foreachPartition") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreachPartition(_.foreach(v => acc += v._2)) assert(acc.value == 6) } test("reduce") { - val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 860e07c68cef1..4ab148065a476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -206,7 +206,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructType( StructField("f1", IntegerType, true) :: StructField("f2", IntegerType, true) :: Nil), - StructType(StructField("f1", LongType, true) :: Nil), + StructType(StructField("f1", LongType, true) :: Nil) , StructType( StructField("f1", LongType, true) :: StructField("f2", IntegerType, true) :: Nil)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f2e0a868f4b1a..ab48e971b507a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -72,7 +72,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { /** * Writes `data` to a Parquet file, reads it back and check file contents. */ - protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { + protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 984e3fbc05e48..1fa22e2933318 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -46,7 +46,7 @@ class JDBCSuite extends SparkFunSuite val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) @@ -489,7 +489,7 @@ class JDBCSuite extends SparkFunSuite test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2:") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala index d1d8a68f6d196..599294dfbb7d7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver private[hive] object ReflectionUtils { - def setSuperField(obj: Object, fieldName: String, fieldValue: Object) { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { setAncestorField(obj, 1, fieldName, fieldValue) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 9f9efe33e12a3..03bc830df2034 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -325,7 +325,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (ret != 0) { // For analysis exception, only the error is printed out to the console. rc.getException() match { - case e: AnalysisException => + case e : AnalysisException => err.println(s"""Error in query: ${e.getMessage}""") case _ => err.println(rc.getErrorMessage()) } @@ -369,7 +369,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (counter != 0) { responseMsg += s", Fetched $counter row(s)" } - console.printInfo(responseMsg, null) + console.printInfo(responseMsg , null) // Destroy the driver to release all the locks. driver.destroy() } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c9df3c4a82c88..7a260e72eb459 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -657,8 +657,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name): _*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)): _*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) } /** @@ -905,8 +905,8 @@ private[hive] trait HiveInspectors { getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name): _*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo): _*)) + java.util.Arrays.asList(fields.map(_.name) : _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 912cd41173a2a..56cab1aee89df 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -181,7 +181,7 @@ private[hive] case class HiveSimpleUDF( val ret = FunctionRegistry.invoke( method, function, - conversionHelper.convertIfNecessary(inputs: _*): _*) + conversionHelper.convertIfNecessary(inputs : _*): _*) unwrap(ret, returnInspector) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index ad28345a667d0..3b867bbfa1817 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -118,8 +118,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name): _*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)): _*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 40e9c9362cf5e..da7303c791064 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -154,8 +154,8 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } val expected = List( "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil, - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) From 8ed5f12d2bb408bd37e4156b5f1bad9a6b8c3cb5 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 12 Jan 2016 14:19:53 -0800 Subject: [PATCH 1543/2089] [SPARK-12724] SQL generation support for persisted data source tables This PR implements SQL generation support for persisted data source tables. A new field `metastoreTableIdentifier: Option[TableIdentifier]` is added to `LogicalRelation`. When a `LogicalRelation` representing a persisted data source relation is created, this field holds the database name and table name of the relation. Author: Cheng Lian Closes #10712 from liancheng/spark-12724-datasources-sql-gen. --- .../scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../datasources/DataSourceStrategy.scala | 16 ++++++++-------- .../execution/datasources/LogicalRelation.scala | 8 +++++--- .../datasources/parquet/ParquetRelation.scala | 10 ++-------- .../spark/sql/execution/datasources/rules.scala | 16 ++++++++-------- .../datasources/parquet/ParquetFilterSuite.scala | 2 +- .../parquet/ParquetPartitionDiscoverySuite.scala | 2 +- .../spark/sql/sources/FilteredScanSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 6 ++++-- .../org/apache/spark/sql/hive/SQLBuilder.scala | 14 +++++--------- .../spark/sql/hive/execution/commands.scala | 2 +- .../spark/sql/hive/LogicalPlanToSQLSuite.scala | 10 ++++++++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- .../spark/sql/hive/orc/OrcFilterSuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 8 ++++---- .../sql/sources/hadoopFsRelationSuites.scala | 2 +- 17 files changed, 55 insertions(+), 51 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 60d2f05b8605b..91bf2f8ce4d2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1728,7 +1728,7 @@ class DataFrame private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: FileRelation, _) => + case LogicalRelation(fsBasedRelation: FileRelation, _, _) => fsBasedRelation.inputFiles case fr: FileRelation => fr.inputFiles diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1d6290e027f3d..da9320ffb61c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} */ private[sql] object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( l, projects, @@ -49,14 +49,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (requestedColumns, allPredicates, _) => toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) => pruneFilterProject( l, projects, filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _)) => pruneFilterProject( l, projects, @@ -64,7 +64,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil // Scanning partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) if t.partitionSpec.partitionColumns.nonEmpty => // We divide the filter expressions into 3 parts val partitionColumns = AttributeSet( @@ -118,7 +118,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { ).getOrElse(scan) :: Nil // Scanning non-partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) => // See buildPartitionedTableScan for the reason that we need to create a shard // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf @@ -130,16 +130,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters, (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil - case l @ LogicalRelation(baseRelation: TableScan, _) => + case l @ LogicalRelation(baseRelation: TableScan, _, _) => execution.PhysicalRDD.createFromDataSource( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil - case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _), + case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), part, query, overwrite, false) if part.isEmpty => execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _), part, query, overwrite, false) => + l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) => val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append execution.ExecutedCommand(InsertIntoHadoopFsRelation(t, query, mode)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 219dae88e515d..fa97f3d7199ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -30,7 +31,8 @@ import org.apache.spark.sql.sources.BaseRelation */ case class LogicalRelation( relation: BaseRelation, - expectedOutputAttributes: Option[Seq[Attribute]] = None) + expectedOutputAttributes: Option[Seq[Attribute]] = None, + metastoreTableIdentifier: Option[TableIdentifier] = None) extends LeafNode with MultiInstanceRelation { override val output: Seq[AttributeReference] = { @@ -49,7 +51,7 @@ case class LogicalRelation( // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _) => relation == otherRelation && output == l.output + case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output case _ => false } @@ -58,7 +60,7 @@ case class LogicalRelation( } override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { - case LogicalRelation(otherRelation, _) => relation == otherRelation + case LogicalRelation(otherRelation, _, _) => relation == otherRelation case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 7754edc803d10..991a5d5aef2db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -44,9 +44,9 @@ import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -147,12 +147,6 @@ private[sql] class ParquetRelation( .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) - // If this relation is converted from a Hive metastore table, this method returns the name of the - // original Hive metastore table. - private[sql] def metastoreTableName: Option[TableIdentifier] = { - parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier) - } - private lazy val metadataCache: MetadataCache = { val meta = new MetadataCache meta.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1c773e69275db..dd3e66d8a9434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,7 +61,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // We are inserting into an InsertableRelation or HadoopFsRelation. case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _), _, child, _, _) => { + l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -70,7 +70,6 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { s"statement generates the same number of columns as its schema.") } castAndRenameChildOutput(i, l.output, child) - } } /** If necessary, cast data types and rename fields to the expected types and names. */ @@ -108,14 +107,15 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation, _), partition, query, overwrite, ifNotExists) => + l @ LogicalRelation(t: InsertableRelation, _, _), + partition, query, overwrite, ifNotExists) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") } else { // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation, _) => src + case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(t)) { failAnalysis( @@ -126,7 +126,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } case logical.InsertIntoTable( - LogicalRelation(r: HadoopFsRelation, _), part, query, overwrite, _) => + LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. val existingPartitionColumns = r.partitionColumns.fieldNames.toSet @@ -145,7 +145,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation, _) => src + case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(r)) { failAnalysis( @@ -173,10 +173,10 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). - case l @ LogicalRelation(dest: BaseRelation, _) => + case l @ LogicalRelation(dest: BaseRelation, _, _) => // Get all input data source relations of the query. val srcRelations = c.child.collect { - case LogicalRelation(src: BaseRelation, _) => src + case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(dest)) { failAnalysis( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 587aa5fd30d2d..97c5313f0feff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -59,7 +59,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex var maybeRelation: Option[ParquetRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(relation: ParquetRelation, _)) => + case PhysicalOperation(_, filters, LogicalRelation(relation: ParquetRelation, _, _)) => maybeRelation = Some(relation) filters }.flatten.reduceLeftOption(_ && _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 0feb945fbb79a..3d1677bed4770 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -563,7 +563,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation, _) => + case LogicalRelation(relation: ParquetRelation, _, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 398b8a1a661c6..7196b6dc13394 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -317,7 +317,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic val table = caseInsensitiveContext.table("oneToTenFiltered") val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _) => r + case LogicalRelation(r, _, _) => r }.get assert( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 67228f3f3c9c9..daaa5a5709bdc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -184,7 +184,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive table.properties("spark.sql.sources.provider"), options) - LogicalRelation(resolvedRelation.relation) + LogicalRelation( + resolvedRelation.relation, + metastoreTableIdentifier = Some(TableIdentifier(in.name, Some(in.database)))) } } @@ -447,7 +449,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical @ LogicalRelation(parquetRelation: ParquetRelation, _) => + case logical @ LogicalRelation(parquetRelation: ParquetRelation, _, _) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 61e3f183bb42d..e83b4bffff857 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation /** * A builder class used to convert a resolved logical plan into a SQL query string. Note that this @@ -135,13 +135,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi rightSQL <- toSQL(right) } yield s"$leftSQL UNION ALL $rightSQL" - // ParquetRelation converted from Hive metastore table - case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => - // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is - // that, the metastore database name and table name are not always propagated to converted - // `ParquetRelation` instances via data source options. Here we use subquery alias as a - // workaround. - Some(s"`$alias`") + // Persisted data source relation + case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => + Some(s"`$database`.`$table`") case Subquery(alias, child) => toSQL(child).map(childSQL => s"($childSQL) AS $alias") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 612f01cda88ba..07a352873d087 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -216,7 +216,7 @@ case class CreateMetastoreDataSourceAsSelect( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { - case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _) => + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => if (l.relation != createdRelation.relation) { val errorDescription = s"Cannot append to table $tableName because the resolved relation does not " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 2ee8150fb80d5..0604d9f47c5da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -146,4 +146,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { ignore("distinct and non-distinct aggregation") { checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") } + + test("persisted data source relations") { + Seq("orc", "json", "parquet").foreach { format => + val tableName = s"${format}_t0" + withTable(tableName) { + sqlContext.range(10).write.format(format).saveAsTable(tableName) + checkHiveQl(s"SELECT id FROM $tableName") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 202851ae1366e..253f13c598520 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -571,7 +571,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation, _) => // OK + case LogicalRelation(p: ParquetRelation, _, _) => // OK case _ => fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 593fac2c32817..f6c687aab7a1b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -268,7 +268,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { val relation = EliminateSubQueries(catalog.lookupRelation(TableIdentifier(tableName))) relation match { - case LogicalRelation(r: ParquetRelation, _) => + case LogicalRelation(r: ParquetRelation, _, _) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 5afc7e77ab775..c94e73c4aa300 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -42,7 +42,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { var maybeRelation: Option[OrcRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: OrcRelation, _)) => + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: OrcRelation, _, _)) => maybeRelation = Some(orcRelation) filters }.flatten.reduceLeftOption(_ && _) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 2ceb836681901..ed544c638058c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -282,7 +282,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation, _) => // OK + case LogicalRelation(_: ParquetRelation, _, _) => // OK case _ => fail( "test_parquet_ctas should be converted to " + s"${classOf[ParquetRelation].getCanonicalName }") @@ -369,7 +369,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation, _) => r + case r @ LogicalRelation(_: ParquetRelation, _, _) => r }.size } } @@ -378,7 +378,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { def collectParquetRelation(df: DataFrame): ParquetRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: ParquetRelation, _) => r + case LogicalRelation(r: ParquetRelation, _, _) => r }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$plan") } @@ -428,7 +428,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation, _) => // OK + case logical @ LogicalRelation(parquetRelation: ParquetRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index efbf9988ddc13..3f9ecf6965e1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -500,7 +500,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: HadoopFsRelation, _) => + case LogicalRelation(relation: HadoopFsRelation, _, _) => relation.paths.toSet }.getOrElse { fail("Expect an FSBasedRelation, but none could be found") From 4f60651cbec1b4c9cc2e6d832ace77e89a233f3a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 12 Jan 2016 14:27:05 -0800 Subject: [PATCH 1544/2089] [SPARK-12652][PYSPARK] Upgrade Py4J to 0.9.1 - [x] Upgrade Py4J to 0.9.1 - [x] SPARK-12657: Revert SPARK-12617 - [x] SPARK-12658: Revert SPARK-12511 - Still keep the change that only reading checkpoint once. This is a manual change and worth to take a look carefully. https://github.com/zsxwing/spark/commit/bfd4b5c040eb29394c3132af3c670b1a7272457c - [x] Verify no leak any more after reverting our workarounds Author: Shixiong Zhu Closes #10692 from zsxwing/py4j-0.9.1. --- LICENSE | 2 +- bin/pyspark | 2 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../apache/spark/api/python/PythonUtils.scala | 2 +- dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.9-src.zip | Bin 44846 -> 0 bytes python/lib/py4j-0.9.1-src.zip | Bin 0 -> 47035 bytes python/pyspark/streaming/context.py | 89 +----------------- python/pyspark/streaming/util.py | 3 +- sbin/spark-config.sh | 2 +- .../streaming/api/python/PythonDStream.scala | 10 -- .../org/apache/spark/deploy/yarn/Client.scala | 4 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 4 +- 18 files changed, 20 insertions(+), 112 deletions(-) delete mode 100644 python/lib/py4j-0.9-src.zip create mode 100644 python/lib/py4j-0.9.1-src.zip diff --git a/LICENSE b/LICENSE index a2f75b817ab37..9c944ac610afe 100644 --- a/LICENSE +++ b/LICENSE @@ -264,7 +264,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.9 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.9.1 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/bin/pyspark b/bin/pyspark index 5eaa17d3c2016..2ac4a8be250d6 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -67,7 +67,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.1-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a97d884f0bf39..51d6d15f66c69 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 34ecb19654f1a..3bec5debc2968 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -350,7 +350,7 @@ net.sf.py4j py4j - 0.9 + 0.9.1 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 2d97cd9a9a208..bda872746c8b8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9.1-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index cd3ff293502ae..53034a25d46ab 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -160,7 +160,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.jar +py4j-0.9.1.jar pyrolite-4.9.jar quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 0985089ccea61..a23e260641aeb 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -151,7 +151,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.jar +py4j-0.9.1.jar pyrolite-4.9.jar quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 50f062601c02b..6bedbed1e3355 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -152,7 +152,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.jar +py4j-0.9.1.jar pyrolite-4.9.jar quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2b6ca983ad65e..7bfad57b4a4a6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -158,7 +158,7 @@ pmml-agent-1.2.7.jar pmml-model-1.2.7.jar pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar -py4j-0.9.jar +py4j-0.9.1.jar pyrolite-4.9.jar quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar diff --git a/python/docs/Makefile b/python/docs/Makefile index 4cec74f057fbe..b6d24d8599cf7 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.1-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.9-src.zip b/python/lib/py4j-0.9-src.zip deleted file mode 100644 index dace2d0fe3b0bd01d24c07e7747aa68980896598..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 44846 zcma&NbC7RMmnB@bZQZiXTej``m2KO$ZR3`0+tw}Hw!i1?nVyb!V!r6kh?8;t$h~q$ zWUP~CuUx4h4GIPW^shER_a^oqFaNs(2f_og_h7YTP*s5j0)-1{R1H+PP<3&K1p)>; z1qA~7k51v=!j<6ljnV!I|JQ@~pW&`1j!qVKw)FNM{|D$_VE-48p5Ff+r>956OGLp) z&qV?AFYtdsLH$Q=H}z?rs_;NSApc~E^`GFD2CfErMt0WLCPvQxOycz4%yOk;W4FPH z@_nt>@6;=yy(-c1v4>=hF=o2e`IyKWz6Fl4B~)U=twtlwrzp2i^W3{Lbt_t|i$@{1 zL8|~NYU60inT^etOk}AWKoqwgZ_|N+T?fW zO(QMj{-f3*IAh)FXXfAyDRL6Yj1s$KTe46Ci4o?7VJuxsV-e4XHESH1Hs#(eowgx| z{F=LJ9qEBGlgR;FRwT9_nP~6g8Qw`C<{Zqu?n837_qq|o%kv9&4rouH;yvgy&1_V= z#E{{%?}0kWQGGGh)V+0op?M3nT6;JN21BZ*Zk;O5V6tfPP{^1EiC1)-Noo}gGF4>k zK5cVaDBCIHYj!mZ!ZYW=B-vWQXy|?~O@f@r?Bg0lnl~j3DcwLU2KlN|T$9evV?F;k zhP3YMuY%B9J7Dh8eT>%=o%_PEeu#R!TG;-F?Am3DOOVE#f<~Lx!{h11Xe+AwU(%%f8XNv?kckCb{U?=}`_9Jv@W! ze0GMI7kRpf_rDj7`Zh6t=ZHz)7F&*kbJE4scr<`?9`949>PJz^rO}X<5rM%=re95~ zg{4bw!^ z7m!N2g3OnuHK}r&a0OFWVaWyZ!9`hD1Pe_0=!0Ah+8&dT391c%sR?HG8%oxS z)MSYgf)bSUO$4OJEw+Uh3Pc*O7=h2u7c+4xe(+|J)=Nx3-YNer?qzE=5gYIgM#d@e zCSD`H*pI6ra~~1I#vNrmAkB>!_XLUJ&+2ai0@h(U_|3#8jr;3>8FQ5}_t%wW;hN|& zw&ybzoRptgBz3XGdLq_S6-po#bV*%Y#jP0!7X(!7D{vNn`ZK3$g27|F7BVh2IrTfM zk_NMsJt@#sG6->Qzn{#e9M~T5BzcnR49I-%`Ab#TgLPZx6eN(deoR0CJaBJ1CU-0z zR*0ds$tiuI3#@Z?)5xGqUVau-x-jNb2TU`l8mfv46e_OUm*((2F?D@Z$k zk))H$E zYWaH~FqbgWGko8x@SvLdEoX{$`c(Y4c-2KJ&9%a>KzNS`R)Q?s%3H+&5QG=S{11OJ#Fa#l$28(zc>Z(6$U!T~;KdO<}Le5Y8 zrVcnr+NTsC=5wLok?oq3`@!^PsspuNRZ`&-&!X6(L5)%Kg!A@{0{~R>Ah(-Adg4c| zXX2OyBJ7ndTCdiZ+^;ELV-jEml8L%X6CgKM5WzC>q5yPIjCHBR`mI(mXucfwLa-*c z%BFb!+pdI6Ps>NkH#X*Sddx6=fvxp${Bm=J;TQ{Air+h9WyJXt6{T@q)wzINc zXjK+}dpE1$vp1RoTFV7pD{G8L*ZOOSJUEtYBJ5Rh*$i}SJmcUUSRlyDTUN+AU>jp) zn;*pm_iZG0-K_8kR)FQBTG#gaH;`9?3HmZTn;`7K%tUH;Z74DBVBy13FZ`jwb>b__ zm~&xdf_SL;(7;njjRH)3en)}Rx>OxAX+qEd(4CU1Iv=s-leofgMkCtKR-)x8C$*37Na**iDBJSOC&1N@nTt6z zza!vgmFq@4sfbcmSKi;}PU-xI6KBX5D%xCFaQVL45{IeMhWBU?Dz^Y_b)P_2*3lgS zouj^dBEWUvM4CpgLam^reZ{9QW8dDiQ~?59)hH4{CoX~|3)R-RL^1_i&{{%!0Cm;F zwZ8oKQGrRO(vE5JVB>ryd8uL9VCCPeG?N_MCDNg4d8$BkDfeGRlaodKg$EpFX5uHk zq9@{JpwFN3@>Xfrdj3#@t-ZFzt=P87ikkw#V8mFBA7+26EABq)Tmw-~&z4+#ZIcQ5 z=EOpxiY60E4lBBl+u4E%(|oA(q3CGmepdS`JiSpVGi-*l^M1scBPdLHz%B+6&p0hP z4iM$#V=e8FJQUDNZZdA~Ha4vF?Zc{uh0Byr(|` z=v#j>)QR|F89^ms`Te1OP=Uk~q@uOfpw;lCc^H;uYrfi2Nb}=`BJj0>szJ7jVvFZ9 z;aod!Gm%yifYwlDC8r}TLuQB5<|Y3HV@^~n5F!qh-Of_DUy-0=l^B`!cquW$AHM&d z-_nMz^BJ2B!(fT_dUg*2*REm)r!=B&kwGmiCfu&lnt_@R-mV`J*f^|$y~^A^N>{EQ zw^Z$KjejWdU}>=gCy+2g=r&JNfCyw5g?hcN)0xRe%RauF7Q8>Q;Km`qji1fWO~lK) zYfrtTe5?Ebts|Xtt~>eG>hQuD{_FZ&63+{qC%^4zt@!#EV@o}RY8u1PlkLy$=nHXY zgh4ViK%AxPO}9qYK)R5{MgHaeg+-MUp*bJxCPpoB0FfLks9E1SJ4`2L9L-=>rH)7) zeIC^%DJZE~OAxp_bLuV;eROZpqUt&_|IfhHV)B`k{GUWK7Dfg|GY@a{vfso8#%*?f zOadJ9FWi|Q6BP1@^j!d1nbnsCuakw1Cv07ybHa-B{zoM7c1}iS{O3&lYbD)uHN_KuRbcWnQh^$jCR(z)%O9Q-NgaV+D9B$xLF|_s+QVmov-VKhs(ye6S`MT z(afw!#cm1_Wg4$Od4>7QbzmC+d zN(6jfKWBIIZY|F@!^>7I;pxdQ4QJ7wcWNlIx4Da#HPMeGQ5+P8d5s|I0~*-*9m*pA zhL#aM5f6M#OumylxE*a66?JJPw=wxosLo!%Z6&SgvZittVD06@i}6V%_7PKx(l(k{ zoDDwER1cG}d~ z%@X?X(t!W)#nUu)J)17^y4CDuO#-H0-QCS7LhX`{lz!{mRs%b6z53X5ZB{6KDLHN3 zI=`{7OB_*;s=J2n0yik}?=-O!yqbu`V1XCIGiFT-5SBcqm|gfqx>ZO}uWK3Qc? zE7;9AWTkc#Sc#F|khV1~T7tCE#SBq{sged2a?(vNH8F9Hssm8OR?D>ZZQT28rw`s* zR(nA|nC(_s`Dt`hc1m6ZTHKjTDI<*@wyopa?M6`+?w7jMtk~(&hOp=EJWUDt8Gg^z zM71AedPtO1`v7aU7xYr^7ysDNpUiF9^jeaaNzeprm)^&QOf_?*D_9&r2JDp;3oWip zS%+6~-Rb&qc<~vHIx#866?63R(w?kSD^@y2!zN=-9`sZ6fM1U(#Zf_k|BxzS@vgF} z>D!1M%Q{^Q*R)ICgkr#Dr_D{{)z6Caik|jnUn^M!YOS@^*7CD@o^C-rifVsq1a^{8 zDhknlfd+=#AIwk+QR9l4dB0IsHeAb&4HZ8dCskgp8@%j|j5b4~t~EIqe9#XvNCcAC zXyiER7&&;WD?u0p?FkO?oDQWWqf>oZHu1%~a@+XX65R@*^fji;J#3H*5mF$`L4GHkWJL>Px=fbTigs;6bb=NbS5?JGUv% zfh^-_zvg1KW#d8gX2#JWh}qWJ_HzC>p=YnS_51fSO*qosL~(MCv)A5?hznN`bOrDE~{K2E|E49N$Ka#x-tiC2>`CbxF=J0()y1 zef7y+(itU!P|xx(9cR|#UPF|`h3UpqE_-(_ixg^bt42U$>OLWDt3_#WEf{#SMsw z$Rz8E$SD=3Ccj2{o_mGY35b80x{#8j*!VChmB;>86-#+U;?)_#Uz(XJ61X(Yj;r)bAe~5Ny)VFz4{*+kf;ym39U0 zqo0W_Rob~(0dO;~A}yfC$_=SZxWa$tPCP@zL2cqq%1Ac3q0Q~YnP&f9$^S9@%<()) zHf}8eYfQ9&_vL%9KzS>Hm2tA+``wL^4Y;&&&`!NqM0{|BAcg^*WUhicgo`Wwc^gB6 zXC?{UCBF`ewvj}slt`2?(zMo>HMbn}yt*vSIsd%cOG4w6h_MtOgFD6_N~tQGgRRKm zg~3fb5Vzi;5`FMFE-Va!0|BU)7g`LTDwV0I-^IUK%=GKac;6PB09X0^COzlqC;WF> z=;d}<$0Gs(but10A^j&U>>cf#?fy|O{ySpeSpB0~+;{qd%YO`Sw+Sskbi(eq0*@*x zyWww`@%ZFxk2T3`SUxd6N}SWZ@7RSFQz|7c+B#sO$%@QGu?%`hw%BM-^8N~A9Y^Jd zmQ1#ej8A~yfMA-Gn0X1L91HStFTM+9sGlkeBic?NF~V)&(`r08q7fac_cx74eifIT z;X9FELEA%1@|JhnG}*A(g`}MsXO|q09i}bWOA<~XNQ~A*o_rY(ePx9UP>rjK%HC{z zln`JJc||((LRoMgmCzBfN;NHRH01;Dr;UG?jYJwMN-Z$EIwm2J9!*>=vZ?d z2b~Vj8`7*ZpbEzy5`t%)Y)4WbXvONmlo_3Rqd!Y9Z2uMbX&zM_ z;l-j6tJ?`NmBc-#A2^B1T_2FuiG^TV!{N=9esHem*S{}66yx;hjke@SG5N-0#}4k$ zb+%07*7hj23YE37V$Yq?H(<_vXZmn9mBX-;!PsL%mrAsV^t^>x)Y|s!miOTP^uE&J zst0#HB8~W+3UwyFp^*b;3(yG@5RPIzNuN9jFj-mdEa|*Wy)pGBoD&du8Of1!ET z9FOrN*jQP1=5~cdZEk`vUF@u^uT*0n^X2Tok6vD2%QfIzpfqU#FLOA#Hr6mGZDC8F zg}hk;f$FqRJ9j|bkH%PbGS_Ym%J3|1VQa@#*D>12gb(#Yu%$+TRCsl-D@TC;+wGFZ zP_}AG212cTv(nw|*v^Z(;mM1y`tg!qTEas}Tb7dR+INPpu*dGLJ>GjXkvo11>DvcC zWIdTPTbBK9#<{m;e@c(>>MY{-9zjb)dgz;Ac9cso3(QFNh&?}B3iNgcS9jiZf6DO9 za@SXM{Armlvu*F`#pA!jQXg)Z8}Ms1lUxcbsOhh0L6{_(Amw<~9+0&{B5lsBxFv(w z^=n6c^Dtc-#4fR$@RB7yk%0g3WarV-)bLhO59&UG`-aAL_?<&2|0oa}Pj4 zzs4h*^{92A96&Km9Nd=sHdGzy<@}9t#@xX$8{jFc{70^f(St3hGdlIl)F|}`jAvSi zA0B`b3w?1|(fFsf4rifpx?;o>Gn0j-Fl|5&pM`A?h$i9rHdQeexzxuk%rZM`x84S9 zK&8;*7qluDik0O;gVc0Y`Zo>7Yy-m4hECBW@|ilangXp$w)MbhH?jU8y(n=jRQ^DP zi}Hw&L+>^Ft-` zx4IyJ@P^bapUK?FRRI@*@V55-6P6Z7O@(7#gc=e~jR-SFyaULzw$^-t`xoeKganmR z9!FLo!@!ENLVUgj0+|P)1l0gY-EZ)p9AbETHjFz+2p(@~@$shMJ7NO<~5Xmf1p?@6Ai z4U$&I`SAD>C<7T2nOA2%Kxhtaz+dK>cp5kacT9yIHOb?+KO`8x(%=bUUh^I1*RsTt z0$(bLWlT^84cqko?jmPd016+FI|gDKfGm4jg84V~4Zn7cAx|H9Jj*;nuasIh*zEuA z^%bIHS)Td_VR)3PprxP>V9 z$00L)v9DAT`{FIfmN{9BF1Tdj!z7wLmmzf{wZeRR9=CHyn*hwfA5IBVz#DEC0eqs~`NeWTl7}>e{ignk1 zH<){OSd>4+sG*yHt#q2uXCQ}YR^kdE=F@wX&5;!-i3su4m@%cKJ3v#f;?zY zi;a)Yg8j2Xh#jkd3RN=obBG1(2pxjt8J2#Ti?#rj>ET%w9rrE(#V9bgFNKdlFjygS zmB>HVZl62L$8$Jz)SKULUGfC7&^Fbx;D1!wG#Go7OKZxjpQ)bhCBhN!CXsl{#ek6# z5Gn~*PqT_ZR}%7{@}LWMOeuSHIByZ9MIh8N>AOX~3>OPM?62V)+gbGvk`$qw;6!`Z55n z!#zWxEs~pXQwp?H`y9OVnDG!9|6BexC9`(*`%k%UT%k>_()8eL=6X#|z+2@0&v41| z#jYz`6Wp8H0Of|F^!O{+p))& z0R+7KLoX>v!*|`(jt3saqdOG2GE_{+#{TuQfnMPcpCxz@#!H%#T~lmgAePd^SL2lH zsUD)vMq+Bntg7(!MrXf_!A>x!DX)_4q@)HWkFu11nM9zjHXu{0HHAGjUX+*CfJZ2a z8p?L*oV+na)-m`46yZ{wQ!uxnM8af@%o&Xs1yc8TY&CZ!25-;FO|(-7z|D@l>mp6im4DMu zVFf(FZaUZheD4RdTp10#LB0qgegC%j5&mk(q~jRNKQ)Ro2}M^XAfPb^;-yt{ zIaQgh>N|aG>^j9 z21~PzGsm}jGZLFBo%+G`C?NAk+WMjSzQ4SP8`F){9p6s=$YS>E?VCjLT6m&Lg}vTDHYemuY@@-vW7?si#gpoZ*l6tKjI5q3;c%)02;$ zm7mq)?gFOi52mHs5R@^lNJHa7gVW8RTZq-^6Mj-bk$jRY?mIkF1X9;O3gg4FXNtcX zFou`4!>`JgMfgiuZ12q4`Kly+J()kog6KOG9PRIyF~->};=M%L$SA~R`p!!Z&NvW| zW(cx=@Jo$flpq+v?;z@iLDZ_M#QAreFh+F3a-?5kkVJlyLXXC@8jY^Lq?)D;J)ojV z9-?sZa0aTZuq4HJBEg{RbF*;5z|(D+?LKS)A7uxl-hFwJzj5M>f*ej2L$m9ofDKVq z#=P`Ncz=>JpM=!9Wi5yC6^tsPs*XX&p;k3nO?A9p$@DzaLfljQOX#k9gl;hz>aLgD zd8;63?a}NB(-&bBfp+`%$;)1_xYMH@w==xd!wE#s*^HsKXX9~2+-u|g+s(b{w1J7{ zghex_2ke&}cN^Y2mfW}Ltiml3qB+)Q%sl=@J38HcG3%gHv~ndu`0B|+E#FHlZgt>i zw;?rPabDnC19+Zs=v;Bpu*CN_ObmF*gx{n*^8snC1BS0hWX0N-5m|J6pjRbCoukP^ zD6YB!3!fBFwFeZq2YDs;XCYH%O@#b`k~F?Ubq9^_gW9oWZK|@G7*u@j76r>3?6%i; zjVb036S>BjEqrxBl7tar$7Tf{Aj~ab`v$t zQh0qwtEJ=;k%pQ-g%aL=bwLXHQD}BE3Gpa>nGqJ_Z$zJ!2X1WG*4@lA`XoByp9QrW z=sOA;1SFZ{-*dshTC0jGNWIia_KSPYj4F2wZco2tHNNAE^M=OJrtR`*G2A`5RQdDT zKr~HnE2(#H%rGYR=*!cY7L%Beu<^dX6T zW zqK2 zP?XVNW`C*m1OD%(hkpzT|EK+*%3e>;!q&oBPw&5}?En9fVEZ}a>x^Ur78xA zTR4wt?|rzXyD4Qz&$E32ir=nsX1crY_tRG2F4ErJnd9;SS-MWq`bTBk%(14dk^;OD z9y56ouatXplB^#>%FxdChj@9hIr)x7S3~*@<)U{2LThz`y z^eYsGj2h+hk^0UNhDo+I8ob>`(wI(CFAKK_!Vq;p=*J5>U_N$0e=nrw=1a~9@Lk`T zFMgi_^1a^ekV6xmuI*MW{e-doIiTmfo`#AeLj~n1s%*BW=r_X;fp(bfWC8YBa5K*c zK6w%!jaW8Ykt545yp~@Wi%DLAJ5!EI$I5(SdkNx#BYcbB(Gxd9u=pA}a-LO7gSa;J zNjZNBBq#R^N)Ax*@N~VQh$o&8;VUATTwBK49kb7{4?RTB?!QSD4FV4<4=rnYoybkx z9WsIv_)FbjPtnZE%IbH=D`SWK%q%^*vBYI}zkVGXb_~Tis@KIQVW(YAD06$2_b*_TqD}3k#T-fN`Wu3nP|@R?{YF$M++qi|<#N8LFeR>)oWe5V_f*-RRAB!gv-kS1%_I zUd-s>Yq*anA;_co)b8jFBBFTFs*Y)?ahQ*8AdM&tiy)9W$6z86$bumQ7(v;5>+|bs zx=jz|k z1?pTV7Scfw9+4pN>^ySQ^NcM*jY*C1F$^w`XZTpD!FEbLx};*Pp_@t=!Hal%ywVrz zVL_kr-LLTv=P6n~T|t!vj?YKQa?sCHfu^xXM-jMaEQJvvw>r+uRC8*AN#8)2{u|F z^%6BygsBQ)!1IjwY$FKp6(UC{@CeDxn2b6QrT~gKP8ur3ike*p6s&xx9SE)(Z{7|Z z80xf(3^vyb4I6F;6M2m+{r8Syjv+1}#9R+9&llv#uu-ewcj}DjRLBBM_&$L49*eGm9zsBA+Fh;UWe`3C zLri4h)RMKfx29umFUFd=|Bh zLG$1{5)VQR$@KwZffYSw73DNVgV4JwVqrtf5-~;l>HEe>#%?&L-ASZx(1VNUrqraQ zB8(fi57XYs1#mRsO<3Xw;JET^J7;H$}L|4*6353>xXPya4+ zG%{l_IpSl;B~G=P=TwE=LZ(D}KFj%5vK9h{0{oFZfr{lU8Tdqpdh<(lKm((hA+lvv ze^u!4s6uJI6edq?v>Ys}Y2gRab`qf*GBUzn0R!9tP$ey#PB6HLAP+;wUPNwVDKq-m z$b7~Dkg?Hz&L$GfhJ($&wBe9ndtWO4<_#~o%jSEHj;Jwb*jzoGy$oIU4wy2UJAZ1A z(o1wuRLmUU5&VpCbWM{e1{Xx4&T1CIXV>B8d%%&;S|ibXr0ga&xlwC2#S>|3eNo%p zh@9-qK!NT3L)8%db|Do>>hp=4M1+Z_;1<|<*36mST8Gwb@~1=k15nTB6-hmtk*)T` zZ7p$-#70bqBAGOp(CTWwOMwTO(V>PL17*#_X)bNnXR9SI$Sh+4Ax937~qYKI_kh9`I6i zWKZjWq=#P7e4{8mu7h2j3`yN(r4mXJV}!j@O-BA4-1iOjpt2KuH=>UkS{*H^>qub3 zGImTX-XD1LHjaP$OJRYPFjn zdlA4jBz0I+>s^JYn|%Sv`0D9kXF9UU4Or{`Q*MAfhed@~h=z|2

    encyYo)zHI9u zo5qy0N6^f}S#Ew^R_97A|Km2K(8q2<^}T(u{5~dq*wBa_-mUZ2Q)Ae}9h2Klj4!Bu zzniMEr1I2IdWCQQqO}`YF_PQBk8s;LgtQqG+fp?mh5bye9SiBxOnz z9bKMvkI|;oLEc>5&fbUkr5kM7SR?D=`$1W5_i)aqVarp0Zsoibc=MTy9bf?1MI5|O=3eg>S<%=v|32{mi+1|q?WiQ)dPhh{JB zY7l*exDVsX{IOlACc}nWisdFl5xILGg5RRSnPdFA!4e%WQ)iLJJExS_#VfAh>Fg0! z2V~kdSOF98VW?rk{%JHg$}~4xC8JP6XFs_#4qcI$ByH-e&8L{2FE{T0u08vzuH{zc zo0^Fvl>_O;!!Q>ZC@au(uqNf#PTXuYZm-=q7IJF#&q!7C(PD1laDadkx3{*$03TJq z1hF!ZXi&ZirKgeuQVFTs1nSaV7^Lh1Mu}@#`5vug+!ozicF_3S_ffYXo6$N{Jt`a2 z34s55ebck)!{%Q9@rM^7t5Ojeml6B2rf0)wE_E4Vj{5KA`fe=AS5>}6&wHgC8Nvr| zVJE4maQd8Q{fE*9;^NKGiciGDicR>{Bjn2Ub`%Nm(HFg+-qIZu$I;<*O*tk1JqCSd zBs5R6iHsLtk?%Bx)c09 zmUBxCht?Qq$!DK41ksFSB`Ttu8egL;A>u?Er&}t2<58pE7mMFSZ2t3`41TfV?GWT@ zSh&nPKK@9Je|?q`uL`|D@uU>iOi8_&fW0<&YgS(omSM?H{R=cBr^9C#@b-6GR)6EH zChW%r_t$l1TD00>;QR6m%L_td1#I}bpy|V}nxNg}P3g}Q)wDG)uUq*aa4+kP7=m*b zQTH72h%`?QG?QZoc9;2f+{!xcOjV7ZF|s97(3>UcN$W~BG>^5Co&xJ6s?rHo6gTShP@WtM+~Ku& zB@1M%7X`7%(mB*C5bHb1i;CCpCa%fBkPG0cio^);Yyh;G6{&Pf`tu|QN_wJ05^|X< zZy|tsW1g8a)4Pc@sFP>tu~vq$TPX3Z*todo7$~C376hjgSQQ2&YbRP!`Iih}y2uq= zl;egktNVE0Lys#TQ|x z0Rj}^yQzQXk6Fme<=3XcP;v-nlX}k+4at)1N?V2%MhL9;@Jv#eVl>AjC0+Qkn1s8NBMq)iP-hu(iN4riZDLzT7D)rh9e*0oy;hbtqABsmcA2eLb*xfjB`X zO3`hDRh)XZah}K}9i}lo&;JU*bi`C8@_9xhUpyFZl2fW^Vcvw=0TVR_wrtR^q z$4@@Te3X#Q@AQ7z-*-PhN{Oc=t3lZYMo#dQMbcEte)QHv;ojb}sXoJ)$-~lQq2%u{ z8ZNnBxImYcAYE2|^nCVE8(--X^!5ri;pI!RX3a9b$NbWb0VO(3isC~!=EKl`&%>Dl zrkfdUmwhhK``a5n7>Jz*#Nd&>GuUj9li9x%OFX2iTSXjpHUxVb0XdBYlJ%))5u?E( zyhDIR*Ro=a4a0GuN@Z%)!paBoH6^mJ{%GKU_&)W}JhU$4!W^-A(91Fz)RQ z8CR#*Bt_?{tDOu|QjLBYI%ve+s;AVP92~X(yCaFc6XIpfj$pi%3rd@tpjcglP>#?2 zv1N?9xb?Md46&9ODPHgI zh-gU_*%TATzl{tHu`WJ0%2B11YA7?lXsvCNj8W6ECIQS9>1!XFJDK%>_Kn8ZQH8tmrU(QDl#9-D7=udO5WV7!1W`B{kCq(zPBN zHIOjko7fnVr77mL>Ta>52$cX{8OP~`&d_4^Vex#qo}$0EnPpjB;eHEqa(Cs<#mkmU zekisp@BP$`5?qpnW92Zt?y!chhJV~^{jq?04Yc8xuj6sKb=)QHowL3slimSD6(Kgq zfYMY(bGzz%Y2DqKkGI2KM`+WEZ_A_7^(j=7op7)Q62?%dH8il@)3;CO4yNKb>dT^& ze+D^M1hEaRIA*KQI5eoa)epE*vC~k)SG%uLfI1Du90}n*ZA<$`0tTtahV1V!>d zm1ud7tMh3K!?^+%-v(pHoAb2o?jaYNxCIg8M|NjX|x-`6W%{?720_Tpri?HZ_N)jPTnQN zG&Pm9Ge5klJ3v=SsYwVmvu%8?m+n6F17G=F2R2aRyM-ykBJY3E z40LsIJ*ic)@9WI9V+uFqiImaw zHod*`J&(PM6TEkJ=jpjV*YEWN+(dMZ;4C~Qj(leh2?1r+ZlQ5iDwA_E3|<4>2`~1= zQ|el!&qD634qls5Hnv{$_H#Ad8}ks?_6B& zY!KPp`nSiLRUYUcArdrU$l;!lh`LrXpU0Zaj_}a5BjG`Y(CnfeWX_lLGXpsKi12}x zQAum>;t-d?pg#shj)!?7EociMmSV<1c65fBjbU| zKrFQ(c!cd%!{~ok;thPJU9ZX3FQ6r71-3)CGUeY@}%cKFC#X%R);|l_D6$FI_yeF}|qLQsp zgRLIE z@IT(nU;nNP|Lc~Yo{@pIwV{ELm7a;afsMVj$$z5_rzpru4+MeA2nQ6+Z`!z})_ z&`4A;N31!rR)exKKt5ZNBo;x&ROT;11mUt=|rHUL7ogcj$&xaFJREMkC z2V5M2@GaI!>**{X77C2&Pse2ei&wp(_p zfQ#q>4}KH|t-WWxVwZ8}ugCMk&e)-tf$3@J z9P1-y>0wjs$^@N*>q60-?+Sx0XBs3ZtBV!O?)Ty?`kr@YZ1NH_IbAXNzYY@qOSzZX zM>o9qXLWF(fq-!TeTDu*xwkQ}|DUM#!)iNr|JZ1KXY~vQq^VUEvj(I^$>v4`BqY52 z%Pbo?5f3(#(&49ETGEF4nQCu6{55lcndoyaBD`W(6N15}0QMMN%>t)mPxJRp zl0*y^zNCMztXvgbhjKV2RayWf(hwYPwR-%hrC=6D0KSlN!eH38|5mHqn9)U-Rbd7G z13$`rVni5&;gmyKyh(e3{u`Y?ey;%~{8HHJFvR*=Aj=MnV*M!RZ!K)V+((lE(v({Y zkxHfRDbb?YpS4Se%W)!=Z6?l#ap`An*DEX~FLE=AO>z~_@ihGBb&jn8|JOAGHljx| zCKr&^vU;B1Qe<={^RMZ2i;}xIG2$Bz!Cg(p?h&O&fGe21meTTtS&&sdoBH2TTq&oa zQw;*-+1B3inwfEb-jNE+68}n}U`y(j-WHXR%R8E6oWaT!{MZ%z$R_BVKP%-o5;kN~ zRtF z^En@?0f+CFwO`n=Hf!YfhcWP!yNZlNG}!q&!gH<8)?o95)Uxhd{3)u(V^WLc!PPVm zAwSL@mE;5`I}?_Z z4gY~9+3&%~VLf|%^nkN2c*NG3>|awLDw;vOA=T~h0%iiglmKFqR9a!?P<n z^xSm1PW73MHdEW|@0g9bBg8mQ)@zc6-gIoHP=U>BeO93+qwoWYo-RptTpVoh11;VO zK~-D}8Z?gVn!Nx+`xce+uS`>r6L?RbI!QVi;lc?H7Trbpnw0ur3*a+ zq5g9*SsfYZANKt3%gg`k_{`M8*1+1r)5P(ALZD^!9jlFh2=t;yyos3wj+!-)kEAia z5+=y8)={PyY@MfF){r7~FZ8l{gZ7hr(Ja|XUrmT2>DKXa&9gP`3Ru;Y(#KIEk#d|N z_2BE9A0Xbouki|)VNv@kGoml$E^#dq09L94}3$4p3bk^N1n;u zD)i|w{d9Z%MAwbv15PR}c_tLNRU=6ybaXSVeB0c23WR<$x(6jM&aD{;Zs6GQPH^aX znSl(Y!zN5m5?#SUT-*`Z60Q-kiZXw*=?;%6XDds=VFQr-5>7s8%2?19W!JSn+8AxmR!BC z8k?!}p>V4OuiC7;Lz!L7v}t-w8moWuSMFNKy|c#20)cBfJ>bnv7nftUJip zDn0UCeiu}=cF`vRYbgUKW}DP~{UPr#HS^|aUaX$4BtDPKUWG|&XS=I^FKjZytW)d? z1$%~D-1|&l{j~IL`Lw1mp-PMcMkwiSt%8(VO?#|DP0_J`{;yw&PBGxnKrkR6(0}%> z|JO!kZ1F$h*_DR0-2ppN&xbnFOo6KF}ch5FJ#c^c8GCDwqKWG z5nU?+Im70xQd0|-KHfaNCNyHor3D5T{x&lpn&iuH2n@uXsLzxKV}UAP=jFcbc8eLz?aSbJ{MW#C`7C z7;+zn?Yi9shG2~8HQ*dxzTA8sZa(%A<*R(dGt#t=sZRCU7x3ptzKN}~V4B*N64TBy zlC8Jqzm?8&sSQ11$N1;=;()~KxLnK;H*`EynBPy=3bm+#fl_Nbci5Ri^ z@zMnz4cF3U9t3Fc2c8;8X{mqfz;b#CgNt$P5 zGMn-ZWK^!~x$&bbA*~+VA>mHYLVXV2{(YN>3v*u9E_h_$*VWb8lM&p6i^5 zxqFvOd2TJa*7K0T|Kn`&(q>|CQ)Z1N(1XEI6BcyaRvStpVg>JMF8LX-#|0G9NtlLCvp#UBI4nubkf`j4r@HldDsWNG5$<4HBc;%-mzy=?|t@sm7;`B%0bv_tAeaKU$amXrG#`Diho&|5) zCU1*36f-u#cgtaG!d_@7mK*X`H9IrX!-yho|9UFwgT;sI#OK!#-paMh>f2)dg^61& zvheEEW`p|c{F?p*JpzW9X~$zrY=gx>Lnf|)=g!`WgV$HYXsorIsn)QKr)X_k*X_mV z?Ac14ESE4z<-^yuUW-yH%>_RZC?bi*WuhNob5~IBFk~}!iEMlm;Leado52k>nru9C z#SxTigImgy1Lu0$sL5)Hmoc~lzd;T$e*WVR47sx}koEFQKbjk%MLCn|-U!yo;Bkhr z+5l2^I|x08RRJpYA~F#p*<{0kxdXTZohR@!oa9y;Xa8%2N& zEkt~TmIg$G4gHeY>_$ZUXwdk(*IA|i)VEfQb2U!@t*aK-8^6GVLMj<9jSsv7E=(9v$V^uEejnkN%$L7C$I4l~=vq9j*6>!RLe zfej!HtWi-*`)a;@H0vJ)4+i&-J`P%gtqiFVXU)jiEfa_DW{mOl1OA)w-j~N>^Xsv; zi!VrGNM4aQmVQ}B{sRR_C6=5L=eD#tYT>0;K$fXiGlts#yF3Z?pYi$skURZvyi>Z` zzUT=C0I){>fAnkr!aG*xj!yqczc!_zZM(sW?EO-!w}ke`8ALrsb22UTiReNDS-Z}~ z6B63LzSi(hLz#k-!Ous0l93{Ed7C-uikM#f3-0~xtxiP25xdL{4CthuzgWO_KQ@dm ztlaKL&!b1rK~^TL5{$>+tVl5oGAWbDkH_SYp<)RG14I^OUe4$I~rJ9_BpQ3%1++fc30sgfMm4P zw0#&muyi093VkD){axPI_^(UK53DU5FIsr#jCNa3m+y_08HPHOhng-EM7IZD- z)j*6OL?#K_KLnGw!LT*LJ)jn>br9K(J47Y_=r8}1T$MxK)rXNp6G_XY_fup8cGpmq z^yoe(Ik(6Q+p6nYL#kI{adYiGSgJs^B|F%}z(N~KjW`W;y+Fqt%DYmsKMt_as?`jp ztvugf1~a6YNG{I-wX~(NfT3u#;OSQlFNje#QhJWGsf|2NPE*R2gBhMB47( zSlTq8!u%-z5QlRYj{Gn}V)p_8*L}q8uT`Ab?G67)WyqY%B=B_Sg*UO@yV8n<@uf3e zu<9vHu`T{hSvrC{M%J%$gFZSPRlkm(Dl~``xANuyx!6(2?p*6i=qH3Eybf|;uj`~? z5{*7&zEIpDht?Hwr;~EA12_0Fl7jxeih6{2W*?8Q-ZniuXD1@*KlUncael2)jOdO= zxVVXqsM`MlOk6`!u>-XGrx85N5=YK1zZaM}5G>o|>u076keqycBXH4PE`e(i8wwV4 zch~J!z43x+?FIJ-`78WHcaCM!J_pdCl3ef;j_40id@;IB(x%ZfESp=$iip-u?WAn`T@wO{S@hMGC5jX$8_5r{0_d6pn&WX7);*Sefk>D!K59=>qKleqf6_x(!zMw>j%$ zRgRC=RQ7gb%7`=qH#JmevNp;f%AF}8UoVsgh-FiTMd$p!tA4a-QJmv~(s(T&hJ`q{ ze@5M37+M1|Pz(z{FD$&6Qz6$*pNY-&r_)H5uc~U&;V|5B1siqm>oHlMCkf2tsz0`# zvpTm3KUqLs4a1PmM!=FDW;GB|42_81U}nPPB07{dX(97qh8E}E65d7I^2X~L+FlQ~qo*q45<*I%XH~-pnEGLV- zk$vr9@wPg?ZrGqBZ*pkzaN=u3@QzLt@?Xs*wD{t|$D}({6-p=!l<(omy@aMQ3j~9y zayTWdU<{(j3!*=)vTNyIVPM}c4VH9gyvU~u%CqdeMOHK91g*=M8F0Q>2zAm-C9or3 z%n6loHee-pw!YIu^Rb+&^(k?;;aO0k!7)m7PgDK2lB+&%{9yird6g~7}uzwCu6t>7^6O1hwO2kgt5b0vA}R zqj`-IeS)8z9e~YP(eB>Oe3-Ddtd=Y_8=@YlxEiNP?`qkSZQ!dVGvTIEX>Akmvl~n{ zb#?P$1h{eIiC+2KVcChhwHLKBW;?dB&qLM%^cKlI zUH}Rmu|sfEWMdHMC3?C-W}zrP3_u5Fb{a`TZQxZ3UFPK(%r-i~zBlD)&NIUls+`!| zH#^6Fa6#VZ^rCw36+)Y1MBSskzF%i$wcg;a_?aVx#r0H*x}uAfvrM|+vC@KCI^qE~ z!iJf*Lpie!eVuE7#~U1Cj=AkhNzA>Z`0OBG-X;CQV=DV=J0v+W!_A#cqkvxOR;1Hz zwAiKJF0haoIvc3TH(hkmq4(cb-`lvFQLz4t!24U&aQ;~!|5ecR9UT5M;cc{%WNhZ& zP52L0wJbgXa)L2jLSc)JjtY(Pt|lJgcou0=$ascD*39D_*GVcqNE)``9m|luoRfOf zXuZ2$rlB6c<2%N~J1btpOC03qtTUD$fb|@C1CIkbI|0yeeO!iEPRT#;BmJnD!Moxa zrt9|tBp;W40;H%cVn`0J=8WoU$S48znC6MdYuUy$n4Hwc4;%N0+Pg4=;q%&i%VoYL z>mwD5{G26cd%;vJ#JTn!QNv(@`c)?n*YO<28LF^K=_3 zF=P1#X3Y!2j$@%jkF{0<4}QZg-Z8dhjn7uu%(84Qi5NdgAF3BvISG8A)YhWZ8G^~! z%1mku0QRV7mnj}G^9_3+k{O5AF#>x3UVMI11)SFn(ki&8D%1mNy97AMyY@srqd|qD zt7^18*a|wW2pQUVt07ubD93{X=;Y4~DoHF2hlqLfMaUT@BJYf7kP(@B^QiZ%LQ;J& zcUNr83U&SJux=Du;$do+Xrd~3{{fqDa+f6H?9@z=yhp#F^%D!=6kg(d!p`XStYWR( z3Ca7QL+AVs+?Yaf(CFTy<6)dF7`VXYuODWoaIWOdp znd1AZT60Fay4Q>0sUnOj4F#Yti9+c z#VW&d?TG2){^EPRcIKW*GgS%Az^H^`a_tbF1Q=x2$%Xu)tl%TVbt(BQj4spKoXVudod6T2l#k!8tqNqEvM|hZP?YZ%5~Hxn zBIXLm1i51Z3rvI-@PB{@fkA}Or3KlAtA>@;5Nw6AzdoMW}g+P>Jc zWe0`Ri@J%{6`NmPg0R*Hi-KoGknMBo+E#twm-%Ax7Uf1B{579cAh6~?Sf0H?`|>{R z;w3H{INks!wly z+f2nR+b&nYc}$ZPJR=8?6Ab-~V*~p&CgKG$j<2)Rni>nN>KP||>i&Tr^NKs%UHgqz z@*BNyQg`ihZeqiW@Dom8Nb*P5j=KcGkOU3IP(IH3Jqu5`PxhabEip`e!uemw75f)* z5&!eY{BMU+jQ@|Rl(bxj&0p}f`GgWg=6SbR?pF#e%w}HoMPL%k<-61a)nsrQcr(lsPvbW z!~JI*4cwiKbse1?%xz5n6EGF5G!c_YkG%DO(mDoK0ckNyT_LvCBu(W!r(#i(DR0bS z2PrgM@=SunTfQ@gcfkT8Fh7Ei`uxN(!aEYnL9IiJ&!S%bLglvFt=svWjNe)UL~tL1 z87aW0zBG(Dn$1)}fb9JnV){uWRp`|0R8d`jHgfUmi-32u5ZP1wm3V{yf@p>Ju^6(U zB}B?%(a#~diJ0+_WFO?1%u%(WKNU;`9l$q+8%?&M0_B(bO2Cq$(=}|?qT^-8Dw&oeVhI2p0@opUM`E369%#e-9^)O4-505i?^e%^0 zCg92wsoe^tXL%6q@Nv zq~LzT+gy!W!qfFO6`{M{nAuGn`9RodeoHv{L^e}@3)Pzf{3So80L5_8>fA}vJ#PL3 z_RmL2zvW>=2Mz#WM*skT|IcLlR~z%cCdL0jUAa>GYq-h^|6hiy?Yc&bC^3^f%>0+#(AM~~h*$LSLeT{oX2L%oF~}r|3H!|I9aS4a1@)r`TLULCBV5iMThnPs z)M9>WQ6i9{vL9w`46CP3K7HxzhMyek!&+JKaY7vHeUgZV87}R0c75tW+_v%IQAjg_ z%o6BmNW;@G%2TvA5jD`lMm1Vonjs@~FQu?2Z$ei|-n=e6*sR9xXGbnx*O` zvH&-G<8sRZdB*Bxk5_xh$&h3W_u{E~8iB|_#3iA{zYjN`)RXk&Y4_N55DlNPefzf& z(c1^zDGr?AxxwihHKiWs>|u?D0KO@In^0dCGQbK#8cV#A^~W-dAOH%?=j3Y=iqk;l zU^h%*|2!Y|a-nzOxQ`UB>EIQ+@pnE_BLQ|q?XeIOpAr43QP^)}C+*IozINu!_M?`g z><#AwiAjK0p^}oh`NLROu=)hp^uQlIUP_hk{89F?D!VHSGbSNrEkY%`C4Og^L`?8h zFEq!a9cuX6<%dlI?-yVXV+Mv?E!yCYMSN3OSsJMCWU2JQ5Yz_*q&(0LiztIA8%VCt zBXqOVBg!Ru`ZNdK#m%{vJvn;`@zwrwj*@XDJ$pAc=f;>_@RfS*%$h#U=fS9azrc2I zXVg%LOk!5aZ3c_3jaa`lELscivXrIRS?;#bQwzz-GXY5)f=^#;>m(Y_inxdnzVt|4R) z4^)p=Ge)?z5D}(59I$b(cd1KMc#6RJL80D1mOo(ph zURg!!TB;l!Rm0U_8E+=%mC%Q>BxNcByNu$1( zYFncLQ76UI60`^jbRWypN2}*H%+%$0bYBG_2psr`p1iG*-dJw^swjRoGR=OJbCq&} z#ZXn98kj!?v0>72++#kO$S`Z1S(SlKq8#&U-LpAiE^NR5j*`)Nlmd?f%bR)Cue7vK zf@40BTJ<5bOL1Zznvxx)^4q?A)$RZZ`bC}wAqKV zNZ;Ij)m!che;b9q({1Telak*I9e7P0-#ZVAM8FqGr7@QRUdI@=8s2jI8RQbAWh=H6KEBU*0Z9f3DS|8)N8x%d~;N)63*hXb9M5bAKN$>YS>uj$fInjIILk1NBsyW z+Nze5L1_mGway4tUt)16<4Tn)%10vPseVqp4d$A;Y8;>7DX>zRp%Eenv~uCx2^CNg zoNk|Izj9Zx0=Jr&JB8_6bJ8AUQD)} zs53(crv@AYUV`nG$;-$$?aAFSc-Z{;oq^shpv)bnxqQ z+epPO>7XfW=xFzzQSf2H z!|ff=WM;BlBfhD&$EwEvBgH9VoY^`U|;$ur&J zZoP9=nCCvn=G7*M2y2~z-^b(GsHW?9pi(&aD(ilR@=cUmaijSAmcnwe282%EeAKS4 zn0D27m%$2Vm^IW}Tj)p>J0;Bv#n;Bup1&AqyNrMFkin9T7 z)k!2nt76ch!5n`ptKy~JCDgJ{v<rZ%vZ9TX#iN^@NnuDLNElM*c-+#Rg{W>H~mm>Hi^UZ{rBikH(p3+2-*ifBS3G79 zuJY{h!y*L4o@bn)!k2|g4p5AH5CxGG2e%9Y)z!c2!^>bOb`e8>!gL06x8b6(p?TZV z;Og7h<~OR>PNG%5c&F#~%C)xgS72h%G}vEa|CLH#zm$ ze{Vsx>};k{77x2aCES$4M^TZ~k-jOrwPS1Hr~eeG&-*1t{KQx2iKO!vHCm;(fsK&^F*Y zYR`V59%mZ}7Jfx!C04SxjAEQok(A?dD`O7fw}p5a(j;~c@vGze2zB}KhfDSd`T#Be zQ)sK!M>|tVN;Zq03Hxw&>qXg=>X>uuvwg+a#i?1g#ap0Sx^eJ~)-QUxEMXhPTBEJi zO1f!rL8p!ZkzWe$9D8YPX8x~j)IrR`F9Ez3+4;en_B>Q@XLFylM3I!;!keZi5;}> zzv+DzFlEWD|2~f=sQ+Gu;As3GjgPieb|N$R{>pu7(Jvzf;ufg9*&&*yl2ZiE_!mkG zsT0}7k)2yREayL8a9Bt&t<7TRFS2E?*?YQg97CM7SJKS|BpF75c=JO}mp}?CZ^!L* z49(dTc6v1zT?8?XJllp?P}0yxjz*JHI3VgNt;@wB#2iSPgB!$B7)j$vQkYFlfZFI8 zsc6E72kd2XmJ{>^vn8%{OZX!+ku9&yLO)m!L|eLSrI6S=H}i+H+5*l%RL7*~g3F|= zac2O1qKvr%eDrD<5zCU~G-$`{3c*x~Crg^@H&MJM!;)^;rx^teccsmbq);{IU8dE$ zTE6l)r)J}P716@XK|?N8k1)nnzzwz2WBmipYBg^^yB#G5n;gQ3hvM8VGDZyh)L z@S7)6zxtUx1Y5NVVCrFeUiagy_Qw>HrUTy$K-S%~HV&&#NeCVH9q7I5+}oaG9OX70 zSHxD{^(>gygi|Jf^I5s+)n_h?q>@-RJ{k0jFgtyu4_ApgZ!1^0VJMGUoyzXsV!okm za$pb(YbON%#w~y%$Vjr0Y`T%>!u|S2EBvCvZS}3$fmDwQ>hu%P;8T;t01i$0(*Rlf zp>4os*Lp&t+v{wH(|ZE*gIpFUa|nDkpEN-eP0sth*F${eKbwl7=*G%OG;9j+M;ret z<(mu2$mrx}4v6D5geYfEZJLa8F);2H89rxIbmTrOli|V9|I=yY&UKG-EW&mVRQKr> zmNz?qN+?F0AZTN9|L&KfI>zI_(H|3QLwt6;002mkfPZgh{7=Q&zcAMSbZPj1&>dIa zmaZEj38$`Kzo;E&{uG6ZTuL2pYiG;57>COwsl_HA8aeXt;tR2o2od)JD$&uV{F)YMeejF62zn9d}Mv(SuRY#m-5U7Y+x$+~!SZQHYGPYKVf zh|9MWOp{qV=r*y5J3j@cHNkL8-ydFl(#m$zoVz5?Bp!G3mL1Puu)Gjj_}*Mg`zWbZl@Jhkd+KFc}GEY8?ueYuu>ZQ6Zc{ zGGw6653XVi*ht;r?(fK98zSzrr1YRn#YkH}z(Tg<1-k%MXs)zW{%TL+av1CG> z^1Cvmk_NA2Nk~~uz}G^#00Bbwr-HP#;D~eT+(KOyY3?*-x}umobZl*IO-$7-j;4u? zo+ym8;+BL5o%F;OpoG$g?tVjK(nYD3_lF?3T)Urg0_ZCngx6U|CsN-bk5 z`s`qB;6yj}yuBjy@kyxZMcJDwH_Rk~PFN+vQ*910#DQiVx}KJlZ8oCW}C z{E{}eHcxBLD1d))zugP~qrqZ{R#gq_wk-sFphz7c$chr)@G^hk1ME)gnQ7uMka}Zn zjo-*7A#us;8^y_h_Pl#AS?L@1QX#j*ASFzA?irBDL~xz@NvV5a2C?_@G6{gJV=rPX z2iT>q5MbbTyI0k1dEomA%$zs7!hjW|Mj(#!yySOp%5;tX%+}wv_P*)q%GO%l*y+Z| z!XLH(johBs)SpI2c8{)~=PBbBUpF`2_BOwfuGHMGok!5m8)MlX9xk4qAl2NjqZ!}t zM)%jRBh^x}wY9f1UoYF#SFKd`Ro{o>rSCqrwV~5bO*6NjT|Fn4bZuOmSeYGzyPnG| z@1M7owIMc9)3v^{Ot!+niNXb>bKytf*a_X7o*_Wyf~%ejp0E-&DG#_;S;?P8syR=% zfvrB{^av7a@cC??bUFY{J#o$E(E+0+fO>L3)B7_WIE$H~?6jM5NFMAikB=icy28&z zzEmMnUa4mt;3Wiy>_`j*38*~vA7RI6!}qhH8NVu(R6j2n|Cs3$K3YIa)FmA5KtVo& zBUw7%>#wQ5(NcaQv?VgckmCTvc_+Ur314xzu_j8$pRSK$?3 z_U!7P^y>?XU~>mK5z7EDZU1odva|t@LXcsK5#(A!w`k=o4lJW)-5{qM(1ESDMw94n zj>WzLMfu(8ud5qVe^IiyFl#!R-8=AcnNMl@M}&mrJoHf5YH6>ImnZW(?DNIl-S?Qj zOM@@B#J7=-iRg$JNke37Z6_mGyj0*2^7-?(%JXilE8}A|$x<*hw2d9etTL79$4TlV zCtapzLE5$5JQzBb^Wx)??y$&5QJ=^5-hyeEAlxoJNBJ$tfUh+>EPHsf-Lp(1-aTea z+jEJn-D3l7TxMmNpO&s(+qJC_%_|x1kQ_qmb7z5fHs?h65!UHr$4gYh>~K}u0sW~P zKm`Dq(3T?B3g=q}#fAO9%ct1Zpj4%jid2a+9A6N#jPiq=1)YWZtbci2*l-GlovW(5*_4Pr^ zx~G3q!7!y0A>z?r$*XaN@y2DBCtg!Z>jcw-BSGgJ6c2E0ILNQUo0SZvg3l^Ew(V;>InL#DMd#! z{IZwVlsN}llae?V4iC9WLI9+#V+(<^JCHb3BUSd^KlA1szPlQeu0+E7Yewd>GN8Dq z+-T+w%#yejD9p3F>SZ%X$Y**m(UTRXP{G9(&fdVG4pe&5buJWcDWt?|`$TC+o;nzD zM5%n|rhb^Cv~8tK`){;J^4tM{AhsZ=i3rBTr2IDvAY7U%T1{rm$C$i;7Rlt0kw?*L zdE^!2ga<_74f{}TrdOB&w5k|K1YVirXVjntRyx0TMMi*bW!SuP*#<+7@K=GDLZyV% zlV|q$Ii@Zs@W4k$`}zf)%USZi^YSWt59_o|nf#NeF|Zn#MR0)XIqALneu19nEL4RE z-1jhvNi({oy^~3gWUM`#L{Lo8PFv}$w4vi0dz)Qw`=1a`0h1$F-_rG+aZRNsx0kE< z3HM%PuoR}};YK+J*AsvX8XS*QARRdAy-XE2bg=gwNXBst3o_x0%b^_YMT0cozS27P ztYfJm{$Z~m?zb}A$~bK44RYaqTku)1C z=CrGk8p8RNY3}fA_uSK{+g)^Edig0lGY!a91#JTL=!W}EK5<+NNfMjOK`lsHh!4VW z!$99nc^fpEv^_fz?EEegJbLlw-)kzL84phMf4pf{Yz%7iBJ_~|g8456}X3N|claF;T9A}&a{k;;E9>_#o zO*33Oa9fbh&kh(gYm%xcXA|%>w)g@CJ~uRI`myql8gyc%a%C~&=P2k1Cc9ukH?RT( z;W6+zlwE$fr`!qu7B6+MGjdYpQ1Yy4vpzX|Oo2m0PYPlW;h9L-9(b>*XD3hV8{}hN znx+v?x=KK(Xq1)(@QWHZq{lAQutxpz4Gw(uKszzJ83o`h9J|ld6_r@KqAj%aaUUa+ zF>}Tyo#5n)(g|k0(X)-xp~fo0Bh=7PMcmy`|Z*a1F>8-v!;VK zZCPR0i^Z^2dll7q3k%ZECkL`8?8uR__EyCGTZ39e&VsZuGKiExIC#vw0cDNKiuaiF zK0!ek4+D)w(X$p+xWY_x5x<3CuX>3x-gOB-wFqa+MIv3A*Qvn{JWbl{N6tJTU|vCf zVtZDYPaT5#57gs$_Lj49SQQN67XA=x>-u2(8UH*eulTjUkJ8bP4GC6g2R9uas#|Pc z{6=OpPX@Fl>fP%)edK+Rqv;&oEK4@^INfRfdvYZNR7|M~z`DK9KZ zvMDhxW|0hon>#-mS6|<}(!14ckwrvv<=9neFQv}IoSnqvkAKzHO~$3Z!IQw4tLL6Q zo0$3NS6@b#Kaag?OC=J56C;j3@=gtrDK}C(Xb^+cwtJvg4+^&?>CEeJ#Gr!`6JRQ* zX!C@24udDX)Tw;PI}FihPGk%1F1PJrId$BtV7NIW>(EGer!C^-yhNBG&10VurMP$m zP2*Bl+LBRL?Rtb~BilfsqQpwXxpHS|_?2;84(Vdk$Vn9;qePvO7dwmo3iH=V#H)1n zB)8l?fcxk4*kvWDE?)FD_ApA~=lKM}841@?f(f)am}TgNAr-l_Y;IYs6XvP}M>%US zP=1JqLk3Q4jBfO_vL&_TElj$S>fhyLx1Nwdaw?84CtEqnW?&tWGNh?B6_dw-m~_lR zS~B8oa7tzHmigM6^t>98P@HU~70{^?X;!H+L6chTFqZ*V;7cb~bc55hJf~|G)JA`v zp{XEb?YyjB8m(|~SkL%z0AD4oRH$NaQBn?Ds{MK5AKo#^F+o*3y~#?ii2eiHHnjW; z7aUABG}A8*;UeX6M(BBu_@s6kTyw z_J5uTqFy!(6HWnxwtKnv*f<&u2H6N>5DJb~QEWI!VCZ_Y$=VPco66)r)Sjc)C^>e1 z*Z3KyUy}+?u#F~&W;?WQY9wYYDa#hHLl4(jNZW`x8dYjT`w1Z%NyLCkJc0>R*TGU+ z4M(zw6~lu*hhyrgOuaYARpB7IRb!nwNMziAk)bxr)*1XpPl;eYU$&S~_rvj2l`75| z@pDe)b1_i0Bulo@&D|s~hlCNiCdjfDwIhq)B7W{8F<4jSOwXMc2b& zS&O}w!Yy|`t~Jdgj#OS)kKCK|pzq^RMUF>L{I+Wq6W;G2yp^TIy?#R_ z|6x~7MCK@N5vC(X8UoeN8{)1-Ip7q5H>9C8FJHh5wh98g1KFKtc_;pLEtfRpkjrXH zC)g7C0&S<;#h*92Qn0RC$Yz?IV#Exqc=@=rQW)4AfaQ}1UcTY;+eAMkmybRff@@W6 zc2`x-CAtFQcPyJOQOIDf-mVJoauaUnT2I@mg-ZUMZ(Dj2VN{tX@cPZRQ8b8Np&b;=}-?j<6~CdX&<^P1s| z!X*ynDb*J%<&lcd4>uQ8*q*^R07P+2vyULDEF#7^Bq2FFJ`99PUA-kps&&<^)^rc% zz7?3NIq!4E1XsZFhVFpebSL$ACvrW*C>@r;=AH-^6*_xeJ|{XI4hSd0uUiZ(CjAk1Sc3k29bSi57AQLR4d{GKRL za~2Fvoy%tK!Q35%(oSHik4S54gqwFJ^t+B{dxDFspJIOzHYpsi&a@S69oIY+xe zhgsv+rse{wIFNx`MOy~DL8!h>B&1FjKJZB}a z1EBvvi!6sI<7_>Muxgli^vBk%)xNReC{Q8U_@Tg@?MAisC3drJ9vS>+)$HoAFMCrq z;U+124n#D3qf$8)4kpuvb-f8sqXmLsY#p}=Pp)1=D;33ak9yvYB72+9Hhbr|(-pmc zdi-X)R#=9I>dae*90}!<{ZR|;o`#M~5>TNntm>)WZ&F`$%U?*j3)A-q(zS$#?mf90 z6+G&rW*DdD^KOy99}4WG=&GrY?rk4Xx=&npsMJb%I>NGZdpxn;`HDd&aHY=y8fgsv z#yzG%u%T|f5yvPr+*y1C69F8zt5{rLxX`PYc6?l2B=0hAZ>#IzsGz;7Thci85_@NH z6mQf{3{Q|TgiMuJ0dpy6uiAJb3w_yUBoyae%*xLKu#t5rpJ9aLWslixbP{u8exYrI z9R&#m!{3cxN9znU;5eWJ!P;LvxeL`8$bqq}Vq+p4HNOvCmz^Q!ewl=#MmXfDkY43abD0tF|PJ^B`#L+Qe%y|KQ7(9 zrea9D0gR&zEKnoHZ36kmplJd<^{oWG<${miqe17p;>!O0_78UBBKp3>R8R z$hEkcX|j0d5$+n)!VFIx=HP4A{fd?204;m}Dw-D5S|x-di(*+;%lyncH;T~t3yApT z^LT_KI}IbQy z|77E-+${0UJ{X7F+BH8;8mr(;j)JKa79|u-aTYu#8Xe)RPuSU?m1yB;eJr|M9~t1` z`o5(>K@(5fN+*45YLIcPHd>*YPpC_t79tF203e$^-PlEtTWoUkqQpQ+zzA>H8}_gq z)}yOp!tFs);YjTOJVe18G+MgvO0XB0*S5=wtC@Ans=c=pauYo14dmz;#!bMJ!q$pb zku{#%9Gu3_Jj8W3&1WN4I$1?Jx1ln3slCd^hQjnt)4+K(Gm?+R*gn6(5x#tCj5Mu( zT8{w2twD_h|0Q%d-^m9Ly}wze@D@}xQq&{{1xvQFJu&@jKGTP1ks8TO_fPpiH(C*?t*oasVUSb5Gdx zOs3axR2U4DyjMY?2~X~?dl|;rQJNX{Q(m@^-jFe45nI5fs@_OTOD4oDP!9}U&+m&> zxN>UvMWQhi_)={94HB_a40zd$McaRjLwSq zSxXp|e9T@QH>q=t?8iqPKn`h}hpzqfTTVAT6LY&#oCp<60v=1uTMH)D^Kv;1^tf5I z4Kxns&11a75t{&(Eh^a!-=`&lX6O*{`8GRjE~HGUKePsiH+-*mC%&qhbDO{Fo45y( z&Fu=STQ4!Z(F;-k5<&~IpgEnPKdXa=O7`X zP*=)467vyQJG}3k@mEK#Zt$!+g?eb|AP#J>P89@8ejm4Q&^?5`fvPQr3wTt{MPup< z692_W*$!)eK7QCn6Qx&~5)G79h7YX9StA@~g%pcx)hxfD*9}fEznaoD4HZkAkaS^~ zAgj5ovux;JMp3D;0Ph$eBP?Q6H_>7*mnrH&QRq4-C`7AzdaS>;vV(gO2A9GFQ7;i8 z+2L>-oz08b-D&V0XD|eCoR@qW4;?t5k7cw+1wLg37H{n|obh4V|DgPw<{rVYfgE#X z&rFCOD|aK^5me2@K0KDdH@qV+FFBxMjte~8uP+w;3$BdfnLu^2**w`c*TkK{Rm4O% zf0ihL)Placp!YUvJy11PJY|<(ZQjj@eRxOgc1luSk@rxzM^UH;3 z6|o&b$WX;2HOJxM_0d6&bWh_7X>UKSZDJKU$vw<5_tK1d8EHN~-BN<3eyChi8U z5o1WiW{&-99UfvaY=Hv}*FYul1G{t*A*)uk&e7k7qJm_G(y+pjTagT3JWSNJlp(_Bod4BDp#1sFvJ%r*J6&nkcs0D?XAa)prU_ zb%Fa&ZD@D&UJeT$ZMIv~QrT2TLLE9n=&kz!(RMNcsaR&QX@u5feG3iF0b)#{8$25R z(YW~Pmrtk~_bN3F`!8f*-An|HYMOgZfBD7tuJ>PRv1nWSl zoLQv_@_zc`827b_@gdKjWg7IN)O62U+k)w=~NZNvi z3Z59i9V9vWJ9T4eT^a7WhM$LMr`t_E$wqjy%;!sq#GXXc8@)M%YnuC=Fc#HDu{ntH zQ`vZRt1rItoLo!bji?ffOb_k&F>Ri;wu6{mrxCNjUTMd)BY=;#JltQ`#2*en4v$;i z7Xoqp>A}mHUCLaZ-!nY1E>?qGPp_pCRjO9r3D^qHs8I?pvL$jvz_f@{y=j)5j}`Au z?pSV0q!m+cQ9XdXnzeneo@O*kO_6+-`futSo~JK`Rj(c7h+0-0cHpfW2$1rv_1AnV zAF^W|maR$8XO2;&hA!)0eH~jTbUSd%yPw@$=GWv#@LB$?Yu$t}P6Ly6HfwqisICk~ zxYFzprJcZc38*r+Lk-@ydSiWB;pxoGDp@VY@xdVTTdkZ~a4~HF;qU|Rj~@9#E9HPg zvV4I@M)3QyeY41l4s7{3o)Qq{db>W{A;pkwFshD}%6Q2d$VpvaC0y6ue;ICu#=i%A zsLZ6y0P2XhbS40uM^CzSiwKqT{lKn`o9@=|8V9`8w=D0)JRpprC0cL1E}}x8Em;|l z{BE&#h6_}XKfra$>3PHp4oS}ztfTcCF@B5o&*A2!xaQvgoNZ($E}i27+Uc6{3?E@<#}Fq34O%4N1lsB)GN9OL&doJR&Ai3a0joGFyJwOc^tEy9C@laE zAlp%$+X*P%+tSO+{7l84x;^DdnSI?u$~(L%2=B%toZg*s8iZUeG=>V=zm%JiiRCHN zLqPo!{w2p?80WzNnh+nd7!04%;ND>e&p*0UQ~%fN!WF+?E`yWuN53WU1tp7LaT|w; zGnJer{-SgqNV#x@)cKbW{fE>ma?i}ifdtHCJ@}# zF%nMOw$tj2p)opHiRA?b8I=cx;pH)l)>chfbn*j!Y zEKCU78|tR=Gx;;6qZc!)y@)n1cDlxL*I{XHmn!MG36Ucg>?a~RK82>r#1i+-VxFsm z!s{OtIQh3(xHO9bJ+~kuLum|nA_3pYjI{4E?;qx^PooC3ivf5!*BMKPz#5$pjC85G zYdt;PqY!^#J%_3?Li%D#nTBQplD%8E#Tu7`!rPEtEXWLaX&ow9;8W>$JpXK+W%Tl% zm!;FGOhw7wucbeir$C>GtXQd#g)r<1C8m?&h{W{2BG}H~FejNf6b}Q2b6g>CAbBeg zPD!x?uDAV2f^xrH@_-p(W#eM>r zyJz^~^a`j-LVF7lphr3=EFo(L#O3Wm)Vh4rgrpE}UZ{ld_VwkBG^t)T@;Bt_JxZ%s zj1Gidsu>3ZrY{!48jbJqqXJ=wVmfGkV3q_~Q{C1HH${;vJnLUYR4Pxvsk zfCI_x{N^RR7ws{7-I~>>*2+m1Y-SQY0Pb}`j)Hw7WjBh8arP^u2&3n`1_AQ-tETo2 z3F(fGyrXEVJqJUBG-*l?W2}uFx0fN&L&8rjLpgL)Ua%0Z8S6bzO!_sVsupr2MHCOd zNZC1H-xYcm3CgE}xBPKC{TOycFM&|y=s!lJ KIT9YYK$giZ6&Z#Gw;~_PO{a#W zfn+%A&&voYgGG@gBIg$OsoJ(JlLQav74_;H)%+EFVWHXau_;?O_vww=h~scnGEn%? zY>@b#HjfPGYuNeQl|m3@vt6aVnEJ`SxccgURCbn8ab$}Y1_Cq$hv1gr8eD_ByAvch zjXMMg?(W*SLlfLxf&~o(mteu&3GkTA+)*t5Rcb>Jdkc2K9`{5Gy|Kx(xZqz%D9It?K40QW6x2ja-Us_3K2h;zF^w_)3jW-s{lXdEjdDs5 z*=DW6KGN_YN{i&>$SfxO)gUX)jKoz2JT$Dy{%(hf3NN}CX9i9hEIH$XeK#A2Je3P$ zZQwUVWNv#5dp2hWja3LtIJwq>2u5M&+PbEfOcP+WFyp)>%AKn5O$wC*n>Qc39Od;w zW^=0Y$JX9LxAECz3e$`a*(m|$cq3bY)R%aUt{FvJ-zJ3J`p&{Xaw-GL>PoYQt9#EmBN!dwQ8BeBs<@$tdi)FHM6g- zJubQ%QjYnJM6=6?yaxg4M`lLKz1`3!+;nvdTN-qu`-1HqE({yDwGs{J56Nk$1KjR8 zBQiZPn^eWfKYp+!wGa{jkl$G-x3~5!q&30d9H%}ss_d5?k5BR{x*fQXNjJiwfrj)x zY3jW_&?6$W9My2?6XCFZ@TOynDMB`xTaL$10RoPVT3wnCGqX&fb#Ha#iE4OvUJ_^U zW)5qW{y1S;_c*&5-dZ`?Ix|8D#6rN}r)}i#h#$sIvKU)GCfH{XByblyrqDVm9>5-u zPKiYHdNvQlSq!TO0CK~^1T%?WcVGsy}SNvMPc&S!P7vg%361B206I8J@B1Jir} z7~QI5^eg^?@dN3?Nz=^T)%M-l>i4t4B!2hp(QSiN8_wF!{N*w)?6W@R&4K&2%{LZI zRIrC$doFcV%aC&4&yHBRu&W7-;f!FIvW$=kh5&DG(poOS{UMft14pbMt0JfsE~0=4 z@w~%0TW9O5CqMQ(H@4r8(9`7p0E!7lt$oh0#(aTlnX)D`KDs=Dvt>oP^HJbtMz0oA zj4Cs%=p*A@ueR(A<&cMc7;^LQ?)>i2w-*zkjpc`9A$J{2Riy(t=Anxg1vuD%?aJc zRl09Sys69KP_~br98%c*Jj)vi9Lb-%Ti@K0TspG=m(~P&YY&&`!QK1t#mSA&bU56M zk9TbVwuuoeQ=M0t9Um3nspwRt5p6@qyLQ+VcwT1Xlv+hfuhghUU3;Tcb*>>B&a+`w4fJ5{9%*CuDYvAyjyH6MWp%k5iSTAW-MI`28J=f-;!kYgPd^buQN`?Mu$F&gi;?l_lTk-V=-0!WMm6nclWHf(QFYe?#yNHWCkd zLuhNe2#Z&jhn4x7Q|T6~K+I?QOa5^-g}QcafaOQh zbqO4Df^2$Kp?PsN{1)EpP3SvWj%_5B`qdIhfZYd{eV^JF_fE-^=mPkFa2*o^nX~ngzYdpm|=iYgjSE}Q9**Yc9EMYo!z4!&4R<1V=Ze zxaieNa0Z-3qL4nYiE*6l=mxadbwRsktEJ-%%Wr?FC0Pliq#wvC>bR|-ZV!v}QlyRk z)Ur$EOCPwzLWqf*7(@^*0SV(iF93g%${o!o`8sc(tao5)ucXaqSDoftX`&{8JG~qb z%6i0|Be0_fVHOBN*TUL}C_-?CK3bWdef=6E~C9jttIM5vX}@f*e?HO+UEr>i4e~M0&=l7b}~LqHeSB@YVF)HP=73Y zWp`aZVgX0bTFiG}zuLT>;GeWC6AFerCAYE(=hFZL7k2OolditshG3Vf{&a;cyrAlk zSvgBH8p(S5(hpf1=2g&CzOO>d%R*L&oj7bY6F0;k+MNtC$?qAccmgHofR!GN8|;^K z@B-Q?av)BD#*s82rU5@Myk-!1tV5BP1J=K2ysEOp1B!R>2R=;&lLi|~kjKzl zSG|c_Pz!mnVV*VpnZUP!Re$RjWNELjld_G9jm6oSzA`zdmJmtT^IT-a&=8(fy+d@m z?fI&4nf#z;3}*fEO{Tce6yJhvxTeX!;Kq;X8v+45OnOebhi~>)EEZ>}8W0_Z(h4Ns zdkSR#(l5$EV(in0SSr8-5}g{HL`=3cpV~Tu(k@i9-wu93{&ZW2Q!4az-wv(CC>JQsR4zN6kEU#V$=QpGJBY243jzS;XK#`awV zo^pcIipD^}`S?p^{2ZW~d9%)gUQ%$7v>_m55`{3a1ZYT$#~h_N)5$F^XHl)X=Qr4C zaw~3@{q9xcHTO~AQUsza!_xIcW+)ryM&&wX0%(0-ck*i}2j&qb`n8RhhdI^}06AG{ zu;yM-ie=Tu30n7_1PyPK@SUF`Rnt$Dt;bkVy$33rZBCA&f<^>rwzPc?^Rfe@Y^nIK zMlDGkAag*H0WZFku;-jZ*SFm|pmv>}-t_vt0%(6iLC-F;A*r4aCBvz3nI(vuvvdier;$zhlZ9Y1M!fiVRY{U+3q zWPRhTgCs76zLqb0RVo*Mu@R(oi;^2z(lr4TQZ7f1Fo=KG#-PSdT-&bboD`YnelHcXoC`cF`r{q_e0fiD16P)o*#<6BtpWBz-v0;4dUwO_Jw3&Q-rJ}^mQ zBHq3zn~tKb$E)4ES2~iP7ib-wnA)!K2%~QNPRH*zsFc;WEi7P? zHXKVH64k#MrIpBH_@*5_<`rj;JICjrs4U=f(lsGF?q~!BKpECPQkOwlNFyBAJuSz6Q62LA0wrz+40<leo0+cT-rjAUmdjn}z^L8gs1FP{E7tJs}|5!rz{0c#j z6xlv}IXc?vh<_YGHFV7Gb%xBbMFOH!kyMH*v0iO7psyPeq52-virQ3;fqW|O-d+G> zlElA)1>Ga(Mq1AZwk^;^Dpqptg(w=gjg9^~=M*_2Yid*VT|U9Fkp~d5kv2(x(#sj! z^<{tSbtoK7CvGlvFOCZQ^M0UD937{$6qiD(8EiaX%@2W1^HabkC8xRY7duy{qu_f3 zXUb|FhU09}IF3z$*?!CPySz7sVd6JVSP~ z=i@~;{0%t$wAnL9D3Shnjlgj*wx;xTUuAf@w_ONXO<8~p*Dj4*7EXZ%tju8?KkEJqhFsRv zy6}i79g(`wbDqy$+hCIUpeAkD#K1&P9W!ga(p~azF>?uKNlp1zP9o!>y*xF>a4xEm zr`jpOQOTBW?83%|ZSp`mb6IQJ%>b(o;kV$p5NU3c!7>TwGwK1|)JpN#K$>}K>bDL( zPPEh0s7}D?I2E-D7f6Jamr=~mF*kIaO1Z00!%8-3y$uu<-mtZsm(oCf*cObO+hyrhJ+1z81@ySMifb-dLEOG+1}*@)kp@+*iKhrHY`no8cs zOO4x(^`RWu1}~Sjf3$aYMuh9Z4AAgHPj`_$4@PILi^Jbi{}3^2qv;Y*MOhF)Q$9p6 zhf)6M2(l{WF+Wve+hWvKR0q4$Q^*%s(nJ%12On+lx^H!{fhI5-f2#qw6rd8xl~-{eWKJHxu%)!UfF2JGknNZ*zcImFF{w6 zf`eRHt13)2#Tat?RYIyg0OFY>C`Tgv5g|y}^$P41QF7DrlaF-CIb7qF5;+Wf z-rK2Wz*+BH3am!&i9|NJBxqcmJ~{%jx`&GNSGb-;A}uun1ecCNq%2_r9vDhbxQHsh zv}E{BzKob*{Iny(gg(i@kHf$7(bN&IexSk4OzAV9t_RL(0{(_KHH_YC1rAUFuk7t4 z`@XVEInw3FG(-Dn%Q#k;y-k<2g{dGfI^$mQzt6B>I+6fDN!HVB8QfiB9q zRak5ml+eMu;+d%4*L|tcM{i%Lu+c$t%avjEFt4(;nusm<$qA-R3kD_&z{1XsqS#4_kO3V3O1Ox~ll*@t-h&1m1-W&U|>R zMAgQaZGzU+uz+Jj6(*Y*P1ucb0$kdt0pWQ%4Rr2C(t7OcxJX_3OKUf{Sw8Nyj+42~ zmUzfDZtLLvq75{YV}U?Oyg9Ou?vl}m#~)Ch)%EFhbGG4HQ+XJ%7v@yEN6M@B@RPA_ zQ65B@n4&I`+irtEAbE6|zDwZ4O4|DBV=Tnk&6a>$j9VRE8!4`oNQQmFmUuNZ9gUGI zaa>Ij5>SHF@J#eglePdsYTa?HPP-~@Yebvm=1l&xnWUX_8R^M#uK5s8P2=Jc{k0W& z^K9~v9#YTn8P4ACJH%^Td2|V?$9LV`r{lf7436iWAE-Z6roIXDUB&H}t8XBSEs%UJ zy(hl4TPrtd0oF;o9p-sgU&_1tPfM>YS}_*5NOzqh5r(K{ zt-2p#T2P5L2y!n%suq%#m?ails1v_;QqWrWnFTLe;hteR8B4b7tOf6mc#F2CU*%_# z+0b+_VP+BJu5(}(>_I@*XF)*Yz&ta@3>d30!-nsO%*GLe7Hm!j()80vXcBE%5Ttx> zqd12jfa!1<1-;ZG_^~2LYejMW-0jhdthoQXr9(oKU)mbnMXUosh*!QMYl@QV?s`;e zSrxg?I`gNQtoLxh(inO)iJeleh{8a!dQx;m#6|CMF*Rea=)g;smL`w+Knr<2rTZ^Q z?H78c;*_~V+3(%$-XF<2+qv=0@n3x8wX)s9@>%wta4igiI;9X+KwjVTIJ+?L|K7pq z);J(AZgo&XRe0AH@maW<5_VDDyJ-33;1CPxQ~v(`kcCWJQ<>9By&7PC!lbp+3G_^U z2?pr6C;SX8ce5iFeFzC-`ujeoR|7f?6(Dtl&GHQ(4B7G$m{OMbuu~pN*G!3Fbiw5U zi!TGI%Q=RE#7{7wPBRO=Ia~J;!cM9S0 zP6p&5)pjlTL`}wmFPG3tn%L#yJn`d4k7vs-PO4zu6w>94pF&~kypSPRK<1GJR^ZoW z4_wr@FnoXGEZWsIv@7@A1iQ+n;k-DxzcE|YuQ&n}deF8Ju>W8@2G7Ie$B?X#o9;dJ z@VJ4gm6fW**B&rnyD0aT1uS&#;Kg|4ko>C2#0V20uAy<;`leol)bTZs1(xCX$SHM` zI0^=0NcVE^y22_=Q53fF4@zs1EcZac^s9%^6tx^#XQoGI4#$jrVuJ9I*VRM0YvbGV z(l-8Aiq01n?%~A9euQZ|{z9Kh6rIT6a=C_H5TV4$D5ASax`w$^EH{1Jus24gL!#4J zuZK!NzSH%D;aLcCwF)!AVc!$r-;+};a09)%)W?aW=)!j4_O&L-I$pAwYk41OqHXYQ zhIFV$@Y733RY7Fm}^z%9PX>JsIKby?l@QV^DX?|TJBQ|R<={q1njMj%?;XZ_RWp1YeJ&lygpNjukrqK zxw^HVNAX3EzJT6AX8YWOJ#?;{pB-sEk@gG49P&bgrD@hZ!l*0HGIt%1T!Z8p_wQD2 z+Og!dg4Betiq|QYL$DtVdTnXZdKT<8&T|sWv30DlI+ zxs(wb^;=pFIXSfAoUR^oYcp}NCkap%1~hx<(jc0wpZlR}G%&P=jEu^UR-cY<%%8^X zxjNcGHkcA411Xa>XHH`~uZ+<7W8u%uW)qFVCE61%Fbceh&)={c;h^C7(fM=o#oytb zWioZhDVQX_M*+=lKRcvvCKhJV%RsDhezrmxFaCNvBPIuO=Jm_)#Z|xLvUqU=bQ1fc z?EzY@(51CmCB2tkSOr>+g}p;H6xc!IqvCH`lfQzXH&Nb`SxZDPs{NoGLK^|`v@L}- zbWV^c?H8@;qY6FldcLzq;K{zV$HT2L>AQO=6~kDf8^e^EO`7See3Xr<<|^k|IromB zO`ZS$=0r<@TmY? z$+}e_NA0Cib$?FxVIU7i))4Tx$vaRqu(_%)0^}(m^7w13Rdtil5Qe!j4}CwLp3npw zJB1LGh!Ky>Cs*H+0I96pDnGv>mZAX`lUvcOBY8+Ov#r1-iIENz@AB7dw>gQmN87N&rcV@tO7vA*nG6Vq&VaFJrsPLC=;=M zz5U_G?%rxzE(^ID0Vv~bKbel1H}zn14uyl6fz#(Wm+{oYcS;B2xiUUkN=e^s@-1Tr z9L*WQt!A!}D~$L2Pwh#RZ*u*=Jm#%Tph7^P{cHF7zmrz}udVCqi(&I(s9w{0OmB&S z*xGsIVQpv|uKE8s1l>iG&z>wcP6r9)P`a~bwqFk; z-oPd<)4x)ONlHV!Z)w%MUnd*r?G-%{!vqUdU{)*W50K;O+jf&<{2&GP!?3`zriw{+ zPv>dT$Ubl^XiTo7CeukEq~!5S&f-!GFyN*1LdE-Y$aF^`Tz%`kWQZj*1cT#8R#MHi zvpe1c+N+SHK1IkEPnHBul~`Jp+SNkoa!6^Lkr$%dl5@hCbkTW(BhbOUs$&wq`CzIF z+FBNmTdj@DW8-S~>(|4oiwcd$Y>lO|v1Hj3+BSCpTDz!D8SwHu>+hOSX9|Jpwj9~c zfCa&NG~lpE510+wTh+ZqKWy*Oc1R;fECnM2^z0c+hJlomTdmJv;(TZx7ly&`$*o(e zD{(!hD56h8lxiJR9i~)QdU$e;$&2*!N`Zt)PPha&Q$Dc~r={lww=)z@c1QjjGr+gW zKB+b|2?&;W*FQ5-@K3)qa}V0TNpG8egScb2jP6~l235&jO7$dq=eep4o7H06^^j#i?mnwyK zMeX4d_cU_bAS@2luxN}eUTeTc;`V~i|N><*^tMw+c&%=j=x85KQyJG7!I`6XYwMcWnNXReQTIn{mWAHQVp7p9mX6}pHc zqD5}qi;-Pvy>}1|d{#&*Znqs-i^Li#y;~kU4~Nd*uxeH=Og#K!ul&ef;dX0SWVD4M z;z0oh{@zenoN8KUFnt2C&yxLuoxd&=K4#qVATed_Q_WWqdw_pv&}pZG+kH$rPxL^R zo?g`Y!968+Qf(8kiz*wFD& z+~NPBE*$P5HTvjIGXMtxf&4^W_|eYkXYCDbtZd&o{y#MJGFmL7Q6M0i86hB0pJ-}5 zW?uc&w6(W!v@x`?{6AIyzMuZ3`q%)Ejm!Gz0Qpn(Nxs8BweH`mBK%d>KT3Tk_~Bow ze`a(2^BDL~cQuVwoVQ~_K#bTxKoCEveD_nqUkz0LV|n^drT!1(X-eXsWA@L1HqY@_ zTl`t-Kbqz5GZTLg^fc`6SCA*y6VT7vJc<7MJ;2j6v|jVYMuKl14iu$<7c1z zw~_AeQJ(q&{z{3|`ESZom%!g6JhfK)l>pZNZ^9F!wci6gb-wx)U}N|f;Aa#4>Vx%r n%1@t4epRN=`AhlVGu}VmBw=9xm<;~-Vle!83zshVy{ZZ%05BNuz_4O_7EM4^V=^Q-&57IZ!|A$0Opy0_OMWA44Y;9<63iU7G ze{BHqKj&JRz>4by1pt7<`hS3p{tR^rH^CbqnG2fIrM`svp_Ea6?A1sPwsL2=eu4E*OLd%w-%q zQ2Z-s7k)9pTM%g>NL6lW4@_2o=-#Y`!7yzFek}tL-_1CH?IGa&u;u}KoZBRii_n6u ztD9gg0Gh7k2Ak}{t}U?KN< z8>nxgMAR;J`Pm4@n!teZ4g$m+^FCf$-vElC7rzrMaoR(ff@#&73t@P?g&Ylqp4p9v zQ;L8_Mz8Zz&+DY6Dz_{m6L>??HevR{3xXI3JCE)3?RupGfxN+mK4yz6I0#rb;9ter z7$-!T36_Hy$*3Jx^GLt(gZ5Nl8mlTL%DRn3UnsFgksilJ#0A~ROXP_^H*IDNZxCZ6 zo1;;;649O>x&>FrSo*rT5qa@* zeOPiV!<@VrlPCsP_sif?D5H?X>j>WDoNGu#qa#Vw{-671_X?ZxwKSh)n3)5`k!u5> zEfBG!WH>KJa5&jtnmd5l0x$Xu$}^Y2rY0$`9Vh{CoYdO$il@)L#oXIsLU3%i-aEJO zj{JMOj?u&F>FB{&m840<%UYH2*XB}o*Z*zL>C^Ov^i7fnf8mYB>r(|=T6-M zv6x5buAa!Kh*U)vQr!69o1P{CeI(Q*6k{u3+AC>l5Z0oy=qqcTfz_< zWhA%KSJ41ss;t6yh(q?TlgTCrGjwnQh5|F@=APJPeC%{E?cvV|ZODYmS#PGToUu)) z#V0Xj_Q`DAE2RHE*`J?^ZHc*~lq)2iL>pM^VDi@ZAli%>@)rBFzI^>NG;Gv~yA8%M z6ZG)o);|_n+wfTvqE?9E5zH-A*`t|!tj@7eZuJcpeX^FG)5M6PtV}YE_PcRGC!~{K zWP32k903QLBg$zFm~9He#| zm#cFsse&tz0bc-W*(@WtkR7_NtEn6C6w;+=Qv9oI+hYRefdl|krj&Vk|52enh9# z_&zqWS9sSY{I0cm?BM*{<7yfjh12 zLVb2x>>h=mA5$OKFY_nP99F#++64%P-6q_BkU`^OV+SP)m79Kp+{wIv@0L84*^J^<>oZM?YMfv0rqBiO^e`aqoRRJgnp74N@`F(Xr2GdXHisFlvC<8TTEB zk@pyo1@}F+;6P8Ew4xsYGRrtRRY^Y^PE^M{N+1Dr(B`5=~%{p>u|$ zyvTbY_L**07QfEyO4!Fu0CA&e$t5u*N@r+|TgWI9;EVyx&wfe0W=E(~xA*LX4B)F1 z+UV(uOiI)9?D|H-AHVG<4D0|Vw9^0p9YdT4&O|+hZ#?VN*B#jAun%mqGP|c2rl}qJ zO;J1pIYwl!1!i2!HLcE!AJ77M$=MT|U1CZ_R-53Z{wb^9(}J3))%!d-9M2H5L2)xx z6b4z*<8I{AWf}dKKsl1t8oS6FL&v@Hxp;ai;bHxgZ9-TK4^Vg4#Tw$d8&N0eA2Xv2 zC6OYE!LF4i|BXWa!IzlwNT19(IJe<-OxuH`PyCYmza|_Of;W2u<$=+*$ z`q45Bfpg%#DA6X4i#z(G)C0Ooc18L{0}KDr0#)+mmNBtZgY2r*7aQKFK9E0%mq{s) zUmn{FU&2KgNxh|zW6&ZPloX~$``K*r9(Z~jHXR0&PD z-U7?YLHoT8?3*q$PZ>Qs(xy$tK#QSf0N#%OMyRhj!g5cptg|ifG0KOYE=q{+%r9P^ zg2d3kGO{EfdO!E74PItGn3NF0u8?lrf!5Im))B{w<^eSuIjeU6YwJD6zA~mQWLC)_ z4r>`E8FCstTuL((8h*jljiy>RCa>PJ4(0nkJ(%L0fH%JiNtNRUqi3mN$#hBGPs4eR zI|5b|2gDC_740%%}M#AD0uN7b6P+E06_e=Ihk478QNHSnL7QKch$A}cGRXA z!k<-rMjd=BbnRl&IUI^l{U6jjx9iX>{8kmq-97Bq#U4TgEam%F-SEd zs&f*7B{dZF)Gf?DpIXPS_ohwW?{w3u+X@}`tQ}Rdi6jLJ+VUu4FV}k91tm<=6qr^t~*erpKD~U zj(5SwYgRhx0h#tx&`7&U%UHf2AmH>Bq8PY7)?m36;odeh)$Vdb-ptiED#r8dBK%+d z7h-#Z?~xHYA^p93*0bZ60Y_GXP^C&W6arMzhV_7G7kI+56I)m)rZDb}ivecR!t_2I zfxFX?QdN*12*aG>z9yc>xD}xq)*XozhwbRFTz??*_`EaPx?r+%rh|57dq-Ip)==~W zf}pA?)V#r)wx)sU(R#B{Q;^rOFhALp`lC4zEJ(-c&@OYU4KtTXCVh52_px@^tHUb8 zeYpe}c%PO`QND18w&$5~$+~}WA;@Gd{LihO;*wqvenz7DAlvCSfX6JWm%VwwZV{Nu zSiz!wDzZOoP97WW0A__|X=iOm9w(r~Y7hd=OjPs*yAp(Wk<*$Lr>7VX*|nhH5ht!4 zwu_8_XauE_ygh}v(d7+9#@%g6Y`_NM;@3YX_f#bpOSHSSJ-~x34X|F9hi-$DcbtNv z6jP)kw=qPtNbN?D3mWXMZsM(g%ZSAt&tqF)St6w`Yr{7okoemVENX@Er@yXh@0=7q z&!$CYyUAjH6+j_>jD^K1R&^(`7hpzh*k0K9TO%P8=520LGY)-0{c-Qeb!)-v&>%l!7zr+`?=Q*ooG@L?gKh@s zX_Bk>gKGE#KXwS&1b*l_e9o`D?aXG$H7H|{kQ>VDaBP|R`#)^%|B|lThBPl#SO5SZ zF#rIJe|rZjLpMWxV|yDLQ)3rPdpqa<<|K4%?Ke4)es1&!ockoSS0y?>_YrK+#>}?6 zo)THYw?Waig-UF>)TpKT6y*-6U;1{Z?nI0AaLDC0X%v7(ZJo?GvN8Fx2}^j4>HRzg z$ktip>fqKk=c~|vhg0R+Hm|cy7gdQg=#H4AO?g6Zn`pofp0tiY8R}QRGKX$S5R-^z zl-MNOlZ6_IjL|QRVrg5Oinu>*SmH>vDfaJZwT;;2*F03~Ne-16O%GYKA~E$zMf(=d zaZUr#=b#?+9+P`~){W_3UtYO#fO`WKAAnz|XQMhKhK;8E4%LZ{8;U8X9&82*En2D6 zI>L$2=uQ%9Zl0}n;L&iLbyrbhxQ>&p6DI;SKXj;-jS6XZl@pVq+Ad?=tuXa{4_$X1o&nst7k8u-W2r1jp23qtSg0l7*K(B4jU9ty_> zz#4Gspa&ka>y{ZBbP|{U47+I_@v@yD_4&sZFwclPXuE)t_3T?SV|6}EzY$Nhito-2 z8_@_1rce}y08}V1`$g-~m}0(|=AvSzM>$sXau2QZ*&Cr>=IOyd@GKhlZ=v(#h)LfS zTaANq(8knyHUf2>98jeiL{Z44QInJrfWS(oUr(!rtl|p2DeNUYke}TUO$Ir&?gK*l z8vvS1{gvG#BZ!NRd%~F05$#%q8;F=H>M+39dB-;>Adz$frUvY67NM?gR^>9~45q5a zkPGC4iL$8(7MSui0J)(@Cc2F2^@0H-^~XGts#sz@5#zZUDUcGf zq&}|V&YYbS3?lXoFpEF^g+n#L@F`vk5gU_?>Vri|gGtJP1Yjx|h%k5HkIa@F$UfmD zS(54u(0t#;YjyXdO?&1PIDm^mOh5rFU|%{qS1b-jh>?xeVNkx2nzLpmDhK*)s<02T)uCp^(?ooHd~{ksNFDh;BliFD@>6TNPF^P;!A3OQnSH7zrV-EhsLu%@G94XrA zQ}I9I)t9N%Hwwgou$~btcv*H;cZ!2RaIcE_pPsjMz%QP^b$0jr#mg+ihZyd^v1|AkFSo&GGzq-3ggqR!>&%tW4!}=wSu| z+w0-D`S&04yxF!hB~%haj=feU}WX3E2N#!O|h~qPvU|Hwi0{p);M@8fbvmoYx@J6 zh%3Q({TW`(VD=#9B6WMV6lnL*uwki}{*a(L@l|C^xll4e+*EuhpeZEA0j9n@QJ^%g z)hCRaVAQ&(&Pml>PZ;w_oMBjhD~kU5JV3^wWh;khBFu#fe^^8M9Wi5YM9HC4b*{6o9+SYBa{{pzSiS7*O8ujB70jvil&@_G%Y6B+eFFt!6 z`|+Wn3=rU~L6QhMbrmdGsIkE&kSW-P)Dqg)RaZS)>(Bof6_{i!?VKhHHpyp{ml}}` zR_0-$o@D1Lkq%YMQw5++dH61xoGjumJY+XF7eDP2Jry?xe)*D@w@$m!_lFp2>$59v z!?a6Q+!6=|A;f6SR?H%Oqe5dm$Xn7qqp#rrZiAbw&j ztrWZ3chs!~hu`V66wXJqfyeS+0y8md0dN?~K?p@rk9T^tXbS|{OeJfP#h(V?+Iztj zy@Q>lD5R)1BZRC#-yf)K6B|R}b^QV9YGqb@pgYzzu=%2|7xBX|hDgHrQ%Uux0*)g{ zNn@iyqv1vUI3mm1a=opP_JwZ!gSqgaB0alApf^V^Ttif@P+S{uMr)98nv?SA)0UkSS+43i;s#hJU`^=f4er3;x|O9TJkY&W7HA{;mI(9S`2Km!*pWCQ4D8Q>Iu|Q=aF5Lf|6Ra1Oa<8r|uI`NB0*k zt8XInD+jL^lh38(D-+F`8R!+wJ$+EicnA$m+U@@^3b4<=a%Fx_kjul$(knFj=o+e6OiaLRED#U$Zb$}go|#M zboP~C17!e62Wm^>YS0IeUs_EC%AZQ|IbEgGs|gmy%eRh#M9&)0>7ahSt%40bim z#mAC`0VAFT7Lu$P6c&E8yY(#^3~kcYMOrGZ!z!q}W3Q3Muj4zOBsJrC&4Hk*NM5ti zChgsq8x*^*?bq0r+l;@rT(RaTh1GePEJ)eu{L;obStX*&qY|GGgqv(>{Oxs3Pkpvf zKhOp4DGqqiK4!nh&I(~ywbD-OdRs3%S~kI+(7SG)E}{2prT%mOc!!kUj$Z)tSH3C0 zl5BT!|F6^AkN8U&`OrLC31h-K@aOI6^X}{={%QA*HBIC+e6TO%5c?9++d5s>s61_5 zC}?r8U77*^Sb|$~+1c<$5PA8z?raW1Oq!NlnVNvw6d|uKjtP&Sv!4&k1}N72K)Hf& zAH4`NziII#d&YVkr)Y&m_l9*6#DVH~h@ox#O{8{pqVCttOLhp=AWmgoED`lOJS`?#G+PMcrD- z?Tr2tsvG)<Tajq2JC@$;^!v- zx3uTaoy~LN*(c-L%TSl|h{)xq;q-Fdn{?o3&(_(~3idOOS*e``)?y^LB<;$9JJF*&5RtQ>bgi`t7TdTwjTX<(}(Y^t9`(qO!liR{M33WyCtsztsabK z6p_Y{J2vqh_M=D(4@=!@)@-zC!;64U`Zr4O+oQ!Sk73YLf90sCdeLW?U?HsRHr_j-TWz4?qsof#G5irM>k zX-?Ov6sw$~p_4Hu4+qG5L2t&C;wZsDe@T=ucvo4}3~a?tWSuWZYC9xvL(yQe)8?je z8fL|LMbG-OZiAi_&bGmxM76&(0=tMP6oqKMfdj)G4reHYsIbM%ecmam z8gFFBhKpZJlB%xP4PW>FjTgP!d)y=w!Y4;N?J5EQ0E9?@04V=K)65NBOx+DV{~JZ~wRYbWOFVP`hN5wv z)m)G!b|dq?t)8LkY7`-tsveba>h8e9MlT%G-mmUr{&*>PmmofchClT>UioFTV$(r;#$ zbbT&LZ-(WOeK^|uqLc5Xzi>;LO*-^(QcfnySY9R?cb~!A$f|tWE#>CkJ#|gcpB7DO zX7|jZiM-5W%b0pR3Dvf+7kRjT9hF=>X|Z9UPg3PXb!D!wYFI2YlYanx;^Hiy;8tt} zWnryJY_v+!lBpaKN>;w|A&)OTU@O2@*AVQV^mL>?fd@W26r_DnlR@_Pq97a33mr$N z*ZZ2BnX&ta?k#tXQv-MtpBUZ8|KkMAj+5Qmdh-jVWRke%8C zlv4T8KWu7ExvA9keldnr=nhg%0)6Lz@VV;gMHxCJv3qW==))vcbtB4Z1HTc{P^c;F z{Wc;|Q+&Vy=!d%*0Oq2a^3ne&FWJS>J!E*)=w%#uIbBrN7^Eh3@86<{e9xR;tq-N` z#$I$*hCMxtG{36)Qs;%6CDMzkMS5#&AaQxoZotgb)0eFyIhaT=+*#5BkWO6D9?d6iIi}SsUxb*Qb z()Ujaa%leYyUm`i>3eACaef4p=cecD;J285?+?BA$MyR8J8AdFLEi6E8Gl9?zmA`? zS7J_1-p?q!TLispM9??8z1}sxPnvG@OkY^&D?FT-dD{B%{QcPzE!yvCP7{9b7hBic z%^3|hhwr1Ai-V_&mrzf5@0sKKf%E;IHL*o)@VXzCk;8JBdk3!NQ~@f=EcAtiNZ&dgkOmI$Wl$nKrOY`3M?S|3{O|k(Z6n!LaeQNg6ie$_9We-gfgk#CYFd2k6<9i=-S~r4HfKrw(cZa=I$+cQA4Wq7Z(5 z@6xw*a*vcU7ew~N#+VA6fVdwd*QF7wPPeoK!?XGTBzU@=uCv*iCw5F768rLbWc^r z2sieT<+^)NEbjZ05@do#P%d76v|je&+_7G3NS8xkVC)<~=2WT8KTl<(9vVB!lqYcI zHUQKsVqbf|Cf6+TZZu}O?>=E1&j4}A$WfhO-|ApT0m&QC-f}EojP-!E=Xai9WA#=* z8y8Sf<1eANZoXs6L4QSFz^M*aa@Cvbbd{oxvCp19;iDVp{#K_SGMq_(RRWYbZ7XB1 z8nJ_w`eo7wcxS(;UXN8(O5VGF7^^0Kj5e1$T0cEEmZLjhOk0*?4g~!Lb5JNfKmEdg zx#BdsIBkFET;8o+KOw6FW$kJZU7Q@RtToKZTkQ`8j=Nmhnp55D2_QvnIv{q*?{Wnpdgx)1Xc3^VVZlo?;1JNv{P&4S0+t=ACszI!>Mh79wgeF_W znJGhR|B>n0?E7A;_SW-PsOLG-x>>98^2XjBgu36jC)1vT^J72-r^C;2nS zR5H6(%FYg}06uUhKaC!KVjWHMt>8qD0tXKFkZ>pffWTe0@ZgufzNI!e(V3OTvvPVV zOcDBTEu2^ho!^5vu!Yf4(LJKP3D=Gk!Xc65%BD6StL$kGI46*aHyssd2?E37g?w-b zI$~+6*3i#2g#=o)boK@?V{*(AF4KX|E(E`8dw?|g%u)k0eIa3L3&0+sZ4F`0`=nV? z%n=3H&W#d1vzD|D)Yg{#|f+Xq)NoV zK7*&gK*_*$xw0Wkj3qV{NEt8fB9s9Sp;u28+;C1?Q51VSf>lt0!7Kloycc-4_GB0d zHAOLJt+LXALA2s;a>?&;S}FmcNHx8s?LFn3Oi694Q1=rVxK3xz%P7GA?HpW-4`^$rcca1Iw((sXew~EPTT|{-^f6X zOTD3T%V#2)E?=()v1KJGmehxe17@c0rY|5(yXbFVLUpIuEgPrwuewcmJ;8n<>GmPSVmaHnD?h+rzT)m3R$)Alv`!-R*CWsU#_3LJ2 z;6OB}L<4%$>010@@wXU6=1!Jocv7d`1D;B(FU9Br>j~2X*aw5+$WWKy?EqHltG7}U z@WDgkTC6)UVw9;x2%65#JmSa`piqDev*wvS2Kwa1E944&-LuKnH%2tlSW;AMgO@hU zK{~c_BzJNI+CL%!3p0+aL}JIRXmQ-ECEs zFp@poCFlmbeVJ@U7Cw>gPSXhc2h42%-CgI^=)V-4ATzmp4b0BsTR@7^_+(_iKS zZ2^p-qaOA#cStqsp?>kl>@w7KP9n5ph+86ayC}X)2h~kovZ0PlVqnBA}dsTO(PG? z>M|hK>wCsZi*O`LRB$=jHCEZ*4kABgaC-P;4}QEp=erNC@m296_Atj7lD-cw5vXxF z-(pPR%z=&lkF`w6#YJ)|q8zZ+1*uBe0)g^^+-$SaqC@o}=VdHuL?2wr=j5PQdcJBy zgb67*dR>fUseS)8go}Qs))dU|dgD^C25rcQ*8sZZfEtu4YqIjHc|&pXYwPxwESs@c zgPr5e4O_6aa~p=?IYzwWfbf`qdgG1ZI^TH9URm%1rTwA6fG*u?atRV(ruSYi+21>YXj`p z_p}$+1kXt!kuQY0YRf+Lk1R@#_t=S0=69V^;DGNfC`Y&P{x&3-dBHlt#X!Uk|@zS*M2Q6e6$TaZ;&S0mI{gRle#e9elFbK6`b@x-1 z7p{S}JOi;A#kt8^;!JnbYZ4#77ojv#&udu#)2Z-WC$^oeWu4Sicq0xNFuE#PI|!!} zYTOtBkQC!7IngNUu~8cOxytJ?=~oC+gizMVt^74vFDHc>+$3^aY6~YR&%3d+HI}#s zBEKy(s1{1Jt7y#u-L7;QQ?T%WS9F2aBXn!hRJ3eCZLhryLzLIUY9E%)=yLfz-i{zQ zy-0mUn!=9CRQb$psjrg?SZH|qrD{DW_oZ>QijM>8^cqN%7LhmirW!m3;chi`W1)uF zaJFY#3>Iu&lsiyq@MgRW{9Tt!_$Qb>i&rxGasv`+)MTycLzyoJKwfneMT}%V4!4O& zZK2{BE!)!d_-ID7n>i^mi8-ehHFt62gKn*?0Z0d$aF#uNztC+(V1Ir=cd15FDoxdp)o@DW#aJn2sNR!|LE@v?cst#S`2ITD@Qn| z`9R9(_G_nh=WOkT+_=V;i){-adowUPXv_gRk315|G1N-II6=2xkL4oH>YsafS<(~ z4W49vbhEs&JL?Kg%}j+=Z>M<`m)h~}sPn7S8|e&qyo6HwCb`K0PC`U3L0};8M7YK` z)TOG%eD6gs?NGjVn4LjPrFvRIH#EfjGwRn=v3_!~Fd(+^l0X-s`ad-!sr6LeMLv-m znGHX7J?*n~C1869Y@16e7r?6&cHY*uAVu(>q1>Iuw9~UnCl{xR@RxbQ8@gXT?ug|d zuV1GpyI=5Qs0Kj))!RZr^WsDq{Vm<|DSL6-FqdE%lR}Gy`?EI`u`{Hp>0@|tdm%3D z8nbhaqT2v4Xo4Z&Q&V>PTE0t9mx%U_g18cVv0c5T%4N$h|w4(wEmJ3$|+-3K96%F;F(#-C4U!BH?Io5#8|ZkOu>zynM{{rw(~aP-0~tl z*jUvRV20?3F%j;@%(_;>&ig_vZQ2sN{k`DJUw6T#NFC`0RH&1xv|c8%{)+DIxQ}(E zAE@(%=52n434i9jbP!AdS@#SaJ_A~3S@jJY&8N$hJ;Tni*$DxX0wOPhbqBiieqg?v zreU{JYjsnsoUqYjxjy;6c>YPiR?|}7$H%V8;cHL&)O?W3h?R9nZtu}@-CwFNuq(pl z%r}ifNE{ppdLgD-jJj+zN~;pZqAOX%KC5}uMtAQxn3=fcuNIZ#r9Su7qi7=e;&{?Q zf4H{cmQG;Jj2RQd*AwHqfFr5gS$(6HZ)d(XU#z>TLg;UIXT5Pa_NjnF(qrHtw?)VG z8&8w9c$1gn<}V``-PM07wn0*w%PgLE`y$;}$sgBZAgXX+D zqm_lS);FP$A~Hfis27i}{_&L$T5c(V zy{u+hCz=qVBUO3f)^ap?7m6c@*9?}hkLpqud@)@GZY@VbT?%^(R(z(6hl`XO3euy0 z)SsP*%GGL#xlUe1N(9yN6u*Unq8jvE7o%jWrT!=ohyGqeDXiNN%1Q1XzYgOeiPCkz zEQx{kmY?2YS?rAR&XhIM{P(u0?856eAQm#NY$e-!4qnWrjl7%bCVMYLZ%g#&Nbj z-4iruI>8eorsKAeZ20QJz!2r4b)rqZHd%@9&*yi8mG?uqGh#qPoe0-sFlx>5bKVuUNpgcLO#V@ZlUN7<4Ei1hS>VjA__p& z1%851be1<-=_r(B@P;$OJgtQqNMr{7N4(Ra(O0c`%I1^3v(9^iBF}R%ld@V_VuN<; znZ(gpV=bgF$h!);+JiTa);QZ~YHqUbr$@UvJAH5EAyGsl>k>(x+KOtEq+1UWEu^-Y zC!h`nOJ|h_2ns7LUY=BFu?bi4)l0)pCgU3PW%TH6u2%^eyJKbAExG#acJ9zw>YBr_tk->O;2l~e^}MXL4c572Mav{NRWV1r3xQmlprPT zLg3JN%WGp(AWTj@o>eMjck}j=hghtxmf8rkiRBP1e{G;ut&NjFF;V{WX`$vI50bP$ z6QTn=Cm5(xJGHAFHjAT5mXu%yc9y-!>8UKAvDk2E2nFB#{dulDC}%dN#Zj>3Rzn{4 zq5aSiQ&Ag$mfGHRPYWjG3{Z`JZI1=$8hix}!V>Yg8A2J3;{urk22g;vbmyrFP@+?d zgxfmN9O6JSn(FASu{QL>)=Y_Us_9O-07n(Gnqt`6@Bn?0x#QHNJLnNZ$M|iqVdwe7 zpVXp+{pHG*S%`DVWf%FZ7AaxRzx^H#y%VitB_(v;$<$83)$3ZW6CXuca0P}0BB5Ts zyZt#fzCT^z34Z|gNArGlZKYhOhZiD&^QQ)|ZF$XS&xTCwKzrg__`WfX^}hrGvzy85 zw|~-UFM*!NxZ5ah#_*`OKU#s(XmO`PlB5qyik#=3>7;?UgN^&laW_*XK%y)tZG%4L z7BNcz&BK07Tbh?_A69N4B-X{PY?0%m5$Ex$0?2O1zHLR3QC1#L#w@7`!V2c{9n38h zOwa1R8G2`eB#L5#j!Fk~(A|lO(!+KBo=jssgz@sb^FfCrFIe%_%%66r5QC#}aCgx4r2CLD{S5S`K#<8O$I3ONhAXIZCKxj#sJJV(9=MZtik1eMo` zF`vSNT*d~~;WF1@h_I<*SYKCzPphsNA5o0is&O_CEkqdb3U;#GAhpdMqLWAiqFajx zfx>DQc)+4VhT^eM@f4J=glp5S7pZpnrXP%MqNku}3dKpTpA57_0~$D5R$K9sXa<7n zi|@xAs#AVhvTDo}A!Z%CcU?CSxW;hEKdm8NKd#000Seuus4S>&s>2<5pHfTM;f@;o zu)T^-bUPfo_i6LC>1$MCv}J8??l~J0=86z^0`-o1z%5k*bfqMy*E|#lu zm|3M>>kw!|Sxz*^ZIg;FE+S&)F)OMKE9NdlQLbN90Jp@ze81CjDIe0bW=5ZyW(Ygb zD>o}AExJ4VMi;K!%Ps%|N>XHYRY&z`M8dYpB*uB44mRN&c-W9$>U*%`8;aBI-xY z0HoWet`B&5U#P7kicb_56>qrpDh+3yrh?3sU6!cmZ#aSZP?wYt`Xc0gm)i|}{`+;1 zww{W2JT0uP`FA#KWf^aZ5Fo&Y-BvX*< zeF0Swx#!0d`VQF+t3FDVLXx6|lDzO;%54|UHgL|^L?miRU94b1RZl`|Fu9*fAtKl( zAzBF}3cF=(jeKA*bq_jmp_e4Sph3~3*jaY^Y01;9-^58p4z+hCj@D8*rf??(>L`Uo zYAA{BVGWZ7`@<=h=$77+&FK3_aaSnkO8XVU|5FIqShVmoLRvIhc{kURfe!JUhBaY} zMJ?${+!r`2y25`mGPcFcYbMUevLUl-Vskhh`tNkbZYCT?slT=&cvfa2c`&5o=WXoO zXcf4r5O>CGc;aLl6<^YnjKBCQBC}0i9xqO0*Q_YFAW)5wxrPEoePAR-@KO|Ie6(;$ zZESt=^kzs9#6g=&Ov@%=7cgSz0)5!in^AHE4uJ$a4 zWZ1nt?}>{eZ4%oKYRqXqGXu z`O@)qLWSwJwBOH&x2aJh8`JbfKYhS`aieLAj$>_Hd2LTp(%@{3NBBgTABQe7@h~w7 zw)KPpA>xPe%J7r;^=FyoJ$X`(27ai2r?*}bcJl`l)uc(FD>(4y0y|6(P?GV8ZUCjK zVsi~FC1luWky{Hkeh53rZuXo=^MjTnhJgLTZ&=Pv=DGfm{Gpct0FZd`(-pU*3P`K~ zqad%oswY`laQDI*CVT<0skkqoZtY}edTrp;;UioOj{l-1O}_mG4@Lk=%@S%NUXSF) zCzWRH1*x9Mo39tA0(8z(0waN|+@n3P#=KPoVF3k|nW9`}CXn2m6*|t@s$Boyxx`@a z8YId?gSm(e>X#9Mb$JE|YHRwV3gOb4#zFt;E5&uj<8VUg030INjv8N8rNI+4k4mB+ zQ5Le#dbg3~%40P5Kv!ekO6fsh$+Fu#WJJLteU>k~GY14Z=Ly8m4A*-=18yj>FaqRk zGKqSw{&^G;4lH+|0Hprh6gv^&-K&vy%q?}-MJJ<5Gop!ZI;v6q6n|NNy26jRLl4<1 zR8MMPX88bWoL56kO4b+* zfE{z^z?hTUrgY8(Y64umoVt{_rP1V45eV!{$DlF}Mp!kht{gkYYX~> z4)MB;GEGKOZx`BFWeiP7vM5q@n6~|w(XuX>xPoi9-8`v7*~x-n(Nuc6!|7A!R!m%h&%eIK2pJGy=gKXUd|68%5z?k6MbuNScjow&Sv1Q`@l*5R5O9h}I1$G_h$6zOLWowL? zy?kPJLUm*IFj6=L<-H~A4G3lvb?YkVMD%T|77eir-%>kNTRWj+zqUndnb)W7G+Jcs zOl~>Jed?PR&IzO*9CFwEzlY(?sc3bA_-7H3m2^3&?lQA)bj+^(R=`l5%fY0TNJ+o( zAOcEkXCU)pbuX#1ETM25o~%~O%?kqPE;;u+!;?b54(4C!#J#e4{NG&G}QZZuBX^p7jcylG0D(l&{y_ zbG1oV9xlMu>&RH#6c+mMpEtdyKr6KU5$oRbHWz2G@?^G=Q%-*b8k3_C<~zPzv!6$f ze~#w|UZP#yt37~UZ{thIkd_BzKC_&Nm9GhMt6%+Ihme<|`>p$K$AenMagTlv*YyoFmZM8a&m13=8Oyv%t>Ft(;MGIonSM0a*MuM!;NPt^f;yYSl_8UDVrAd z?G3bHi@fh&_6%Gc5h_&df1-ZvgBknKRyte-~=t zj4rSwt#e6%8?(DUZ+`a~LVX_d9*<9(&N*YxA zR9`l|g1c}0L$vkMmjU>D^Yy=ye!d>j=_h%21)fyPuG0)yuk}B^d^K-P{I$7rTQN|A z|MG`%X=7dilskKQWaUI_#xa331yM^0w^8sNE$yGQ+SAFIFwaC*c&6Ig_%;1jiSchP z*Z9fF)6vcE-VgN_%_Xuw_gZPbq-kN>cRO-Z|+Hrl_hZxaX$DG-lcMn;rf_g==;8?DjDwK zdn8E#mjFf~l8jBWRMYr8Y3o*ptyZ8_Ri!PfXM8{A?1=(s+N(lv1eyfrIH;08gtfcO zc$c{I{M~ceQ9@#9QHO`xSh#v)?(*xNj_!n7k8P;jo+K4Foh(=W;@3?-cV3%`ID2Bw zvnhaHo??;4!k+q)jO0-zj|T0WUSj^Kv(}jZP%EdjULi`gPF1R=C)JK$bs1^r^{!~j z8DiWN&m03*G%R9_OclaW5s1fk$fD9R^zO6{nxqj%v{UOZlwlEK&1c9#`{sIs^xix~ zOkTHQtfA|k%DorkC(!&-4}Dd&J^WPX6^AA_FOr5Q*XPdSkRp}yqHx2OGH)PfB0RlSOZwa_~O!yqbW>9U@8JcNW8glC@-e* z->isFWcl>X;VTbFkx#4@jJgKHSS=cl4i_WK7qi+QTA^^T?_#lAJyEW1OSshw!GSCe zA#N#^viKaa3T`Lg&n+;%BW53#0D!3SM+JDEWp|G zB1|1OXKro7&LMBD#!Gm{>!;EpZjp8ks0Hx&PDfS&&(pv2+aPhXMaxxG+7psWTQ)L{ zPqYz#-Z)SfQ!gO+_37kY zU{pso&NIs;oHgrj#b7o6jx4$SiYO;^`v4#A$)f;qQ7L0>-}Ze@Id2#Z5vD(;^e{V-JYW> zN}dXjG+T$c{f>s$r2!Z4an7KC?#^?1EjI<}X9&OODd4~)kFOfNJ5BT7X(9!* zCFOGg`a9)Htx}D^c_KB%$rlC|d)dK2C>EO<2U3E0m`O}G=yoHqRZ^*VfdqMs0rb|n zxDKhcn#O>{-^N@?p1O#Up>B2ISJ*DzcGV7ff{?>%_AShYp`erCv!>Ma-i(of@sF)3 zuNq&1Pw4R|wTkTV^(gP~HL7m#mYu0Wvd%VfNMB!mj%n$#)Y27J{us!3K#MB_mUEtu z8c=kVfX-E=X5oxUYM(1V^kmJlBrVU%XdXC@S82?&D7G?LgPa9wv2r~Pu4(4$c#(-_ zy=+3Hx2!K{T@%L{ex4|ZG%$-6kT_N8C~TLoxC~l0Gv<;@yX{*K(l_i0vxjtL-=xba zA+{3B*wr^C$>+jl?dX(u^r45gF!ki66R8SguF|p@gQC$x-J{nvTOPT44pf)abb;FI{86wfe;`S@?yJ&P*#nxI( zjEEV6U*jBN&8RiP%YbL!;un@igr1yT4`6$-GaY>UoJ;#ttsSJ0W96v0{d9UOYE$p2 z42(O}Ib&eb_oZV}r)|6KB5)<2fi8=JkP`?dK9{%oZw|%=)Xh?;-y=zIho`4mV&`aH zIe99*jI``GBeP^aR>=Eh_6_g@?mg>xlao)MzRpfy#am2BnYO7imwgdqMJBz>y#|50 zav_9{is$y9Bw_?qjq7fpFyr5dNwd#qIw~a&zYfvZu=d;G@{SJ*taunH*XxBLi5FO= zteVO-CAduztmA?xgxkg557ue5;2ygq8JoJZl}Ep=%(El3OEf1%w-sJg$`)!GaR|F` z0biEbi*6-@5%Ucd;vDlgN6vsoWROWK%XDu?tRWM8-pva_%*%6P#4qF6VDeLiHUUM= zNTDYh44m2eBM59;J&g#_kF|OF151#nnKL9Zx~;+AP!0=u4=PN!fnwTwms9=qtutcb zlFs?<5>So>pr*X7RdCg?0A-Azp(wfL`oej8hdbbBlFr&A%#5bjZW|jLTf)^xns!ccij$*jURK|agQIIM9**@(hyg8( z>1*WtHvnkl0_Xg7fp)(DB)NMa+UB_~x~^xz8gM0&K7hU~(O5h@NW9Vz2c1ibyNy~a zp$xK2B)g|8!E->JK%CQEKETZWNpaw-hT8GUhs(_%P*FcvoD2M|rc= z_~0A%&G=ZJo0SEt+=_BB9WQ0aV+1JziI_5;43nc~hvL+4e5b~+LIoq?Bj>z{P+n#f zSN%2aS-kw9bU|ZD626W{LMz~jUaVb}JpW*ID!6_-{btC4)K#y;fH^k}B5P^Ns1A7m zXWo!UQ*U03o4#5;IfD=6cc%V+T8VkU92H#VRSXP0|g6qaXgy)Wc(z=ryoE_%2QDM}zRy-Sy79skb`phzeJ?**))WaHA=NR1Jq@~5VJe* zHOv8uqh@baWbPPdM?@}EX>}_Ob*Mp2v>P^&0X)Nl?+_a~GMdCSwWu4|2;P6ZwU808 z%WJuTjDuWXfeUxNQ`nzcp0G7oyA)lowA&=r%)#kI3EbT!;VMvSzZpE)K<_F>xI@dL zDHg~IKSDk&Jo(uO#!6_|6~+e1m`vR-*CygD@7`~*txY6la|4x@ojyJ9KC5V#;YGY* z-a}All-s*}$B(bp_xo}D9t>$}!vT8Bsv$*TF_cAH_yg7#Yt7>A^2u)XV_kb-)FsrL zy9U&a%WYZFb9q2M9yF@M(x2$rcBF`P%k$WC0D~^YpbMOG2w~pK0JNsa{j=}$^Svu2 zkFXWfj%|M~aP1aEKk_S*N;Xw@N1R8sY|h$>nf zW6Ds<6iJ;4hT-KT4rwS)qb-XzjXwVMvT44};%DR^fS}8cAFd7(e?)KQ5_$AqM2H#g zWdQS1Haos1-|-X-)G2u_ReW&F09?t$lBc;`6mTbHalu#nYJkDLB$d%E%`*11W7 zIIhF(y&>A()ECkd?nH183&LrPhofcr6iH60l zl*iL(@Q{Hq+NSru&iv4vKhR#!_>gidUaJ?^7U;`hnJ+XFVekIL#iq6 zRMWUAq%G>GcFNhG%83{mo-W3x!Ul(4N|N;U|C|!qD3cvw;g(!1xvr5>FdYX^-l@j) zCRwg-qDpS!2JdfQs)JQDb?22d1@%U6@PR*~D^^h4Q*EyzhHk8|>cgQ#18r4)%5JMd zG8o4@k%*lddL_SMaBX&&%CbN+B?1=TvXOW2SA$N_-~|W%*Xd!V8T+2M5&?)(9=ZNY z2ma%j*c%)B<~9GbjpUDeWCsfOBt=% z3Px0WW(dF;~2AW{0h?~iBQG20z}LC@CAyma$rbcG9$Q{%4I9`jtCw3`kFxl z*WmgL6#3OM1DZ6_hurQP6+E*7fpt97pQr>Gu2CURy7%KGTTu@>9kZ!fDZ=d|ZY@KN zJyc8Dx{t>8_TJ`~vNV;$eamJo$JRudnI~1w&PAmdYVvvP%$L95GSkLuq zeTj%F)9sty-vcJ~Wg!z)NEw5|Sk6>c+qD72wxqJMJ(o9z$|ZUQj~8EqB8N9n7QR3jJR;kQKQ3i-tR7$7h9-jK)3>jx*?9DY4}! zKCXhhE?KY7Z_3+ein|3>y+gxm{4a?sjZlf3$F^=X$P-9 zsg#y`PEr=qMC#9{MkE(N$HShxu~9kK8n0G3a6?Em?|%a7%jPJb)w`1>cvgtgW{$%Q zE>DQ6mLS>>3C1UF1EUdVGQ@4%5FKbY0yFkAzNA!!cx?>96>LiSN)YyPaaIeyFxtum zlvpoSc2XS%O2>9GbFkAWZa5hje7P6$&6Lokq0AG9IVmnE>LQs2>sdU*pM);G8`(3| z(ir<_d!z~WWMilBS8H++3TpFxgaFR=*F!?ap$e-G2^NY?S_$G1V*yy`kMf+Mrc&h4 zsD{koj^cRuQ*blFhXR9{j#!+)rPd~;)8=|CEtB}&g8Qc>8o#Ymc z-7M#f@0xOy+G3@tYF1t1^kzkaK6Q&u;+09nD4>U3%M0#)m;MoFB%x;GYdzzntc{S4 z(8)fL=St-*?LXZ$?PiGuwsy@XR~E~^ej{n)O)t+L#l>8+uuvYpTNR3%_~eDbAoBM- zuCRD3`IETW9tCpG`Qw3@pS1uP^d^**E9en9SnBTP#wX2SepLM(e4iBEyg0acT?~8* zc!3%DCKo_K?Wv_pL_3SrWgvgXT z+t*WXuGOW!o>Mv2N7|gC(NO-UyoE9O`8d%}(+2hU=`;SJO=)en>h0(7@crWGNWDwk zupfn5b2e90a1genN?Qj>@{Z(zbdKaXeyc#~0p@1?j8O!pL5YveU&Fo1uDM&l^^IeB zb$?UXqEk8uK{ct#XW=go?6tD>QhmY+K+gM4w%|%el+PTS zekJ!9XPf~&=Xm0qJDzAVWo(DsN@XRr_M@nsbqs7xczIjiCQ?#~P_}ZNm4uTz(eA(n zHU@*4>MOXVN8oVuOuKIc&@<^Nzwf*CudAC2Qx{HEJ9;_zzA1Y8dGWXL@75|Qy_mh zuw7ML$GaVPokwrQv9>cw#5dbLzr%td$cN7QLpIMWq{O{@f!&g2e5c_(P#GL0olOjF-#HioT{ zj!fi@h3Q)@s*Vpn4J&nGG!HD|*;cGCpR6Wi%b^gcks-v;rR^o+mic{VUnD!pE9d$c&*U*+p1w%IRK>zxgeV{ zfm&b+$jENCvLxdZ+Ik8J-0JHc{v{kjpodM4P>W;M&31uL<#1daVLecz^As)sZm$dbc8WmW{vkFew{@;<>qn)}p7+R;bM<^lT=nlbyw?ft7JirgXEX;pZ<7 zX@KK#+#YCt8W_33OW?xmQb7e#m*CnBEBN@pd|<>#I0AZcTzD&ghqE+!*j%aQtPzd2 z5Z`D6*f5B&)dM?M0L26D%DA=i4>G&sBJVJP!)w7*}tF8b4_Qx4Kt8eJ3i zpjOEry&q5~5$F$FXtNAO1A%@Uxkn05+roE)3^=kPY}YOadoDO*-u;$OSDab4k)ek1 zG{87BIyQdp=V}A-l;|8&Z~1p6&{t2(RRNK!Xl86SSx+yzhg&vs&1F|&)~&P6=Lv#F zLnM$L^OEM%d)==)9xRTvuGxnG?P0beINOiw{3_L_$8#$t_#sW5?xc6u<0(ZQ*I%NO z2*8L<)lG>xDdTaq(8;{59Wk!su}wGJArD$l-FEdNaGa0fP>U0p66rFqM5D2oJ6%nR zm-8=EG$P$I(PnEG8>6{<6Y)&lCSNd^2!o3bMamzh6Aw zzDv5n>>F}y&xoOCSG_U2C3|b{rJX1nFlbPJr3kD0FyxcokMI}ilEZ*tW|Rr(Nm^m2 z<_@glyr3jo_o>gn+w}nBfQ}`Ig>$%ZfFhpOUtOuCJjgzWK;q`+=jYk;=3Z~!%-0O0 zka-k@n(3*P(z_B3vu>qQW%i=dN;XIzKxhd=2z(2;UcMw^U#Ou)lO zV~S%t7u|^vh0g$2R0)I(!e<#NBHyJp!|#aVN()Q<)8zav4(~@|e4h)ijR)(Jqve@l z2=%7eA3PBXxW)J?Ex{hvq#yjl47LF|HKTM-n1zYXA4p%Z8()}7FJGq}{rm=L+9S<` zmoTMv*hwE5Z`7ppGb_q~J9*6g8KKNeOSBqiWduful3SYS*9n@p;L8~cqZQ5J$I>aGV_iRi z(JPnmcbM``1AGzdJ5b0=LK@_znXlFF*GPK*kdB7 z!b4J6Y^MN|#;I8GxnystK}>A6v8S-0u2P=s$&LNy2l8Y#7LfdZZircEnwHaf+g!d; z)VMN4&0A)#pkSv0QzR(KS@Ll87Gs* zSW@TmRzrUmZb)@8Ln>p%d*v1vFkk1_8 zw(*Y&Y$J`5;3?6Bn6a`+C)dSk2H|RTA;pCbB+du8j@P3Pz;*@lP?v@3GF~jyPti|Y zZLMYDS3|>Imy$BKgA2jsltx62%j);vlGY2ecU(gBEEt?3@f(RRL9xSPC9nWhz{Lf- zRpZf6?zfY9CBKk}0U`i8 zJ8VGB=CToT*tZE(Hs>D=0B+I!Mr9ubsW}eNNTi&B`$KqDqt)8rD^sO&rfe*koS`xE zG`h+n|CW@?!L{Eq3-!IO33XqyPb4G4Gq3;bUR1#XW`&t?CZ;P8z98@=1X z$==2OA4u;1flknhW&IBz_t@-3DuvmY6df zL@^QU?NRa&M&CG79!{{ANMwxN#HZDKd`c}k(imWtfcPdZImdS{zlQRFlI$bz9MW{l z_6Hd4+$6iiWZWoi>Ca@L1ia)}O~mP^iLeh=*g%!|nyB2JrWXl8=GGg+i8s=c%eaJ& zh;^D-NwXOr=paqPvxLPI3uYI=wkm>hK~c#igZcUM$Wzyb+a&O8ME;0ogCS)E?uZa9 z%XBA#8p1hEtfj;n&Ff!8S^h$*Cf~;awnrOIa-=31eiH=k2F59gCQ{8m?8sLzL}v*~ z`q@AwvK;-iFV2FoNeq*;J1Bv45^3cF^vC4J`QN;;?gQ<_tEkHnFA?}k2DCqY<(2~L zgw1~#uMyKe;jRobn6ifFNXV28Tpc%w<>%?=XshL+2~uAtQDp^_5=6zFWCwXzCTOsE z%(BT4A9`#P(J7yJkVL{sW+f59{hAE*FitEa3`Ey8k0UlC={{UVOwIxKV^zr=NCZT84?ewUwLajfB!(sH z;3-7L?&yp@G&u7bP7mg+`*}g9`4QRS5T6$pq;+q~>3=d)Mi|f5>vam3ws(o`=(L3` zPllw4DMPkr%b(kMf6R?M-bRZ!>BH=KLc9BFUiDo$GP9F!JuBmpHc0Z)upXI=58#LTOCy?z# z*;w=7@+{07Z*yhMjoTd*p{)(ve7>`~sY-=y!jH29H*R%_HP47=k;=3kwA|_B##G~0 z-o~0aclhH37^2%D{n7#MC>DLi*;2bbD8s9?owXB7UDs$k3pUITPEQ2~&M@cEPyq*< z*5R5?U%qZd3Ph!Rzt-F9)Om}%?Zu0$`o~*gc_|k@ZCOg5``{U_!oj=0_GI75bl#)_ zB>xc1sO@y_Tv_g$Ip>cZhYLEi4_DFsw+I>npHu%tGt$|lGk{#Auh?@lrHCIFa4lCo z&&RAk8J@>Vjy`e9B(|KrJ@_AQGBm{*W`_@rWsykZ1T_T|uM82yklYz8xB)QI$)$In z5;tb2@pae*G(iW`0pYu_id8P50Ep;x_v!1)>_~OJI4-p?jUC>r1O&nZTVqf1>pv)m z4hfe%lB_}-((XLg0vtWt$70EpdqC?r$a7Pc`1$$$Bezoe1WKaA!v@x9Rq9QUgGzER`xzTp?I)vfuI~+yXCP zt&N$URal1v%`hH+9_{!7rJFP&P>7P^WB!;PSs17q<{d{i&}Fcw1fI@(K_xFbaE&o| z138YzJK&atwJ~lvf?KyPw!uUY$+sVgkjB-d$Z=8@r8stip6B7ac(v#3K0rgW`%gv( z(Dz8Gj*8~H2w~N-r}*3@8t2O-Ww?rxR;P=jZt=oGk14+E%R%(YZCxgvYaEw=HlNz5 z+O*n^GW87OcO&W2H*dL3eC+L%tb5pI7j&?)tEd`bcU{o)s2;78)JP4t&2BRkvf63o zZnPCgcs|YaWbg`?>gP{54eC&rLVM?UNQjBwb{6JZd*v7-6SmGYq<;% zMhF*h&JeG59^Nc=$BeQz5wSWo9aNeC2etpG^t@Q2FVg)DJp(< zjUON{B9-{j6NuW2*47*pBqHlfvHK|?g@_h$O8Iu5B4~|iJe+-&*P!5PK#mH2YoNV6 zd#Z6{qQyPAgb?D`KhND|7HD;AD`V~~up z5UKu=?3LCeX>C#fiz|sV#5tRFd+7^+;%ETGYnO$ifraBvY&1LVkA%8>al<_6kIM!2;pf0LX_z85sm}%{Ps{p`lS2 z4yFaSN0F)^Epn50rz%C!AKnUV`JKg>;xlF*YoeJ8c@lSWOSI4TF?%=FhIuVPf-s#N z#TZQ8L^8OBmFzb{TvUS7jd&Syg=xz0 z&$0Hx&FP8Mp$6-ik;R&4J(+&4Yx;gC>xZp#z`Ei|H@;M0&ssG$GmwyuG7(@{z63CQdb}gU2{E)0c48XM&nNMXbrf`QQIK2D4oC4;0{f%ub}0 zDD#~_Wpy+v)f)PY<4slODbxf>gjj;;R6?%lB=in)j~`o*tffU5+hl*x!FZs-CDcHMAjz@wlUksu+2b{E)gr$>T`F@oRfr~ z5`gi<=P-}~`x2>!SZ>OF3COCbls z2m1*JFvPQcwg%M{EvtyEmDY6>b1Y+c;(mxIy@pGdS~CgZIG?J5Qg~cJX+j{Xt6{j< zW^k6hxU-u;JYN10_X32`2VGjHW6zTDeZ@Q(N=8JJfX4YC?}(=_QG76y6-~*W8CEeM zD`~>pNs7%hPf-_RF*QUMRoF)3OJWm{b2KW7+Z13@(!HNk% zTa)1D@jb2*k%3B)$119+9IL+$W^!`z6qy>-X#Ud_O){2Lp$wVHy1_ zD&m3KHX27CEZrj38q@mEL~N$K`xDc*fFKlo`j6`8`F|DpK1Au0O$4#=-|Tjn^b!d<^ln{ba)EhKsO1^k}wmAY{8 z$i4CSGV=EI{8c)P-GC%Jk<^4cB9i7MnJ7GXkh)29I|0h0B2SvAk}ndK%fkzG0#*Cv z0Rc=0)})(JQ+QchoZ37oh#iU&7iWg9?=RT_6UFm%5Uw}@@s=JLqbxlVuEYdQoMv3M zZ>$7BjN?H`mcYwr^rCzsc=3mQ_$^PlWlT$I{E5S45JR)}&77|TG6ii4yS_Sk1e=LNXRT&H!@g00cj%YYbfUouTE=&W3$gGk~=H8qW)&zZ&1K z?p_UN4UIHsEgHFf;jb)sx-kCm6fil+Nv{)6^6M;8OFa z37FTQ_x%3qO%TTz2=#7lA*%~g^{v>N9&lA;(n(QC{UK)oPRg5(JPxgAWnU~_f{pmL zv8IunRw}C1h{mpVD~S~X4fOyrCEUZ>!c>%#@SIj+f^qsXBMe4f_`hM~+?Y^phgp~O z$+QIj-s-<0Z^@|O5M)w#7DEQw>xwE4{;Fhz3x_TYDi8I6FT}DMKM5uIBaFOdK{C*KR_sb`Qb|5%E zat^SVBjaarwj8vGyfBrcsH!k(N!L;Hfd^3DKy&e}N}Eyr@+VPI4X!(kFYg=?=l(+# zl)$t%a6(@qos@>;&_n?bJPrX63G@#1^wr$Dw7I!X)fC$=5tMy*Di>rIoCl=2Y%0je za=;P+(7)LT&e#SS(5y_SP3Pj+Upb}U+Fi|OCuLsQx@Rd-;m4Ikx_*yA_w3gA;#Xck z*OIK3Sd7IueUMgo^AB>W^hKP0dc{X32)a|~Fu7$=SM+T?O~&7Mg-H?Es+f9n8hOMlqFt3;+MoiWOs1$oxlTrUC*0fcIb5F!FRU)pvGrva~bT|IerY@5dDX z&03)<({b4hD7#On?UP_tkX8#cRT7&mved4NYF1_0il&?nkRoGcucXL)mHUhM*Q_8y zOXCD+uP>bAeB<$)G4(B+$|P@y!}gI1MY?%;A}6g!FvXLv3tAc9mR&ktm!-%@G98%)b2mjHiF09ubr0-+`^Q75Ec zBi7*Elrje4oQES;EMh4ika%lL(2MindEKBI;nEjhlEN<5msdO~i?M|)_Z?CEF|y89 z;P*0O;mcOmmE!0)_n_tum(Op_M)AC_#j4RrZZZbsc!8&I{)AXAXWAV~ONTBdj1+_A z`3xcJKOAa|?sjsJ-x={uW%uz)?$Bm(Buiy(mZo_$PEta157g&3L(6MXYtN4=_niBd ze`Tw{hS_-rOa25&+94e~ho!OcV^s=;Zbk0|+CMjo%nv40^1S2kY{abM>-(CE(cf*) zA0&@|BJQ`oC!K$xSZKb78_WViDNZUuGhVm5_K@~XS^k3k*Hf+xp;rLpe=-$yM1X&x z^uOGB|98~?f0Zi#lTcb!-?!doK=AvkPq>4g4T_vSRDhr{xfU+SzR^{#7-EyJUEY)` z^(gdr{T}5v<+@d}o30igNz%RR^NxFW(oJ_=Q_28KiAc(6j>MCnc#$WEdk$Rf(wb0S zVy2`I3~n^{xX6!8RFGnpN>{QVZ6pz*h1Sy0s({QwR=f0y4!-m*kRFjKY0SyleA>k#E1?WGJ$=4pS27wL zS5H+a65p3?Mm>PYJL*-(J|0gG2F;=fvQC*YNi9P_0xNFb7|pFz1rXS^g0~$uyj#1#1s&WIoxGq(MSuC+rgfLYV^3BvRK0Gy~!%&?E&Y$?ha!A*xUj6B0`t zEU_i86+r|SZn~rzZ8^sL7~PmjIC6*UyXa@VYi68oV|zpb7L0(jugF>a@&6) z`q>@oeRB>!5qZ!>JIJHzc65C*kB4qSHpY0-?lc*^$659f?^OEaIsLCG>+NIC12>1q0>>z5|5yhrwVU->hnb;b{-4Lm9o(&`*gyaPSYQAExc@cLto~ysF?4b=^!%S> zTC9q6d^W>>%#a#60zwpolXyg;R^8oIT9v&myrPM$vSg5nj4N!}XZ!9m)B=#S>|^`Z zVS{<+jpnfi4+G4j{Q+kWOeYUEe8xApD6ct}tiJ$Td5T6}$Mgni?pmL5*0J$tat-rnFdGG^S76k4U-)up^O6 zx`%5O{$*R^RV#vAWtWGc)T|`=j$SdPVjw?ZEa?;k^P+XlyjaK#W3Nathxe^v+wMjm zvUwA|yupJ8x=}cPCSbSlJM9(lu8k_?>J{hM<8A-(D7xo}i4mRs(Q1SGB$gk_wrLWz z{i6%?o2s$m1%?+aiz3eAp~cU1HiD1&;Z`1)+H$5AYU~zRcUHwrUt~`-i)>tkK2huI zQ5%fGduE~>&tx8LiCH!;Q@LCvqQ>~%OfG;p8XN?#)v6;qMGEyX5YOU zy=#!wpDaCRxiMe{T)1)6UuIU0Kg1AIixn6Lx z`n+q{8ur5qKIt*IeuB4WQJu8<4(WNBme`h7B(8by$})oLf5Pm|x!c!hIG0pw30%=% z!d~SgP>@JN)P2KVyPQ=&nt%Qe6OI!#7}hm70DuV@006;%6LH3dHa7pUk^WDD?pEuc z5r+-oKSrD`eG?hvuckTyc3@|mcG-0nu^U_5t$+g3EkmM6EK$~%u%ExUmFUSyM=Y9M zcGn>VZ4<{2htA{2xVv`k+nG>OB(MnaizJ8_iP2%-g8j?yf&YM`(c~8a5T0_vShkCV zn+q8YEz6;S)`P?$4_wKa1sKF=q&gBQqci7#qlC#S2@I}?nx#B5B%_@oA`KcwAA9#V1Co-5P-?Plc%Jn8Jp=V+ z{PGUoP2P7zDK(XI^30hI)ERriQpo6(j;Tob;*GDQ$By;Nm){{2KmAH%3V0^U8=phtZ;SFUHe~=-zo<8A}>?XrEMIRuyC9Si|1w4|KsHNh0Y4P#E%x z)a&<)g_B>bP9pL0a{97q7o9?WJY1YgxA^c0KP@L;+B6D?%fDPZwP^%B+Z3Ln^UIbu zZx*Vh6gYPJ;%2xApLWE?>(Sj@@{#*5&uRMfj|oTX2@>v%P;^?l4$rst9PT&wEuL%b zU+k@&hm*)IDsvEDkU(Uh9aTdk(5IC_KzQ^Ws+da+w?=0~)9JWy`i&ZEK~#m;B945c zRG#Tt&n0Z(VpAc2e_bc&^EP}yOB9a}9~S={FB&MkCjFG%rhpt8W_=zLJKT(oN1v@&0LyA~B@{s3^< zrm9XIv{WhXxr)2dBEd$EI|TY4c{H4FZ7^jHagip(zpu)2RC6TuWiYiD{nXSm(c+1u0HDlae`g5 zmbftp^yLU0lR7qddk+UOvw#(zu~?%xYhGCnDa=E0^t5m0N-3e9QJtAmsuuS*=+Ppw_HeNGyv_`rw#4T2MT7Zzc!jE4 z3LA|&t@2;xYMRZdS|;jP(n?&7T?KRuFF?;7X~(VT0LQ9GoL0#dM(V#(1)+QSTiIm^ z!bs&L+{4!Ho=ctBenF!fE`&-XO99Sl`E+i~I+uXP8#JATe(S~WeshGHu>mr*X^95W+1LYS%s z^JspxiXM}^Jc34ecE%#6P*$f26Z(#tbRH=rL8qvP=7q1V_{Ys4 z>Jyv3#QA07)Y1YpHWb`Eq|u-ib{H%G7emv|0T8+r#Cwa2#mm3 z%9*y($@~*!KZHo%{=swgV&$kEp4U~6(fNdKj~rcRzRaZ#pO|+X6-s#I*g^XmzJ*8` zPvQ0O)cBXbk)5lXbNNr(T+a4khK^E-SrCz@9EnzQM#znOsy>L0c%ACRw{-NtVJ|bs zPJ6?mE}EI7eoypsX*RkL(Ftu^Wl3WD18#IgdP)y45{5_ql#-I3PU`B8!%CEjITd~+ zl5c!iE3DH*v#P{xa&@06mS#KWs>8lvWnRCBf!b3oBy??UM?F|Ia>-9+?N^3S7PgEh zOr`o+b=7lNMHge0wML7o&3y?6?#e%&Zcs0Kz}TN~u3N!GFmo!^Qw0E5mK&tTS}03o zW#Pzv?d(EJ4{wj#<`qu0D$P9S3z-$E6FJHfi_+e5!r+8=dbxjp_Pyhmz1+3z$xj6< z;X()whCCoMY;`st=&`iA+1f6S(X~<^(uzs|XMMxjtBchH;lc>bz5&ZUR#A&E0*$}X zUR@@jZp;*9nf@c7%?_}_fXCp6rG#WINxpGC?05Fcvb<~3$QDg+Wb)Ci?39C{#!pH^ zN;y0J>dmR%eS_x@eI)~N|1B~p!ZifNQ)Gz-NXm&9YAY!`KfXrv_3r6?UV<6Jo?E?* zj_M^ev}^u$gb$g{X?$z-C0MHMdtKp=?5|DyO?%tTdT=vZI`!-ANZ)nmH>}h<^xjB^ zD5r%UfgniEN+cbf2iaV1z5UcAe9UCE7^*9uC=Av&xsie~K21r$@ROQn$LX3>yt6@3 z-?luO+0kypC7~?H%j)mf+sxYR`1%Z&UCEbzN~&Jcemk_U_YeJ^YphmM5-bcXC*;`2 z!0q=F%+6JeE!ZbV{M+@%J?4UM2Zb`f?3jUR%&NVuZVd5|QA~Gd)0D9)HOw3Gf2|K6 z=L&So|J8@7|LQ~R|7H_mV(Iih+COMW+aLeanP00TOpOZtGYTZ}p-m7%G+kJ419f{- z1O*tBB#zS1R*{&F|Mva6jU^!$Jh1VmfNwMB_a_i`hn|j@G2Z;?yGBK15lGuMYX^?)3!v)k z7*~G5``CDYsj#;88*eAllO}KMNRuXS{K#F8N6-g2$WcjXGyj*5?UxYE-^c95c&4Yv zybtIFJGF{@YlA<`wqgMdKeb@?N(rLr+S1=nrGyH?n5!8BpYm}qlS58xGyWal{jWj= zZ&$3ZhB~sQ{-qQI)7S2$6GYNcoIwJ1okeSOsq|Qok#q$X`G~l*ZBCFJeq$t)^%8S* zUWx*tGHOQ&kn{B(jko!2yUMEa3*y;dwmuy*Od_OOJaS!M%eLxJxV8I+m%f>%36W+n z-S|DySz|DG%`iH;1rbIWWjDVmb!PIeMUI$g`J&)r?jzO2R!rs9A6P8*TwwB0)7f^U zddXtt*xLq4GYZB^7iu=M-`nJydKII;b||L>tR#HE`9(?DB$zD#DTVSdsHD#K zF0s4N{c^Fm9N|8V$F(H6+;-4sMBFx8@KBWq7A>NUkr)cv(pVZcLA2pEok?s5;2l-8 z{aoUv&n2)_f28Eo=$odRO%@Vmn{^jsZ6~Z%ma2dNV}(HNQ-0#O7a01}`39-?AXU3z z__248hgm%}gsEGdr013aYcH=f;HHUV*3%|^LQ=fY?eqo6f3;lZ#$R@lj-WLM6_s6X zN~w>}>4VS@WixoY7UOi%?Ml&Jtszx9-ui z7r(8|*Q$0IyJC;Dm|o7_Lcp9lj3Jh|^HJBbX;)0;HIjc{#8C>3Fg}I?`yXX2Z^0$f zJ&MR6Qz>1z#zDq$MT7;%D{(_2Me^tAq^Wj8ss7SH{v`U*A)b~>GCe@+d$nUGq=LYKRdQ<6EwLPAo{TEI3 zps}w$43oSLGrAU4 zh@B5d@f6A#GcE$rK`L&jvHVLLKhCC?z}b!oqy)mlgNwi2$%X$OeK}277Q!HNy$sas zG!h(n9W&~d(7Bnl$f&*q+Lj<{^wmX2ufA>z;)pMJeMLM{>(5(wVUOmIA~SeC*Y(h2 zf?}cNA!Wt%g1QQOLFZuv-Ap(9srbJ-`^tbgwrp!0g1fuBJHg#8xVyVsa0mo?WkIt{p?H<*WL84J&F{=Giccl~STliaLszty99o15(G5PE z++Azu9V0LHESJ^Uo1dffoiz(jw1b}d3_Zk$2G1FKu;kFRQ5LzB1?dQ70W;UA^S30q$F0yre2<)Q`qI!S8Ng-4a2DwvXKgj^ z`%IXMXwcgTV|0#;I&ZNb*7mhhWK&kA!`Bo>3avWrwT53Asmo)*j?vK7B9fJu${xIF zkX@v3x(~H|bog*TUu<2`d`8qCg=|PhNJwZKw7=J;Sl^qxCqn0o$I9mQ>B#r!aTQ0& zNT)spUQxrknw_odCJY>%ca@&fv;J1e{DfDx%C1dY#VnfXs~_%uMVz}!0#D?+DfZHs zZ%a~|R*a+a9S3F;K37$6v}|S|M?C?wl2R_h4xpS!uxi5pql|D_%?OgmOK3AivJf+W zWQQ-|3gH%-oDikAw_i@K8*k7!y#S{0_B-XqAQ!+a7Ku#rn>Rjk;inO< z12Gr*34LXOXHEl%hf}rMw5lVE#7yF3`B}M<^Y${6Vs$4SJzVCOT+1ZmmEtDb$qmnHU zx@k1EU>LR8B+x+^bTNsd@o_ZR`G)z(f-FuLD8{A}n2b9f@`0$JZ(GW>!_FMhgSW$O z^2)W3i(RFqVb{CbS4+1_x+);G!PoEraGvi|cLVAx=co*4-%vxM+*AniKXG1fz7pkm z)PVRfqW+*(Ys~EYqlsm5ypEC49&Kjf{6qSjce_6Eam77k+wrktj0gM z9|F5u6?VtK!Dwb?Nio*=Y`VNHmt$ZMmt2EkATdze4fAwu>2v6Jvef|nYq1f84)A7? ze?p^Ftj$XKH$#h;6jgsZduC%<7uA1UC#}BPlLV$e*Ade5oAVOx?NIEG$_w3M7k1N4 zLl8vjN}j&0;edl5FOILECElB66W;kTSPBLOuDiJ8h3#3F`q=H9Grauv(we73jH&{{^Nf1z z2$I`vwP^LnFm=cu!>0Kk(Ao)06p&p_XZgl5Ft9h1t zLokf^l)BEt3c#uzWvtH1$x6T1E5nz)jj~h^b<&gqJI4?Q1;QRIptkl3BlggmBkyZc zgdNc?4rfu(Y2lq0+5oBKz};a~m#V?Km;< zo69($)5|7!Io|LpcB)g4Rv<`S#I%#W74zQGxMnP+gK&r~p(b0fj_I%{GKwIt{EHxW zWNGFFWTTvGcazQ2hdwWpZxD0FTB(Omx5{#7SBI>Y6$xi~ep z7RQ`I^0#_kQJeS>V$30Af+BLGq7)^PhNLV~yl4%gEP8{R3e=aogAicGg>ky-qm;q8 z8HJx;LBSoFzO)uw(G2-yJI_)?yCGM~EUk5NSkpa4q6-fd)>f0q;Kq~D-OkKnt*;U& zB^IQnsrv*+2F~^Q9h#~(lP)K<>Y1q1hEpjd@IO}w43jA3p~2ufVh0YOVrOj@&fcbS z_N|vAUMe{aqNR`l_>^3=3h+&R61ahZ%dM{dmRCE_4l%HHyQ#OpPNhKjO-9BHP!Csl{MNXF7yBbaTt^{24S#_+9Pf zacFB=WF~8BEH26xTbpJhvOYBwYDL>T?Yb9j!nrKgWJha@3f@)0fF5zFc~+t&&t=Sim1Y zG+A-YE&iA^DtluButnd&d11TPZ?fEQkrfu~VDi8u%JaCudsTQ$UOQ%Th#w{(`|S>> zYOwcopI!zuYzFx-+a8cmOS#*YHar=y3kPM@3@B3nt@%Rmj7ZmZ>)s9K+m#-dzey$& z+@o9kXO~lwa+GG>$HS}9wlT{cKz(QYwI`n=e0GvM$XB=TByW2vJ{fcvM&cA!2fJ(O zX`{(+6~3s1GCFW!xh*zZlvZ(Swql^=TXj0m>vmZ8p0epPJu~atuo>j)V^z6eRx^1? zJ~owBNV=BaXm(r9UnJ!0Aez_uH6G?0waj4)<2o0+#7Q%D%O#6!Nx;pVix^ntqGKsg z;`%zn8RgjqW}HJ!_KDIcUOby2p$0=ONF@=GAVRShB92dqB{(wFjIe1W#n{Cnipp^3 zC{nJ-GqI<5*5UiNa`j$ULesYiL_HJ<5o;U)gp^qWwOs!AL;->QnL$!Ioc^*7*NKwc z-R-Agc_L|nv0Ndhka*VM_u&2q*i-P=?>b4ipJP`koyUs3dKplFi}3eWV>OE{()ac? ze$Y>L;VtZ|ov2NDj9ahi^~MFF-BYdqL_s(tp;{efiRh+i*I$4RkPkz{qQLDvY|ZoI z_f|=g5=rw0y|kp$q_)xN9Khp=L3g1hW|rbUN!uoMh5;)msNY#l=*oh)HQ&@G@O&XQ zjvGtaD90$WYv@&o2uXAd60Kc2uwB&%A+%g>%WOGD2Rt8Y%Nxd^+vQ@USsFBpHu38*XT$c--T=Ot})?8x1<+N}*M|%*L4hG2mS@@MB zGsQDaZrQp}D^}=Rz$8n5Zd)Iok=mhU2o_$%WZ$LLk3hweFt(0yO5P`+@p0vfQP|Yo zE{{m{GiXCZA>Ty^zBDP(2V`=xiPx;9_(`tWFfys2hJS9VA*9eH|r%HWR1@`bbRm*YmDzS zfXHW}!`%#r*5~#$UqOM~VnE`PfO#~aa@1}*B-{WVxYge^@Igj~sO1U#VQ?ji0A9=^TA!pLA+3WoH9iFI2Gkt8f8Tp>&a^ig>~o1!lxywG;W$QO9_Bh zz22ge`6EObNdnW-$95nMlp1F`S=~wm8gNs@u}!8h{_QU++W-(hcl3Syo8dsw9}WCD zJGX(-Rt)iur(cAzQNtcUT#JJrSQIVhi!r25&~si2C@;P9S@fl08fMtkEz5@pk*)hn zEwNWJ)vQ-wI;x990R5Op&#u!|=yY5-NRF=WN4+|{FRdrth;J@%J_>kGG1f1-S;<4j zH7Hk&&_yn-el9M~$Ja#Oo8mI_TG+sSydipE1=o`t@RvYIfQX$lS?QfB)hDC4TMh*_ zzfmVAh5!*UV+fjrzqUo#ERXnkQX|Z;p4onPU?txawYIgLNM{vcD6))*To@BoQL-W{ z52@jiO#+I|zAWDdTUfbj+NQWVL*|)>e#tkDlu8+0oP?yr9eyKa+;-*=rJSRvPaQuQ zVo<-QUK=l9Z(qX|1q)9D#Kx<$p zn~BKg5chs3F|zbWt*&SdO;V0qWM4ZqLXa@)$9+L{MR2Z4wWxS9e45X*Jho6v3+^mJ zjCB{wK=&F=Ms?6-R|h3*cH#w7%cE&D27T)b=r{3Z6F5}`k9`IX&ala@T1H5I<}WH1 zv#Vc7D1iM4^>`Mxm_6cn(DQ&P#YgJt8xcWKaMGsWIjGE?$bS^|oqrySqJha$w|~n$ z+>r~3&&dW+wN$}wxJV$phl4-iT=w8zj38fwRaaC5VE9V#Ca#%EBYBwn*TARFVZ}8E zr5GbEiw%&B+z!{Ydk0OH%VGn?K&(DRGS!7&#D|tODb=z|KYVHie2Pu*hovQQ18jB2 z_!MT#OMcb_w<+X%zHZ|=H*mVleSgrsp>C!9wzEi_JVK=2$U2S}vJ&aEH+pm_$C_Aq z9leK&d4?s<=9oc;IYpv+(V$44?N$#a;qw5#zD43b72^EIvxSe1U-u-4lZOYMx3=~5 zql9uNj&##d>Ak>+1pUq%kZdp_Y)n_T3Q}%ql~_jPwD?s-00tMxkRoZ?f6& zAhcNtqHoKxt0Jt1_xT|yB|LsO3j0=TFO~3{WS%4rZM4~i(f39uTFjy215oc?-xcuA zw$r?5@_J<1G)i3sp|(vvC77wWMP581m{6}@Y8yXoLzxDHOFD9m;Dl9vF%-4fth1w4 zFihk%SgAf%^Qb39^V`m{1x?G*Vrp@07gK?u7ZdHAMIRz~A*jS8((~RawIuEvcx8zh zV2BZ}YpRV#wv{;NZkO@BUh&DUv5*cNV@9Fwb8H1M-Chqr+(;0)K$|UlFqvKN?(LA4 zZ=wK|_mRbl)6_L41c2dZVg~J@Ww4$*)deHc0|5t&1!SR3T8c>R%05F`c}RwZiQrX32TL0mqEHES#5GJC>R8n12GodVcv>%;Z-P6{X?qZxt3V zJLt05ch5SRzu`SE+W2L0e`AXUtwbOFFqm0yo76MX`^vH#H<^W!dabxyr0C0}nQzo1 zo1k8)Hw(<*7>fs+;a(C z)+_YTL1%7|gl*))x>>@$K%n;ML8xt(t;e+VtlQ-AH2g8KO|m-T&I(-84cHC+bJEkY zp-)m0+UQCEdtZl@g4n_3O^vWW<%=naXhMpD<);IUm&e2tt&xl5P7x*@{iDxuBVtpg`x?`o!7*QkEvgi6@h}VZznEpPN^h!dNZU3;$`ioiA+AZ zy-T@;W(BN^>aYpo*GS^e*xC6!{iE8hpOjm?V7ll70sXQJ0g|$0nTafq+o8pj`lVWU zW*h3K55?D1lRuwv$#Ou#51@X`0+=&ZZC7nia2{6TH-1_g!$?_aD-fKJrGt!Y_z=l| z-Zgp2&-)<;M=4mP7+Y3eA}u(8Cozf1X}zGDOD<4PPSwwbfgpQW)F6(7ordIO!=;?U zI>=`~_X;PbK!8v6{Go^Tb#Q>Q$FmrZHBV&k%qM0Yx|FPN5{wlKZDD)U(X{Xf&}O95 zg1G1a5;EQ^UeY5McE)CD4Q6Xty!6)v{Sdb~CFEs_6y99fz3-Xc@)K%QA%rHA6fyFIpOu;WR%sbp;1ltA^=ZM{w@wK4MlCBJ3 zVvlXwHh_|jz+t6v3@$R4Zx`ko8c5A?5UqGYms}oo#85(Hk}q33U=n8-Hio#{)Pj}v zB5P4QsN|2`#m(_$8RTtU7;!XTX_<6i^Q^%xYRclS-KHd`W_e&6bzCY)bxX|8PQ7{y z6{x<;_SP}5&_+@tjzgW!&@l(|EEH@F1I#pP)`Mv&O?MZ;^l8MB%X2`@eOI5sP&AzJ z@F|BE#3&jlJVaX6LLMfkDdft)3`rA)U6S43-yjeytm{!`zLvj=!np`Rz8WC0y@P;j z-~H&TnIGHk1^+~4z?{h>aC6~_H?rKh(1?ZctUaEy=pjt8CjLTMIDk7u)~$VpJ~$p$ zy^NnI)Qc3g@L>O8wl$aCsnUheM+ixH8D!f|$5Gul9KFwMCcjk^u z0sZG9>Mr7eT{OO0)A;0+t%#)W(7nLX;i-B)q8l3F>?%50IdT)Vv9Ps_Lg@yIq6}Q4i8q8cD7^6h|~ht)s&~Q*2*Bt z9ehGQohkMcOQ-x2p22%jer?{MIK>5}eqY=V3vp<7i@G(_w*;iG_$A~pH}_~tnOrM* zA~MsLPCZ_}tgKF(!(iP7Y|yQ%!+3d`#6Oj*`iIq&<)L}V-VEyE7Yyli1T5)(*0)XN zgG0ktn3*uSi1uX1Pwb58orI6|iiKT*HI)bMyYntGz9WiOA9-p=f7H%GX$K`X*UkP& z)8hr z!7Dsg$agW5(EOPjACqoJMJOiMU%rDo^B9`O)E^9{%zmG+gfW03D}a8d%(kI>fq{Li zFi_Hs@hF=vAj{(0IkKt&Cumj5M32+aOt7OyB7rUWY(}t*lRhiClhuU=+GmTg%FhLE z*4#6nXmAW;-I7#ZEal1%Yaf}LF;CJ(nY_3!bVqud9GB)xCi7Lti@99~z5#A<$F4@WOH%ysRQAYd z(?rjSwd`W6FEdb0j%eO}-NREmm3tRv82VVop~o_ zXf`YQ2AN^5 zxlj^+oKw8DmoILUzT-BLg7=^ zX7xu4skHvKgEH9i>-dtmM(YWywAe->=icMRcb7efZT1LSW(Z)8SeRBnP!BG{CDV_g zId?fM7R7~@1a)eHpf_H~u`d!MDR%AV(xtu*%Ci#YJ%VdYHFnzmHS58{L( zP%!yjn}tP6pqk`;@d=WAK%x4e*bC)DyA9bBkbE#X)?Mf7MTYcF+RSqEFpoH2 zU5EPxQR(*CBt`2qrfFYLc%wFJ5WgG?TI>W`p7N#IfRHWirY~1Q`A$C8>A{b>#N&S| z)!xUSF)3L*wmTle|Fp)yel;X{%jtZAuHZpxLbgi!$$cmh=XQzhd$-U1q8TeVcS@^iqMgA8Q-5(B#|&lpRYd55L$y!`OO)4{-N0mDjWmi#l3bfMQ_ zZv`k!o3!opx}`}iuMec2yUa;wD4gE1i;J z=;B>&^lX+BKX$J$m-u&CTH|~wvW10{305T9?CwB}U}mF%jA9FlP3+2d1*l6rOoR`O z$EuWWS*TMqOqK_&OzpvjxKi(9RdvQ9QUtQC?yFOC)aV7T5OuT((;_3G0LIl1FgFmzlChwjQ-)E-6c(|7ZXG45lo()!X@V z1NHakkAu-~<%8dq*T1Iny(z0I(T~3hM9olnu|w32#V5RVq|Ur;vk=QJj{H{Hx0rsq z!(k!Cv@(sHK1!E8W$$P|a|m+MT1Yk%kYpGH;>ix$p99G$Js-ByHZWsTSnt%Bbr!_f z^=Rs2K}kX%*d2^d;DD&6v?>;d5VI$#53CVOU?hzuNnkcM25O>bq@oGw@3E7~n2*sD zOqV#(Dc}#)K(;tF4SrQxla(Jygl?Bk4r1{>05yHcnc)6U~+Z4J*^oMV$w-il~n z%s_*VRdz9kmB9_Pk|TZn4k|UypV}QHd+SDgU9O8WP>VPb8FA!xs^b)>Q`ZU|v|=Wz zmbGcJ+l@N?om(xL@g}lBilsuWCs6dX>FTR-mcdY_L&IPv@xTA5|IBX|OMUNSd=+Tf zD1fPp?QzZU8|Ro0Xsa9;#KKxJfxNgm zPy{J)){=E+@?5yjFKD^XbhwS)73&|=!UDee2&nU^N@4&9Cp6c5sC?AYXR~eGBhl$} zvc>7#gLy?R@|W2GKA4Uhp$RAFdD-kBKJlGQ#87l)WhCl10ch63KS_Auf-*GRdz}K} zxDO)A*i;=SoR8+&bmVJK*7+T3UF?)+($2XPie`uDhoGcy998K+w^bCJPXgZgm^`;I)4m!J` zV%AxLm%0MP3hWeNgEaE3Hg3fbB)_%Gm1+neW*h5Rk47PrX(5IhQQi9rFS*ApdPa0W zERP(G1-!~Hqa=LBKl5^gi=%l{3E?(k!KU&fWxUDGAxf@7EWoe3D27pawt}@29VK76 zf2svBdDvHcv(whhBW|KBrXCoTP)x3c19epRW)iY+3}C6i4%QN@TqRdHrC7mXp&v_| zo6p3#3y;Ht4^Fb(+~Cyqm=)TI&`Y}&Y2rzaD)d9&0B7CEMo6MM(6&00)h94i7_|>s zNCIM{&Gn9<%dM5MpRuut7s7dhfes+*x>POsQMp@a8(`06AWsK`>U z3o88dQ>)~HASp}u$Z(wto^r#BG*`z`X>muugem2t*DA~W8OO{Gvg{S*yM@FkEYpa& zLNGxtn85sFp#^-e;6Y#zA#`X#HsH!(Wz~gi_8x=)Q+g+$IKJ)AN4z}DXEeKvmpKnP z#w%Pr53x}YS)jMt}?h|KNuClhwc6_&m z!s$d^MQe-9E-pY=>VieVvm{9WyzkOfzU`BGXZ{f8N*?$;9hW1pZ5FnTn^_CnzX3w0L!JyK$I8ssHl{Gi_>BXQ zN+UoWa#8YfjSEJap}8A)pC5_Uqe@tJHo6S%fRVeDKPizq2czW?P1oXQPZFs0FIFJV zcau)!$BG1#QrYbpTY!DGr^aWKWr^7r|Eh%G8T+^O5g}% z7G(>%o7_44Lak8|vt!(@PK#Pr80_<%3$5|fi3Q&Yb>2DZHOWTdz8n^u8w}`K*n?b} z#VgHTQ~|tgMZjztxc|ben!B3Hy@tl7<-ASug^v!`9V>l{d@`E5zGH-!h`F0po-@OT zbSx*b`YdNi3>jm8erav=2v+Yq&o?*x14E;`@R70Nkyp_7R~TZL>X^eft$gI$Q6m10 zBgFW>93pACR_h*m=+zsP05Xq@*)a!W8)lwu6w+ONn2s{A z7otz1Y+d83XtxMrhppv@4=3^s&6cVcAaFdD?o#g~88KH}Dye~Q(pXAt4Eh1#y)Fd; zD*Pd*ruGp^3GXv=S44_(W5{m}PojV7gi)g{+A;rT9)B;u zts;*-D3@u1}fR+YCgbn?e+4M|AYq!_vrPE2dd*8cKjB_zd0IjVO*9$+# zokA)eE{gEmB^k8}-1#CiRepw9kr$=K+d2fdw;ehPOI@X}5gGk>8E&U+OW2 zlehWLhCA=B^UY3&nvR}75P#qid0^?5b>QDt`0$A(qrj;tX^L8St`U%BtkIOA@{c(C zx4DBORnOvEC-a*p!<(Vwf6b4xk-dYNjrFe~_53kb*~t|O01)I4hXs&7Y#elS{`{Vf z4n8+N89glrEdv>(oFpI+F7Uqz-oE&+7bpPwpP%6YPWb=$_{Y`e%|!wf0GNaz`acQ& z6Do-_AmYt;@@<@d7*GBa>TNIpdA_{=1@+EL;(tQ@;}1Xp=&@6X%@1Qnwu1ryz+(O< zYSMq9{vignwXy%FYsBxNUOYy)g1^Z_^xvKlVc(54>#Y>}kH2kTV`Z!Nf5BAznHw`=1>$$3_Pz7TufzQM9`OGocHcXm{EB_{9{X?oPyVL~`Ssgf_pcOaPJgBNSNij>4F306 z-fO~sWzqEdE6Y2D*zXa%7oPk|pz8Zqg1^X6evjn+y5FxPJV}2gdA|zydkpVagnnfR zO8G0p`}LvUV|ZVp`;|c@`>zb|3U$MY z^Y4v(LYn$-(z~8?)}OHT>e+4Uz5N83IE@B nhxh5yulPh2f8zg|JpFf|pl>+<0073@o9OM%?f3Rm0D%7o(m9di literal 0 HcmV?d00001 diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 0f1f005ce3edf..ec3ad9933cf60 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -19,7 +19,6 @@ import os import sys -from threading import RLock, Timer from py4j.java_gateway import java_import, JavaObject @@ -33,63 +32,6 @@ __all__ = ["StreamingContext"] -class Py4jCallbackConnectionCleaner(object): - - """ - A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617. - It will scan all callback connections every 30 seconds and close the dead connections. - """ - - def __init__(self, gateway): - self._gateway = gateway - self._stopped = False - self._timer = None - self._lock = RLock() - - def start(self): - if self._stopped: - return - - def clean_closed_connections(): - from py4j.java_gateway import quiet_close, quiet_shutdown - - callback_server = self._gateway._callback_server - if callback_server: - with callback_server.lock: - try: - closed_connections = [] - for connection in callback_server.connections: - if not connection.isAlive(): - quiet_close(connection.input) - quiet_shutdown(connection.socket) - quiet_close(connection.socket) - closed_connections.append(connection) - - for closed_connection in closed_connections: - callback_server.connections.remove(closed_connection) - except Exception: - import traceback - traceback.print_exc() - - self._start_timer(clean_closed_connections) - - self._start_timer(clean_closed_connections) - - def _start_timer(self, f): - with self._lock: - if not self._stopped: - self._timer = Timer(30.0, f) - self._timer.daemon = True - self._timer.start() - - def stop(self): - with self._lock: - self._stopped = True - if self._timer: - self._timer.cancel() - self._timer = None - - class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext @@ -105,9 +47,6 @@ class StreamingContext(object): # Reference to a currently active StreamingContext _activeContext = None - # A cleaner to clean leak sockets of callback server every 30 seconds - _py4j_cleaner = None - def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -155,34 +94,12 @@ def _ensure_initialized(cls): # get the GatewayServer object in JVM by ID jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port - gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) - _py4j_cleaner = Py4jCallbackConnectionCleaner(gw) - _py4j_cleaner.start() + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing - if cls._transformerSerializer is None: - transformer_serializer = TransformFunctionSerializer() - transformer_serializer.init( - SparkContext._active_spark_context, CloudPickleSerializer(), gw) - # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM - # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. - # (https://github.com/bartdag/py4j/pull/184) - # - # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when - # calling "registerSerializer". If we call "registerSerializer" twice, the second - # PythonProxyHandler will override the first one, then the first one will be GCed and - # trigger "PythonProxyHandler.finalize". To avoid that, we should not call - # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't - # be GCed. - # - # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version. - transformer_serializer.gateway.jvm.PythonDStream.registerSerializer( - transformer_serializer) - cls._transformerSerializer = transformer_serializer - else: - cls._transformerSerializer.init( - SparkContext._active_spark_context, CloudPickleSerializer(), gw) + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) @classmethod def getOrCreate(cls, checkpointPath, setupFunc): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index e617fc9ce9eec..abbbf6eb9394f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -89,10 +89,11 @@ class TransformFunctionSerializer(object): it uses this class to invoke Python, which returns the serialized function as a byte array. """ - def init(self, ctx, serializer, gateway=None): + def __init__(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) self.failure = None def dumps(self, id): diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index d8d9d00d64ebc..0c37985a670b2 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -27,4 +27,4 @@ fi export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:${PYTHONPATH}" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.1-src.zip:${PYTHONPATH}" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 953fe95177f02..8c9beccc2922c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -169,16 +169,6 @@ private[python] object PythonDStream { PythonTransformFunctionSerializer.register(ser) } - /** - * Update the port of callback client to `port` - */ - def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { - val cl = gws.getCallbackClient - val f = cl.getClass.getDeclaredField("port") - f.setAccessible(true) - f.setInt(cl, port) - } - /** * helper function for DStream.foreachRDD(), * cannot be `foreachRDD`, it will confusing py4j diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8cf438be587dc..d4ca255953a48 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1044,9 +1044,9 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), "pyspark.zip not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.9-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.9.1-src.zip") require(py4jFile.exists(), - "py4j-0.9-src.zip not found; cannot run pyspark application in YARN mode.") + "py4j-0.9.1-src.zip not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 6db012a77a936..b91c4be2ea875 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -151,9 +151,9 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // When running tests, let's not assume the user has built the assembly module, which also // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the // needed locations. - val sparkHome = sys.props("spark.test.home"); + val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.9-src.zip", + s"$sparkHome/python/lib/py4j-0.9.1-src.zip", s"$sparkHome/python") val extraEnv = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), From 9247084962259ebbbac4c5a80a6ccb271776f019 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 12 Jan 2016 18:21:04 -0800 Subject: [PATCH 1545/2089] [SPARK-12785][SQL] Add ColumnarBatch, an in memory columnar format for execution. There are many potential benefits of having an efficient in memory columnar format as an alternate to UnsafeRow. This patch introduces ColumnarBatch/ColumnarVector which starts this effort. The remaining implementation can be done as follow up patches. As stated in the in the JIRA, there are useful external components that operate on memory in a simple columnar format. ColumnarBatch would serve that purpose and could server as a zero-serialization/zero-copy exchange for this use case. This patch supports running the underlying data either on heap or off heap. On heap runs a bit faster but we would need offheap for zero-copy exchanges. Currently, this mode is hidden behind one interface (ColumnVector). This differs from Parquet or the existing columnar cache because this is *not* intended to be used as a storage format. The focus is entirely on CPU efficiency as we expect to only have 1 of these batches in memory per task. The layout of the values is just dense arrays of the value type. Author: Nong Li Author: Nong Closes #10628 from nongli/spark-12635. --- .../execution/vectorized/ColumnVector.java | 176 ++++++++++ .../execution/vectorized/ColumnarBatch.java | 296 ++++++++++++++++ .../vectorized/OffHeapColumnVector.java | 179 ++++++++++ .../vectorized/OnHeapColumnVector.java | 175 ++++++++++ .../vectorized/ColumnarBatchBenchmark.scala | 320 ++++++++++++++++++ .../vectorized/ColumnarBatchSuite.scala | 317 +++++++++++++++++ 6 files changed, 1463 insertions(+) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java new file mode 100644 index 0000000000000..d9dde92ceb6d7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.types.DataType; + +/** + * This class represents a column of values and provides the main APIs to access the data + * values. It supports all the types and contains get/put APIs as well as their batched versions. + * The batched versions are preferable whenever possible. + * + * Most of the APIs take the rowId as a parameter. This is the local 0-based row id for values + * in the current RowBatch. + * + * A ColumnVector should be considered immutable once originally created. In other words, it is not + * valid to call put APIs after reads until reset() is called. + */ +public abstract class ColumnVector { + /** + * Allocates a column with each element of size `width` either on or off heap. + */ + public static ColumnVector allocate(int capacity, DataType type, boolean offHeap) { + if (offHeap) { + return new OffHeapColumnVector(capacity, type); + } else { + return new OnHeapColumnVector(capacity, type); + } + } + + public final DataType dataType() { return type; } + + /** + * Resets this column for writing. The currently stored values are no longer accessible. + */ + public void reset() { + numNulls = 0; + if (anyNullsSet) { + putNotNulls(0, capacity); + anyNullsSet = false; + } + } + + /** + * Cleans up memory for this column. The column is not usable after this. + * TODO: this should probably have ref-counted semantics. + */ + public abstract void close(); + + /** + * Returns the number of nulls in this column. + */ + public final int numNulls() { return numNulls; } + + /** + * Returns true if any of the nulls indicator are set for this column. This can be used + * as an optimization to prevent setting nulls. + */ + public final boolean anyNullsSet() { return anyNullsSet; } + + /** + * Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid + * to call for off heap columns. + */ + public abstract long nullsNativeAddress(); + public abstract long valuesNativeAddress(); + + /** + * Sets the value at rowId to null/not null. + */ + public abstract void putNotNull(int rowId); + public abstract void putNull(int rowId); + + /** + * Sets the values from [rowId, rowId + count) to null/not null. + */ + public abstract void putNulls(int rowId, int count); + public abstract void putNotNulls(int rowId, int count); + + /** + * Returns whether the value at rowId is NULL. + */ + public abstract boolean getIsNull(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putInt(int rowId, int value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putInts(int rowId, int count, int value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putInts(int rowId, int count, int[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * The data in src must be 4-byte little endian ints. + */ + public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the integer for rowId. + */ + public abstract int getInt(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putDouble(int rowId, double value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putDoubles(int rowId, int count, double value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * The data in src must be ieee formated doubles. + */ + public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the double for rowId. + */ + public abstract double getDouble(int rowId); + + /** + * Maximum number of rows that can be stored in this column. + */ + protected final int capacity; + + /** + * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. + */ + protected int numNulls; + + /** + * True if there is at least one NULL byte set. This is an optimization for the writer, to skip + * having to clear NULL bits. + */ + protected boolean anyNullsSet; + + /** + * Data type for this column. + */ + protected final DataType type; + + protected ColumnVector(int capacity, DataType type) { + this.capacity = capacity; + this.type = type; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java new file mode 100644 index 0000000000000..47defac4534dc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +import org.apache.commons.lang.NotImplementedException; + +/** + * This class is the in memory representation of rows as they are streamed through operators. It + * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that + * each operator allocates one of thee objects, the storage footprint on the task is negligible. + * + * The layout is a columnar with values encoded in their native format. Each RowBatch contains + * a horizontal partitioning of the data, split into columns. + * + * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. + * + * TODO: + * - There are many TODOs for the existing APIs. They should throw a not implemented exception. + * - Compaction: The batch and columns should be able to compact based on a selection vector. + */ +public final class ColumnarBatch { + private static final int DEFAULT_BATCH_SIZE = 4 * 1024; + + private final StructType schema; + private final int capacity; + private int numRows; + private final ColumnVector[] columns; + + // True if the row is filtered. + private final boolean[] filteredRows; + + // Total number of rows that have been filtered. + private int numRowsFiltered = 0; + + public static ColumnarBatch allocate(StructType schema, boolean offHeap) { + return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, offHeap); + } + + public static ColumnarBatch allocate(StructType schema, boolean offHeap, int maxRows) { + return new ColumnarBatch(schema, maxRows, offHeap); + } + + /** + * Called to close all the columns in this batch. It is not valid to access the data after + * calling this. This must be called at the end to clean up memory allcoations. + */ + public void close() { + for (ColumnVector c: columns) { + c.close(); + } + } + + /** + * Adapter class to interop with existing components that expect internal row. A lot of + * performance is lost with this translation. + */ + public final class Row extends InternalRow { + private int rowId; + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public final void markFiltered() { + ColumnarBatch.this.markFiltered(rowId); + } + + @Override + public final int numFields() { + return ColumnarBatch.this.numCols(); + } + + @Override + public final InternalRow copy() { + throw new NotImplementedException(); + } + + @Override + public final boolean anyNull() { + throw new NotImplementedException(); + } + + @Override + public final boolean isNullAt(int ordinal) { + return ColumnarBatch.this.column(ordinal).getIsNull(rowId); + } + + @Override + public final boolean getBoolean(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final byte getByte(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final short getShort(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final int getInt(int ordinal) { + return ColumnarBatch.this.column(ordinal).getInt(rowId); + } + + @Override + public final long getLong(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final float getFloat(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final double getDouble(int ordinal) { + return ColumnarBatch.this.column(ordinal).getDouble(rowId); + } + + @Override + public final Decimal getDecimal(int ordinal, int precision, int scale) { + throw new NotImplementedException(); + } + + @Override + public final UTF8String getUTF8String(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final byte[] getBinary(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final CalendarInterval getInterval(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final InternalRow getStruct(int ordinal, int numFields) { + throw new NotImplementedException(); + } + + @Override + public final ArrayData getArray(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public final Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + } + + /** + * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + */ + public Iterator rowIterator() { + final int maxRows = ColumnarBatch.this.numRows(); + final Row row = new Row(); + return new Iterator() { + int rowId = 0; + + @Override + public boolean hasNext() { + while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { + ++rowId; + } + return rowId < maxRows; + } + + @Override + public Row next() { + assert(hasNext()); + while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { + ++rowId; + } + row.rowId = rowId++; + return row; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Resets the batch for writing. + */ + public void reset() { + for (int i = 0; i < numCols(); ++i) { + columns[i].reset(); + } + if (this.numRowsFiltered > 0) { + Arrays.fill(filteredRows, false); + } + this.numRows = 0; + this.numRowsFiltered = 0; + } + + /** + * Sets the number of rows that are valid. + */ + public void setNumRows(int numRows) { + assert(numRows <= this.capacity); + this.numRows = numRows; + } + + /** + * Returns the number of columns that make up this batch. + */ + public int numCols() { return columns.length; } + + /** + * Returns the number of rows for read, including filtered rows. + */ + public int numRows() { return numRows; } + + /** + * Returns the number of valid rowss. + */ + public int numValidRows() { + assert(numRowsFiltered <= numRows); + return numRows - numRowsFiltered; + } + + /** + * Returns the max capacity (in number of rows) for this batch. + */ + public int capacity() { return capacity; } + + /** + * Returns the column at `ordinal`. + */ + public ColumnVector column(int ordinal) { return columns[ordinal]; } + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public final void markFiltered(int rowId) { + assert(filteredRows[rowId] == false); + filteredRows[rowId] = true; + ++numRowsFiltered; + } + + private ColumnarBatch(StructType schema, int maxRows, boolean offHeap) { + this.schema = schema; + this.capacity = maxRows; + this.columns = new ColumnVector[schema.size()]; + this.filteredRows = new boolean[maxRows]; + + for (int i = 0; i < schema.fields().length; ++i) { + StructField field = schema.fields()[i]; + columns[i] = ColumnVector.allocate(maxRows, field.dataType(), offHeap); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java new file mode 100644 index 0000000000000..2a9a2d1104b22 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.nio.ByteOrder; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.unsafe.Platform; + + +import org.apache.commons.lang.NotImplementedException; + +/** + * Column data backed using offheap memory. + */ +public final class OffHeapColumnVector extends ColumnVector { + // The data stored in these two allocations need to maintain binary compatible. We can + // directly pass this buffer to external components. + private long nulls; + private long data; + + protected OffHeapColumnVector(int capacity, DataType type) { + super(capacity, type); + if (!ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) { + throw new NotImplementedException("Only little endian is supported."); + } + + this.nulls = Platform.allocateMemory(capacity); + if (type instanceof IntegerType) { + this.data = Platform.allocateMemory(capacity * 4); + } else if (type instanceof DoubleType) { + this.data = Platform.allocateMemory(capacity * 8); + } else { + throw new RuntimeException("Unhandled " + type); + } + reset(); + } + + @Override + public final long valuesNativeAddress() { + return data; + } + + @Override + public long nullsNativeAddress() { + return nulls; + } + + @Override + public final void close() { + Platform.freeMemory(nulls); + Platform.freeMemory(data); + nulls = 0; + data = 0; + } + + // + // APIs dealing with nulls + // + + @Override + public final void putNotNull(int rowId) { + Platform.putByte(null, nulls + rowId, (byte) 0); + } + + @Override + public final void putNull(int rowId) { + Platform.putByte(null, nulls + rowId, (byte) 1); + ++numNulls; + anyNullsSet = true; + } + + @Override + public final void putNulls(int rowId, int count) { + long offset = nulls + rowId; + for (int i = 0; i < count; ++i, ++offset) { + Platform.putByte(null, offset, (byte) 1); + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public final void putNotNulls(int rowId, int count) { + long offset = nulls + rowId; + for (int i = 0; i < count; ++i, ++offset) { + Platform.putByte(null, offset, (byte) 0); + } + } + + @Override + public final boolean getIsNull(int rowId) { + return Platform.getByte(null, nulls + rowId) == 1; + } + + // + // APIs dealing with ints + // + + @Override + public final void putInt(int rowId, int value) { + Platform.putInt(null, data + 4 * rowId, value); + } + + @Override + public final void putInts(int rowId, int count, int value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putInt(null, offset, value); + } + } + + @Override + public final void putInts(int rowId, int count, int[] src, int srcIndex) { + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 4 * rowId, count * 4); + } + + @Override + public final int getInt(int rowId) { + return Platform.getInt(null, data + 4 * rowId); + } + + // + // APIs dealing with doubles + // + + @Override + public final void putDouble(int rowId, double value) { + Platform.putDouble(null, data + rowId * 8, value); + } + + @Override + public final void putDoubles(int rowId, int count, double value) { + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8) { + Platform.putDouble(null, offset, value); + } + } + + @Override + public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, + null, data + 8 * rowId, count * 8); + } + + @Override + public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 8, count * 8); + } + + @Override + public final double getDouble(int rowId) { + return Platform.getDouble(null, data + rowId * 8); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java new file mode 100644 index 0000000000000..a7b3addf11b14 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.unsafe.Platform; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Arrays; + +/** + * A column backed by an in memory JVM array. This stores the NULLs as a byte per value + * and a java array for the values. + */ +public final class OnHeapColumnVector extends ColumnVector { + // The data stored in these arrays need to maintain binary compatible. We can + // directly pass this buffer to external components. + + // This is faster than a boolean array and we optimize this over memory footprint. + private byte[] nulls; + + // Array for each type. Only 1 is populated for any type. + private int[] intData; + private double[] doubleData; + + protected OnHeapColumnVector(int capacity, DataType type) { + super(capacity, type); + if (type instanceof IntegerType) { + this.intData = new int[capacity]; + } else if (type instanceof DoubleType) { + this.doubleData = new double[capacity]; + } else { + throw new RuntimeException("Unhandled " + type); + } + this.nulls = new byte[capacity]; + reset(); + } + + @Override + public final long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + @Override + public final long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + + @Override + public final void close() { + nulls = null; + intData = null; + doubleData = null; + } + + + // + // APIs dealing with nulls + // + + @Override + public final void putNotNull(int rowId) { + nulls[rowId] = (byte)0; + } + + @Override + public final void putNull(int rowId) { + nulls[rowId] = (byte)1; + ++numNulls; + anyNullsSet = true; + } + + @Override + public final void putNulls(int rowId, int count) { + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)1; + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public final void putNotNulls(int rowId, int count) { + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)0; + } + } + + @Override + public final boolean getIsNull(int rowId) { + return nulls[rowId] == 1; + } + + // + // APIs dealing with Ints + // + + @Override + public final void putInt(int rowId, int value) { + intData[rowId] = value; + } + + @Override + public final void putInts(int rowId, int count, int value) { + for (int i = 0; i < count; ++i) { + intData[i + rowId] = value; + } + } + + @Override + public final void putInts(int rowId, int count, int[] src, int srcIndex) { + System.arraycopy(src, srcIndex, intData, rowId, count); + } + + @Override + public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i) { + intData[i + rowId] = Platform.getInt(src, srcOffset);; + srcIndex += 4; + srcOffset += 4; + } + } + + @Override + public final int getInt(int rowId) { + return intData[rowId]; + } + + // + // APIs dealing with doubles + // + + @Override + public final void putDouble(int rowId, double value) { + doubleData[rowId] = value; + } + + @Override + public final void putDoubles(int rowId, int count, double value) { + Arrays.fill(doubleData, rowId, rowId + count, value); + } + + @Override + public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { + System.arraycopy(src, srcIndex, doubleData, rowId, count); + } + + @Override + public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + } + + @Override + public final double getDouble(int rowId) { + return doubleData[rowId]; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala new file mode 100644 index 0000000000000..e28153d12a354 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.ColumnVector +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.BitSet + +/** + * Benchmark to low level memory access using different ways to manage buffers. + */ +object ColumnarBatchBenchmark { + + // This benchmark reads and writes an array of ints. + // TODO: there is a big (2x) penalty for a random access API for off heap. + // Note: carefully if modifying this code. It's hard to reason about the JIT. + def intAccess(iters: Long): Unit = { + val count = 8 * 1000 + + // Accessing a java array. + val javaArray = { i: Int => + val data = new Array[Int](count) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + data(i) = i + i += 1 + } + i = 0 + while (i < count) { + sum += data(i) + i += 1 + } + } + } + + // Accessing ByteBuffers + val byteBufferUnsafe = { i: Int => + val data = ByteBuffer.allocate(count * 4) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + Platform.putInt(data.array(), Platform.BYTE_ARRAY_OFFSET + i * 4, i) + i += 1 + } + i = 0 + while (i < count) { + sum += Platform.getInt(data.array(), Platform.BYTE_ARRAY_OFFSET + i * 4) + i += 1 + } + } + } + + // Accessing offheap byte buffers + val directByteBuffer = { i: Int => + val data = ByteBuffer.allocateDirect(count * 4).asIntBuffer() + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + data.put(i) + i += 1 + } + data.rewind() + i = 0 + while (i < count) { + sum += data.get() + i += 1 + } + data.rewind() + } + } + + // Accessing ByteBuffer using the typed APIs + val byteBufferApi = { i: Int => + val data = ByteBuffer.allocate(count * 4) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + data.putInt(i) + i += 1 + } + data.rewind() + i = 0 + while (i < count) { + sum += data.getInt() + i += 1 + } + data.rewind() + } + } + + // Using unsafe memory + val unsafeBuffer = { i: Int => + val data: Long = Platform.allocateMemory(count * 4) + var sum = 0L + for (n <- 0L until iters) { + var ptr = data + var i = 0 + while (i < count) { + Platform.putInt(null, ptr, i) + ptr += 4 + i += 1 + } + ptr = data + i = 0 + while (i < count) { + sum += Platform.getInt(null, ptr) + ptr += 4 + i += 1 + } + } + } + + // Access through the column API with on heap memory + val columnOnHeap = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, false) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.putInt(i, i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + } + col.close + } + + // Access through the column API with off heap memory + def columnOffHeap = { i: Int => { + val col = ColumnVector.allocate(count, IntegerType, true) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.putInt(i, i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + } + col.close + }} + + // Access by directly getting the buffer backing the column. + val columnOffheapDirect = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, true) + var sum = 0L + for (n <- 0L until iters) { + var addr = col.valuesNativeAddress() + var i = 0 + while (i < count) { + Platform.putInt(null, addr, i) + addr += 4 + i += 1 + } + i = 0 + addr = col.valuesNativeAddress() + while (i < count) { + sum += Platform.getInt(null, addr) + addr += 4 + i += 1 + } + } + col.close + } + + // Access by going through a batch of unsafe rows. + val unsafeRowOnheap = { i: Int => + val buffer = new Array[Byte](count * 16) + var sum = 0L + for (n <- 0L until iters) { + val row = new UnsafeRow(1) + var i = 0 + while (i < count) { + row.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET + i * 16, 16) + row.setInt(0, i) + i += 1 + } + i = 0 + while (i < count) { + row.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET + i * 16, 16) + sum += row.getInt(0) + i += 1 + } + } + } + + // Access by going through a batch of unsafe rows. + val unsafeRowOffheap = { i: Int => + val buffer = Platform.allocateMemory(count * 16) + var sum = 0L + for (n <- 0L until iters) { + val row = new UnsafeRow(1) + var i = 0 + while (i < count) { + row.pointTo(null, buffer + i * 16, 16) + row.setInt(0, i) + i += 1 + } + i = 0 + while (i < count) { + row.pointTo(null, buffer + i * 16, 16) + sum += row.getInt(0) + i += 1 + } + } + Platform.freeMemory(buffer) + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Int Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + Java Array 248.8 1317.04 1.00 X + ByteBuffer Unsafe 435.6 752.25 0.57 X + ByteBuffer API 1752.0 187.03 0.14 X + DirectByteBuffer 595.4 550.35 0.42 X + Unsafe Buffer 235.2 1393.20 1.06 X + Column(on heap) 189.8 1726.45 1.31 X + Column(off heap) 408.4 802.35 0.61 X + Column(off heap direct) 237.6 1379.12 1.05 X + UnsafeRow (on heap) 414.6 790.35 0.60 X + UnsafeRow (off heap) 487.2 672.58 0.51 X + */ + val benchmark = new Benchmark("Int Read/Write", count * iters) + benchmark.addCase("Java Array")(javaArray) + benchmark.addCase("ByteBuffer Unsafe")(byteBufferUnsafe) + benchmark.addCase("ByteBuffer API")(byteBufferApi) + benchmark.addCase("DirectByteBuffer")(directByteBuffer) + benchmark.addCase("Unsafe Buffer")(unsafeBuffer) + benchmark.addCase("Column(on heap)")(columnOnHeap) + benchmark.addCase("Column(off heap)")(columnOffHeap) + benchmark.addCase("Column(off heap direct)")(columnOffheapDirect) + benchmark.addCase("UnsafeRow (on heap)")(unsafeRowOnheap) + benchmark.addCase("UnsafeRow (off heap)")(unsafeRowOffheap) + benchmark.run() + } + + def booleanAccess(iters: Int): Unit = { + val count = 8 * 1024 + val benchmark = new Benchmark("Boolean Read/Write", iters * count) + benchmark.addCase("Bitset") { i: Int => { + val b = new BitSet(count) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + if (i % 2 == 0) b.set(i) + i += 1 + } + i = 0 + while (i < count) { + if (b.get(i)) sum += 1 + i += 1 + } + } + }} + + benchmark.addCase("Byte Array") { i: Int => { + val b = new Array[Byte](count) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + if (i % 2 == 0) b(i) = 1; + i += 1 + } + i = 0 + while (i < count) { + if (b(i) == 1) sum += 1 + i += 1 + } + } + }} + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Boolean Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + Bitset 895.88 374.54 1.00 X + Byte Array 578.96 579.56 1.55 X + */ + benchmark.run() + } + + def main(args: Array[String]): Unit = { + intAccess(1024 * 40) + booleanAccess(1024 * 40) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala new file mode 100644 index 0000000000000..305a83e3e45c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.unsafe.Platform + +class ColumnarBatchSuite extends SparkFunSuite { + test("Null Apis") { + (false :: true :: Nil).foreach { offHeap => { + val reference = mutable.ArrayBuffer.empty[Boolean] + + val column = ColumnVector.allocate(1024, IntegerType, offHeap) + var idx = 0 + assert(column.anyNullsSet() == false) + + column.putNotNull(idx) + reference += false + idx += 1 + assert(column.anyNullsSet() == false) + + column.putNull(idx) + reference += true + idx += 1 + assert(column.anyNullsSet() == true) + assert(column.numNulls() == 1) + + column.putNulls(idx, 3) + reference += true + reference += true + reference += true + idx += 3 + assert(column.anyNullsSet() == true) + + column.putNotNulls(idx, 4) + reference += false + reference += false + reference += false + reference += false + idx += 4 + assert(column.anyNullsSet() == true) + assert(column.numNulls() == 4) + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getIsNull(v._2)) + if (offHeap) { + val addr = column.nullsNativeAddress() + assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) + } + } + column.close + }} + } + + test("Int Apis") { + (false :: true :: Nil).foreach { offHeap => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Int] + + val column = ColumnVector.allocate(1024, IntegerType, offHeap) + var idx = 0 + + val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray + column.putInts(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putInts(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + val littleEndian = new Array[Byte](8) + littleEndian(0) = 7 + littleEndian(1) = 1 + littleEndian(4) = 6 + littleEndian(6) = 1 + + column.putIntsLittleEndian(idx, 1, littleEndian, 4) + column.putIntsLittleEndian(idx + 1, 1, littleEndian, 0) + reference += 6 + (1 << 16) + reference += 7 + (1 << 8) + idx += 2 + + column.putIntsLittleEndian(idx, 2, littleEndian, 0) + reference += 7 + (1 << 8) + reference += 6 + (1 << 16) + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextInt() + column.putInt(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + column.putInts(idx, n, n + 1) + var i = 0 + while (i < n) { + reference += (n + 1) + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Off Heap=" + offHeap) + if (offHeap) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) + } + } + column.close + }} + } + + test("Double APIs") { + (false :: true :: Nil).foreach { offHeap => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Double] + + val column = ColumnVector.allocate(1024, DoubleType, offHeap) + var idx = 0 + + val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray + column.putDoubles(idx, 2, values, 0) + reference += 1.0 + reference += 2.0 + idx += 2 + + column.putDoubles(idx, 3, values, 2) + reference += 3.0 + reference += 4.0 + reference += 5.0 + idx += 3 + + val buffer = new Array[Byte](16) + Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234) + Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) + + column.putDoubles(idx, 1, buffer, 8) + column.putDoubles(idx + 1, 1, buffer, 0) + reference += 1.123 + reference += 2.234 + idx += 2 + + column.putDoubles(idx, 2, buffer, 0) + reference += 2.234 + reference += 1.123 + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextDouble() + column.putDouble(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + val v = random.nextDouble() + column.putDoubles(idx, n, v) + var i = 0 + while (i < n) { + reference += v + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " Off Heap=" + offHeap) + if (offHeap) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) + } + } + column.close + }} + } + + test("ColumnarBatch basic") { + (false :: true :: Nil).foreach { offHeap => { + val schema = new StructType() + .add("intCol", IntegerType) + .add("doubleCol", DoubleType) + .add("intCol2", IntegerType) + + val batch = ColumnarBatch.allocate(schema, offHeap) + assert(batch.numCols() == 3) + assert(batch.numRows() == 0) + assert(batch.numValidRows() == 0) + assert(batch.capacity() > 0) + assert(batch.rowIterator().hasNext == false) + + // Add a row [1, 1.1, NULL] + batch.column(0).putInt(0, 1) + batch.column(1).putDouble(0, 1.1) + batch.column(2).putNull(0) + batch.setNumRows(1) + + // Verify the results of the row. + assert(batch.numCols() == 3) + assert(batch.numRows() == 1) + assert(batch.numValidRows() == 1) + assert(batch.rowIterator().hasNext == true) + assert(batch.rowIterator().hasNext == true) + + assert(batch.column(0).getInt(0) == 1) + assert(batch.column(0).getIsNull(0) == false) + assert(batch.column(1).getDouble(0) == 1.1) + assert(batch.column(1).getIsNull(0) == false) + assert(batch.column(2).getIsNull(0) == true) + + // Verify the iterator works correctly. + val it = batch.rowIterator() + assert(it.hasNext()) + val row = it.next() + assert(row.getInt(0) == 1) + assert(row.isNullAt(0) == false) + assert(row.getDouble(1) == 1.1) + assert(row.isNullAt(1) == false) + assert(row.isNullAt(2) == true) + assert(it.hasNext == false) + assert(it.hasNext == false) + + // Filter out the row. + row.markFiltered() + assert(batch.numRows() == 1) + assert(batch.numValidRows() == 0) + assert(batch.rowIterator().hasNext == false) + + // Reset and add 3 throws + batch.reset() + assert(batch.numRows() == 0) + assert(batch.numValidRows() == 0) + assert(batch.rowIterator().hasNext == false) + + // Add rows [NULL, 2.2, 2], [3, NULL, 3], [4, 4.4, 4] + batch.column(0).putNull(0) + batch.column(1).putDouble(0, 2.2) + batch.column(2).putInt(0, 2) + + batch.column(0).putInt(1, 3) + batch.column(1).putNull(1) + batch.column(2).putInt(1, 3) + + batch.column(0).putInt(2, 4) + batch.column(1).putDouble(2, 4.4) + batch.column(2).putInt(2, 4) + batch.setNumRows(3) + + def rowEquals(x: InternalRow, y: Row): Unit = { + assert(x.isNullAt(0) == y.isNullAt(0)) + if (!x.isNullAt(0)) assert(x.getInt(0) == y.getInt(0)) + + assert(x.isNullAt(1) == y.isNullAt(1)) + if (!x.isNullAt(1)) assert(x.getDouble(1) == y.getDouble(1)) + + assert(x.isNullAt(2) == y.isNullAt(2)) + if (!x.isNullAt(2)) assert(x.getInt(2) == y.getInt(2)) + } + // Verify + assert(batch.numRows() == 3) + assert(batch.numValidRows() == 3) + val it2 = batch.rowIterator() + rowEquals(it2.next(), Row(null, 2.2, 2)) + rowEquals(it2.next(), Row(3, null, 3)) + rowEquals(it2.next(), Row(4, 4.4, 4)) + assert(!it.hasNext) + + // Filter out some rows and verify + batch.markFiltered(1) + assert(batch.numValidRows() == 2) + val it3 = batch.rowIterator() + rowEquals(it3.next(), Row(null, 2.2, 2)) + rowEquals(it3.next(), Row(4, 4.4, 4)) + assert(!it.hasNext) + + batch.markFiltered(2) + assert(batch.numValidRows() == 1) + val it4 = batch.rowIterator() + rowEquals(it4.next(), Row(null, 2.2, 2)) + + batch.close + }} + } +} From b3b9ad23cffc1c6d83168487093e4c03d49e1c2c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Jan 2016 18:45:55 -0800 Subject: [PATCH 1546/2089] [SPARK-12788][SQL] Simplify BooleanEquality by using casts. Author: Reynold Xin Closes #10730 from rxin/SPARK-12788. --- .../catalyst/analysis/HiveTypeCoercion.scala | 30 ++++--------------- .../analysis/HiveTypeCoercionSuite.scala | 28 ++++++++++++++++- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e9e20670817fe..980b5d52fa8f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -482,27 +482,6 @@ object HiveTypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { - CaseKeyWhen(numericExpr, Seq( - Literal(trueValues.head), booleanExpr, - Literal(falseValues.head), Not(booleanExpr), - Literal(false))) - } - - private def transform(booleanExpr: Expression, numericExpr: Expression) = { - If(Or(IsNull(booleanExpr), IsNull(numericExpr)), - Literal.create(null, BooleanType), - buildCaseKeyWhen(booleanExpr, numericExpr)) - } - - private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { - CaseWhen(Seq( - And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), - Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), - buildCaseKeyWhen(booleanExpr, numericExpr) - )) - } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -511,6 +490,7 @@ object HiveTypeCoercion { // all other cases are considered as false. // We may simplify the expression if one side is literal numeric values + // TODO: Maybe these rules should go into the optimizer. case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) if trueValues.contains(value) => bool case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) @@ -529,13 +509,13 @@ object HiveTypeCoercion { if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) case EqualTo(left @ BooleanType(), right @ NumericType()) => - transform(left , right) + EqualTo(Cast(left, right.dataType), right) case EqualTo(left @ NumericType(), right @ BooleanType()) => - transform(right, left) + EqualTo(left, Cast(right, left.dataType)) case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => - transformNullSafe(left, right) + EqualNullSafe(Cast(left, right.dataType), right) case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => - transformNullSafe(right, left) + EqualNullSafe(left, Cast(right, left.dataType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 23b11af9ac087..40378c6727667 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -320,7 +320,33 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("type coercion simplification for equal to") { + test("BooleanEquality type cast") { + val be = HiveTypeCoercion.BooleanEquality + // Use something more than a literal to avoid triggering the simplification rules. + val one = Add(Literal(Decimal(1)), Literal(Decimal(0))) + + ruleTest(be, + EqualTo(Literal(true), one), + EqualTo(Cast(Literal(true), one.dataType), one) + ) + + ruleTest(be, + EqualTo(one, Literal(true)), + EqualTo(one, Cast(Literal(true), one.dataType)) + ) + + ruleTest(be, + EqualNullSafe(Literal(true), one), + EqualNullSafe(Cast(Literal(true), one.dataType), one) + ) + + ruleTest(be, + EqualNullSafe(one, Literal(true)), + EqualNullSafe(one, Cast(Literal(true), one.dataType)) + ) + } + + test("BooleanEquality simplification") { val be = HiveTypeCoercion.BooleanEquality ruleTest(be, From f14922cff84b1e0984ba4597d764615184126bdc Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 12 Jan 2016 19:24:50 -0800 Subject: [PATCH 1547/2089] [SPARK-12692][BUILD][CORE] Scala style: Fix the style violation (Space before ",") Fix the style violation (space before , and :). This PR is a followup for #10643 Author: Kousuke Saruta Closes #10719 from sarutak/SPARK-12692-followup-core. --- core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala | 2 +- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- core/src/main/scala/org/apache/spark/status/api/v1/api.scala | 2 +- core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala | 2 +- .../scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala | 2 +- scalastyle-config.xml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 18e8cddbc40db..57108dcedcf0c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -50,7 +50,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1 : RDD[T], var rdd2 : RDD[U]) - extends RDD[Pair[T, U]](sc, Nil) + extends RDD[(T, U)](sc, Nil) with Serializable { val numPartitionsInRdd2 = rdd2.partitions.length diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 53e01a0dbfc06..9dad7944144d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -95,7 +95,7 @@ abstract class RDD[T: ClassTag]( /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = - this(oneParent.context , List(new OneToOneDependency(oneParent))) + this(oneParent.context, List(new OneToOneDependency(oneParent))) private[spark] def conf = sc.conf // ======================================================================= diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 5feb1dc2e5b74..9cd52d6c2bef5 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -115,7 +115,7 @@ class StageData private[spark]( val status: StageStatus, val stageId: Int, val attemptId: Int, - val numActiveTasks: Int , + val numActiveTasks: Int, val numCompleteTasks: Int, val numFailedTasks: Int, diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 4e72b89bfcc40..76451788d2406 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -178,7 +178,7 @@ class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { test("WorksWithOutOfRangeWithInfiniteBuckets") { // Verify that out of range works with two buckets val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) - val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0) + val buckets = Array(-1.0/0.0, 0.0, 1.0/0.0) val histogramResults = rdd.histogram(buckets) val expectedHistogramResults = Array(1, 1) assert(histogramResults === expectedHistogramResults) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 504e5780f3d8a..e111e2e9f6163 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -76,7 +76,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi test("check spark-class location correctly") { val conf = new SparkConf - conf.set("spark.mesos.executor.home" , "/mesos-home") + conf.set("spark.mesos.executor.home", "/mesos-home") val listenerBus = mock[LiveListenerBus] listenerBus.post( diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 2439a1f715aba..bc209ee6aa873 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -220,7 +220,7 @@ This file is divided into 3 sections: - COLON, COMMA + COMMA From dc7b3870fcfc2723319dbb8c53d721211a8116be Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 12 Jan 2016 21:41:38 -0800 Subject: [PATCH 1548/2089] [SPARK-12558][SQL] AnalysisException when multiple functions applied in GROUP BY clause cloud-fan Can you please take a look ? In this case, we are failing during check analysis while validating the aggregation expression. I have added a semanticEquals for HiveGenericUDF to fix this. Please let me know if this is the right way to address this issue. Author: Dilip Biswal Closes #10520 from dilipbiswal/spark-12558. --- .../org/apache/spark/sql/hive/HiveShim.scala | 23 +++++++++++++++++++ .../sql/hive/execution/HiveUDFSuite.scala | 7 ++++++ 2 files changed, 30 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index b8cced0b80969..087b0c087c111 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -26,11 +26,13 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} +import com.google.common.base.Objects import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector @@ -45,6 +47,7 @@ private[hive] object HiveShim { // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) val UNLIMITED_DECIMAL_PRECISION = 38 val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" /* * This function in hive-0.13 become private, but we have to do this to walkaround hive bug @@ -123,6 +126,26 @@ private[hive] object HiveShim { // for Serialization def this() = this(null) + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + @transient def deserializeObjectByKryo[T: ClassTag]( kryo: Kryo, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index c5ff8825abd7f..dfe33ba8b0502 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -350,6 +350,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { sqlContext.dropTempTable("testUDF") } + test("Hive UDF in group by") { + Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1") + val count = sql("select date(cast(test_date as timestamp))" + + " from tab1 group by date(cast(test_date as timestamp))").count() + assert(count == 1) + } + test("SPARK-11522 select input_file_name from non-parquet table"){ withTempDir { tempDir => From cb7b864a24db4826e2942c186afe3cb8bd788b03 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 12 Jan 2016 22:25:20 -0800 Subject: [PATCH 1549/2089] [SPARK-12692][BUILD][SQL] Scala style: Fix the style violation (Space before ",") Fix the style violation (space before , and :). This PR is a followup for #10643 and rework of #10685 . Author: Kousuke Saruta Closes #10732 from sarutak/SPARK-12692-followup-sql. --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../sql/catalyst/util/NumberConverter.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../apache/spark/sql/DatasetCacheSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 24 +++++++++---------- .../datasources/json/JsonSuite.scala | 2 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 4 ++-- 10 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2a132d8b82bef..6ec408a673c79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -203,7 +203,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(expression ~ direction.? , ",") ^^ { + ( rep1sep(expression ~ direction.?, ",") ^^ { case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 9fefc5656aac0..e4417e0955143 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -122,7 +122,7 @@ object NumberConverter { * unsigned, otherwise it is signed. * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv */ - def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 000a3b7ecb7c6..6932f185b9d62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -80,7 +80,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) - checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) + checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5), 'a < 2) checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2dd82358fbfdf..b909765a7c6dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -945,7 +945,7 @@ class SQLContext private[sql]( } } - // Register a succesfully instantiatd context to the singleton. This should be at the end of + // Register a successfully instantiated context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. sparkContext.addSparkListener(new SparkListener { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6b100577077c6..058d147c7d65d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -223,7 +223,7 @@ case class Exchange( new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } - protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { coordinator match { case Some(exchangeCoordinator) => val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 3a283a4e1f610..848f1af65508b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -27,7 +27,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("persist and unpersist") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) val cached = ds.cache() // count triggers the caching action. It should not throw. cached.count() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 53b5f45c2d4a6..693f5aea2d015 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -30,7 +30,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("toDS") { - val data = Seq(("a", 1) , ("b", 2), ("c", 3)) + val data = Seq(("a", 1), ("b", 2), ("c", 3)) checkAnswer( data.toDS(), data: _*) @@ -87,7 +87,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("as case class / collect") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] checkAnswer( ds, ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) @@ -105,7 +105,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("map") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.map(v => (v._1, v._2 + 1)), ("a", 2), ("b", 3), ("c", 4)) @@ -124,14 +124,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select(expr("_2 + 1").as[Int]), 2, 3, 4) } test("select 2") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -140,7 +140,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and tuple") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -149,7 +149,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], @@ -158,7 +158,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class, fields reordered") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDecoding( ds.select( expr("_1").as[String], @@ -167,28 +167,28 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("filter") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( ds.filter(_._1 == "b"), ("b", 2)) } test("foreach") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreach(v => acc += v._2) assert(acc.value == 6) } test("foreachPartition") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreachPartition(_.foreach(v => acc += v._2)) assert(acc.value == 6) } test("reduce") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4ab148065a476..860e07c68cef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -206,7 +206,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructType( StructField("f1", IntegerType, true) :: StructField("f2", IntegerType, true) :: Nil), - StructType(StructField("f1", LongType, true) :: Nil) , + StructType(StructField("f1", LongType, true) :: Nil), StructType( StructField("f1", LongType, true) :: StructField("f2", IntegerType, true) :: Nil)) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 03bc830df2034..f279b78f47c7d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -369,7 +369,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (counter != 0) { responseMsg += s", Fetched $counter row(s)" } - console.printInfo(responseMsg , null) + console.printInfo(responseMsg, null) // Destroy the driver to release all the locks. driver.destroy() } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index da7303c791064..40e9c9362cf5e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -154,8 +154,8 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } val expected = List( "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil, "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) From 3d81d63f4499478ef7861bf77383c30aed14bb19 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 13 Jan 2016 00:51:24 -0800 Subject: [PATCH 1550/2089] [SPARK-12692][BUILD] Enforce style checking about white space before comma This is the final PR about SPARK-12692. We have removed all of white spaces before comma from code so let's enforce style checking. Author: Kousuke Saruta Closes #10736 from sarutak/SPARK-12692-followup-enforce-checking. --- scalastyle-config.xml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index bc209ee6aa873..967a482ba4f9b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -197,6 +197,12 @@ This file is divided into 3 sections: + + + COMMA + + + @@ -217,13 +223,6 @@ This file is divided into 3 sections: - - - - COMMA - - - From d6fd9b376b7071aecef34dc82a33eba42b183bc9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 13 Jan 2016 10:01:15 -0800 Subject: [PATCH 1551/2089] [SPARK-12692][BUILD][HOT-FIX] Fix the scala style of KinesisBackedBlockRDDSuite.scala. https://github.com/apache/spark/pull/10736 was merged yesterday and caused the master start to fail because of the style issue. Author: Yin Huai Closes #10742 from yhuai/fixStyle. --- .../spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index e6f504c4e54dd..e916f1ee0893b 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -158,9 +158,9 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) testBlockRemove: Boolean = false ): Unit = { require(shardIds.size > 1, "Need at least 2 shards to test") - require(numPartitionsInBM <= shardIds.size , + require(numPartitionsInBM <= shardIds.size, "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") - require(numPartitionsInKinesis <= shardIds.size , + require(numPartitionsInKinesis <= shardIds.size, "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") require(numPartitionsInBM <= numPartitions, "Number of partitions in BlockManager cannot be more than that in RDD") From 63eee86cc652c108ca7712c8c0a73db1ca89ae90 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Jan 2016 10:26:55 -0800 Subject: [PATCH 1552/2089] [SPARK-9297] [SQL] Add covar_pop and covar_samp JIRA: https://issues.apache.org/jira/browse/SPARK-9297 Add two aggregation functions: covar_pop and covar_samp. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10029 from viirya/covar-funcs. --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/aggregate/Covariance.scala | 198 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 40 ++++ .../execution/AggregationQuerySuite.scala | 32 +++ 4 files changed, 272 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c2aa3c06b3e7..d9009e3848e58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -182,6 +182,8 @@ object FunctionRegistry { expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), + expression[CovPopulation]("covar_pop"), + expression[CovSample]("covar_samp"), expression[First]("first"), expression[First]("first_value"), expression[Last]("last"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala new file mode 100644 index 0000000000000..f53b01be2a0d5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +/** + * Compute the covariance between two expressions. + * When applied on empty data (i.e., count is zero), it returns NULL. + * + */ +abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate + with Serializable { + override def children: Seq[Expression] = Seq(left, right) + + override def nullable: Boolean = true + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"covariance requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = { + aggBufferAttributes.map(_.newInstance()) + } + + override val aggBufferAttributes: Seq[AttributeReference] = Seq( + AttributeReference("xAvg", DoubleType)(), + AttributeReference("yAvg", DoubleType)(), + AttributeReference("Ck", DoubleType)(), + AttributeReference("count", LongType)()) + + // Local cache of mutableAggBufferOffset(s) that will be used in update and merge + val xAvgOffset = mutableAggBufferOffset + val yAvgOffset = mutableAggBufferOffset + 1 + val CkOffset = mutableAggBufferOffset + 2 + val countOffset = mutableAggBufferOffset + 3 + + // Local cache of inputAggBufferOffset(s) that will be used in update and merge + val inputXAvgOffset = inputAggBufferOffset + val inputYAvgOffset = inputAggBufferOffset + 1 + val inputCkOffset = inputAggBufferOffset + 2 + val inputCountOffset = inputAggBufferOffset + 3 + + override def initialize(buffer: MutableRow): Unit = { + buffer.setDouble(xAvgOffset, 0.0) + buffer.setDouble(yAvgOffset, 0.0) + buffer.setDouble(CkOffset, 0.0) + buffer.setLong(countOffset, 0L) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val leftEval = left.eval(input) + val rightEval = right.eval(input) + + if (leftEval != null && rightEval != null) { + val x = leftEval.asInstanceOf[Double] + val y = rightEval.asInstanceOf[Double] + + var xAvg = buffer.getDouble(xAvgOffset) + var yAvg = buffer.getDouble(yAvgOffset) + var Ck = buffer.getDouble(CkOffset) + var count = buffer.getLong(countOffset) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + + buffer.setDouble(xAvgOffset, xAvg) + buffer.setDouble(yAvgOffset, yAvg) + buffer.setDouble(CkOffset, Ck) + buffer.setLong(countOffset, count) + } + } + + // Merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val count2 = buffer2.getLong(inputCountOffset) + + // We only go to merge two buffers if there is at least one record aggregated in buffer2. + // We don't need to check count in buffer1 because if count2 is more than zero, totalCount + // is more than zero too, then we won't get a divide by zero exception. + if (count2 > 0) { + var xAvg = buffer1.getDouble(xAvgOffset) + var yAvg = buffer1.getDouble(yAvgOffset) + var Ck = buffer1.getDouble(CkOffset) + var count = buffer1.getLong(countOffset) + + val xAvg2 = buffer2.getDouble(inputXAvgOffset) + val yAvg2 = buffer2.getDouble(inputYAvgOffset) + val Ck2 = buffer2.getDouble(inputCkOffset) + + val totalCount = count + count2 + val deltaX = xAvg - xAvg2 + val deltaY = yAvg - yAvg2 + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 + xAvg = (xAvg * count + xAvg2 * count2) / totalCount + yAvg = (yAvg * count + yAvg2 * count2) / totalCount + count = totalCount + + buffer1.setDouble(xAvgOffset, xAvg) + buffer1.setDouble(yAvgOffset, yAvg) + buffer1.setDouble(CkOffset, Ck) + buffer1.setLong(countOffset, count) + } + } +} + +case class CovSample( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends Covariance(left, right) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(countOffset) + if (count > 1) { + val Ck = buffer.getDouble(CkOffset) + val cov = Ck / (count - 1) + if (cov.isNaN) { + null + } else { + cov + } + } else { + null + } + } +} + +case class CovPopulation( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends Covariance(left, right) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(countOffset) + if (count > 0) { + val Ck = buffer.getDouble(CkOffset) + if (Ck.isNaN) { + null + } else { + Ck / count + } + } else { + null + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 592d79df3109a..71fea2716bd9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -308,6 +308,46 @@ object functions extends LegacyFunctions { def countDistinct(columnName: String, columnNames: String*): Column = countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) + /** + * Aggregate function: returns the population covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction { + CovPopulation(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the population covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_pop(columnName1: String, columnName2: String): Column = { + covar_pop(Column(columnName1), Column(columnName2)) + } + + /** + * Aggregate function: returns the sample covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction { + CovSample(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the sample covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_samp(columnName1: String, columnName2: String): Column = { + covar_samp(Column(columnName1), Column(columnName2)) + } + /** * Aggregate function: returns the first value in a group. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 5550198c02fbf..76b36aa89182e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) } + test("covariance: covar_pop and covar_samp") { + // non-trivial example. To reproduce in python, use: + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> np.cov(a, b, bias = 0)[0][1] + // 595.0 + // >>> np.cov(a, b, bias = 1)[0][1] + // 565.25 + val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp - 595.0) < 1e-12) + + val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop - 565.25) < 1e-12) + + val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp2 - 11564.0) < 1e-12) + + val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12) + + // one row test + val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) + assert(cov_samp3 == null) + + val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(cov_pop3 == 0.0) + } + test("no aggregation function (SPARK-11486)") { val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() From cc91e21879e031bcd05316eabb856e67a51b191d Mon Sep 17 00:00:00 2001 From: Luc Bourlier Date: Wed, 13 Jan 2016 11:45:13 -0800 Subject: [PATCH 1553/2089] [SPARK-12805][MESOS] Fixes documentation on Mesos run modes The default run has changed, but the documentation didn't fully reflect the change. Author: Luc Bourlier Closes #10740 from skyluc/issue/mesos-modes-doc. --- docs/running-on-mesos.md | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3193e17853483..ed720f1039f94 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -202,7 +202,7 @@ where each application gets more or fewer machines as it ramps up and down, but additional overhead in launching each task. This mode may be inappropriate for low-latency requirements like interactive queries or serving web requests. -To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +To run in fine-grained mode, set the `spark.mesos.coarse` property to false in your [SparkConf](configuration.html#spark-properties): {% highlight scala %} @@ -266,13 +266,11 @@ See the [configuration page](configuration.html) for information on Spark config

    - + From 38148f7373ee678cd538ce5eae0a75e15c62db8a Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 13 Jan 2016 11:53:59 -0800 Subject: [PATCH 1554/2089] [SPARK-12761][CORE] Remove duplicated code Removes some duplicated code that was reintroduced during a merge. Author: Jakob Odersky Closes #10711 from jodersky/repl-2.11-duplicate. --- .../src/main/scala/org/apache/spark/repl/Main.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 44650f25f7a18..bb3081d12938e 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -30,11 +30,7 @@ object Main extends Logging { val conf = new SparkConf() val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") - val s = new Settings() - s.processArguments(List("-Yrepl-class-based", - "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", - "-classpath", getAddedJars.mkString(File.pathSeparator)), true) - // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed + var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. From 97e0c7c5af4d002937f9ee679568bb501d8818fc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 Jan 2016 11:56:30 -0800 Subject: [PATCH 1555/2089] [SPARK-9383][PROJECT-INFRA] PR merge script should reset back to previous branch when possible This patch modifies our PR merge script to reset back to a named branch when restoring the original checkout upon exit. When the committer is originally checked out to a detached head, then they will be restored back to that same ref (the same as today's behavior). This is a slightly updated version of #7569, with an extra fix to handle the detached head corner-case. Author: Josh Rosen Closes #10709 from JoshRosen/SPARK-9383. --- dev/merge_spark_pr.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index bf1a000f46791..5ab285eae99b7 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -355,11 +355,21 @@ def standardize_jira_ref(text): return clean_text + +def get_current_ref(): + ref = run_cmd("git rev-parse --abbrev-ref HEAD").strip() + if ref == 'HEAD': + # The current ref is a detached HEAD, so grab its SHA. + return run_cmd("git rev-parse HEAD").strip() + else: + return ref + + def main(): global original_head os.chdir(SPARK_HOME) - original_head = run_cmd("git rev-parse HEAD")[:8] + original_head = get_current_ref() branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) @@ -449,5 +459,8 @@ def main(): (failure_count, test_count) = doctest.testmod() if failure_count: exit(-1) - - main() + try: + main() + except: + clean_up() + raise From e4e0b3f7b2945aae5ec7c3d68296010bbc5160cf Mon Sep 17 00:00:00 2001 From: Erik Selin Date: Wed, 13 Jan 2016 12:21:45 -0800 Subject: [PATCH 1556/2089] [SPARK-12268][PYSPARK] Make pyspark shell pythonstartup work under python3 This replaces the `execfile` used for running custom python shell scripts with explicit open, compile and exec (as recommended by 2to3). The reason for this change is to make the pythonstartup option compatible with python3. Author: Erik Selin Closes #10255 from tyro89/pythonstartup-python3. --- python/pyspark/shell.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 99331297c19f0..26cafca8b8381 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -76,4 +76,6 @@ # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + with open(_pythonstartup) as f: + code = compile(f.read(), _pythonstartup, 'exec') + exec(code) From c2ea79f96acd076351b48162644ed1cff4c8e090 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Jan 2016 12:29:02 -0800 Subject: [PATCH 1557/2089] [SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row https://issues.apache.org/jira/browse/SPARK-12642 Author: Wenchen Fan Closes #10694 from cloud-fan/hash-expr. --- python/pyspark/sql/functions.py | 2 +- .../sql/catalyst/expressions/UnsafeRow.java | 4 - .../spark/sql/catalyst/expressions/misc.scala | 251 +++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 6 +- .../sql/sources/BucketedWriteSuite.scala | 26 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 28 +- 6 files changed, 288 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b0390cb9942e6..719eca8f5559e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1023,7 +1023,7 @@ def hash(*cols): """Calculates the hash code of given columns, and returns the result as a int column. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() - [Row(hash=1358996357)] + [Row(hash=-757602832)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column)) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b8d3c49100476..1a351933a366c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -566,10 +566,6 @@ public int hashCode() { return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); } - public int hashCode(int seed) { - return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed); - } - @Override public boolean equals(Object other) { if (other instanceof UnsafeRow) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index cc406a39f0408..4751fbe4146fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -25,8 +25,11 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.Platform /** * A function that calculates an MD5 128-bit checksum and returns it as a hex string @@ -184,8 +187,31 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp * A function that calculates hash value for a group of expressions. Note that the `seed` argument * is not exposed to users and should only be set inside spark SQL. * - * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of - * the unsafe row using murmur3 hasher with a seed. + * The hash value for an expression depends on its type and seed: + * - null: seed + * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to + * hash this int with seed. + * - byte, short, int: use murmur3 to hash the input as int with seed. + * - long: use murmur3 to hash the long input with seed. + * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it. + * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it. + * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash + * it. Else, turn it into bytes and hash it. + * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`. + * - binary: use murmur3 to hash the bytes with seed. + * - string: get the bytes of string and hash it. + * - array: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each element, and assign the element hash value + * to `result`. + * - map: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each key-value, and assign the key-value hash + * value to `result`. + * - struct: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each field, and assign the field hash value to + * `result`. + * + * Finally we aggregate the hash values for each expression by the same way of struct. + * * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle * and bucketing have same data distribution. */ @@ -206,22 +232,225 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } - private lazy val unsafeProjection = UnsafeProjection.create(children) + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" override def eval(input: InternalRow): Any = { - unsafeProjection(input).hashCode(seed) + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash } + private def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + def hashInt(i: Int): Int = Murmur3_x86_32.hashInt(i, seed) + def hashLong(l: Long): Int = Murmur3_x86_32.hashLong(l, seed) + + value match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0) + case b: Byte => hashInt(b) + case s: Short => hashInt(s) + case i: Int => hashInt(i) + case l: Long => hashLong(l) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f)) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d)) + case d: Decimal => + val precision = dataType.asInstanceOf[DecimalType].precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(d.toUnscaledLong) + } else { + val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray + Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) + } + case c: CalendarInterval => Murmur3_x86_32.hashInt(c.months, hashLong(c.microseconds)) + case a: Array[Byte] => + Murmur3_x86_32.hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) + case s: UTF8String => + Murmur3_x86_32.hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + var result = seed + var i = 0 + while (i < array.numElements()) { + result = computeHash(array.get(i, elementType), elementType, result) + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(kt, vt, _) => kt -> vt + } + val keys = map.keyArray() + val values = map.valueArray() + var result = seed + var i = 0 + while (i < map.numElements()) { + result = computeHash(keys.get(i, kt), kt, result) + result = computeHash(values.get(i, vt), vt, result) + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + var result = seed + var i = 0 + val len = struct.numFields + while (i < len) { + result = computeHash(struct.get(i, types(i)), types(i), result) + i += 1 + } + result + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children) ev.isNull = "false" + val childrenHash = children.zipWithIndex.map { + case (child, dt) => + val childGen = child.gen(ctx) + val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) + s""" + ${childGen.code} + if (!${childGen.isNull}) { + ${childHash.code} + ${ev.value} = ${childHash.value}; + } + """ + }.mkString("\n") s""" - ${unsafeRow.code} - final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); + int ${ev.value} = $seed; + $childrenHash """ } - override def prettyName: String = "hash" - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" + private def computeHash( + input: String, + dataType: DataType, + seed: String, + ctx: CodeGenContext): GeneratedExpressionCode = { + val hasher = classOf[Murmur3_x86_32].getName + def hashInt(i: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashInt($i, $seed)") + def hashLong(l: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashLong($l, $seed)") + def inlineValue(v: String): GeneratedExpressionCode = + GeneratedExpressionCode(code = "", isNull = "false", value = v) + + dataType match { + case NullType => inlineValue(seed) + case BooleanType => hashInt(s"$input ? 1 : 0") + case ByteType | ShortType | IntegerType | DateType => hashInt(input) + case LongType | TimestampType => hashLong(input) + case FloatType => hashInt(s"Float.floatToIntBits($input)") + case DoubleType => hashLong(s"Double.doubleToLongBits($input)") + case d: DecimalType => + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(s"$input.toUnscaledLong()") + } else { + val bytes = ctx.freshName("bytes") + val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" + val offset = "Platform.BYTE_ARRAY_OFFSET" + val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" + GeneratedExpressionCode(code, "false", result) + } + case CalendarIntervalType => + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" + val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" + inlineValue(monthsHash) + case BinaryType => + val offset = "Platform.BYTE_ARRAY_OFFSET" + inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + case StringType => + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + + case ArrayType(et, _) => + val result = ctx.freshName("result") + val index = ctx.freshName("index") + val element = ctx.freshName("element") + val elementHash = computeHash(element, et, result, ctx) + val code = + s""" + int $result = $seed; + for (int $index = 0; $index < $input.numElements(); $index++) { + if (!$input.isNullAt($index)) { + final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; + ${elementHash.code} + $result = ${elementHash.value}; + } + } + """ + GeneratedExpressionCode(code, "false", result) + + case MapType(kt, vt, _) => + val result = ctx.freshName("result") + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val key = ctx.freshName("key") + val value = ctx.freshName("value") + val keyHash = computeHash(key, kt, result, ctx) + val valueHash = computeHash(value, vt, result, ctx) + val code = + s""" + int $result = $seed; + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; + ${keyHash.code} + $result = ${keyHash.value}; + if (!$values.isNullAt($index)) { + final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; + ${valueHash.code} + $result = ${valueHash.value}; + } + } + """ + GeneratedExpressionCode(code, "false", result) + + case StructType(fields) => + val result = ctx.freshName("result") + val fieldsHash = fields.map(_.dataType).zipWithIndex.map { + case (dt, index) => + val field = ctx.freshName("field") + val fieldHash = computeHash(field, dt, result, ctx) + s""" + if (!$input.isNullAt($index)) { + final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; + ${fieldHash.code} + $result = ${fieldHash.value}; + } + """ + }.mkString("\n") + val code = + s""" + int $result = $seed; + $fieldsHash + """ + GeneratedExpressionCode(code, "false", result) + + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 64161bebdcbe8..75131a6170222 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -79,7 +79,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { .add("long", LongType) .add("float", FloatType) .add("double", DoubleType) - .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) + .add("smallDecimal", DecimalType.USER_DEFAULT) .add("string", StringType) .add("binary", BinaryType) .add("date", DateType) @@ -126,7 +127,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { case (value, dt) => Literal.create(value, dt) } - checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed)) + // Only test the interpreted version has same result with codegen version. + checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 7f1745705aaaf..b718b7cefb2a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -70,6 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private def testBucketing( dataDir: File, source: String, @@ -82,27 +85,30 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(groupedBucketFiles.size <= 8) for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) - .select((bucketCols ++ sortCols).map(col): _*) + for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) { + val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) + val columns = (bucketCols ++ sortCols).zip(types).map { + case (colName, dt) => col(colName).cast(dt) + } + val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*) if (sortCols.nonEmpty) { - checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) } - val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + val qe = readBack.select(bucketCols.map(col): _*).queryExecution + val rows = qe.toRdd.map(_.copy()).collect() + val getHashCode = + UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output) for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 + val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8) assert(actualBucketId == bucketId) } } } } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - test("write bucketed data") { for (source <- Seq("parquet", "json", "orc")) { withTable("bucketed_table") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 4276f25c2165b..5e7ee480cafd1 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -38,6 +38,10 @@ public String toString() { } public int hashInt(int input) { + return hashInt(input, seed); + } + + public static int hashInt(int input, int seed) { int k1 = mixK1(input); int h1 = mixH1(seed, k1); @@ -51,16 +55,38 @@ public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = Platform.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } - return fmix(h1, lengthInBytes); + return h1; } public int hashLong(long input) { + return hashLong(input, seed); + } + + public static int hashLong(long input, int seed) { int low = (int) input; int high = (int) (input >>> 32); From cbbcd8e4250aeec700f04c231f8be2f787243f1f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 13 Jan 2016 12:44:35 -0800 Subject: [PATCH 1558/2089] [SPARK-12791][SQL] Simplify CaseWhen by breaking "branches" into "conditions" and "values" This pull request rewrites CaseWhen expression to break the single, monolithic "branches" field into a sequence of tuples (Seq[(condition, value)]) and an explicit optional elseValue field. Prior to this pull request, each even position in "branches" represents the condition for each branch, and each odd position represents the value for each branch. The use of them have been pretty confusing with a lot sliding windows or grouped(2) calls. Author: Reynold Xin Closes #10734 from rxin/simplify-case. --- python/pyspark/sql/column.py | 24 +-- .../spark/sql/catalyst/CatalystQl.scala | 2 +- .../apache/spark/sql/catalyst/SqlParser.scala | 3 +- .../catalyst/analysis/HiveTypeCoercion.scala | 26 +++- .../expressions/conditionalExpressions.scala | 137 +++++++++--------- .../spark/sql/catalyst/trees/TreeNode.scala | 9 ++ .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 4 +- .../analysis/HiveTypeCoercionSuite.scala | 15 +- .../ConditionalExpressionSuite.scala | 51 +++---- .../scala/org/apache/spark/sql/Column.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 2 +- 12 files changed, 156 insertions(+), 138 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 900def59d23a5..320451c52c706 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -368,12 +368,12 @@ def when(self, condition, value): >>> from pyspark.sql import functions as F >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() - +-----+--------------------------------------------------------+ - | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| - +-----+--------------------------------------------------------+ - |Alice| -1| - | Bob| 1| - +-----+--------------------------------------------------------+ + +-----+------------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| + +-----+------------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+------------------------------------------------------------+ """ if not isinstance(condition, Column): raise TypeError("condition should be a Column") @@ -393,12 +393,12 @@ def otherwise(self, value): >>> from pyspark.sql import functions as F >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() - +-----+---------------------------------+ - | name|CASE WHEN (age > 3) THEN 1 ELSE 0| - +-----+---------------------------------+ - |Alice| 0| - | Bob| 1| - +-----+---------------------------------+ + +-----+-------------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END| + +-----+-------------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+-------------------------------------+ """ v = value._jc if isinstance(value, Column) else value jc = self._jc.otherwise(v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index c87b6c8e95436..d0fbdacf6eafd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -752,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Case statements */ case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen(branches.map(nodeToExpr)) + CaseWhen.createFromParser(branches.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => val keyExpr = nodeToExpr(branches.head) CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 6ec408a673c79..85ff4ea0c946b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -305,7 +305,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { throw new AnalysisException(s"invalid function approximate($s) $udfName") } } - | CASE ~> whenThenElse ^^ CaseWhen + | CASE ~> whenThenElse ^^ + { case branches => CaseWhen.createFromParser(branches) } | CASE ~> expression ~ whenThenElse ^^ { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 980b5d52fa8f7..2737fe32cd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -621,14 +621,24 @@ object HiveTypeCoercion { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => - val castedBranches = c.branches.grouped(2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case other => other - }.reduce(_ ++ _) - CaseWhen(castedBranches) + var changed = false + val newBranches = c.branches.map { case (condition, value) => + if (value.dataType.sameType(commonType)) { + (condition, value) + } else { + changed = true + (condition, Cast(value, commonType)) + } + } + val newElseValue = c.elseValue.map { value => + if (value.dataType.sameType(commonType)) { + value + } else { + changed = true + Cast(value, commonType) + } + } + if (changed) CaseWhen(newBranches, newElseValue) else c }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 5a1462433d583..8cc7bc1da2fc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -81,44 +81,39 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi /** * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". * When a = true, returns b; when c = true, returns d; else returns e. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch */ -case class CaseWhen(branches: Seq[Expression]) extends Expression { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches - - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq +case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) + extends Expression { - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { case Seq(dt1, dt2) => dt1.sameType(dt2) } - override def dataType: DataType = thenList.head.dataType + override def dataType: DataType = branches.head._2.dataType override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) + // Result is nullable if any of the branch is nullable, or if the else value is nullable + branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { + // Make sure all branch conditions are boolean types. if (valueTypesEqual) { - if (whenList.forall(_.dataType == BooleanType)) { + if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { - val index = whenList.indexWhere(_.dataType != BooleanType) + val index = branches.indexWhere(_._1.dataType != BooleanType) TypeCheckResult.TypeCheckFailure( s"WHEN expressions in CaseWhen should all be boolean type, " + - s"but the ${index + 1}th when expression's type is ${whenList(index)}") + s"but the ${index + 1}th when expression's type is ${branches(index)._1}") } } else { TypeCheckResult.TypeCheckFailure( @@ -127,31 +122,26 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { } override def eval(input: InternalRow): Any = { - // Written in imperative fashion for performance considerations - val len = branchesArr.length var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) + while (i < branches.size) { + if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { + return branches(i)._2.eval(input) } - i += 2 + i += 1 } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) + if (elseValue.isDefined) { + return elseValue.get.eval(input) + } else { + return null } - return res } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val len = branchesArr.length val got = ctx.freshName("got") - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) + val cases = branches.map { case (condition, value) => + val cond = condition.gen(ctx) + val res = value.gen(ctx) s""" if (!$got) { ${cond.code} @@ -165,17 +155,19 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { """ }.mkString("\n") - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" + val elseCase = { + if (elseValue.isDefined) { + val res = elseValue.get.gen(ctx) + s""" if (!$got) { ${res.code} ${ev.isNull} = ${res.isNull}; ${ev.value} = ${res.value}; } - """ - } else { - "" + """ + } else { + "" + } } s""" @@ -183,32 +175,42 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $cases - $other + $elseCase """ } override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString + val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString + val elseCase = elseValue.map(" ELSE " + _).getOrElse("") + "CASE" + cases + elseCase + " END" } override def sql: String = { - val branchesSQL = branches.map(_.sql) - val (cases, maybeElse) = if (branches.length % 2 == 0) { - (branchesSQL, None) - } else { - (branchesSQL.init, Some(branchesSQL.last)) - } + val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString + val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") + "CASE" + cases + elseCase + " END" + } +} - val head = s"CASE " - val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" - val body = cases.grouped(2).map { - case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" - }.mkString(" ") +/** Factory methods for CaseWhen. */ +object CaseWhen { - head + body + tail + def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { + CaseWhen(branches, Option(elseValue)) + } + + /** + * A factory method to faciliate the creation of this expression when used in parsers. + * @param branches Expressions at even position are the branch conditions, and expressions at odd + * position are branch values. + */ + def createFromParser(branches: Seq[Expression]): CaseWhen = { + val cases = branches.grouped(2).flatMap { + case cond :: value :: Nil => Some((cond, value)) + case value :: Nil => None + }.toArray.toSeq // force materialization to make the seq serializable + val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + CaseWhen(cases, elseValue) } } @@ -218,17 +220,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { */ object CaseKeyWhen { def apply(key: Expression, branches: Seq[Expression]): CaseWhen = { - val newBranches = branches.zipWithIndex.map { case (expr, i) => - if (i % 2 == 0 && i != branches.size - 1) { - // If this expression is at even position, then it is either a branch condition, or - // the very last value that is the "else value". The "i != branches.size - 1" makes - // sure we are not adding an EqualTo to the "else value". - EqualTo(key, expr) - } else { - expr - } - } - CaseWhen(newBranches) + val cases = branches.grouped(2).flatMap { + case cond :: value :: Nil => Some((EqualTo(key, cond), value)) + case value :: Nil => None + }.toArray.toSeq // force materialization to make the seq serializable + val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d4be545a35ab2..d0b29aa01f640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -315,6 +315,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { arg } + case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule) + val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule) + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } case other => other } case nonChild: AnyRef => nonChild diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cf84855885a37..975cd87d090e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -239,7 +239,7 @@ class AnalysisSuite extends AnalysisTest { test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) - val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) + val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) assertAnalysisSuccess(plan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 0521ed848c793..59549e3998e7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -132,13 +132,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)), + CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))), "WHEN expressions in CaseWhen should all be boolean type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 40378c6727667..b1f6c0b802d8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -308,15 +308,14 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) ruleTest(HiveTypeCoercion.CaseWhenCoercion, - CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), - CaseWhen(Seq( - Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Literal(1.2))), + Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) ruleTest(HiveTypeCoercion.CaseWhenCoercion, - CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), - CaseWhen(Seq( - Literal(true), Cast(Literal(100L), DecimalType(22, 2)), - Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) ) } @@ -452,7 +451,7 @@ class HiveTypeCoercionSuite extends PlanTest { val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), DecimalType(25, 5), DoubleType, DoubleType) - rightTypes.zip(expectedTypes).map { case (rType, expectedType) => + rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) => val plan2 = LocalRelation( AttributeReference("r", rType)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 4029da5925580..3c581ecdaf068 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -80,38 +80,39 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) - - checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c3, c4)), c6), "a", row) + checkEvaluation(CaseWhen(Seq((Literal.create(null, BooleanType), c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((Literal.create(false, BooleanType), c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((Literal.create(true, BooleanType), c4)), c6), "a", row) + + checkEvaluation(CaseWhen(Seq((c3, c4), (c2, c5)), c6), "a", row) + checkEvaluation(CaseWhen(Seq((c2, c4), (c3, c5)), c6), "b", row) + checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5))), null, row) + + assert(CaseWhen(Seq((c2, c4)), c6).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable === true) val c4_notNull = 'a.boolean.notNull.at(3) val c5_notNull = 'a.boolean.notNull.at(4) val c6_notNull = 'a.boolean.notNull.at(5) - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false) + assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull)), c6).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6_notNull).nullable === false) + assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } test("case key when") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e8c61d6e01dc3..6a020f9f2883e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "when() cannot be applied once otherwise() is applied") case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def otherwise(value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - if (branches.size % 2 == 0) { - withExpr { CaseWhen(branches :+ lit(value).expr) } - } else { - throw new IllegalArgumentException( - "otherwise() can only be applied once on a Column previously generated by when()") - } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches, Option(lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") case _ => throw new IllegalArgumentException( "otherwise() can only be applied on a Column previously generated by when()") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 71fea2716bd9f..b8ea2261e94e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = withExpr { - CaseWhen(Seq(condition.expr, lit(value).expr)) + CaseWhen(Seq((condition.expr, lit(value).expr))) } /** From eabc7b8ee7e809bab05361ed154f87bff467bd88 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Wed, 13 Jan 2016 13:28:39 -0800 Subject: [PATCH 1559/2089] [SPARK-12690][CORE] Fix NPE in UnsafeInMemorySorter.free() I hit the exception below. The `UnsafeKVExternalSorter` does pass `null` as the consumer when creating an `UnsafeInMemorySorter`. Normally the NPE doesn't occur because the `inMemSorter` is set to null later and the `free()` method is not called. It happens when there is another exception like OOM thrown before setting `inMemSorter` to null. Anyway, we can add the null check to avoid it. ``` ERROR spark.TaskContextImpl: Error in TaskCompletionListener java.lang.NullPointerException at org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.free(UnsafeInMemorySorter.java:110) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.cleanupResources(UnsafeExternalSorter.java:288) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter$1.onTaskCompletion(UnsafeExternalSorter.java:141) at org.apache.spark.TaskContextImpl$$anonfun$markTaskCompleted$1.apply(TaskContextImpl.scala:79) at org.apache.spark.TaskContextImpl$$anonfun$markTaskCompleted$1.apply(TaskContextImpl.scala:77) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at org.apache.spark.TaskContextImpl.markTaskCompleted(TaskContextImpl.scala:77) at org.apache.spark.scheduler.Task.run(Task.scala:91) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1110) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:603) at java.lang.Thread.run(Thread.java:722) ``` Author: Carson Wang Closes #10637 from carsonwang/FixNPE. --- .../util/collection/unsafe/sort/UnsafeInMemorySorter.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f71b8d154cc24..d1b0bc5d11f46 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -116,8 +116,10 @@ public UnsafeInMemorySorter( * Free the memory used by pointer array. */ public void free() { - consumer.freeArray(array); - array = null; + if (consumer != null) { + consumer.freeArray(array); + array = null; + } } public void reset() { From cd81fc9e8652c07b84f0887a24d67381b4e605fa Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 13 Jan 2016 16:34:23 -0800 Subject: [PATCH 1560/2089] [SPARK-12400][SHUFFLE] Avoid generating temp shuffle files for empty partitions This problem lies in `BypassMergeSortShuffleWriter`, empty partition will also generate a temp shuffle file with several bytes. So here change to only create file when partition is not empty. This problem only lies in here, no such issue in `HashShuffleWriter`. Please help to review, thanks a lot. Author: jerryshao Closes #10376 from jerryshao/SPARK-12400. --- .../sort/BypassMergeSortShuffleWriter.java | 25 ++++++------ .../BypassMergeSortShuffleWriterSuite.scala | 38 ++++++++++++++++++- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index a1a1fb01426a0..56cdc22f36261 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -138,7 +138,7 @@ public void write(Iterator> records) throws IOException { final File file = tempShuffleBlockIdPlusFile._2(); final BlockId blockId = tempShuffleBlockIdPlusFile._1(); partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be @@ -185,16 +185,19 @@ private long[] writePartitionedFile(File outputFile) throws IOException { boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { - final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); - boolean copyThrewException = true; - try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - if (!partitionWriters[i].fileSegment().file().delete()) { - logger.error("Unable to delete file for partition {}", i); + final File file = partitionWriters[i].fileSegment().file(); + if (file.exists()) { + final FileInputStream in = new FileInputStream(file); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } } threwException = false; diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index e33408b94e2cf..ef6ce04e3ff28 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -105,7 +105,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte new Answer[(TempShuffleBlockId, File)] { override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = File.createTempFile(blockId.toString, null, tempDir) + val file = new File(tempDir, blockId.name) blockIdToFileMap.put(blockId, file) temporaryFilesCreated.append(file) (blockId, file) @@ -166,6 +166,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ true) assert(temporaryFilesCreated.nonEmpty) assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.filter(_ == 0L).size === 4) // should be 4 zero length files assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) @@ -174,6 +175,41 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } + test("only generate temp shuffle file for non-empty partition") { + // Using exception to test whether only non-empty partition creates temp shuffle file, + // because temp shuffle file will only be cleaned after calling stop(false) in the failure + // case, so we could use it to validate the temp shuffle files. + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ + (0 until 100000).iterator.map { i => + if (i == 99990) { + throw new SparkException("intentional failure") + } else { + (2, 2) + } + } + + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf + ) + + intercept[SparkException] { + writer.write(records) + } + + assert(temporaryFilesCreated.nonEmpty) + // Only 3 temp shuffle files will be created + assert(temporaryFilesCreated.count(_.exists()) === 3) + + writer.stop( /* success = */ false) + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + } + test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, From 021dafc6a05a31dc22c9f9110dedb47a1f913087 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 13 Jan 2016 17:43:27 -0800 Subject: [PATCH 1561/2089] [SPARK-12026][MLLIB] ChiSqTest gets slower and slower over time when number of features is large jira: https://issues.apache.org/jira/browse/SPARK-12026 The issue is valid as features.toArray.view.zipWithIndex.slice(startCol, endCol) becomes slower as startCol gets larger. I tested on local and the change can improve the performance and the running time was stable. Author: Yuhao Yang Closes #10146 from hhbyyh/chiSq. --- .../scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index f22f2df320f0d..4a3fb06469818 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -109,7 +109,9 @@ private[stat] object ChiSqTest extends Logging { } i += 1 distinctLabels += label - features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + val brzFeatures = features.toBreeze + (startCol until endCol).map { col => + val feature = brzFeatures(col) allDistinctFeatures(col) += feature (col, feature, label) } @@ -122,7 +124,7 @@ private[stat] object ChiSqTest extends Logging { pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap } val numLabels = labels.size - pairCounts.keys.groupBy(_._1).map { case (col, keys) => + pairCounts.keys.groupBy(_._1).foreach { case (col, keys) => val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap val numRows = features.size val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) From 20d8ef858af6e13db59df118b562ea33cba5464d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 13 Jan 2016 18:01:29 -0800 Subject: [PATCH 1562/2089] [SPARK-12703][MLLIB][DOC][PYTHON] Fixed pyspark.mllib.clustering.KMeans user guide example Fixed WSSSE computeCost in Python mllib KMeans user guide example by using new computeCost method API in Python. Author: Joseph K. Bradley Closes #10707 from jkbradley/kmeans-doc-fix. --- docs/mllib-clustering.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 93cd0c1c61ae9..d0be03286849a 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -152,11 +152,7 @@ clusters = KMeans.train(parsedData, 2, maxIterations=10, runs=10, initializationMode="random") # Evaluate clustering by computing Within Set Sum of Squared Errors -def error(point): - center = clusters.centers[clusters.predict(point)] - return sqrt(sum([x**2 for x in (point - center)])) - -WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) +WSSSE = clusters.computeCost(parsedData) print("Within Set Sum of Squared Error = " + str(WSSSE)) # Save and load model From e2ae7bd046f6d8d6a375c2e81e5a51d7d78ca984 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 Jan 2016 21:02:54 -0800 Subject: [PATCH 1563/2089] [SPARK-12819] Deprecate TaskContext.isRunningLocally() We've already removed local execution but didn't deprecate `TaskContext.isRunningLocally()`; we should deprecate it for 2.0. Author: Josh Rosen Closes #10751 from JoshRosen/remove-local-exec-from-taskcontext. --- core/src/main/scala/org/apache/spark/CacheManager.scala | 5 ----- core/src/main/scala/org/apache/spark/TaskContext.scala | 3 ++- .../main/scala/org/apache/spark/TaskContextImpl.scala | 3 +-- .../src/main/scala/org/apache/spark/scheduler/Task.scala | 3 +-- .../test/scala/org/apache/spark/CacheManagerSuite.scala | 9 --------- 5 files changed, 4 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 4d20c7369376e..36b536e89c3a4 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -68,11 +68,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { logInfo(s"Partition $key not found, computing it") val computedValues = rdd.computeOrReadCheckpoint(partition, context) - // If the task is running locally, do not persist the result - if (context.isRunningLocally) { - return computedValues - } - // Otherwise, cache the values and keep track of any updates in block statuses val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index e25ed0fdd7fd2..7704abc134096 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -97,8 +97,9 @@ abstract class TaskContext extends Serializable { /** * Returns true if the task is running locally in the driver program. - * @return + * @return false */ + @deprecated("Local execution was removed, so this always returns false", "2.0.0") def isRunningLocally(): Boolean /** diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 6c493630997eb..94ff884b742b8 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -33,7 +33,6 @@ private[spark] class TaskContextImpl( override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, internalAccumulators: Seq[Accumulator[Long]], - val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext with Logging { @@ -85,7 +84,7 @@ private[spark] class TaskContextImpl( override def isCompleted(): Boolean = completed - override def isRunningLocally(): Boolean = runningLocally + override def isRunningLocally(): Boolean = false override def isInterrupted(): Boolean = interrupted diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 0379ca2af6ab3..fca57928eca1b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -74,8 +74,7 @@ private[spark] abstract class Task[T]( attemptNumber, taskMemoryManager, metricsSystem, - internalAccumulators, - runningLocally = false) + internalAccumulators) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index cb8bd04e496a7..30aa94c8a5971 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -82,15 +82,6 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before assert(value.toList === List(5, 6, 7)) } - test("get uncached local rdd") { - // Local computation should not persist the resulting value, so don't expect a put(). - when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - - val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } - test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager val context = TaskContext.empty() From 962e9bcf94da6f5134983f2bf1e56c5cd84f2bf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Jan 2016 22:43:28 -0800 Subject: [PATCH 1564/2089] [SPARK-12756][SQL] use hash expression in Exchange This PR makes bucketing and exchange share one common hash algorithm, so that we can guarantee the data distribution is same between shuffle and bucketed data source, which enables us to only shuffle one side when join a bucketed table and a normal one. This PR also fixes the tests that are broken by the new hash behaviour in shuffle. Author: Wenchen Fan Closes #10703 from cloud-fan/use-hash-expr-in-shuffle. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 26 +++++++-------- python/pyspark/sql/group.py | 6 ++-- .../plans/physical/partitioning.scala | 7 +++- .../apache/spark/sql/execution/Exchange.scala | 12 +++++-- .../datasources/WriterContainer.scala | 20 +++++------ .../apache/spark/sql/JavaDataFrameSuite.java | 4 +-- .../apache/spark/sql/JavaDatasetSuite.java | 33 +++++++++++-------- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++----- .../org/apache/spark/sql/DatasetSuite.scala | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../sql/sources/BucketedWriteSuite.scala | 11 ++++--- 12 files changed, 84 insertions(+), 64 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 97625b94a0e23..40d5066a93f4c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1173,7 +1173,7 @@ test_that("group by, agg functions", { expect_equal(3, count(mean(gd))) expect_equal(3, count(max(gd))) - expect_equal(30, collect(max(gd))[1, 2]) + expect_equal(30, collect(max(gd))[2, 2]) expect_equal(1, collect(count(gd))[1, 2]) mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a7bc288e38861..90a6b5d9c0dda 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -403,10 +403,10 @@ def repartition(self, numPartitions, *cols): +---+-----+ |age| name| +---+-----+ - | 2|Alice| - | 2|Alice| | 5| Bob| | 5| Bob| + | 2|Alice| + | 2|Alice| +---+-----+ >>> data = data.repartition(7, "age") >>> data.show() @@ -552,7 +552,7 @@ def alias(self, alias): >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() - [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @@ -573,14 +573,14 @@ def join(self, other, on=None, how=None): One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() - [Row(name=u'Tom', height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() - [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] @@ -880,9 +880,9 @@ def groupBy(self, *cols): >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] - >>> df.groupBy('name').agg({'age': 'mean'}).collect() + >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(df.name).avg().collect() + >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] @@ -901,11 +901,11 @@ def rollup(self, *cols): +-----+----+-----+ | name| age|count| +-----+----+-----+ - |Alice|null| 1| + |Alice| 2| 1| | Bob| 5| 1| | Bob|null| 1| | null|null| 2| - |Alice| 2| 1| + |Alice|null| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) @@ -923,12 +923,12 @@ def cube(self, *cols): | name| age|count| +-----+----+-----+ | null| 2| 1| - |Alice|null| 1| + |Alice| 2| 1| | Bob| 5| 1| - | Bob|null| 1| | null| 5| 1| + | Bob|null| 1| | null|null| 2| - |Alice| 2| 1| + |Alice|null| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 9ca303a974cd4..ee734cb439287 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -74,11 +74,11 @@ def agg(self, *exprs): or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() + >>> sorted(gdf.agg({"*": "count"}).collect()) [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() + >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" @@ -96,7 +96,7 @@ def agg(self, *exprs): def count(self): """Counts the number of records for each group. - >>> df.groupBy(df.age).count().collect() + >>> sorted(df.groupBy(df.age).count().collect()) [Row(age=2, count=1), Row(age=5, count=1)] """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1bfe0ecb1e20b..d6e10c412ca1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, Unevaluable} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -249,6 +249,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + /** + * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less + * than numPartitions) based on hashing expressions. + */ + def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 058d147c7d65d..3770883af1e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -143,7 +143,13 @@ case class Exchange( val rdd = child.execute() val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -173,7 +179,9 @@ case class Exchange( position += 1 position } - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output) + row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index fff72872c13b1..fc77529b7db32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.JavaConverters._ - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} @@ -30,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} @@ -322,9 +321,12 @@ private[sql] class DynamicPartitionWriterContainer( spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) } - private def bucketIdExpression: Option[Expression] = for { - BucketSpec(numBuckets, _, _) <- bucketSpec - } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) + private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } // Expressions that given a partition key build a string like: col1=val/col2=val/... private def partitionStringExpression: Seq[Expression] = { @@ -341,12 +343,8 @@ private[sql] class DynamicPartitionWriterContainer( } } - private def getBucketIdFromKey(key: InternalRow): Option[Int] = { - if (bucketSpec.isDefined) { - Some(key.getInt(partitionColumns.length)) - } else { - None - } + private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ => + key.getInt(partitionColumns.length) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 8e0b2dbca4a98..ac1607ba3521a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -237,8 +237,8 @@ public void testCrosstab() { DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); - Assert.assertEquals("1", columnNames[1]); - Assert.assertEquals("2", columnNames[2]); + Assert.assertEquals("2", columnNames[1]); + Assert.assertEquals("1", columnNames[2]); Row[] rows = crosstab.collect(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 9f8db39e33d7e..1a3df1b117b68 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -187,7 +187,7 @@ public String call(Integer key, Iterator values) throws Exception { } }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); Dataset flatMapped = grouped.flatMapGroups( new FlatMapGroupsFunction() { @@ -202,7 +202,7 @@ public Iterable call(Integer key, Iterator values) throws Except }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); Dataset> reduced = grouped.reduce(new ReduceFunction() { @Override @@ -212,8 +212,8 @@ public String call(String v1, String v2) throws Exception { }); Assert.assertEquals( - Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), - reduced.collectAsList()); + asSet(tuple2(1, "a"), tuple2(3, "foobar")), + toSet(reduced.collectAsList())); List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, Encoders.INT()); @@ -245,7 +245,7 @@ public Iterable call( }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList())); } @Test @@ -268,7 +268,7 @@ public String call(Integer key, Iterator data) throws Exception { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); } @Test @@ -290,9 +290,7 @@ public void testSetOperation() { List data = Arrays.asList("abc", "abc", "xyz"); Dataset ds = context.createDataset(data, Encoders.STRING()); - Assert.assertEquals( - Arrays.asList("abc", "xyz"), - sort(ds.distinct().collectAsList().toArray(new String[0]))); + Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List data2 = Arrays.asList("xyz", "foo", "foo"); Dataset ds2 = context.createDataset(data2, Encoders.STRING()); @@ -302,16 +300,23 @@ public void testSetOperation() { Dataset unioned = ds.union(ds2); Assert.assertEquals( - Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), - sort(unioned.collectAsList().toArray(new String[0]))); + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"), + unioned.collectAsList()); Dataset subtracted = ds.subtract(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); } - private > List sort(T[] data) { - Arrays.sort(data); - return Arrays.asList(data); + private Set toSet(List records) { + Set set = new HashSet(); + for (T record : records) { + set.add(record); + } + return set; + } + + private Set asSet(T... records) { + return toSet(Arrays.asList(records)); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 983dfbdedeefe..d6c140dfea9ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1083,17 +1083,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Walk each partition and verify that it is sorted descending and does not contain all // the values. df4.rdd.foreachPartition { p => - var previousValue: Int = -1 - var allSequential: Boolean = true - p.foreach { r => - val v: Int = r.getInt(1) - if (previousValue != -1) { - if (previousValue < v) throw new SparkException("Partition is not ordered.") - if (v + 1 != previousValue) allSequential = false + // Skip empty partition + if (p.hasNext) { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v } - previousValue = v + if (allSequential) throw new SparkException("Partition should not be globally ordered") } - if (allSequential) throw new SparkException("Partition should not be globally ordered") } // Distribute and order by with multiple order bys diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 693f5aea2d015..d7b86e381108e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -456,8 +456,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSeq == - Seq((KryoData(1), 1L), (KryoData(2), 1L))) + assert(ds.groupBy(p => p).count().collect().toSet == + Set((KryoData(1), 1L), (KryoData(2), 1L))) } test("Kryo encoder self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5de0979606b88..03d67c4e91f7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -806,7 +806,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") - sql("SELECT DISTINCT n FROM lowerCaseData") + sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC") .limit(2) .registerTempTable("subset2") checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index b718b7cefb2a4..3ea9826544edb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -98,11 +98,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() - val getHashCode = - UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output) + val getHashCode = UnsafeProjection.create( + HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, + qe.analyzed.output) for (row <- rows) { - val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8) + val actualBucketId = getHashCode(row).getInt(0) assert(actualBucketId == bucketId) } } From 8f13cd4cc8dcf638b178774418669a2e247d0652 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 13 Jan 2016 23:50:08 -0800 Subject: [PATCH 1565/2089] =?UTF-8?q?[SPARK-12707][SPARK=20SUBMIT]=20Remov?= =?UTF-8?q?e=20submit=20python/R=20scripts=20through=20py=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …spark/sparkR Author: Jeff Zhang Closes #10658 from zjffdu/SPARK-12707. --- .../spark/launcher/SparkSubmitCommandBuilder.java | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index a95f0f17517d1..269c89c310550 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -231,11 +231,9 @@ private List buildPySparkShellCommand(Map env) throws IO // the pyspark command line, then run it using spark-submit. if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".py")) { System.err.println( - "WARNING: Running python applications through 'pyspark' is deprecated as of Spark 1.0.\n" + + "Running python applications through 'pyspark' is not supported as of Spark 2.0.\n" + "Use ./bin/spark-submit "); - appResource = appArgs.get(0); - appArgs.remove(0); - return buildCommand(env); + System.exit(-1); } checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); @@ -258,9 +256,10 @@ private List buildPySparkShellCommand(Map env) throws IO private List buildSparkRCommand(Map env) throws IOException { if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { - appResource = appArgs.get(0); - appArgs.remove(0); - return buildCommand(env); + System.err.println( + "Running R applications through 'sparkR' is not supported as of Spark 2.0.\n" + + "Use ./bin/spark-submit "); + System.exit(-1); } // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS // env variable. From 56cdbd654d54bf07a063a03a5c34c4165818eeb2 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Jan 2016 10:59:02 +0000 Subject: [PATCH 1566/2089] [SPARK-9844][CORE] File appender race condition during shutdown When an Executor process is destroyed, the FileAppender that is asynchronously reading the stderr stream of the process can throw an IOException during read because the stream is closed. Before the ExecutorRunner destroys the process, the FileAppender thread is flagged to stop. This PR wraps the inputStream.read call of the FileAppender in a try/catch block so that if an IOException is thrown and the thread has been flagged to stop, it will safely ignore the exception. Additionally, the FileAppender thread was changed to use Utils.tryWithSafeFinally to better log any exception that do occur. Added unit tests to verify a IOException is thrown and logged if FileAppender is not flagged to stop, and that no IOException when the flag is set. Author: Bryan Cutler Closes #10714 from BryanCutler/file-appender-read-ioexception-SPARK-9844. --- .../spark/util/logging/FileAppender.scala | 28 ++++--- .../apache/spark/util/FileAppenderSuite.scala | 77 +++++++++++++++++++ 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 58c8560a3d049..86bbaa20f6cf2 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.logging -import java.io.{File, FileOutputStream, InputStream} +import java.io.{File, FileOutputStream, InputStream, IOException} import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.{IntParam, Utils} @@ -58,20 +58,28 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi protected def appendStreamToFile() { try { logDebug("Started appending thread") - openFile() - val buf = new Array[Byte](bufferSize) - var n = 0 - while (!markedForStop && n != -1) { - n = inputStream.read(buf) - if (n != -1) { - appendToFile(buf, n) + Utils.tryWithSafeFinally { + openFile() + val buf = new Array[Byte](bufferSize) + var n = 0 + while (!markedForStop && n != -1) { + try { + n = inputStream.read(buf) + } catch { + // An InputStream can throw IOException during read if the stream is closed + // asynchronously, so once appender has been flagged to stop these will be ignored + case _: IOException if markedForStop => // do nothing and proceed to stop appending + } + if (n > 0) { + appendToFile(buf, n) + } } + } { + closeFile() } } catch { case e: Exception => logError(s"Error writing stream to file $file", e) - } finally { - closeFile() } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 98d1b28d5a167..b367cc8358342 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.util import java.io._ +import java.util.concurrent.CountDownLatch import scala.collection.mutable.HashSet import scala.reflect._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.log4j.{Appender, Level, Logger} +import org.apache.log4j.spi.LoggingEvent +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.{atLeast, mock, verify} import org.scalatest.BeforeAndAfter import org.apache.spark.{Logging, SparkConf, SparkFunSuite} @@ -188,6 +193,67 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { testAppenderSelection[FileAppender, Any](rollingStrategy("xyz")) } + test("file appender async close stream abruptly") { + // Test FileAppender reaction to closing InputStream using a mock logging appender + val mockAppender = mock(classOf[Appender]) + val loggingEventCaptor = new ArgumentCaptor[LoggingEvent] + + // Make sure only logging errors + val logger = Logger.getRootLogger + logger.setLevel(Level.ERROR) + logger.addAppender(mockAppender) + + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) + + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) + + appender.awaitTermination() + + // If InputStream was closed without first stopping the appender, an exception will be logged + verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture) + val loggingEvent = loggingEventCaptor.getValue + assert(loggingEvent.getThrowableInformation !== null) + assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } + + test("file appender async close stream gracefully") { + // Test FileAppender reaction to closing InputStream using a mock logging appender + val mockAppender = mock(classOf[Appender]) + val loggingEventCaptor = new ArgumentCaptor[LoggingEvent] + + // Make sure only logging errors + val logger = Logger.getRootLogger + logger.setLevel(Level.ERROR) + logger.addAppender(mockAppender) + + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream + + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) + + // Stop the appender before an IOException is called during read + testInputStream.latchReadStarted.await() + appender.stop() + testInputStream.latchReadProceed.countDown() + + appender.awaitTermination() + + // Make sure no IOException errors have been logged as a result of appender closing gracefully + verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture) + import scala.collection.JavaConverters._ + loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent => + assert(loggingEvent.getThrowableInformation === null + || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } + } + /** * Run the rolling file appender with data and see whether all the data was written correctly * across rolled over files. @@ -228,4 +294,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { file.getName.startsWith(testFile.getName) }.foreach { _.delete() } } + + /** Used to synchronize when read is called on a stream */ + private trait LatchedInputStream extends PipedInputStream { + val latchReadStarted = new CountDownLatch(1) + val latchReadProceed = new CountDownLatch(1) + abstract override def read(): Int = { + latchReadStarted.countDown() + latchReadProceed.await() + super.read() + } + } } From 501e99ef0fbd2f2165095548fe67a3447ccbfc91 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 14 Jan 2016 09:50:57 -0800 Subject: [PATCH 1567/2089] [SPARK-12784][UI] Fix Spark UI IndexOutOfBoundsException with dynamic allocation Add `listener.synchronized` to get `storageStatusList` and `execInfo` atomically. Author: Shixiong Zhu Closes #10728 from zsxwing/SPARK-12784. --- .../spark/status/api/v1/ExecutorListResource.scala | 10 +++++++--- .../org/apache/spark/ui/exec/ExecutorsPage.scala | 13 ++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 8ad4656b4dada..3bdba922328c2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -28,9 +28,13 @@ private[v1] class ExecutorListResource(ui: SparkUI) { @GET def executorList(): Seq[ExecutorSummary] = { val listener = ui.executorsListener - val storageStatusList = listener.storageStatusList - (0 until storageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId) + listener.synchronized { + // The follow codes should be protected by `listener` to make sure no executors will be + // removed before we query their status. See SPARK-12784. + val storageStatusList = listener.storageStatusList + (0 until storageStatusList.size).map { statusId => + ExecutorsPage.getExecInfo(listener, statusId) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 1a29b0f412603..7072a152d6b69 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -52,12 +52,19 @@ private[ui] class ExecutorsPage( private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val storageStatusList = listener.storageStatusList + val (storageStatusList, execInfo) = listener.synchronized { + // The follow codes should be protected by `listener` to make sure no executors will be + // removed before we query their status. See SPARK-12784. + val _storageStatusList = listener.storageStatusList + val _execInfo = { + for (statusId <- 0 until _storageStatusList.size) + yield ExecutorsPage.getExecInfo(listener, statusId) + } + (_storageStatusList, _execInfo) + } val maxMem = storageStatusList.map(_.maxMem).sum val memUsed = storageStatusList.map(_.memUsed).sum val diskUsed = storageStatusList.map(_.diskUsed).sum - val execInfo = for (statusId <- 0 until storageStatusList.size) yield - ExecutorsPage.getExecInfo(listener, statusId) val execInfoSorted = execInfo.sortBy(_.id) val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty From 902667fd2766f0472a15851b1ed8fb5859593f97 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Jan 2016 10:09:03 -0800 Subject: [PATCH 1568/2089] [SPARK-12771][SQL] Simplify CaseWhen code generation The generated code for CaseWhen uses a control variable "got" to make sure we do not evaluate more branches once a branch is true. Changing that to generate just simple "if / else" would be slightly more efficient. This closes #10737. Author: Reynold Xin Closes #10755 from rxin/SPARK-12771. --- .../expressions/conditionalExpressions.scala | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 8cc7bc1da2fc3..83abbcdc61175 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -137,45 +137,55 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val got = ctx.freshName("got") - - val cases = branches.map { case (condition, value) => - val cond = condition.gen(ctx) - val res = value.gen(ctx) + // Generate code that looks like: + // + // condA = ... + // if (condA) { + // valueA + // } else { + // condB = ... + // if (condB) { + // valueB + // } else { + // condC = ... + // if (condC) { + // valueC + // } else { + // elseValue + // } + // } + // } + val cases = branches.map { case (condExpr, valueExpr) => + val cond = condExpr.gen(ctx) + val res = valueExpr.gen(ctx) s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${cond.value}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } + ${cond.code} + if (!${cond.isNull} && ${cond.value}) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.value} = ${res.value}; } """ - }.mkString("\n") + } - val elseCase = { - if (elseValue.isDefined) { - val res = elseValue.get.gen(ctx) + var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") + + elseValue.foreach { elseExpr => + val res = elseExpr.gen(ctx) + generatedCode += s""" - if (!$got) { ${res.code} ${ev.isNull} = ${res.isNull}; ${ev.value} = ${res.value}; - } """ - } else { - "" - } } + generatedCode += "}\n" * cases.size + s""" - boolean $got = false; boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $cases - $elseCase + $generatedCode """ } From bcc7373f673d1a51b48fb95432ba5c4644dd5d23 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 14 Jan 2016 10:43:39 -0800 Subject: [PATCH 1569/2089] [SPARK-12821][BUILD] Style checker should run when some configuration files for style are modified but any source files are not. When running the `run-tests` script, style checkers run only when any source files are modified but they should run when configuration files related to style are modified. Author: Kousuke Saruta Closes #10754 from sarutak/SPARK-12821. --- dev/run-tests.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 8726889cbc777..795db0dcfbab9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -529,9 +529,14 @@ def main(): run_apache_rat_checks() # style checks - if not changed_files or any(f.endswith(".scala") for f in changed_files): + if not changed_files or any(f.endswith(".scala") + or f.endswith("scalastyle-config.xml") + for f in changed_files): run_scala_style_checks() - if not changed_files or any(f.endswith(".java") for f in changed_files): + if not changed_files or any(f.endswith(".java") + or f.endswith("checkstyle.xml") + or f.endswith("checkstyle-suppressions.xml") + for f in changed_files): # run_java_style_checks() pass if not changed_files or any(f.endswith(".py") for f in changed_files): From 25782981cf58946dc7c186acadd2beec5d964461 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 14 Jan 2016 17:37:27 -0800 Subject: [PATCH 1570/2089] [SPARK-12174] Speed up BlockManagerSuite getRemoteBytes() test This patch significantly speeds up the BlockManagerSuite's "SPARK-9591: getRemoteBytes from another location when Exception throw" test, reducing the test time from 45s to ~250ms. The key change was to set `spark.shuffle.io.maxRetries` to 0 (the code previously set `spark.network.timeout` to `2s`, but this didn't make a difference because the slowdown was not due to this timeout). Along the way, I also cleaned up the way that we handle SparkConf in BlockManagerSuite: previously, each test would mutate a shared SparkConf instance, while now each test gets a fresh SparkConf. Author: Josh Rosen Closes #10759 from JoshRosen/SPARK-12174. --- .../spark/storage/BlockManagerSuite.scala | 71 ++++++++----------- 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 67210e5d4c50e..62e6c4f7932df 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -45,20 +45,18 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - private val conf = new SparkConf(false).set("spark.app.id", "test") + var conf: SparkConf = null var store: BlockManager = null var store2: BlockManager = null var store3: BlockManager = null var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null - conf.set("spark.authenticate", "false") - val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) - val shuffleManager = new HashShuffleManager(conf) + val securityMgr = new SecurityManager(new SparkConf(false)) + val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false)) + val shuffleManager = new HashShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer", "1m") - val serializer = new KryoSerializer(conf) + val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) @@ -79,15 +77,17 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def beforeEach(): Unit = { super.beforeEach() - rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") - conf.set("os.arch", "amd64") - conf.set("spark.test.useCompressedOops", "true") + conf = new SparkConf(false) + .set("spark.app.id", "test") + .set("spark.kryoserializer.buffer", "1m") + .set("spark.test.useCompressedOops", "true") + .set("spark.storage.unrollFraction", "0.4") + .set("spark.storage.unrollMemoryThreshold", "512") + + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - conf.set("spark.storage.unrollFraction", "0.4") - conf.set("spark.storage.unrollMemoryThreshold", "512") master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) @@ -98,6 +98,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def afterEach(): Unit = { try { + conf = null if (store != null) { store.stop() store = null @@ -473,34 +474,22 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { - val origTimeoutOpt = conf.getOption("spark.network.timeout") - try { - conf.set("spark.network.timeout", "2s") - store = makeBlockManager(8000, "executor1") - store2 = makeBlockManager(8000, "executor2") - store3 = makeBlockManager(8000, "executor3") - val list1 = List(new Array[Byte](4000)) - store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - var list1Get = store.getRemoteBytes("list1") - assert(list1Get.isDefined, "list1Get expected to be fetched") - // block manager exit - store2.stop() - store2 = null - list1Get = store.getRemoteBytes("list1") - // get `list1` block - assert(list1Get.isDefined, "list1Get expected to be fetched") - store3.stop() - store3 = null - // exception throw because there is no locations - intercept[BlockFetchException] { - list1Get = store.getRemoteBytes("list1") - } - } finally { - origTimeoutOpt match { - case Some(t) => conf.set("spark.network.timeout", t) - case None => conf.remove("spark.network.timeout") - } + conf.set("spark.shuffle.io.maxRetries", "0") + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store3 = makeBlockManager(8000, "executor3") + val list1 = List(new Array[Byte](4000)) + store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") + store2.stop() + store2 = null + assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") + store3.stop() + store3 = null + // exception throw because there is no locations + intercept[BlockFetchException] { + store.getRemoteBytes("list1") } } From cc7af86afd3e769d1e2a581f31bb3db5a3d0229f Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 Jan 2016 17:44:56 -0800 Subject: [PATCH 1571/2089] [SPARK-12813][SQL] Eliminate serialization for back to back operations The goal of this PR is to eliminate unnecessary translations when there are back-to-back `MapPartitions` operations. In order to achieve this I also made the following simplifications: - Operators no longer have hold encoders, instead they have only the expressions that they need. The benefits here are twofold: the expressions are visible to transformations so go through the normal resolution/binding process. now that they are visible we can change them on a case by case basis. - Operators no longer have type parameters. Since the engine is responsible for its own type checking, having the types visible to the complier was an unnecessary complication. We still leverage the scala compiler in the companion factory when constructing a new operator, but after this the types are discarded. Deferred to a follow up PR: - Remove as much of the resolution/binding from Dataset/GroupedDataset as possible. We should still eagerly check resolution and throw an error though in the case of mismatches for an `as` operation. - Eliminate serializations in more cases by adding more cases to `EliminateSerialization` Author: Michael Armbrust Closes #10747 from marmbrus/encoderExpressions. --- .../sql/catalyst/analysis/Analyzer.scala | 4 + .../sql/catalyst/analysis/unresolved.scala | 5 + .../catalyst/encoders/ExpressionEncoder.scala | 10 + .../catalyst/expressions/BoundAttribute.scala | 4 +- .../expressions/namedExpressions.scala | 6 + .../sql/catalyst/expressions/objects.scala | 4 + .../sql/catalyst/optimizer/Optimizer.scala | 16 +- .../plans/logical/basicOperators.scala | 119 ----------- .../sql/catalyst/plans/logical/object.scala | 185 ++++++++++++++++++ .../EliminateSerializationSuite.scala | 76 +++++++ .../scala/org/apache/spark/sql/Dataset.scala | 9 +- .../org/apache/spark/sql/GroupedDataset.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 19 +- .../spark/sql/execution/basicOperators.scala | 127 ------------ .../apache/spark/sql/execution/objects.scala | 182 +++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 ++ .../org/apache/spark/sql/QueryTest.scala | 8 +- 17 files changed, 518 insertions(+), 274 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8a33af8207350..dadea6b54a946 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1214,6 +1214,10 @@ object CleanupAliases extends Rule[LogicalPlan] { Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + // Operators that operate on objects should only have expressions from encoders, which should + // never have extra aliases. + case o: ObjectOperator => o + case other => var stop = false other transformExpressionsDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index fc0e87aa68ed4..79eebbf9b1ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -160,6 +160,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] @@ -246,6 +247,8 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") + override lazy val resolved = false override def toString: String = s"$child AS $names" @@ -259,6 +262,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * @param expressions Expressions to expand. */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } @@ -298,6 +302,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def name: String = throw new UnresolvedException(this, "name") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 05f746e72b498..64832dc114e67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -207,6 +207,16 @@ case class ExpressionEncoder[T]( resolve(attrs, OuterScopes.outerScopes).bind(attrs) } + + /** + * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form + * of this object. + */ + def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map { + case (_, ne: NamedExpression) => ne.newInstance() + case (name, e) => Alias(e, name)() + } + /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7293d5d4472af..c94b2c0e270b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal, $dataType]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { @@ -66,6 +66,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException + override def newInstance(): NamedExpression = this + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eee708cb02f9d..b6d7a7f5e8d01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -79,6 +79,9 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty + /** Returns a copy of this expression with a new `exprId`. */ + def newInstance(): NamedExpression + protected def typeSuffix = if (resolved) { dataType match { @@ -144,6 +147,9 @@ case class Alias(child: Expression, name: String)( } } + def newInstance(): NamedExpression = + Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata) + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index c0c3e6e891669..8385f7e1da591 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -172,6 +172,8 @@ case class Invoke( $objNullCheck """ } + + override def toString: String = s"$targetObject.$functionName" } object NewInstance { @@ -253,6 +255,8 @@ case class NewInstance( """ } } + + override def toString: String = s"newInstance($cls)" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 487431f8925a3..cc3371c08fac4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -67,7 +67,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + EliminateSerialization) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -96,6 +97,19 @@ object SamplePushDown extends Rule[LogicalPlan] { } } +/** + * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) + * representation of data item. For example back to back map operations. + */ +object EliminateSerialization extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case m @ MapPartitions(_, input, _, child: ObjectOperator) + if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization) + } +} + /** * Pushes certain operations to both sides of a Union, Intersect or Except operator. * Operations that are safe to pushdown are listed as follows. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64957db6b4013..2a1b1b131d813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -480,120 +478,3 @@ case object OneRowRelation extends LeafNode { */ override def statistics: Statistics = Statistics(sizeInBytes = 1) } - -/** - * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are - * used respectively to decode/encode from the JVM object representation expected by `func.` - */ -case class MapPartitions[T, U]( - func: Iterator[T] => Iterator[U], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet -} - -/** Factory for constructing new `AppendColumn` nodes. */ -object AppendColumns { - def apply[T, U : Encoder]( - func: T => U, - tEncoder: ExpressionEncoder[T], - child: LogicalPlan): AppendColumns[T, U] = { - val attrs = encoderFor[U].schema.toAttributes - new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child) - } -} - -/** - * A relation produced by applying `func` to each partition of the `child`, concatenating the - * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to - * decode/encode from the JVM object representation expected by `func.` - */ -case class AppendColumns[T, U]( - func: T => U, - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - newColumns: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output ++ newColumns - override def producedAttributes: AttributeSet = AttributeSet(newColumns) -} - -/** Factory for constructing new `MapGroups` nodes. */ -object MapGroups { - def apply[K, T, U : Encoder]( - func: (K, Iterator[T]) => TraversableOnce[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - groupingAttributes: Seq[Attribute], - child: LogicalPlan): MapGroups[K, T, U] = { - new MapGroups( - func, - kEncoder, - tEncoder, - encoderFor[U], - groupingAttributes, - encoderFor[U].schema.toAttributes, - child) - } -} - -/** - * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. - * Func is invoked with an object representation of the grouping key an iterator containing the - * object representation of all the rows with that key. - */ -case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => TraversableOnce[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - groupingAttributes: Seq[Attribute], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet -} - -/** Factory for constructing new `CoGroup` nodes. */ -object CoGroup { - def apply[Key, Left, Right, Result : Encoder]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], - keyEnc: ExpressionEncoder[Key], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: LogicalPlan, - right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { - CoGroup( - func, - keyEnc, - leftEnc, - rightEnc, - encoderFor[Result], - encoderFor[Result].schema.toAttributes, - leftGroup, - rightGroup, - left, - right) - } -} - -/** - * A relation produced by applying `func` to each grouping key and associated values from left and - * right children. - */ -case class CoGroup[Key, Left, Right, Result]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], - keyEnc: ExpressionEncoder[Key], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - resultEnc: ExpressionEncoder[Result], - output: Seq[Attribute], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: LogicalPlan, - right: LogicalPlan) extends BinaryNode { - override def producedAttributes: AttributeSet = outputSet -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala new file mode 100644 index 0000000000000..760348052739c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.ObjectType + +/** + * A trait for logical operators that apply user defined functions to domain objects. + */ +trait ObjectOperator extends LogicalPlan { + + /** The serializer that is used to produce the output of this operator. */ + def serializer: Seq[NamedExpression] + + /** + * The object type that is produced by the user defined function. Note that the return type here + * is the same whether or not the operator is output serialized data. + */ + def outputObject: NamedExpression = + Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")() + + /** + * Returns a copy of this operator that will produce an object instead of an encoded row. + * Used in the optimizer when transforming plans to remove unneeded serialization. + */ + def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { + this + } else { + withNewSerializer(outputObject) + } + + /** Returns a copy of this operator with a different serializer. */ + def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy { + productIterator.map { + case c if c == serializer => newSerializer :: Nil + case other: AnyRef => other + }.toArray + } +} + +object MapPartitions { + def apply[T : Encoder, U : Encoder]( + func: Iterator[T] => Iterator[U], + child: LogicalPlan): MapPartitions = { + MapPartitions( + func.asInstanceOf[Iterator[Any] => Iterator[Any]], + encoderFor[T].fromRowExpression, + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`. + * @param input used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapPartitions( + func: Iterator[Any] => Iterator[Any], + input: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `AppendColumn` nodes. */ +object AppendColumns { + def apply[T : Encoder, U : Encoder]( + func: T => U, + child: LogicalPlan): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + encoderFor[T].fromRowExpression, + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`, concatenating the + * resulting columns at the end of the input row. + * @param input used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class AppendColumns( + func: Any => Any, + input: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = child.output ++ newColumns + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroups { + def apply[K : Encoder, T : Encoder, U : Encoder]( + func: (K, Iterator[T]) => TraversableOnce[U], + groupingAttributes: Seq[Attribute], + child: LogicalPlan): MapGroups = { + new MapGroups( + func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], + encoderFor[K].fromRowExpression, + encoderFor[T].fromRowExpression, + encoderFor[U].namedExpressions, + groupingAttributes, + child) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + * @param keyObject used to extract the key object for each group. + * @param input used to extract the items in the iterator from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapGroups( + func: (Any, Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + input: Expression, + serializer: Seq[NamedExpression], + groupingAttributes: Seq[Attribute], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + + def output: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `CoGroup` nodes. */ +object CoGroup { + def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan): CoGroup = { + CoGroup( + func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], + encoderFor[Key].fromRowExpression, + encoderFor[Left].fromRowExpression, + encoderFor[Right].fromRowExpression, + encoderFor[Result].namedExpressions, + leftGroup, + rightGroup, + left, + right) + } +} + +/** + * A relation produced by applying `func` to each grouping key and associated values from left and + * right children. + */ +case class CoGroup( + func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + leftObject: Expression, + rightObject: Expression, + serializer: Seq[NamedExpression], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode with ObjectOperator { + override def producedAttributes: AttributeSet = outputSet + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala new file mode 100644 index 0000000000000..91777375608fd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.NewInstance +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +case class OtherTuple(_1: Int, _2: Int) + +class EliminateSerializationSuite extends PlanTest { + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Serialization", FixedPoint(100), + EliminateSerialization) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + private val func = identity[Iterator[(Int, Int)]] _ + private val func2 = identity[Iterator[OtherTuple]] _ + + def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = { + val newInstances = plan.flatMap(_.expressions.collect { + case n: NewInstance => n + }) + + if (newInstances.size != count) { + fail( + s""" + |Wrong number of object creations in plan: ${newInstances.size} != $count + |$plan + """.stripMargin) + } + } + + test("back to back MapPartitions") { + val input = LocalRelation('_1.int, '_2.int) + val plan = + MapPartitions(func, + MapPartitions(func, input)) + + val optimized = Optimize.execute(plan.analyze) + assertObjectCreations(1, optimized) + } + + test("back to back with object change") { + val input = LocalRelation('_1.int, '_2.int) + val plan = + MapPartitions(func, + MapPartitions(func2, input)) + + val optimized = Optimize.execute(plan.analyze) + assertObjectCreations(2, optimized) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 42f01e9359c64..9a9f7d111cf4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -336,12 +336,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sqlContext, - MapPartitions[T, U]( - func, - resolvedTEncoder, - encoderFor[U], - encoderFor[U].schema.toAttributes, - logicalPlan)) + MapPartitions[T, U](func, logicalPlan)) } /** @@ -434,7 +429,7 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = logicalPlan - val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) + val withGroupingKey = AppendColumns(func, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a819ddceb1b1b..b3f8284364782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -115,8 +115,6 @@ class GroupedDataset[K, V] private[sql]( sqlContext, MapGroups( f, - resolvedKEncoder, - resolvedVEncoder, groupingAttributes, logicalPlan)) } @@ -305,13 +303,11 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + implicit val uEncoder = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( f, - this.resolvedKEncoder, - this.resolvedVEncoder, - other.resolvedVEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 482130a18d939..910519d0e6814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -309,16 +309,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.MapPartitions(f, tEnc, uEnc, output, child) => - execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => - execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil - case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => - execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil - case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, - leftGroup, rightGroup, left, right) => - execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, - planLater(left), planLater(right)) :: Nil + case logical.MapPartitions(f, in, out, child) => + execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.AppendColumns(f, in, out, child) => + execution.AppendColumns(f, in, out, planLater(child)) :: Nil + case logical.MapGroups(f, key, in, out, grouping, child) => + execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil + case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) => + execution.CoGroup( + f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 95bef683238a7..92c9a561312ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -21,9 +21,7 @@ import org.apache.spark.{HashPartitioner, SparkEnv} import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -329,128 +327,3 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl protected override def doExecute(): RDD[InternalRow] = child.execute() } - -/** - * Applies the given function to each input row and encodes the result. - */ -case class MapPartitions[T, U]( - func: Iterator[T] => Iterator[U], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - output: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => - val tBoundEncoder = tEncoder.bind(child.output) - func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) - } - } -} - -/** - * Applies the given function to each input row, appending the encoded result at the end of the row. - */ -case class AppendColumns[T, U]( - func: T => U, - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - newColumns: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = AttributeSet(newColumns) - - override def output: Seq[Attribute] = child.output ++ newColumns - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => - val tBoundEncoder = tEncoder.bind(child.output) - val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) - iter.map { row => - val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) - combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow - } - } - } -} - -/** - * Groups the input rows together and calls the function with each group and an iterator containing - * all elements in the group. The result of this function is encoded and flattened before - * being output. - */ -case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => TraversableOnce[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - groupingAttributes: Seq[Attribute], - output: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => - val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val groupKeyEncoder = kEncoder.bind(groupingAttributes) - val groupDataEncoder = tEncoder.bind(child.output) - - grouped.flatMap { case (key, rowIter) => - val result = func( - groupKeyEncoder.fromRow(key), - rowIter.map(groupDataEncoder.fromRow)) - result.map(uEncoder.toRow) - } - } - } -} - -/** - * Co-groups the data from left and right children, and calls the function with each group and 2 - * iterators containing all elements in the group from left and right side. - * The result of this function is encoded and flattened before being output. - */ -case class CoGroup[Key, Left, Right, Result]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], - keyEnc: ExpressionEncoder[Key], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - resultEnc: ExpressionEncoder[Result], - output: Seq[Attribute], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - override def producedAttributes: AttributeSet = outputSet - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil - - override protected def doExecute(): RDD[InternalRow] = { - left.execute().zipPartitions(right.execute()) { (leftData, rightData) => - val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) - val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val boundKeyEnc = keyEnc.bind(leftGroup) - val boundLeftEnc = leftEnc.bind(left.output) - val boundRightEnc = rightEnc.bind(right.output) - - new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { - case (key, leftResult, rightResult) => - val result = func( - boundKeyEnc.fromRow(key), - leftResult.map(boundLeftEnc.fromRow), - rightResult.map(boundRightEnc.fromRow)) - result.map(resultEnc.toRow) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala new file mode 100644 index 0000000000000..2acca1743cbb9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.ObjectType + +/** + * Helper functions for physical operators that work with user defined objects. + */ +trait ObjectOperator extends SparkPlan { + def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = { + val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema) + (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType) + } + + def generateToRow(serializer: Seq[Expression]): Any => InternalRow = { + val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) { + GenerateSafeProjection.generate(serializer) + } else { + GenerateUnsafeProjection.generate(serializer) + } + val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head + val outputRow = new SpecificMutableRow(inputType :: Nil) + (o: Any) => { + outputRow(0) = o + outputProjection(outputRow) + } + } +} + +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapPartitions( + func: Iterator[Any] => Iterator[Any], + input: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(input, child.output) + val outputObject = generateToRow(serializer) + func(iter.map(getObject)).map(outputObject) + } + } +} + +/** + * Applies the given function to each input row, appending the encoded result at the end of the row. + */ +case class AppendColumns( + func: Any => Any, + input: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute) + + private def newColumnSchema = serializer.map(_.toAttribute).toStructType + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(input, child.output) + val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) + val outputObject = generateToRow(serializer) + + iter.map { row => + val newColumns = outputObject(func(getObject(row))) + + // This operates on the assumption that we always serialize the result... + combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + } + } + } +} + +/** + * Groups the input rows together and calls the function with each group and an iterator containing + * all elements in the group. The result of this function is encoded and flattened before + * being output. + */ +case class MapGroups( + func: (Any, Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + input: Expression, + serializer: Seq[NamedExpression], + groupingAttributes: Seq[Attribute], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + + val getKey = generateToObject(keyObject, groupingAttributes) + val getValue = generateToObject(input, child.output) + val outputObject = generateToRow(serializer) + + grouped.flatMap { case (key, rowIter) => + val result = func( + getKey(key), + rowIter.map(getValue)) + result.map(outputObject) + } + } + } +} + +/** + * Co-groups the data from left and right children, and calls the function with each group and 2 + * iterators containing all elements in the group from left and right side. + * The result of this function is encoded and flattened before being output. + */ +case class CoGroup( + func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + leftObject: Expression, + rightObject: Expression, + serializer: Seq[NamedExpression], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with ObjectOperator { + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + + val getKey = generateToObject(keyObject, leftGroup) + val getLeft = generateToObject(leftObject, left.output) + val getRight = generateToObject(rightObject, right.output) + val outputObject = generateToRow(serializer) + + new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { + case (key, leftResult, rightResult) => + val result = func( + getKey(key), + leftResult.map(getLeft), + rightResult.map(getRight)) + result.map(outputObject) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d7b86e381108e..b69bb21db532b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +case class OtherTuple(_1: String, _2: Int) + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -111,6 +113,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } + test("map with type change") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + + checkAnswer( + ds.map(identity[(String, Int)]) + .as[OtherTuple] + .map(identity[OtherTuple]), + OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) + } + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index fac26bd0c0269..ce12f788b786c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -192,10 +192,10 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: MapPartitions[_, _] => return - case _: MapGroups[_, _, _] => return - case _: AppendColumns[_, _] => return - case _: CoGroup[_, _, _, _] => return + case _: MapPartitions => return + case _: MapGroups => return + case _: AppendColumns => return + case _: CoGroup => return case _: LogicalRelation => return }.transformAllExpressions { case a: ImperativeAggregate => return From 32cca933546b4aaf0fc040b9cfd1a5968171b423 Mon Sep 17 00:00:00 2001 From: Koyo Yoshida Date: Fri, 15 Jan 2016 13:32:47 +0900 Subject: [PATCH 1572/2089] [SPARK-12708][UI] Sorting task error in Stages Page when yarn mode. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If sort column contains slash(e.g. "Executor ID / Host") when yarn mode,sort fail with following message. ![spark-12708](https://cloud.githubusercontent.com/assets/6679275/12193320/80814f8c-b62a-11e5-9914-7bf3907029df.png) It's similar to SPARK-4313 . Author: root Author: Koyo Yoshida Closes #10663 from yoshidakuy/SPARK-12708. --- .../main/scala/org/apache/spark/ui/UIUtils.scala | 16 ++++++++++++++++ .../spark/ui/exec/ExecutorThreadDumpPage.scala | 15 ++------------- .../org/apache/spark/ui/jobs/PoolPage.scala | 11 ++++++++--- .../org/apache/spark/ui/jobs/PoolTable.scala | 4 +++- .../org/apache/spark/ui/jobs/StagePage.scala | 4 +++- .../scala/org/apache/spark/ui/UIUtilsSuite.scala | 14 ++++++++++++++ 6 files changed, 46 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 81a6f07ec836a..1949c4b3cbf42 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -451,4 +452,19 @@ private[spark] object UIUtils extends Logging { {desc} } } + + /** + * Decode URLParameter if URL is encoded by YARN-WebAppProxyServlet. + * Due to YARN-2844: WebAppProxyServlet cannot handle urls which contain encoded characters + * Therefore we need to decode it until we get the real URLParameter. + */ + def decodeURLParameter(urlParam: String): String = { + var param = urlParam + var decodedParam = URLDecoder.decode(param, "UTF-8") + while (param != decodedParam) { + param = decodedParam + decodedParam = URLDecoder.decode(param, "UTF-8") + } + param + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 1a6f0fdd50df7..edc66709e229a 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.net.URLDecoder import javax.servlet.http.HttpServletRequest import scala.util.Try @@ -30,18 +29,8 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage private val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val executorId = Option(request.getParameter("executorId")).map { - executorId => - // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when - // running in yarn-cluster mode. `request.getParameter("executorId")` will return - // "%253Cdriver%253E". Therefore we need to decode it until we get the real id. - var id = executorId - var decodedId = URLDecoder.decode(id, "UTF-8") - while (id != decodedId) { - id = decodedId - decodedId = URLDecoder.decode(id, "UTF-8") - } - id + val executorId = Option(request.getParameter("executorId")).map { executorId => + UIUtils.decodeURLParameter(executorId) }.getOrElse { throw new IllegalArgumentException(s"Missing executorId parameter") } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index fa30f2bda4272..6cd25919ca5fd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -31,8 +31,11 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val poolName = request.getParameter("poolname") - require(poolName != null && poolName.nonEmpty, "Missing poolname parameter") + val poolName = Option(request.getParameter("poolname")).map { poolname => + UIUtils.decodeURLParameter(poolname) + }.getOrElse { + throw new IllegalArgumentException(s"Missing poolname parameter") + } val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { @@ -44,7 +47,9 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getPoolForName(poolName).get).toSeq + val pools = sc.map(_.getPoolForName(poolName).getOrElse { + throw new IllegalArgumentException(s"Unknown poolname: $poolName") + }).toSeq val poolTable = new PoolTable(pools, parent) val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 9ba2af54dacf4..ea02968733cac 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.net.URLEncoder + import scala.collection.mutable.HashMap import scala.xml.Node @@ -59,7 +61,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { case None => 0 } val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), p.name) + .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) From ba4a641902f95c5a9b3a6bebcaa56039eca2720d Mon Sep 17 00:00:00 2001 From: "Oscar D. Lara Yejas" Date: Fri, 15 Jan 2016 07:37:54 -0800 Subject: [PATCH 1577/2089] [SPARK-11031][SPARKR] Method str() on a DataFrame Author: Oscar D. Lara Yejas Author: Oscar D. Lara Yejas Author: Oscar D. Lara Yejas Author: Oscar D. Lara Yejas Closes #9613 from olarayej/SPARK-11031. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 73 +++++++++++++++++++++++ R/pkg/R/generics.R | 36 +++++------ R/pkg/R/types.R | 21 +++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 31 ++++++++++ 5 files changed, 140 insertions(+), 22 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 34be7f0ebd752..34d14373b9027 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -278,6 +278,7 @@ export("as.DataFrame", "read.parquet", "read.text", "sql", + "str", "table", "tableNames", "tables", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3bf5bc924f7dc..35695b9df1974 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2299,3 +2299,76 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) + +#' Display the structure of a DataFrame, including column names, column types, as well as a +#' a small sample of rows. +#' @name str +#' @title Compactly display the structure of a dataset +#' @rdname str +#' @family DataFrame functions +#' @param object a DataFrame +#' @examples \dontrun{ +#' # Create a DataFrame from the Iris dataset +#' irisDF <- createDataFrame(sqlContext, iris) +#' +#' # Show the structure of the DataFrame +#' str(irisDF) +#' } +setMethod("str", + signature(object = "DataFrame"), + function(object) { + + # TODO: These could be made global parameters, though in R it's not the case + MAX_CHAR_PER_ROW <- 120 + MAX_COLS <- 100 + + # Get the column names and types of the DataFrame + names <- names(object) + types <- coltypes(object) + + # Get the first elements of the dataset. Limit number of columns accordingly + localDF <- if (ncol(object) > MAX_COLS) { + head(object[, c(1:MAX_COLS)]) + } else { + head(object) + } + + # The number of observations will not be displayed as computing the + # number of rows is a very expensive operation + cat(paste0("'", class(object), "': ", length(names), " variables:\n")) + + if (nrow(localDF) > 0) { + for (i in 1 : ncol(localDF)) { + # Get the first elements for each column + + firstElements <- if (types[i] == "character") { + paste(paste0("\"", localDF[,i], "\""), collapse = " ") + } else { + paste(localDF[,i], collapse = " ") + } + + # Add the corresponding number of spaces for alignment + spaces <- paste(rep(" ", max(nchar(names) - nchar(names[i]))), collapse="") + + # Get the short type. For 'character', it would be 'chr'; + # 'for numeric', it's 'num', etc. + dataType <- SHORT_TYPES[[types[i]]] + if (is.null(dataType)) { + dataType <- substring(types[i], 1, 3) + } + + # Concatenate the colnames, coltypes, and first + # elements of each column + line <- paste0(" $ ", names[i], spaces, ": ", + dataType, " ",firstElements) + + # Chop off extra characters if this is too long + cat(substr(line, 1, MAX_CHAR_PER_ROW)) + cat("\n") + } + + if (ncol(localDF) < ncol(object)) { + cat(paste0("\nDisplaying first ", ncol(localDF), " columns only.")) + } + } + }) \ No newline at end of file diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5ba68e3a4f378..860329988f97c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -378,7 +378,6 @@ setGeneric("subtractByKey", setGeneric("value", function(bcast) { standardGeneric("value") }) - #################### DataFrame Methods ######################## #' @rdname agg @@ -389,6 +388,14 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname as.data.frame +#' @export +setGeneric("as.data.frame") + +#' @rdname attach +#' @export +setGeneric("attach") + #' @rdname columns #' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) @@ -525,13 +532,12 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) -#' @rdname withColumn #' @export -setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) +setGeneric("str") -#' @rdname write.df +#' @rdname mutate #' @export -setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export @@ -593,6 +599,10 @@ setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) +#' @rdname with +#' @export +setGeneric("with") + #' @rdname withColumn #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) @@ -602,6 +612,9 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) ###################### Column Methods ########################## @@ -1109,7 +1122,6 @@ setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) - #' @rdname glm #' @export setGeneric("glm") @@ -1121,15 +1133,3 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind #' @export setGeneric("rbind", signature = "...") - -#' @rdname as.data.frame -#' @export -setGeneric("as.data.frame") - -#' @rdname attach -#' @export -setGeneric("attach") - -#' @rdname with -#' @export -setGeneric("with") diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index 1f06af7e904fe..ad048b1cd1795 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -47,10 +47,23 @@ COMPLEX_TYPES <- list( # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) +SHORT_TYPES <- as.environment(list( + "character" = "chr", + "logical" = "logi", + "POSIXct" = "POSIXct", + "integer" = "int", + "numeric" = "num", + "raw" = "raw", + "Date" = "Date", + "map" = "map", + "array" = "array", + "struct" = "struct" +)) + # An environment for mapping R to Scala, names are R types and values are Scala types. rToSQLTypes <- as.environment(list( - "integer" = "integer", # in R, integer is 32bit - "numeric" = "double", # in R, numeric == double which is 64bit - "double" = "double", + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", "character" = "string", - "logical" = "boolean")) + "logical" = "boolean")) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 40d5066a93f4c..27ad9f3958362 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1799,6 +1799,37 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { "Only atomic type is supported for column types") }) +test_that("Method str()", { + # Structure of Iris + iris2 <- iris + colnames(iris2) <- c("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species") + iris2$col <- TRUE + irisDF2 <- createDataFrame(sqlContext, iris2) + + out <- capture.output(str(irisDF2)) + expect_equal(length(out), 7) + expect_equal(out[1], "'DataFrame': 6 variables:") + expect_equal(out[2], " $ Sepal_Length: num 5.1 4.9 4.7 4.6 5 5.4") + expect_equal(out[3], " $ Sepal_Width : num 3.5 3 3.2 3.1 3.6 3.9") + expect_equal(out[4], " $ Petal_Length: num 1.4 1.4 1.3 1.5 1.4 1.7") + expect_equal(out[5], " $ Petal_Width : num 0.2 0.2 0.2 0.2 0.2 0.4") + expect_equal(out[6], paste0(" $ Species : chr \"setosa\" \"setosa\" \"", + "setosa\" \"setosa\" \"setosa\" \"setosa\"")) + expect_equal(out[7], " $ col : logi TRUE TRUE TRUE TRUE TRUE TRUE") + + # A random dataset with many columns. This test is to check str limits + # the number of columns. Therefore, it will suffice to check for the + # number of returned rows + x <- runif(200, 1, 10) + df <- data.frame(t(as.matrix(data.frame(x,x,x,x,x,x,x,x,x)))) + DF <- createDataFrame(sqlContext, df) + out <- capture.output(str(DF)) + expect_equal(length(out), 103) + + # Test utils:::str + expect_equal(capture.output(utils:::str(iris)), capture.output(str(iris))) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From c5e7076da72657ea35a0aa388f8d2e6411d39280 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 15 Jan 2016 08:26:20 -0800 Subject: [PATCH 1578/2089] [MINOR] [SQL] GeneratedExpressionCode -> ExprCode GeneratedExpressionCode is too long Author: Davies Liu Closes #10767 from davies/renaming. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 20 +++--- .../sql/catalyst/expressions/Expression.scala | 44 ++++++------- .../catalyst/expressions/InputFileName.scala | 4 +- .../MonotonicallyIncreasingID.scala | 4 +- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +- .../sql/catalyst/expressions/SortOrder.scala | 4 +- .../expressions/SparkPartitionID.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 22 +++---- .../expressions/bitwiseExpressions.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 16 ++--- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateOrdering.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 26 ++++---- .../codegen/GenerateUnsafeProjection.scala | 20 +++--- .../expressions/collectionOperations.scala | 6 +- .../expressions/complexTypeCreator.scala | 10 +-- .../expressions/complexTypeExtractors.scala | 10 +-- .../expressions/conditionalExpressions.scala | 12 ++-- .../expressions/datetimeExpressions.scala | 54 ++++++++-------- .../expressions/decimalExpressions.scala | 12 ++-- .../sql/catalyst/expressions/literals.scala | 2 +- .../expressions/mathExpressions.scala | 36 +++++------ .../spark/sql/catalyst/expressions/misc.scala | 28 ++++---- .../expressions/namedExpressions.scala | 4 +- .../expressions/nullExpressions.scala | 14 ++-- .../sql/catalyst/expressions/objects.scala | 28 ++++---- .../sql/catalyst/expressions/predicates.scala | 18 +++--- .../expressions/randomExpressions.scala | 6 +- .../expressions/regexpExpressions.scala | 10 +-- .../expressions/stringExpressions.scala | 64 +++++++++---------- .../expressions/NonFoldableLiteral.scala | 2 +- 32 files changed, 249 insertions(+), 249 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c94b2c0e270b6..397abc7391ec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ /** @@ -68,7 +68,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def newInstance(): NamedExpression = this - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 6f199cfc5d8cd..1072158f04585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -446,7 +446,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = cast(input) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) eval.code + @@ -460,7 +460,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def nullSafeCastFunction( from: DataType, to: DataType, - ctx: CodeGenContext): CastFunction = to match { + ctx: CodegenContext): CastFunction = to match { case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" @@ -491,7 +491,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultNull = $childNull; @@ -502,7 +502,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } - private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" @@ -524,7 +524,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def castToDateCode( from: DataType, - ctx: CodeGenContext): CastFunction = from match { + ctx: CodegenContext): CastFunction = from match { case StringType => val intOpt = ctx.freshName("intOpt") (c, evPrim, evNull) => s""" @@ -556,7 +556,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def castToDecimalCode( from: DataType, target: DecimalType, - ctx: CodeGenContext): CastFunction = { + ctx: CodegenContext): CastFunction = { val tmp = ctx.freshName("tmpDecimal") from match { case StringType => @@ -614,7 +614,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def castToTimestampCode( from: DataType, - ctx: CodeGenContext): CastFunction = from match { + ctx: CodegenContext): CastFunction = from match { case StringType => val longOpt = ctx.freshName("longOpt") (c, evPrim, evNull) => @@ -826,7 +826,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] def castArrayCode( - fromType: DataType, toType: DataType, ctx: CodeGenContext): CastFunction = { + fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") @@ -861,7 +861,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } - private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) @@ -889,7 +889,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] def castStructCode( - from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { + from: StructType, to: StructType, ctx: CodegenContext): CastFunction = { val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d6219514b752b..25cf210c4b527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -86,22 +86,22 @@ abstract class Expression extends TreeNode[Expression] { def eval(input: InternalRow = null): Any /** - * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * Returns an [[ExprCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. * - * @param ctx a [[CodeGenContext]] - * @return [[GeneratedExpressionCode]] + * @param ctx a [[CodegenContext]] + * @return [[ExprCode]] */ - def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + def gen(ctx: CodegenContext): ExprCode = { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. val code = s"/* ${this.toCommentSafeString} */" - GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) + ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) + val ve = ExprCode("", isNull, primitive) ve.code = genCode(ctx, ve) // Add `this` in the comment. ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) @@ -113,11 +113,11 @@ abstract class Expression extends TreeNode[Expression] { * The default behavior is to call the eval method of the expression. Concrete expression * implementations should override this to do actual code generation. * - * @param ctx a [[CodeGenContext]] - * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @param ctx a [[CodegenContext]] + * @param ev an [[ExprCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String + protected def genCode(ctx: CodegenContext, ev: ExprCode): String /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -245,7 +245,7 @@ trait Unevaluable extends Expression { final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - final override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + final override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } @@ -330,8 +330,8 @@ abstract class UnaryExpression extends Expression { * @param f function that accepts a variable name and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: String => String): String = { nullSafeCodeGen(ctx, ev, eval => { s"${ev.value} = ${f(eval)};" @@ -346,8 +346,8 @@ abstract class UnaryExpression extends Expression { * code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: String => String): String = { val eval = child.gen(ctx) if (nullable) { @@ -420,8 +420,8 @@ abstract class BinaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.value} = ${f(eval1, eval2)};" @@ -437,8 +437,8 @@ abstract class BinaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -560,8 +560,8 @@ abstract class TernaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { s"${ev.value} = ${f(eval1, eval2, eval3)};" @@ -577,8 +577,8 @@ abstract class TernaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) val resultCode = f(evals(0).value, evals(1).value, evals(2).value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 827dce8af100e..c49c601c3034b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { SqlNewHadoopRDDState.getInputFileName() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 94f8801dec369..5d28f8fbde8be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,7 +65,7 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with partitionMask + currentCount } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3a6c909fffce7..4035c9dfa4f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -974,7 +974,7 @@ case class ScalaUDF( // scalastyle:on line.size.limit // Generate codes used to convert the arguments to Scala type for user-defined funtions - private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { val converterClassName = classOf[Any => Any].getName val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" val expressionClassName = classOf[Expression].getName @@ -990,8 +990,8 @@ case class ScalaUDF( } override def genCode( - ctx: CodeGenContext, - ev: GeneratedExpressionCode): String = { + ctx: CodegenContext, + ev: ExprCode): String = { ctx.references += this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 1cb1b9da3049b..bd1d91487275b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator @@ -69,7 +69,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childCode = child.child.gen(ctx) val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index aa3951480c503..377f08eb105fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,7 +44,7 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm override protected def evalInternal(input: InternalRow): Int = partitionId - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val idTerm = ctx.freshName("partitionId") ctx.addMutableState(ctx.JAVA_INT, idTerm, s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7bd851c059d0e..1cacd3f76aa36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -34,7 +34,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -65,7 +65,7 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects override def dataType: DataType = child.dataType - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + override def genCode(ctx: CodegenContext, ev: ExprCode): String = defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input @@ -87,7 +87,7 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => @@ -109,7 +109,7 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -141,7 +141,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => @@ -170,7 +170,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => @@ -225,7 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic /** * Special case handling due to division by 0 => null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -287,7 +287,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet /** * Special case handling for x % 0 ==> null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -344,7 +344,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) @@ -398,7 +398,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) @@ -449,7 +449,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { dataType match { case dt: DecimalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a1e48c4210877..a97bd9edcef84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -118,7 +118,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6daa8ee2f42bf..1c7083bbdacb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,14 +42,14 @@ import org.apache.spark.util.Utils * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: String, var value: String) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported * by codegen, then they are evaluated directly. The unsupported expression is appended at the * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ -class CodeGenContext { +class CodegenContext { /** * Holding all the expressions those do not support codegen, will be evaluated directly. @@ -454,7 +454,7 @@ class CodeGenContext { * expression will be combined in the `expressions` order. */ def generateExpressions(expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.gen(this)) } @@ -479,17 +479,17 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodeGenContext): String = { + protected def declareMutableStates(ctx: CodegenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n") } - protected def initMutableStates(ctx: CodeGenContext): String = { + protected def initMutableStates(ctx: CodegenContext): String = { ctx.mutableStates.map(_._3).mkString("\n") } - protected def declareAddedFunctions(ctx: CodeGenContext): String = { + protected def declareAddedFunctions(ctx: CodegenContext): String = { ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } @@ -591,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * Create a new codegen context for expression evaluator, used to store those * expressions that don't support codegen */ - def newCodeGenContext(): CodeGenContext = { - new CodeGenContext + def newCodeGenContext(): CodegenContext = { + new CodegenContext } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 3353580148799..c98b7350b3594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic} */ trait CodegenFallback extends Expression { - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { foreach { case n: Nondeterministic => n.setInitialValues() case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1af7c73cd4bf5..88bcf5b4ed369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -55,7 +55,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR * Generates the code for comparing a struct type according to its natural ordering * (i.e. ascending order by field 1, then field 2, ..., then field n. */ - def genComparisons(ctx: CodeGenContext, schema: StructType): String = { + def genComparisons(ctx: CodegenContext, schema: StructType): String = { val ordering = schema.fields.map(_.dataType).zipWithIndex.map { case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) } @@ -65,7 +65,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR /** * Generates the code for ordering based on the given order. */ - def genComparisons(ctx: CodeGenContext, ordering: Seq[SortOrder]): String = { + def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => val eval = order.child.gen(ctx) val asc = order.direction == Ascending diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 364dbb770f5e5..865170764640e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -40,9 +40,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(BindReferences.bindReference(_, inputSchema)) private def createCodeForStruct( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - schema: StructType): GeneratedExpressionCode = { + schema: StructType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") @@ -68,13 +68,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final InternalRow $output = new $rowClass($values); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def createCodeForArray( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - elementType: DataType): GeneratedExpressionCode = { + elementType: DataType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeArray") val values = ctx.freshName("values") @@ -96,14 +96,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def createCodeForMap( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, keyType: DataType, - valueType: DataType): GeneratedExpressionCode = { + valueType: DataType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeMap") val mapClass = classOf[ArrayBasedMapData].getName @@ -117,20 +117,20 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def convertToSafe( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - dataType: DataType): GeneratedExpressionCode = dataType match { + dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => GeneratedExpressionCode("", "false", input) + case _ => ExprCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d0e031f27990c..3a929927c3f60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,7 +48,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], bufferHolder: String): String = { @@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldName = ctx.freshName("fieldName") val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" val isNull = s"$input.isNullAt($i)" - GeneratedExpressionCode(code, isNull, fieldName) + ExprCode(code, isNull, fieldName) } s""" @@ -69,9 +69,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } private def writeExpressionsToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, row: String, - inputs: Seq[GeneratedExpressionCode], + inputs: Seq[ExprCode], inputTypes: Seq[DataType], bufferHolder: String): String = { val rowWriter = ctx.freshName("rowWriter") @@ -160,7 +160,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, elementType: DataType, bufferHolder: String): String = { @@ -232,7 +232,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, keyType: DataType, valueType: DataType, @@ -270,7 +270,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); @@ -282,9 +282,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } def createCode( - ctx: CodeGenContext, + ctx: CodegenContext, expressions: Seq[Expression], - useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) @@ -305,7 +305,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7aac2e5e6c1b8..e36c9852491bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -35,7 +35,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType case _: MapType => value.asInstanceOf[MapData].numElements() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") } } @@ -170,7 +170,7 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = ctx.getValue(arr, right.dataType, i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d71bbd63c8e89..0df8101d9417b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -46,7 +46,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { new GenericArrayData(children.map(_.eval(input)).toArray) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") s""" @@ -94,7 +94,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") s""" @@ -171,7 +171,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") s""" @@ -223,7 +223,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, children) ev.isNull = eval.isNull ev.value = eval.value @@ -263,7 +263,7 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ev.isNull = eval.isNull ev.value = eval.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5bd97cc7467ab..5256baaf432a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -113,7 +113,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { if (nullable) { s""" @@ -170,7 +170,7 @@ case class GetArrayStructFields( new GenericArrayData(result) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" @@ -225,7 +225,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int index = (int) $eval2; @@ -285,7 +285,7 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 83abbcdc61175..2a24235a29c9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -52,7 +52,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) @@ -136,7 +136,7 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Generate code that looks like: // // condA = ... @@ -275,11 +275,11 @@ case class Least(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = { + def updateEval(eval: ExprCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || @@ -334,11 +334,11 @@ case class Greatest(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = { + def updateEval(eval: ExprCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 17f1df06f2fad..1d0ea68d7a7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,8 +23,8 @@ import java.util.{Calendar, TimeZone} import scala.util.Try import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, - GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, + ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -84,7 +84,7 @@ case class DateAdd(startDate: Expression, days: Expression) start.asInstanceOf[Int] + d.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd + $d;""" }) @@ -109,7 +109,7 @@ case class DateSub(startDate: Expression, days: Expression) start.asInstanceOf[Int] - d.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd - $d;""" }) @@ -128,7 +128,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } @@ -144,7 +144,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } @@ -160,7 +160,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } @@ -176,7 +176,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } @@ -193,7 +193,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } @@ -209,7 +209,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI DateTimeUtils.getQuarter(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } @@ -225,7 +225,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp DateTimeUtils.getMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } @@ -241,7 +241,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } @@ -265,7 +265,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa c.get(Calendar.WEEK_OF_YEAR) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") @@ -295,7 +295,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) @@ -386,7 +386,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { left.dataType match { case StringType if right.foldable => val sdf = classOf[SimpleDateFormat].getName @@ -503,7 +503,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val sdf = classOf[SimpleDateFormat].getName if (format.foldable) { if (constFormat == null) { @@ -555,7 +555,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") } @@ -591,7 +591,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") val dayOfWeekTerm = ctx.freshName("dayOfWeek") @@ -643,7 +643,7 @@ case class TimeAdd(start: Expression, interval: Expression) start.asInstanceOf[Long], itvl.months, itvl.microseconds) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" @@ -666,7 +666,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() @@ -718,7 +718,7 @@ case class TimeSub(start: Expression, interval: Expression) start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" @@ -743,7 +743,7 @@ case class AddMonths(startDate: Expression, numMonths: Expression) DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, m) => { s"""$dtu.dateAddMonths($sd, $m)""" @@ -770,7 +770,7 @@ case class MonthsBetween(date1: Expression, date2: Expression) DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { s"""$dtu.monthsBetween($l, $r)""" @@ -795,7 +795,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() @@ -840,7 +840,7 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def eval(input: InternalRow): Any = child.eval(input) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, d => d) } @@ -882,7 +882,7 @@ case class TruncDate(date: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { @@ -933,7 +933,7 @@ case class DateDiff(endDate: Expression, startDate: Expression) end.asInstanceOf[Int] - start.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 5f8b544edb511..74e86f40c0364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ /** @@ -34,7 +34,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toUnscaledLong - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -53,7 +53,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un protected override def nullSafeEval(input: Any): Any = Decimal(input.asInstanceOf[Long], precision, scale) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); @@ -70,8 +70,8 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" override def prettyName: String = "promote_precision" override def sql: String = child.sql } @@ -93,7 +93,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e0b020330278b..db30845fdab6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -171,7 +171,7 @@ case class Literal protected (value: Any, dataType: DataType) override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66d8631a846ab..8b9a60f97ce6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -67,7 +67,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } @@ -87,7 +87,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) if (d <= yAsymptote) null else f(d) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -119,7 +119,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -172,7 +172,7 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -207,7 +207,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre toBase.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val numconv = NumberConverter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (num, from, to) => s""" @@ -240,7 +240,7 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -299,7 +299,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" if ($eval > 20 || $eval < 0) { @@ -317,7 +317,7 @@ case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -369,7 +369,7 @@ case class Bin(child: Expression) protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } @@ -464,7 +464,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { @@ -489,7 +489,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp protected override def nullSafeEval(num: Any): Any = Hex.unhex(num.asInstanceOf[UTF8String].getBytes) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s""" @@ -516,14 +516,14 @@ case class Atan2(left: Expression, right: Expression) math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -549,7 +549,7 @@ case class ShiftLeft(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } } @@ -575,7 +575,7 @@ case class ShiftRight(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } } @@ -601,7 +601,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } } @@ -635,7 +635,7 @@ case class Logarithm(left: Expression, right: Expression) if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (left.isInstanceOf[EulerNumber]) { nullSafeCodeGen(ctx, ev, (c1, c2) => s""" @@ -758,7 +758,7 @@ case class Round(child: Expression, scale: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val ce = child.gen(ctx) val evaluationCode = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4751fbe4146fe..2c12de08f4115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -47,7 +47,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } @@ -100,7 +100,7 @@ case class Sha2(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val digestUtils = "org.apache.commons.codec.digest.DigestUtils" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -145,7 +145,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" ) @@ -171,7 +171,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp checksum.getValue } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val CRC32 = "java.util.zip.CRC32" nullSafeCodeGen(ctx, ev, value => { s""" @@ -323,7 +323,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" val childrenHash = children.zipWithIndex.map { case (child, dt) => @@ -347,12 +347,12 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression input: String, dataType: DataType, seed: String, - ctx: CodeGenContext): GeneratedExpressionCode = { + ctx: CodegenContext): ExprCode = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): GeneratedExpressionCode = - GeneratedExpressionCode(code = "", isNull = "false", value = v) + def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") + def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") + def inlineValue(v: String): ExprCode = + ExprCode(code = "", isNull = "false", value = v) dataType match { case NullType => inlineValue(seed) @@ -369,7 +369,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" val offset = "Platform.BYTE_ARRAY_OFFSET" val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) } case CalendarIntervalType => val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" @@ -400,7 +400,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) case MapType(kt, vt, _) => val result = ctx.freshName("result") @@ -427,7 +427,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) case StructType(fields) => val result = ctx.freshName("result") @@ -448,7 +448,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression int $result = $seed; $fieldsHash """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index b6d7a7f5e8d01..7983501ada9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -133,8 +133,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple passthrough for code generation. */ - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 89aec2b20fd0c..667d3513d32b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -61,7 +61,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val first = children(0) val rest = children.drop(1) val firstEval = first.gen(ctx) @@ -110,7 +110,7 @@ case class IsNaN(child: Expression) extends UnaryExpression } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) child.dataType match { case DoubleType | FloatType => @@ -150,7 +150,7 @@ case class NaNvl(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val leftGen = left.gen(ctx) val rightGen = right.gen(ctx) left.dataType match { @@ -189,7 +189,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) == null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.value = eval.isNull @@ -210,7 +210,7 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) != null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.value = s"(!(${eval.isNull}))" @@ -250,7 +250,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 8385f7e1da591..79fe0033b71ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -56,7 +56,7 @@ case class StaticInvoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -145,7 +145,7 @@ case class Invoke( case _ => identity[String] _ } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val obj = targetObject.gen(ctx) val argGen = arguments.map(_.gen(ctx)) @@ -214,7 +214,7 @@ case class NewInstance( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -277,7 +277,7 @@ case class UnwrapOption( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val inputObject = child.gen(ctx) @@ -309,7 +309,7 @@ case class WrapOption(child: Expression, optType: DataType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val inputObject = child.gen(ctx) s""" @@ -332,8 +332,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext override def nullable: Boolean = true - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - GeneratedExpressionCode(code = "", value = value, isNull = isNull) + override def gen(ctx: CodegenContext): ExprCode = { + ExprCode(code = "", value = value, isNull = isNull) } } @@ -415,7 +415,7 @@ case class MapObjects( override def dataType: DataType = ArrayType(lambdaFunction.dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) val genInputData = inputData.gen(ctx) @@ -491,7 +491,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericRow].getName val values = ctx.freshName("values") s""" @@ -521,7 +521,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -560,7 +560,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression { - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -605,7 +605,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val instanceGen = beanInstance.gen(ctx) val initialize = setters.map { @@ -648,7 +648,7 @@ case class AssertNotNull( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childGen = child.gen(ctx) ev.isNull = "false" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bca12a8d21023..a3c10c81c35e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -98,7 +98,7 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } @@ -154,7 +154,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val valueGen = value.gen(ctx) val listGen = list.map(_.gen(ctx)) val listCode = listGen.map(x => @@ -213,7 +213,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with def getHSet(): Set[Any] = hset - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.gen(ctx) @@ -267,7 +267,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -318,7 +318,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -347,7 +347,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType @@ -394,7 +394,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -428,7 +428,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8de47e9ddc28d..2e703671fcd66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -65,7 +65,7 @@ case class Rand(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, @@ -88,7 +88,7 @@ case class Randn(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index db266639b8560..b68009331b0ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -76,7 +76,7 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val pattern = ctx.freshName("pattern") @@ -125,7 +125,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val pattern = ctx.freshName("pattern") @@ -182,7 +182,7 @@ case class StringSplit(str: Expression, pattern: Expression) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. @@ -238,7 +238,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def children: Seq[Expression] = subject :: regexp :: rep :: Nil override def prettyName: String = "regexp_replace" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") @@ -318,7 +318,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 931f752b4dc1a..b965212f27777 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -49,7 +49,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas UTF8String.concat(inputs : _*) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" @@ -102,7 +102,7 @@ case class ConcatWs(children: Seq[Expression]) UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.gen(ctx)) @@ -183,7 +183,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -198,7 +198,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -223,7 +223,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -234,7 +234,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -245,7 +245,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } @@ -291,7 +291,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac srcEval.asInstanceOf[UTF8String].translate(dict) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastMatching = ctx.freshName("lastMatching") val termLastReplace = ctx.freshName("lastReplace") val termDict = ctx.freshName("dict") @@ -338,7 +338,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override protected def nullSafeEval(word: Any, set: Any): Any = set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);" ) @@ -359,7 +359,7 @@ case class StringTrim(child: Expression) override def prettyName: String = "trim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trim()") } } @@ -374,7 +374,7 @@ case class StringTrimLeft(child: Expression) override def prettyName: String = "ltrim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimLeft()") } } @@ -389,7 +389,7 @@ case class StringTrimRight(child: Expression) override def prettyName: String = "rtrim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimRight()") } } @@ -415,7 +415,7 @@ case class StringInstr(str: Expression, substr: Expression) override def prettyName: String = "instr" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } @@ -441,7 +441,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: count.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } @@ -484,7 +484,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val substrGen = substr.gen(ctx) val strGen = str.gen(ctx) val startGen = start.gen(ctx) @@ -526,7 +526,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } @@ -547,7 +547,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } @@ -583,7 +583,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val pattern = children.head.gen(ctx) val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) @@ -634,7 +634,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def nullSafeEval(string: Any): Any = { string.asInstanceOf[UTF8String].toTitleCase } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") } } @@ -656,7 +656,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } } @@ -669,7 +669,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 override def prettyName: String = "reverse" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).reverse()") } } @@ -688,7 +688,7 @@ case class StringSpace(child: Expression) UTF8String.blankString(if (length < 0) 0 else length) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (length) => s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } @@ -723,7 +723,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { @@ -746,7 +746,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => value.asInstanceOf[Array[Byte]].length } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") @@ -766,7 +766,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } @@ -783,7 +783,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } } @@ -805,7 +805,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { val bytes = ctx.freshName("bytes") s""" @@ -833,7 +833,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn bytes.asInstanceOf[Array[Byte]])) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); @@ -852,7 +852,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); @@ -878,7 +878,7 @@ case class Decode(bin: Expression, charset: Expression) UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { @@ -908,7 +908,7 @@ case class Encode(value: Expression, charset: Expression) input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { @@ -985,7 +985,7 @@ case class FormatNumber(x: Expression, d: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (num, d) => { def typeHelper(p: String): String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 118fd695fe2f5..ff34b1e37be93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -35,7 +35,7 @@ case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpres override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { Literal.create(value, dataType).genCode(ctx, ev) } } From 5f83c6991c95616ecbc2878f8860c69b2826f56c Mon Sep 17 00:00:00 2001 From: Hossein Date: Fri, 15 Jan 2016 11:46:46 -0800 Subject: [PATCH 1579/2089] [SPARK-12833][SQL] Initial import of spark-csv CSV is the most common data format in the "small data" world. It is often the first format people want to try when they see Spark on a single node. Having to rely on a 3rd party component for this leads to poor user experience for new users. This PR merges the popular spark-csv data source package (https://github.com/databricks/spark-csv) with SparkSQL. This is a first PR to bring the functionality to spark 2.0 master. We will complete items outlines in the design document (see JIRA attachment) in follow up pull requests. Author: Hossein Author: Reynold Xin Closes #10766 from rxin/csv. --- .rat-excludes | 2 + NOTICE | 38 +- dev/deps/spark-deps-hadoop-2.2 | 1 + dev/deps/spark-deps-hadoop-2.3 | 1 + dev/deps/spark-deps-hadoop-2.4 | 1 + dev/deps/spark-deps-hadoop-2.6 | 1 + sql/core/pom.xml | 6 + ...pache.spark.sql.sources.DataSourceRegister | 1 + .../datasources/csv/CSVInferSchema.scala | 227 ++++++++++++ .../datasources/csv/CSVParameters.scala | 107 ++++++ .../execution/datasources/csv/CSVParser.scala | 243 +++++++++++++ .../datasources/csv/CSVRelation.scala | 298 +++++++++++++++ .../datasources/csv/DefaultSource.scala | 48 +++ .../datasources/json/InferSchema.scala | 13 +- .../src/test/resources/cars-alternative.csv | 5 + sql/core/src/test/resources/cars-null.csv | 6 + .../test/resources/cars-unbalanced-quotes.csv | 4 + sql/core/src/test/resources/cars.csv | 6 + sql/core/src/test/resources/cars.tsv | 4 + .../src/test/resources/cars_iso-8859-1.csv | 6 + sql/core/src/test/resources/comments.csv | 6 + .../src/test/resources/disable_comments.csv | 2 + sql/core/src/test/resources/empty.csv | 0 .../datasources/csv/CSVInferSchemaSuite.scala | 71 ++++ .../datasources/csv/CSVParserSuite.scala | 125 +++++++ .../execution/datasources/csv/CSVSuite.scala | 341 ++++++++++++++++++ .../datasources/csv/CSVTypeCastSuite.scala | 98 +++++ 27 files changed, 1653 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala create mode 100644 sql/core/src/test/resources/cars-alternative.csv create mode 100644 sql/core/src/test/resources/cars-null.csv create mode 100644 sql/core/src/test/resources/cars-unbalanced-quotes.csv create mode 100644 sql/core/src/test/resources/cars.csv create mode 100644 sql/core/src/test/resources/cars.tsv create mode 100644 sql/core/src/test/resources/cars_iso-8859-1.csv create mode 100644 sql/core/src/test/resources/comments.csv create mode 100644 sql/core/src/test/resources/disable_comments.csv create mode 100644 sql/core/src/test/resources/empty.csv create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala diff --git a/.rat-excludes b/.rat-excludes index bf071eba652b1..a4f316a4aaa04 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -86,3 +86,5 @@ org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet LZ4BlockInputStream.java spark-deps-.* +.*csv +.*tsv diff --git a/NOTICE b/NOTICE index 571f8c2fff7ff..e416aadce9911 100644 --- a/NOTICE +++ b/NOTICE @@ -610,7 +610,43 @@ Vis.js uses and redistributes the following third-party libraries: =============================================================================== -The CSS style for the navigation sidebar of the documentation was originally +The CSS style for the navigation sidebar of the documentation was originally submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project is distributed under the 3-Clause BSD license. =============================================================================== + +For CSV functionality: + +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2015 Ayasdi Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 53034a25d46ab..fb2e91e1ee4b0 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -184,6 +184,7 @@ tachyon-underfs-hdfs-0.8.2.jar tachyon-underfs-local-0.8.2.jar tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar +univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index a23e260641aeb..59e4d4f839788 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -175,6 +175,7 @@ tachyon-underfs-hdfs-0.8.2.jar tachyon-underfs-local-0.8.2.jar tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar +univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 6bedbed1e3355..e4395c872c230 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -176,6 +176,7 @@ tachyon-underfs-hdfs-0.8.2.jar tachyon-underfs-local-0.8.2.jar tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar +univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 7bfad57b4a4a6..89fd15da7d0b3 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -182,6 +182,7 @@ tachyon-underfs-hdfs-0.8.2.jar tachyon-underfs-local-0.8.2.jar tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar +univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 6db7a8a2dc526..31b364f351d56 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -36,6 +36,12 @@ + + com.univocity + univocity-parsers + 1.5.6 + jar + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1ca2044057e56..226d59d0eae88 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,4 @@ +org.apache.spark.sql.execution.datasources.csv.DefaultSource org.apache.spark.sql.execution.datasources.jdbc.DefaultSource org.apache.spark.sql.execution.datasources.json.DefaultSource org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala new file mode 100644 index 0000000000000..0aa4539e60516 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.sql.{Date, Timestamp} +import java.text.NumberFormat +import java.util.Locale + +import scala.util.control.Exception._ +import scala.util.Try + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.types._ + + +private[sql] object CSVInferSchema { + + /** + * Similar to the JSON schema inference + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + * TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670] + */ + def apply( + tokenRdd: RDD[Array[String]], + header: Array[String], + nullValue: String = ""): StructType = { + + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val rootTypes: Array[DataType] = + tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) + + val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => + StructField(thisHeader, rootType, nullable = true) + } + + StructType(structFields) + } + + private def inferRowType(nullValue: String) + (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + var i = 0 + while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. + rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue) + i+=1 + } + rowSoFar + } + + private[csv] def mergeRowTypes( + first: Array[DataType], + second: Array[DataType]): Array[DataType] = { + + first.zipAll(second, NullType, NullType).map { case ((a, b)) => + val tpe = findTightestCommonType(a, b).getOrElse(StringType) + tpe match { + case _: NullType => StringType + case other => other + } + } + } + + /** + * Infer type of string field. Given known type Double, and a string "1", there is no + * point checking if it is an Int, as the final type must be Double or higher. + */ + private[csv] def inferField( + typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { + if (field == null || field.isEmpty || field == nullValue) { + typeSoFar + } else { + typeSoFar match { + case NullType => tryParseInteger(field) + case IntegerType => tryParseInteger(field) + case LongType => tryParseLong(field) + case DoubleType => tryParseDouble(field) + case TimestampType => tryParseTimestamp(field) + case StringType => StringType + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } + } + } + + private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) { + IntegerType + } else { + tryParseLong(field) + } + + private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) { + LongType + } else { + tryParseDouble(field) + } + + private def tryParseDouble(field: String): DataType = { + if ((allCatch opt field.toDouble).isDefined) { + DoubleType + } else { + tryParseTimestamp(field) + } + } + + def tryParseTimestamp(field: String): DataType = { + if ((allCatch opt Timestamp.valueOf(field)).isDefined) { + TimestampType + } else { + stringType() + } + } + + // Defining a function to return the StringType constant is necessary in order to work around + // a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions; + // see issue #128 for more details. + private def stringType(): DataType = { + StringType + } + + private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence + + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + Some(numericPrecedence(index)) + + case _ => None + } +} + +object CSVTypeCast { + + /** + * Casts given string datum to specified type. + * Currently we do not support complex types (ArrayType, MapType, StructType). + * + * For string types, this is simply the datum. For other types. + * For other nullable types, this is null if the string datum is empty. + * + * @param datum string value + * @param castType SparkSQL type + */ + private[csv] def castTo( + datum: String, + castType: DataType, + nullable: Boolean = true, + nullValue: String = ""): Any = { + + if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + case _: DoubleType => Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => Timestamp.valueOf(datum) + // TODO(hossein): would be good to support other common date formats + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } + } + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + * + */ + @throws[IllegalArgumentException] + private[csv] def toChar(str: String): Char = { + if (str.charAt(0) == '\\') { + str.charAt(1) + match { + case 't' => '\t' + case 'r' => '\r' + case 'b' => '\b' + case 'f' => '\f' + case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options + case '\'' => '\'' + case 'u' if str == """\u0000""" => '\u0000' + case _ => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + } + } else if (str.length == 1) { + str.charAt(0) + } else { + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala new file mode 100644 index 0000000000000..ba44121244163 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import org.apache.spark.Logging + +private[sql] case class CSVParameters(parameters: Map[String, String]) extends Logging { + + private def getChar(paramName: String, default: Char): Char = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(value) if value.length == 0 => '\0' + case Some(value) if value.length == 1 => value.charAt(0) + case _ => throw new RuntimeException(s"$paramName cannot be more than one character") + } + } + + private def getBool(paramName: String, default: Boolean = false): Boolean = { + val param = parameters.getOrElse(paramName, default.toString) + if (param.toLowerCase() == "true") { + true + } else if (param.toLowerCase == "false") { + false + } else { + throw new Exception(s"$paramName flag can be true or false") + } + } + + val delimiter = CSVTypeCast.toChar(parameters.getOrElse("delimiter", ",")) + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val charset = parameters.getOrElse("charset", Charset.forName("UTF-8").name()) + + val quote = getChar("quote", '\"') + val escape = getChar("escape", '\\') + val comment = getChar("comment", '\0') + + val headerFlag = getBool("header") + val inferSchemaFlag = getBool("inferSchema") + val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") + val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + + // Limit the number of lines we'll search for a header row that isn't comment-prefixed + val MAX_COMMENT_LINES_IN_HEADER = 10 + + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + + val failFast = ParseModes.isFailFastMode(parseMode) + val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + val permissive = ParseModes.isPermissiveMode(parseMode) + + val nullValue = parameters.getOrElse("nullValue", "") + + val maxColumns = 20480 + + val maxCharsPerColumn = 100000 + + val inputBufferSize = 128 + + val isCommentSet = this.comment != '\0' + + val rowSeparator = "\n" +} + +private[csv] object ParseModes { + + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala new file mode 100644 index 0000000000000..ba1cc42f3e446 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader} + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings} + +import org.apache.spark.Logging + +/** + * Read and parse CSV-like input + * + * @param params Parameters object + * @param headers headers for the columns + */ +private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) { + + protected lazy val parser: CsvParser = { + val settings = new CsvParserSettings() + val format = settings.getFormat + format.setDelimiter(params.delimiter) + format.setLineSeparator(params.rowSeparator) + format.setQuote(params.quote) + format.setQuoteEscape(params.escape) + format.setComment(params.comment) + settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag) + settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag) + settings.setReadInputOnSeparateThread(false) + settings.setInputBufferSize(params.inputBufferSize) + settings.setMaxColumns(params.maxColumns) + settings.setNullValue(params.nullValue) + settings.setMaxCharsPerColumn(params.maxCharsPerColumn) + if (headers != null) settings.setHeaders(headers: _*) + + new CsvParser(settings) + } +} + +/** + * Converts a sequence of string to CSV string + * + * @param params Parameters object for configuration + * @param headers headers for columns + */ +private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging { + private val writerSettings = new CsvWriterSettings + private val format = writerSettings.getFormat + + format.setDelimiter(params.delimiter) + format.setLineSeparator(params.rowSeparator) + format.setQuote(params.quote) + format.setQuoteEscape(params.escape) + format.setComment(params.comment) + + writerSettings.setNullValue(params.nullValue) + writerSettings.setEmptyValue(params.nullValue) + writerSettings.setSkipEmptyLines(true) + writerSettings.setQuoteAllFields(false) + writerSettings.setHeaders(headers: _*) + + def writeRow(row: Seq[String], includeHeader: Boolean): String = { + val buffer = new ByteArrayOutputStream() + val outputWriter = new OutputStreamWriter(buffer) + val writer = new CsvWriter(outputWriter, writerSettings) + + if (includeHeader) { + writer.writeHeaders() + } + writer.writeRow(row.toArray: _*) + writer.close() + buffer.toString.stripLineEnd + } +} + +/** + * Parser for parsing a line at a time. Not efficient for bulk data. + * + * @param params Parameters object + */ +private[sql] class LineCsvReader(params: CSVParameters) + extends CsvReader(params, null) { + /** + * parse a line + * + * @param line a String with no newline at the end + * @return array of strings where each string is a field in the CSV record + */ + def parseLine(line: String): Array[String] = { + parser.beginParsing(new StringReader(line)) + val parsed = parser.parseNext() + parser.stopParsing() + parsed + } +} + +/** + * Parser for parsing lines in bulk. Use this when efficiency is desired. + * + * @param iter iterator over lines in the file + * @param params Parameters object + * @param headers headers for the columns + */ +private[sql] class BulkCsvReader( + iter: Iterator[String], + params: CSVParameters, + headers: Seq[String]) + extends CsvReader(params, headers) with Iterator[Array[String]] { + + private val reader = new StringIteratorReader(iter) + parser.beginParsing(reader) + private var nextRecord = parser.parseNext() + + /** + * get the next parsed line. + * @return array of strings where each string is a field in the CSV record + */ + override def next(): Array[String] = { + val curRecord = nextRecord + if(curRecord != null) { + nextRecord = parser.parseNext() + } else { + throw new NoSuchElementException("next record is null") + } + curRecord + } + + override def hasNext: Boolean = nextRecord != null + +} + +/** + * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at + * end of each line Univocity parser requires a Reader that provides access to the data to be + * parsed and needs the newlines to be present + * @param iter iterator over RDD[String] + */ +private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { + + private var next: Long = 0 + private var length: Long = 0 // length of input so far + private var start: Long = 0 + private var str: String = null // current string from iter + + /** + * fetch next string from iter, if done with current one + * pretend there is a new line at the end of every string we get from from iter + */ + private def refill(): Unit = { + if (length == next) { + if (iter.hasNext) { + str = iter.next() + start = length + length += (str.length + 1) // allowance for newline removed by SparkContext.textFile() + } else { + str = null + } + } + } + + /** + * read the next character, if at end of string pretend there is a new line + */ + override def read(): Int = { + refill() + if (next >= length) { + -1 + } else { + val cur = next - start + next += 1 + if (cur == str.length) '\n' else str.charAt(cur.toInt) + } + } + + /** + * read from str into cbuf + */ + override def read(cbuf: Array[Char], off: Int, len: Int): Int = { + refill() + var n = 0 + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException() + } else if (len == 0) { + n = 0 + } else { + if (next >= length) { // end of input + n = -1 + } else { + n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size + if (n == length - next) { + str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off) + cbuf(off + n - 1) = '\n' + } else { + str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off) + } + next += n + if (n < len) { + val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter + if(m != -1) n += m + } + } + } + + n + } + + override def skip(ns: Long): Long = { + throw new IllegalArgumentException("Skip not implemented") + } + + override def ready: Boolean = { + refill() + true + } + + override def markSupported: Boolean = false + + override def mark(readAheadLimit: Int): Unit = { + throw new IllegalArgumentException("Mark not implemented") + } + + override def reset(): Unit = { + throw new IllegalArgumentException("Mark and hence reset not implemented") + } + + override def close(): Unit = { } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala new file mode 100644 index 0000000000000..9267479755e82 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import scala.util.control.NonFatal + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce.RecordWriter +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat + +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +private[csv] class CSVRelation( + private val inputRDD: Option[RDD[String]], + override val paths: Array[String], + private val maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + private val parameters: Map[String, String]) + (@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable { + + override lazy val dataSchema: StructType = maybeDataSchema match { + case Some(structType) => structType + case None => inferSchema(paths) + } + + private val params = new CSVParameters(parameters) + + @transient + private var cachedRDD: Option[RDD[String]] = None + + private def readText(location: String): RDD[String] = { + if (Charset.forName(params.charset) == Charset.forName("UTF-8")) { + sqlContext.sparkContext.textFile(location) + } else { + sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) + .mapPartitions { _.map { pair => + new String(pair._2.getBytes, 0, pair._2.getLength, params.charset) + } + } + } + } + + private def baseRdd(inputPaths: Array[String]): RDD[String] = { + inputRDD.getOrElse { + cachedRDD.getOrElse { + val rdd = readText(inputPaths.mkString(",")) + cachedRDD = Some(rdd) + rdd + } + } + } + + private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = { + val rdd = baseRdd(inputPaths) + // Make sure firstLine is materialized before sending to executors + val firstLine = if (params.headerFlag) findFirstLine(rdd) else null + CSVRelation.univocityTokenizer(rdd, header, firstLine, params) + } + + /** + * This supports to eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. This reads all the tokens of each line + * and then drop unneeded tokens without casting and type-checking by mapping + * both the indices produced by `requiredColumns` and the ones of tokens. + * TODO: Switch to using buildInternalScan + */ + override def buildScan(requiredColumns: Array[String], inputs: Array[FileStatus]): RDD[Row] = { + val pathsString = inputs.map(_.getPath.toUri.toString) + val header = schema.fields.map(_.name) + val tokenizedRdd = tokenRdd(header, pathsString) + CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params) + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new CSVOutputWriterFactory(params) + } + + override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns) + + override def equals(other: Any): Boolean = other match { + case that: CSVRelation => { + val equalPath = paths.toSet == that.paths.toSet + val equalDataSchema = dataSchema == that.dataSchema + val equalSchema = schema == that.schema + val equalPartitionColums = partitionColumns == that.partitionColumns + + equalPath && equalDataSchema && equalSchema && equalPartitionColums + } + case _ => false + } + + private def inferSchema(paths: Array[String]): StructType = { + val rdd = baseRdd(Array(paths.head)) + val firstLine = findFirstLine(rdd) + val firstRow = new LineCsvReader(params).parseLine(firstLine) + + val header = if (params.headerFlag) { + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"C$index" } + } + + val parsedRdd = tokenRdd(header, paths) + if (params.inferSchemaFlag) { + CSVInferSchema(parsedRdd, header, params.nullValue) + } else { + // By default fields are assumed to be StringType + val schemaFields = header.map { fieldName => + StructField(fieldName.toString, StringType, nullable = true) + } + StructType(schemaFields) + } + } + + /** + * Returns the first line of the first non-empty file in path + */ + private def findFirstLine(rdd: RDD[String]): String = { + if (params.isCommentSet) { + rdd.take(params.MAX_COMMENT_LINES_IN_HEADER) + .find(!_.startsWith(params.comment.toString)) + .getOrElse(sys.error(s"No uncommented header line in " + + s"first ${params.MAX_COMMENT_LINES_IN_HEADER} lines")) + } else { + rdd.first() + } + } +} + +object CSVRelation extends Logging { + + def univocityTokenizer( + file: RDD[String], + header: Seq[String], + firstLine: String, + params: CSVParameters): RDD[Array[String]] = { + // If header is set, make sure firstLine is materialized before sending to executors. + file.mapPartitionsWithIndex({ + case (split, iter) => new BulkCsvReader( + if (params.headerFlag) iter.filterNot(_ == firstLine) else iter, + params, + headers = header) + }, true) + } + + def parseCsv( + tokenizedRDD: RDD[Array[String]], + schema: StructType, + requiredColumns: Array[String], + inputs: Array[FileStatus], + sqlContext: SQLContext, + params: CSVParameters): RDD[Row] = { + + val schemaFields = schema.fields + val requiredFields = StructType(requiredColumns.map(schema(_))).fields + val safeRequiredFields = if (params.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredFields ++ schemaFields.filterNot(requiredFields.contains(_)) + } else { + requiredFields + } + if (requiredColumns.isEmpty) { + sqlContext.sparkContext.emptyRDD[Row] + } else { + val safeRequiredIndices = new Array[Int](safeRequiredFields.length) + schemaFields.zipWithIndex.filter { + case (field, _) => safeRequiredFields.contains(field) + }.foreach { + case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index + } + val rowArray = new Array[Any](safeRequiredIndices.length) + val requiredSize = requiredFields.length + tokenizedRDD.flatMap { tokens => + if (params.dropMalformed && schemaFields.length != tokens.size) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } else if (params.failFast && schemaFields.length != tokens.size) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(params.delimiter.toString)}") + } else { + val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) { + tokens ++ new Array[String](schemaFields.length - tokens.size) + } else if (params.permissive && schemaFields.length < tokens.size) { + tokens.take(schemaFields.length) + } else { + tokens + } + try { + var index: Int = 0 + var subIndex: Int = 0 + while (subIndex < safeRequiredIndices.length) { + index = safeRequiredIndices(subIndex) + val field = schemaFields(index) + rowArray(subIndex) = CSVTypeCast.castTo( + indexSafeTokens(index), + field.dataType, + field.nullable, + params.nullValue) + subIndex = subIndex + 1 + } + Some(Row.fromSeq(rowArray.take(requiredSize))) + } catch { + case NonFatal(e) if params.dropMalformed => + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } + } + } + } + } +} + +private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new CsvOutputWriter(path, dataSchema, context, params) + } +} + +private[sql] class CsvOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext, + params: CSVParameters) extends OutputWriter with Logging { + + // create the Generator without separator inserted between 2 records + private[this] val text = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + private var firstRow: Boolean = params.headerFlag + + private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + + private def rowToString(row: Seq[Any]): Seq[String] = row.map { field => + if (field != null) { + field.toString + } else { + params.nullValue + } + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + // TODO: Instead of converting and writing every row, we should use the univocity buffer + val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow) + if (firstRow) { + firstRow = false + } + text.set(resultString) + recordWriter.write(NullWritable.get(), text) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala new file mode 100644 index 0000000000000..2fffae452c2f7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType + +/** + * Provides access to CSV data from pure SQL statements. + */ +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "csv" + + /** + * Creates a new relation for data store in CSV given parameters and user supported schema. + */ + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + + new CSVRelation( + None, + paths, + dataSchema, + partitionColumns, + parameters)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 59ba4ae2cba0a..44d5e4ff7ec8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -145,7 +145,7 @@ private[json] object InferSchema { /** * Convert NullType to StringType and remove StructTypes with no fields */ - private def canonicalizeType: DataType => Option[DataType] = { + private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) @@ -154,15 +154,15 @@ private[json] object InferSchema { } case StructType(fields) => - val canonicalFields = for { + val canonicalFields: Array[StructField] = for { field <- fields - if field.name.nonEmpty + if field.name.length > 0 canonicalType <- canonicalizeType(field.dataType) } yield { field.copy(dataType = canonicalType) } - if (canonicalFields.nonEmpty) { + if (canonicalFields.length > 0) { Some(StructType(canonicalFields)) } else { // per SPARK-8093: empty structs should be deleted @@ -217,10 +217,9 @@ private[json] object InferSchema { (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. - case (DoubleType, t: DecimalType) => - DoubleType - case (t: DecimalType, DoubleType) => + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => DoubleType + case (t1: DecimalType, t2: DecimalType) => val scale = math.max(t1.scale, t2.scale) val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/cars-alternative.csv new file mode 100644 index 0000000000000..646f7c456c866 --- /dev/null +++ b/sql/core/src/test/resources/cars-alternative.csv @@ -0,0 +1,5 @@ +year|make|model|comment|blank +'2012'|'Tesla'|'S'| 'No comment'| + +1997|Ford|E350|'Go get one now they are going fast'| +2015|Chevy|Volt diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/cars-null.csv new file mode 100644 index 0000000000000..130c0b40bbe78 --- /dev/null +++ b/sql/core/src/test/resources/cars-null.csv @@ -0,0 +1,6 @@ +year,make,model,comment,blank +"2012","Tesla","S",null, + +1997,Ford,E350,"Go get one now they are going fast", +null,Chevy,Volt + diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/cars-unbalanced-quotes.csv new file mode 100644 index 0000000000000..5ea39fcbfadcc --- /dev/null +++ b/sql/core/src/test/resources/cars-unbalanced-quotes.csv @@ -0,0 +1,4 @@ +year,make,model,comment,blank +"2012,Tesla,S,No comment +1997,Ford,E350,Go get one now they are going fast" +"2015,"Chevy",Volt, diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/cars.csv new file mode 100644 index 0000000000000..2b9d74ca607ad --- /dev/null +++ b/sql/core/src/test/resources/cars.csv @@ -0,0 +1,6 @@ +year,make,model,comment,blank +"2012","Tesla","S","No comment", + +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt + diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/cars.tsv new file mode 100644 index 0000000000000..a7bfa9a91f961 --- /dev/null +++ b/sql/core/src/test/resources/cars.tsv @@ -0,0 +1,4 @@ +year make model price comment blank +2012 Tesla S "80,000.65" +1997 Ford E350 35,000 "Go get one now they are going fast" +2015 Chevy Volt 5,000.10 diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/cars_iso-8859-1.csv new file mode 100644 index 0000000000000..c51b6c59010f0 --- /dev/null +++ b/sql/core/src/test/resources/cars_iso-8859-1.csv @@ -0,0 +1,6 @@ +yearþmakeþmodelþcommentþblank +"2012"þ"Tesla"þ"S"þ"No comment"þ + +1997þFordþE350þ"Go get one now they are þoing fast"þ +2015þChevyþVolt + diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/comments.csv new file mode 100644 index 0000000000000..6275be7285b36 --- /dev/null +++ b/sql/core/src/test/resources/comments.csv @@ -0,0 +1,6 @@ +~ Version 1.0 +~ Using a non-standard comment char to test CSV parser defaults are overridden +1,2,3,4,5.01,2015-08-20 15:57:00 +6,7,8,9,0,2015-08-21 16:58:01 +~0,9,8,7,6,2015-08-22 17:59:02 +1,2,3,4,5,2015-08-23 18:00:42 diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/disable_comments.csv new file mode 100644 index 0000000000000..304d406e4d980 --- /dev/null +++ b/sql/core/src/test/resources/disable_comments.csv @@ -0,0 +1,2 @@ +#1,2,3 +4,5,6 diff --git a/sql/core/src/test/resources/empty.csv b/sql/core/src/test/resources/empty.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala new file mode 100644 index 0000000000000..a1796f1326007 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class InferSchemaSuite extends SparkFunSuite { + + test("String fields types are inferred correctly from null types") { + assert(CSVInferSchema.inferField(NullType, "") == NullType) + assert(CSVInferSchema.inferField(NullType, null) == NullType) + assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType) + assert(CSVInferSchema.inferField(NullType, "60") == IntegerType) + assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) + assert(CSVInferSchema.inferField(NullType, "test") == StringType) + assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + } + + test("String fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType) + assert(CSVInferSchema.inferField(LongType, "test") == StringType) + assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + } + + test("Timestamp field types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) + } + + test("Type arrays are merged to highest common type") { + assert( + CSVInferSchema.mergeRowTypes(Array(StringType), + Array(DoubleType)).deep == Array(StringType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(IntegerType), + Array(LongType)).deep == Array(LongType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(DoubleType), + Array(LongType)).deep == Array(DoubleType).deep) + } + + test("Null fields are handled properly when a nullValue is specified") { + assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType) + assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType) + assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType) + assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) + assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) + assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala new file mode 100644 index 0000000000000..c0c38c6787789 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite + +/** + * test cases for StringIteratorReader + */ +class CSVParserSuite extends SparkFunSuite { + + private def readAll(iter: Iterator[String]) = { + val reader = new StringIteratorReader(iter) + var c: Int = -1 + val read = new scala.collection.mutable.StringBuilder() + do { + c = reader.read() + read.append(c.toChar) + } while (c != -1) + + read.dropRight(1).toString + } + + private def readBufAll(iter: Iterator[String], bufSize: Int) = { + val reader = new StringIteratorReader(iter) + val cbuf = new Array[Char](bufSize) + val read = new scala.collection.mutable.StringBuilder() + + var done = false + do { // read all input one cbuf at a time + var numRead = 0 + var n = 0 + do { // try to fill cbuf + var off = 0 + var len = cbuf.length + n = reader.read(cbuf, off, len) + + if (n != -1) { + off += n + len -= n + } + + assert(len >= 0 && len <= cbuf.length) + assert(off >= 0 && off <= cbuf.length) + read.appendAll(cbuf.take(n)) + } while (n > 0) + if(n != -1) { + numRead += n + } else { + done = true + } + } while (!done) + + read.toString + } + + test("Hygiene") { + val reader = new StringIteratorReader(List("").toIterator) + assert(reader.ready === true) + assert(reader.markSupported === false) + intercept[IllegalArgumentException] { reader.skip(1) } + intercept[IllegalArgumentException] { reader.mark(1) } + intercept[IllegalArgumentException] { reader.reset() } + } + + test("Regular case") { + val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") + val read = readAll(input.toIterator) + assert(read === input.mkString("\n") ++ ("\n")) + } + + test("Empty iter") { + val input = List[String]() + val read = readAll(input.toIterator) + assert(read === "") + } + + test("Embedded new line") { + val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") + val read = readAll(input.toIterator) + assert(read === input.mkString("\n") ++ ("\n")) + } + + test("Buffer Regular case") { + val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") + val output = input.mkString("\n") ++ ("\n") + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, i) + assert(read === output) + } + } + + test("Buffer Empty iter") { + val input = List[String]() + val output = "" + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, 1) + assert(read === "") + } + } + + test("Buffer Embedded new line") { + val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") + val output = input.mkString("\n") ++ ("\n") + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, 1) + assert(read === output) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala new file mode 100644 index 0000000000000..8fdd31aa4334f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File +import java.nio.charset.UnsupportedCharsetException +import java.sql.Timestamp + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + private val carsFile = "cars.csv" + private val carsFile8859 = "cars_iso-8859-1.csv" + private val carsTsvFile = "cars.tsv" + private val carsAltFile = "cars-alternative.csv" + private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv" + private val carsNullFile = "cars-null.csv" + private val emptyFile = "empty.csv" + private val commentsFile = "comments.csv" + private val disableCommentsFile = "disable_comments.csv" + + private def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + + /** Verifies data and schema. */ + private def verifyCars( + df: DataFrame, + withHeader: Boolean, + numCars: Int = 3, + numFields: Int = 5, + checkHeader: Boolean = true, + checkValues: Boolean = true, + checkTypes: Boolean = false): Unit = { + + val numColumns = numFields + val numRows = if (withHeader) numCars else numCars + 1 + // schema + assert(df.schema.fieldNames.length === numColumns) + assert(df.collect().length === numRows) + + if (checkHeader) { + if (withHeader) { + assert(df.schema.fieldNames === Array("year", "make", "model", "comment", "blank")) + } else { + assert(df.schema.fieldNames === Array("C0", "C1", "C2", "C3", "C4")) + } + } + + if (checkValues) { + val yearValues = List("2012", "1997", "2015") + val actualYears = if (!withHeader) "year" :: yearValues else yearValues + val years = if (withHeader) df.select("year").collect() else df.select("C0").collect() + + years.zipWithIndex.foreach { case (year, index) => + if (checkTypes) { + assert(year === Row(actualYears(index).toInt)) + } else { + assert(year === Row(actualYears(index))) + } + } + } + } + + test("simple csv test") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("simple csv test with type inference") { + val cars = sqlContext + .read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = true, checkTypes = true) + } + + test("test with alternative delimiter and quote") { + val cars = sqlContext.read + .format("csv") + .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) + .load(testFile(carsAltFile)) + + verifyCars(cars, withHeader = true) + } + + test("bad encoding name") { + val exception = intercept[UnsupportedCharsetException] { + sqlContext + .read + .format("csv") + .option("charset", "1-9588-osi") + .load(testFile(carsFile8859)) + } + + assert(exception.getMessage.contains("1-9588-osi")) + } + + ignore("test different encoding") { + // scalastyle:off + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsFile8859)}", header "true", + |charset "iso-8859-1", delimiter "þ") + """.stripMargin.replaceAll("\n", " ")) + //scalstyle:on + + verifyCars(sqlContext.table("carsTable"), withHeader = true) + } + + test("DDL test with tab separated file") { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") + """.stripMargin.replaceAll("\n", " ")) + + verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + } + + test("DDL test parsing decimal type") { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, priceTag decimal, + | comments string, grp string) + |USING csv + |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") + """.stripMargin.replaceAll("\n", " ")) + + assert( + sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + } + + test("test for DROPMALFORMED parsing mode") { + val cars = sqlContext.read + .format("csv") + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) + + assert(cars.select("year").collect().size === 2) + } + + test("test for FAILFAST parsing mode") { + val exception = intercept[SparkException]{ + sqlContext.read + .format("csv") + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } + + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } + + test("test with null quote character") { + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .option("quote", "") + .load(testFile(carsUnbalancedQuotesFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + + } + + test("test with empty file and known schema") { + val result = sqlContext.read + .format("csv") + .schema(StructType(List(StructField("column", StringType, false)))) + .load(testFile(emptyFile)) + + assert(result.collect.size === 0) + assert(result.schema.fieldNames.size === 1) + } + + + test("DDL test with empty file") { + sqlContext.sql(s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, grp string) + |USING csv + |OPTIONS (path "${testFile(emptyFile)}", header "false") + """.stripMargin.replaceAll("\n", " ")) + + assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + } + + test("DDL test with schema") { + sqlContext.sql(s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, blank string) + |USING csv + |OPTIONS (path "${testFile(carsFile)}", header "true") + """.stripMargin.replaceAll("\n", " ")) + + val cars = sqlContext.table("carsTable") + verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) + assert( + cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) + } + + test("save csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .save(csvDir) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("save csv with quote") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("quote", "\"") + .save(csvDir) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .option("quote", "\"") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("commented lines in CSV data") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false")) + .load(testFile(commentsFile)) + .collect() + + val expected = + Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), + Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), + Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("inferring schema with commented lines in CSV data") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) + .load(testFile(commentsFile)) + .collect() + + val expected = + Seq(Seq(1, 2, 3, 4, 5.01D, Timestamp.valueOf("2015-08-20 15:57:00")), + Seq(6, 7, 8, 9, 0, Timestamp.valueOf("2015-08-21 16:58:01")), + Seq(1, 2, 3, 4, 5, Timestamp.valueOf("2015-08-23 18:00:42"))) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("setting comment to null disables comment support") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "", "header" -> "false")) + .load(testFile(disableCommentsFile)) + .collect() + + val expected = + Seq( + Seq("#1", "2", "3"), + Seq("4", "5", "6")) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("nullable fields with user defined null value of \"null\"") { + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = sqlContext.read + .format("csv") + .schema(dataSchema) + .options(Map("header" -> "true", "nullValue" -> "null")) + .load(testFile(carsNullFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) + assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala new file mode 100644 index 0000000000000..40c5ccd0f7a4a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.sql.{Date, Timestamp} +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class CSVTypeCastSuite extends SparkFunSuite { + + test("Can parse decimal type values") { + val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") + val decimalValues = Seq(10.05, 1000.01, 158058049.001) + val decimalType = new DecimalType() + + stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => + assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString)) + } + } + + test("Can parse escaped characters") { + assert(CSVTypeCast.toChar("""\t""") === '\t') + assert(CSVTypeCast.toChar("""\r""") === '\r') + assert(CSVTypeCast.toChar("""\b""") === '\b') + assert(CSVTypeCast.toChar("""\f""") === '\f') + assert(CSVTypeCast.toChar("""\"""") === '\"') + assert(CSVTypeCast.toChar("""\'""") === '\'') + assert(CSVTypeCast.toChar("""\u0000""") === '\u0000') + } + + test("Does not accept delimiter larger than one character") { + val exception = intercept[IllegalArgumentException]{ + CSVTypeCast.toChar("ab") + } + assert(exception.getMessage.contains("cannot be more than one character")) + } + + test("Throws exception for unsupported escaped characters") { + val exception = intercept[IllegalArgumentException]{ + CSVTypeCast.toChar("""\1""") + } + assert(exception.getMessage.contains("Unsupported special character for delimiter")) + } + + test("Nullable types are handled") { + assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null) + } + + test("String type should always return the same as the input") { + assert(CSVTypeCast.castTo("", StringType, nullable = true) == "") + assert(CSVTypeCast.castTo("", StringType, nullable = false) == "") + } + + test("Throws exception for empty string with non null type") { + val exception = intercept[NumberFormatException]{ + CSVTypeCast.castTo("", IntegerType, nullable = false) + } + assert(exception.getMessage.contains("For input string: \"\"")) + } + + test("Types are cast correctly") { + assert(CSVTypeCast.castTo("10", ByteType) == 10) + assert(CSVTypeCast.castTo("10", ShortType) == 10) + assert(CSVTypeCast.castTo("10", IntegerType) == 10) + assert(CSVTypeCast.castTo("10", LongType) == 10) + assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) + assert(CSVTypeCast.castTo("true", BooleanType) == true) + val timestamp = "2015-01-01 00:00:00" + assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp)) + assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01")) + } + + test("Float and Double Types are cast correctly with Locale") { + val locale : Locale = new Locale("fr", "FR") + Locale.setDefault(locale) + assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + } +} From ad1503f92e1f6e960a24f9f5d36b1735d1f5073a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 15 Jan 2016 12:03:28 -0800 Subject: [PATCH 1580/2089] [SPARK-12667] Remove block manager's internal "external block store" API This pull request removes the external block store API. This is rarely used, and the file system interface is actually a better, more standard way to interact with external storage systems. There are some other things to remove also, as pointed out by JoshRosen. We will do those as follow-up pull requests. Author: Reynold Xin Closes #10752 from rxin/remove-offheap. --- core/pom.xml | 27 -- .../scala/org/apache/spark/SparkContext.scala | 6 - .../spark/rdd/LocalRDDCheckpointData.scala | 8 +- .../org/apache/spark/status/api/v1/api.scala | 2 - .../apache/spark/storage/BlockManager.scala | 55 +-- .../spark/storage/BlockManagerMaster.scala | 6 +- .../storage/BlockManagerMasterEndpoint.scala | 41 +-- .../spark/storage/BlockManagerMessages.scala | 7 +- .../spark/storage/BlockStatusListener.scala | 9 +- .../spark/storage/BlockUpdatedInfo.scala | 6 +- .../spark/storage/ExternalBlockManager.scala | 122 ------- .../spark/storage/ExternalBlockStore.scala | 211 ------------ .../org/apache/spark/storage/RDDInfo.scala | 9 +- .../apache/spark/storage/StorageLevel.scala | 4 +- .../apache/spark/storage/StorageUtils.scala | 31 +- .../spark/storage/TachyonBlockManager.scala | 324 ------------------ .../apache/spark/ui/storage/StoragePage.scala | 8 - .../org/apache/spark/util/JsonProtocol.scala | 17 +- .../spark/memory/MemoryManagerSuite.scala | 4 +- .../spark/rdd/LocalCheckpointSuite.scala | 6 - .../spark/storage/BlockManagerSuite.scala | 26 +- .../storage/BlockStatusListenerSuite.scala | 18 +- .../storage/StorageStatusListenerSuite.scala | 18 +- .../apache/spark/storage/StorageSuite.scala | 97 +++--- .../spark/ui/storage/StoragePageSuite.scala | 60 +--- .../spark/ui/storage/StorageTabSuite.scala | 30 +- .../apache/spark/util/JsonProtocolSuite.scala | 32 +- dev/deps/spark-deps-hadoop-2.2 | 6 +- dev/deps/spark-deps-hadoop-2.3 | 4 - dev/deps/spark-deps-hadoop-2.4 | 4 - dev/deps/spark-deps-hadoop-2.6 | 4 - .../spark/examples/SparkTachyonHdfsLR.scala | 93 ----- .../spark/examples/SparkTachyonPi.scala | 50 --- project/MimaExcludes.scala | 6 +- 34 files changed, 139 insertions(+), 1212 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala diff --git a/core/pom.xml b/core/pom.xml index 3bec5debc2968..2071a58de92e4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -267,33 +267,6 @@ oro ${oro.version} - - org.tachyonproject - tachyon-client - 0.8.2 - - - org.apache.hadoop - hadoop-client - - - org.apache.curator - curator-client - - - org.apache.curator - curator-framework - - - org.apache.curator - curator-recipes - - - org.tachyonproject - tachyon-underfs-glusterfs - - - org.seleniumhq.selenium selenium-java diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 98075cef112db..77acb7052ddf5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -243,10 +243,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec - // Generate the random name for a temp folder in external block store. - // Add a timestamp as the suffix here to make it more safe - val externalBlockStoreFolderName = "spark-" + randomUUID.toString() - def isLocal: Boolean = (master == "local" || master.startsWith("local[")) /** @@ -423,8 +419,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - _conf.set("spark.externalBlockStore.folderName", externalBlockStoreFolderName) - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") // "_jobProgressListener" should be set up before creating SparkEnv because when creating diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index c115e0ff74d3c..dad90fc220849 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -72,12 +72,6 @@ private[spark] object LocalRDDCheckpointData { * This method is idempotent. */ def transformStorageLevel(level: StorageLevel): StorageLevel = { - // If this RDD is to be cached off-heap, fail fast since we cannot provide any - // correctness guarantees about subsequent computations after the first one - if (level.useOffHeap) { - throw new SparkException("Local checkpointing is not compatible with off-heap caching.") - } - StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 9cd52d6c2bef5..fe372116f1b69 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -85,8 +85,6 @@ class JobData private[spark]( val numSkippedStages: Int, val numFailedStages: Int) -// Q: should Tachyon size go in here as well? currently the UI only shows it on the overall storage -// page ... does anybody pay attention to it? class RDDStorageInfo private[spark]( val id: Int, val name: String, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4479e6875a731..e49d79b8ad66e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -83,13 +83,8 @@ private[spark] class BlockManager( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) // Actual storage of where blocks are kept - private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, memoryManager) private[spark] val diskStore = new DiskStore(this, diskBlockManager) - private[spark] lazy val externalBlockStore: ExternalBlockStore = { - externalBlockStoreInitialized = true - new ExternalBlockStore(this, executorId) - } memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time. @@ -313,8 +308,7 @@ private[spark] class BlockManager( blockInfo.asScala.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L - // Assume that block is not in external block store - BlockStatus(info.level, memSize, diskSize, 0L) + BlockStatus(info.level, memSize = memSize, diskSize = diskSize) } } @@ -363,10 +357,8 @@ private[spark] class BlockManager( if (info.tellMaster) { val storageLevel = status.storageLevel val inMemSize = Math.max(status.memSize, droppedMemorySize) - val inExternalBlockStoreSize = status.externalBlockStoreSize val onDiskSize = status.diskSize - master.updateBlockInfo( - blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inExternalBlockStoreSize) + master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) } else { true } @@ -381,20 +373,17 @@ private[spark] class BlockManager( info.synchronized { info.level match { case null => - BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) + BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) - val inExternalBlockStore = level.useOffHeap && externalBlockStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false - val replication = if (inMem || inExternalBlockStore || onDisk) level.replication else 1 + val replication = if (inMem || onDisk) level.replication else 1 val storageLevel = - StorageLevel(onDisk, inMem, inExternalBlockStore, deserialized, replication) + StorageLevel(onDisk, inMem, deserialized, replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L - val externalBlockStoreSize = - if (inExternalBlockStore) externalBlockStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - BlockStatus(storageLevel, memSize, diskSize, externalBlockStoreSize) + BlockStatus(storageLevel, memSize, diskSize) } } } @@ -475,25 +464,6 @@ private[spark] class BlockManager( } } - // Look for the block in external block store - if (level.useOffHeap) { - logDebug(s"Getting block $blockId from ExternalBlockStore") - if (externalBlockStore.contains(blockId)) { - val result = if (asBlockResult) { - externalBlockStore.getValues(blockId) - .map(new BlockResult(_, DataReadMethod.Memory, info.size)) - } else { - externalBlockStore.getBytes(blockId) - } - result match { - case Some(values) => - return result - case None => - logDebug(s"Block $blockId not found in ExternalBlockStore") - } - } - } - // Look for block on disk, potentially storing it back in memory if required if (level.useDisk) { logDebug(s"Getting block $blockId from disk") @@ -786,9 +756,6 @@ private[spark] class BlockManager( // Put it in memory first, even if it also has useDisk set to true; // We will drop it to disk later if the memory store can't hold it. (true, memoryStore) - } else if (putLevel.useOffHeap) { - // Use external block store - (false, externalBlockStore) } else if (putLevel.useDisk) { // Don't get back the bytes from put unless we replicate them (putLevel.replication > 1, diskStore) @@ -909,8 +876,7 @@ private[spark] class BlockManager( val peersForReplication = new ArrayBuffer[BlockManagerId] val peersReplicatedTo = new ArrayBuffer[BlockManagerId] val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] - val tLevel = StorageLevel( - level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) val startTime = System.currentTimeMillis val random = new Random(blockId.hashCode) @@ -1120,9 +1086,7 @@ private[spark] class BlockManager( // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) - val removedFromExternalBlockStore = - if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false - if (!removedFromMemory && !removedFromDisk && !removedFromExternalBlockStore) { + if (!removedFromMemory && !removedFromDisk) { logWarning(s"Block $blockId could not be removed as it was not found in either " + "the disk, memory, or external block store") } @@ -1212,9 +1176,6 @@ private[spark] class BlockManager( blockInfo.clear() memoryStore.clear() diskStore.clear() - if (externalBlockStoreInitialized) { - externalBlockStore.clear() - } futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index da1de11d605c9..0b7aa599e9ded 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -54,11 +54,9 @@ class BlockManagerMaster( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long): Boolean = { + diskSize: Long): Boolean = { val res = driverEndpoint.askWithRetry[Boolean]( - UpdateBlockInfo(blockManagerId, blockId, storageLevel, - memSize, diskSize, externalBlockStoreSize)) + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logDebug(s"Updated info of block $blockId") res } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 4db400a3442ca..fbb3df8c3e90c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -59,10 +59,9 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case _updateBlockInfo @ UpdateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => - context.reply(updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + case _updateBlockInfo @ + UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + context.reply(updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)) listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => @@ -325,8 +324,7 @@ class BlockManagerMasterEndpoint( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long): Boolean = { + diskSize: Long): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { if (blockManagerId.isDriver && !isLocal) { @@ -343,8 +341,7 @@ class BlockManagerMasterEndpoint( return true } - blockManagerInfo(blockManagerId).updateBlockInfo( - blockId, storageLevel, memSize, diskSize, externalBlockStoreSize) + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) var locations: mutable.HashSet[BlockManagerId] = null if (blockLocations.containsKey(blockId)) { @@ -404,17 +401,13 @@ class BlockManagerMasterEndpoint( } @DeveloperApi -case class BlockStatus( - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) { - def isCached: Boolean = memSize + diskSize + externalBlockStoreSize > 0 +case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) { + def isCached: Boolean = memSize + diskSize > 0 } @DeveloperApi object BlockStatus { - def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L) } private[spark] class BlockManagerInfo( @@ -443,8 +436,7 @@ private[spark] class BlockManagerInfo( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) { + diskSize: Long) { updateLastSeenMs() @@ -468,7 +460,7 @@ private[spark] class BlockManagerInfo( * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ var blockStatus: BlockStatus = null if (storageLevel.useMemory) { - blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( @@ -476,17 +468,11 @@ private[spark] class BlockManagerInfo( Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } - if (storageLevel.useOffHeap) { - blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) - _blocks.put(blockId, blockStatus) - logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) - } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } @@ -504,11 +490,6 @@ private[spark] class BlockManagerInfo( logInfo("Removed %s on %s on disk (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) } - if (blockStatus.storageLevel.useOffHeap) { - logInfo("Removed %s on %s on externalBlockStore (size: %s)".format( - blockId, blockManagerId.hostPort, - Utils.bytesToString(blockStatus.externalBlockStoreSize))) - } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index f392a4a0cd9be..6bded92700504 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -63,12 +63,11 @@ private[spark] object BlockManagerMessages { var blockId: BlockId, var storageLevel: StorageLevel, var memSize: Long, - var diskSize: Long, - var externalBlockStoreSize: Long) + var diskSize: Long) extends ToBlockManagerMaster with Externalizable { - def this() = this(null, null, null, 0, 0, 0) // For deserialization only + def this() = this(null, null, null, 0, 0) // For deserialization only override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { blockManagerId.writeExternal(out) @@ -76,7 +75,6 @@ private[spark] object BlockManagerMessages { storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) - out.writeLong(externalBlockStoreSize) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -85,7 +83,6 @@ private[spark] object BlockManagerMessages { storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() - externalBlockStoreSize = in.readLong() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala index 2789e25b8d3ab..0a14fcadf53e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -26,8 +26,7 @@ private[spark] case class BlockUIData( location: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) + diskSize: Long) /** * The aggregated status of stream blocks in an executor @@ -41,8 +40,6 @@ private[spark] case class ExecutorStreamBlockStatus( def totalDiskSize: Long = blocks.map(_.diskSize).sum - def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum - def numStreamBlocks: Int = blocks.size } @@ -62,7 +59,6 @@ private[spark] class BlockStatusListener extends SparkListener { val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel val memSize = blockUpdated.blockUpdatedInfo.memSize val diskSize = blockUpdated.blockUpdatedInfo.diskSize - val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize synchronized { // Drop the update info if the block manager is not registered @@ -74,8 +70,7 @@ private[spark] class BlockStatusListener extends SparkListener { blockManagerId.hostPort, storageLevel, memSize, - diskSize, - externalBlockStoreSize) + diskSize) ) } else { // If isValid is not true, it means we should drop the block. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala index a5790e4454a89..e070bf658acb8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -30,8 +30,7 @@ case class BlockUpdatedInfo( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) + diskSize: Long) private[spark] object BlockUpdatedInfo { @@ -41,7 +40,6 @@ private[spark] object BlockUpdatedInfo { updateBlockInfo.blockId, updateBlockInfo.storageLevel, updateBlockInfo.memSize, - updateBlockInfo.diskSize, - updateBlockInfo.externalBlockStoreSize) + updateBlockInfo.diskSize) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala deleted file mode 100644 index f39325a12d244..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -/** - * An abstract class that the concrete external block manager has to inherit. - * The class has to have a no-argument constructor, and will be initialized by init, - * which is invoked by ExternalBlockStore. The main input parameter is blockId for all - * the methods, which is the unique identifier for Block in one Spark application. - * - * The underlying external block manager should avoid any name space conflicts among multiple - * Spark applications. For example, creating different directory for different applications - * by randomUUID - * - */ -private[spark] abstract class ExternalBlockManager { - - protected var blockManager: BlockManager = _ - - override def toString: String = {"External Block Store"} - - /** - * Initialize a concrete block manager implementation. Subclass should initialize its internal - * data structure, e.g, file system, in this function, which is invoked by ExternalBlockStore - * right after the class is constructed. The function should throw IOException on failure - * - * @throws java.io.IOException if there is any file system failure during the initialization. - */ - def init(blockManager: BlockManager, executorId: String): Unit = { - this.blockManager = blockManager - } - - /** - * Drop the block from underlying external block store, if it exists.. - * @return true on successfully removing the block - * false if the block could not be removed as it was not found - * - * @throws java.io.IOException if there is any file system failure in removing the block. - */ - def removeBlock(blockId: BlockId): Boolean - - /** - * Used by BlockManager to check the existence of the block in the underlying external - * block store. - * @return true if the block exists. - * false if the block does not exists. - * - * @throws java.io.IOException if there is any file system failure in checking - * the block existence. - */ - def blockExists(blockId: BlockId): Boolean - - /** - * Put the given block to the underlying external block store. Note that in normal case, - * putting a block should never fail unless something wrong happens to the underlying - * external block store, e.g., file system failure, etc. In this case, IOException - * should be thrown. - * - * @throws java.io.IOException if there is any file system failure in putting the block. - */ - def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit - - def putValues(blockId: BlockId, values: Iterator[_]): Unit = { - val bytes = blockManager.dataSerialize(blockId, values) - putBytes(blockId, bytes) - } - - /** - * Retrieve the block bytes. - * @return Some(ByteBuffer) if the block bytes is successfully retrieved - * None if the block does not exist in the external block store. - * - * @throws java.io.IOException if there is any file system failure in getting the block. - */ - def getBytes(blockId: BlockId): Option[ByteBuffer] - - /** - * Retrieve the block data. - * @return Some(Iterator[Any]) if the block data is successfully retrieved - * None if the block does not exist in the external block store. - * - * @throws java.io.IOException if there is any file system failure in getting the block. - */ - def getValues(blockId: BlockId): Option[Iterator[_]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) - } - - /** - * Get the size of the block saved in the underlying external block store, - * which is saved before by putBytes. - * @return size of the block - * 0 if the block does not exist - * - * @throws java.io.IOException if there is any file system failure in getting the block size. - */ - def getSize(blockId: BlockId): Long - - /** - * Clean up any information persisted in the underlying external block store, - * e.g., the directory, files, etc,which is invoked by the shutdown hook of ExternalBlockStore - * during system shutdown. - * - */ - def shutdown() -} diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala deleted file mode 100644 index 94883a54a74e4..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -import scala.util.control.NonFatal - -import org.apache.spark.Logging -import org.apache.spark.util.{ShutdownHookManager, Utils} - - -/** - * Stores BlockManager blocks on ExternalBlockStore. - * We capture any potential exception from underlying implementation - * and return with the expected failure value - */ -private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: String) - extends BlockStore(blockManager: BlockManager) with Logging { - - lazy val externalBlockManager: Option[ExternalBlockManager] = createBlkManager() - - logInfo("ExternalBlockStore started") - - override def getSize(blockId: BlockId): Long = { - try { - externalBlockManager.map(_.getSize(blockId)).getOrElse(0) - } catch { - case NonFatal(t) => - logError(s"Error in getSize($blockId)", t) - 0L - } - } - - override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { - putIntoExternalBlockStore(blockId, bytes, returnValues = true) - } - - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIntoExternalBlockStore(blockId, values.toIterator, returnValues) - } - - override def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIntoExternalBlockStore(blockId, values, returnValues) - } - - private def putIntoExternalBlockStore( - blockId: BlockId, - values: Iterator[_], - returnValues: Boolean): PutResult = { - logTrace(s"Attempting to put block $blockId into ExternalBlockStore") - // we should never hit here if externalBlockManager is None. Handle it anyway for safety. - try { - val startTime = System.currentTimeMillis - if (externalBlockManager.isDefined) { - externalBlockManager.get.putValues(blockId, values) - val size = getSize(blockId) - val data = if (returnValues) { - Left(getValues(blockId).get) - } else { - null - } - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(size), finishTime - startTime)) - PutResult(size, data) - } else { - logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured") - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } catch { - case NonFatal(t) => - logError(s"Error in putValues($blockId)", t) - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } - - private def putIntoExternalBlockStore( - blockId: BlockId, - bytes: ByteBuffer, - returnValues: Boolean): PutResult = { - logTrace(s"Attempting to put block $blockId into ExternalBlockStore") - // we should never hit here if externalBlockManager is None. Handle it anyway for safety. - try { - val startTime = System.currentTimeMillis - if (externalBlockManager.isDefined) { - val byteBuffer = bytes.duplicate() - byteBuffer.rewind() - externalBlockManager.get.putBytes(blockId, byteBuffer) - val size = bytes.limit() - val data = if (returnValues) { - Right(bytes) - } else { - null - } - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(size), finishTime - startTime)) - PutResult(size, data) - } else { - logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured") - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } catch { - case NonFatal(t) => - logError(s"Error in putBytes($blockId)", t) - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } - - // We assume the block is removed even if exception thrown - override def remove(blockId: BlockId): Boolean = { - try { - externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true) - } catch { - case NonFatal(t) => - logError(s"Error in removeBlock($blockId)", t) - true - } - } - - override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - try { - externalBlockManager.flatMap(_.getValues(blockId)) - } catch { - case NonFatal(t) => - logError(s"Error in getValues($blockId)", t) - None - } - } - - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - try { - externalBlockManager.flatMap(_.getBytes(blockId)) - } catch { - case NonFatal(t) => - logError(s"Error in getBytes($blockId)", t) - None - } - } - - override def contains(blockId: BlockId): Boolean = { - try { - val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false) - if (!ret) { - logInfo(s"Remove block $blockId") - blockManager.removeBlock(blockId, true) - } - ret - } catch { - case NonFatal(t) => - logError(s"Error in getBytes($blockId)", t) - false - } - } - - // Create concrete block manager and fall back to Tachyon by default for backward compatibility. - private def createBlkManager(): Option[ExternalBlockManager] = { - val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) - .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) - - try { - val instance = Utils.classForName(clsName) - .newInstance() - .asInstanceOf[ExternalBlockManager] - instance.init(blockManager, executorId) - ShutdownHookManager.addShutdownHook { () => - logDebug("Shutdown hook called") - externalBlockManager.map(_.shutdown()) - } - Some(instance) - } catch { - case NonFatal(t) => - logError("Cannot initialize external block store", t) - None - } - } -} - -private[spark] object ExternalBlockStore extends Logging { - val MAX_DIR_CREATION_ATTEMPTS = 10 - val SUB_DIRS_PER_DIR = "64" - val BASE_DIR = "spark.externalBlockStore.baseDir" - val FOLD_NAME = "spark.externalBlockStore.folderName" - val MASTER_URL = "spark.externalBlockStore.url" - val BLOCK_MANAGER_NAME = "spark.externalBlockStore.blockManager" - val DEFAULT_BLOCK_MANAGER_NAME = "org.apache.spark.storage.TachyonBlockManager" -} diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 673f7ad79def0..083d78b59ebee 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.Utils @DeveloperApi class RDDInfo( @@ -37,15 +37,14 @@ class RDDInfo( var diskSize = 0L var externalBlockStoreSize = 0L - def isCached: Boolean = - (memSize + diskSize + externalBlockStoreSize > 0) && numCachedPartitions > 0 + def isCached: Boolean = (memSize + diskSize > 0) && numCachedPartitions > 0 override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + - "MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s").format( + "MemorySize: %s; DiskSize: %s").format( name, id, storageLevel.toString, numCachedPartitions, numPartitions, - bytesToString(memSize), bytesToString(externalBlockStoreSize), bytesToString(diskSize)) + bytesToString(memSize), bytesToString(diskSize)) } override def compare(that: RDDInfo): Int = { diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 703bce3e6b85b..38e9534251c3a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -150,7 +150,9 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2) - val OFF_HEAP = new StorageLevel(false, false, true, false) + + // Redirect to MEMORY_ONLY_SER for now. + val OFF_HEAP = MEMORY_ONLY_SER /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index c4ac30092f807..8e2cfb2441f00 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -48,14 +48,14 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { * non-RDD blocks for the same reason. In particular, RDD storage information is stored * in a map indexed by the RDD ID to the following 4-tuple: * - * (memory size, disk size, off-heap size, storage level) + * (memory size, disk size, storage level) * * We assume that all the blocks that belong to the same RDD have the same storage level. * This field is not relevant to non-RDD blocks, however, so the storage information for * non-RDD blocks contains only the first 3 fields (in the same order). */ - private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, Long, StorageLevel)] - private var _nonRddStorageInfo: (Long, Long, Long) = (0L, 0L, 0L) + private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, StorageLevel)] + private var _nonRddStorageInfo: (Long, Long) = (0L, 0L) /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { @@ -177,20 +177,14 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the off-heap space used by this block manager. */ - def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) - /** Return the off-heap space used by the given RDD in this block manager in O(1) time. */ - def offHeapUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._3).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._4) + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._3) /** * Update the relevant storage info, taking into account any existing status for this block. @@ -199,34 +193,31 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val oldBlockStatus = getBlock(blockId).getOrElse(BlockStatus.empty) val changeInMem = newBlockStatus.memSize - oldBlockStatus.memSize val changeInDisk = newBlockStatus.diskSize - oldBlockStatus.diskSize - val changeInExternalBlockStore = - newBlockStatus.externalBlockStoreSize - oldBlockStatus.externalBlockStoreSize val level = newBlockStatus.storageLevel // Compute new info from old info - val (oldMem, oldDisk, oldExternalBlockStore) = blockId match { + val (oldMem, oldDisk) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, externalBlockStore, _) => (mem, disk, externalBlockStore) } - .getOrElse((0L, 0L, 0L)) + .map { case (mem, disk, _) => (mem, disk) } + .getOrElse((0L, 0L)) case _ => _nonRddStorageInfo } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) - val newExternalBlockStore = math.max(oldExternalBlockStore + changeInExternalBlockStore, 0L) // Set the correct info blockId match { case RDDBlockId(rddId, _) => // If this RDD is no longer persisted, remove it - if (newMem + newDisk + newExternalBlockStore == 0) { + if (newMem + newDisk == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, newExternalBlockStore, level) + _rddStorageInfo(rddId) = (newMem, newDisk, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk, newExternalBlockStore) + _nonRddStorageInfo = (newMem, newDisk) } } @@ -248,13 +239,11 @@ private[spark] object StorageUtils { val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum val memSize = statuses.map(_.memUsedByRdd(rddId)).sum val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - val externalBlockStoreSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum rddInfo.storageLevel = storageLevel rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.diskSize = diskSize - rddInfo.externalBlockStoreSize = externalBlockStoreSize } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala deleted file mode 100644 index 6aa7e13901779..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.IOException -import java.nio.ByteBuffer -import java.text.SimpleDateFormat -import java.util.{Date, Random} - -import scala.util.control.NonFatal - -import com.google.common.io.ByteStreams -import tachyon.{Constants, TachyonURI} -import tachyon.client.ClientContext -import tachyon.client.file.{TachyonFile, TachyonFileSystem} -import tachyon.client.file.TachyonFileSystem.TachyonFileSystemFactory -import tachyon.client.file.options.DeleteOptions -import tachyon.conf.TachyonConf -import tachyon.exception.{FileAlreadyExistsException, FileDoesNotExistException} - -import org.apache.spark.Logging -import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils - -/** - * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By - * default, one block is mapped to one file with a name given by its BlockId. - * - */ -private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging { - - var rootDirs: String = _ - var master: String = _ - var client: TachyonFileSystem = _ - private var subDirsPerTachyonDir: Int = _ - - // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; - // then, inside this directory, create multiple subdirectories that we will hash files into, - // in order to avoid having really large inodes at the top level in Tachyon. - private var tachyonDirs: Array[TachyonFile] = _ - private var subDirs: Array[Array[TachyonFile]] = _ - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - override def init(blockManager: BlockManager, executorId: String): Unit = { - super.init(blockManager, executorId) - val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon") - val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME) - - rootDirs = s"$storeDir/$appFolderName/$executorId" - master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") - client = if (master != null && master != "") { - val tachyonConf = new TachyonConf() - tachyonConf.set(Constants.MASTER_ADDRESS, master) - ClientContext.reset(tachyonConf) - TachyonFileSystemFactory.get - } else { - null - } - // original implementation call System.exit, we change it to run without extblkstore support - if (client == null) { - logError("Failed to connect to the Tachyon as the master address is not configured") - throw new IOException("Failed to connect to the Tachyon as the master " + - "address is not configured") - } - subDirsPerTachyonDir = blockManager.conf.get("spark.externalBlockStore.subDirectories", - ExternalBlockStore.SUB_DIRS_PER_DIR).toInt - - // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; - // then, inside this directory, create multiple subdirectories that we will hash files into, - // in order to avoid having really large inodes at the top level in Tachyon. - tachyonDirs = createTachyonDirs() - subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(registerShutdownDeleteDir) - } - - override def toString: String = {"ExternalBlockStore-Tachyon"} - - override def removeBlock(blockId: BlockId): Boolean = { - val file = getFile(blockId) - if (fileExists(file)) { - removeFile(file) - true - } else { - false - } - } - - override def blockExists(blockId: BlockId): Boolean = { - val file = getFile(blockId) - fileExists(file) - } - - override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { - val file = getFile(blockId) - val os = client.getOutStream(new TachyonURI(client.getInfo(file).getPath)) - try { - Utils.writeByteBuffer(bytes, os) - } catch { - case NonFatal(e) => - logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) - os.cancel() - } finally { - os.close() - } - } - - override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { - val file = getFile(blockId) - val os = client.getOutStream(new TachyonURI(client.getInfo(file).getPath)) - try { - blockManager.dataSerializeStream(blockId, os, values) - } catch { - case NonFatal(e) => - logWarning(s"Failed to put values of block $blockId into Tachyon", e) - os.cancel() - } finally { - os.close() - } - } - - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val file = getFile(blockId) - if (file == null) { - return None - } - val is = try { - client.getInStream(file) - } catch { - case _: FileDoesNotExistException => - return None - } - try { - val size = client.getInfo(file).length - val bs = new Array[Byte](size.asInstanceOf[Int]) - ByteStreams.readFully(is, bs) - Some(ByteBuffer.wrap(bs)) - } catch { - case NonFatal(e) => - logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) - None - } finally { - is.close() - } - } - - override def getValues(blockId: BlockId): Option[Iterator[_]] = { - val file = getFile(blockId) - if (file == null) { - return None - } - val is = try { - client.getInStream(file) - } catch { - case _: FileDoesNotExistException => - return None - } - try { - Some(blockManager.dataDeserializeStream(blockId, is)) - } finally { - is.close() - } - } - - override def getSize(blockId: BlockId): Long = { - client.getInfo(getFile(blockId.name)).length - } - - def removeFile(file: TachyonFile): Unit = { - client.delete(file) - } - - def fileExists(file: TachyonFile): Boolean = { - try { - client.getInfo(file) - true - } catch { - case _: FileDoesNotExistException => false - } - } - - def getFile(filename: String): TachyonFile = { - // Figure out which tachyon directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(filename) - val dirId = hash % tachyonDirs.length - val subDirId = (hash / tachyonDirs.length) % subDirsPerTachyonDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") - client.mkdir(path) - val newDir = client.loadMetadata(path) - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - val filePath = new TachyonURI(s"$subDir/$filename") - try { - client.create(filePath) - } catch { - case _: FileAlreadyExistsException => client.loadMetadata(filePath) - } - } - - def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name) - - // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore. - private def createTachyonDirs(): Array[TachyonFile] = { - logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var tachyonDir: TachyonFile = null - var tachyonDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") - try { - foundLocalDir = client.mkdir(path) - tachyonDir = client.loadMetadata(path) - } catch { - case _: FileAlreadyExistsException => // continue - } - } catch { - case NonFatal(e) => - logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS - + " attempts to create tachyon dir in " + rootDir) - System.exit(ExecutorExitCode.EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created tachyon directory at " + tachyonDir) - tachyonDir - } - } - - override def shutdown() { - logDebug("Shutdown hook called") - tachyonDirs.foreach { tachyonDir => - try { - if (!hasRootAsShutdownDeleteDir(tachyonDir)) { - deleteRecursively(tachyonDir, client) - } - } catch { - case NonFatal(e) => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) - } - } - } - - /** - * Delete a file or directory and its contents recursively. - */ - private def deleteRecursively(dir: TachyonFile, client: TachyonFileSystem) { - client.delete(dir, new DeleteOptions.Builder(ClientContext.getConf).setRecursive(true).build()) - } - - // Register the tachyon path to be deleted via shutdown hook - private def registerShutdownDeleteDir(file: TachyonFile) { - val absolutePath = client.getInfo(file).getPath - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Remove the tachyon path to be deleted via shutdown hook - private def removeShutdownDeleteDir(file: TachyonFile) { - val absolutePath = client.getInfo(file).getPath - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths -= absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - private def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = client.getInfo(file).getPath - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - private def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = client.getInfo(file).getPath - val hasRoot = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists( - path => !absolutePath.equals(path) && absolutePath.startsWith(path)) - } - if (hasRoot) { - logInfo(s"path = $absolutePath, already present as root for deletion.") - } - hasRoot - } - -} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 04f584621e71e..c9bb49b83e9cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -54,7 +54,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Cached Partitions", "Fraction Cached", "Size in Memory", - "Size in ExternalBlockStore", "Size on Disk") /** Render an HTML row representing an RDD */ @@ -71,7 +70,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - // scalastyle:on @@ -104,7 +102,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Executor ID", "Address", "Total Size in Memory", - "Total Size in ExternalBlockStore", "Total Size on Disk", "Stream Blocks") @@ -119,9 +116,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - @@ -195,8 +189,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { ("Memory", block.memSize) } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { ("Memory Serialized", block.memSize) - } else if (block.storageLevel.useOffHeap) { - ("External", block.externalBlockStoreSize) } else { throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a62fd2f339285..a6460bc8b8202 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -409,14 +409,12 @@ private[spark] object JsonProtocol { ("Number of Partitions" -> rddInfo.numPartitions) ~ ("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~ ("Memory Size" -> rddInfo.memSize) ~ - ("ExternalBlockStore Size" -> rddInfo.externalBlockStoreSize) ~ ("Disk Size" -> rddInfo.diskSize) } def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ - ("Use ExternalBlockStore" -> storageLevel.useOffHeap) ~ ("Deserialized" -> storageLevel.deserialized) ~ ("Replication" -> storageLevel.replication) } @@ -425,7 +423,6 @@ private[spark] object JsonProtocol { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ ("Memory Size" -> blockStatus.memSize) ~ - ("ExternalBlockStore Size" -> blockStatus.externalBlockStoreSize) ~ ("Disk Size" -> blockStatus.diskSize) } @@ -867,15 +864,11 @@ private[spark] object JsonProtocol { val numPartitions = (json \ "Number of Partitions").extract[Int] val numCachedPartitions = (json \ "Number of Cached Partitions").extract[Int] val memSize = (json \ "Memory Size").extract[Long] - // fallback to tachyon for backward compatibility - val externalBlockStoreSize = (json \ "ExternalBlockStore Size").toSome - .getOrElse(json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize - rddInfo.externalBlockStoreSize = externalBlockStoreSize rddInfo.diskSize = diskSize rddInfo } @@ -883,22 +876,16 @@ private[spark] object JsonProtocol { def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] - // fallback to tachyon for backward compatability - val useExternalBlockStore = (json \ "Use ExternalBlockStore").toSome - .getOrElse(json \ "Use Tachyon").extract[Boolean] val deserialized = (json \ "Deserialized").extract[Boolean] val replication = (json \ "Replication").extract[Int] - StorageLevel(useDisk, useMemory, useExternalBlockStore, deserialized, replication) + StorageLevel(useDisk, useMemory, deserialized, replication) } def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - // fallback to tachyon for backward compatability - val externalBlockStoreSize = (json \ "ExternalBlockStore Size").toSome - .getOrElse(json \ "Tachyon Size").extract[Long] - BlockStatus(storageLevel, memorySize, diskSize, externalBlockStoreSize) + BlockStatus(storageLevel, memorySize, diskSize) } def executorInfoFromJson(json: JValue): ExecutorInfo = { diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index f2924a6a5c052..3b2368798c1dd 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -102,14 +102,14 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // We can evict enough blocks to fulfill the request for space mm.releaseStorageMemory(numBytesToFree) args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( - (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))) // We need to add this call so that that the suite-level `evictedBlocks` is updated when // execution evicts storage; in that case, args.last will not be equal to evictedBlocks // because it will be a temporary buffer created inside of the MemoryManager rather than // being passed in by the test code. if (!(evictedBlocks eq args.last)) { evictedBlocks.append( - (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))) } true } else { diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index e694f5e5e7ad2..2802cd975292c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.rdd -import org.mockito.Mockito.spy - import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -46,10 +44,6 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { assert(transform(StorageLevel.MEMORY_AND_DISK_SER) === StorageLevel.MEMORY_AND_DISK_SER) assert(transform(StorageLevel.MEMORY_AND_DISK_2) === StorageLevel.MEMORY_AND_DISK_2) assert(transform(StorageLevel.MEMORY_AND_DISK_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2) - // Off-heap is not supported and Spark should fail fast - intercept[SparkException] { - transform(StorageLevel.OFF_HEAP) - } } test("basic lineage truncation") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 62e6c4f7932df..0f3156117004b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -121,11 +121,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("StorageLevel object caching") { - val level1 = StorageLevel(false, false, false, false, 3) + val level1 = StorageLevel(false, false, false, 3) // this should return the same object as level1 - val level2 = StorageLevel(false, false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return a different object - val level3 = StorageLevel(false, false, false, false, 2) + val level3 = StorageLevel(false, false, false, 2) assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -562,26 +562,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("tachyon storage") { - // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar. - val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) - conf.set(ExternalBlockStore.BLOCK_MANAGER_NAME, ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) - if (tachyonUnitTestEnabled) { - store = makeBlockManager(1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.OFF_HEAP) - store.putSingle("a2", a2, StorageLevel.OFF_HEAP) - store.putSingle("a3", a3, StorageLevel.OFF_HEAP) - assert(store.getSingle("a3").isDefined, "a3 was in store") - assert(store.getSingle("a2").isDefined, "a2 was in store") - assert(store.getSingle("a1").isDefined, "a1 was in store") - } else { - info("tachyon storage test disabled.") - } - } - test("on-disk storage") { store = makeBlockManager(1200) val a1 = new Array[Byte](400) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala index d7ffde1e7864e..06acca3943c20 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala @@ -34,16 +34,14 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0))) + diskSize = 100))) // The new block status should be added to the listener val expectedBlock = BlockUIData( StreamBlockId(0, 100), "localhost:10000", StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0 + diskSize = 100 ) val expectedExecutorStreamBlockStatus = Seq( ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) @@ -60,15 +58,13 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0))) + diskSize = 100))) val expectedBlock2 = BlockUIData( StreamBlockId(0, 100), "localhost:10001", StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0 + diskSize = 100 ) // Each block manager should contain one block val expectedExecutorStreamBlockStatus2 = Set( @@ -84,8 +80,7 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.NONE, // StorageLevel.NONE means removing it memSize = 0, - diskSize = 0, - externalBlockStoreSize = 0))) + diskSize = 0))) // Only the first block manager contains a block val expectedExecutorStreamBlockStatus3 = Set( ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), @@ -102,8 +97,7 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0))) + diskSize = 100))) // The second block manager is removed so we should not see the new block val expectedExecutorStreamBlockStatus4 = Seq( ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 1a199beb3558f..355d80d06898b 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -82,9 +82,9 @@ class StorageStatusListenerSuite extends SparkFunSuite { listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics - val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) - val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L, 0L)) - val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L)) + val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L)) + val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L)) + val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L)) taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) taskMetrics2.updatedBlocks = Some(Seq(block3)) @@ -105,9 +105,9 @@ class StorageStatusListenerSuite extends SparkFunSuite { assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) // Task end with dropped blocks - val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) - val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) - val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) + val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L)) + val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L)) + val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L)) taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3)) taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3)) @@ -130,9 +130,9 @@ class StorageStatusListenerSuite extends SparkFunSuite { listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics - val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) - val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L, 0L)) - val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L)) + val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L)) + val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L)) + val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L)) taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) taskMetrics2.updatedBlocks = Some(Seq(block3)) listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index 1d5a813a4d336..e5733aebf607c 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -33,10 +33,9 @@ class StorageSuite extends SparkFunSuite { assert(status.memUsed === 0L) assert(status.memRemaining === 1000L) assert(status.diskUsed === 0L) - assert(status.offHeapUsed === 0L) - status.addBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(TestBlockId("faa"), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status.addBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("faa"), BlockStatus(memAndDisk, 10L, 20L)) status } @@ -50,18 +49,16 @@ class StorageSuite extends SparkFunSuite { assert(status.memUsed === 30L) assert(status.memRemaining === 970L) assert(status.diskUsed === 60L) - assert(status.offHeapUsed === 3L) } test("storage status update non-RDD blocks") { val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L, 1L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L, 0L)) + status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) + status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) assert(status.blocks.size === 3) assert(status.memUsed === 160L) assert(status.memRemaining === 840L) assert(status.diskUsed === 140L) - assert(status.offHeapUsed === 2L) } test("storage status remove non-RDD blocks") { @@ -73,20 +70,19 @@ class StorageSuite extends SparkFunSuite { assert(status.memUsed === 10L) assert(status.memRemaining === 990L) assert(status.diskUsed === 20L) - assert(status.offHeapUsed === 1L) } // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) assert(status.rddBlocks.isEmpty) - status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L, 0L)) - status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L, 0L)) - status.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 100L, 200L, 1L)) - status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L, 0L)) - status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L, 0L)) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 100L, 200L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L)) status } @@ -113,9 +109,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsedByRdd(0) === 20L) assert(status.diskUsedByRdd(1) === 200L) assert(status.diskUsedByRdd(2) === 80L) - assert(status.offHeapUsedByRdd(0) === 1L) - assert(status.offHeapUsedByRdd(1) === 1L) - assert(status.offHeapUsedByRdd(2) === 1L) assert(status.rddStorageLevel(0) === Some(memAndDisk)) assert(status.rddStorageLevel(1) === Some(memAndDisk)) assert(status.rddStorageLevel(2) === Some(memAndDisk)) @@ -124,15 +117,14 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocksById(10).isEmpty) assert(status.memUsedByRdd(10) === 0L) assert(status.diskUsedByRdd(10) === 0L) - assert(status.offHeapUsedByRdd(10) === 0L) assert(status.rddStorageLevel(10) === None) } test("storage status update RDD blocks") { val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L, 0L)) + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) + status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) assert(status.blocks.size === 7) assert(status.rddBlocks.size === 5) assert(status.rddBlocksById(0).size === 1) @@ -144,9 +136,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsedByRdd(0) === 0L) assert(status.diskUsedByRdd(1) === 200L) assert(status.diskUsedByRdd(2) === 1060L) - assert(status.offHeapUsedByRdd(0) === 0L) - assert(status.offHeapUsedByRdd(1) === 1L) - assert(status.offHeapUsedByRdd(2) === 0L) } test("storage status remove RDD blocks") { @@ -170,9 +159,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsedByRdd(0) === 20L) assert(status.diskUsedByRdd(1) === 0L) assert(status.diskUsedByRdd(2) === 20L) - assert(status.offHeapUsedByRdd(0) === 1L) - assert(status.offHeapUsedByRdd(1) === 0L) - assert(status.offHeapUsedByRdd(2) === 0L) } test("storage status containsBlock") { @@ -209,17 +195,17 @@ class StorageSuite extends SparkFunSuite { val status = storageStatus2 assert(status.blocks.size === status.numBlocks) assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) + status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) + status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) assert(status.blocks.size === status.numBlocks) assert(status.rddBlocks.size === status.numRddBlocks) assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L, 400L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) + status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) + status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) + status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) assert(status.blocks.size === status.numBlocks) assert(status.rddBlocks.size === status.numRddBlocks) assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) @@ -244,29 +230,24 @@ class StorageSuite extends SparkFunSuite { val status = storageStatus2 def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum - def actualOffHeapUsed: Long = status.blocks.values.map(_.externalBlockStoreSize).sum assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) - status.addBlock(TestBlockId("fire"), BlockStatus(memAndDisk, 4000L, 5000L, 6000L)) - status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L, 600L)) - status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L, 60L)) + status.addBlock(TestBlockId("fire"), BlockStatus(memAndDisk, 4000L, 5000L)) + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L, 6L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L, 6L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L, 6L)) + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) status.removeBlock(TestBlockId("fire")) status.removeBlock(TestBlockId("man")) status.removeBlock(RDDBlockId(2, 2)) status.removeBlock(RDDBlockId(2, 3)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -274,14 +255,14 @@ class StorageSuite extends SparkFunSuite { val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) - status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(0, 3), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status3.addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status3.addBlock(RDDBlockId(1, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) + status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(0, 3), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 1L, 2L)) + status3.addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) + status3.addBlock(RDDBlockId(1, 2), BlockStatus(memAndDisk, 1L, 2L)) Seq(status1, status2, status3) } @@ -334,9 +315,9 @@ class StorageSuite extends SparkFunSuite { test("StorageUtils.getRddBlockLocations with multiple locations") { val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L, 0L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) + storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) + storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) + storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) assert(blockLocations0.size === 5) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 3dab15a9d4691..350c174e24742 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ui.storage -import scala.xml.Utility - import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -64,26 +62,24 @@ class StoragePageSuite extends SparkFunSuite { "Cached Partitions", "Fraction Cached", "Size in Memory", - "Size in ExternalBlockStore", "Size on Disk") assert((xmlNodes \\ "th").map(_.text) === headers) assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B")) + Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B")) + Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === - Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B", - "500.0 B")) + Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=3")) @@ -98,16 +94,14 @@ class StoragePageSuite extends SparkFunSuite { "localhost:1111", StorageLevel.MEMORY_ONLY, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0) + diskSize = 0) assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), "localhost:1111", StorageLevel.MEMORY_ONLY_SER, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0) + diskSize = 0) assert(("Memory Serialized", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) @@ -115,18 +109,8 @@ class StoragePageSuite extends SparkFunSuite { "localhost:1111", StorageLevel.DISK_ONLY, memSize = 0, - diskSize = 100, - externalBlockStoreSize = 0) + diskSize = 100) assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) - - val externalBlock = BlockUIData(StreamBlockId(0, 0), - "localhost:1111", - StorageLevel.OFF_HEAP, - memSize = 0, - diskSize = 0, - externalBlockStoreSize = 100) - assert(("External", 100) === - storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock)) } test("receiverBlockTables") { @@ -135,14 +119,12 @@ class StoragePageSuite extends SparkFunSuite { "localhost:10000", StorageLevel.MEMORY_ONLY, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0), + diskSize = 0), BlockUIData(StreamBlockId(1, 1), "localhost:10000", StorageLevel.DISK_ONLY, memSize = 0, - diskSize = 100, - externalBlockStoreSize = 0) + diskSize = 100) ) val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) @@ -151,20 +133,12 @@ class StoragePageSuite extends SparkFunSuite { "localhost:10001", StorageLevel.MEMORY_ONLY, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0), - BlockUIData(StreamBlockId(2, 2), - "localhost:10001", - StorageLevel.OFF_HEAP, - memSize = 0, - diskSize = 0, - externalBlockStoreSize = 200), + diskSize = 0), BlockUIData(StreamBlockId(1, 1), "localhost:10001", StorageLevel.MEMORY_ONLY_SER, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0) + diskSize = 0) ) val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) @@ -174,16 +148,15 @@ class StoragePageSuite extends SparkFunSuite { "Executor ID", "Address", "Total Size in Memory", - "Total Size in ExternalBlockStore", "Total Size on Disk", "Stream Blocks") assert((executorTable \\ "th").map(_.text) === executorHeaders) assert((executorTable \\ "tr").size === 2) assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2")) + Seq("0", "localhost:10000", "100.0 B", "100.0 B", "2")) assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3")) + Seq("1", "localhost:10001", "200.0 B", "0.0 B", "2")) val blockTable = (xmlNodes \\ "table")(1) val blockHeaders = Seq( @@ -194,7 +167,7 @@ class StoragePageSuite extends SparkFunSuite { "Size") assert((blockTable \\ "th").map(_.text) === blockHeaders) - assert((blockTable \\ "tr").size === 5) + assert((blockTable \\ "tr").size === 4) assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) === Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B")) // Check "rowspan=2" for the first 2 columns @@ -212,17 +185,10 @@ class StoragePageSuite extends SparkFunSuite { assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) === Seq("localhost:10001", "Memory Serialized", "100.0 B")) - - assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) === - Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B")) - // Check "rowspan=1" for the first 2 columns - assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1")) - assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1")) } test("empty receiverBlockTables") { assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) - val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 4b838a8ab1335..5ac922c2172ce 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -128,20 +128,17 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { // Task end with a few new persisted blocks, some from the same RDD val metrics1 = new TaskMetrics metrics1.updatedBlocks = Some(Seq( - (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L, 0L)), - (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L, 0L)), - (RDDBlockId(0, 102), BlockStatus(memAndDisk, 400L, 0L, 200L)), - (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L, 0L)) + (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L)), + (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L)), + (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L)) )) bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1)) - assert(storageListener._rddInfoMap(0).memSize === 800L) + assert(storageListener._rddInfoMap(0).memSize === 400L) assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).externalBlockStoreSize === 200L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 3) + assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) assert(storageListener._rddInfoMap(0).isCached) assert(storageListener._rddInfoMap(1).memSize === 0L) assert(storageListener._rddInfoMap(1).diskSize === 240L) - assert(storageListener._rddInfoMap(1).externalBlockStoreSize === 0L) assert(storageListener._rddInfoMap(1).numCachedPartitions === 1) assert(storageListener._rddInfoMap(1).isCached) assert(!storageListener._rddInfoMap(2).isCached) @@ -150,16 +147,15 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { // Task end with a few dropped blocks val metrics2 = new TaskMetrics metrics2.updatedBlocks = Some(Seq( - (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L, 0L)), - (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L, 0L)), - (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L, 0L)), // doesn't actually exist - (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L, 0L)) // doesn't actually exist + (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L)), + (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L)), + (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L)), // doesn't actually exist + (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L)) // doesn't actually exist )) bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2)) - assert(storageListener._rddInfoMap(0).memSize === 400L) + assert(storageListener._rddInfoMap(0).memSize === 0L) assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).externalBlockStoreSize === 200L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) + assert(storageListener._rddInfoMap(0).numCachedPartitions === 1) assert(storageListener._rddInfoMap(0).isCached) assert(!storageListener._rddInfoMap(1).isCached) assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) @@ -175,8 +171,8 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details") val taskMetrics0 = new TaskMetrics val taskMetrics1 = new TaskMetrics - val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L)) - val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L)) + val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L)) + val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L)) taskMetrics0.updatedBlocks = Some(Seq(block0)) taskMetrics1.updatedBlocks = Some(Seq(block1)) bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 6566400e63799..068e8397c89bb 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -801,7 +801,7 @@ class JsonProtocolSuite extends SparkFunSuite { } // Make at most 6 blocks t.updatedBlocks = Some((1 to (e % 5 + 1)).map { i => - (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i, c%i)) + (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i)) }.toSeq) t } @@ -867,14 +867,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 201, | "Number of Cached Partitions": 301, | "Memory Size": 401, - | "ExternalBlockStore Size": 0, | "Disk Size": 501 | } | ], @@ -1063,12 +1061,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } @@ -1149,12 +1145,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } @@ -1235,12 +1229,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } @@ -1270,14 +1262,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 200, | "Number of Cached Partitions": 300, | "Memory Size": 400, - | "ExternalBlockStore Size": 0, | "Disk Size": 500 | } | ], @@ -1314,14 +1304,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 400, | "Number of Cached Partitions": 600, | "Memory Size": 800, - | "ExternalBlockStore Size": 0, | "Disk Size": 1000 | }, | { @@ -1332,14 +1320,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 401, | "Number of Cached Partitions": 601, | "Memory Size": 801, - | "ExternalBlockStore Size": 0, | "Disk Size": 1001 | } | ], @@ -1376,14 +1362,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 600, | "Number of Cached Partitions": 900, | "Memory Size": 1200, - | "ExternalBlockStore Size": 0, | "Disk Size": 1500 | }, | { @@ -1394,14 +1378,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 601, | "Number of Cached Partitions": 901, | "Memory Size": 1201, - | "ExternalBlockStore Size": 0, | "Disk Size": 1501 | }, | { @@ -1412,14 +1394,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 602, | "Number of Cached Partitions": 902, | "Memory Size": 1202, - | "ExternalBlockStore Size": 0, | "Disk Size": 1502 | } | ], @@ -1456,14 +1436,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 800, | "Number of Cached Partitions": 1200, | "Memory Size": 1600, - | "ExternalBlockStore Size": 0, | "Disk Size": 2000 | }, | { @@ -1474,14 +1452,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 801, | "Number of Cached Partitions": 1201, | "Memory Size": 1601, - | "ExternalBlockStore Size": 0, | "Disk Size": 2001 | }, | { @@ -1492,14 +1468,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 802, | "Number of Cached Partitions": 1202, | "Memory Size": 1602, - | "ExternalBlockStore Size": 0, | "Disk Size": 2002 | }, | { @@ -1510,14 +1484,12 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 803, | "Number of Cached Partitions": 1203, | "Memory Size": 1603, - | "ExternalBlockStore Size": 0, | "Disk Size": 2003 | } | ], @@ -1723,12 +1695,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index fb2e91e1ee4b0..0760529b578ee 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -35,7 +35,7 @@ commons-configuration-1.6.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar -commons-io-2.4.jar +commons-io-2.1.jar commons-lang-2.6.jar commons-lang3-3.3.2.jar commons-logging-1.1.3.jar @@ -179,10 +179,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -tachyon-client-0.8.2.jar -tachyon-underfs-hdfs-0.8.2.jar -tachyon-underfs-local-0.8.2.jar -tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 59e4d4f839788..191f2a0e4e86f 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -170,10 +170,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -tachyon-client-0.8.2.jar -tachyon-underfs-hdfs-0.8.2.jar -tachyon-underfs-local-0.8.2.jar -tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index e4395c872c230..9134e997c8457 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -171,10 +171,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -tachyon-client-0.8.2.jar -tachyon-underfs-hdfs-0.8.2.jar -tachyon-underfs-local-0.8.2.jar -tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 89fd15da7d0b3..8c45832873848 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -177,10 +177,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -tachyon-client-0.8.2.jar -tachyon-underfs-hdfs-0.8.2.jar -tachyon-underfs-local-0.8.2.jar -tachyon-underfs-s3-0.8.2.jar uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala deleted file mode 100644 index 8b739c9d7c1db..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples - -import java.util.Random - -import scala.math.exp - -import breeze.linalg.{DenseVector, Vector} -import org.apache.hadoop.conf.Configuration - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -/** - * Logistic regression based classification. - * This example uses Tachyon to persist rdds during computation. - * - * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. - */ -object SparkTachyonHdfsLR { - val D = 10 // Numer of dimensions - val rand = new Random(42) - - def showWarning() { - System.err.println( - """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS - |for more conventional use. - """.stripMargin) - } - - case class DataPoint(x: Vector[Double], y: Double) - - def parsePoint(line: String): DataPoint = { - val tok = new java.util.StringTokenizer(line, " ") - var y = tok.nextToken.toDouble - var x = new Array[Double](D) - var i = 0 - while (i < D) { - x(i) = tok.nextToken.toDouble; i += 1 - } - DataPoint(new DenseVector(x), y) - } - - def main(args: Array[String]) { - - showWarning() - - val inputPath = args(0) - val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") - val conf = new Configuration() - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint).persist(StorageLevel.OFF_HEAP) - val ITERATIONS = args(1).toInt - - // Initialize w to a random value - var w = DenseVector.fill(D){2 * rand.nextDouble - 1} - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - val gradient = points.map { p => - p.x * (1 / (1 + exp(-p.y * (w.dot(p.x)))) - 1) * p.y - }.reduce(_ + _) - w -= gradient - } - - println("Final w: " + w) - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala deleted file mode 100644 index e46ac655beb58..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples - -import scala.math.random - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -/** - * Computes an approximation to pi - * This example uses Tachyon to persist rdds during computation. - */ -object SparkTachyonPi { - def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("SparkTachyonPi") - val spark = new SparkContext(sparkConf) - - val slices = if (args.length > 0) args(0).toInt else 2 - val n = 100000 * slices - - val rdd = spark.parallelize(1 to n, slices) - rdd.persist(StorageLevel.OFF_HEAP) - val count = rdd.map { i => - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x * x + y * y < 1) 1 else 0 - }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / n) - - spark.stop() - } -} -// scalastyle:on println diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4206d1fada421..ccd3c34bb5c8c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -120,7 +120,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") ) ++ // SPARK-12665 Remove deprecated and unused classes Seq( From 513266c0426faf8b9cc7576963bb9a57d2fdbc54 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 15 Jan 2016 13:17:29 -0800 Subject: [PATCH 1581/2089] [SPARK-12833][HOT-FIX] Fix scala 2.11 compilation. Seems https://github.com/apache/spark/commit/5f83c6991c95616ecbc2878f8860c69b2826f56c breaks scala 2.11 compilation. Author: Yin Huai Closes #10774 from yhuai/fixScala211Compile. --- .../spark/sql/execution/datasources/csv/CSVParameters.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala index ba44121244163..ec16bdbd8bfb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -27,7 +27,7 @@ private[sql] case class CSVParameters(parameters: Map[String, String]) extends L val paramValue = parameters.get(paramName) paramValue match { case None => default - case Some(value) if value.length == 0 => '\0' + case Some(value) if value.length == 0 => '\u0000' case Some(value) if value.length == 1 => value.charAt(0) case _ => throw new RuntimeException(s"$paramName cannot be more than one character") } @@ -50,7 +50,7 @@ private[sql] case class CSVParameters(parameters: Map[String, String]) extends L val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') - val comment = getChar("comment", '\0') + val comment = getChar("comment", '\u0000') val headerFlag = getBool("header") val inferSchemaFlag = getBool("inferSchema") @@ -77,7 +77,7 @@ private[sql] case class CSVParameters(parameters: Map[String, String]) extends L val inputBufferSize = 128 - val isCommentSet = this.comment != '\0' + val isCommentSet = this.comment != '\u0000' val rowSeparator = "\n" } From 0bb73554a96088adca9dcd62a5a37b13772b02d8 Mon Sep 17 00:00:00 2001 From: Julien Baley Date: Fri, 15 Jan 2016 13:53:20 -0800 Subject: [PATCH 1582/2089] Fix typo disvoered => discovered Author: Julien Baley Closes #10773 from julienbaley/patch-1. --- .../spark/sql/execution/datasources/PartitioningUtils.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 81962f8d63789..65a715caf1cee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -102,11 +102,11 @@ private[sql] object PartitioningUtils { // It will be recognised as conflicting directory structure: // "hdfs://host:9000/invalidPath" // "hdfs://host:9000/path" - val disvoeredBasePaths = optDiscoveredBasePaths.flatMap(x => x) + val discoveredBasePaths = optDiscoveredBasePaths.flatMap(x => x) assert( - disvoeredBasePaths.distinct.size == 1, + discoveredBasePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + - disvoeredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + + discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + "If provided paths are partition directories, please set " + "\"basePath\" in the options of the data source to specify the " + "root directory of the table. If there are multiple root directories, " + From 61c45876fb532cdb7278dea48cc141208b63737c Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Fri, 15 Jan 2016 16:03:21 -0600 Subject: [PATCH 1583/2089] [SPARK-12716][WEB UI] Add a TOTALS row to the Executors Web UI Added a Totals table to the top of the page to display the totals of each applicable column in the executors table. Old Description: ~~Created a TOTALS row containing the totals of each column in the executors UI. By default the TOTALS row appears at the top of the table. When a column is sorted the TOTALS row will always sort to either the top or bottom of the table.~~ Author: Alex Bozarth Closes #10668 from ajbozarth/spark12716. --- .../apache/spark/ui/exec/ExecutorsPage.scala | 74 ++++++++++++++++--- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 7072a152d6b69..440dfa2679563 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -62,9 +62,6 @@ private[ui] class ExecutorsPage( } (_storageStatusList, _execInfo) } - val maxMem = storageStatusList.map(_.maxMem).sum - val memUsed = storageStatusList.map(_.memUsed).sum - val diskUsed = storageStatusList.map(_.diskUsed).sum val execInfoSorted = execInfo.sortBy(_.id) val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty @@ -100,18 +97,15 @@ private[ui] class ExecutorsPage(
    Property NameDefaultMeaning
    spark.mesos.coarsefalsetrue - If set to true, runs over Mesos clusters in - "coarse-grained" sharing mode, - where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per - Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use - for the whole duration of the Spark job. + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine. + If set to false, runs over Mesos cluster in "fine-grained" sharing mode, where one Mesos task is created per Spark task. + Detailed information in 'Mesos Run Modes'.
    {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2cc6c75a9ac12..6d4066a870cdd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -100,7 +100,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) - val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") + val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Index") val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index dd8d5ec27f87e..bc8a5d494dbd3 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -67,6 +67,20 @@ class UIUtilsSuite extends SparkFunSuite { s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") } + test("decodeURLParameter (SPARK-12708: Sorting task error in Stages Page when yarn mode.)") { + val encoded1 = "%252F" + val decoded1 = "/" + val encoded2 = "%253Cdriver%253E" + val decoded2 = "" + + assert(decoded1 === decodeURLParameter(encoded1)) + assert(decoded2 === decodeURLParameter(encoded2)) + + // verify that no affect to decoded URL. + assert(decoded1 === decodeURLParameter(decoded1)) + assert(decoded2 === decodeURLParameter(decoded2)) + } + private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) From 591c88c9e2a6c2e2ca84f1b66c635f198a16d112 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Jan 2016 21:02:18 -0800 Subject: [PATCH 1573/2089] [SPARK-12829] Turn Java style checker on It was previously turned off because there was a problem with a pull request. We should turn it on now. Author: Reynold Xin Closes #10763 from rxin/SPARK-12829. --- dev/run-tests.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 795db0dcfbab9..c1646c77f1e53 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -537,8 +537,7 @@ def main(): or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - # run_java_style_checks() - pass + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): From fe7246fea67e1d71fba679dee3eb2c7386b4f4e2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Jan 2016 23:33:45 -0800 Subject: [PATCH 1574/2089] [SPARK-12830] Java style: disallow trailing whitespaces. Author: Reynold Xin Closes #10764 from rxin/SPARK-12830. --- checkstyle.xml | 8 +++++++- .../java/org/apache/spark/examples/JavaLogQuery.java | 2 +- .../java/org/apache/spark/examples/JavaSparkPi.java | 2 +- .../spark/mllib/random/JavaRandomRDDsSuite.java | 2 +- .../apache/spark/network/buffer/LazyFileRegion.java | 2 +- .../java/org/apache/spark/network/util/ByteUnit.java | 12 ++++++------ .../test/resources/data/conf/hive-log4j.properties | 2 +- .../apache/spark/unsafe/types/UTF8StringSuite.java | 2 +- 8 files changed, 19 insertions(+), 13 deletions(-) diff --git a/checkstyle.xml b/checkstyle.xml index a493ee443c752..b5d1617ba4ede 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -58,6 +58,12 @@ + + + + + + @@ -84,7 +90,7 @@ - + diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 812e9d5580cbf..0448a1a0c8cf3 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -35,7 +35,7 @@ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: JavaLogQuery [logFile] */ public final class JavaLogQuery { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 0f07cb4098325..af874887445b1 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -26,7 +26,7 @@ import java.util.ArrayList; import java.util.List; -/** +/** * Computes an approximation to pi * Usage: JavaSparkPi [slices] */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 5728df5aeebdc..be58691f4d87e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -166,7 +166,7 @@ public void testNormalVectorRDD() { @SuppressWarnings("unchecked") public void testLogNormalVectorRDD() { double mean = 4.0; - double std = 2.0; + double std = 2.0; long m = 100L; int n = 10; int p = 2; diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java index 81bc8ec40fc82..162cf6da0dffe 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java @@ -32,7 +32,7 @@ /** * A FileRegion implementation that only creates the file descriptor when the region is being * transferred. This cannot be used with Epoll because there is no native support for it. - * + * * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we * should push this into Netty so the native Epoll transport can support this feature. */ diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java index 36d655017fb0d..a2f018373f2a4 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -33,8 +33,8 @@ private ByteUnit(long multiplier) { public long convertFrom(long d, ByteUnit u) { return u.convertTo(d, this); } - - // Convert the provided number (d) interpreted as this unit type to unit type (u). + + // Convert the provided number (d) interpreted as this unit type to unit type (u). public long convertTo(long d, ByteUnit u) { if (multiplier > u.multiplier) { long ratio = multiplier / u.multiplier; @@ -44,7 +44,7 @@ public long convertTo(long d, ByteUnit u) { } return d * ratio; } else { - // Perform operations in this order to avoid potential overflow + // Perform operations in this order to avoid potential overflow // when computing d * multiplier return d / (u.multiplier / multiplier); } @@ -54,14 +54,14 @@ public double toBytes(long d) { if (d < 0) { throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); } - return d * multiplier; + return d * multiplier; } - + public long toKiB(long d) { return convertTo(d, KiB); } public long toMiB(long d) { return convertTo(d, MiB); } public long toGiB(long d) { return convertTo(d, GiB); } public long toTiB(long d) { return convertTo(d, TiB); } public long toPiB(long d) { return convertTo(d, PiB); } - + private final long multiplier; } diff --git a/sql/hive/src/test/resources/data/conf/hive-log4j.properties b/sql/hive/src/test/resources/data/conf/hive-log4j.properties index 885c86f2b94f4..6a042472adb90 100644 --- a/sql/hive/src/test/resources/data/conf/hive-log4j.properties +++ b/sql/hive/src/test/resources/data/conf/hive-log4j.properties @@ -47,7 +47,7 @@ log4j.appender.DRFA.layout.ConversionPattern=%d{ISO8601} %-5p %c{2} (%F:%M(%L)) # # console -# Add "console" to rootlogger above if you want to use this +# Add "console" to rootlogger above if you want to use this # log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e21ffdcff9abf..bef5d712cfca5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -378,7 +378,7 @@ public void split() { assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), new UTF8String[]{fromString("ab"), fromString("def,ghi")})); } - + @Test public void levenshteinDistance() { assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8)); From d0a5c32bd05841f411a342a80c5da9f73f30d69a Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Fri, 15 Jan 2016 12:04:05 +0000 Subject: [PATCH 1575/2089] [SPARK-12655][GRAPHX] GraphX does not unpersist RDDs Some VertexRDD and EdgeRDD are created during the intermediate step of g.connectedComponents() but unnecessarily left cached after the method is done. The fix is to unpersist these RDDs once they are no longer in use. A test case is added to confirm the fix for the reported bug. Author: Jason Lee Closes #10713 from jasoncl/SPARK-12655. --- .../scala/org/apache/spark/graphx/Pregel.scala | 2 +- .../spark/graphx/lib/ConnectedComponents.scala | 4 +++- .../org/apache/spark/graphx/GraphSuite.scala | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index b908860310093..796082721d696 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -151,7 +151,7 @@ object Pregel extends Logging { // count the iteration i += 1 } - + messages.unpersist(blocking = false) g } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index 859f896039047..f72cbb15242ec 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -47,9 +47,11 @@ object ConnectedComponents { } } val initialMessage = Long.MaxValue - Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( + val pregelGraph = Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( vprog = (id, attr, msg) => math.min(attr, msg), sendMsg = sendMessage, mergeMsg = (a, b) => math.min(a, b)) + ccGraph.unpersist() + pregelGraph } // end of connectedComponents } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 1f5e27d5508b8..2fbc6f069d48d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -428,4 +428,20 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } + test("unpersist graph RDD") { + withSpark { sc => + val vert = sc.parallelize(List((1L, "a"), (2L, "b"), (3L, "c")), 1) + val edges = sc.parallelize(List(Edge[Long](1L, 2L), Edge[Long](1L, 3L)), 1) + val g0 = Graph(vert, edges) + val g = g0.partitionBy(PartitionStrategy.EdgePartition2D, 2) + val cc = g.connectedComponents() + assert(sc.getPersistentRDDs.nonEmpty) + cc.unpersist() + g.unpersist() + g0.unpersist() + vert.unpersist() + edges.unpersist() + assert(sc.getPersistentRDDs.isEmpty) + } + } } From 96fb894d4b33e293625fa92bbeccbbf5e688015e Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Fri, 15 Jan 2016 13:11:27 +0000 Subject: [PATCH 1576/2089] =?UTF-8?q?[SPARK-2930]=20clarify=20docs=20on=20?= =?UTF-8?q?using=20webhdfs=20with=20spark.yarn.access.nam=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …enodes Author: Tom Graves Closes #10699 from tgravescs/SPARK-2930. --- docs/running-on-yarn.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 06413f83c3a71..a148c867eb18f 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -260,10 +260,10 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A comma-separated list of secure HDFS namenodes your Spark application is going to access. For - example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. - The Spark application must have access to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the namenodes listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters.
    {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)}{Utils.bytesToString(rdd.externalBlockStoreSize)} {Utils.bytesToString(rdd.diskSize)}
    {Utils.bytesToString(status.totalMemSize)} - {Utils.bytesToString(status.totalExternalBlockStoreSize)} - {Utils.bytesToString(status.totalDiskSize)}
    val content = -
    +
    -
      -
    • Memory: - {Utils.bytesToString(memUsed)} Used - ({Utils.bytesToString(maxMem)} Total)
    • -
    • Disk: {Utils.bytesToString(diskUsed)} Used
    • -
    +

    Totals for {execInfo.size} Executors

    + {execSummary(execInfo)}
    +

    Active Executors

    {execTable}
    ; @@ -179,6 +173,66 @@ private[ui] class ExecutorsPage( } + private def execSummary(execInfo: Seq[ExecutorSummary]): Seq[Node] = { + val maximumMemory = execInfo.map(_.maxMemory).sum + val memoryUsed = execInfo.map(_.memoryUsed).sum + val diskUsed = execInfo.map(_.diskUsed).sum + val totalDuration = execInfo.map(_.totalDuration).sum + val totalInputBytes = execInfo.map(_.totalInputBytes).sum + val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum + val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum + + val sumContent = + + {execInfo.map(_.rddBlocks).sum} + + {Utils.bytesToString(memoryUsed)} / + {Utils.bytesToString(maximumMemory)} + + + {Utils.bytesToString(diskUsed)} + + {execInfo.map(_.activeTasks).sum} + {execInfo.map(_.failedTasks).sum} + {execInfo.map(_.completedTasks).sum} + {execInfo.map(_.totalTasks).sum} + + {Utils.msDurationToString(totalDuration)} + + + {Utils.bytesToString(totalInputBytes)} + + + {Utils.bytesToString(totalShuffleRead)} + + + {Utils.bytesToString(totalShuffleWrite)} + + ; + + + + + + + + + + + + + + + + + {sumContent} + +
    RDD BlocksStorage MemoryDisk UsedActive TasksFailed TasksComplete TasksTotal TasksTask TimeInputShuffle Read + + Shuffle Write + +
    + } } private[spark] object ExecutorsPage { From 3f1c58d60b85625ab3abf16456ce27c820453ecf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Jan 2016 14:20:22 -0800 Subject: [PATCH 1584/2089] [SQL][MINOR] BoundReference do not need to be NamedExpression We made it a `NamedExpression` to workaroud some hacky cases long time ago, and now seems it's safe to remove it. Author: Wenchen Fan Closes #10765 from cloud-fan/minor. --- .../sql/catalyst/expressions/BoundAttribute.scala | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 397abc7391ec6..dda822d05485b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends LeafExpression with NamedExpression { + extends LeafExpression { override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" @@ -58,16 +58,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } } - override def name: String = s"i[$ordinal]" - - override def toAttribute: Attribute = throw new UnsupportedOperationException - - override def qualifiers: Seq[String] = throw new UnsupportedOperationException - - override def exprId: ExprId = throw new UnsupportedOperationException - - override def newInstance(): NamedExpression = this - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) From 7cd7f2202547224593517b392f56e49e4c94cabc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 15 Jan 2016 15:19:10 -0800 Subject: [PATCH 1585/2089] [SPARK-12575][SQL] Grammar parity with existing SQL parser In this PR the new CatalystQl parser stack reaches grammar parity with the old Parser-Combinator based SQL Parser. This PR also replaces all uses of the old Parser, and removes it from the code base. Although the existing Hive and SQL parser dialects were mostly the same, some kinks had to be worked out: - The SQL Parser allowed syntax like ```APPROXIMATE(0.01) COUNT(DISTINCT a)```. In order to make this work we needed to hardcode approximate operators in the parser, or we would have to create an approximate expression. ```APPROXIMATE_COUNT_DISTINCT(a, 0.01)``` would also do the job and is much easier to maintain. So, this PR **removes** this keyword. - The old SQL Parser supports ```LIMIT``` clauses in nested queries. This is **not supported** anymore. See https://github.com/apache/spark/pull/10689 for the rationale for this. - Hive has a charset name char set literal combination it supports, for instance the following expression ```_ISO-8859-1 0x4341464562616265``` would yield this string: ```CAFEbabe```. Hive will only allow charset names to start with an underscore. This is quite annoying in spark because as soon as you use a tuple names will start with an underscore. In this PR we **remove** this feature from the parser. It would be quite easy to implement such a feature as an Expression later on. - Hive and the SQL Parser treat decimal literals differently. Hive will turn any decimal into a ```Double``` whereas the SQL Parser would convert a non-scientific decimal into a ```BigDecimal```, and would turn a scientific decimal into a Double. We follow Hive's behavior here. The new parser supports a big decimal literal, for instance: ```81923801.42BD```, which can be used when a big decimal is needed. cc rxin viirya marmbrus yhuai cloud-fan Author: Herman van Hovell Closes #10745 from hvanhovell/SPARK-12575-2. --- python/pyspark/sql/tests.py | 3 +- .../sql/catalyst/parser/ExpressionParser.g | 57 +- .../sql/catalyst/parser/FromClauseParser.g | 7 +- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 32 +- .../sql/catalyst/parser/SparkSqlParser.g | 10 +- .../spark/sql/catalyst/parser/ParseUtils.java | 31 +- .../sql/catalyst/AbstractSparkSQLParser.scala | 4 +- .../spark/sql/catalyst/CatalystQl.scala | 139 +++-- .../spark/sql/catalyst/ParserDialect.scala | 46 +- .../apache/spark/sql/catalyst/SqlParser.scala | 509 ------------------ .../aggregate/HyperLogLogPlusPlus.scala | 1 + .../spark/sql/catalyst/CatalystQlSuite.scala | 1 + .../spark/sql/catalyst/SqlParserSuite.scala | 150 ------ .../scala/org/apache/spark/sql/Column.scala | 1 - .../org/apache/spark/sql/DataFrame.scala | 8 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 17 +- .../spark/sql/execution/SparkSQLParser.scala | 19 +- .../sql/execution/datasources/DDLParser.scala | 17 +- .../org/apache/spark/sql/functions.scala | 7 +- .../spark/sql/MathExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 73 +-- .../datasources/json/JsonSuite.scala | 12 +- .../execution/HiveCompatibilitySuite.scala | 4 +- .../spark/sql/hive/ExtendedHiveQlParser.scala | 18 +- .../apache/spark/sql/hive/HiveContext.scala | 21 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 7 +- .../apache/spark/sql/hive/HiveQlSuite.scala | 9 +- .../spark/sql/hive/StatisticsSuite.scala | 5 +- .../sql/hive/execution/SQLQuerySuite.scala | 20 +- .../spark/unsafe/types/CalendarInterval.java | 14 + 33 files changed, 286 insertions(+), 972 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e396cf41f2f7b..c03cb9338ae68 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1081,8 +1081,7 @@ def test_replace(self): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - # RuntimeException should not be captured - self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index aabb5d49582c8..047a7e56cb577 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -123,7 +123,6 @@ constant | SmallintLiteral | TinyintLiteral | DecimalLiteral - | charSetStringLiteral | booleanValue ; @@ -132,13 +131,6 @@ stringLiteralSequence StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) ; -charSetStringLiteral -@init { gParent.pushMsg("character string literal", state); } -@after { gParent.popMsg(state); } - : - csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) - ; - dateLiteral : KW_DATE StringLiteral -> @@ -163,22 +155,38 @@ timestampLiteral intervalLiteral : - KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> - { - adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH + -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant) + | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND + -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant) + | KW_INTERVAL + ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)? + ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)? + ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)? + ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)? + ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)? + ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)? + ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)? + (millisecond=intervalConstant KW_MILLISECOND)? + (microsecond=intervalConstant KW_MICROSECOND)? + -> ^(TOK_INTERVAL + ^(TOK_INTERVAL_YEAR_LITERAL $year?) + ^(TOK_INTERVAL_MONTH_LITERAL $month?) + ^(TOK_INTERVAL_WEEK_LITERAL $week?) + ^(TOK_INTERVAL_DAY_LITERAL $day?) + ^(TOK_INTERVAL_HOUR_LITERAL $hour?) + ^(TOK_INTERVAL_MINUTE_LITERAL $minute?) + ^(TOK_INTERVAL_SECOND_LITERAL $second?) + ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?) + ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?)) + ; + +intervalConstant + : + sign=(MINUS|PLUS)? value=Number -> { + adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText()) } - ; - -intervalQualifiers - : - KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL - | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL - | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL - | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL - | KW_DAY -> TOK_INTERVAL_DAY_LITERAL - | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL - | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL - | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + | StringLiteral ; expression @@ -219,7 +227,8 @@ nullCondition precedenceUnaryPrefixExpression : - (precedenceUnaryOperator^)* precedenceFieldExpression + (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression + | precedenceFieldExpression ; precedenceUnarySuffixExpression diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index 972c52e3ffcec..6d76afcd4ac07 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -206,11 +206,8 @@ tableName @init { gParent.pushMsg("table name", state); } @after { gParent.popMsg(state); } : - db=identifier DOT tab=identifier - -> ^(TOK_TABNAME $db $tab) - | - tab=identifier - -> ^(TOK_TABNAME $tab) + id1=identifier (DOT id2=identifier)? + -> ^(TOK_TABNAME $id1 $id2?) ; viewName diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 44a63fbef258c..ee2882e51c450 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -307,12 +307,12 @@ KW_AUTHORIZATION: 'AUTHORIZATION'; KW_CONF: 'CONF'; KW_VALUES: 'VALUES'; KW_RELOAD: 'RELOAD'; -KW_YEAR: 'YEAR'; -KW_MONTH: 'MONTH'; -KW_DAY: 'DAY'; -KW_HOUR: 'HOUR'; -KW_MINUTE: 'MINUTE'; -KW_SECOND: 'SECOND'; +KW_YEAR: 'YEAR'|'YEARS'; +KW_MONTH: 'MONTH'|'MONTHS'; +KW_DAY: 'DAY'|'DAYS'; +KW_HOUR: 'HOUR'|'HOURS'; +KW_MINUTE: 'MINUTE'|'MINUTES'; +KW_SECOND: 'SECOND'|'SECONDS'; KW_START: 'START'; KW_TRANSACTION: 'TRANSACTION'; KW_COMMIT: 'COMMIT'; @@ -324,6 +324,9 @@ KW_ISOLATION: 'ISOLATION'; KW_LEVEL: 'LEVEL'; KW_SNAPSHOT: 'SNAPSHOT'; KW_AUTOCOMMIT: 'AUTOCOMMIT'; +KW_WEEK: 'WEEK'|'WEEKS'; +KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; +KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. @@ -400,12 +403,6 @@ StringLiteral )+ ; -CharSetLiteral - : - StringLiteral - | '0' 'X' (HexDigit|Digit)+ - ; - BigintLiteral : (Digit)+ 'L' @@ -433,7 +430,7 @@ ByteLengthLiteral Number : - (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? + ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent? ; /* @@ -456,10 +453,10 @@ An Identifier can be: - macro name - hint name - window name -*/ +*/ Identifier : - (Letter | Digit) (Letter | Digit | '_')* + (Letter | Digit | '_')+ | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; at the API level only columns are allowed to be of this form */ | '`' RegexComponent+ '`' @@ -471,11 +468,6 @@ QuotedIdentifier '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } ; -CharSetName - : - '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ - ; - WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 2c13d3056f468..c146ca5914884 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -116,16 +116,20 @@ TOK_DATELITERAL; TOK_DATETIME; TOK_TIMESTAMP; TOK_TIMESTAMPLITERAL; +TOK_INTERVAL; TOK_INTERVAL_YEAR_MONTH; TOK_INTERVAL_YEAR_MONTH_LITERAL; TOK_INTERVAL_DAY_TIME; TOK_INTERVAL_DAY_TIME_LITERAL; TOK_INTERVAL_YEAR_LITERAL; TOK_INTERVAL_MONTH_LITERAL; +TOK_INTERVAL_WEEK_LITERAL; TOK_INTERVAL_DAY_LITERAL; TOK_INTERVAL_HOUR_LITERAL; TOK_INTERVAL_MINUTE_LITERAL; TOK_INTERVAL_SECOND_LITERAL; +TOK_INTERVAL_MILLISECOND_LITERAL; +TOK_INTERVAL_MICROSECOND_LITERAL; TOK_STRING; TOK_CHAR; TOK_VARCHAR; @@ -228,7 +232,6 @@ TOK_TMP_FILE; TOK_TABSORTCOLNAMEASC; TOK_TABSORTCOLNAMEDESC; TOK_STRINGLITERALSEQUENCE; -TOK_CHARSETLITERAL; TOK_CREATEFUNCTION; TOK_DROPFUNCTION; TOK_RELOADFUNCTION; @@ -509,7 +512,9 @@ import java.util.HashMap; xlateMap.put("KW_UPDATE", "UPDATE"); xlateMap.put("KW_VALUES", "VALUES"); xlateMap.put("KW_PURGE", "PURGE"); - + xlateMap.put("KW_WEEK", "WEEK"); + xlateMap.put("KW_MILLISECOND", "MILLISECOND"); + xlateMap.put("KW_MICROSECOND", "MICROSECOND"); // Operators xlateMap.put("DOT", "."); @@ -2078,6 +2083,7 @@ primitiveType | KW_SMALLINT -> TOK_SMALLINT | KW_INT -> TOK_INT | KW_BIGINT -> TOK_BIGINT + | KW_LONG -> TOK_BIGINT | KW_BOOLEAN -> TOK_BOOLEAN | KW_FLOAT -> TOK_FLOAT | KW_DOUBLE -> TOK_DOUBLE diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java index 5bc87b680f9ad..2520c7bb8dae4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java @@ -18,12 +18,10 @@ package org.apache.spark.sql.catalyst.parser; -import java.io.UnsupportedEncodingException; - /** * A couple of utility methods that help with parsing ASTs. * - * Both methods in this class were take from the SemanticAnalyzer in Hive: + * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive: * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java */ public final class ParseUtils { @@ -31,33 +29,6 @@ private ParseUtils() { super(); } - public static String charSetString(String charSetName, String charSetString) - throws UnsupportedEncodingException { - // The character set name starts with a _, so strip that - charSetName = charSetName.substring(1); - if (charSetString.charAt(0) == '\'') { - return new String(unescapeSQLString(charSetString).getBytes(), charSetName); - } else // hex input is also supported - { - assert charSetString.charAt(0) == '0'; - assert charSetString.charAt(1) == 'x'; - charSetString = charSetString.substring(2); - - byte[] bArray = new byte[charSetString.length() / 2]; - int j = 0; - for (int i = 0; i < charSetString.length(); i += 2) { - int val = Character.digit(charSetString.charAt(i), 16) * 16 - + Character.digit(charSetString.charAt(i + 1), 16); - if (val > 127) { - val = val - 256; - } - bArray[j++] = (byte)val; - } - - return new String(bArray, charSetName); - } - } - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; @SuppressWarnings("nls") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index bdc52c08acb66..9443369808984 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,9 +26,9 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser - extends StandardTokenParsers with PackratParsers { + extends StandardTokenParsers with PackratParsers with ParserDialect { - def parse(input: String): LogicalPlan = synchronized { + def parsePlan(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index d0fbdacf6eafd..c1591ecfe2b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -30,16 +30,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler -private[sql] object CatalystQl { - val parser = new CatalystQl - def parseExpression(sql: String): Expression = parser.parseExpression(sql) - def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql) -} - /** * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. */ -private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { +private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserDialect { object Token { def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { CurrentOrigin.setPosition(node.line, node.positionInLine) @@ -611,13 +605,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case plainIdent => plainIdent } - val numericAstTypes = Seq( - SparkSqlParser.Number, - SparkSqlParser.TinyintLiteral, - SparkSqlParser.SmallintLiteral, - SparkSqlParser.BigintLiteral, - SparkSqlParser.DecimalLiteral) - /* Case insensitive matches */ val COUNT = "(?i)COUNT".r val SUM = "(?i)SUM".r @@ -635,6 +622,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r + val INTEGRAL = "[+-]?\\d+".r + protected def nodeToExpr(node: ASTNode): Expression = node match { /* Attribute References */ case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => @@ -650,8 +639,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => + UnresolvedStar(Some(target.map(_.text))) /* Aggregate Functions */ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => @@ -787,71 +776,71 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) - // This code is adapted from - // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 - case ast: ASTNode if numericAstTypes contains ast.tokenType => - var v: Literal = null - try { - if (ast.text.endsWith("L")) { - // Literal bigint. - v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - } else if (ast.text.endsWith("S")) { - // Literal smallint. - v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - } else if (ast.text.endsWith("Y")) { - // Literal tinyint. - v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) { - // Literal decimal - val strVal = ast.text.stripSuffix("D").stripSuffix("B") - v = Literal(Decimal(strVal)) - } else { - v = Literal.create(ast.text.toDouble, DoubleType) - v = Literal.create(ast.text.toLong, LongType) - v = Literal.create(ast.text.toInt, IntegerType) - } - } catch { - case nfe: NumberFormatException => // Do nothing - } - - if (v == null) { - sys.error(s"Failed to parse number '${ast.text}'.") - } else { - v - } - - case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral => - Literal(ParseUtils.unescapeSQLString(ast.text)) + case ast if ast.tokenType == SparkSqlParser.TinyintLiteral => + Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) + case ast if ast.tokenType == SparkSqlParser.SmallintLiteral => + Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL => - Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text)) + case ast if ast.tokenType == SparkSqlParser.BigintLiteral => + Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.text)) + case ast if ast.tokenType == SparkSqlParser.DecimalLiteral => + Literal(Decimal(ast.text.substring(0, ast.text.length() - 2))) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("year", ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("month", ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("day", ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("hour", ast.text)) + case ast if ast.tokenType == SparkSqlParser.Number => + val text = ast.text + text match { + case INTEGRAL() => + BigDecimal(text) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + case _ => + Literal(text.toDouble) + } + case ast if ast.tokenType == SparkSqlParser.StringLiteral => + Literal(ParseUtils.unescapeSQLString(ast.text)) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("minute", ast.text)) + case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("second", ast.text)) + case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.children.head.text)) + + case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.children.head.text)) + + case Token("TOK_INTERVAL", elements) => + var interval = new CalendarInterval(0, 0) + var updated = false + elements.foreach { + // The interval node will always contain children for all possible time units. A child node + // is only useful when it contains exactly one (numeric) child. + case e @ Token(name, Token(value, Nil) :: Nil) => + val unit = name match { + case "TOK_INTERVAL_YEAR_LITERAL" => "year" + case "TOK_INTERVAL_MONTH_LITERAL" => "month" + case "TOK_INTERVAL_WEEK_LITERAL" => "week" + case "TOK_INTERVAL_DAY_LITERAL" => "day" + case "TOK_INTERVAL_HOUR_LITERAL" => "hour" + case "TOK_INTERVAL_MINUTE_LITERAL" => "minute" + case "TOK_INTERVAL_SECOND_LITERAL" => "second" + case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond" + case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond" + case _ => noParseRule(s"Interval($name)", e) + } + interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value)) + updated = true + case _ => + } + if (!updated) { + throw new AnalysisException("at least one time unit should be given for interval literal") + } + Literal(interval) case _ => noParseRule("Expression", node) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index e21d3c05464b6..7d9fbf2f12ee6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -18,52 +18,22 @@ package org.apache.spark.sql.catalyst import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * Root class of SQL Parser Dialect, and we don't guarantee the binary * compatibility for the future release, let's keep it as the internal * interface for advanced user. - * */ @DeveloperApi -abstract class ParserDialect { - // this is the main function that will be implemented by sql parser. - def parse(sqlText: String): LogicalPlan -} +trait ParserDialect { + /** Creates LogicalPlan for a given SQL string. */ + def parsePlan(sqlText: String): LogicalPlan -/** - * Currently we support the default dialect named "sql", associated with the class - * [[DefaultParserDialect]] - * - * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: - * {{{ - *-- switch to "hiveql" dialect - * spark-sql>SET spark.sql.dialect=hiveql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- switch to "sql" dialect - * spark-sql>SET spark.sql.dialect=sql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- register the new SQL dialect - * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- register the non-exist SQL dialect - * spark-sql> SET spark.sql.dialect=NotExistedClass; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- Exception will be thrown and switch to dialect - *-- "sql" (for SQLContext) or - *-- "hiveql" (for HiveContext) - * }}} - */ -private[spark] class DefaultParserDialect extends ParserDialect { - @transient - protected val sqlParser = SqlParser + /** Creates Expression for a given SQL string. */ + def parseExpression(sqlText: String): Expression - override def parse(sqlText: String): LogicalPlan = { - sqlParser.parse(sqlText) - } + /** Creates TableIdentifier for a given SQL string. */ + def parseTableIdentifier(sqlText: String): TableIdentifier } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala deleted file mode 100644 index 85ff4ea0c946b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ /dev/null @@ -1,509 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import scala.language.implicitConversions - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * A very simple SQL parser. Based loosely on: - * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala - * - * Limitations: - * - Only supports a very limited subset of SQL. - * - * This is currently included mostly for illustrative purposes. Users wanting more complete support - * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. - */ -object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - - def parseExpression(input: String): Expression = synchronized { - // Initialize the Keywords. - initLexical - phrase(projection)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) - } - } - - def parseTableIdentifier(input: String): TableIdentifier = synchronized { - // Initialize the Keywords. - initLexical - phrase(tableIdentifier)(new lexical.Scanner(input)) match { - case Success(ident, _) => ident - case failureOrError => sys.error(failureOrError.toString) - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ALL = Keyword("ALL") - protected val AND = Keyword("AND") - protected val APPROXIMATE = Keyword("APPROXIMATE") - protected val AS = Keyword("AS") - protected val ASC = Keyword("ASC") - protected val BETWEEN = Keyword("BETWEEN") - protected val BY = Keyword("BY") - protected val CASE = Keyword("CASE") - protected val CAST = Keyword("CAST") - protected val DESC = Keyword("DESC") - protected val DISTINCT = Keyword("DISTINCT") - protected val ELSE = Keyword("ELSE") - protected val END = Keyword("END") - protected val EXCEPT = Keyword("EXCEPT") - protected val FALSE = Keyword("FALSE") - protected val FROM = Keyword("FROM") - protected val FULL = Keyword("FULL") - protected val GROUP = Keyword("GROUP") - protected val HAVING = Keyword("HAVING") - protected val IN = Keyword("IN") - protected val INNER = Keyword("INNER") - protected val INSERT = Keyword("INSERT") - protected val INTERSECT = Keyword("INTERSECT") - protected val INTERVAL = Keyword("INTERVAL") - protected val INTO = Keyword("INTO") - protected val IS = Keyword("IS") - protected val JOIN = Keyword("JOIN") - protected val LEFT = Keyword("LEFT") - protected val LIKE = Keyword("LIKE") - protected val LIMIT = Keyword("LIMIT") - protected val NOT = Keyword("NOT") - protected val NULL = Keyword("NULL") - protected val ON = Keyword("ON") - protected val OR = Keyword("OR") - protected val ORDER = Keyword("ORDER") - protected val SORT = Keyword("SORT") - protected val OUTER = Keyword("OUTER") - protected val OVERWRITE = Keyword("OVERWRITE") - protected val REGEXP = Keyword("REGEXP") - protected val RIGHT = Keyword("RIGHT") - protected val RLIKE = Keyword("RLIKE") - protected val SELECT = Keyword("SELECT") - protected val SEMI = Keyword("SEMI") - protected val TABLE = Keyword("TABLE") - protected val THEN = Keyword("THEN") - protected val TRUE = Keyword("TRUE") - protected val UNION = Keyword("UNION") - protected val WHEN = Keyword("WHEN") - protected val WHERE = Keyword("WHERE") - protected val WITH = Keyword("WITH") - - protected lazy val start: Parser[LogicalPlan] = - start1 | insert | cte - - protected lazy val start1: Parser[LogicalPlan] = - (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - - protected lazy val select: Parser[LogicalPlan] = - SELECT ~> DISTINCT.? ~ - repsep(projection, ",") ~ - (FROM ~> relations).? ~ - (WHERE ~> expression).? ~ - (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ - (HAVING ~> expression).? ~ - sortType.? ~ - (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(OneRowRelation) - val withFilter = f.map(Filter(_, base)).getOrElse(base) - val withProjection = g - .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) - .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(_(withHaving)).getOrElse(withHaving) - val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) - withLimit - } - - protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false) - } - - protected lazy val cte: Parser[LogicalPlan] = - WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { - case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) - } - - protected lazy val projection: Parser[Expression] = - expression ~ (AS.? ~> ident.?) ^^ { - case e ~ a => a.fold(e)(Alias(e, _)()) - } - - // Based very loosely on the MySQL Grammar. - // http://dev.mysql.com/doc/refman/5.0/en/join.html - protected lazy val relations: Parser[LogicalPlan] = - ( relation ~ rep1("," ~> relation) ^^ { - case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } } - | relation - ) - - protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | relationFactor - - protected lazy val relationFactor: Parser[LogicalPlan] = - ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ { - case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias) - } - | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } - ) - - protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ rep1(joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.?) ^^ { - case r1 ~ joins => - joins.foldLeft(r1) { case (lhs, jt ~ rhs ~ cond) => - Join(lhs, rhs, joinType = jt.getOrElse(Inner), cond) - } - } - - protected lazy val joinConditions: Parser[Expression] = - ON ~> expression - - protected lazy val joinType: Parser[JoinType] = - ( INNER ^^^ Inner - | LEFT ~ SEMI ^^^ LeftSemi - | LEFT ~ OUTER.? ^^^ LeftOuter - | RIGHT ~ OUTER.? ^^^ RightOuter - | FULL ~ OUTER.? ^^^ FullOuter - ) - - protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] = - ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) } - | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) } - ) - - protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(expression ~ direction.?, ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) - } - ) - - protected lazy val direction: Parser[SortDirection] = - ( ASC ^^^ Ascending - | DESC ^^^ Descending - ) - - protected lazy val expression: Parser[Expression] = - orExpression - - protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) - - protected lazy val andExpression: Parser[Expression] = - notExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) - - protected lazy val notExpression: Parser[Expression] = - NOT.? ~ comparisonExpression ^^ { case maybeNot ~ e => maybeNot.map(_ => Not(e)).getOrElse(e) } - - protected lazy val comparisonExpression: Parser[Expression] = - ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } - | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } - | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } - | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } - | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } - | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ ("<=>" ~> termExpression) ^^ { case e1 ~ e2 => EqualNullSafe(e1, e2) } - | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { - case e ~ not ~ el ~ eu => - val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f => Not(betweenExpr)) - } - | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } - | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } - | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } - | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } - | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { - case e1 ~ e2 => In(e1, e2) - } - | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { - case e1 ~ e2 => Not(In(e1, e2)) - } - | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } - | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } - | termExpression - ) - - protected lazy val termExpression: Parser[Expression] = - productExpression * - ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } - | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } - ) - - protected lazy val productExpression: Parser[Expression] = - baseExpression * - ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } - | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } - | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } - | "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) } - | "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) } - | "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) } - ) - - protected lazy val function: Parser[Expression] = - ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => - if (lexical.normalizeKeyword(udfName) == "count") { - AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) - } else { - throw new AnalysisException(s"invalid expression $udfName(*)") - } - } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } - | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => - lexical.normalizeKeyword(udfName) match { - case "count" => - aggregate.Count(exprs).toAggregateExpression(isDistinct = true) - case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) - } - } - | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => - if (lexical.normalizeKeyword(udfName) == "count") { - AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) - } else { - throw new AnalysisException(s"invalid function approximate $udfName") - } - } - | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => - if (lexical.normalizeKeyword(udfName) == "count") { - AggregateExpression( - HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), - mode = Complete, - isDistinct = false) - } else { - throw new AnalysisException(s"invalid function approximate($s) $udfName") - } - } - | CASE ~> whenThenElse ^^ - { case branches => CaseWhen.createFromParser(branches) } - | CASE ~> expression ~ whenThenElse ^^ - { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } - ) - - protected lazy val whenThenElse: Parser[List[Expression]] = - rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { - case altPart ~ elsePart => - altPart.flatMap { case whenExpr ~ thenExpr => - Seq(whenExpr, thenExpr) - } ++ elsePart - } - - protected lazy val cast: Parser[Expression] = - CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { - case exp ~ t => Cast(exp, t) - } - - protected lazy val literal: Parser[Literal] = - ( numericLiteral - | booleanLiteral - | stringLit ^^ { case s => Literal.create(s, StringType) } - | intervalLiteral - | NULL ^^^ Literal.create(null, NullType) - ) - - protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal.create(true, BooleanType) - | FALSE ^^^ Literal.create(false, BooleanType) - ) - - protected lazy val numericLiteral: Parser[Literal] = - ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ - { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } - ) - - protected lazy val unsignedFloat: Parser[String] = - ( "." ~> numericLit ^^ { u => "0." + u } - | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - ) - - protected lazy val sign: Parser[String] = ("+" | "-") - - protected lazy val integral: Parser[String] = - sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n } - - private def intervalUnit(unitName: String) = acceptIf { - case lexical.Identifier(str) => - val normalized = lexical.normalizeKeyword(str) - normalized == unitName || normalized == unitName + "s" - case _ => false - } {_ => "wrong interval unit"} - - protected lazy val month: Parser[Int] = - integral <~ intervalUnit("month") ^^ { case num => num.toInt } - - protected lazy val year: Parser[Int] = - integral <~ intervalUnit("year") ^^ { case num => num.toInt * 12 } - - protected lazy val microsecond: Parser[Long] = - integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong } - - protected lazy val millisecond: Parser[Long] = - integral <~ intervalUnit("millisecond") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_MILLI - } - - protected lazy val second: Parser[Long] = - integral <~ intervalUnit("second") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_SECOND - } - - protected lazy val minute: Parser[Long] = - integral <~ intervalUnit("minute") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE - } - - protected lazy val hour: Parser[Long] = - integral <~ intervalUnit("hour") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_HOUR - } - - protected lazy val day: Parser[Long] = - integral <~ intervalUnit("day") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_DAY - } - - protected lazy val week: Parser[Long] = - integral <~ intervalUnit("week") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_WEEK - } - - private def intervalKeyword(keyword: String) = acceptIf { - case lexical.Identifier(str) => - lexical.normalizeKeyword(str) == keyword - case _ => false - } {_ => "wrong interval keyword"} - - protected lazy val intervalLiteral: Parser[Literal] = - ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~ - intervalKeyword("month") ^^ { case s => - Literal(CalendarInterval.fromYearMonthString(s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~ - intervalKeyword("second") ^^ { case s => - Literal(CalendarInterval.fromDayTimeString(s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("year", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("month", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("day", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("hour", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("minute", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("second", s)) - } - | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ - millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ - millisecond ~ microsecond => - if (!Seq(year, month, week, day, hour, minute, second, - millisecond, microsecond).exists(_.isDefined)) { - throw new AnalysisException( - "at least one time unit should be given for interval literal") - } - val months = Seq(year, month).map(_.getOrElse(0)).sum - val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) - .map(_.getOrElse(0L)).sum - Literal(new CalendarInterval(months, microseconds)) - } - ) - - private def toNarrowestIntegerType(value: String): Any = { - val bigIntValue = BigDecimal(value) - - bigIntValue match { - case v if bigIntValue.isValidInt => v.toIntExact - case v if bigIntValue.isValidLong => v.toLongExact - case v => v.underlying() - } - } - - private def toDecimalOrDouble(value: String): Any = { - val decimal = BigDecimal(value) - // follow the behavior in MS SQL Server - // https://msdn.microsoft.com/en-us/library/ms179899.aspx - if (value.contains('E') || value.contains('e')) { - decimal.doubleValue() - } else { - decimal.underlying() - } - } - - protected lazy val baseExpression: Parser[Expression] = - ( "*" ^^^ UnresolvedStar(None) - | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))} - | primary - ) - - protected lazy val signedPrimary: Parser[Expression] = - sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e } - - protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", { - case lexical.Identifier(str) => str - case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str - }) - - protected lazy val primary: PackratParser[Expression] = - ( literal - | expression ~ ("[" ~> expression <~ "]") ^^ - { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) } - | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) } - | cast - | "(" ~> expression <~ ")" - | function - | dotExpressionHeader - | signedPrimary - | "~" ~> expression ^^ BitwiseNot - | attributeName ^^ UnresolvedAttribute.quoted - ) - - protected lazy val dotExpressionHeader: Parser[Expression] = - (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) - } - - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index e1fd22e36764e..ec833d6789e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -447,6 +447,7 @@ object HyperLogLogPlusPlus { private def validateDoubleLiteral(exp: Expression): Double = exp match { case Literal(d: Double, DoubleType) => d + case Literal(dec: Decimal, _) => dec.toDouble case _ => throw new AnalysisException("The second argument should be a double literal.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index ba9d2524a9551..6d25de98cebc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -108,6 +108,7 @@ class CatalystQlSuite extends PlanTest { } assertRight("9.0e1", 90) + assertRight(".9e+2", 90) assertRight("0.9e+2", 90) assertRight("900e-1", 90) assertRight("900.0E-1", 90) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala deleted file mode 100644 index b0884f528742f..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project} -import org.apache.spark.unsafe.types.CalendarInterval - -private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} - -private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -class SqlParserSuite extends PlanTest { - - test("test long keyword") { - val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === - parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) - } - - test("test case insensitive") { - val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) - } - - test("test NOT operator with comparison operations") { - val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - test("support hive interval literal") { - def checkInterval(sql: String, result: CalendarInterval): Unit = { - val parsed = SqlParser.parse(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - def checkYearMonth(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' YEAR TO MONTH", - CalendarInterval.fromYearMonthString(lit)) - } - - def checkDayTime(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' DAY TO SECOND", - CalendarInterval.fromDayTimeString(lit)) - } - - def checkSingleUnit(lit: String, unit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' $unit", - CalendarInterval.fromSingleUnitString(unit, lit)) - } - - checkYearMonth("123-10") - checkYearMonth("496-0") - checkYearMonth("-2-3") - checkYearMonth("-123-0") - - checkDayTime("99 11:22:33.123456789") - checkDayTime("-99 11:22:33.123456789") - checkDayTime("10 9:8:7.123456789") - checkDayTime("1 0:0:0") - checkDayTime("-1 0:0:0") - checkDayTime("1 0:0:1") - - for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { - checkSingleUnit("7", unit) - checkSingleUnit("-7", unit) - checkSingleUnit("0", unit) - } - - checkSingleUnit("13.123456789", "second") - checkSingleUnit("-13.123456789", "second") - } - - test("support scientific notation") { - def assertRight(input: String, output: Double): Unit = { - val parsed = SqlParser.parse("SELECT " + input) - val expected = Project( - UnresolvedAlias( - Literal(output) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - assertRight("9.0e1", 90) - assertRight(".9e+2", 90) - assertRight("0.9e+2", 90) - assertRight("900e-1", 90) - assertRight("900.0E-1", 90) - assertRight("9.e+1", 90) - - intercept[RuntimeException](SqlParser.parse("SELECT .e3")) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6a020f9f2883e..97bf7a0cc4514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,7 +21,6 @@ import scala.language.implicitConversions import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.SqlParser._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 91bf2f8ce4d2f..3422d0ead4fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -737,7 +737,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(SqlParser.parseExpression(expr)) + Column(sqlContext.sqlParser.parseExpression(expr)) }: _*) } @@ -764,7 +764,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(SqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) } /** @@ -788,7 +788,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(SqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d948e4894253c..8f852e521668a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.SqlParser +import org.apache.spark.sql.catalyst.{CatalystQl} import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation @@ -337,7 +337,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { DataFrame(sqlContext, - sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) + sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 00f9817b53976..ab63fe4aa88b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,7 +22,7 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} @@ -192,7 +192,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(SqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -282,7 +282,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(SqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b909765a7c6dd..a0939adb6d5ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ @@ -205,15 +206,17 @@ class SQLContext private[sql]( protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient - protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) + protected[sql] val ddlParser = new DDLParser(sqlParser) @transient - protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_)) + protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect()) protected[sql] def getSQLDialect(): ParserDialect = { try { val clazz = Utils.classForName(dialectClassName) - clazz.newInstance().asInstanceOf[ParserDialect] + clazz.getConstructor(classOf[ParserConf]) + .newInstance(conf) + .asInstanceOf[ParserDialect] } catch { case NonFatal(e) => // Since we didn't find the available SQL Dialect, it will fail even for SET command: @@ -237,7 +240,7 @@ class SQLContext private[sql]( new sparkexecution.QueryExecution(this, plan) protected[sql] def dialectClassName = if (conf.dialect == "sql") { - classOf[DefaultParserDialect].getCanonicalName + classOf[SparkQl].getCanonicalName } else { conf.dialect } @@ -682,7 +685,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -728,7 +731,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -833,7 +836,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(SqlParser.parseTableIdentifier(tableName)) + table(sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index b3e8d0d84937e..1af2c756cdc5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType @@ -29,9 +29,16 @@ import org.apache.spark.sql.types.StringType * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * - * @param fallback A function that parses an input string to a logical plan + * @param fallback A function that returns the next parser in the chain. This is a call-by-name + * parameter because this allows us to return a different dialect if we + * have to. */ -class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser { + + override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) + + override def parseTableIdentifier(sql: String): TableIdentifier = + fallback.parseTableIdentifier(sql) // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -74,7 +81,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + CacheTableCommand(tableName, plan.map(fallback.parsePlan), isLazy.isDefined) } private lazy val uncache: Parser[LogicalPlan] = @@ -111,7 +118,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa private lazy val others: Parser[LogicalPlan] = wholeInput ^^ { - case input => fallback(input) + case input => fallback.parsePlan(input) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index d8d21b06b8b35..10655a85ccf89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -22,25 +22,30 @@ import scala.util.matching.Regex import org.apache.spark.Logging import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ - /** * A parser for foreign DDL commands. */ -class DDLParser(parseQuery: String => LogicalPlan) +class DDLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser with DataTypeParser with Logging { + override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) + + override def parseTableIdentifier(sql: String): TableIdentifier = + + fallback.parseTableIdentifier(sql) def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { try { - parse(input) + parsePlan(input) } catch { case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) + case _ if !exceptionOnError => fallback.parsePlan(input) case x: Throwable => throw x } } @@ -104,7 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) SaveMode.ErrorIfExists } - val queryPlan = parseQuery(query.get) + val queryPlan = fallback.parsePlan(query.get) CreateTableUsingAsSelect(tableIdent, provider, temp.isDefined, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b8ea2261e94e2..8c2530fd684a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystQl, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -1063,7 +1063,10 @@ object functions extends LegacyFunctions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) + def expr(expr: String): Column = { + val parser = SQLContext.getActive().map(_.getSQLDialect()).getOrElse(new CatalystQl()) + Column(parser.parseExpression(expr)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 58f982c2bc932..aec450e0a6084 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -212,7 +212,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) - val pi = 3.1415 + val pi = "3.1415BD" checkAnswer( sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03d67c4e91f7f..75e81b9c9174d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,10 +21,11 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.CatalystQl import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.catalyst.parser.ParserConf +import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} @@ -32,7 +33,7 @@ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect extends DefaultParserDialect +class MyDialect(conf: ParserConf) extends CatalystQl(conf) class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -161,7 +162,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { newContext.sql("SELECT 1") } // test if the dialect set back to DefaultSQLDialect - assert(newContext.getSQLDialect().getClass === classOf[DefaultParserDialect]) + assert(newContext.getSQLDialect().getClass === classOf[SparkQl]) } test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { @@ -586,7 +587,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Allow only a single WITH clause per query") { - intercept[RuntimeException] { + intercept[AnalysisException] { sql( "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } @@ -602,8 +603,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("from follow multiple brackets") { checkAnswer(sql( """ - |select key from ((select * from testData limit 1) - | union all (select * from testData limit 1)) x limit 1 + |select key from ((select * from testData) + | union all (select * from testData)) x limit 1 """.stripMargin), Row(1) ) @@ -616,7 +617,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql( """ |select key from - | (select * from testData limit 1 union all select * from testData limit 1) x + | (select * from testData union all select * from testData) x | limit 1 """.stripMargin), Row(1) @@ -649,13 +650,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("approximate count distinct") { checkAnswer( - sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROX_COUNT_DISTINCT(a) FROM testData2"), Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( - sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROX_COUNT_DISTINCT(a, 0.04) FROM testData2"), Row(3)) } @@ -1192,19 +1193,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) + sql("SELECT 0.3"), Row(0.3) ) checkAnswer( - sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) + sql("SELECT -0.8"), Row(-0.8) ) checkAnswer( - sql("SELECT .5"), Row(BigDecimal(0.5)) + sql("SELECT .5"), Row(0.5) ) checkAnswer( - sql("SELECT -.18"), Row(BigDecimal(-0.18)) + sql("SELECT -.18"), Row(-0.18) ) } @@ -1218,11 +1219,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) + sql("SELECT 9223372036854775808BD"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) + sql("SELECT -9223372036854775809BD"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } @@ -1237,11 +1238,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT -5.2"), Row(BigDecimal(-5.2)) + sql("SELECT -5.2BD"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(BigDecimal(6.8)) + sql("SELECT +6.8"), Row(6.8d) ) checkAnswer( @@ -1616,20 +1617,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + checkAnswer(sql("select 10.3BD * 3.0BD"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000BD * 3.0BD"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000BD * 30.0BD"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000BD * 3.000000000000000000BD"), Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + checkAnswer(sql("select 10.300000000000000000BD * 3.0000000000000000000BD"), Row(null)) - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + checkAnswer(sql("select 10.3BD / 3.0BD"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000BD / 3.0BD"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000BD / 30.0BD"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000BD / 3.00000000000000000BD"), Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + checkAnswer(sql("select 10.3000000000000000000BD / 3.00000000000000000BD"), Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) } @@ -1655,13 +1656,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("precision smaller than scale") { - checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) - checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) - checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) - checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) - checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) - checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) - checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + checkAnswer(sql("select 10.00BD"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00BD"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10BD"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01BD"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001BD"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01BD"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001BD"), Row(BigDecimal("-0.001"))) } test("external sorting updates peak execution memory") { @@ -1750,7 +1751,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(e1.message.contains("Table not found")) val e2 = intercept[AnalysisException] { - sql("select * from no_db.no_table") + sql("select * from no_db.no_table").show() } assert(e2.message.contains("Table not found")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 860e07c68cef1..e70eb2a060309 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -442,13 +442,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 14"), + sql("select num_str + 1.2BD from jsonTable where num_str > 14"), Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + sql("select num_str + 1.2BD from jsonTable where num_str >= 92233720368547758060BD"), Row(new java.math.BigDecimal("92233720368547758071.2")) ) @@ -856,7 +856,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") checkAnswer( - sql("select map from jsonWithSimpleMap"), + sql("select `map` from jsonWithSimpleMap"), Row(Map("a" -> 1)) :: Row(Map("b" -> 2)) :: Row(Map("c" -> 3)) :: @@ -865,7 +865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - sql("select map['c'] from jsonWithSimpleMap"), + sql("select `map`['c'] from jsonWithSimpleMap"), Row(null) :: Row(null) :: Row(3) :: @@ -884,7 +884,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithComplexMap.registerTempTable("jsonWithComplexMap") checkAnswer( - sql("select map from jsonWithComplexMap"), + sql("select `map` from jsonWithComplexMap"), Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: Row(Map("b" -> Row(null, 2))) :: Row(Map("c" -> Row(Seq(), 4))) :: @@ -894,7 +894,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), Row(Seq(1, 2, 3, null), null) :: Row(null, null) :: Row(null, 4) :: diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index afd2f611580fc..828ec9710550c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -296,6 +296,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Odd changes to output "merge4", + // Unsupported underscore syntax. + "inputddl5", + // Thift is broken... "inputddl8", @@ -603,7 +606,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl2", "inputddl3", "inputddl4", - "inputddl5", "inputddl6", "inputddl7", "inputddl8", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index b22f424981325..313ba18f6aef0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -19,14 +19,23 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ -private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { +private[hive] class ExtendedHiveQlParser(sqlContext: HiveContext) extends AbstractSparkSQLParser { + + val parser = new HiveQl(sqlContext.conf) + + override def parseExpression(sql: String): Expression = parser.parseExpression(sql) + + override def parseTableIdentifier(sql: String): TableIdentifier = + parser.parseTableIdentifier(sql) + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ADD = Keyword("ADD") @@ -38,7 +47,10 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.parsePlan(statement.trim) + case statement => + sqlContext.executionHive.withHiveState { + parser.parsePlan(statement.trim) + } } protected lazy val dfs: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cbaf00603e189..7bdca52200305 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -56,17 +56,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -/** - * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext - */ -private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect { - override def parse(sqlText: String): LogicalPlan = { - sqlContext.executionHive.withHiveState { - HiveQl.parseSql(sqlText) - } - } -} - /** * Returns the current database of metadataHive. */ @@ -342,12 +331,12 @@ class HiveContext private[hive]( * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -361,7 +350,7 @@ class HiveContext private[hive]( * @since 1.2.0 */ def analyze(tableName: String) { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) relation match { @@ -559,7 +548,7 @@ class HiveContext private[hive]( protected[sql] override def getSQLDialect(): ParserDialect = { if (conf.dialect == "hiveql") { - new HiveQLDialect(this) + new ExtendedHiveQlParser(this) } else { super.getSQLDialect() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index daaa5a5709bdc..3d54048c24782 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -416,8 +416,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(table.name, HiveQl.parsePlan(viewText)) - case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText)) + case None => Subquery(table.name, hive.parseSql(viewText)) + case Some(aliasText) => Subquery(aliasText, hive.parseSql(viewText)) } } else { MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ca9ddf94c11a7..46246f8191db1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -79,7 +79,7 @@ private[hive] case class CreateViewAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl extends SparkQl with Logging { +private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -168,8 +168,6 @@ private[hive] object HiveQl extends SparkQl with Logging { "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. ) ++ nativeCommands - protected val hqlParser = new ExtendedHiveQlParser - /** * Returns the HiveConf */ @@ -186,9 +184,6 @@ private[hive] object HiveQl extends SparkQl with Logging { ss.getConf } - /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) - protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { case Token("TOK_TABLEPROPLIST", list) => list.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 53d15c14cb3d5..137dadd6c6bb3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -23,12 +23,15 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.catalyst.plans.logical.Generate import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { + val parser = new HiveQl(SimpleParserConf()) + private def extractTableDesc(sql: String): (HiveTable, Boolean) = { - HiveQl.parsePlan(sql).collect { + parser.parsePlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) }.head } @@ -173,7 +176,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { test("Invalid interval term should throw AnalysisException") { def assertError(sql: String, errorMessage: String): Unit = { val e = intercept[AnalysisException] { - HiveQl.parseSql(sql) + parser.parsePlan(sql) } assert(e.getMessage.contains(errorMessage)) } @@ -186,7 +189,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { } test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { - val plan = HiveQl.parseSql( + val plan = parser.parsePlan( """ |SELECT * |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 78f74cdc19ddb..91bedf9c5af5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -28,9 +29,11 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton class StatisticsSuite extends QueryTest with TestHiveSingleton { import hiveContext.sql + val parser = new HiveQl(SimpleParserConf()) + test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = HiveQl.parseSql(analyzeCommand) + val parsed = parser.parsePlan(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f6c687aab7a1b..61d5aa7ae6b31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -22,12 +22,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{DefaultParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.parser.ParserConf +import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.{ExtendedHiveQlParser, HiveContext, HiveQl, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -56,7 +58,7 @@ case class WindowData( area: String, product: Int) /** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect extends DefaultParserDialect +class MyDialect(conf: ParserConf) extends HiveQl(conf) /** * A collection of hive query tests where we generate the answers ourselves instead of depending on @@ -339,20 +341,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val hiveContext = new HiveContext(sqlContext.sparkContext) val dialectConf = "spark.sql.dialect" checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql")) - assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(hiveContext.getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) } test("SQL Dialect Switching") { - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(getSQLDialect().getClass === classOf[MyDialect]) assert(sql("SELECT 1").collect() === Array(Row(1))) // set the dialect back to the DefaultSQLDialect sql("SET spark.sql.dialect=sql") - assert(getSQLDialect().getClass === classOf[DefaultParserDialect]) + assert(getSQLDialect().getClass === classOf[SparkQl]) sql("SET spark.sql.dialect=hiveql") - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) // set invalid dialect sql("SET spark.sql.dialect.abc=MyTestClass") @@ -361,14 +363,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - getSQLDialect().getClass === classOf[HiveQLDialect] + getSQLDialect().getClass === classOf[ExtendedHiveQlParser] sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) } test("CTAS with serde") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 30e1758076361..62edf6c64bbc7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -188,6 +188,11 @@ public static CalendarInterval fromSingleUnitString(String unit, String s) Integer.MIN_VALUE, Integer.MAX_VALUE); result = new CalendarInterval(month, 0L); + } else if (unit.equals("week")) { + long week = toLongWithRange("week", m.group(1), + Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK); + result = new CalendarInterval(0, week * MICROS_PER_WEEK); + } else if (unit.equals("day")) { long day = toLongWithRange("day", m.group(1), Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); @@ -206,6 +211,15 @@ public static CalendarInterval fromSingleUnitString(String unit, String s) } else if (unit.equals("second")) { long micros = parseSecondNano(m.group(1)); result = new CalendarInterval(0, micros); + + } else if (unit.equals("millisecond")) { + long millisecond = toLongWithRange("millisecond", m.group(1), + Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI); + result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI); + + } else if (unit.equals("microsecond")) { + long micros = Long.valueOf(m.group(1)); + result = new CalendarInterval(0, micros); } } catch (Exception e) { throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); From 5f843781e3e7581c61b7e235d4041d85e8e48c7e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 15 Jan 2016 15:54:19 -0800 Subject: [PATCH 1586/2089] [SPARK-11925][ML][PYSPARK] Add PySpark missing methods for ml.feature during Spark 1.6 QA Add PySpark missing methods and params for ml.feature: * ```RegexTokenizer``` should support setting ```toLowercase```. * ```MinMaxScalerModel``` should support output ```originalMin``` and ```originalMax```. * ```PCAModel``` should support output ```pc```. Author: Yanbo Liang Closes #9908 from yanboliang/spark-11925. --- python/pyspark/ml/feature.py | 72 +++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b02d41b52ab25..141ec3492aa94 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -606,6 +606,10 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") >>> model = mmScaler.fit(df) + >>> model.originalMin + DenseVector([0.0]) + >>> model.originalMax + DenseVector([2.0]) >>> model.transform(df).show() +-----+------+ | a|scaled| @@ -688,6 +692,22 @@ class MinMaxScalerModel(JavaModel): .. versionadded:: 1.6.0 """ + @property + @since("2.0.0") + def originalMin(self): + """ + Min value for each original column during fitting. + """ + return self._call_java("originalMin") + + @property + @since("2.0.0") + def originalMax(self): + """ + Max value for each original column during fitting. + """ + return self._call_java("originalMax") + @inherit_doc @ignore_unicode_prefix @@ -984,18 +1004,18 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): length. It returns an array of strings that can be empty. - >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> df = sqlContext.createDataFrame([("A B c",)], ["text"]) >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") >>> reTokenizer.transform(df).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'A B c', words=[u'a', u'b', u'c']) >>> # Change a parameter. >>> reTokenizer.setParams(outputCol="tokens").transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'A B c', tokens=[u'a', u'b', u'c']) >>> # Temporarily modify a parameter. >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'A B c', words=[u'a', u'b', u'c']) >>> reTokenizer.transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'A B c', tokens=[u'a', u'b', u'c']) >>> # Must use keyword arguments to specify params. >>> reTokenizer.setParams("text") Traceback (most recent call last): @@ -1009,26 +1029,34 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") + toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " + + "lowercase before tokenizing") @keyword_only - def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): + def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, + outputCol=None, toLowercase=True): """ - __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) + __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \ + outputCol=None, toLowercase=True) """ super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") - self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+") + self.toLowercase = Param(self, "toLowercase", "whether to convert all characters to " + + "lowercase before tokenizing") + self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): + def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, + outputCol=None, toLowercase=True): """ - setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) + setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \ + outputCol=None, toLowercase=True) Sets params for this RegexTokenizer. """ kwargs = self.setParams._input_kwargs @@ -1079,6 +1107,21 @@ def getPattern(self): """ return self.getOrDefault(self.pattern) + @since("2.0.0") + def setToLowercase(self, value): + """ + Sets the value of :py:attr:`toLowercase`. + """ + self._paramMap[self.toLowercase] = value + return self + + @since("2.0.0") + def getToLowercase(self): + """ + Gets the value of toLowercase or its default value. + """ + return self.getOrDefault(self.toLowercase) + @inherit_doc class SQLTransformer(JavaTransformer): @@ -2000,6 +2043,15 @@ class PCAModel(JavaModel): .. versionadded:: 1.5.0 """ + @property + @since("2.0.0") + def pc(self): + """ + Returns a principal components Matrix. + Each column is one principal component. + """ + return self._call_java("pc") + @inherit_doc class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): From f6ddbb360ac6ac2778bbdbebbf2fcccabe73349b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 15 Jan 2016 16:03:05 -0800 Subject: [PATCH 1587/2089] [SPARK-12833][HOT-FIX] Reset the locale after we set it. Author: Yin Huai Closes #10778 from yhuai/resetLocale. --- .../datasources/csv/CSVTypeCastSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 40c5ccd0f7a4a..c28a25057e22f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -90,9 +90,14 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Float and Double Types are cast correctly with Locale") { - val locale : Locale = new Locale("fr", "FR") - Locale.setDefault(locale) - assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) - assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + val originalLocale = Locale.getDefault + try { + val locale : Locale = new Locale("fr", "FR") + Locale.setDefault(locale) + assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + } finally { + Locale.setDefault(originalLocale) + } } } From 8dbbf3e75e70e98391b4a1705472caddd129945a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 15 Jan 2016 17:07:24 -0800 Subject: [PATCH 1588/2089] [SPARK-12842][TEST-HADOOP2.7] Add Hadoop 2.7 build profile This patch adds a Hadoop 2.7 build profile in order to let us automate tests against that version. /cc rxin srowen Author: Josh Rosen Closes #10775 from JoshRosen/add-hadoop-2.7-profile. --- dev/create-release/release-build.sh | 3 +- dev/deps/spark-deps-hadoop-2.7 | 188 ++++++++++++++++++++++++++++ dev/run-tests-jenkins.py | 2 + dev/run-tests.py | 1 + dev/test-dependencies.sh | 1 + docs/building-spark.md | 3 +- pom.xml | 10 ++ 7 files changed, 206 insertions(+), 2 deletions(-) create mode 100644 dev/deps/spark-deps-hadoop-2.7 diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index b1895b16b1b61..00bf81120df65 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -168,7 +168,8 @@ if [[ "$1" == "package" ]]; then # share the same Zinc server. make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3035" & + make_binary_release "hadoop2.7" "-Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pyarn" "3036" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & wait diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 new file mode 100644 index 0000000000000..1d34854819a69 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -0,0 +1,188 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-runtime-3.5.2.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +apacheds-i18n-2.0.0-M15.jar +apacheds-kerberos-codec-2.0.0-M15.jar +api-asn1-api-1.0.0-M20.jar +api-util-1.0.0-M20.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.6.0.jar +curator-framework-2.6.0.jar +curator-recipes-2.6.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +gson-2.2.4.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.7.0.jar +hadoop-auth-2.7.0.jar +hadoop-client-2.7.0.jar +hadoop-common-2.7.0.jar +hadoop-hdfs-2.7.0.jar +hadoop-mapreduce-client-app-2.7.0.jar +hadoop-mapreduce-client-common-2.7.0.jar +hadoop-mapreduce-client-core-2.7.0.jar +hadoop-mapreduce-client-jobclient-2.7.0.jar +hadoop-mapreduce-client-shuffle-2.7.0.jar +hadoop-yarn-api-2.7.0.jar +hadoop-yarn-client-2.7.0.jar +hadoop-yarn-common-2.7.0.jar +hadoop-yarn-server-common-2.7.0.jar +hadoop-yarn-server-web-proxy-2.7.0.jar +htrace-core-3.1.0-incubating.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.5.3.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.5.3.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.1.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +super-csv-2.2.0.jar +uncommons-maths-1.2.2a.jar +univocity-parsers-1.5.6.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xercesImpl-2.9.1.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.6.jar diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index c44e522c0475d..a48d918f9dc1f 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -172,6 +172,8 @@ def main(): os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.4" if "test-hadoop2.6" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.6" + if "test-hadoop2.7" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.7" build_display_name = os.environ["BUILD_DISPLAY_NAME"] build_url = os.environ["BUILD_URL"] diff --git a/dev/run-tests.py b/dev/run-tests.py index c1646c77f1e53..5b717895106b3 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -304,6 +304,7 @@ def get_hadoop_profiles(hadoop_version): "hadoop2.3": ["-Pyarn", "-Phadoop-2.3"], "hadoop2.4": ["-Pyarn", "-Phadoop-2.4"], "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], + "hadoop2.7": ["-Pyarn", "-Phadoop-2.7"], } if hadoop_version in sbt_maven_hadoop_profiles: diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3cb5d2be2a91a..924b55287c2dc 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -36,6 +36,7 @@ HADOOP_PROFILES=( hadoop-2.3 hadoop-2.4 hadoop-2.6 + hadoop-2.7 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/docs/building-spark.md b/docs/building-spark.md index 785988902da8e..e1abcf1be501d 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -77,7 +77,8 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 - 2.6.x and later 2.xhadoop-2.6 + 2.6.xhadoop-2.6 + 2.7.x and later 2.xhadoop-2.7 diff --git a/pom.xml b/pom.xml index fc5cf970e0601..fca626991324b 100644 --- a/pom.xml +++ b/pom.xml @@ -2421,6 +2421,16 @@ + + hadoop-2.7 + + 2.7.0 + 0.9.3 + 3.4.6 + 2.6.0 + + + yarn From 3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Jan 2016 17:20:01 -0800 Subject: [PATCH 1589/2089] [SPARK-12649][SQL] support reading bucketed table This PR adds the support to read bucketed tables, and correctly populate `outputPartitioning`, so that we can avoid shuffle for some cases. TODO(follow-up PRs): * bucket pruning * avoid shuffle for bucketed table join when use any super-set of the bucketing key. (we should re-visit it after https://issues.apache.org/jira/browse/SPARK-12704 is fixed) * recognize hive bucketed table Author: Wenchen Fan Closes #10604 from cloud-fan/bucket-read. --- .../apache/spark/sql/DataFrameReader.scala | 1 + .../scala/org/apache/spark/sql/SQLConf.scala | 6 + .../spark/sql/execution/ExistingRDD.scala | 28 ++- .../InsertIntoHadoopFsRelation.scala | 2 +- .../datasources/ResolvedDataSource.scala | 4 +- .../datasources/WriterContainer.scala | 2 +- .../sql/execution/datasources/bucket.scala | 21 ++- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../datasources/json/JSONRelation.scala | 4 +- .../datasources/parquet/ParquetRelation.scala | 2 +- .../sql/execution/datasources/rules.scala | 1 + .../apache/spark/sql/sources/interfaces.scala | 55 +++++- .../datasources/json/JsonSuite.scala | 2 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 26 +-- .../spark/sql/hive/execution/commands.scala | 7 +- .../spark/sql/hive/orc/OrcRelation.scala | 2 +- .../spark/sql/sources/BucketedReadSuite.scala | 178 ++++++++++++++++++ .../sql/sources/BucketedWriteSuite.scala | 16 +- 18 files changed, 314 insertions(+), 45 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8f852e521668a..634c1bd4739b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -109,6 +109,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { sqlContext, userSpecifiedSchema = userSpecifiedSchema, partitionColumns = Array.empty[String], + bucketSpec = None, provider = source, options = extraOptions.toMap) DataFrame(sqlContext, LogicalRelation(resolved.relation)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 7976795ff5919..4e3662724c575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -422,6 +422,10 @@ private[spark] object SQLConf { doc = "The maximum number of concurrent files to open before falling back on sorting when " + "writing out files using dynamic partitioning.") + val BUCKETING_ENABLED = booleanConf("spark.sql.sources.bucketing.enabled", + defaultValue = Some(true), + doc = "When false, we will treat bucketed table as normal table") + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // @@ -590,6 +594,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) + private[spark] def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) + // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 569a21feaa8a2..92cfd5f841c51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -98,7 +99,8 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], override val nodeName: String, override val metadata: Map[String, String] = Map.empty, - isUnsafeRow: Boolean = false) + isUnsafeRow: Boolean = false, + override val outputPartitioning: Partitioning = UnknownPartitioning(0)) extends LeafNode { protected override def doExecute(): RDD[InternalRow] = { @@ -130,6 +132,24 @@ private[sql] object PhysicalRDD { metadata: Map[String, String] = Map.empty): PhysicalRDD = { // All HadoopFsRelations output UnsafeRows val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] - PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows) + + val bucketSpec = relation match { + case r: HadoopFsRelation => r.getBucketSpec + case _ => None + } + + def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { + throw new AnalysisException(s"bucket column $colName not found in existing columns " + + s"(${output.map(_.name).mkString(", ")})") + } + + bucketSpec.map { spec => + val numBuckets = spec.numBuckets + val bucketColumns = spec.bucketColumnNames.map(toAttribute) + val partitioning = HashPartitioning(bucketColumns, numBuckets) + PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows, partitioning) + }.getOrElse { + PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 7a8691e7cb9c5..314c957d57bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.getBucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index ece9b8a9a9174..cc8dcf59307f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -97,6 +97,7 @@ object ResolvedDataSource extends Logging { sqlContext: SQLContext, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) @@ -142,6 +143,7 @@ object ResolvedDataSource extends Logging { paths, Some(dataSchema), maybePartitionsSchema, + bucketSpec, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.RelationProvider => throw new AnalysisException(s"$className does not allow user-specified schemas.") @@ -173,7 +175,7 @@ object ResolvedDataSource extends Logging { SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) } } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + dataSource.createRelation(sqlContext, paths, None, None, None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => throw new AnalysisException( s"A schema needs to be specified when using $className.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index fc77529b7db32..563fd9eefcce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -311,7 +311,7 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - private val bucketSpec = relation.bucketSpec + private val bucketSpec = relation.getBucketSpec private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 9976829638d70..c7ecd6125d860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -44,9 +44,7 @@ private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProv dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = - // TODO: throw exception here as we won't call this method during execution, after bucketed read - // support is finished. - createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters) + throw new UnsupportedOperationException("use the overload version with bucketSpec parameter") } private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { @@ -54,5 +52,20 @@ private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFact path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = - throw new UnsupportedOperationException("use bucket version") + throw new UnsupportedOperationException("use the overload version with bucketSpec parameter") +} + +private[sql] object BucketingUtils { + // The file name of bucketed data should have 3 parts: + // 1. some other information in the head of file name, ends with `-` + // 2. bucket id part, some numbers + // 3. optional file extension part, in the tail of file name, starts with `.` + // An example of bucketed parquet file name with bucket id 3: + // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb-00003.gz.parquet + private val bucketedFileName = """.*-(\d+)(?:\..*)?$""".r + + def getBucketId(fileName: String): Option[Int] = fileName match { + case bucketedFileName(bucketId) => Some(bucketId.toInt) + case other => None + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0897fcadbc011..c3603936dfd2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -91,7 +91,7 @@ case class CreateTempTableUsing( def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( - sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) + sqlContext, userSpecifiedSchema, Array.empty[String], bucketSpec = None, provider, options) sqlContext.catalog.registerTable( tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 8a6fa4aeebc09..20c60b9c43e10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -57,7 +57,7 @@ class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegi maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, - bucketSpec = bucketSpec, + maybeBucketSpec = bucketSpec, paths = paths, parameters = parameters)(sqlContext) } @@ -68,7 +68,7 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val bucketSpec: Option[BucketSpec] = None, + override val maybeBucketSpec: Option[BucketSpec] = None, override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 991a5d5aef2db..30ddec686c921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -112,7 +112,7 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val bucketSpec: Option[BucketSpec], + override val maybeBucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index dd3e66d8a9434..9358c9c37bf67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -36,6 +36,7 @@ private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[Logica sqlContext, userSpecifiedSchema = None, partitionColumns = Array(), + bucketSpec = None, provider = u.tableIdentifier.database.get, options = Map("path" -> u.tableIdentifier.table)) val plan = LogicalRelation(resolved.relation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 9f3607369c30f..7800776fa1a5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -28,13 +28,13 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -458,7 +458,12 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ - private[sql] def bucketSpec: Option[BucketSpec] = None + private[this] var malformedBucketFile = false + + private[sql] def maybeBucketSpec: Option[BucketSpec] = None + + final private[sql] def getBucketSpec: Option[BucketSpec] = + maybeBucketSpec.filter(_ => sqlContext.conf.bucketingEnabled() && !malformedBucketFile) private class FileStatusCache { var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] @@ -664,6 +669,35 @@ abstract class HadoopFsRelation private[sql]( }) } + /** + * Groups the input files by bucket id, if bucketing is enabled and this data source is bucketed. + * Returns None if there exists any malformed bucket files. + */ + private def groupBucketFiles( + files: Array[FileStatus]): Option[scala.collection.Map[Int, Array[FileStatus]]] = { + malformedBucketFile = false + if (getBucketSpec.isDefined) { + val groupedBucketFiles = mutable.HashMap.empty[Int, mutable.ArrayBuffer[FileStatus]] + var i = 0 + while (!malformedBucketFile && i < files.length) { + val bucketId = BucketingUtils.getBucketId(files(i).getPath.getName) + if (bucketId.isEmpty) { + logError(s"File ${files(i).getPath} is expected to be a bucket file, but there is no " + + "bucket id information in file name. Fall back to non-bucketing mode.") + malformedBucketFile = true + } else { + val bucketFiles = + groupedBucketFiles.getOrElseUpdate(bucketId.get, mutable.ArrayBuffer.empty) + bucketFiles += files(i) + } + i += 1 + } + if (malformedBucketFile) None else Some(groupedBucketFiles.mapValues(_.toArray)) + } else { + None + } + } + final private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], @@ -683,7 +717,20 @@ abstract class HadoopFsRelation private[sql]( } } - buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf) + groupBucketFiles(inputStatuses).map { groupedBucketFiles => + // For each bucket id, firstly we get all files belong to this bucket, by detecting bucket + // id from file name. Then read these files into a RDD(use one-partition empty RDD for empty + // bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result. + val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId => + groupedBucketFiles.get(bucketId).map { inputStatuses => + buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1) + }.getOrElse(sqlContext.emptyResult) + } + + new UnionRDD(sqlContext.sparkContext, perBucketRows) + }.getOrElse { + buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e70eb2a060309..8de8ba355e7d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1223,6 +1223,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], + bucketSpec = None, provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) @@ -1230,6 +1231,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], + bucketSpec = None, provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) assert(d1 === d2) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 3d54048c24782..0cfe03ba91ec7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -143,19 +143,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - def partColsFromParts: Option[Seq[String]] = { - table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols => - (0 until numPartCols.toInt).map { index => - val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull - if (partCol == null) { + def getColumnNames(colType: String): Seq[String] = { + table.properties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map { + numCols => (0 until numCols.toInt).map { index => + table.properties.get(s"spark.sql.sources.schema.${colType}Col.$index").getOrElse { throw new AnalysisException( - "Could not read partitioned columns from the metastore because it is corrupted " + - s"(missing part $index of the it, $numPartCols parts are expected).") + s"Could not read $colType columns from the metastore because it is corrupted " + + s"(missing part $index of it, $numCols parts are expected).") } - - partCol } - } + }.getOrElse(Nil) } // Originally, we used spark.sql.sources.schema to store the schema of a data source table. @@ -170,7 +167,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // We only need names at here since userSpecifiedSchema we loaded from the metastore // contains partition columns. We can always get datatypes of partitioning columns // from userSpecifiedSchema. - val partitionColumns = partColsFromParts.getOrElse(Nil) + val partitionColumns = getColumnNames("part") + + val bucketSpec = table.properties.get("spark.sql.sources.schema.numBuckets").map { n => + BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort")) + } // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -181,6 +182,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive hive, userSpecifiedSchema, partitionColumns.toArray, + bucketSpec, table.properties("spark.sql.sources.provider"), options) @@ -282,7 +284,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) val dataSource = ResolvedDataSource( - hive, userSpecifiedSchema, partitionColumns, provider, options) + hive, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options) def newSparkSQLSpecificMetastoreTable(): HiveTable = { HiveTable( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 07a352873d087..e703ac016496d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -213,7 +213,12 @@ case class CreateMetastoreDataSourceAsSelect( case SaveMode.Append => // Check if the specified data source match the data source of the existing table. val resolved = ResolvedDataSource( - sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) + sqlContext, + Some(query.schema.asNullable), + partitionColumns, + bucketSpec, + provider, + optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 14fa152c2331d..40409169b095a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -156,7 +156,7 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val bucketSpec: Option[BucketSpec], + override val maybeBucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala new file mode 100644 index 0000000000000..58ecdd3b801d3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.joins.SortMergeJoin +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("read bucketed data") { + val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + withTable("bucketed_table") { + df.write + .format("parquet") + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + val rdd = hiveContext.table("bucketed_table").filter($"i" === i).queryExecution.toRdd + assert(rdd.partitions.length == 8) + + val attrs = df.select("j", "k").schema.toAttributes + val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { + val getBucketId = UnsafeProjection.create( + HashPartitioning(attrs, 8).partitionIdExpression :: Nil, + attrs) + rows.map(row => getBucketId(row).getInt(0) == index) + }) + + assert(checkBucketId.collect().reduce(_ && _)) + } + } + } + + private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") + private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + + private def testBucketing( + bucketing1: DataFrameWriter => DataFrameWriter, + bucketing2: DataFrameWriter => DataFrameWriter, + joinColumns: Seq[String], + shuffleLeft: Boolean, + shuffleRight: Boolean): Unit = { + withTable("bucketed_table1", "bucketed_table2") { + bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1") + bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val t1 = hiveContext.table("bucketed_table1") + val t2 = hiveContext.table("bucketed_table2") + val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) + + // First check the result is corrected. + checkAnswer( + joined.sort("bucketed_table1.k", "bucketed_table2.k"), + df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) + + assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin]) + val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin] + + assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft) + assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight) + } + } + } + + private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { + joinCols.map(col => left(col) === right(col)).reduce(_ && _) + } + + test("avoid shuffle when join 2 bucketed tables") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + } + + // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 + ignore("avoid shuffle when join keys are a super-set of bucket keys") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") + testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + } + + test("only shuffle one side when join bucketed table and non-bucketed table") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + } + + test("only shuffle one side when 2 bucketed tables have different bucket number") { + val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j") + testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + } + + test("only shuffle one side when 2 bucketed tables have different bucket keys") { + val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i") + val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j") + testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true) + } + + test("shuffle when join keys are not equal to bucket keys") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") + testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true) + } + + test("shuffle when join 2 bucketed tables with bucketing disabled") { + val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { + testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + } + } + + test("avoid shuffle when grouping keys are equal to bucket keys") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table") + val tbl = hiveContext.table("bucketed_table") + val agged = tbl.groupBy("i", "j").agg(max("k")) + + checkAnswer( + agged.sort("i", "j"), + df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) + + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + } + } + + test("avoid shuffle when grouping keys are a super-set of bucket keys") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val tbl = hiveContext.table("bucketed_table") + val agged = tbl.groupBy("i", "j").agg(max("k")) + + checkAnswer( + agged.sort("i", "j"), + df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) + + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + } + } + + test("fallback to non-bucketing mode if there exists any malformed bucket files") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + Utils.deleteRecursively(tableDir) + df1.write.parquet(tableDir.getAbsolutePath) + + val agged = hiveContext.table("bucketed_table").groupBy("i").count() + // make sure we fall back to non-bucketing mode and can't avoid shuffle + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined) + checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i")) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 3ea9826544edb..e812439bed9aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -62,15 +63,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) } - private val testFileName = """.*-(\d+)$""".r - private val otherFileName = """.*-(\d+)\..*""".r - private def getBucketId(fileName: String): Int = { - fileName match { - case testFileName(bucketId) => bucketId.toInt - case otherFileName(bucketId) => bucketId.toInt - } - } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") private def testBucketing( @@ -81,7 +73,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val allBucketFiles = dataDir.listFiles().filterNot(f => f.getName.startsWith(".") || f.getName.startsWith("_") ) - val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get) assert(groupedBucketFiles.size <= 8) for ((bucketId, bucketFiles) <- groupedBucketFiles) { @@ -98,12 +90,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() - val getHashCode = UnsafeProjection.create( + val getBucketId = UnsafeProjection.create( HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, qe.analyzed.output) for (row <- rows) { - val actualBucketId = getHashCode(row).getInt(0) + val actualBucketId = getBucketId(row).getInt(0) assert(actualBucketId == bucketId) } } From 9039333c0a0ce4bea32f012b81c1e82e31246fc1 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 15 Jan 2016 17:40:26 -0800 Subject: [PATCH 1590/2089] [SPARK-12644][SQL] Update parquet reader to be vectorized. This inlines a few of the Parquet decoders and adds vectorized APIs to support decoding in batch. There are a few particulars in the Parquet encodings that make this much more efficient. In particular, RLE encodings are very well suited for batch decoding. The Parquet 2.0 encodings are also very suited for this. This is a work in progress and does not affect the current execution. In subsequent patches, we will support more encodings and types before enabling this. Simple benchmarks indicate this can decode single ints about > 3x faster. Author: Nong Li Author: Nong Closes #10593 from nongli/spark-12644. --- .../org/apache/spark/util/Benchmark.scala | 6 +- .../parquet/UnsafeRowParquetRecordReader.java | 146 +++++++++- .../parquet/VectorizedPlainValuesReader.java | 66 +++++ .../parquet/VectorizedRleValuesReader.java | 274 ++++++++++++++++++ .../parquet/VectorizedValuesReader.java | 37 +++ .../execution/vectorized/ColumnVector.java | 9 +- .../execution/vectorized/ColumnarBatch.java | 13 +- .../vectorized/OffHeapColumnVector.java | 2 + .../vectorized/OnHeapColumnVector.java | 1 + .../parquet/ParquetReadBenchmark.scala | 93 +++++- .../vectorized/ColumnarBatchBenchmark.scala | 7 +- .../vectorized/ColumnarBatchSuite.scala | 27 +- 12 files changed, 625 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 457a1a05a1bf5..d484cec7ae384 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -62,10 +62,10 @@ private[spark] class Benchmark( val firstRate = results.head.avgRate // The results are going to be processor specific so it is useful to include that. println(Benchmark.getProcessorName()) - printf("%-24s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") - println("-------------------------------------------------------------------------") + printf("%-30s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") + println("-------------------------------------------------------------------------------") results.zip(benchmarks).foreach { r => - printf("%-24s %16s %16s %14s\n", + printf("%-30s %16s %16s %14s\n", r._2.name, "%10.2f" format r._1.avgMs, "%10.2f" format r._1.avgRate, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 47818c0939f2a..80805f15a8f06 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,10 +21,10 @@ import java.nio.ByteBuffer; import java.util.List; -import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; import org.apache.parquet.column.Encoding; @@ -35,9 +35,12 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; +import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -102,6 +105,25 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private static final int DEFAULT_VAR_LEN_SIZE = 32; + /** + * columnBatch object that is used for batch decoding. This is created on first use and triggers + * batched decoding. It is not valid to interleave calls to the batched interface with the row + * by row RecordReader APIs. + * This is only enabled with additional flags for development. This is still a work in progress + * and currently unsupported cases will fail with potentially difficult to diagnose errors. + * This should be only turned on for development to work on this feature. + * + * TODOs: + * - Implement all the encodings to support vectorized. + * - Implement v2 page formats (just make sure we create the correct decoders). + */ + private ColumnarBatch columnarBatch; + + /** + * The default config on whether columnarBatch should be offheap. + */ + private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + /** * Tries to initialize the reader for this split. Returns true if this reader supports reading * this split and false otherwise. @@ -135,6 +157,15 @@ public void initialize(String path, List columns) throws IOException { initializeInternal(); } + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + super.close(); + } + @Override public boolean nextKeyValue() throws IOException, InterruptedException { if (batchIdx >= numBatched) { @@ -154,6 +185,46 @@ public float getProgress() throws IOException, InterruptedException { return (float) rowsReturned / totalRowCount; } + /** + * Returns the ColumnarBatch object that will be used for all rows returned by this reader. + * This object is reused. Calling this enables the vectorized reader. This should be called + * before any calls to nextKeyValue/nextBatch. + */ + public ColumnarBatch resultBatch() { + return resultBatch(DEFAULT_MEMORY_MODE); + } + + public ColumnarBatch resultBatch(MemoryMode memMode) { + if (columnarBatch == null) { + columnarBatch = ColumnarBatch.allocate(sparkSchema, memMode); + } + return columnarBatch; + } + + /** + * Advances to the next batch of rows. Returns false if there are no more. + */ + public boolean nextBatch() throws IOException { + assert(columnarBatch != null); + columnarBatch.reset(); + if (rowsReturned >= totalRowCount) return false; + checkEndOfRowGroup(); + + int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned); + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case INT32: + columnReaders[i].readIntBatch(num, columnarBatch.column(i)); + break; + default: + throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType()); + } + } + rowsReturned += num; + columnarBatch.setNumRows(num); + return true; + } + private void initializeInternal() throws IOException { /** * Check that the requested schema is supported. @@ -382,7 +453,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept * * Decoder to return values from a single column. */ - private static final class ColumnReader { + private final class ColumnReader { /** * Total number of values read. */ @@ -416,6 +487,10 @@ private static final class ColumnReader { private IntIterator definitionLevelColumn; private ValuesReader dataColumn; + // Only set if vectorized decoding is true. This is used instead of the row by row decoding + // with `definitionLevelColumn`. + private VectorizedRleValuesReader defColumn; + /** * Total number of values in this column (in this row group). */ @@ -521,6 +596,35 @@ private boolean next() throws IOException { return definitionLevelColumn.nextInt() == maxDefLevel; } + /** + * Reads `total` values from this columnReader into column. + * TODO: implement the other encodings. + */ + private void readIntBatch(int total, ColumnVector column) throws IOException { + int rowId = 0; + while (total > 0) { + // Compute the number of values we want to read in this page. + int leftInPage = (int)(endOfPageValueCount - valuesRead); + if (leftInPage == 0) { + readPage(); + leftInPage = (int)(endOfPageValueCount - valuesRead); + } + int num = Math.min(total, leftInPage); + defColumn.readIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0); + + // Remap the values if it is dictionary encoded. + if (useDictionary) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(column.getInt(i))); + } + } + valuesRead += num; + rowId += num; + total -= num; + } + } + private void readPage() throws IOException { DataPage page = pageReader.readPage(); // TODO: Why is this a visitor? @@ -547,21 +651,28 @@ public Void visit(DataPageV2 dataPageV2) { }); } - private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) - throws IOException { - this.pageValueCount = valueCount; + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset)throws IOException { this.endOfPageValueCount = valuesRead + pageValueCount; if (dataEncoding.usesDictionary()) { + this.dataColumn = null; if (dictionary == null) { throw new IOException( "could not read page in col " + descriptor + " as the dictionary was missing for encoding " + dataEncoding); } - this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( - descriptor, VALUES, dictionary); + if (columnarBatch != null && dataEncoding == Encoding.PLAIN_DICTIONARY) { + this.dataColumn = new VectorizedRleValuesReader(); + } else { + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + } this.useDictionary = true; } else { - this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + if (columnarBatch != null && dataEncoding == Encoding.PLAIN) { + this.dataColumn = new VectorizedPlainValuesReader(4); + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + } this.useDictionary = false; } @@ -573,8 +684,19 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int } private void readPageV1(DataPageV1 page) throws IOException { + this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); - ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + ValuesReader dlReader; + + // Initialize the decoders. Use custom ones if vectorized decoding is enabled. + if (columnarBatch != null && page.getDlEncoding() == Encoding.RLE) { + int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); + assert(bitWidth != 0); // not implemented + this.defColumn = new VectorizedRleValuesReader(bitWidth); + dlReader = this.defColumn; + } else { + dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + } this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { @@ -583,20 +705,20 @@ private void readPageV1(DataPageV1 page) throws IOException { int next = rlReader.getNextOffset(); dlReader.initFromPage(pageValueCount, bytes, next); next = dlReader.getNextOffset(); - initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + initDataReader(page.getValueEncoding(), bytes, next); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } } private void readPageV2(DataPageV2 page) throws IOException { + this.pageValueCount = page.getValueCount(); this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), page.getRepetitionLevels(), descriptor); this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), page.getDefinitionLevels(), descriptor); try { - initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, - page.getValueCount()); + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java new file mode 100644 index 0000000000000..dac0c52ebd2cf --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; + +import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.unsafe.Platform; + +import org.apache.parquet.column.values.ValuesReader; + +/** + * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. + */ +public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { + private byte[] buffer; + private int offset; + private final int byteSize; + + public VectorizedPlainValuesReader(int byteSize) { + this.byteSize = byteSize; + } + + @Override + public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { + this.buffer = bytes; + this.offset = offset + Platform.BYTE_ARRAY_OFFSET; + } + + @Override + public void skip() { + offset += byteSize; + } + + @Override + public void skip(int n) { + offset += n * byteSize; + } + + @Override + public void readIntegers(int total, ColumnVector c, int rowId) { + c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 4 * total; + } + + @Override + public int readInteger() { + int v = Platform.getInt(buffer, offset); + offset += 4; + return v; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java new file mode 100644 index 0000000000000..493ec9deed499 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.bitpacking.BytePacker; +import org.apache.parquet.column.values.bitpacking.Packer; +import org.apache.parquet.io.ParquetDecodingException; +import org.apache.spark.sql.execution.vectorized.ColumnVector; + +/** + * A values reader for Parquet's run-length encoded data. This is based off of the version in + * parquet-mr with these changes: + * - Supports the vectorized interface. + * - Works on byte arrays(byte[]) instead of making byte streams. + * + * This encoding is used in multiple places: + * - Definition/Repetition levels + * - Dictionary ids. + */ +public final class VectorizedRleValuesReader extends ValuesReader { + // Current decoding mode. The encoded data contains groups of either run length encoded data + // (RLE) or bit packed data. Each group contains a header that indicates which group it is and + // the number of values in the group. + // More details here: https://github.com/Parquet/parquet-format/blob/master/Encodings.md + private enum MODE { + RLE, + PACKED + } + + // Encoded data. + private byte[] in; + private int end; + private int offset; + + // bit/byte width of decoded data and utility to batch unpack them. + private int bitWidth; + private int bytesWidth; + private BytePacker packer; + + // Current decoding mode and values + private MODE mode; + private int currentCount; + private int currentValue; + + // Buffer of decoded values if the values are PACKED. + private int[] currentBuffer = new int[16]; + private int currentBufferIdx = 0; + + // If true, the bit width is fixed. This decoder is used in different places and this also + // controls if we need to read the bitwidth from the beginning of the data stream. + private final boolean fixedWidth; + + public VectorizedRleValuesReader() { + fixedWidth = false; + } + + public VectorizedRleValuesReader(int bitWidth) { + fixedWidth = true; + init(bitWidth); + } + + @Override + public void initFromPage(int valueCount, byte[] page, int start) { + this.offset = start; + this.in = page; + if (fixedWidth) { + int length = readIntLittleEndian(); + this.end = this.offset + length; + } else { + this.end = page.length; + if (this.end != this.offset) init(page[this.offset++] & 255); + } + this.currentCount = 0; + } + + /** + * Initializes the internal state for decoding ints of `bitWidth`. + */ + private void init(int bitWidth) { + Preconditions.checkArgument(bitWidth >= 0 && bitWidth <= 32, "bitWidth must be >= 0 and <= 32"); + this.bitWidth = bitWidth; + this.bytesWidth = BytesUtils.paddedByteCountFromBits(bitWidth); + this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); + } + + @Override + public int getNextOffset() { + return this.end; + } + + @Override + public boolean readBoolean() { + return this.readInteger() != 0; + } + + @Override + public void skip() { + this.readInteger(); + } + + @Override + public int readValueDictionaryId() { + return readInteger(); + } + + @Override + public int readInteger() { + if (this.currentCount == 0) { this.readNextGroup(); } + + this.currentCount--; + switch (mode) { + case RLE: + return this.currentValue; + case PACKED: + return this.currentBuffer[currentBufferIdx++]; + } + throw new RuntimeException("Unreachable"); + } + + /** + * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader + * reads the definition levels and then will read from `data` for the non-null values. + * If the value is null, c will be populated with `nullValue`. + * + * This is a batched version of this logic: + * if (this.readInt() == level) { + * c[rowId] = data.readInteger(); + * } else { + * c[rowId] = nullValue; + * } + */ + public void readIntegers(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data, int nullValue) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readIntegers(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putInt(rowId + i, data.readInteger()); + c.putNotNull(rowId + i); + } else { + c.putInt(rowId + i, nullValue); + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + /** + * Reads the next varint encoded int. + */ + private int readUnsignedVarInt() { + int value = 0; + int shift = 0; + int b; + do { + b = in[offset++] & 255; + value |= (b & 0x7F) << shift; + shift += 7; + } while ((b & 0x80) != 0); + return value; + } + + /** + * Reads the next 4 byte little endian int. + */ + private int readIntLittleEndian() { + int ch4 = in[offset] & 255; + int ch3 = in[offset + 1] & 255; + int ch2 = in[offset + 2] & 255; + int ch1 = in[offset + 3] & 255; + offset += 4; + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); + } + + /** + * Reads the next byteWidth little endian int. + */ + private int readIntLittleEndianPaddedOnBitWidth() { + switch (bytesWidth) { + case 0: + return 0; + case 1: + return in[offset++] & 255; + case 2: { + int ch2 = in[offset] & 255; + int ch1 = in[offset + 1] & 255; + offset += 2; + return (ch1 << 8) + ch2; + } + case 3: { + int ch3 = in[offset] & 255; + int ch2 = in[offset + 1] & 255; + int ch1 = in[offset + 2] & 255; + offset += 3; + return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); + } + case 4: { + return readIntLittleEndian(); + } + } + throw new RuntimeException("Unreachable"); + } + + /** + * Reads the next group. + */ + private void readNextGroup() { + Preconditions.checkArgument(this.offset < this.end, + "Reading past RLE/BitPacking stream. offset=" + this.offset + " end=" + this.end); + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int bytesToRead = (int)Math.ceil((double)(this.currentCount * this.bitWidth) / 8.0D); + + bytesToRead = Math.min(bytesToRead, this.end - this.offset); + int valueIndex = 0; + for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { + this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); + valueIndex += 8; + } + offset += bytesToRead; + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } +} \ No newline at end of file diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java new file mode 100644 index 0000000000000..49a9ed83d590a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import org.apache.spark.sql.execution.vectorized.ColumnVector; + +/** + * Interface for value decoding that supports vectorized (aka batched) decoding. + * TODO: merge this into parquet-mr. + */ +public interface VectorizedValuesReader { + int readInteger(); + + /* + * Reads `total` values into `c` start at `c[rowId]` + */ + void readIntegers(int total, ColumnVector c, int rowId); + + // TODO: add all the other parquet types. + + void skip(int n); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index d9dde92ceb6d7..85509751dbbee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.DataType; /** @@ -33,8 +34,8 @@ public abstract class ColumnVector { /** * Allocates a column with each element of size `width` either on or off heap. */ - public static ColumnVector allocate(int capacity, DataType type, boolean offHeap) { - if (offHeap) { + public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { + if (mode == MemoryMode.OFF_HEAP) { return new OffHeapColumnVector(capacity, type); } else { return new OnHeapColumnVector(capacity, type); @@ -111,7 +112,7 @@ public void reset() { public abstract void putInts(int rowId, int count, int[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) * The data in src must be 4-byte little endian ints. */ public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); @@ -138,7 +139,7 @@ public void reset() { public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) * The data in src must be ieee formated doubles. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 47defac4534dc..2c55f854c2419 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -19,6 +19,7 @@ import java.util.Arrays; import java.util.Iterator; +import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -59,12 +60,12 @@ public final class ColumnarBatch { // Total number of rows that have been filtered. private int numRowsFiltered = 0; - public static ColumnarBatch allocate(StructType schema, boolean offHeap) { - return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, offHeap); + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { + return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); } - public static ColumnarBatch allocate(StructType schema, boolean offHeap, int maxRows) { - return new ColumnarBatch(schema, maxRows, offHeap); + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) { + return new ColumnarBatch(schema, maxRows, memMode); } /** @@ -282,7 +283,7 @@ public final void markFiltered(int rowId) { ++numRowsFiltered; } - private ColumnarBatch(StructType schema, int maxRows, boolean offHeap) { + private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { this.schema = schema; this.capacity = maxRows; this.columns = new ColumnVector[schema.size()]; @@ -290,7 +291,7 @@ private ColumnarBatch(StructType schema, int maxRows, boolean offHeap) { for (int i = 0; i < schema.fields().length; ++i) { StructField field = schema.fields()[i]; - columns[i] = ColumnVector.allocate(maxRows, field.dataType(), offHeap); + columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 2a9a2d1104b22..6180dd308e5e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -49,6 +49,7 @@ protected OffHeapColumnVector(int capacity, DataType type) { } else { throw new RuntimeException("Unhandled " + type); } + anyNullsSet = true; reset(); } @@ -98,6 +99,7 @@ public final void putNulls(int rowId, int count) { @Override public final void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index a7b3addf11b14..76d9956c3842f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -97,6 +97,7 @@ public final void putNulls(int rowId, int count) { @Override public final void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index ae95b50e1ee76..14be9eec9a97a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -59,24 +59,31 @@ object ParquetReadBenchmark { } def intScanBenchmark(values: Int): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) + withTempPath { dir => - sqlContext.range(values).write.parquet(dir.getCanonicalPath) - withTempTable("tempTable") { + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select cast(id as INT) as id from t1") + .write.parquet(dir.getCanonicalPath) sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") - val benchmark = new Benchmark("Single Int Column Scan", values) - benchmark.addCase("SQL Parquet Reader") { iter => + sqlBenchmark.addCase("SQL Parquet Reader") { iter => sqlContext.sql("select sum(id) from tempTable").collect() } - benchmark.addCase("SQL Parquet MR") { iter => + sqlBenchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { sqlContext.sql("select sum(id) from tempTable").collect() } } val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("ParquetReader") { num => + // Driving the parquet reader directly without Spark. + parquetReaderBenchmark.addCase("ParquetReader") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => val reader = new UnsafeRowParquetRecordReader @@ -87,26 +94,82 @@ object ParquetReadBenchmark { if (!record.isNullAt(0)) sum += record.getInt(0) } reader.close() - }} + } + } + + // Driving the parquet reader in batch mode directly. + parquetReaderBenchmark.addCase("ParquetReader(Batched)") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.getIsNull(i)) sum += col.getInt(i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader(Batch -> Row)") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) sum += record.getInt(0) + } + } + } finally { + reader.close() + } + } + } /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - SQL Parquet Reader 1910.0 13.72 1.00 X - SQL Parquet MR 2330.0 11.25 0.82 X - ParquetReader 1252.6 20.93 1.52 X + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + SQL Parquet Reader 1682.6 15.58 1.00 X + SQL Parquet MR 2379.6 11.02 0.71 X */ - benchmark.run() + sqlBenchmark.run() + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Parquet Reader Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + ParquetReader 610.40 25.77 1.00 X + ParquetReader(Batched) 172.66 91.10 3.54 X + ParquetReader(Batch -> Row) 192.28 81.80 3.17 X + */ + parquetReaderBenchmark.run() } } } def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + withTempPath { dir => withTempTable("t1", "tempTable") { sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select id as c1, cast(id as STRING) as c2 from t1") + sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") .write.parquet(dir.getCanonicalPath) sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index e28153d12a354..bfe944d835bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector import org.apache.spark.sql.types.IntegerType @@ -136,7 +137,7 @@ object ColumnarBatchBenchmark { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, false) + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -155,7 +156,7 @@ object ColumnarBatchBenchmark { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = ColumnVector.allocate(count, IntegerType, true) + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -174,7 +175,7 @@ object ColumnarBatchBenchmark { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, true) + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 305a83e3e45c9..d5e517c7f56be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} @@ -28,10 +29,10 @@ import org.apache.spark.unsafe.Platform class ColumnarBatchSuite extends SparkFunSuite { test("Null Apis") { - (false :: true :: Nil).foreach { offHeap => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[Boolean] - val column = ColumnVector.allocate(1024, IntegerType, offHeap) + val column = ColumnVector.allocate(1024, IntegerType, memMode) var idx = 0 assert(column.anyNullsSet() == false) @@ -64,7 +65,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getIsNull(v._2)) - if (offHeap) { + if (memMode == MemoryMode.OFF_HEAP) { val addr = column.nullsNativeAddress() assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) } @@ -74,12 +75,12 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("Int Apis") { - (false :: true :: Nil).foreach { offHeap => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = ColumnVector.allocate(1024, IntegerType, offHeap) + val column = ColumnVector.allocate(1024, IntegerType, memMode) var idx = 0 val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray @@ -131,8 +132,8 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Off Heap=" + offHeap) - if (offHeap) { + assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Mem Mode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) } @@ -142,12 +143,12 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("Double APIs") { - (false :: true :: Nil).foreach { offHeap => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = ColumnVector.allocate(1024, DoubleType, offHeap) + val column = ColumnVector.allocate(1024, DoubleType, memMode) var idx = 0 val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray @@ -198,8 +199,8 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " Off Heap=" + offHeap) - if (offHeap) { + assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) } @@ -209,13 +210,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("ColumnarBatch basic") { - (false :: true :: Nil).foreach { offHeap => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType() .add("intCol", IntegerType) .add("doubleCol", DoubleType) .add("intCol2", IntegerType) - val batch = ColumnarBatch.allocate(schema, offHeap) + val batch = ColumnarBatch.allocate(schema, memMode) assert(batch.numCols() == 3) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) From 242efb7546084592a5e8122549a27117977303fb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 15 Jan 2016 19:07:42 -0800 Subject: [PATCH 1591/2089] [SPARK-12840] [SQL] Support passing arbitrary objects (not just expressions) into code generated classes This is a refactor to support codegen for aggregation and broadcast join. Author: Davies Liu Closes #10777 from davies/rename2. --- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +++--- .../expressions/codegen/CodeGenerator.scala | 18 ++++++++---------- .../expressions/codegen/CodegenFallback.scala | 5 +++-- .../codegen/GenerateMutableProjection.scala | 14 +++++++------- .../expressions/codegen/GenerateOrdering.scala | 10 +++++----- .../codegen/GeneratePredicate.scala | 10 +++++----- .../codegen/GenerateSafeProjection.scala | 12 ++++++------ .../codegen/GenerateUnsafeProjection.scala | 14 +++++++------- .../codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../columnar/GenerateColumnAccessor.scala | 4 ++-- 11 files changed, 48 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 4035c9dfa4f8c..681694746b84c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -985,7 +985,7 @@ case class ScalaUDF( ctx.addMutableState(converterClassName, converterTerm, s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + - s"expressions[$expressionIdx]).getChildren().apply($index))).dataType());") + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") converterTerm } @@ -1005,7 +1005,7 @@ case class ScalaUDF( val catalystConverterTermIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, catalystConverterTerm, s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter((($scalaUDFClassName)expressions" + + s".createToCatalystConverter((($scalaUDFClassName)references" + s"[$catalystConverterTermIdx]).dataType());") val resultTerm = ctx.freshName("result") @@ -1020,7 +1020,7 @@ case class ScalaUDF( val funcTerm = ctx.freshName("udf") val funcExpressionIdx = ctx.references.size - 1 ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)expressions" + + s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" + s"[$funcExpressionIdx]).userDefinedFunc());") // codegen for children expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1c7083bbdacb2..f3a39a0e7588a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -45,16 +45,15 @@ import org.apache.spark.util.Utils case class ExprCode(var code: String, var isNull: String, var value: String) /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. + * A context for codegen, tracking a list of objects that could be passed into generated Java + * function. */ class CodegenContext { /** - * Holding all the expressions those do not support codegen, will be evaluated directly. + * Holding a list of objects that could be used passed into generated class. */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a @@ -400,7 +399,7 @@ class CodegenContext { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) - // Get all the exprs that appear at least twice and set up the state for subexpression + // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { @@ -465,7 +464,7 @@ class CodegenContext { * into generated class. */ abstract class GeneratedClass { - def generate(expressions: Array[Expression]): Any + def generate(references: Array[Any]): Any } /** @@ -475,8 +474,6 @@ abstract class GeneratedClass { */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val exprType: String = classOf[Expression].getName - protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName protected def declareMutableStates(ctx: CodegenContext): String = { @@ -534,7 +531,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeArrayData].getName, classOf[MapData].getName, classOf[UnsafeMapData].getName, - classOf[MutableRow].getName + classOf[MutableRow].getName, + classOf[Expression].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index c98b7350b3594..cface21e5f6f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -30,12 +30,13 @@ trait CodegenFallback extends Expression { case _ => } + val idx = ctx.references.length ctx.references += this val objectTerm = ctx.freshName("obj") if (nullable) { s""" /* expression: ${this.toCommentSafeString} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { ev.isNull = "false" s""" /* expression: ${this.toCommentSafeString} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW}); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index a6ec242589fa9..63d13a8b8780f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -99,24 +99,24 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public java.lang.Object generate($exprType[] expr) { - return new SpecificMutableProjection(expr); + public java.lang.Object generate(Object[] references) { + return new SpecificMutableProjection(references); } class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { - private $exprType[] expressions; - private $mutableRowType mutableRow; + private Object[] references; + private MutableRow mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificMutableProjection($exprType[] expr) { - expressions = expr; + public SpecificMutableProjection(Object[] references) { + this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} } - public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { mutableRow = row; return this; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 88bcf5b4ed369..e033f6217003c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -111,18 +111,18 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR val ctx = newCodeGenContext() val comparisons = genComparisons(ctx, ordering) val code = s""" - public SpecificOrdering generate($exprType[] expr) { - return new SpecificOrdering(expr); + public SpecificOrdering generate(Object[] references) { + return new SpecificOrdering(references); } class SpecificOrdering extends ${classOf[BaseOrdering].getName} { - private $exprType[] expressions; + private Object[] references; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificOrdering($exprType[] expr) { - expressions = expr; + public SpecificOrdering(Object[] references) { + this.references = references; ${initMutableStates(ctx)} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 457b4f08424a6..6fbe12fc6505e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -41,17 +41,17 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool val ctx = newCodeGenContext() val eval = predicate.gen(ctx) val code = s""" - public SpecificPredicate generate($exprType[] expr) { - return new SpecificPredicate(expr); + public SpecificPredicate generate(Object[] references) { + return new SpecificPredicate(references); } class SpecificPredicate extends ${classOf[Predicate].getName} { - private final $exprType[] expressions; + private final Object[] references; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificPredicate($exprType[] expr) { - expressions = expr; + public SpecificPredicate(Object[] references) { + this.references = references; ${initMutableStates(ctx)} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 865170764640e..10bd9c6103bb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -152,19 +152,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public java.lang.Object generate($exprType[] expr) { - return new SpecificSafeProjection(expr); + public java.lang.Object generate(Object[] references) { + return new SpecificSafeProjection(references); } class SpecificSafeProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - private $mutableRowType mutableRow; + private Object[] references; + private MutableRow mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificSafeProjection($exprType[] expr) { - expressions = expr; + public SpecificSafeProjection(Object[] references) { + this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3a929927c3f60..1a0565a8ebc1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -320,8 +320,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro create(canonicalize(expressions), subexpressionEliminationEnabled) } - protected def create(expressions: Seq[Expression]): UnsafeProjection = { - create(expressions, subexpressionEliminationEnabled = false) + protected def create(references: Seq[Expression]): UnsafeProjection = { + create(references, subexpressionEliminationEnabled = false) } private def create( @@ -331,20 +331,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public java.lang.Object generate($exprType[] exprs) { - return new SpecificUnsafeProjection(exprs); + public java.lang.Object generate(Object[] references) { + return new SpecificUnsafeProjection(references); } class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - private $exprType[] expressions; + private Object[] references; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificUnsafeProjection($exprType[] expressions) { - this.expressions = expressions; + public SpecificUnsafeProjection(Object[] references) { + this.references = references; ${initMutableStates(ctx)} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 88b3c5e47f6ff..8781cc77f4963 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -158,7 +158,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public java.lang.Object generate($exprType[] exprs) { + |public java.lang.Object generate(Object[] references) { | return new SpecificUnsafeRowJoiner(); |} | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a3c10c81c35e5..c290aa8825adc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -221,7 +221,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with val hsetTerm = ctx.freshName("hset") val hasNullTerm = ctx.freshName("hasNull") ctx.addMutableState(setName, hsetTerm, - s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") s""" ${childGen.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 55e2c0ed70029..7888e34e8a83a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -123,7 +123,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; - public SpecificColumnarIterator generate($exprType[] expr) { + public SpecificColumnarIterator generate(Object[] references) { return new SpecificColumnarIterator(); } @@ -190,6 +190,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") - compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] + compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] } } From 2f7d0b68a29de9755fc9fafd9a52c048981ad880 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 16 Jan 2016 00:38:17 -0800 Subject: [PATCH 1592/2089] [SPARK-12856] [SQL] speed up hashCode of unsafe array We iterate the bytes to calculate hashCode before, but now we have `Murmur3_x86_32.hashUnsafeBytes` that don't require the bytes to be word algned, we should use that instead. A simple benchmark shows it's about 3 X faster, benchmark code: https://gist.github.com/cloud-fan/fa77713ccebf0823b2ab#file-arrayhashbenchmark-scala Author: Wenchen Fan Closes #10784 from cloud-fan/array-hashcode. --- .../spark/sql/catalyst/expressions/UnsafeArrayData.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 3d80df227151d..648625b2cc5d2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -299,11 +300,7 @@ public UnsafeMapData getMap(int ordinal) { @Override public int hashCode() { - int result = 37; - for (int i = 0; i < sizeInBytes; i++) { - result = 37 * result + Platform.getByte(baseObject, baseOffset + i); - } - return result; + return Murmur3_x86_32.hashUnsafeBytes(baseObject, baseOffset, sizeInBytes, 42); } @Override From 86972fa52152d2149b88ba75be048a6986006285 Mon Sep 17 00:00:00 2001 From: Jeff Lam Date: Sat, 16 Jan 2016 10:41:40 +0000 Subject: [PATCH 1593/2089] [SPARK-12722][DOCS] Fixed typo in Pipeline example http://spark.apache.org/docs/latest/ml-guide.html#example-pipeline ``` val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") ``` should be ``` val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") ``` cc: jkbradley Author: Jeff Lam Closes #10769 from Agent007/SPARK-12722. --- docs/ml-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 1343753bce246..5aafd53b584e7 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -428,7 +428,7 @@ This example follows the simple text document `Pipeline` illustrated in the figu
    {% highlight scala %} -import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.mllib.linalg.Vector @@ -466,7 +466,7 @@ model.save("/tmp/spark-logistic-regression-model") pipeline.save("/tmp/unfit-lr-model") // and load it back in during production -val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") +val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") // Prepare test documents, which are unlabeled (id, text) tuples. val test = sqlContext.createDataFrame(Seq( From 3c0d2365d57fc49ac9bf0d7cc9bd2ef633fb5fb6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 16 Jan 2016 10:29:27 -0800 Subject: [PATCH 1594/2089] [SPARK-12796] [SQL] Whole stage codegen This is the initial work for whole stage codegen, it support Projection/Filter/Range, we will continue work on this to support more physical operators. A micro benchmark show that a query with range, filter and projection could be 3X faster then before. It's turned on by default. For a tree that have at least two chained plans, a WholeStageCodegen will be inserted into it, for example, the following plan ``` Limit 10 +- Project [(id#5L + 1) AS (id + 1)#6L] +- Filter ((id#5L & 1) = 1) +- Range 0, 1, 4, 10, [id#5L] ``` will be translated into ``` Limit 10 +- WholeStageCodegen +- Project [(id#1L + 1) AS (id + 1)#2L] +- Filter ((id#1L & 1) = 1) +- Range 0, 1, 4, 10, [id#1L] ``` Here is the call graph to generate Java source for A and B (A support codegen, but B does not): ``` * WholeStageCodegen Plan A FakeInput Plan B * ========================================================================= * * -> execute() * | * doExecute() --------> produce() * | * doProduce() -------> produce() * | * doProduce() ---> execute() * | * consume() * doConsume() ------------| * | * doConsume() <----- consume() ``` A SparkPlan that support codegen need to implement doProduce() and doConsume(): ``` def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String ``` Author: Davies Liu Closes #10735 from davies/whole2. --- .../catalyst/expressions/BoundAttribute.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 76 +++-- .../expressions/codegen/CodegenFallback.scala | 8 +- .../codegen/GenerateMutableProjection.scala | 8 +- .../codegen/GenerateOrdering.scala | 8 +- .../codegen/GeneratePredicate.scala | 8 +- .../codegen/GenerateSafeProjection.scala | 8 +- .../codegen/GenerateUnsafeProjection.scala | 10 +- .../codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../expressions/conditionalExpressions.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 3 +- .../spark/sql/catalyst/trees/TreeNode.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 3 - .../scala/org/apache/spark/sql/SQLConf.scala | 9 + .../org/apache/spark/sql/SQLContext.scala | 3 +- .../sql/execution/BufferedRowIterator.java | 64 ++++ .../spark/sql/execution/SparkPlan.scala | 1 - .../sql/execution/WholeStageCodegen.scala | 299 ++++++++++++++++++ .../spark/sql/execution/basicOperators.scala | 114 ++++++- .../columnar/GenerateColumnAccessor.scala | 6 +- .../apache/spark/sql/CachedTableSuite.scala | 6 +- .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 +- .../BenchmarkWholeStageCodegen.scala | 60 ++++ .../spark/sql/execution/PlannerSuite.scala | 6 +- .../execution/WholeStageCodegenSuite.scala | 38 +++ .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../spark/sql/hive/CachedTableSuite.scala | 6 +- .../hive/execution/HiveComparisonTest.scala | 2 +- .../execution/HiveTypeCoercionSuite.scala | 2 +- .../sql/hive/execution/PruningSuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 8 +- .../ParquetHadoopFsRelationSuite.scala | 2 +- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 37 files changed, 694 insertions(+), 107 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index dda822d05485b..4727ff1885ad7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -61,7 +61,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - if (nullable) { + if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { + ev.isNull = ctx.currentVars(ordinal).isNull + ev.value = ctx.currentVars(ordinal).value + "" + } else if (nullable) { s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f3a39a0e7588a..683029ff144d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,12 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Holding a list of generated columns as input of current operator, will be used by + * BoundReference to generate code. + */ + var currentVars: Seq[ExprCode] = null + /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. @@ -77,6 +83,16 @@ class CodegenContext { mutableStates += ((javaType, variableName, initCode)) } + def declareMutableStates(): String = { + mutableStates.map { case (javaType, variableName, _) => + s"private $javaType $variableName;" + }.mkString("\n") + } + + def initMutableStates(): String = { + mutableStates.map(_._3).mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ @@ -111,6 +127,10 @@ class CodegenContext { // The collection of sub-exression result resetting methods that need to be called on each row. val subExprResetVariables = mutable.ArrayBuffer.empty[String] + def declareAddedFunctions(): String = { + addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -120,7 +140,7 @@ class CodegenContext { final val JAVA_DOUBLE = "double" /** The variable name of the input row in generated code. */ - final val INPUT_ROW = "i" + final var INPUT_ROW = "i" private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -476,20 +496,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodegenContext): String = { - ctx.mutableStates.map { case (javaType, variableName, _) => - s"private $javaType $variableName;" - }.mkString("\n") - } - - protected def initMutableStates(ctx: CodegenContext): String = { - ctx.mutableStates.map(_._3).mkString("\n") - } - - protected def declareAddedFunctions(ctx: CodegenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim - } - /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -505,16 +511,33 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def generate(expressions: InType): OutType = create(canonicalize(expressions)) + /** - * Compile the Java source code into a Java class, using Janino. + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen */ - protected def compile(code: String): GeneratedClass = { + def newCodeGenContext(): CodegenContext = { + new CodegenContext + } +} + +object CodeGenerator extends Logging { + /** + * Compile the Java source code into a Java class, using Janino. + */ + def compile(code: String): GeneratedClass = { cache.get(code) } /** - * Compile the Java source code into a Java class, using Janino. - */ + * Compile the Java source code into a Java class, using Janino. + */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) @@ -577,19 +600,4 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin result } }) - - /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = - generate(bind(expressions, inputSchema)) - - /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = create(canonicalize(expressions)) - - /** - * Create a new codegen context for expression evaluator, used to store those - * expressions that don't support codegen - */ - def newCodeGenContext(): CodegenContext = { - new CodegenContext - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index cface21e5f6f1..f58a2daf902f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -30,13 +30,15 @@ trait CodegenFallback extends Expression { case _ => } + // LeafNode does not need `input` + val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length ctx.references += this val objectTerm = ctx.freshName("obj") if (nullable) { s""" /* expression: ${this.toCommentSafeString} */ - Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW}); + Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { @@ -47,7 +49,7 @@ trait CodegenFallback extends Expression { ev.isNull = "false" s""" /* expression: ${this.toCommentSafeString} */ - Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW}); + Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 63d13a8b8780f..59ef0f5836a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -107,13 +107,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private Object[] references; private MutableRow mutableRow; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificMutableProjection(Object[] references) { this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { @@ -138,7 +138,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) () => { c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index e033f6217003c..6de57537ec078 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -118,12 +118,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private Object[] references; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificOrdering(Object[] references) { this.references = references; - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public int compare(InternalRow a, InternalRow b) { @@ -135,6 +135,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") - compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 6fbe12fc6505e..58065d956f072 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -47,12 +47,12 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final Object[] references; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificPredicate(Object[] references) { this.references = references; - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { @@ -63,7 +63,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 10bd9c6103bb0..e750ad9c184b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -160,13 +160,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private Object[] references; private MutableRow mutableRow; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificSafeProjection(Object[] references) { this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public java.lang.Object apply(java.lang.Object _i) { @@ -179,7 +179,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1a0565a8ebc1b..61e7469ee4be2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -338,14 +338,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { private Object[] references; - - ${declareMutableStates(ctx)} - - ${declareAddedFunctions(ctx)} + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} public SpecificUnsafeProjection(Object[] references) { this.references = references; - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } // Scala.Function1 need this @@ -362,7 +360,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 8781cc77f4963..b1ffbaa3e94ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 2a24235a29c9c..1eff2c4dd0086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -224,6 +224,7 @@ object CaseWhen { } } + /** * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". * When a = b, returns c; when a = d, returns e; else returns f. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 2c12de08f4115..493e0aae01af7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -351,8 +351,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression val hasher = classOf[Murmur3_x86_32].getName def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = - ExprCode(code = "", isNull = "false", value = v) + def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) dataType match { case NullType => inlineValue(seed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d0b29aa01f640..d74f3ef2ffba6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -452,7 +452,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and * `lastChildren` for the root node should be empty. */ - protected def generateTreeString( + def generateTreeString( depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { if (depth > 0) { lastChildren.init.foreach { isLast => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3422d0ead4fc1..95e5fbb11900e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.util.Properties import scala.language.implicitConversions import scala.reflect.ClassTag @@ -39,12 +38,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator -import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { new DataFrame(sqlContext, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4e3662724c575..4c1eb0b30b25e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -489,6 +489,13 @@ private[spark] object SQLConf { isPublic = false, doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") + val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage", + defaultValue = Some(true), + doc = "When true, the whole stage (of multiple operators) will be compiled into single java" + + " method", + isPublic = false) + + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -561,6 +568,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) + private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) private[spark] def subexpressionEliminationEnabled: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a0939adb6d5ae..18ddffe1be211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -904,7 +904,8 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)) + Batch("Add exchange", Once, EnsureRequirements(self)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java new file mode 100644 index 0000000000000..b1bbb1da10a39 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import scala.collection.Iterator; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; + +/** + * An iterator interface used to pull the output from generated function for multiple operators + * (whole stage codegen). + * + * TODO: replaced it by batched columnar format. + */ +public class BufferedRowIterator { + protected InternalRow currentRow; + protected Iterator input; + // used when there is no column in output + protected UnsafeRow unsafeRow = new UnsafeRow(0); + + public boolean hasNext() { + if (currentRow == null) { + processNext(); + } + return currentRow != null; + } + + public InternalRow next() { + InternalRow r = currentRow; + currentRow = null; + return r; + } + + public void setInput(Iterator iter) { + input = iter; + } + + /** + * Processes the input until have a row as output (currentRow). + * + * After it's called, if currentRow is still null, it means no more rows left. + */ + protected void processNext() { + if (input.hasNext()) { + currentRow = input.next(); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2355de3d05865..75101ea0fc6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,7 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala new file mode 100644 index 0000000000000..c15fabab805a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * An interface for those physical operators that support codegen. + */ +trait CodegenSupport extends SparkPlan { + + /** + * Whether this SparkPlan support whole stage codegen or not. + */ + def supportCodegen: Boolean = true + + /** + * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. + */ + private var parent: CodegenSupport = null + + /** + * Returns an input RDD of InternalRow and Java source code to process them. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + this.parent = parent + doProduce(ctx) + } + + /** + * Generate the Java source code to process, should be overrided by subclass to support codegen. + * + * doProduce() usually generate the framework, for example, aggregation could generate this: + * + * if (!initialized) { + * # create a hash map, then build the aggregation hash map + * # call child.produce() + * initialized = true; + * } + * while (hashmap.hasNext()) { + * row = hashmap.next(); + * # build the aggregation results + * # create varialbles for results + * # call consume(), wich will call parent.doConsume() + * } + */ + protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + + /** + * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + */ + protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { + assert(columns.length == output.length) + parent.doConsume(ctx, this, columns) + } + + + /** + * Generate the Java source code to process the rows from child SparkPlan. + * + * This should be override by subclass to support codegen. + * + * For example, Filter will generate the code like this: + * + * # code to evaluate the predicate expression, result is isNull1 and value2 + * if (isNull1 || value2) { + * # call consume(), which will call parent.doConsume() + * } + */ + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String +} + + +/** + * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. + * + * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes + * an RDD iterator of InternalRow. + */ +case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { + + override def output: Seq[Attribute] = child.output + + override def supportCodegen: Boolean = true + + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + val row = ctx.freshName("row") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns = exprs.map(_.gen(ctx)) + val code = s""" + | while (input.hasNext()) { + | InternalRow $row = (InternalRow) input.next(); + | ${columns.map(_.code).mkString("\n")} + | ${consume(ctx, columns)} + | } + """.stripMargin + (child.execute(), code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + + override def simpleString: String = "INPUT" +} + +/** + * WholeStageCodegen compile a subtree of plans that support codegen together into single Java + * function. + * + * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): + * + * WholeStageCodegen Plan A FakeInput Plan B + * ========================================================================= + * + * -> execute() + * | + * doExecute() --------> produce() + * | + * doProduce() -------> produce() + * | + * doProduce() ---> execute() + * | + * consume() + * doConsume() ------------| + * | + * doConsume() <----- consume() + * + * SparkPlan A should override doProduce() and doConsume(). + * + * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, + * used to generated code for BoundReference. + */ +case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) + extends SparkPlan with CodegenSupport { + + override def output: Seq[Attribute] = plan.output + + override def doExecute(): RDD[InternalRow] = { + val ctx = new CodegenContext + val (rdd, code) = plan.produce(ctx, this) + val references = ctx.references.toArray + val source = s""" + public Object generate(Object[] references) { + return new GeneratedIterator(references); + } + + class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + + private Object[] references; + ${ctx.declareMutableStates()} + + public GeneratedIterator(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} + } + + protected void processNext() { + $code + } + } + """ + // try to compile, helpful for debug + // println(s"${CodeFormatter.format(source)}") + CodeGenerator.compile(source) + + rdd.mapPartitions { iter => + val clazz = CodeGenerator.compile(source) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.setInput(iter) + new Iterator[InternalRow] { + override def hasNext: Boolean = buffer.hasNext + override def next: InternalRow = buffer.next() + } + } + } + + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + throw new UnsupportedOperationException + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } + } + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + + builder.append(simpleString) + builder.append("\n") + + plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + + builder + } + + override def simpleString: String = "WholeStageCodegen" +} + + +/** + * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + */ +private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { + + private def supportCodegen(plan: SparkPlan): Boolean = plan match { + case plan: CodegenSupport if plan.supportCodegen => + // Non-leaf with CodegenFallback does not work with whole stage codegen + val willFallback = plan.expressions.exists( + _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined + ) + // the generated code will be huge if there are too many columns + val haveManyColumns = plan.output.length > 200 + !willFallback && !haveManyColumns + case _ => false + } + + def apply(plan: SparkPlan): SparkPlan = { + if (sqlContext.conf.wholeStageEnabled) { + plan.transform { + case plan: CodegenSupport if supportCodegen(plan) && + // Whole stage codegen is only useful when there are at least two levels of operators that + // support it (save at least one projection/iterator). + plan.children.exists(supportCodegen) => + + var inputs = ArrayBuffer[SparkPlan]() + val combined = plan.transform { + case p if !supportCodegen(p) => + inputs += p + InputAdapter(p) + }.asInstanceOf[CodegenSupport] + WholeStageCodegen(combined, inputs) + } + } else { + plan + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 92c9a561312ba..9e2e0357c65f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -22,19 +22,37 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler -case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) + extends UnaryNode with CodegenSupport { override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) override def output: Seq[Attribute] = projectList.map(_.toAttribute) + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + val exprs = projectList.map(x => + ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) + ctx.currentVars = input + val output = exprs.map(_.gen(ctx)) + s""" + | ${output.map(_.code).mkString("\n")} + | + | ${consume(ctx, output)} + """.stripMargin + } + protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitionsInternal { iter => @@ -51,13 +69,30 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends } -case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { +case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output private[sql] override lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + val expr = ExpressionCanonicalizer.execute( + BindReferences.bindReference(condition, child.output)) + ctx.currentVars = input + val eval = expr.gen(ctx) + s""" + | ${eval.code} + | if (!${eval.isNull} && ${eval.value}) { + | ${consume(ctx, ctx.currentVars)} + | } + """.stripMargin + } + protected override def doExecute(): RDD[InternalRow] = { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") @@ -116,7 +151,80 @@ case class Range( numSlices: Int, numElements: BigInt, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with CodegenSupport { + + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initTerm = ctx.freshName("range_initRange") + ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + val partitionEnd = ctx.freshName("range_partitionEnd") + ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") + val number = ctx.freshName("range_number") + ctx.addMutableState("long", number, s"$number = 0L;") + val overflow = ctx.freshName("range_overflow") + ctx.addMutableState("boolean", overflow, s"$overflow = false;") + + val value = ctx.freshName("range_value") + val ev = ExprCode("", "false", value) + val BigInt = classOf[java.math.BigInteger].getName + val checkEnd = if (step > 0) { + s"$number < $partitionEnd" + } else { + s"$number > $partitionEnd" + } + + val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) + + val code = s""" + | // initialize Range + | if (!$initTerm) { + | $initTerm = true; + | if (input.hasNext()) { + | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } else { + | return; + | } + | } + | + | while (!$overflow && $checkEnd) { + | long $value = $number; + | $number += ${step}L; + | if ($number < $value ^ ${step}L < 0) { + | $overflow = true; + | } + | ${consume(ctx, Seq(ev))} + | } + """.stripMargin + + (rdd, code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } protected override def doExecute(): RDD[InternalRow] = { sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 7888e34e8a83a..72eb1f6cf0518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -143,14 +143,14 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private DataType[] columnTypes = null; private int[] columnIndexes = null; - ${declareMutableStates(ctx)} + ${ctx.declareMutableStates()} public SpecificColumnarIterator() { this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; this.mutableRow = new MutableUnsafeRow(rowWriter); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { @@ -190,6 +190,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") - compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] + CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 89b9a687682d9..e8d0678989d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -36,12 +36,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan - executedPlan.collect { + val plan = sqlContext.table(tableName).queryExecution.sparkPlan + plan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id case _ => - fail(s"Table $tableName is not cached\n" + executedPlan) + fail(s"Table $tableName is not cached\n" + plan) }.head } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index eb4efcd1d4e41..b349bb6dc9ce9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -629,7 +629,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { - val projects = df.queryExecution.executedPlan.collect { + val projects = df.queryExecution.sparkPlan.collect { case tungstenProject: Project => tungstenProject } assert(projects.size === expectedNumProjects) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39a65413bd592..c17be8ace9287 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -123,15 +123,15 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // equijoin - should be converted into broadcast join - val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + val plan1 = df1.join(broadcast(df2), "key").queryExecution.sparkPlan assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) // no join key -- should not be a broadcast join - val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) // planner should not crash without a join - broadcast(df1).queryExecution.executedPlan + broadcast(df1).queryExecution.sparkPlan // SPARK-12275: no physical plan for BroadcastHint in some condition withTempPath { path => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 75e81b9c9174d..bdb9421cc19d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -247,7 +247,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { val df = sql(sqlText) // First, check if we have GeneratedAggregate. - val hasGeneratedAgg = df.queryExecution.executedPlan + val hasGeneratedAgg = df.queryExecution.sparkPlan .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { @@ -792,11 +792,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-11111 null-safe join should not use cartesian product") { val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)") - val cp = df.queryExecution.executedPlan.collect { + val cp = df.queryExecution.sparkPlan.collect { case cp: CartesianProduct => cp } assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") - val smj = df.queryExecution.executedPlan.collect { + val smj = df.queryExecution.sparkPlan.collect { case smj: SortMergeJoin => smj } assert(smj.size > 0, "should use SortMergeJoin") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala new file mode 100644 index 0000000000000..788b04fcf8c2e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure whole stage codegen performance. + * To run this: + * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" + */ +class BenchmarkWholeStageCodegen extends SparkFunSuite { + def testWholeStage(values: Int): Unit = { + val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + val sc = SparkContext.getOrCreate(conf) + val sqlContext = SQLContext.getOrCreate(sc) + + val benchmark = new Benchmark("Single Int Column Scan", values) + + benchmark.addCase("Without whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).filter("(id & 1) = 1").count() + } + + benchmark.addCase("With whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).filter("(id & 1) = 1").count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + Without whole stage codegen 6725.52 31.18 1.00 X + With whole stage codegen 2233.05 93.91 3.01 X + */ + benchmark.run() + } + + ignore("benchmark") { + testWholeStage(1024 * 1024 * 200) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 03a1b8e11d455..49feeaf17d68f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -94,7 +94,7 @@ class PlannerSuite extends SharedSQLContext { """ |SELECT l.a, l.b |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan + """.stripMargin).queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } @@ -147,7 +147,7 @@ class PlannerSuite extends SharedSQLContext { val a = testData.as("a") val b = sqlContext.table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } @@ -168,7 +168,7 @@ class PlannerSuite extends SharedSQLContext { sqlContext.registerDataFrameAsTable(df, "testPushed") withTempTable("testPushed") { - val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan + val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala new file mode 100644 index 0000000000000..c54fc6ba2de3d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext + +class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { + + test("range/filter should be combined") { + val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val plan = df.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined) + + checkThatPlansAgree( + sqlContext.range(100), + (p: SparkPlan) => + WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()), + (p: SparkPlan) => Filter('a == 1, p), + sortAnswers = false + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 25afed25c897b..6e21d5a06150e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -48,7 +48,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index d762f7bfe914c..647a7e9a4e196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -114,7 +114,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 58581d71e1bc1..aee8e84db56e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -62,7 +62,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = df3.queryExecution.executedPlan + val plan = df3.queryExecution.sparkPlan assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 9b37dd1103764..11863caffed75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -30,12 +30,12 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { import hiveContext._ def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan - executedPlan.collect { + val plan = table(tableName).queryExecution.sparkPlan + plan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id case _ => - fail(s"Table $tableName is not cached\n" + executedPlan) + fail(s"Table $tableName is not cached\n" + plan) }.head } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index fd3339a66bec0..2e0a8698e6ffe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -485,7 +485,7 @@ abstract class HiveComparisonTest val executions = queryList.map(new TestHive.QueryExecution(_)) executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { - case (q, e) => e.executedPlan.collect { + case (q, e) => e.sparkPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => (q, e, i) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 5bd323ea096a4..d2f91861ff73b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -43,7 +43,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { + val project = TestHive.sql(q).queryExecution.sparkPlan.collect { case e: Project => e }.head diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 210d566745415..b91248bfb3fc0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -144,7 +144,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.QueryExecution(sql).executedPlan + val plan = new TestHive.QueryExecution(sql).sparkPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index ed544c638058c..c997453803b09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -190,11 +190,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test(s"conversion is working") { assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect { case _: HiveTableScan => true }.isEmpty) assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect { case _: PhysicalRDD => true }.nonEmpty) } @@ -305,7 +305,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { + df.queryExecution.sparkPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation].getCanonicalName} and " + @@ -335,7 +335,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { + df.queryExecution.sparkPlan match { case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation].getCanonicalName} and " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index e866493ee6c96..ba2a483bba534 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -149,7 +149,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) val df = sqlContext.read.parquet(path).filter('a === 0).select('b) - val physicalPlan = df.queryExecution.executedPlan + val physicalPlan = df.queryExecution.sparkPlan assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 058c101eebb04..9ab3e11609cec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -156,9 +156,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") { val df = partitionedDF.where(filter).select(projections: _*) val queryExecution = df.queryExecution - val executedPlan = queryExecution.executedPlan + val sparkPlan = queryExecution.sparkPlan - val rawScan = executedPlan.collect { + val rawScan = sparkPlan.collect { case p: PhysicalRDD => p } match { case Seq(scan) => scan @@ -177,7 +177,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat assert(requiredColumns === SimpleTextRelation.requiredColumns) val nonPushedFilters = { - val boundFilters = executedPlan.collect { + val boundFilters = sparkPlan.collect { case f: execution.Filter => f } match { case Nil => Nil From cede7b2a1134a6c93aff20ed5625054d988d3659 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 17 Jan 2016 09:11:43 -0800 Subject: [PATCH 1595/2089] [SPARK-12860] [SQL] speed up safe projection for primitive types The idea is simple, use `SpecificMutableRow` instead of `GenericMutableRow` as result row for safe projection. A simple benchmark shows about 1.5x speed up for primitive types, code: https://gist.github.com/cloud-fan/fa77713ccebf0823b2ab#file-safeprojectionbenchmark-scala Author: Wenchen Fan Closes #10790 from cloud-fan/safe-projection. --- .../expressions/codegen/GenerateSafeProjection.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index e750ad9c184b2..4cb6af9d9fe9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] public SpecificSafeProjection(Object[] references) { this.references = references; - mutableRow = new $genericMutableRowType(${expressions.size}); + mutableRow = (MutableRow) references[references.length - 1]; ${ctx.initMutableStates()} } @@ -180,6 +180,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[Projection] + val resultRow = new SpecificMutableRow(expressions.map(_.dataType)) + c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } From 92502703f4a29c706539f5ba47fd58b6fc41c14d Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 17 Jan 2016 09:29:08 -0800 Subject: [PATCH 1596/2089] [SPARK-12862][SPARKR] Jenkins does not run R tests Slight correction: I'm leaving sparkR as-is (ie. R file not supported) and fixed only run-tests.sh as shivaram described. I also assume we are going to cover all doc changes in https://issues.apache.org/jira/browse/SPARK-12846 instead of here. rxin shivaram zjffdu Author: felixcheung Closes #10792 from felixcheung/sparkRcmd. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- R/run-tests.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 27ad9f3958362..67ecdbc522d23 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1781,7 +1781,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(x), "map") df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") - expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) df1 <- select(df, cast(df$age, "integer")) coltypes(df) <- c("character", "integer") diff --git a/R/run-tests.sh b/R/run-tests.sh index e64a4ea94c584..9dcf0ace7d97e 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then From bc36b0f1a1f38060700903be7bf5cf012b414551 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 17 Jan 2016 11:02:37 -0800 Subject: [PATCH 1597/2089] [SQL] [MINOR] speed up hashcode for UTF8String similar to https://github.com/apache/spark/pull/10784, use `Murmur3_x86_32.hashUnsafeBytes` instead. Author: Wenchen Fan Closes #10791 from cloud-fan/string-hashcode. --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5b61386808769..87706d0b68388 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -31,6 +31,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import static org.apache.spark.unsafe.Platform.*; @@ -935,11 +936,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - int result = 1; - for (int i = 0; i < numBytes; i ++) { - result = 31 * result + getByte(i); - } - return result; + return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); } /** From 233d6cee96bb4c1723a5ab36efd19fd6180d651c Mon Sep 17 00:00:00 2001 From: Tommy YU Date: Mon, 18 Jan 2016 13:46:14 +0000 Subject: [PATCH 1598/2089] [SPARK-10264][DOCUMENTATION] Added @Since to ml.recomendation I create new pr since original pr long time no update. Please help to review. srowen Author: Tommy YU Closes #10756 from Wenpei/add_since_to_recomm. --- .../apache/spark/ml/recommendation/ALS.scala | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 472c1854d3d1f..1481c8268a198 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -180,22 +180,27 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ @Experimental +@Since("1.3.0") class ALSModel private[ml] ( - override val uid: String, - val rank: Int, + @Since("1.4.0") override val uid: String, + @Since("1.4.0") val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ + @Since("1.4.0") def setUserCol(value: String): this.type = set(userCol, value) /** @group setParam */ + @Since("1.4.0") def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ + @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.3.0") override def transform(dataset: DataFrame): DataFrame = { // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. @@ -213,6 +218,7 @@ class ALSModel private[ml] ( predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } + @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) @@ -220,6 +226,7 @@ class ALSModel private[ml] ( SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } + @Since("1.5.0") override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) @@ -303,65 +310,83 @@ object ALSModel extends MLReadable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams +@Since("1.3.0") +class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating + @Since("1.4.0") def this() = this(Identifiable.randomUID("als")) /** @group setParam */ + @Since("1.3.0") def setRank(value: Int): this.type = set(rank, value) /** @group setParam */ + @Since("1.3.0") def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) /** @group setParam */ + @Since("1.3.0") def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) /** @group setParam */ + @Since("1.3.0") def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) /** @group setParam */ + @Since("1.3.0") def setAlpha(value: Double): this.type = set(alpha, value) /** @group setParam */ + @Since("1.3.0") def setUserCol(value: String): this.type = set(userCol, value) /** @group setParam */ + @Since("1.3.0") def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ + @Since("1.3.0") def setRatingCol(value: String): this.type = set(ratingCol, value) /** @group setParam */ + @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.3.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ + @Since("1.3.0") def setRegParam(value: Double): this.type = set(regParam, value) /** @group setParam */ + @Since("1.3.0") def setNonnegative(value: Boolean): this.type = set(nonnegative, value) /** @group setParam */ + @Since("1.4.0") def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ + @Since("1.3.0") def setSeed(value: Long): this.type = set(seed, value) /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @group setParam */ + @Since("1.3.0") def setNumBlocks(value: Int): this.type = { setNumUserBlocks(value) setNumItemBlocks(value) this } + @Since("1.3.0") override def fit(dataset: DataFrame): ALSModel = { import dataset.sqlContext.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) @@ -381,10 +406,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams copyValues(model) } + @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): ALS = defaultCopy(extra) } From db9a860589bfc4f80d6cdf174a577ca538b82e6d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Jan 2016 10:28:01 -0800 Subject: [PATCH 1599/2089] [SPARK-12558][FOLLOW-UP] AnalysisException when multiple functions applied in GROUP BY clause Addresses the comments from Yin. https://github.com/apache/spark/pull/10520 Author: Dilip Biswal Closes #10758 from dilipbiswal/spark-12558-followup. --- .../spark/sql/hive/execution/HiveUDFSuite.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index dfe33ba8b0502..af76ff91a267c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -22,7 +22,7 @@ import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.UDAFPercentile -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF, GenericUDFOPAnd, GenericUDTFExplode} +import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -351,10 +351,14 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Hive UDF in group by") { - Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1") - val count = sql("select date(cast(test_date as timestamp))" + - " from tab1 group by date(cast(test_date as timestamp))").count() - assert(count == 1) + withTempTable("tab1") { + Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1") + sql(s"CREATE TEMPORARY FUNCTION testUDFToDate AS '${classOf[GenericUDFToDate].getName}'") + val count = sql("select testUDFToDate(cast(test_date as timestamp))" + + " from tab1 group by testUDFToDate(cast(test_date as timestamp))").count() + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToDate") + assert(count == 1) + } } test("SPARK-11522 select input_file_name from non-parquet table"){ From 44fcf992aa516153a43d7141d3b8e092f0671ba4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 18 Jan 2016 11:08:44 -0800 Subject: [PATCH 1600/2089] [SPARK-12873][SQL] Add more comment in HiveTypeCoercion for type widening I was reading this part of the analyzer code again and got confused by the difference between findWiderTypeForTwo and findTightestCommonTypeOfTwo. I also simplified WidenSetOperationTypes to make it a lot simpler. The easiest way to review this one is to just read the original code, and the new code. The logic is super simple. Author: Reynold Xin Closes #10802 from rxin/SPARK-12873. --- .../catalyst/analysis/HiveTypeCoercion.scala | 86 +++++++++++-------- .../analysis/HiveTypeCoercionSuite.scala | 3 +- 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2737fe32cd086..7df3787e6d2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -27,10 +27,20 @@ import org.apache.spark.sql.types._ /** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. + * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in + * operations into compatible ones. + * + * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on + * the hive codebase. + * + * Notes about type widening / tightest common types: Broadly, there are two cases when we need + * to widen data types (e.g. union, binary comparison). In case 1, we are looking for a common + * data type for two or more data types, and in this case no loss of precision is allowed. Examples + * include type inference in JSON (e.g. what's the column's data type if one row is an integer + * while the other row is a long?). In case 2, we are looking for a widened data type with + * some acceptable loss of precision (e.g. there is no common type for double and decimal because + * double's range is larger than decimal, and yet decimal is more precise than double, but in + * union we would cast the decimal into double). */ object HiveTypeCoercion { @@ -63,6 +73,8 @@ object HiveTypeCoercion { DoubleType) /** + * Case 1 type widening (see the classdoc comment above for HiveTypeCoercion). + * * Find the tightest common type of two types that might be used in a binary expression. * This handles all numeric types except fixed-precision decimals interacting with each other or * with primitive types, because in that case the precision and scale of the result depends on @@ -118,6 +130,12 @@ object HiveTypeCoercion { }) } + /** + * Case 2 type widening (see the classdoc comment above for HiveTypeCoercion). + * + * i.e. the main difference with [[findTightestCommonTypeOfTwo]] is that here we allow some + * loss of precision when widening decimal and double. + */ private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { case (t1: DecimalType, t2: DecimalType) => Some(DecimalPrecision.widerDecimalType(t1, t2)) @@ -125,9 +143,7 @@ object HiveTypeCoercion { Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) case (d: DecimalType, t: IntegralType) => Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (t: FractionalType, d: DecimalType) => - Some(DoubleType) - case (d: DecimalType, t: FractionalType) => + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => Some(DoubleType) case _ => findTightestCommonTypeToString(t1, t2) @@ -200,41 +216,37 @@ object HiveTypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - private[this] def widenOutputTypes( - planName: String, - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - require(left.output.length == right.output.length) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p - val castedTypes = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - findWiderTypeForTwo(lhs.dataType, rhs.dataType) - case other => None - } + case s @ SetOperation(left, right) if s.childrenResolved + && left.output.length == right.output.length && !s.resolved => - def castOutput(plan: LogicalPlan): LogicalPlan = { - val casted = plan.output.zip(castedTypes).map { - case (e, Some(dt)) if e.dataType != dt => - Alias(Cast(e, dt), e.name)() - case (e, _) => e + // Tracks the list of data types to widen. + // Some(dataType) means the right-hand side and the left-hand side have different types, + // and there is a target type to widen both sides to. + val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + findWiderTypeForTwo(lhs.dataType, rhs.dataType) + case other => None } - Project(casted, plan) - } - if (castedTypes.exists(_.isDefined)) { - (castOutput(left), castOutput(right)) - } else { - (left, right) - } + if (targetTypes.exists(_.isDefined)) { + // There is at least one column to widen. + s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes))) + } else { + // If we cannot find any column to widen, then just return the original set. + s + } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if p.analyzed => p - - case s @ SetOperation(left, right) if s.childrenResolved - && left.output.length == right.output.length && !s.resolved => - val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) - s.makeCopy(Array(newLeft, newRight)) + /** Given a plan, add an extra project on top to widen some columns' data types. */ + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = { + val casted = plan.output.zip(targetTypes).map { + case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + Project(casted, plan) } } @@ -372,8 +384,6 @@ object HiveTypeCoercion { * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE - * - * Note: Union/Except/Interact is handled by WidenTypes */ // scalastyle:on object DecimalPrecision extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b1f6c0b802d8e..b326aa9c55992 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -387,7 +387,7 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenSetOperationTypes for union except and intersect") { + test("WidenSetOperationTypes for union, except, and intersect") { def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) @@ -499,7 +499,6 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) } - /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. From 5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 18 Jan 2016 12:50:58 -0800 Subject: [PATCH 1601/2089] [SPARK-12346][ML] Missing attribute names in GLM for vector-type features Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names. cc mengxr Author: Eric Liang Closes #10323 from ericl/spark-12346. --- .../spark/ml/feature/VectorAssembler.scala | 6 +-- .../spark/ml/feature/RFormulaSuite.scala | 38 +++++++++++++++++++ .../ml/feature/VectorAssemblerSuite.scala | 4 +- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 716bc63e00995..7ff5ad143f80b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -70,19 +70,19 @@ class VectorAssembler(override val uid: String) val group = AttributeGroup.fromStructField(field) if (group.attributes.isDefined) { // If attributes are defined, copy them with updated names. - group.attributes.get.map { attr => + group.attributes.get.zipWithIndex.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) } else { - attr + attr.withName(c + "_" + i) } } } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.fill(numAttrs)(NumericAttribute.defaultAttr) + Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index dc20a5ec2152d..16e565d8b588b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(attrs === expectedAttrs) } + test("vector attribute generation") { + val formula = new RFormula().setFormula("id ~ vec") + val original = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec_0"), Some(1)), + new NumericAttribute(Some("vec_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("vector attribute generation with unnamed input attrs") { + val formula = new RFormula().setFormula("id ~ vec2") + val base = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val metadata = new AttributeGroup( + "vec2", + Array[Attribute]( + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr)).toMetadata + val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec2_0"), Some(1)), + new NumericAttribute(Some("vec2_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = sqlContext.createDataFrame( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index f7de7c1e93fb2..dce994fdbd056 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -111,8 +111,8 @@ class VectorAssemblerSuite assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) val userSalaryOut = features.getAttr(4) assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) - assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) - assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0")) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1")) } test("read/write") { From 302bb569f3e1f09e2e7cea5e4e7f5c6953b0fc82 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 18 Jan 2016 13:27:18 -0800 Subject: [PATCH 1602/2089] [SPARK-12884] Move classes to their own files for readability This is a small step in implementing SPARK-10620, which migrates `TaskMetrics` to accumulators. This patch is strictly a cleanup patch and introduces no change in functionality. It literally just moves classes to their own files to avoid having single monolithic ones that contain 10 different classes. Parent PR: #10717 Author: Andrew Or Closes #10810 from andrewor14/move-things. --- .../{Accumulators.scala => Accumulable.scala} | 179 +---------------- .../scala/org/apache/spark/Accumulator.scala | 160 +++++++++++++++ .../apache/spark/InternalAccumulator.scala | 58 ++++++ .../apache/spark/executor/InputMetrics.scala | 77 ++++++++ .../apache/spark/executor/OutputMetrics.scala | 53 +++++ .../spark/executor/ShuffleReadMetrics.scala | 87 ++++++++ .../spark/executor/ShuffleWriteMetrics.scala | 53 +++++ .../apache/spark/executor/TaskMetrics.scala | 186 +----------------- 8 files changed, 493 insertions(+), 360 deletions(-) rename core/src/main/scala/org/apache/spark/{Accumulators.scala => Accumulable.scala} (54%) create mode 100644 core/src/main/scala/org/apache/spark/Accumulator.scala create mode 100644 core/src/main/scala/org/apache/spark/InternalAccumulator.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/InputMetrics.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala similarity index 54% rename from core/src/main/scala/org/apache/spark/Accumulators.scala rename to core/src/main/scala/org/apache/spark/Accumulable.scala index 5592b75afb75b..a456d420b8d6a 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -20,14 +20,12 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable -import scala.collection.Map -import scala.collection.mutable -import scala.ref.WeakReference import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils + /** * A data type that can be accumulated, ie has an commutative and associative "add" operation, * but where the result type, `R`, may be different from the element type being added, `T`. @@ -166,6 +164,7 @@ class Accumulable[R, T] private[spark] ( override def toString: String = if (value_ == null) "null" else value_.toString } + /** * Helper object defining how to accumulate values of a particular type. An implicit * AccumulableParam needs to be available when you create [[Accumulable]]s of a specific type. @@ -201,6 +200,7 @@ trait AccumulableParam[R, T] extends Serializable { def zero(initialValue: R): R } + private[spark] class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] extends AccumulableParam[R, T] { @@ -224,176 +224,3 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa copy } } - -/** - * A simpler value of [[Accumulable]] where the result type being accumulated is the same - * as the types of elements being merged, i.e. variables that are only "added" to through an - * associative operation and can therefore be efficiently supported in parallel. They can be used - * to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric - * value types, and programmers can add support for new types. - * - * An accumulator is created from an initial value `v` by calling [[SparkContext#accumulator]]. - * Tasks running on the cluster can then add to it using the [[Accumulable#+=]] operator. - * However, they cannot read its value. Only the driver program can read the accumulator's value, - * using its value method. - * - * The interpreter session below shows an accumulator being used to add up the elements of an array: - * - * {{{ - * scala> val accum = sc.accumulator(0) - * accum: spark.Accumulator[Int] = 0 - * - * scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x) - * ... - * 10/09/29 18:41:08 INFO SparkContext: Tasks finished in 0.317106 s - * - * scala> accum.value - * res2: Int = 10 - * }}} - * - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `T` - * @tparam T result type - */ -class Accumulator[T] private[spark] ( - @transient private[spark] val initialValue: T, - param: AccumulatorParam[T], - name: Option[String], - internal: Boolean) - extends Accumulable[T, T](initialValue, param, name, internal) { - - def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { - this(initialValue, param, name, false) - } - - def this(initialValue: T, param: AccumulatorParam[T]) = { - this(initialValue, param, None, false) - } -} - -/** - * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add - * in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be - * available when you create Accumulators of a specific type. - * - * @tparam T type of value to accumulate - */ -trait AccumulatorParam[T] extends AccumulableParam[T, T] { - def addAccumulator(t1: T, t2: T): T = { - addInPlace(t1, t2) - } -} - -object AccumulatorParam { - - // The following implicit objects were in SparkContext before 1.2 and users had to - // `import SparkContext._` to enable them. Now we move them here to make the compiler find - // them automatically. However, as there are duplicate codes in SparkContext for backward - // compatibility, please update them accordingly if you modify the following implicit objects. - - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double): Double = 0.0 - } - - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { - def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int): Int = 0 - } - - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long): Long = t1 + t2 - def zero(initialValue: Long): Long = 0L - } - - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float): Float = t1 + t2 - def zero(initialValue: Float): Float = 0f - } - - // TODO: Add AccumulatorParams for other types, e.g. lists and strings -} - -// TODO: The multi-thread support in accumulators is kind of lame; check -// if there's a more intuitive way of doing it right -private[spark] object Accumulators extends Logging { - /** - * This global map holds the original accumulator objects that are created on the driver. - * It keeps weak references to these objects so that accumulators can be garbage-collected - * once the RDDs and user-code that reference them are cleaned up. - */ - val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() - - private var lastId: Long = 0 - - def newId(): Long = synchronized { - lastId += 1 - lastId - } - - def register(a: Accumulable[_, _]): Unit = synchronized { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) - } - - def remove(accId: Long) { - synchronized { - originals.remove(accId) - } - } - - // Add values to the original accumulators with some given IDs - def add(values: Map[Long, Any]): Unit = synchronized { - for ((id, value) <- values) { - if (originals.contains(id)) { - // Since we are now storing weak references, we must check whether the underlying data - // is valid. - originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value - case None => - throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") - } - } else { - logWarning(s"Ignoring accumulator update for unknown accumulator id $id") - } - } - } - -} - -private[spark] object InternalAccumulator { - val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" - val TEST_ACCUMULATOR = "testAccumulator" - - // For testing only. - // This needs to be a def since we don't want to reuse the same accumulator across stages. - private def maybeTestAccumulator: Option[Accumulator[Long]] = { - if (sys.props.contains("spark.testing")) { - Some(new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) - } else { - None - } - } - - /** - * Accumulators for tracking internal metrics. - * - * These accumulators are created with the stage such that all tasks in the stage will - * add to the same set of accumulators. We do this to report the distribution of accumulator - * values across all tasks within each stage. - */ - def create(sc: SparkContext): Seq[Accumulator[Long]] = { - val internalAccumulators = Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq - internalAccumulators.foreach { accumulator => - sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) - } - internalAccumulators - } -} diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala new file mode 100644 index 0000000000000..007136e6ae349 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.{mutable, Map} +import scala.ref.WeakReference + + +/** + * A simpler value of [[Accumulable]] where the result type being accumulated is the same + * as the types of elements being merged, i.e. variables that are only "added" to through an + * associative operation and can therefore be efficiently supported in parallel. They can be used + * to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric + * value types, and programmers can add support for new types. + * + * An accumulator is created from an initial value `v` by calling [[SparkContext#accumulator]]. + * Tasks running on the cluster can then add to it using the [[Accumulable#+=]] operator. + * However, they cannot read its value. Only the driver program can read the accumulator's value, + * using its value method. + * + * The interpreter session below shows an accumulator being used to add up the elements of an array: + * + * {{{ + * scala> val accum = sc.accumulator(0) + * accum: spark.Accumulator[Int] = 0 + * + * scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x) + * ... + * 10/09/29 18:41:08 INFO SparkContext: Tasks finished in 0.317106 s + * + * scala> accum.value + * res2: Int = 10 + * }}} + * + * @param initialValue initial value of accumulator + * @param param helper object defining how to add elements of type `T` + * @tparam T result type + */ +class Accumulator[T] private[spark] ( + @transient private[spark] val initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { + + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } + + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } +} + + +// TODO: The multi-thread support in accumulators is kind of lame; check +// if there's a more intuitive way of doing it right +private[spark] object Accumulators extends Logging { + /** + * This global map holds the original accumulator objects that are created on the driver. + * It keeps weak references to these objects so that accumulators can be garbage-collected + * once the RDDs and user-code that reference them are cleaned up. + */ + val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() + + private var lastId: Long = 0 + + def newId(): Long = synchronized { + lastId += 1 + lastId + } + + def register(a: Accumulable[_, _]): Unit = synchronized { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) + } + + def remove(accId: Long) { + synchronized { + originals.remove(accId) + } + } + + // Add values to the original accumulators with some given IDs + def add(values: Map[Long, Any]): Unit = synchronized { + for ((id, value) <- values) { + if (originals.contains(id)) { + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value + case None => + throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") + } + } else { + logWarning(s"Ignoring accumulator update for unknown accumulator id $id") + } + } + } + +} + + +/** + * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add + * in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be + * available when you create Accumulators of a specific type. + * + * @tparam T type of value to accumulate + */ +trait AccumulatorParam[T] extends AccumulableParam[T, T] { + def addAccumulator(t1: T, t2: T): T = { + addInPlace(t1, t2) + } +} + + +object AccumulatorParam { + + // The following implicit objects were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, as there are duplicate codes in SparkContext for backward + // compatibility, please update them accordingly if you modify the following implicit objects. + + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double): Double = 0.0 + } + + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int): Int = 0 + } + + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f + } + + // TODO: Add AccumulatorParams for other types, e.g. lists and strings +} diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala new file mode 100644 index 0000000000000..6ea997c079f33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + + +// This is moved to its own file because many more things will be added to it in SPARK-10620. +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } + + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala new file mode 100644 index 0000000000000..8f1d7f89a44b4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * Method by which input data was read. Network means that the data was read over the network + * from a remote block manager (which may have stored the data on-disk or in-memory). + */ +@DeveloperApi +object DataReadMethod extends Enumeration with Serializable { + type DataReadMethod = Value + val Memory, Disk, Hadoop, Network = Value +} + + +/** + * :: DeveloperApi :: + * Metrics about reading input data. + */ +@DeveloperApi +case class InputMetrics(readMethod: DataReadMethod.Value) { + + /** + * This is volatile so that it is visible to the updater thread. + */ + @volatile @transient var bytesReadCallback: Option[() => Long] = None + + /** + * Total bytes read. + */ + private var _bytesRead: Long = _ + def bytesRead: Long = _bytesRead + def incBytesRead(bytes: Long): Unit = _bytesRead += bytes + + /** + * Total records read. + */ + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + def incRecordsRead(records: Long): Unit = _recordsRead += records + + /** + * Invoke the bytesReadCallback and mutate bytesRead. + */ + def updateBytesRead() { + bytesReadCallback.foreach { c => + _bytesRead = c() + } + } + + /** + * Register a function that can be called to get up-to-date information on how many bytes the task + * has read from an input source. + */ + def setBytesReadCallback(f: Option[() => Long]) { + bytesReadCallback = f + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala new file mode 100644 index 0000000000000..ad132d004cde0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * Method by which output data was written. + */ +@DeveloperApi +object DataWriteMethod extends Enumeration with Serializable { + type DataWriteMethod = Value + val Hadoop = Value +} + + +/** + * :: DeveloperApi :: + * Metrics about writing output data. + */ +@DeveloperApi +case class OutputMetrics(writeMethod: DataWriteMethod.Value) { + /** + * Total bytes written + */ + private var _bytesWritten: Long = _ + def bytesWritten: Long = _bytesWritten + private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value + + /** + * Total records written + */ + private var _recordsWritten: Long = 0L + def recordsWritten: Long = _recordsWritten + private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value +} diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala new file mode 100644 index 0000000000000..e985b35ace623 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * Metrics pertaining to shuffle data read in a given task. + */ +@DeveloperApi +class ShuffleReadMetrics extends Serializable { + /** + * Number of remote blocks fetched in this shuffle by this task + */ + private var _remoteBlocksFetched: Int = _ + def remoteBlocksFetched: Int = _remoteBlocksFetched + private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value + private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + + /** + * Number of local blocks fetched in this shuffle by this task + */ + private var _localBlocksFetched: Int = _ + def localBlocksFetched: Int = _localBlocksFetched + private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value + private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value + + /** + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. + */ + private var _fetchWaitTime: Long = _ + def fetchWaitTime: Long = _fetchWaitTime + private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value + private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value + + /** + * Total number of remote bytes read from the shuffle by this task + */ + private var _remoteBytesRead: Long = _ + def remoteBytesRead: Long = _remoteBytesRead + private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value + private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + private var _localBytesRead: Long = _ + def localBytesRead: Long = _localBytesRead + private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead: Long = _remoteBytesRead + _localBytesRead + + /** + * Number of blocks fetched in this shuffle by this task (remote or local) + */ + def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched + + /** + * Total number of records read from the shuffle by this task + */ + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + private[spark] def incRecordsRead(value: Long) = _recordsRead += value + private[spark] def decRecordsRead(value: Long) = _recordsRead -= value +} diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala new file mode 100644 index 0000000000000..469ebe26c7b56 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * Metrics pertaining to shuffle data written in a given task. + */ +@DeveloperApi +class ShuffleWriteMetrics extends Serializable { + /** + * Number of bytes written for the shuffle by this task + */ + @volatile private var _shuffleBytesWritten: Long = _ + def shuffleBytesWritten: Long = _shuffleBytesWritten + private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value + private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value + + /** + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds + */ + @volatile private var _shuffleWriteTime: Long = _ + def shuffleWriteTime: Long = _shuffleWriteTime + private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value + private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value + + /** + * Total number of records written to the shuffle by this task + */ + @volatile private var _shuffleRecordsWritten: Long = _ + def shuffleRecordsWritten: Long = _shuffleRecordsWritten + private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value + private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value + private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value +} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 42207a9553592..ce1fcbff71208 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -27,6 +27,7 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.util.Utils + /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. @@ -241,6 +242,7 @@ class TaskMetrics extends Serializable { } } + private[spark] object TaskMetrics { private val hostNameCache = new ConcurrentHashMap[String, String]() @@ -251,187 +253,3 @@ private[spark] object TaskMetrics { if (canonicalHost != null) canonicalHost else host } } - -/** - * :: DeveloperApi :: - * Method by which input data was read. Network means that the data was read over the network - * from a remote block manager (which may have stored the data on-disk or in-memory). - */ -@DeveloperApi -object DataReadMethod extends Enumeration with Serializable { - type DataReadMethod = Value - val Memory, Disk, Hadoop, Network = Value -} - -/** - * :: DeveloperApi :: - * Method by which output data was written. - */ -@DeveloperApi -object DataWriteMethod extends Enumeration with Serializable { - type DataWriteMethod = Value - val Hadoop = Value -} - -/** - * :: DeveloperApi :: - * Metrics about reading input data. - */ -@DeveloperApi -case class InputMetrics(readMethod: DataReadMethod.Value) { - - /** - * This is volatile so that it is visible to the updater thread. - */ - @volatile @transient var bytesReadCallback: Option[() => Long] = None - - /** - * Total bytes read. - */ - private var _bytesRead: Long = _ - def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long): Unit = _bytesRead += bytes - - /** - * Total records read. - */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records - - /** - * Invoke the bytesReadCallback and mutate bytesRead. - */ - def updateBytesRead() { - bytesReadCallback.foreach { c => - _bytesRead = c() - } - } - - /** - * Register a function that can be called to get up-to-date information on how many bytes the task - * has read from an input source. - */ - def setBytesReadCallback(f: Option[() => Long]) { - bytesReadCallback = f - } -} - -/** - * :: DeveloperApi :: - * Metrics about writing output data. - */ -@DeveloperApi -case class OutputMetrics(writeMethod: DataWriteMethod.Value) { - /** - * Total bytes written - */ - private var _bytesWritten: Long = _ - def bytesWritten: Long = _bytesWritten - private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value - - /** - * Total records written - */ - private var _recordsWritten: Long = 0L - def recordsWritten: Long = _recordsWritten - private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value -} - -/** - * :: DeveloperApi :: - * Metrics pertaining to shuffle data read in a given task. - */ -@DeveloperApi -class ShuffleReadMetrics extends Serializable { - /** - * Number of remote blocks fetched in this shuffle by this task - */ - private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched: Int = _remoteBlocksFetched - private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - - /** - * Number of local blocks fetched in this shuffle by this task - */ - private var _localBlocksFetched: Int = _ - def localBlocksFetched: Int = _localBlocksFetched - private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value - - /** - * Time the task spent waiting for remote shuffle blocks. This only includes the time - * blocking on shuffle input data. For instance if block B is being fetched while the task is - * still not finished processing block A, it is not considered to be blocking on block B. - */ - private var _fetchWaitTime: Long = _ - def fetchWaitTime: Long = _fetchWaitTime - private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value - private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - - /** - * Total number of remote bytes read from the shuffle by this task - */ - private var _remoteBytesRead: Long = _ - def remoteBytesRead: Long = _remoteBytesRead - private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value - private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value - - /** - * Shuffle data that was read from the local disk (as opposed to from a remote executor). - */ - private var _localBytesRead: Long = _ - def localBytesRead: Long = _localBytesRead - private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value - - /** - * Total bytes fetched in the shuffle by this task (both remote and local). - */ - def totalBytesRead: Long = _remoteBytesRead + _localBytesRead - - /** - * Number of blocks fetched in this shuffle by this task (remote or local) - */ - def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched - - /** - * Total number of records read from the shuffle by this task - */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - private[spark] def incRecordsRead(value: Long) = _recordsRead += value - private[spark] def decRecordsRead(value: Long) = _recordsRead -= value -} - -/** - * :: DeveloperApi :: - * Metrics pertaining to shuffle data written in a given task. - */ -@DeveloperApi -class ShuffleWriteMetrics extends Serializable { - /** - * Number of bytes written for the shuffle by this task - */ - @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten: Long = _shuffleBytesWritten - private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value - private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - - /** - * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds - */ - @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime: Long = _shuffleWriteTime - private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value - private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - - /** - * Total number of records written to the shuffle by this task - */ - @volatile private var _shuffleRecordsWritten: Long = _ - def shuffleRecordsWritten: Long = _shuffleRecordsWritten - private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value - private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value - private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value -} From b8cb548a4394221f2b029c84c7f130774da69e3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 18 Jan 2016 13:34:12 -0800 Subject: [PATCH 1603/2089] [SPARK-10985][CORE] Avoid passing evicted blocks throughout BlockManager This patch refactors portions of the BlockManager and CacheManager in order to avoid having to pass `evictedBlocks` lists throughout the code. It appears that these lists were only consumed by `TaskContext.taskMetrics`, so the new code now directly updates the metrics from the lower-level BlockManager methods. Author: Josh Rosen Closes #10776 from JoshRosen/SPARK-10985. --- .../scala/org/apache/spark/CacheManager.scala | 20 ++--- .../apache/spark/memory/MemoryManager.scala | 18 +---- .../spark/memory/StaticMemoryManager.scala | 18 ++--- .../spark/memory/StorageMemoryPool.scala | 30 ++------ .../spark/memory/UnifiedMemoryManager.scala | 18 ++--- .../apache/spark/storage/BlockManager.scala | 71 +++++++++-------- .../apache/spark/storage/MemoryStore.scala | 77 +++++++------------ .../org/apache/spark/CacheManagerSuite.scala | 9 ++- .../spark/memory/MemoryManagerSuite.scala | 23 ++---- .../memory/StaticMemoryManagerSuite.scala | 24 +++--- .../spark/memory/TestMemoryManager.scala | 10 +-- .../memory/UnifiedMemoryManagerSuite.scala | 30 ++++---- .../spark/storage/BlockManagerSuite.scala | 55 +++++++------ .../receiver/ReceivedBlockHandler.scala | 8 +- 14 files changed, 170 insertions(+), 241 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 36b536e89c3a4..d92d8b0eef8a0 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -18,7 +18,6 @@ package org.apache.spark import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -68,12 +67,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { logInfo(s"Partition $key not found, computing it") val computedValues = rdd.computeOrReadCheckpoint(partition, context) - // Otherwise, cache the values and keep track of any updates in block statuses - val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks) - val metrics = context.taskMetrics - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ updatedBlocks.toSeq) + // Otherwise, cache the values + val cachedValues = putInBlockManager(key, computedValues, storageLevel) new InterruptibleIterator(context, cachedValues) } finally { @@ -135,7 +130,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { key: BlockId, values: Iterator[T], level: StorageLevel, - updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)], effectiveStorageLevel: Option[StorageLevel] = None): Iterator[T] = { val putLevel = effectiveStorageLevel.getOrElse(level) @@ -144,8 +138,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { * This RDD is not to be cached in memory, so we can just pass the computed values as an * iterator directly to the BlockManager rather than first fully unrolling it in memory. */ - updatedBlocks ++= - blockManager.putIterator(key, values, level, tellMaster = true, effectiveStorageLevel) + blockManager.putIterator(key, values, level, tellMaster = true, effectiveStorageLevel) blockManager.get(key) match { case Some(v) => v.data.asInstanceOf[Iterator[T]] case None => @@ -163,11 +156,10 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { * single partition. Instead, we unroll the values cautiously, potentially aborting and * dropping the partition to disk if applicable. */ - blockManager.memoryStore.unrollSafely(key, values, updatedBlocks) match { + blockManager.memoryStore.unrollSafely(key, values) match { case Left(arr) => // We have successfully unrolled the entire partition, so cache it in memory - updatedBlocks ++= - blockManager.putArray(key, arr, level, tellMaster = true, effectiveStorageLevel) + blockManager.putArray(key, arr, level, tellMaster = true, effectiveStorageLevel) arr.iterator.asInstanceOf[Iterator[T]] case Right(it) => // There is not enough space to cache this partition in memory @@ -176,7 +168,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { logWarning(s"Persisting partition $key to disk instead.") val diskOnlyLevel = StorageLevel(useDisk = true, useMemory = false, useOffHeap = false, deserialized = false, putLevel.replication) - putInBlockManager[T](key, returnValues, level, updatedBlocks, Some(diskOnlyLevel)) + putInBlockManager[T](key, returnValues, level, Some(diskOnlyLevel)) } else { returnValues } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 33f8b9f16c11b..b5adbd88a2c23 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -19,10 +19,8 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import org.apache.spark.storage.{BlockId, MemoryStore} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator @@ -67,17 +65,11 @@ private[spark] abstract class MemoryManager( storageMemoryPool.setMemoryStore(store) } - // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) - /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. * @return whether all N bytes were successfully granted. */ - def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + def acquireStorageMemory(blockId: BlockId, numBytes: Long): Boolean /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -85,14 +77,10 @@ private[spark] abstract class MemoryManager( * This extra method allows subclasses to differentiate behavior between acquiring storage * memory and acquiring unroll memory. For instance, the memory management model in Spark * 1.5 and before places a limit on the amount of space that can be freed from unrolling. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. * * @return whether all N bytes were successfully granted. */ - def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + def acquireUnrollMemory(blockId: BlockId, numBytes: Long): Boolean /** * Try to acquire up to `numBytes` of execution memory for the current task and return the diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 3554b558f2123..f9f8f820bc49c 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -17,10 +17,8 @@ package org.apache.spark.memory -import scala.collection.mutable - import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.storage.BlockId /** * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. @@ -53,24 +51,18 @@ private[spark] class StaticMemoryManager( (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } - override def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + override def acquireStorageMemory(blockId: BlockId, numBytes: Long): Boolean = synchronized { if (numBytes > maxStorageMemory) { // Fail fast if the block simply won't fit logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + s"memory limit ($maxStorageMemory bytes)") false } else { - storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + storageMemoryPool.acquireMemory(blockId, numBytes) } } - override def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + override def acquireUnrollMemory(blockId: BlockId, numBytes: Long): Boolean = synchronized { val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory val freeMemory = storageMemoryPool.memoryFree // When unrolling, we will use all of the existing free memory, and, if necessary, @@ -80,7 +72,7 @@ private[spark] class StaticMemoryManager( val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) // Keep it within the range 0 <= X <= maxNumBytesToFree val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) - storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree) } private[memory] diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 4036484aada23..6a88966f60d23 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -19,11 +19,8 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, MemoryStore} /** * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage @@ -58,15 +55,11 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. * @return whether all N bytes were successfully granted. */ - def acquireMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + def acquireMemory(blockId: BlockId, numBytes: Long): Boolean = lock.synchronized { val numBytesToFree = math.max(0, numBytes - memoryFree) - acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + acquireMemory(blockId, numBytes, numBytesToFree) } /** @@ -80,19 +73,12 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w def acquireMemory( blockId: BlockId, numBytesToAcquire: Long, - numBytesToFree: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + numBytesToFree: Long): Boolean = lock.synchronized { assert(numBytesToAcquire >= 0) assert(numBytesToFree >= 0) assert(memoryUsed <= poolSize) if (numBytesToFree > 0) { - memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) - } + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree) } // NOTE: If the memory store evicts blocks, then those evictions will synchronously call // back into this StorageMemoryPool in order to free memory. Therefore, these variables @@ -129,9 +115,7 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, evictedBlocks) - val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum + val spaceFreedByEviction = memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree) // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. decrementPoolSize(spaceFreedByEviction) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 57a24ac140287..a3321e3f179f6 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -17,10 +17,8 @@ package org.apache.spark.memory -import scala.collection.mutable - import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.storage.BlockId /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that @@ -133,10 +131,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( } } - override def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + override def acquireStorageMemory(blockId: BlockId, numBytes: Long): Boolean = synchronized { assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) if (numBytes > maxStorageMemory) { @@ -152,14 +147,11 @@ private[spark] class UnifiedMemoryManager private[memory] ( onHeapExecutionMemoryPool.decrementPoolSize(memoryBorrowedFromExecution) storageMemoryPool.incrementPoolSize(memoryBorrowedFromExecution) } - storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + storageMemoryPool.acquireMemory(blockId, numBytes) } - override def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, evictedBlocks) + override def acquireUnrollMemory(blockId: BlockId, numBytes: Long): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e49d79b8ad66e..e0a8e88df224a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -612,12 +612,16 @@ private[spark] class BlockManager( None } + /** + * @return true if the block was stored or false if the block was already stored or an + * error occurred. + */ def putIterator( blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { + effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { require(values != null, "Values is null") doPut(blockId, IteratorValues(values), level, tellMaster, effectiveStorageLevel) } @@ -641,28 +645,32 @@ private[spark] class BlockManager( /** * Put a new block of values to the block manager. - * Return a list of blocks updated as a result of this put. + * + * @return true if the block was stored or false if the block was already stored or an + * error occurred. */ def putArray( blockId: BlockId, values: Array[Any], level: StorageLevel, tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { + effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { require(values != null, "Values is null") doPut(blockId, ArrayValues(values), level, tellMaster, effectiveStorageLevel) } /** * Put a new block of serialized bytes to the block manager. - * Return a list of blocks updated as a result of this put. + * + * @return true if the block was stored or false if the block was already stored or an + * error occurred. */ def putBytes( blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { + effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { require(bytes != null, "Bytes is null") doPut(blockId, ByteBufferValues(bytes), level, tellMaster, effectiveStorageLevel) } @@ -674,14 +682,16 @@ private[spark] class BlockManager( * The effective storage level refers to the level according to which the block will actually be * handled. This allows the caller to specify an alternate behavior of doPut while preserving * the original level specified by the user. + * + * @return true if the block was stored or false if the block was already stored or an + * error occurred. */ private def doPut( blockId: BlockId, data: BlockValues, level: StorageLevel, tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None) - : Seq[(BlockId, BlockStatus)] = { + effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { require(blockId != null, "BlockId is null") require(level != null && level.isValid, "StorageLevel is null or invalid") @@ -689,9 +699,6 @@ private[spark] class BlockManager( require(level != null && level.isValid, "Effective StorageLevel is null or invalid") } - // Return value - val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - /* Remember the block's storage level so that we can correctly drop it to disk if it needs * to be dropped right after it got put into memory. Note, however, that other threads will * not be able to get() this block until we call markReady on its BlockInfo. */ @@ -702,7 +709,7 @@ private[spark] class BlockManager( if (oldBlockOpt.isDefined) { if (oldBlockOpt.get.waitForReady()) { logWarning(s"Block $blockId already exists on this machine; not re-adding it") - return updatedBlocks + return false } // TODO: So the block info exists - but previous attempt to load it (?) failed. // What do we do now ? Retry on it ? @@ -743,11 +750,12 @@ private[spark] class BlockManager( case _ => null } + var marked = false + putBlockInfo.synchronized { logTrace("Put for block %s took %s to get into synchronized block" .format(blockId, Utils.getUsedTimeMs(startTimeMs))) - var marked = false try { // returnValues - Whether to return the values put // blockStore - The type of storage to put these values into @@ -783,11 +791,6 @@ private[spark] class BlockManager( case _ => } - // Keep track of which blocks are dropped from memory - if (putLevel.useMemory) { - result.droppedBlocks.foreach { updatedBlocks += _ } - } - val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) if (putBlockStatus.storageLevel != StorageLevel.NONE) { // Now that the block is in either the memory, externalBlockStore, or disk store, @@ -797,7 +800,11 @@ private[spark] class BlockManager( if (tellMaster) { reportBlockStatus(blockId, putBlockInfo, putBlockStatus) } - updatedBlocks += ((blockId, putBlockStatus)) + Option(TaskContext.get()).foreach { taskContext => + val metrics = taskContext.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, putBlockStatus))) + } } } finally { // If we failed in putting the block to memory/disk, notify other possible readers @@ -847,7 +854,7 @@ private[spark] class BlockManager( .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } - updatedBlocks + marked } /** @@ -967,32 +974,27 @@ private[spark] class BlockManager( /** * Write a block consisting of a single object. + * + * @return true if the block was stored or false if the block was already stored or an + * error occurred. */ def putSingle( blockId: BlockId, value: Any, level: StorageLevel, - tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { + tellMaster: Boolean = true): Boolean = { putIterator(blockId, Iterator(value), level, tellMaster) } - def dropFromMemory( - blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { - dropFromMemory(blockId, () => data) - } - /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * * If `data` is not put on disk, it won't be created. - * - * Return the block status if the given block has been updated, else None. */ def dropFromMemory( blockId: BlockId, - data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], ByteBuffer]): Unit = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId) @@ -1005,10 +1007,10 @@ private[spark] class BlockManager( if (!info.waitForReady()) { // If we get here, the block write failed. logWarning(s"Block $blockId was marked as failure. Nothing to drop") - return None + return } else if (blockInfo.asScala.get(blockId).isEmpty) { logWarning(s"Block $blockId was already dropped.") - return None + return } var blockIsUpdated = false val level = info.level @@ -1044,11 +1046,14 @@ private[spark] class BlockManager( blockInfo.remove(blockId) } if (blockIsUpdated) { - return Some(status) + Option(TaskContext.get()).foreach { taskContext => + val metrics = taskContext.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, status))) + } } } } - None } /** diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index bdab8c2332fae..76aaa782b9524 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -95,9 +95,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) - PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) + tryToPut(blockId, bytes, bytes.limit, deserialized = false) + PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -110,8 +109,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { // Work on a duplicate - since the original input might be used elsewhere. lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks) + val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false) val data = if (putSuccess) { assert(bytes.limit == size) @@ -119,7 +117,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } else { null } - PutResult(size, data, droppedBlocks) + PutResult(size, data) } override def putArray( @@ -127,15 +125,14 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks) - PutResult(sizeEstimate, Left(values.iterator), droppedBlocks) + tryToPut(blockId, values, sizeEstimate, deserialized = true) + PutResult(sizeEstimate, Left(values.iterator)) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) - PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) + tryToPut(blockId, bytes, bytes.limit, deserialized = false) + PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -165,22 +162,20 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo level: StorageLevel, returnValues: Boolean, allowPersistToDisk: Boolean): PutResult = { - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val unrolledValues = unrollSafely(blockId, values, droppedBlocks) + val unrolledValues = unrollSafely(blockId, values) unrolledValues match { case Left(arrayValues) => // Values are fully unrolled in memory, so store them as an array val res = putArray(blockId, arrayValues, level, returnValues) - droppedBlocks ++= res.droppedBlocks - PutResult(res.size, res.data, droppedBlocks) + PutResult(res.size, res.data) case Right(iteratorValues) => // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk && allowPersistToDisk) { logWarning(s"Persisting block $blockId to disk instead.") val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues) - PutResult(res.size, res.data, droppedBlocks) + PutResult(res.size, res.data) } else { - PutResult(0, Left(iteratorValues), droppedBlocks) + PutResult(0, Left(iteratorValues)) } } } @@ -246,11 +241,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * This method returns either an array with the contents of the entire block or an iterator * containing the values of the block (if the array would have exceeded available memory). */ - def unrollSafely( - blockId: BlockId, - values: Iterator[Any], - droppedBlocks: ArrayBuffer[(BlockId, BlockStatus)]) - : Either[Array[Any], Iterator[Any]] = { + def unrollSafely(blockId: BlockId, values: Iterator[Any]): Either[Array[Any], Iterator[Any]] = { // Number of elements unrolled so far var elementsUnrolled = 0 @@ -270,7 +261,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks) + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -286,8 +277,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - keepUnrolling = reserveUnrollMemoryForThisTask( - blockId, amountToRequest, droppedBlocks) + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -337,9 +327,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo blockId: BlockId, value: Any, size: Long, - deserialized: Boolean, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - tryToPut(blockId, () => value, size, deserialized, droppedBlocks) + deserialized: Boolean): Boolean = { + tryToPut(blockId, () => value, size, deserialized) } /** @@ -355,16 +344,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * blocks to free memory for one block, another thread may use up the freed space for * another block. * - * All blocks evicted in the process, if any, will be added to `droppedBlocks`. - * * @return whether put was successful. */ private def tryToPut( blockId: BlockId, value: () => Any, size: Long, - deserialized: Boolean, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + deserialized: Boolean): Boolean = { /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has @@ -380,7 +366,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo // happen atomically. This relies on the assumption that all memory acquisitions are // synchronized on the same lock. releasePendingUnrollMemoryForThisTask() - val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks) + val enoughMemory = memoryManager.acquireStorageMemory(blockId, size) if (enoughMemory) { // We acquired enough memory for the block, so go ahead and put it val entry = new MemoryEntry(value(), size, deserialized) @@ -398,8 +384,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } else { Right(value().asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) - droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } + blockManager.dropFromMemory(blockId, () => data) } enoughMemory } @@ -413,13 +398,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * * @param blockId the ID of the block we are freeing space for, if any * @param space the size of this block - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. + * @return the amount of memory (in bytes) freed by eviction */ - private[spark] def evictBlocksToFreeSpace( - blockId: Option[BlockId], - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + private[spark] def evictBlocksToFreeSpace(blockId: Option[BlockId], space: Long): Long = { assert(space > 0) memoryManager.synchronized { var freedMemory = 0L @@ -453,17 +434,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } else { Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) - droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } + blockManager.dropFromMemory(blockId, () => data) } } - true + freedMemory } else { blockId.foreach { id => logInfo(s"Will not store $id as it would require dropping another block " + "from the same RDD") } - false + 0L } } } @@ -481,12 +461,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo * Reserve memory for unrolling the given block for this task. * @return whether the request is granted. */ - def reserveUnrollMemoryForThisTask( - blockId: BlockId, - memory: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + def reserveUnrollMemoryForThisTask(blockId: BlockId, memory: Long): Boolean = { memoryManager.synchronized { - val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks) + val success = memoryManager.acquireUnrollMemory(blockId, memory) if (success) { val taskAttemptId = currentTaskAttemptId() unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 30aa94c8a5971..3865c201bf893 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -85,7 +85,12 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager val context = TaskContext.empty() - cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) - assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) + try { + TaskContext.setTaskContext(context) + cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) + assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) + } finally { + TaskContext.unset() + } } } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 3b2368798c1dd..d9764c7c10983 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -70,8 +70,7 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft */ protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) - when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) - .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) + when(ms.evictBlocksToFreeSpace(any(), anyLong())).thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) mm.setMemoryStore(ms) ms } @@ -89,9 +88,9 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft * records the number of bytes this is called with. This variable is expected to be cleared * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. */ - private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Boolean] = { - new Answer[Boolean] { - override def answer(invocation: InvocationOnMock): Boolean = { + private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = { + new Answer[Long] { + override def answer(invocation: InvocationOnMock): Long = { val args = invocation.getArguments val numBytesToFree = args(1).asInstanceOf[Long] assert(numBytesToFree > 0) @@ -101,20 +100,12 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft if (numBytesToFree <= mm.storageMemoryUsed) { // We can evict enough blocks to fulfill the request for space mm.releaseStorageMemory(numBytesToFree) - args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( + evictedBlocks.append( (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))) - // We need to add this call so that that the suite-level `evictedBlocks` is updated when - // execution evicts storage; in that case, args.last will not be equal to evictedBlocks - // because it will be a temporary buffer created inside of the MemoryManager rather than - // being passed in by the test code. - if (!(evictedBlocks eq args.last)) { - evictedBlocks.append( - (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))) - } - true + numBytesToFree } else { // No blocks were evicted because eviction would not free enough space. - false + 0L } } } diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 68cf26fc3ed5d..eee78d396e147 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -81,22 +81,22 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val dummyBlock = TestBlockId("you can see the world you brought to live") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 10L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) - assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 100L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted - assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) + assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction - assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem)) assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 1L)) assertEvictBlocksToFreeSpaceCalled(ms, 1L) assert(evictedBlocks.nonEmpty) evictedBlocks.clear() @@ -107,12 +107,12 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 1L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 1L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired @@ -134,7 +134,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase - assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 50L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) @@ -152,21 +152,21 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val maxStorageMem = 1000L val dummyBlock = TestBlockId("lonely water") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) - assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.acquireUnrollMemory(dummyBlock, 100L)) when(ms.currentUnrollMemory).thenReturn(100L) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 100L) mm.releaseUnrollMemory(40L) assert(mm.storageMemoryUsed === 60L) when(ms.currentUnrollMemory).thenReturn(60L) - assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 800L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. // As of this point, cache memory is 800 bytes and current unroll memory is 60 bytes. // Requesting 240 more bytes of unroll memory will leave our total unroll memory at // 300 bytes, still under the 400-byte limit. Therefore, all 240 bytes are granted. - assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) + assert(mm.acquireUnrollMemory(dummyBlock, 240L)) assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 860 + 240 - 1000 when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 assert(mm.storageMemoryUsed === 1000L) @@ -174,7 +174,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { // We already have 300 bytes of unroll memory, so requesting 150 more will leave us // above the 400-byte limit. Since there is not enough free memory, this request will // fail even after evicting as much as we can (400 - 300 = 100 bytes). - assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) + assert(!mm.acquireUnrollMemory(dummyBlock, 150L)) assertEvictBlocksToFreeSpaceCalled(ms, 100L) assert(mm.storageMemoryUsed === 900L) // Release beyond what was acquired diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 4a1e49b45df40..e5cb9d3a99f0b 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -41,14 +41,8 @@ class TestMemoryManager(conf: SparkConf) grant } } - override def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def acquireStorageMemory(blockId: BlockId, numBytes: Long): Boolean = true + override def acquireUnrollMemory(blockId: BlockId, numBytes: Long): Boolean = true override def releaseStorageMemory(numBytes: Long): Unit = {} override private[memory] def releaseExecutionMemory( numBytes: Long, diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 6cc48597d38f9..0c4359c3c2cd5 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -74,24 +74,24 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val maxMemory = 1000L val (mm, ms) = makeThings(maxMemory) assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 10L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) - assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 100L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted - assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) + assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction - assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, maxMemory)) assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) assert(evictedBlocks.nonEmpty) evictedBlocks.clear() - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 1L)) assertEvictBlocksToFreeSpaceCalled(ms, 1L) assert(evictedBlocks.nonEmpty) evictedBlocks.clear() @@ -102,12 +102,12 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 1L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 1L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired @@ -120,7 +120,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) // Acquire enough storage memory to exceed the storage region - assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 750L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) @@ -140,7 +140,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region - assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 400L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 400L) assert(mm.executionMemoryUsed === 300L) @@ -157,7 +157,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) // Acquire enough storage memory to exceed the storage region size - assert(mm.acquireStorageMemory(dummyBlock, 700L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 700L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 700L) @@ -182,11 +182,11 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.storageMemoryUsed === 0L) assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should not be able to evict execution - assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 100L)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) assertEvictBlocksToFreeSpaceNotCalled(ms) - assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) + assert(!mm.acquireStorageMemory(dummyBlock, 250L)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) // Do not attempt to evict blocks, since evicting will not free enough memory: @@ -199,11 +199,11 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.storageMemoryUsed === 0L) assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should still not be able to evict execution - assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 750L)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) assertEvictBlocksToFreeSpaceNotCalled(ms) // since there were 800 bytes free - assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) + assert(!mm.acquireStorageMemory(dummyBlock, 850L)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) // Do not attempt to evict blocks, since evicting will not free enough memory: @@ -243,7 +243,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) assert(mm.acquireExecutionMemory(100L, 1, MemoryMode.ON_HEAP) === 100L) // Fill up all of the remaining memory with storage. - assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 800L)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 800) assert(mm.executionMemoryUsed === 200) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0f3156117004b..6e6cf6385f919 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -184,8 +184,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a1", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", () => null: Either[Array[Any], ByteBuffer]) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") @@ -425,8 +425,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE t2.join() t3.join() - store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a1", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", () => null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -847,23 +847,37 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) + def getUpdatedBlocks(task: => Unit): Seq[(BlockId, BlockStatus)] = { + val context = TaskContext.empty() + try { + TaskContext.setTaskContext(context) + task + } finally { + TaskContext.unset() + } + context.taskMetrics.updatedBlocks.getOrElse(Seq.empty) + } + // 1 updated block (i.e. list1) - val updatedBlocks1 = + val updatedBlocks1 = getUpdatedBlocks { store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks1.size === 1) assert(updatedBlocks1.head._1 === TestBlockId("list1")) assert(updatedBlocks1.head._2.storageLevel === StorageLevel.MEMORY_ONLY) // 1 updated block (i.e. list2) - val updatedBlocks2 = + val updatedBlocks2 = getUpdatedBlocks { store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + } assert(updatedBlocks2.size === 1) assert(updatedBlocks2.head._1 === TestBlockId("list2")) assert(updatedBlocks2.head._2.storageLevel === StorageLevel.MEMORY_ONLY) // 2 updated blocks - list1 is kicked out of memory while list3 is added - val updatedBlocks3 = + val updatedBlocks3 = getUpdatedBlocks { store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks3.size === 2) updatedBlocks3.foreach { case (id, status) => id match { @@ -875,8 +889,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains("list3"), "list3 was not in memory store") // 2 updated blocks - list2 is kicked out of memory (but put on disk) while list4 is added - val updatedBlocks4 = + val updatedBlocks4 = getUpdatedBlocks { store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks4.size === 2) updatedBlocks4.foreach { case (id, status) => id match { @@ -889,8 +904,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains("list4"), "list4 was not in memory store") // No updated blocks - list5 is too big to fit in store and nothing is kicked out - val updatedBlocks5 = + val updatedBlocks5 = getUpdatedBlocks { store.putIterator("list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks5.size === 0) // memory store contains only list3 and list4 @@ -1005,8 +1021,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.currentUnrollMemoryForThisTask === 0) def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - memoryStore.reserveUnrollMemoryForThisTask( - TestBlockId(""), memory, new ArrayBuffer[(BlockId, BlockStatus)]) + memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory) } // Reserve @@ -1062,11 +1077,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. - var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) + var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.releasePendingUnrollMemoryForThisTask() @@ -1074,24 +1088,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) - unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) + unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - assert(droppedBlocks.size === 1) - assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) - droppedBlocks.clear() + assert(memoryStore.contains("someBlock2")) + assert(!memoryStore.contains("someBlock1")) memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. // In the mean time, however, we kicked out someBlock2 before giving up. store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) - unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) + unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - assert(droppedBlocks.size === 1) - assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) - droppedBlocks.clear() + assert(!memoryStore.contains("someBlock2")) } test("safely unroll blocks through putIterator") { @@ -1238,7 +1249,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE }) assert(result.size === 13000) assert(result.data === null) - assert(result.droppedBlocks === Nil) } test("put a small ByteBuffer to MemoryStore") { @@ -1252,6 +1262,5 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE }) assert(result.size === 10000) assert(result.data === Right(bytes)) - assert(result.droppedBlocks === Nil) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index faa5aca1d8f7a..e22e320b17126 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -71,7 +71,7 @@ private[streaming] class BlockManagerBasedBlockHandler( var numRecords: Option[Long] = None - val putResult: Seq[(BlockId, BlockStatus)] = block match { + val putSucceeded: Boolean = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, @@ -88,7 +88,7 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") } - if (!putResult.map { _._1 }.contains(blockId)) { + if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } @@ -184,9 +184,9 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in block manager val storeInBlockManagerFuture = Future { - val putResult = + val putSucceeded = blockManager.putBytes(blockId, serializedBlock, effectiveStorageLevel, tellMaster = true) - if (!putResult.map { _._1 }.contains(blockId)) { + if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } From 38c3c0e31a00aed56f0bc0791a2789c845a3fd61 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 18 Jan 2016 13:55:42 -0800 Subject: [PATCH 1604/2089] [SPARK-12855][SQL] Remove parser dialect developer API This pull request removes the public developer parser API for external parsers. Given everything a parser depends on (e.g. logical plans and expressions) are internal and not stable, external parsers will break with every release of Spark. It is a bad idea to create the illusion that Spark actually supports pluggable parsers. In addition, this also reduces incentives for 3rd party projects to contribute parse improvements back to Spark. Author: Reynold Xin Closes #10801 from rxin/SPARK-12855. --- project/MimaExcludes.scala | 4 +- .../spark/sql/catalyst/CatalystQl.scala | 2 +- .../spark/sql/catalyst/ParserDialect.scala | 6 +-- .../spark/sql/catalyst/errors/package.scala | 2 - .../scala/org/apache/spark/sql/SQLConf.scala | 20 ---------- .../org/apache/spark/sql/SQLContext.scala | 35 ++--------------- .../apache/spark/sql/execution/commands.scala | 8 +--- .../sql/execution/datasources/DDLParser.scala | 5 ++- .../org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 ---------- .../apache/spark/sql/hive/HiveContext.scala | 10 ++--- .../apache/spark/sql/hive/test/TestHive.scala | 2 - .../sql/hive/execution/SQLQuerySuite.scala | 39 ------------------- 13 files changed, 16 insertions(+), 139 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ccd3c34bb5c8c..4430bfd3b0380 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -53,7 +53,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect") ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index c1591ecfe2b4d..4bc721d0b2074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler /** - * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. + * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s. */ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserDialect { object Token { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 7d9fbf2f12ee6..3aa141af63079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** - * Root class of SQL Parser Dialect, and we don't guarantee the binary - * compatibility for the future release, let's keep it as the internal - * interface for advanced user. + * Interface for a parser. */ -@DeveloperApi trait ParserDialect { /** Creates LogicalPlan for a given SQL string. */ def parsePlan(sqlText: String): LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index d2a90a50c89f4..0d44d1dd9650b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -38,8 +38,6 @@ package object errors { } } - class DialectException(msg: String, cause: Throwable) extends Exception(msg, cause) - /** * Wraps any exceptions that are thrown while executing `f` in a * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4c1eb0b30b25e..2d664d3ee691b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -278,11 +278,6 @@ private[spark] object SQLConf { doc = "When true, common subexpressions will be eliminated.", isPublic = false) - val DIALECT = stringConf( - "spark.sql.dialect", - defaultValue = Some("sql"), - doc = "The default SQL dialect to use.") - val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", defaultValue = Some(true), doc = "Whether the query analyzer should be case sensitive or not.") @@ -524,21 +519,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon new java.util.HashMap[String, String]()) /** ************************ Spark SQL Params/Hints ******************* */ - // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? - - /** - * The SQL dialect that is used when parsing queries. This defaults to 'sql' which uses - * a simple SQL parser provided by Spark SQL. This is currently the only option for users of - * SQLContext. - * - * When using a HiveContext, this value defaults to 'hiveql', which uses the Hive 0.12.0 HiveQL - * parser. Users can change this to 'sql' if they want to run queries that aren't supported by - * HiveQL (e.g., SELECT 1). - * - * Note that the choice of dialect does not affect things like what tables are available or - * how query execution is performed. - */ - private[spark] def dialect: String = getConf(DIALECT) private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 18ddffe1be211..b8c8c78b91a5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -33,13 +32,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.SQLConf.SQLConfEntry -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} +import org.apache.spark.sql.catalyst.{InternalRow, _} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ @@ -206,30 +203,10 @@ class SQLContext private[sql]( protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient - protected[sql] val ddlParser = new DDLParser(sqlParser) + protected[sql] val sqlParser: ParserDialect = new SparkSQLParser(new SparkQl(conf)) @transient - protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect()) - - protected[sql] def getSQLDialect(): ParserDialect = { - try { - val clazz = Utils.classForName(dialectClassName) - clazz.getConstructor(classOf[ParserConf]) - .newInstance(conf) - .asInstanceOf[ParserDialect] - } catch { - case NonFatal(e) => - // Since we didn't find the available SQL Dialect, it will fail even for SET command: - // SET spark.sql.dialect=sql; Let's reset as default dialect automatically. - val dialect = conf.dialect - // reset the sql dialect - conf.unsetConf(SQLConf.DIALECT) - // throw out the exception, and the default sql dialect will take effect for next query. - throw new DialectException( - s"""Instantiating dialect '$dialect' failed. - |Reverting to default dialect '${conf.dialect}'""".stripMargin, e) - } - } + protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) @@ -239,12 +216,6 @@ class SQLContext private[sql]( protected[sql] def executePlan(plan: LogicalPlan) = new sparkexecution.QueryExecution(this, plan) - protected[sql] def dialectClassName = if (conf.dialect == "sql") { - classOf[SparkQl].getCanonicalName - } else { - conf.dialect - } - /** * Add a jar to SQLContext */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 2e2ce88211a08..3cfa3dfd9c7ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -201,13 +201,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm case Some((key, None)) => val runFunc = (sqlContext: SQLContext) => { val value = - try { - if (key == SQLConf.DIALECT.key) { - sqlContext.conf.dialect - } else { - sqlContext.getConf(key) - } - } catch { + try sqlContext.getConf(key) catch { case _: NoSuchElementException => "" } Seq(Row(key, value)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 10655a85ccf89..4dea947f6a849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -37,9 +37,10 @@ class DDLParser(fallback: => ParserDialect) override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) - override def parseTableIdentifier(sql: String): TableIdentifier = - + override def parseTableIdentifier(sql: String): TableIdentifier = { fallback.parseTableIdentifier(sql) + } + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { try { parsePlan(input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8c2530fd684a7..3a27466176a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1064,7 +1064,7 @@ object functions extends LegacyFunctions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.getSQLDialect()).getOrElse(new CatalystQl()) + val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl()) Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bdb9421cc19d7..d7f182352b4c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,7 +23,6 @@ import java.sql.Timestamp import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.CatalystQl import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} @@ -32,8 +31,6 @@ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -/** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect(conf: ParserConf) extends CatalystQl(conf) class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -148,23 +145,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .count(), Row(24, 1) :: Row(14, 1) :: Nil) } - test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sparkContext) - newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) - assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) - assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) - } - - test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sparkContext) - newContext.sql("SET spark.sql.dialect=MyTestClass") - intercept[DialectException] { - newContext.sql("SELECT 1") - } - // test if the dialect set back to DefaultSQLDialect - assert(newContext.getSQLDialect().getClass === classOf[SparkQl]) - } - test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7bdca52200305..dd536b78c7242 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -542,16 +542,12 @@ class HiveContext private[hive]( } protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - protected[sql] override def getSQLDialect(): ParserDialect = { - if (conf.dialect == "hiveql") { - new ExtendedHiveQlParser(this) - } else { - super.getSQLDialect() - } + @transient + protected[sql] override val sqlParser: ParserDialect = { + new SparkSQLParser(new ExtendedHiveQlParser(this)) } @transient diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 033746d42f557..a33223af24370 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -120,8 +120,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { new this.QueryExecution(plan) protected[sql] override lazy val conf: SQLConf = new SQLConf { - // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" - override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 61d5aa7ae6b31..f7d8d395ba648 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} -import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -57,8 +56,6 @@ case class WindowData( month: Int, area: String, product: Int) -/** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect(conf: ParserConf) extends HiveQl(conf) /** * A collection of hive query tests where we generate the answers ourselves instead of depending on @@ -337,42 +334,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("SQL dialect at the start of HiveContext") { - val hiveContext = new HiveContext(sqlContext.sparkContext) - val dialectConf = "spark.sql.dialect" - checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql")) - assert(hiveContext.getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) - } - - test("SQL Dialect Switching") { - assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) - setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) - assert(getSQLDialect().getClass === classOf[MyDialect]) - assert(sql("SELECT 1").collect() === Array(Row(1))) - - // set the dialect back to the DefaultSQLDialect - sql("SET spark.sql.dialect=sql") - assert(getSQLDialect().getClass === classOf[SparkQl]) - sql("SET spark.sql.dialect=hiveql") - assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) - - // set invalid dialect - sql("SET spark.sql.dialect.abc=MyTestClass") - sql("SET spark.sql.dialect=abc") - intercept[Exception] { - sql("SELECT 1") - } - // test if the dialect set back to HiveQLDialect - getSQLDialect().getClass === classOf[ExtendedHiveQlParser] - - sql("SET spark.sql.dialect=MyTestClass") - intercept[DialectException] { - sql("SELECT 1") - } - // test if the dialect set back to HiveQLDialect - assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) - } - test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( From 4f11e3f2aa4f097ed66296fe72b5b5384924010c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jan 2016 14:15:27 -0800 Subject: [PATCH 1605/2089] [SPARK-12841][SQL] fix cast in filter In SPARK-10743 we wrap cast with `UnresolvedAlias` to give `Cast` a better alias if possible. However, for cases like `filter`, the `UnresolvedAlias` can't be resolved and actually we don't need a better alias for this case. This PR move the cast wrapping logic to `Column.named` so that we will only do it when we need a alias name. Author: Wenchen Fan Closes #10781 from cloud-fan/bug. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 17 ++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dadea6b54a946..9257fba60e36c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -147,7 +147,7 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr transform { + expr transformUp { case u @ UnresolvedAlias(child, optionalAliasName) => child match { case ne: NamedExpression => ne case e if !e.resolved => u diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 97bf7a0cc4514..2ab091e40a073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -133,6 +133,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + // If we have a top level Cast, there is a chance to give it a better alias, if there is a + // NamedExpression under this Cast. + case c: Cast => c.transformUp { + case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) + } match { + case ne: NamedExpression => ne + case other => Alias(expr, expr.prettyString)() + } + case expr: Expression => Alias(expr, expr.prettyString)() } @@ -921,13 +930,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { - expr match { - // keeps the name of expression if possible when do cast. - case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) - case _ => Cast(expr, to) - } - } + def cast(to: DataType): Column = withExpr { Cast(expr, to) } /** * Casts the column to a different data type, using the canonical string representation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d6c140dfea9ed..afc8df07fd9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1007,6 +1007,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") + assert(df.select($"src.i".cast(StringType).cast(IntegerType)).columns.head === "i") } test("SPARK-11301: fix case sensitivity for filter on partitioned columns") { @@ -1228,4 +1229,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b")) checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c")) } + + test("SPARK-12841: cast in filter") { + checkAnswer( + Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"), + Row(1, "a")) + } } From 404190221a788ebc3a0cbf5cb47cf532436ce965 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jan 2016 15:10:04 -0800 Subject: [PATCH 1606/2089] [SPARK-12882][SQL] simplify bucket tests and add more comments Right now, the bucket tests are kind of hard to understand, this PR simplifies them and add more commetns. Author: Wenchen Fan Closes #10813 from cloud-fan/bucket-comment. --- .../spark/sql/sources/BucketedReadSuite.scala | 56 +++++++++------ .../sql/sources/BucketedWriteSuite.scala | 68 ++++++++++++------- 2 files changed, 78 insertions(+), 46 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 58ecdd3b801d3..150d0c748631e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + /** + * A helper method to test the bucket read functionality using join. It will save `df1` and `df2` + * to hive tables, bucketed or not, according to the given bucket specifics. Next we will join + * these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle + * exists as user expected according to the `shuffleLeft` and `shuffleRight`. + */ private def testBucketing( - bucketing1: DataFrameWriter => DataFrameWriter, - bucketing2: DataFrameWriter => DataFrameWriter, + bucketSpecLeft: Option[BucketSpec], + bucketSpecRight: Option[BucketSpec], joinColumns: Seq[String], shuffleLeft: Boolean, shuffleRight: Boolean): Unit = { withTable("bucketed_table1", "bucketed_table2") { - bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1") - bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2") + def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = { + bucketSpec.map { spec => + writer.bucketBy( + spec.numBuckets, + spec.bucketColumnNames.head, + spec.bucketColumnNames.tail: _*) + }.getOrElse(writer) + } + + withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1") + withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val t1 = hiveContext.table("bucketed_table1") @@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } test("avoid shuffle when join 2 bucketed tables") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") - testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") - testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) } test("only shuffle one side when join bucketed table and non-bucketed table") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") - testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) } test("only shuffle one side when 2 bucketed tables have different bucket number") { - val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") - val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j") - testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) + testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { - val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i") - val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j") - testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true) + val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) + testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) } test("shuffle when join keys are not equal to bucket keys") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") - testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true) + val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) } test("shuffle when join 2 bucketed tables with bucketing disabled") { - val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index e812439bed9aa..dad1fc1273810 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + /** + * A helper method to check the bucket write functionality in low level, i.e. check the written + * bucket files to see if the data are correct. User should pass in a data dir that these bucket + * files are written to, and the format of data(parquet, json, etc.), and the bucketing + * information. + */ private def testBucketing( dataDir: File, source: String, + numBuckets: Int, bucketCols: Seq[String], sortCols: Seq[String] = Nil): Unit = { val allBucketFiles = dataDir.listFiles().filterNot(f => f.getName.startsWith(".") || f.getName.startsWith("_") ) - val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get) - assert(groupedBucketFiles.size <= 8) - - for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) { - val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) - val columns = (bucketCols ++ sortCols).zip(types).map { - case (colName, dt) => col(colName).cast(dt) - } - val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*) - if (sortCols.nonEmpty) { - checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) - } + for (bucketFile <- allBucketFiles) { + val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get + assert(bucketId >= 0 && bucketId < numBuckets) - val qe = readBack.select(bucketCols.map(col): _*).queryExecution - val rows = qe.toRdd.map(_.copy()).collect() - val getBucketId = UnsafeProjection.create( - HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, - qe.analyzed.output) + // We may loss the type information after write(e.g. json format doesn't keep schema + // information), here we get the types from the original dataframe. + val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) + val columns = (bucketCols ++ sortCols).zip(types).map { + case (colName, dt) => col(colName).cast(dt) + } - for (row <- rows) { - val actualBucketId = getBucketId(row).getInt(0) - assert(actualBucketId == bucketId) - } + // Read the bucket file into a dataframe, so that it's easier to test. + val readBack = sqlContext.read.format(source) + .load(bucketFile.getAbsolutePath) + .select(columns: _*) + + // If we specified sort columns while writing bucket table, make sure the data in this + // bucket file is already sorted. + if (sortCols.nonEmpty) { + checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) + } + + // Go through all rows in this bucket file, calculate bucket id according to bucket column + // values, and make sure it equals to the expected bucket id that inferred from file name. + val qe = readBack.select(bucketCols.map(col): _*).queryExecution + val rows = qe.toRdd.map(_.copy()).collect() + val getBucketId = UnsafeProjection.create( + HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil, + qe.analyzed.output) + + for (row <- rows) { + val actualBucketId = getBucketId(row).getInt(0) + assert(actualBucketId == bucketId) } } } @@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) } } } @@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k")) } } } @@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .saveAsTable("bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - testBucketing(tableDir, source, Seq("i", "j")) + testBucketing(tableDir, source, 8, Seq("i", "j")) } } } @@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .saveAsTable("bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table") - testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) + testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k")) } } } From a973f483f6b819ed4ecac27ff5c064ea13a8dd71 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 18 Jan 2016 15:38:03 -0800 Subject: [PATCH 1607/2089] [SPARK-12814][DOCUMENT] Add deploy instructions for Python in flume integration doc This PR added instructions to get flume assembly jar for Python users in the flume integration page like Kafka doc. Author: Shixiong Zhu Closes #10746 from zsxwing/flume-doc. --- docs/streaming-flume-integration.md | 13 +++++++++++-- docs/streaming-kafka-integration.md | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 383d954409ce4..e2d589b843f99 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -71,7 +71,16 @@ configuring Flume agents. cluster (Mesos, YARN or Spark Standalone), so that resource allocation can match the names and launch the receiver in the right machine. -3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-flume-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-flume-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Pull-based Approach using a Custom Sink Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. @@ -157,7 +166,7 @@ configuring Flume agents. Note that each input DStream can be configured to receive data from multiple sinks. -3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). +3. **Deploying:** This is same as the first approach. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 9454714eeb9cb..015a2f1fa0bdc 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -71,7 +71,7 @@ Next, we discuss how to use this approach in your streaming application. ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-assembly` from the - [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. @@ -207,4 +207,4 @@ Next, we discuss how to use this approach in your streaming application. Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. -3. **Deploying:** This is same as the first approach, for Scala, Java and Python. +3. **Deploying:** This is same as the first approach. From 4bcea1b8595424678aa6c92d66ba08c92e0fefe5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 18 Jan 2016 16:26:52 -0800 Subject: [PATCH 1608/2089] Revert "[SPARK-12829] Turn Java style checker on" This reverts commit 591c88c9e2a6c2e2ca84f1b66c635f198a16d112. `lint-java` doesn't work on a machine with a clean Maven cache. --- dev/run-tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5b717895106b3..8f47728f206c3 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -538,7 +538,8 @@ def main(): or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - run_java_style_checks() + # run_java_style_checks() + pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): From 721845c1b64fd6e3b911bd77c94e01dc4e5fd102 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 18 Jan 2016 16:50:05 -0800 Subject: [PATCH 1609/2089] [SPARK-12894][DOCUMENT] Add deploy instructions for Python in Kinesis integration doc This PR added instructions to get Kinesis assembly jar for Python users in the Kinesis integration page like Kafka doc. Author: Shixiong Zhu Closes #10822 from zsxwing/kinesis-doc. --- docs/streaming-kinesis-integration.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 07194b0a6b758..5f5e2b9087cc9 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -15,12 +15,13 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Configuring Spark Streaming Application -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. **Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.** 2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream of byte array as follows: @@ -116,7 +117,16 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m In other versions of the API, you can also specify the AWS access key and secret key directly. -3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kinesis-asl-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kinesis-asl-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. *Points to remember at runtime:* From 39ac56fc60734d0e095314fc38a7b36fbb4c80f7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 18 Jan 2016 17:10:32 -0800 Subject: [PATCH 1610/2089] [SPARK-12889][SQL] Rename ParserDialect -> ParserInterface. Based on discussions in #10801, I'm submitting a pull request to rename ParserDialect to ParserInterface. Author: Reynold Xin Closes #10817 from rxin/SPARK-12889. --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala | 2 +- .../catalyst/{ParserDialect.scala => ParserInterface.scala} | 2 +- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../scala/org/apache/spark/sql/execution/SparkSQLParser.scala | 4 ++-- .../apache/spark/sql/execution/datasources/DDLParser.scala | 4 ++-- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 4 ++-- 7 files changed, 10 insertions(+), 10 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/{ParserDialect.scala => ParserInterface.scala} (98%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 9443369808984..38fa5cb585ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser - extends StandardTokenParsers with PackratParsers with ParserDialect { + extends StandardTokenParsers with PackratParsers with ParserInterface { def parsePlan(input: String): LogicalPlan = synchronized { // Initialize the Keywords. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 4bc721d0b2074..5fb41f7e4bf8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.random.RandomSampler /** * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s. */ -private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserDialect { +private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface { object Token { def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { CurrentOrigin.setPosition(node.line, node.positionInLine) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserInterface.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserInterface.scala index 3aa141af63079..24ec452c4d2ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserInterface.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * Interface for a parser. */ -trait ParserDialect { +trait ParserInterface { /** Creates LogicalPlan for a given SQL string. */ def parsePlan(sqlText: String): LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b8c8c78b91a5e..147e3557b632b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -203,7 +203,7 @@ class SQLContext private[sql]( protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient - protected[sql] val sqlParser: ParserDialect = new SparkSQLParser(new SparkQl(conf)) + protected[sql] val sqlParser: ParserInterface = new SparkSQLParser(new SparkQl(conf)) @transient protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index 1af2c756cdc5a..d2d8271563726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.StringType * parameter because this allows us to return a different dialect if we * have to. */ -class SparkSQLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParser { override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 4dea947f6a849..f4766b037027d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -22,7 +22,7 @@ import scala.util.matching.Regex import org.apache.spark.Logging import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ /** * A parser for foreign DDL commands. */ -class DDLParser(fallback: => ParserDialect) +class DDLParser(fallback: => ParserInterface) extends AbstractSparkSQLParser with DataTypeParser with Logging { override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index dd536b78c7242..eaca3c9269bb7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect} +import org.apache.spark.sql.catalyst.{InternalRow, ParserInterface} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -546,7 +546,7 @@ class HiveContext private[hive]( } @transient - protected[sql] override val sqlParser: ParserDialect = { + protected[sql] override val sqlParser: ParserInterface = { new SparkSQLParser(new ExtendedHiveQlParser(this)) } From 323d51f1dadf733e413203d678cb3f76e4d68981 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 18 Jan 2016 17:29:54 -0800 Subject: [PATCH 1611/2089] [SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin Currently SortMergeJoin and BroadcastHashJoin do not support condition, the need a followed Filter for that, the result projection to generate UnsafeRow could be very expensive if they generate lots of rows and could be filtered mostly by condition. This PR brings the support of condition for SortMergeJoin and BroadcastHashJoin, just like other outer joins do. This could improve the performance of Q72 by 7x (from 120s to 16.5s). Author: Davies Liu Closes #10653 from davies/filter_join. --- .../spark/sql/execution/SparkStrategies.scala | 24 ++---- .../execution/joins/BroadcastHashJoin.scala | 1 + .../spark/sql/execution/joins/HashJoin.scala | 81 +++++++++++-------- .../sql/execution/joins/HashOuterJoin.scala | 5 +- .../sql/execution/joins/SortMergeJoin.scala | 46 +++++++---- .../sql/execution/joins/InnerJoinSuite.scala | 11 +-- 6 files changed, 96 insertions(+), 72 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910519d0e6814..df0f730499211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object EquiJoinSelection extends Strategy with PredicateHelper { - private[this] def makeBroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: LogicalPlan, - right: LogicalPlan, - condition: Option[Expression], - side: joins.BuildSide): Seq[SparkPlan] = { - val broadcastHashJoin = execution.joins.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right)) - condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) + joins.BroadcastHashJoin( + leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + joins.BroadcastHashJoin( + leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => - val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + joins.SortMergeJoin( + leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil // --- Outer joins -------------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 0a818cc2c2a27..c9ea579b5e809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -39,6 +39,7 @@ case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], buildSide: BuildSide, + condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7f9d9daa5ab20..8ef854001f4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.joins +import java.util.NoSuchElementException + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -29,6 +31,7 @@ trait HashJoin { val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide + val condition: Option[Expression] val left: SparkPlan val right: SparkPlan @@ -50,6 +53,12 @@ trait HashJoin { protected def streamSideKeyGenerator: Projection = UnsafeProjection.create(streamedKeys, streamedPlan.output) + @transient private[this] lazy val boundCondition = if (condition.isDefined) { + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + } else { + (r: InternalRow) => true + } + protected def hashJoin( streamIter: Iterator[InternalRow], numStreamRows: LongSQLMetric, @@ -68,44 +77,52 @@ trait HashJoin { private[this] val joinKeys = streamSideKeyGenerator - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) + override final def hasNext: Boolean = { + while (true) { + // check if it's end of current matches + if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) { + currentHashMatches = null + currentMatchPosition = -1 + } - override final def next(): InternalRow = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - numOutputRows += 1 - resultProjection(ret) - } + // find the next match + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + numStreamRows += 1 + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) + if (currentHashMatches != null) { + currentMatchPosition = 0 + } + } + } + if (currentHashMatches == null) { + return false + } - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - numStreamRows += 1 - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashedRelation.get(key) + // found some matches + buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + if (boundCondition(joinRow)) { + return true + } else { + currentMatchPosition += 1 } } + false // unreachable + } - if (currentHashMatches == null) { - false + override final def next(): InternalRow = { + // next() could be called without calling hasNext() + if (hasNext) { + currentMatchPosition += 1 + numOutputRows += 1 + resultProjection(joinRow) } else { - currentMatchPosition = 0 - true + throw new NoSuchElementException } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6d464d6946b78..9e614309de129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -78,8 +78,11 @@ trait HashOuterJoin { @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = + @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + } else { + (row: InternalRow) => true + } // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 812f881d06fb8..322a954b4f792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -64,6 +65,13 @@ case class SortMergeJoin( val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } new RowIterator { // The projection used to extract keys from input rows of the left child. private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) @@ -89,26 +97,34 @@ case class SortMergeJoin( private[this] val resultProjection: (InternalRow) => InternalRow = UnsafeProjection.create(schema) + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } + override def advanceNext(): Boolean = { - if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 - } else { - currentRightMatches = null - currentLeftRow = null - currentMatchIdx = -1 + while (currentMatchIdx >= 0) { + if (currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + return false + } } - } - if (currentLeftRow != null) { joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 - numOutputRows += 1 - true - } else { - false + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } + false } override def getRow: InternalRow = resultProjection(joinRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 42fadaa8e2215..ab81b702596af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{DataFrame, Row, SQLConf} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder @@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan) } def makeSortMergeJoin( @@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan) = { val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) + joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) + EnsureRequirements(sqlContext).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { From 2b5d11f34d73eb7117c0c4668c1abb27dcc3a403 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 18 Jan 2016 19:22:29 -0800 Subject: [PATCH 1612/2089] [SPARK-12885][MINOR] Rename 3 fields in ShuffleWriteMetrics This is a small step in implementing SPARK-10620, which migrates TaskMetrics to accumulators. This patch is strictly a cleanup patch and introduces no change in functionality. It literally just renames 3 fields for consistency. Today we have: ``` inputMetrics.recordsRead outputMetrics.bytesWritten shuffleReadMetrics.localBlocksFetched ... shuffleWriteMetrics.shuffleRecordsWritten shuffleWriteMetrics.shuffleBytesWritten shuffleWriteMetrics.shuffleWriteTime ``` The shuffle write ones are kind of redundant. We can drop the `shuffle` part in the method names. I added backward compatible (but deprecated) methods with the old names. Parent PR: #10717 Author: Andrew Or Closes #10811 from andrewor14/rename-things. --- .../sort/BypassMergeSortShuffleWriter.java | 4 +- .../shuffle/sort/ShuffleExternalSorter.java | 4 +- .../shuffle/sort/UnsafeShuffleWriter.java | 6 +-- .../storage/TimeTrackingOutputStream.java | 10 ++-- .../spark/executor/ShuffleWriteMetrics.scala | 37 ++++++++----- .../spark/scheduler/SparkListener.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/sort/SortShuffleWriter.scala | 2 +- .../status/api/v1/AllStagesResource.scala | 12 ++--- .../spark/storage/DiskBlockObjectWriter.scala | 12 ++--- .../apache/spark/ui/exec/ExecutorsTab.scala | 2 +- .../spark/ui/jobs/JobProgressListener.scala | 8 +-- .../org/apache/spark/ui/jobs/StagePage.scala | 14 ++--- .../org/apache/spark/util/JsonProtocol.scala | 13 ++--- .../collection/ExternalAppendOnlyMap.scala | 4 +- .../util/collection/ExternalSorter.scala | 4 +- .../sort/UnsafeShuffleWriterSuite.java | 20 +++---- .../scala/org/apache/spark/ShuffleSuite.scala | 4 +- .../metrics/InputOutputMetricsSuite.scala | 2 +- .../spark/scheduler/SparkListenerSuite.scala | 2 +- .../BypassMergeSortShuffleWriterSuite.scala | 8 +-- .../storage/DiskBlockObjectWriterSuite.scala | 54 +++++++++---------- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 12 ++--- 24 files changed, 126 insertions(+), 114 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 56cdc22f36261..a06dc1ce91542 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -143,7 +143,7 @@ public void write(Iterator> records) throws IOException { // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be // included in the shuffle write time. - writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); while (records.hasNext()) { final Product2 record = records.next(); @@ -203,7 +203,7 @@ private long[] writePartitionedFile(File outputFile) throws IOException { threwException = false; } finally { Closeables.close(out, threwException); - writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; return lengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 9affff80143d7..2c84de5bf2a5a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -233,8 +233,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); - taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 744c3008ca50e..c8cc7056975ec 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -298,8 +298,8 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. - writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incShuffleBytesWritten(outputFile.length()); + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incBytesWritten(outputFile.length()); return partitionLengths; } } catch (IOException e) { @@ -411,7 +411,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th spillInputChannelPositions[i] += actualBytesTransferred; bytesToTransfer -= actualBytesTransferred; } - writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; } diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index dc2aa30466cc6..5d0555a8c28e1 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -42,34 +42,34 @@ public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream o public void write(int b) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void write(byte[] b) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void write(byte[] b, int off, int len) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b, off, len); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void flush() throws IOException { final long startTime = System.nanoTime(); outputStream.flush(); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void close() throws IOException { final long startTime = System.nanoTime(); outputStream.close(); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 469ebe26c7b56..24795f860087f 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -26,28 +26,39 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi class ShuffleWriteMetrics extends Serializable { + /** * Number of bytes written for the shuffle by this task */ - @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten: Long = _shuffleBytesWritten - private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value - private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value + @volatile private var _bytesWritten: Long = _ + def bytesWritten: Long = _bytesWritten + private[spark] def incBytesWritten(value: Long) = _bytesWritten += value + private[spark] def decBytesWritten(value: Long) = _bytesWritten -= value /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ - @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime: Long = _shuffleWriteTime - private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value - private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value + @volatile private var _writeTime: Long = _ + def writeTime: Long = _writeTime + private[spark] def incWriteTime(value: Long) = _writeTime += value + private[spark] def decWriteTime(value: Long) = _writeTime -= value /** * Total number of records written to the shuffle by this task */ - @volatile private var _shuffleRecordsWritten: Long = _ - def shuffleRecordsWritten: Long = _shuffleRecordsWritten - private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value - private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value - private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value + @volatile private var _recordsWritten: Long = _ + def recordsWritten: Long = _recordsWritten + private[spark] def incRecordsWritten(value: Long) = _recordsWritten += value + private[spark] def decRecordsWritten(value: Long) = _recordsWritten -= value + private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value + + // Legacy methods for backward compatibility. + // TODO: remove these once we make this class private. + @deprecated("use bytesWritten instead", "2.0.0") + def shuffleBytesWritten: Long = bytesWritten + @deprecated("use writeTime instead", "2.0.0") + def shuffleWriteTime: Long = writeTime + @deprecated("use recordsWritten instead", "2.0.0") + def shuffleRecordsWritten: Long = recordsWritten + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3130a65240a99..f5267f58c2e42 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -271,7 +271,7 @@ class StatsReportListener extends SparkListener with Logging { // Shuffle write showBytesDistribution("shuffle bytes written:", - (_, metric) => metric.shuffleWriteMetrics.map(_.shuffleBytesWritten), taskInfoMetrics) + (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics) // Fetch & I/O showMillisDistribution("fetch wait time:", diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 294e16cde1931..2970968f0bd47 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -90,7 +90,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, so should be included in the shuffle write time. - writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) + writeMetrics.incWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { shuffleState.completedMapTasks.add(mapId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index f83cf8859e581..5c5a5f5a4cb6a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -94,7 +94,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val startTime = System.nanoTime() sorter.stop() context.taskMetrics.shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - startTime)) + _.incWriteTime(System.nanoTime - startTime)) sorter = null } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 341ae782362a0..078718ba11260 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -214,9 +214,9 @@ private[v1] object AllStagesResource { raw.shuffleWriteMetrics } def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.shuffleBytesWritten), - writeRecords = submetricQuantiles(_.shuffleRecordsWritten), - writeTime = submetricQuantiles(_.shuffleWriteTime) + writeBytes = submetricQuantiles(_.bytesWritten), + writeRecords = submetricQuantiles(_.recordsWritten), + writeTime = submetricQuantiles(_.writeTime) ) }.metricOption @@ -283,9 +283,9 @@ private[v1] object AllStagesResource { def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): ShuffleWriteMetrics = { new ShuffleWriteMetrics( - bytesWritten = internal.shuffleBytesWritten, - writeTime = internal.shuffleWriteTime, - recordsWritten = internal.shuffleRecordsWritten + bytesWritten = internal.bytesWritten, + writeTime = internal.writeTime, + recordsWritten = internal.recordsWritten ) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e36a367323b20..c34d49c0d9061 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -102,7 +102,7 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() val start = System.nanoTime() fos.getFD.sync() - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) + writeMetrics.incWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -132,7 +132,7 @@ private[spark] class DiskBlockObjectWriter( close() finalPosition = file.length() // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + writeMetrics.incBytesWritten(finalPosition - reportedPosition) } else { finalPosition = file.length() } @@ -152,8 +152,8 @@ private[spark] class DiskBlockObjectWriter( // truncating the file to its initial position. try { if (initialized) { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) + writeMetrics.decBytesWritten(reportedPosition - initialPosition) + writeMetrics.decRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -201,7 +201,7 @@ private[spark] class DiskBlockObjectWriter( */ def recordWritten(): Unit = { numRecordsWritten += 1 - writeMetrics.incShuffleRecordsWritten(1) + writeMetrics.incRecordsWritten(1) if (numRecordsWritten % 32 == 0) { updateBytesWritten() @@ -226,7 +226,7 @@ private[spark] class DiskBlockObjectWriter( */ private def updateBytesWritten() { val pos = channel.position() - writeMetrics.incShuffleBytesWritten(pos - reportedPosition) + writeMetrics.incBytesWritten(pos - reportedPosition) reportedPosition = pos } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 2d955a66601ee..160d7a4dff2dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -129,7 +129,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp } metrics.shuffleWriteMetrics.foreach { shuffleWrite => executorToShuffleWrite(eid) = - executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.shuffleBytesWritten + executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.bytesWritten } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index ca37829216f22..4a9f8b30525fe 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -426,14 +426,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) val shuffleWriteDelta = - (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)) + (taskMetrics.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.bytesWritten).getOrElse(0L)) stageData.shuffleWriteBytes += shuffleWriteDelta execSummary.shuffleWrite += shuffleWriteDelta val shuffleWriteRecordsDelta = - (taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleRecordsWritten).getOrElse(0L)) + (taskMetrics.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.recordsWritten).getOrElse(0L)) stageData.shuffleWriteRecords += shuffleWriteRecordsDelta execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 6d4066a870cdd..914f6183cc2a4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -500,11 +500,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getFormattedSizeQuantiles(shuffleReadRemoteSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble + metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L).toDouble + metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val shuffleWriteQuantiles = Shuffle Write Size / Records +: @@ -619,7 +619,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val shuffleReadTimeProportion = toProportion(shuffleReadTime) val shuffleWriteTime = (metricsOpt.flatMap(_.shuffleWriteMetrics - .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong + .map(_.writeTime)).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) @@ -930,13 +930,13 @@ private[ui] class TaskDataSource( val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L) + val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") val shuffleWriteRecords = maybeShuffleWrite - .map(_.shuffleRecordsWritten.toString).getOrElse("") + .map(_.recordsWritten.toString).getOrElse("") - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) + val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.writeTime) val writeTimeSortable = maybeWriteTime.getOrElse(0L) val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => if (ms == 0) "" else UIUtils.formatDuration(ms) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a6460bc8b8202..b88221a249eb8 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -331,10 +331,11 @@ private[spark] object JsonProtocol { ("Total Records Read" -> shuffleReadMetrics.recordsRead) } + // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = { - ("Shuffle Bytes Written" -> shuffleWriteMetrics.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) ~ - ("Shuffle Records Written" -> shuffleWriteMetrics.shuffleRecordsWritten) + ("Shuffle Bytes Written" -> shuffleWriteMetrics.bytesWritten) ~ + ("Shuffle Write Time" -> shuffleWriteMetrics.writeTime) ~ + ("Shuffle Records Written" -> shuffleWriteMetrics.recordsWritten) } def inputMetricsToJson(inputMetrics: InputMetrics): JValue = { @@ -752,9 +753,9 @@ private[spark] object JsonProtocol { def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = { val metrics = new ShuffleWriteMetrics - metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) - metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long]) - metrics.setShuffleRecordsWritten((json \ "Shuffle Records Written") + metrics.incBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) + metrics.incWriteTime((json \ "Shuffle Write Time").extract[Long]) + metrics.setRecordsWritten((json \ "Shuffle Records Written") .extractOpt[Long].getOrElse(0)) metrics } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 4a44481cf4e14..ff9dad7d38bf0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -193,8 +193,8 @@ class ExternalAppendOnlyMap[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) + _diskBytesSpilled += curWriteMetrics.bytesWritten + batchSizes.append(curWriteMetrics.bytesWritten) objectsWritten = 0 } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 63ba954a7fa7e..4c7416e00b004 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -262,8 +262,8 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += spillMetrics.shuffleBytesWritten - batchSizes.append(spillMetrics.shuffleBytesWritten) + _diskBytesSpilled += spillMetrics.bytesWritten + batchSizes.append(spillMetrics.bytesWritten) spillMetrics = null objectsWritten = 0 } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 5fe64bde3604a..625fdd57eb5d4 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -279,8 +279,8 @@ public void writeEmptyIterator() throws Exception { assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().recordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().bytesWritten()); assertEquals(0, taskMetrics.diskBytesSpilled()); assertEquals(0, taskMetrics.memoryBytesSpilled()); } @@ -311,10 +311,10 @@ public void writeWithoutSpilling() throws Exception { HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertEquals(0, taskMetrics.diskBytesSpilled()); assertEquals(0, taskMetrics.memoryBytesSpilled()); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } private void testMergingSpills( @@ -354,11 +354,11 @@ private void testMergingSpills( assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @Test @@ -416,11 +416,11 @@ public void writeEnoughDataToTriggerSpill() throws Exception { readRecordsFromFile(); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @Test @@ -437,11 +437,11 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce readRecordsFromFile(); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @Test diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c45d81459e8e2..6ffa1c8ac1406 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -450,8 +450,8 @@ object ShuffleSuite { val listener = new SparkListener { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m => - recordsWritten += m.shuffleRecordsWritten - bytesWritten += m.shuffleBytesWritten + recordsWritten += m.recordsWritten + bytesWritten += m.bytesWritten } taskEnd.taskMetrics.shuffleReadMetrics.foreach { m => recordsRead += m.recordsRead diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index aaf62e0f91067..e5a448298a624 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -212,7 +212,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext metrics.inputMetrics.foreach(inputRead += _.recordsRead) metrics.outputMetrics.foreach(outputWritten += _.recordsWritten) metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead) - metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.shuffleRecordsWritten) + metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.recordsWritten) } }) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index dc15f5932d6f8..c87158d89f3fb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -269,7 +269,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match taskMetrics.inputMetrics should not be ('defined) taskMetrics.outputMetrics should not be ('defined) taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0L) + taskMetrics.shuffleWriteMetrics.get.bytesWritten should be > (0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index ef6ce04e3ff28..fdacd8c9f5908 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -145,8 +145,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get - assert(shuffleWriteMetrics.shuffleBytesWritten === 0) - assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) + assert(shuffleWriteMetrics.bytesWritten === 0) + assert(shuffleWriteMetrics.recordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) assert(taskMetrics.memoryBytesSpilled === 0) } @@ -169,8 +169,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(writer.getPartitionLengths.filter(_ == 0L).size === 4) // should be 4 zero length files assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get - assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) assert(taskMetrics.memoryBytesSpilled === 0) } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 5d36617cfc447..8eff3c297035d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -50,18 +50,18 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write - assert(writeMetrics.shuffleRecordsWritten === 1) + assert(writeMetrics.recordsWritten === 1) // Metrics don't update on every write - assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.bytesWritten == 0) // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() writer.write(Long.box(i), Long.box(i)) } - assert(writeMetrics.shuffleBytesWritten > 0) - assert(writeMetrics.shuffleRecordsWritten === 33) + assert(writeMetrics.bytesWritten > 0) + assert(writeMetrics.recordsWritten === 33) writer.commitAndClose() - assert(file.length() == writeMetrics.shuffleBytesWritten) + assert(file.length() == writeMetrics.bytesWritten) } test("verify write metrics on revert") { @@ -72,19 +72,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write - assert(writeMetrics.shuffleRecordsWritten === 1) + assert(writeMetrics.recordsWritten === 1) // Metrics don't update on every write - assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.bytesWritten == 0) // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() writer.write(Long.box(i), Long.box(i)) } - assert(writeMetrics.shuffleBytesWritten > 0) - assert(writeMetrics.shuffleRecordsWritten === 33) + assert(writeMetrics.bytesWritten > 0) + assert(writeMetrics.recordsWritten === 33) writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleBytesWritten == 0) - assert(writeMetrics.shuffleRecordsWritten == 0) + assert(writeMetrics.bytesWritten == 0) + assert(writeMetrics.recordsWritten == 0) } test("Reopening a closed block writer") { @@ -109,11 +109,11 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(i, i) } writer.commitAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - assert(writeMetrics.shuffleRecordsWritten === 1000) + val bytesWritten = writeMetrics.bytesWritten + assert(writeMetrics.recordsWritten === 1000) writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleRecordsWritten === 1000) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.recordsWritten === 1000) + assert(writeMetrics.bytesWritten === bytesWritten) } test("commitAndClose() should be idempotent") { @@ -125,13 +125,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(i, i) } writer.commitAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - val writeTime = writeMetrics.shuffleWriteTime - assert(writeMetrics.shuffleRecordsWritten === 1000) + val bytesWritten = writeMetrics.bytesWritten + val writeTime = writeMetrics.writeTime + assert(writeMetrics.recordsWritten === 1000) writer.commitAndClose() - assert(writeMetrics.shuffleRecordsWritten === 1000) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - assert(writeMetrics.shuffleWriteTime === writeTime) + assert(writeMetrics.recordsWritten === 1000) + assert(writeMetrics.bytesWritten === bytesWritten) + assert(writeMetrics.writeTime === writeTime) } test("revertPartialWritesAndClose() should be idempotent") { @@ -143,13 +143,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(i, i) } writer.revertPartialWritesAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - val writeTime = writeMetrics.shuffleWriteTime - assert(writeMetrics.shuffleRecordsWritten === 0) + val bytesWritten = writeMetrics.bytesWritten + val writeTime = writeMetrics.writeTime + assert(writeMetrics.recordsWritten === 0) writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleRecordsWritten === 0) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - assert(writeMetrics.shuffleWriteTime === writeTime) + assert(writeMetrics.recordsWritten === 0) + assert(writeMetrics.bytesWritten === bytesWritten) + assert(writeMetrics.writeTime === writeTime) } test("fileSegment() can only be called after commitAndClose() has been called") { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e02f5a1b20fe3..ee2d56a679395 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -277,7 +277,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with shuffleReadMetrics.incRemoteBytesRead(base + 1) shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) - shuffleWriteMetrics.incShuffleBytesWritten(base + 3) + shuffleWriteMetrics.incBytesWritten(base + 3) taskMetrics.setExecutorRunTime(base + 4) taskMetrics.incDiskBytesSpilled(base + 5) taskMetrics.incMemoryBytesSpilled(base + 6) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 068e8397c89bb..9dd400fc1c2de 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -227,7 +227,7 @@ class JsonProtocolSuite extends SparkFunSuite { .removeField { case (field, _) => field == "Shuffle Records Written" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) assert(newMetrics.shuffleReadMetrics.get.recordsRead == 0) - assert(newMetrics.shuffleWriteMetrics.get.shuffleRecordsWritten == 0) + assert(newMetrics.shuffleWriteMetrics.get.recordsWritten == 0) } test("OutputMetrics backward compatibility") { @@ -568,8 +568,8 @@ class JsonProtocolSuite extends SparkFunSuite { } private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) { - assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten) - assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime) + assert(metrics1.bytesWritten === metrics2.bytesWritten) + assert(metrics1.writeTime === metrics2.writeTime) } private def assertEquals(metrics1: InputMetrics, metrics2: InputMetrics) { @@ -794,9 +794,9 @@ class JsonProtocolSuite extends SparkFunSuite { t.outputMetrics = Some(outputMetrics) } else { val sw = new ShuffleWriteMetrics - sw.incShuffleBytesWritten(a + b + c) - sw.incShuffleWriteTime(b + c + d) - sw.setShuffleRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) + sw.incBytesWritten(a + b + c) + sw.incWriteTime(b + c + d) + sw.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) t.shuffleWriteMetrics = Some(sw) } // Make at most 6 blocks From 74ba84b64cab0bf3828033037267955aca296d3a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 18 Jan 2016 19:40:10 -0800 Subject: [PATCH 1613/2089] [HOT][BUILD] Changed the import order This PR is to fix the master's build break. The following tests failed due to the import order issues in the master. https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/49651/consoleFull https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/49652/consoleFull https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/49653/consoleFull Author: gatorsmile Closes #10823 from gatorsmile/importOrder. --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/execution/joins/InnerJoinSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index df0f730499211..c4ddb6d76b2c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index ab81b702596af..149f34dbd748f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner @@ -24,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -import org.apache.spark.sql.{DataFrame, Row, SQLConf} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder From 453dae56716bc254bf5022fddc9b8327c9b1a49f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 18 Jan 2016 21:42:07 -0800 Subject: [PATCH 1614/2089] [SPARK-12668][SQL] Providing aliases for CSV options to be similar to Pandas and R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-12668 Spark CSV datasource has been being merged (filed in [SPARK-12420](https://issues.apache.org/jira/browse/SPARK-12420)). This is a quicky PR that simply renames several CSV options to similar Pandas and R. - Alias for delimiter ­-> sep - charset -­> encoding Author: hyukjinkwon Closes #10800 from HyukjinKwon/SPARK-12668. --- .../execution/datasources/csv/CSVParameters.scala | 8 +++++--- .../execution/datasources/csv/CSVRelation.scala | 3 ++- .../sql/execution/datasources/csv/CSVSuite.scala | 15 +++++++++++++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala index ec16bdbd8bfb3..127c9728da2d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -21,7 +21,7 @@ import java.nio.charset.Charset import org.apache.spark.Logging -private[sql] case class CSVParameters(parameters: Map[String, String]) extends Logging { +private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) @@ -44,9 +44,11 @@ private[sql] case class CSVParameters(parameters: Map[String, String]) extends L } } - val delimiter = CSVTypeCast.toChar(parameters.getOrElse("delimiter", ",")) + val delimiter = CSVTypeCast.toChar( + parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = parameters.getOrElse("charset", Charset.forName("UTF-8").name()) + val charset = parameters.getOrElse("encoding", + parameters.getOrElse("charset", Charset.forName("UTF-8").name())) val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 9267479755e82..53818853ffb3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -58,9 +58,10 @@ private[csv] class CSVRelation( if (Charset.forName(params.charset) == Charset.forName("UTF-8")) { sqlContext.sparkContext.textFile(location) } else { + val charset = params.charset sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) .mapPartitions { _.map { pair => - new String(pair._2.getBytes, 0, pair._2.getLength, params.charset) + new String(pair._2.getBytes, 0, pair._2.getLength, charset) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 8fdd31aa4334f..071b5ef56d58b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -122,7 +122,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(exception.getMessage.contains("1-9588-osi")) } - ignore("test different encoding") { + test("test different encoding") { // scalastyle:off sqlContext.sql( s""" @@ -135,6 +135,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(sqlContext.table("carsTable"), withHeader = true) } + test("test aliases sep and encoding for delimiter and charset") { + val cars = sqlContext + .read + .format("csv") + .option("header", "true") + .option("encoding", "iso-8859-1") + .option("sep", "þ") + .load(testFile(carsFile8859)) + + verifyCars(cars, withHeader = true) + } + test("DDL test with tab separated file") { sqlContext.sql( s""" @@ -337,5 +349,4 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } - } From c00744e60f77edb238aff1e30b450dca65451e91 Mon Sep 17 00:00:00 2001 From: proflin Date: Tue, 19 Jan 2016 00:15:43 -0800 Subject: [PATCH 1615/2089] [SQL][MINOR] Fix one little mismatched comment according to the codes in interface.scala Author: proflin Closes #10824 from proflin/master. --- .../main/scala/org/apache/spark/sql/sources/interfaces.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7800776fa1a5a..8911ad370aa7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -426,7 +426,7 @@ abstract class OutputWriter { * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file * systems, it's able to discover partitioning information from the paths of input directories, and * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] - * must override one of the three `buildScan` methods to implement the read path. + * must override one of the four `buildScan` methods to implement the read path. * * For the write path, it provides the ability to write to both non-partitioned and partitioned * tables. Directory layout of the partitioned tables is compatible with Hive. From d8c4b00a234514cc3a877e3daed5d1378a2637e8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 19 Jan 2016 09:34:49 +0000 Subject: [PATCH 1616/2089] [SPARK-7683][PYSPARK] Confusing behavior of fold function of RDD in pyspark Fix order of arguments that Pyspark RDD.fold passes to its op - should be (acc, obj) like other implementations. Obviously, this is a potentially breaking change, so can only happen for 2.x CC davies Author: Sean Owen Closes #10771 from srowen/SPARK-7683. --- python/pyspark/rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a019c05862549..c28594625457a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -861,7 +861,7 @@ def fold(self, zeroValue, op): def func(iterator): acc = zeroValue for obj in iterator: - acc = op(obj, acc) + acc = op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided From ebd9ce0f1f55f7d2d3bd3b92c4b0a495c51ac6fd Mon Sep 17 00:00:00 2001 From: Wojciech Jurczyk Date: Tue, 19 Jan 2016 09:36:45 +0000 Subject: [PATCH 1617/2089] [MLLIB] Fix CholeskyDecomposition assertion's message Change assertion's message so it's consistent with the code. The old message says that the invoked method was lapack.dports, where in fact it was lapack.dppsv method. Author: Wojciech Jurczyk Closes #10818 from wjur/wjur/rename_error_message. --- .../org/apache/spark/mllib/linalg/CholeskyDecomposition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 0cd371e9cce34..ffdcddec11470 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -37,7 +37,7 @@ private[spark] object CholeskyDecomposition { val info = new intW(0) lapack.dppsv("U", k, 1, A, bx, k, info) val code = info.`val` - assert(code == 0, s"lapack.dpotrs returned $code.") + assert(code == 0, s"lapack.dppsv returned $code.") bx } From 0ddba6d88ff093a96b4931f71bd0a599afbbca78 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 19 Jan 2016 10:15:54 -0800 Subject: [PATCH 1618/2089] [SPARK-11944][PYSPARK][MLLIB] python mllib.clustering.bisecting k means From the coverage issues for 1.6 : Add Python API for mllib.clustering.BisectingKMeans. Author: Holden Karau Closes #10150 from holdenk/SPARK-11937-python-api-coverage-SPARK-11944-python-mllib.clustering.BisectingKMeans. --- .../mllib/api/python/PythonMLLibAPI.scala | 17 +++ python/pyspark/mllib/clustering.py | 136 +++++++++++++++++- python/pyspark/mllib/tests.py | 11 ++ 3 files changed, 159 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 061db56c74938..05f9a76d32671 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -119,6 +119,23 @@ private[python] class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib BisectingKMeans.run() + */ + def trainBisectingKMeans( + data: JavaRDD[Vector], + k: Int, + maxIterations: Int, + minDivisibleClusterSize: Double, + seed: Long): BisectingKMeansModel = { + new BisectingKMeans() + .setK(k) + .setMaxIterations(maxIterations) + .setMinDivisibleClusterSize(minDivisibleClusterSize) + .setSeed(seed) + .run(data) + } + /** * Java stub for Python mllib LinearRegressionWithSGD.train() */ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 580cb512d8025..4e9eb96fd9da1 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -38,12 +38,129 @@ from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable from pyspark.streaming import DStream -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', - 'PowerIterationClusteringModel', 'PowerIterationClustering', - 'StreamingKMeans', 'StreamingKMeansModel', +__all__ = ['BisectingKMeansModel', 'BisectingKMeans', 'KMeansModel', 'KMeans', + 'GaussianMixtureModel', 'GaussianMixture', 'PowerIterationClusteringModel', + 'PowerIterationClustering', 'StreamingKMeans', 'StreamingKMeansModel', 'LDA', 'LDAModel'] +@inherit_doc +class BisectingKMeansModel(JavaModelWrapper): + """ + .. note:: Experimental + + A clustering model derived from the bisecting k-means method. + + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) + >>> bskm = BisectingKMeans() + >>> model = bskm.train(sc.parallelize(data, 2), k=4) + >>> p = array([0.0, 0.0]) + >>> model.predict(p) + 0 + >>> model.k + 4 + >>> model.computeCost(p) + 0.0 + + .. versionadded:: 2.0.0 + """ + + def __init__(self, java_model): + super(BisectingKMeansModel, self).__init__(java_model) + self.centers = [c.toArray() for c in self.call("clusterCenters")] + + @property + @since('2.0.0') + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy + arrays.""" + return self.centers + + @property + @since('2.0.0') + def k(self): + """Get the number of clusters""" + return self.call("k") + + @since('2.0.0') + def predict(self, x): + """ + Find the cluster that each of the points belongs to in this + model. + + :param x: the point (or RDD of points) to determine + compute the clusters for. + """ + if isinstance(x, RDD): + vecs = x.map(_convert_to_vector) + return self.call("predict", vecs) + + x = _convert_to_vector(x) + return self.call("predict", x) + + @since('2.0.0') + def computeCost(self, x): + """ + Return the Bisecting K-means cost (sum of squared distances of + points to their nearest center) for this model on the given + data. If provided with an RDD of points returns the sum. + + :param point: the point or RDD of points to compute the cost(s). + """ + if isinstance(x, RDD): + vecs = x.map(_convert_to_vector) + return self.call("computeCost", vecs) + + return self.call("computeCost", _convert_to_vector(x)) + + +class BisectingKMeans(object): + """ + .. note:: Experimental + + A bisecting k-means algorithm based on the paper "A comparison of + document clustering techniques" by Steinbach, Karypis, and Kumar, + with modification to fit Spark. + The algorithm starts from a single cluster that contains all points. + Iteratively it finds divisible clusters on the bottom level and + bisects each of them using k-means, until there are `k` leaf + clusters in total or no leaf clusters are divisible. + The bisecting steps of clusters on the same level are grouped + together to increase parallelism. If bisecting all divisible + clusters on the bottom level would result more than `k` leaf + clusters, larger clusters get higher priority. + + Based on + U{http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf} + Steinbach, Karypis, and Kumar, A comparison of document clustering + techniques, KDD Workshop on Text Mining, 2000. + + .. versionadded:: 2.0.0 + """ + + @since('2.0.0') + def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1888008604): + """ + Runs the bisecting k-means algorithm return the model. + + :param rdd: input RDD to be trained on + :param k: The desired number of leaf clusters (default: 4). + The actual number could be smaller if there are no divisible + leaf clusters. + :param maxIterations: the max number of k-means iterations to + split clusters (default: 20) + :param minDivisibleClusterSize: the minimum number of points + (if >= 1.0) or the minimum proportion of points (if < 1.0) + of a divisible cluster (default: 1) + :param seed: a random seed (default: -1888008604 from + classOf[BisectingKMeans].getName.##) + """ + java_model = callMLlibFunc( + "trainBisectingKMeans", rdd.map(_convert_to_vector), + k, maxIterations, minDivisibleClusterSize, seed) + return BisectingKMeansModel(java_model) + + @inherit_doc class KMeansModel(Saveable, Loader): @@ -118,7 +235,13 @@ def k(self): @since('0.9.0') def predict(self, x): - """Find the cluster to which x belongs in this model.""" + """ + Find the cluster that each of the points belongs to in this + model. + + :param x: the point (or RDD of points) to determine + compute the clusters for. + """ best = 0 best_distance = float("inf") if isinstance(x, RDD): @@ -136,7 +259,10 @@ def predict(self, x): def computeCost(self, rdd): """ Return the K-means cost (sum of squared distances of points to - their nearest center) for this model on the given data. + their nearest center) for this model on the given + data. + + :param point: the RDD of points to compute the cost on. """ cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), [_convert_to_vector(c) for c in self.centers]) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3436a28b2974f..32ed48e10388e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -419,6 +419,17 @@ class ListTests(MLlibTestCase): as NumPy arrays. """ + def test_bisecting_kmeans(self): + from pyspark.mllib.clustering import BisectingKMeans + data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) + bskm = BisectingKMeans() + model = bskm.train(sc.parallelize(data, 2), k=4) + p = array([0.0, 0.0]) + rdd_p = self.sc.parallelize([p]) + self.assertEqual(model.predict(p), model.predict(rdd_p).first()) + self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) + self.assertEqual(model.k, len(model.clusterCenters)) + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ From e14817b528ccab4b4685b45a95e2325630b5fc53 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Jan 2016 10:44:51 -0800 Subject: [PATCH 1619/2089] [SPARK-12870][SQL] better format bucket id in file name for normal parquet file without bucket, it's file name ends with a jobUUID which maybe all numbers and mistakeny regarded as bucket id. This PR improves the format of bucket id in file name by using a different seperator, `_`, so that the regex is more robust. Author: Wenchen Fan Closes #10799 from cloud-fan/fix-bucket. --- .../spark/sql/execution/datasources/bucket.scala | 14 ++++++++++---- .../execution/datasources/json/JSONRelation.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 2 +- .../apache/spark/sql/hive/orc/OrcRelation.scala | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index c7ecd6125d860..3e0d484b74cfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -57,15 +57,21 @@ private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFact private[sql] object BucketingUtils { // The file name of bucketed data should have 3 parts: - // 1. some other information in the head of file name, ends with `-` - // 2. bucket id part, some numbers + // 1. some other information in the head of file name + // 2. bucket id part, some numbers, starts with "_" + // * The other-information part may use `-` as separator and may have numbers at the end, + // e.g. a normal parquet file without bucketing may have name: + // part-r-00000-2dd664f9-d2c4-4ffe-878f-431234567891.gz.parquet, and we will mistakenly + // treat `431234567891` as bucket id. So here we pick `_` as separator. // 3. optional file extension part, in the tail of file name, starts with `.` // An example of bucketed parquet file name with bucket id 3: - // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb-00003.gz.parquet - private val bucketedFileName = """.*-(\d+)(?:\..*)?$""".r + // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r def getBucketId(fileName: String): Option[Int] = fileName match { case bucketedFileName(bucketId) => Some(bucketId.toInt) case other => None } + + def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 20c60b9c43e10..31c5620c9a80e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -193,7 +193,7 @@ private[json] class JsonOutputWriter( val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } }.getRecordWriter(context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 30ddec686c921..b460ec1d26047 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -90,7 +90,7 @@ private[sql] class ParquetOutputWriter( val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 40409169b095a..800823febab26 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -103,7 +103,7 @@ private[orc] class OrcOutputWriter( val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString.orc" new OrcOutputFormat().getRecordWriter( From b122c861cd72b580334a7532f0a52c0439552bdf Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 19 Jan 2016 10:58:51 -0800 Subject: [PATCH 1620/2089] [SPARK-12887] Do not expose var's in TaskMetrics This is a step in implementing SPARK-10620, which migrates TaskMetrics to accumulators. TaskMetrics has a bunch of var's, some are fully public, some are `private[spark]`. This is bad coding style that makes it easy to accidentally overwrite previously set metrics. This has happened a few times in the past and caused bugs that were difficult to debug. Instead, we should have get-or-create semantics, which are more readily understandable. This makes sense in the case of TaskMetrics because these are just aggregated metrics that we want to collect throughout the task, so it doesn't matter who's incrementing them. Parent PR: #10717 Author: Andrew Or Author: Josh Rosen Author: andrewor14 Closes #10815 from andrewor14/get-or-create-metrics. --- .../sort/BypassMergeSortShuffleWriter.java | 3 +- .../shuffle/sort/UnsafeShuffleWriter.java | 3 +- .../unsafe/sort/UnsafeExternalSorter.java | 4 +- .../scala/org/apache/spark/CacheManager.scala | 6 +- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/executor/TaskMetrics.scala | 180 ++++++++++++------ .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 3 +- .../apache/spark/rdd/PairRDDFunctions.scala | 46 +++-- .../shuffle/BlockStoreShuffleReader.scala | 4 +- .../shuffle/hash/HashShuffleWriter.scala | 3 +- .../shuffle/sort/SortShuffleWriter.scala | 6 +- .../apache/spark/storage/BlockManager.scala | 12 +- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../spark/storage/StorageStatusListener.scala | 2 +- .../apache/spark/ui/storage/StorageTab.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 165 ++++++++-------- .../util/collection/ExternalSorter.scala | 10 +- .../org/apache/spark/CacheManagerSuite.scala | 2 +- .../spark/executor/TaskMetricsSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../storage/StorageStatusListenerSuite.scala | 12 +- .../ui/jobs/JobProgressListenerSuite.scala | 23 +-- .../spark/ui/storage/StorageTabSuite.scala | 8 +- .../apache/spark/util/JsonProtocolSuite.scala | 17 +- .../datasources/SqlNewHadoopRDD.scala | 3 +- .../execution/UnsafeRowSerializerSuite.scala | 1 - 27 files changed, 281 insertions(+), 246 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index a06dc1ce91542..dc4f289ae7f89 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -114,8 +114,7 @@ public BypassMergeSortShuffleWriter( this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = new ShuffleWriteMetrics(); - taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); this.serializer = Serializer.getSerializer(dep.serializer()); this.shuffleBlockResolver = shuffleBlockResolver; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index c8cc7056975ec..d3d79a27ea1c6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -119,8 +119,7 @@ public UnsafeShuffleWriter( this.shuffleId = dep.shuffleId(); this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = new ShuffleWriteMetrics(); - taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 68dc0c6d415f6..a6edc1ad3f665 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -122,9 +122,7 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - // TODO: metrics tracking + integration with shuffle write metrics - // need to connect the write metrics to task metrics so we count the spill IO somewhere. - this.writeMetrics = new ShuffleWriteMetrics(); + this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index d92d8b0eef8a0..fa8e2b953835b 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -43,8 +43,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - val existingMetrics = context.taskMetrics - .getInputMetricsForReadMethod(blockResult.readMethod) + val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) existingMetrics.incBytesRead(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] @@ -66,11 +65,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { try { logInfo(s"Partition $key not found, computing it") val computedValues = rdd.computeOrReadCheckpoint(partition, context) - - // Otherwise, cache the values val cachedValues = putInBlockManager(key, computedValues, storageLevel) new InterruptibleIterator(context, cachedValues) - } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9b14184364246..75d7e34d60eb2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -425,7 +425,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => - metrics.updateShuffleReadMetrics() + metrics.mergeShuffleReadMetrics() metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) metrics.updateAccumulators() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ce1fcbff71208..32ef5a9b5606f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -102,14 +102,37 @@ class TaskMetrics extends Serializable { private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value - /** - * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read - * are stored here. - */ private var _inputMetrics: Option[InputMetrics] = None + /** + * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted + * data, defined only in tasks with input. + */ def inputMetrics: Option[InputMetrics] = _inputMetrics + /** + * Get or create a new [[InputMetrics]] associated with this task. + */ + private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = { + synchronized { + val metrics = _inputMetrics.getOrElse { + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + } + // If there already exists an InputMetric with the same read method, we can just return + // that one. Otherwise, if the read method is different from the one previously seen by + // this task, we return a new dummy one to avoid clobbering the values of the old metrics. + // In the future we should try to store input metrics from all different read methods at + // the same time (SPARK-5225). + if (metrics.readMethod == readMethod) { + metrics + } else { + new InputMetrics(readMethod) + } + } + } + /** * This should only be used when recreating TaskMetrics, not when updating input metrics in * executors @@ -118,18 +141,37 @@ class TaskMetrics extends Serializable { _inputMetrics = inputMetrics } + private var _outputMetrics: Option[OutputMetrics] = None + /** - * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much - * data was written are stored here. + * Metrics related to writing data externally (e.g. to a distributed filesystem), + * defined only in tasks with output. */ - var outputMetrics: Option[OutputMetrics] = None + def outputMetrics: Option[OutputMetrics] = _outputMetrics + + @deprecated("setting OutputMetrics is for internal use only", "2.0.0") + def outputMetrics_=(om: Option[OutputMetrics]): Unit = { + _outputMetrics = om + } /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. - * This includes read metrics aggregated over all the task's shuffle dependencies. + * Get or create a new [[OutputMetrics]] associated with this task. */ + private[spark] def registerOutputMetrics( + writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized { + _outputMetrics.getOrElse { + val metrics = new OutputMetrics(writeMethod) + _outputMetrics = Some(metrics) + metrics + } + } + private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None + /** + * Metrics related to shuffle read aggregated across all shuffle dependencies. + * This is defined only if there are shuffle dependencies in this task. + */ def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** @@ -141,66 +183,35 @@ class TaskMetrics extends Serializable { } /** - * ShuffleReadMetrics per dependency for collecting independently while task is in progress. - */ - @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] = - new ArrayBuffer[ShuffleReadMetrics]() - - /** - * If this task writes to shuffle output, metrics on the written shuffle data will be collected - * here - */ - var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None - - /** - * Storage statuses of any blocks that have been updated as a result of this task. - */ - var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None - - /** + * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency. + * * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization - * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each - * dependency, and merge these metrics before reporting them to the driver. This method returns - * a ShuffleReadMetrics for a dependency and registers it for merging later. - */ - private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized { - val readMetrics = new ShuffleReadMetrics() - depsShuffleReadMetrics += readMetrics - readMetrics - } + * issues from readers in different threads, in-progress tasks use a [[ShuffleReadMetrics]] for + * each dependency and merge these metrics before reporting them to the driver. + */ + @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[ShuffleReadMetrics] /** - * Returns the input metrics object that the task should use. Currently, if - * there exists an input metric with the same readMethod, we return that one - * so the caller can accumulate bytes read. If the readMethod is different - * than previously seen by this task, we return a new InputMetric but don't - * record it. + * Create a temporary [[ShuffleReadMetrics]] for a particular shuffle dependency. * - * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, - * we can store all the different inputMetrics (one per readMethod). + * All usages are expected to be followed by a call to [[mergeShuffleReadMetrics]], which + * merges the temporary values synchronously. Otherwise, all temporary data collected will + * be lost. */ - private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { - synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) - } - } + private[spark] def registerTempShuffleReadMetrics(): ShuffleReadMetrics = synchronized { + val readMetrics = new ShuffleReadMetrics + tempShuffleReadMetrics += readMetrics + readMetrics } /** - * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. + * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. + * This is expected to be called on executor heartbeat and at the end of a task. */ - private[spark] def updateShuffleReadMetrics(): Unit = synchronized { - if (!depsShuffleReadMetrics.isEmpty) { - val merged = new ShuffleReadMetrics() - for (depMetrics <- depsShuffleReadMetrics) { + private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { + if (tempShuffleReadMetrics.nonEmpty) { + val merged = new ShuffleReadMetrics + for (depMetrics <- tempShuffleReadMetrics) { merged.incFetchWaitTime(depMetrics.fetchWaitTime) merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) @@ -212,6 +223,55 @@ class TaskMetrics extends Serializable { } } + private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None + + /** + * Metrics related to shuffle write, defined only in shuffle map stages. + */ + def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics + + @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0") + def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = { + _shuffleWriteMetrics = swm + } + + /** + * Get or create a new [[ShuffleWriteMetrics]] associated with this task. + */ + private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized { + _shuffleWriteMetrics.getOrElse { + val metrics = new ShuffleWriteMetrics + _shuffleWriteMetrics = Some(metrics) + metrics + } + } + + private var _updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = + Seq.empty[(BlockId, BlockStatus)] + + /** + * Storage statuses of any blocks that have been updated as a result of this task. + */ + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses + + @deprecated("setting updated blocks is for internal use only", "2.0.0") + def updatedBlocks_=(ub: Option[Seq[(BlockId, BlockStatus)]]): Unit = { + _updatedBlockStatuses = ub.getOrElse(Seq.empty[(BlockId, BlockStatus)]) + } + + private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = { + _updatedBlockStatuses ++= v + } + + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = { + _updatedBlockStatuses = v + } + + @deprecated("use updatedBlockStatuses instead", "2.0.0") + def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { + if (_updatedBlockStatuses.nonEmpty) Some(_updatedBlockStatuses) else None + } + private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index a7a6e0b8a94f6..a79ab86d49227 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -212,7 +212,7 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) // Sets the thread local variable for the file's name split.inputSplit.value match { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7a1197830443f..5cc9c81cc6749 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -129,8 +129,7 @@ class NewHadoopRDD[K, V]( logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 16a856f594e97..33f2f0b44f773 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1092,7 +1092,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) + val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = + initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") @@ -1103,15 +1104,17 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.write(pair._1, pair._2) // Update bytes written metric every few records - maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) - bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } - outputMetrics.setRecordsWritten(recordsWritten) + outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } 1 } : Int @@ -1177,7 +1180,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) + val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = + initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1189,35 +1193,43 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) // Update bytes written metric every few records - maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } } { writer.close() } writer.commit() - bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } - outputMetrics.setRecordsWritten(recordsWritten) + outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } } self.context.runJob(self, writeToFile) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + // return type: (output metrics, bytes written callback), defined only if the latter is defined + private def initHadoopOutputMetrics( + context: TaskContext): Option[(OutputMetrics, () => Long)] = { val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) - if (bytesWrittenCallback.isDefined) { - context.taskMetrics.outputMetrics = Some(outputMetrics) + bytesWrittenCallback.map { b => + (context.taskMetrics().registerOutputMetrics(DataWriteMethod.Hadoop), b) } - (outputMetrics, bytesWrittenCallback) } - private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], - outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { + private def maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], + recordsWritten: Long): Unit = { if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } - outputMetrics.setRecordsWritten(recordsWritten) + outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b0abda4a81b8d..a57e5b0bfb865 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -65,13 +65,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map(record => { readMetrics.incRecordsRead(1) record }), - context.taskMetrics().updateShuffleReadMetrics()) + context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 412bf70000da7..28bcced901a70 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -42,8 +42,7 @@ private[spark] class HashShuffleWriter[K, V]( // we don't try deleting files, etc twice. private var stopping = false - private val writeMetrics = new ShuffleWriteMetrics() - metrics.shuffleWriteMetrics = Some(writeMetrics) + private val writeMetrics = metrics.registerShuffleWriteMetrics() private val blockManager = SparkEnv.get.blockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5c5a5f5a4cb6a..7eb3d9603736c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -45,8 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null - private val writeMetrics = new ShuffleWriteMetrics() - context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) + private val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics() /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { @@ -93,8 +92,7 @@ private[spark] class SortShuffleWriter[K, V, C]( if (sorter != null) { val startTime = System.nanoTime() sorter.stop() - context.taskMetrics.shuffleWriteMetrics.foreach( - _.incWriteTime(System.nanoTime - startTime)) + writeMetrics.incWriteTime(System.nanoTime - startTime) sorter = null } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0a8e88df224a..77fd03a6bcfc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -800,10 +800,8 @@ private[spark] class BlockManager( if (tellMaster) { reportBlockStatus(blockId, putBlockInfo, putBlockStatus) } - Option(TaskContext.get()).foreach { taskContext => - val metrics = taskContext.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, putBlockStatus))) + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus))) } } } finally { @@ -1046,10 +1044,8 @@ private[spark] class BlockManager( blockInfo.remove(blockId) } if (blockIsUpdated) { - Option(TaskContext.get()).foreach { taskContext => - val metrics = taskContext.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, status))) + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status))) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 037bec1d9c33b..c6065df64ae03 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -101,7 +101,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() + private[this] val shuffleMetrics = context.taskMetrics().registerTempShuffleReadMetrics() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index ec711480ebf30..d98aae8ff0c68 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -63,7 +63,7 @@ class StorageStatusListener extends SparkListener { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { - val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + val updatedBlocks = metrics.updatedBlockStatuses if (updatedBlocks.length > 0) { updateStorageStatus(info.executorId, updatedBlocks) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 2d9b885c684b2..f1e28b4e1e9c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -63,8 +63,8 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc */ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val metrics = taskEnd.taskMetrics - if (metrics != null && metrics.updatedBlocks.isDefined) { - updateRDDInfo(metrics.updatedBlocks.get) + if (metrics != null && metrics.updatedBlockStatuses.nonEmpty) { + updateRDDInfo(metrics.updatedBlockStatuses) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b88221a249eb8..efa22b99936af 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -292,21 +292,38 @@ private[spark] object JsonProtocol { } def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { - val shuffleReadMetrics = - taskMetrics.shuffleReadMetrics.map(shuffleReadMetricsToJson).getOrElse(JNothing) - val shuffleWriteMetrics = - taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) - val inputMetrics = - taskMetrics.inputMetrics.map(inputMetricsToJson).getOrElse(JNothing) - val outputMetrics = - taskMetrics.outputMetrics.map(outputMetricsToJson).getOrElse(JNothing) - val updatedBlocks = - taskMetrics.updatedBlocks.map { blocks => - JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) - }) + val shuffleReadMetrics: JValue = + taskMetrics.shuffleReadMetrics.map { rm => + ("Remote Blocks Fetched" -> rm.remoteBlocksFetched) ~ + ("Local Blocks Fetched" -> rm.localBlocksFetched) ~ + ("Fetch Wait Time" -> rm.fetchWaitTime) ~ + ("Remote Bytes Read" -> rm.remoteBytesRead) ~ + ("Local Bytes Read" -> rm.localBytesRead) ~ + ("Total Records Read" -> rm.recordsRead) + }.getOrElse(JNothing) + val shuffleWriteMetrics: JValue = + taskMetrics.shuffleWriteMetrics.map { wm => + ("Shuffle Bytes Written" -> wm.shuffleBytesWritten) ~ + ("Shuffle Write Time" -> wm.shuffleWriteTime) ~ + ("Shuffle Records Written" -> wm.shuffleRecordsWritten) + }.getOrElse(JNothing) + val inputMetrics: JValue = + taskMetrics.inputMetrics.map { im => + ("Data Read Method" -> im.readMethod.toString) ~ + ("Bytes Read" -> im.bytesRead) ~ + ("Records Read" -> im.recordsRead) + }.getOrElse(JNothing) + val outputMetrics: JValue = + taskMetrics.outputMetrics.map { om => + ("Data Write Method" -> om.writeMethod.toString) ~ + ("Bytes Written" -> om.bytesWritten) ~ + ("Records Written" -> om.recordsWritten) }.getOrElse(JNothing) + val updatedBlocks = + JArray(taskMetrics.updatedBlockStatuses.toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) + }) ("Host Name" -> taskMetrics.hostname) ~ ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ @@ -322,34 +339,6 @@ private[spark] object JsonProtocol { ("Updated Blocks" -> updatedBlocks) } - def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = { - ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ - ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ - ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ - ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~ - ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~ - ("Total Records Read" -> shuffleReadMetrics.recordsRead) - } - - // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. - def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = { - ("Shuffle Bytes Written" -> shuffleWriteMetrics.bytesWritten) ~ - ("Shuffle Write Time" -> shuffleWriteMetrics.writeTime) ~ - ("Shuffle Records Written" -> shuffleWriteMetrics.recordsWritten) - } - - def inputMetricsToJson(inputMetrics: InputMetrics): JValue = { - ("Data Read Method" -> inputMetrics.readMethod.toString) ~ - ("Bytes Read" -> inputMetrics.bytesRead) ~ - ("Records Read" -> inputMetrics.recordsRead) - } - - def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = { - ("Data Write Method" -> outputMetrics.writeMethod.toString) ~ - ("Bytes Written" -> outputMetrics.bytesWritten) ~ - ("Records Written" -> outputMetrics.recordsWritten) - } - def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { val reason = Utils.getFormattedClassName(taskEndReason) val json: JObject = taskEndReason match { @@ -721,58 +710,54 @@ private[spark] object JsonProtocol { metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long]) metrics.incMemoryBytesSpilled((json \ "Memory Bytes Spilled").extract[Long]) metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) - metrics.setShuffleReadMetrics( - Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) - metrics.shuffleWriteMetrics = - Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) - metrics.setInputMetrics( - Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson)) - metrics.outputMetrics = - Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson) - metrics.updatedBlocks = - Utils.jsonOption(json \ "Updated Blocks").map { value => - value.extract[List[JValue]].map { block => - val id = BlockId((block \ "Block ID").extract[String]) - val status = blockStatusFromJson(block \ "Status") - (id, status) - } - } - metrics - } - def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = { - val metrics = new ShuffleReadMetrics - metrics.incRemoteBlocksFetched((json \ "Remote Blocks Fetched").extract[Int]) - metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int]) - metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long]) - metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long]) - metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0)) - metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0)) - metrics - } + // Shuffle read metrics + Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + val readMetrics = metrics.registerTempShuffleReadMetrics() + readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) + readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) + readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) + readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) + readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L)) + metrics.mergeShuffleReadMetrics() + } - def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = { - val metrics = new ShuffleWriteMetrics - metrics.incBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) - metrics.incWriteTime((json \ "Shuffle Write Time").extract[Long]) - metrics.setRecordsWritten((json \ "Shuffle Records Written") - .extractOpt[Long].getOrElse(0)) - metrics - } + // Shuffle write metrics + // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. + Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + val writeMetrics = metrics.registerShuffleWriteMetrics() + writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) + writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written") + .extractOpt[Long].getOrElse(0L)) + writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) + } - def inputMetricsFromJson(json: JValue): InputMetrics = { - val metrics = new InputMetrics( - DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.incBytesRead((json \ "Bytes Read").extract[Long]) - metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0)) - metrics - } + // Output metrics + Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + val writeMethod = DataWriteMethod.withName((outJson \ "Data Write Method").extract[String]) + val outputMetrics = metrics.registerOutputMetrics(writeMethod) + outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) + outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L)) + } + + // Input metrics + Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) + val inputMetrics = metrics.registerInputMetrics(readMethod) + inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + } + + // Updated blocks + Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => + val id = BlockId((blockJson \ "Block ID").extract[String]) + val status = blockStatusFromJson(blockJson \ "Status") + (id, status) + }) + } - def outputMetricsFromJson(json: JValue): OutputMetrics = { - val metrics = new OutputMetrics( - DataWriteMethod.withName((json \ "Data Write Method").extract[String])) - metrics.setBytesWritten((json \ "Bytes Written").extract[Long]) - metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0)) metrics } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4c7416e00b004..df9e0502e7361 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -644,6 +644,8 @@ private[spark] class ExternalSorter[K, V, C]( blockId: BlockId, outputFile: File): Array[Long] = { + val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics() + // Track location of each range in the output file val lengths = new Array[Long](numPartitions) @@ -652,8 +654,8 @@ private[spark] class ExternalSorter[K, V, C]( val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, - context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter( + blockId, outputFile, serInstance, fileBufferSize, writeMetrics) val partitionId = it.nextPartition() while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(writer) @@ -666,8 +668,8 @@ private[spark] class ExternalSorter[K, V, C]( // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, - context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter( + blockId, outputFile, serInstance, fileBufferSize, writeMetrics) for (elem <- elements) { writer.write(elem._1, elem._2) } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 3865c201bf893..48a0282b30cf0 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -88,7 +88,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before try { TaskContext.setTaskContext(context) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) - assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) + assert(context.taskMetrics.updatedBlockStatuses.size === 2) } finally { TaskContext.unset() } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 8275fd87764cd..e5ec2aa1be355 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite class TaskMetricsSuite extends SparkFunSuite { test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { val taskMetrics = new TaskMetrics() - taskMetrics.updateShuffleReadMetrics() + taskMetrics.mergeShuffleReadMetrics() assert(taskMetrics.shuffleReadMetrics.isEmpty) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6e6cf6385f919..e1b2c9633edca 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -855,7 +855,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } finally { TaskContext.unset() } - context.taskMetrics.updatedBlocks.getOrElse(Seq.empty) + context.taskMetrics.updatedBlockStatuses } // 1 updated block (i.e. list1) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 355d80d06898b..9de434166bba3 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -85,8 +85,8 @@ class StorageStatusListenerSuite extends SparkFunSuite { val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L)) val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L)) val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L)) - taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) - taskMetrics2.updatedBlocks = Some(Seq(block3)) + taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2)) + taskMetrics2.setUpdatedBlockStatuses(Seq(block3)) // Task end with new blocks assert(listener.executorIdToStorageStatus("big").numBlocks === 0) @@ -108,8 +108,8 @@ class StorageStatusListenerSuite extends SparkFunSuite { val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L)) val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L)) val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L)) - taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3)) - taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3)) + taskMetrics1.setUpdatedBlockStatuses(Seq(droppedBlock1, droppedBlock3)) + taskMetrics2.setUpdatedBlockStatuses(Seq(droppedBlock2, droppedBlock3)) listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) @@ -133,8 +133,8 @@ class StorageStatusListenerSuite extends SparkFunSuite { val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L)) val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L)) val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L)) - taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) - taskMetrics2.updatedBlocks = Some(Seq(block3)) + taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2)) + taskMetrics2.setUpdatedBlockStatuses(Seq(block3)) listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 3) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index ee2d56a679395..607617cbe91ca 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -184,12 +184,12 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val conf = new SparkConf() val listener = new JobProgressListener(conf) val taskMetrics = new TaskMetrics() - val shuffleReadMetrics = new ShuffleReadMetrics() + val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() assert(listener.stageIdToData.size === 0) // finish this task, should get updated shuffleRead shuffleReadMetrics.incRemoteBytesRead(1000) - taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) + taskMetrics.mergeShuffleReadMetrics() var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) @@ -270,22 +270,19 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with def makeTaskMetrics(base: Int): TaskMetrics = { val taskMetrics = new TaskMetrics() - val shuffleReadMetrics = new ShuffleReadMetrics() - val shuffleWriteMetrics = new ShuffleWriteMetrics() - taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) - taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + taskMetrics.setExecutorRunTime(base + 4) + taskMetrics.incDiskBytesSpilled(base + 5) + taskMetrics.incMemoryBytesSpilled(base + 6) + val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() shuffleReadMetrics.incRemoteBytesRead(base + 1) shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) + taskMetrics.mergeShuffleReadMetrics() + val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() shuffleWriteMetrics.incBytesWritten(base + 3) - taskMetrics.setExecutorRunTime(base + 4) - taskMetrics.incDiskBytesSpilled(base + 5) - taskMetrics.incMemoryBytesSpilled(base + 6) - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - taskMetrics.setInputMetrics(Some(inputMetrics)) + val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) inputMetrics.incBytesRead(base + 7) - val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) - taskMetrics.outputMetrics = Some(outputMetrics) + val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) outputMetrics.setBytesWritten(base + 8) taskMetrics } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 5ac922c2172ce..d1dbf7c1558b2 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -127,7 +127,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { // Task end with a few new persisted blocks, some from the same RDD val metrics1 = new TaskMetrics - metrics1.updatedBlocks = Some(Seq( + metrics1.setUpdatedBlockStatuses(Seq( (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L)), (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L)), (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L)) @@ -146,7 +146,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { // Task end with a few dropped blocks val metrics2 = new TaskMetrics - metrics2.updatedBlocks = Some(Seq( + metrics2.setUpdatedBlockStatuses(Seq( (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L)), (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L)), (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L)), // doesn't actually exist @@ -173,8 +173,8 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { val taskMetrics1 = new TaskMetrics val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L)) val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L)) - taskMetrics0.updatedBlocks = Some(Seq(block0)) - taskMetrics1.updatedBlocks = Some(Seq(block1)) + taskMetrics0.setUpdatedBlockStatuses(Seq(block0)) + taskMetrics1.setUpdatedBlockStatuses(Seq(block1)) bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener.rddInfoList.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 9dd400fc1c2de..e5ca2de4ad537 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -557,7 +557,7 @@ class JsonProtocolSuite extends SparkFunSuite { metrics1.shuffleWriteMetrics, metrics2.shuffleWriteMetrics, assertShuffleWriteEquals) assertOptionEquals( metrics1.inputMetrics, metrics2.inputMetrics, assertInputMetricsEquals) - assertOptionEquals(metrics1.updatedBlocks, metrics2.updatedBlocks, assertBlocksEquals) + assertBlocksEquals(metrics1.updatedBlockStatuses, metrics2.updatedBlockStatuses) } private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) { @@ -773,34 +773,31 @@ class JsonProtocolSuite extends SparkFunSuite { t.incMemoryBytesSpilled(a + c) if (hasHadoopInput) { - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) inputMetrics.incBytesRead(d + e + f) inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) - t.setInputMetrics(Some(inputMetrics)) } else { - val sr = new ShuffleReadMetrics + val sr = t.registerTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1) sr.incLocalBytesRead(a + f) - t.setShuffleReadMetrics(Some(sr)) + t.mergeShuffleReadMetrics() } if (hasOutput) { - val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) + val outputMetrics = t.registerOutputMetrics(DataWriteMethod.Hadoop) outputMetrics.setBytesWritten(a + b + c) outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1) - t.outputMetrics = Some(outputMetrics) } else { - val sw = new ShuffleWriteMetrics + val sw = t.registerShuffleWriteMetrics() sw.incBytesWritten(a + b + c) sw.incWriteTime(b + c + d) sw.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) - t.shuffleWriteMetrics = Some(sw) } // Make at most 6 blocks - t.updatedBlocks = Some((1 to (e % 5 + 1)).map { i => + t.setUpdatedBlockStatuses((1 to (e % 5 + 1)).map { i => (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i)) }.toSeq) t diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index d45d2db62f3a9..8222b84d33e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -126,8 +126,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf(isDriverSide = false) - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 9f09eb4429c12..7438e11ef7176 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -127,7 +127,6 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { assert(sorter.numSpills > 0) // Merging spilled files should not throw assertion error - taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } { // Clean up From 2388de51912efccaceeb663ac56fc500a79d2ceb Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 19 Jan 2016 11:08:52 -0800 Subject: [PATCH 1621/2089] [SPARK-12804][ML] Fix LogisticRegression with FitIntercept on all same label training data CC jkbradley mengxr dbtsai Author: Feynman Liang Closes #10743 from feynmanliang/SPARK-12804. --- .../classification/LogisticRegression.scala | 200 +++++++++--------- .../LogisticRegressionSuite.scala | 43 ++++ 2 files changed, 148 insertions(+), 95 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 486043e8d9741..dad8dfc84ec15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -276,113 +276,123 @@ class LogisticRegression @Since("1.2.0") ( val numClasses = histogram.length val numFeatures = summarizer.mean.size - if (numInvalid != 0) { - val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + - s"Found $numInvalid invalid labels." - logError(msg) - throw new SparkException(msg) - } - - if (numClasses > 2) { - val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + - s"binary classification. Found $numClasses in the input dataset." - logError(msg) - throw new SparkException(msg) - } + val (coefficients, intercept, objectiveHistory) = { + if (numInvalid != 0) { + val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + + s"Found $numInvalid invalid labels." + logError(msg) + throw new SparkException(msg) + } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + if (numClasses > 2) { + val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + + s"binary classification. Found $numClasses in the input dataset." + logError(msg) + throw new SparkException(msg) + } else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { + logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " + + s"zeros and the intercept will be positive infinity; as a result, " + + s"training is not needed.") + (Vectors.sparse(numFeatures, Seq()), Double.PositiveInfinity, Array.empty[Double]) + } else if ($(fitIntercept) && numClasses == 1) { + logWarning(s"All labels are zero and fitIntercept=true, so the coefficients will be " + + s"zeros and the intercept will be negative infinity; as a result, " + + s"training is not needed.") + (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double]) + } else { + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) - val regParamL1 = $(elasticNetParam) * $(regParam) - val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) + val regParamL1 = $(elasticNetParam) * $(regParam) + val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) - val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization), - featuresStd, featuresMean, regParamL2) + val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), + $(standardization), featuresStd, featuresMean, regParamL2) - val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - } else { - def regParamL1Fun = (index: Int) => { - // Remove the L1 penalization on the intercept - if (index == numFeatures) { - 0.0 + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { - if ($(standardization)) { - regParamL1 - } else { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0 + def regParamL1Fun = (index: Int) => { + // Remove the L1 penalization on the intercept + if (index == numFeatures) { + 0.0 + } else { + if ($(standardization)) { + regParamL1 + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0 + } + } } + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - } - new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) - } - - val initialCoefficientsWithIntercept = - Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) - - if ($(fitIntercept)) { - /* - For binary logistic regression, when we initialize the coefficients as zeros, - it will converge faster if we initialize the intercept such that - it follows the distribution of the labels. - - {{{ - P(0) = 1 / (1 + \exp(b)), and - P(1) = \exp(b) / (1 + \exp(b)) - }}}, hence - {{{ - b = \log{P(1) / P(0)} = \log{count_1 / count_0} - }}} - */ - initialCoefficientsWithIntercept.toArray(numFeatures) - = math.log(histogram(1) / histogram(0)) - } - val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.toBreeze.toDenseVector) + val initialCoefficientsWithIntercept = + Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) + + if ($(fitIntercept)) { + /* + For binary logistic regression, when we initialize the coefficients as zeros, + it will converge faster if we initialize the intercept such that + it follows the distribution of the labels. + + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} + */ + initialCoefficientsWithIntercept.toArray(numFeatures) = math.log( + histogram(1) / histogram(0)) + } - val (coefficients, intercept, objectiveHistory) = { - /* - Note that in Logistic Regression, the objective history (loss + regularization) - is log-likelihood which is invariance under feature standardization. As a result, - the objective history from optimizer is the same as the one in the original space. - */ - val arrayBuilder = mutable.ArrayBuilder.make[Double] - var state: optimizer.State = null - while (states.hasNext) { - state = states.next() - arrayBuilder += state.adjustedValue - } + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialCoefficientsWithIntercept.toBreeze.toDenseVector) + + /* + Note that in Logistic Regression, the objective history (loss + regularization) + is log-likelihood which is invariance under feature standardization. As a result, + the objective history from optimizer is the same as the one in the original space. + */ + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } - if (state == null) { - val msg = s"${optimizer.getClass.getName} failed." - logError(msg) - throw new SparkException(msg) - } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + logError(msg) + throw new SparkException(msg) + } - /* - The coefficients are trained in the scaled space; we're converting them back to - the original space. - Note that the intercept in scaled space and original space is the same; - as a result, no scaling is needed. - */ - val rawCoefficients = state.x.toArray.clone() - var i = 0 - while (i < numFeatures) { - rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } - i += 1 - } + /* + The coefficients are trained in the scaled space; we're converting them back to + the original space. + Note that the intercept in scaled space and original space is the same; + as a result, no scaling is needed. + */ + val rawCoefficients = state.x.toArray.clone() + var i = 0 + while (i < numFeatures) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } - if ($(fitIntercept)) { - (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, - arrayBuilder.result()) - } else { - (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) + if ($(fitIntercept)) { + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) + } else { + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index ff0d0ff771042..972c0868a454a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -883,6 +884,48 @@ class LogisticRegressionSuite assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) } + test("logistic regression with all labels the same") { + val sameLabels = dataset + .withColumn("zeroLabel", lit(0.0)) + .withColumn("oneLabel", lit(1.0)) + + // fitIntercept=true + val lrIntercept = new LogisticRegression() + .setFitIntercept(true) + .setMaxIter(3) + + val allZeroInterceptModel = lrIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allZeroInterceptModel.intercept === Double.NegativeInfinity) + assert(allZeroInterceptModel.summary.totalIterations === 0) + + val allOneInterceptModel = lrIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allOneInterceptModel.intercept === Double.PositiveInfinity) + assert(allOneInterceptModel.summary.totalIterations === 0) + + // fitIntercept=false + val lrNoIntercept = new LogisticRegression() + .setFitIntercept(false) + .setMaxIter(3) + + val allZeroNoInterceptModel = lrNoIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroNoInterceptModel.intercept === 0.0) + assert(allZeroNoInterceptModel.summary.totalIterations > 0) + + val allOneNoInterceptModel = lrNoIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneNoInterceptModel.intercept === 0.0) + assert(allOneNoInterceptModel.summary.totalIterations > 0) + } + test("read/write") { def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { assert(model.intercept === model2.intercept) From b72e01e82148a908eb19bb3f526f9777bfe27dde Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Jan 2016 11:35:58 -0800 Subject: [PATCH 1622/2089] [SPARK-12867][SQL] Nullability of Intersect can be stricter JIRA: https://issues.apache.org/jira/browse/SPARK-12867 When intersecting one nullable column with one non-nullable column, the result will not contain any null. Thus, we can make nullability of `intersect` stricter. liancheng Could you please check if the code changes are appropriate? Also added test cases to verify the results. Thanks! Author: gatorsmile Closes #10812 from gatorsmile/nullabilityIntersect. --- .../plans/logical/basicOperators.scala | 18 ++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 2a1b1b131d813..f4a3d85d2a8a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -91,11 +91,6 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = - left.output.zip(right.output).map { case (leftAttr, rightAttr) => - leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) - } - final override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && @@ -108,13 +103,24 @@ private[sql] object SetOperation { case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) + } + override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes Statistics(sizeInBytes = sizeInBytes) } } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) + } +} case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afc8df07fd9ab..bd11a387a1d5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -337,6 +337,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } + test("intersect - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(_.nullable == false)) + + val df1 = nonNullableInts.intersect(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(_.nullable == false)) + + val df2 = nullInts.intersect(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(_.nullable == false)) + + val df3 = nullInts.intersect(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable == true)) + + val df4 = nonNullableInts.intersect(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(_.nullable == false)) + } + test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) From 4dbd3161227a32736105cef624f9df21650a359c Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 19 Jan 2016 12:24:21 -0800 Subject: [PATCH 1623/2089] [SPARK-12560][SQL] SqlTestUtils.stripSparkFilter needs to copy utf8strings See https://issues.apache.org/jira/browse/SPARK-12560 This isn't causing any problems currently because the tests for string predicate pushdown are currently disabled. I ran into this while trying to turn them back on with a different version of parquet. Figure it was good to fix now in any case. Author: Imran Rashid Closes #10510 from squito/SPARK-12560. --- .../src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7df344edb4edd..5f73d71d4510a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -189,7 +189,7 @@ private[sql] trait SQLTestUtils .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() - .map(row => Row.fromSeq(row.toSeq(schema))) + .map(row => Row.fromSeq(row.copy().toSeq(schema))) sqlContext.createDataFrame(childRDD, schema) } From c78e2080e00a73159ab749691ad634fa6c0a2302 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 19 Jan 2016 12:31:03 -0800 Subject: [PATCH 1624/2089] [SPARK-12816][SQL] De-alias type when generating schemas Call `dealias` on local types to fix schema generation for abstract type members, such as ```scala type KeyValue = (Int, String) ``` Add simple test Author: Jakob Odersky Closes #10749 from jodersky/aliased-schema. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 5 ++++- .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 79f723cf9b8a0..643228d0eb27d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -642,7 +642,10 @@ trait ScalaReflection { * * @see SPARK-5281 */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + def localTypeOf[T: TypeTag]: `Type` = { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.normalize + } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index c2aace1ef238e..a32f5b70a0124 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -69,6 +69,10 @@ case class ComplexData( case class GenericData[A]( genericField: A) +object GenericData { + type IntData = GenericData[Int] +} + case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } @@ -186,6 +190,10 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } + test("type-aliased data") { + assert(schemaFor[GenericData[Int]] == schemaFor[GenericData.IntData]) + } + test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) From c6f971b4aeca7265ab374fa46c5c452461d9b6a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Lipt=C3=A1k?= Date: Tue, 19 Jan 2016 14:06:53 -0800 Subject: [PATCH 1625/2089] [SPARK-11295] Add packages to JUnit output for Python tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SPARK-11295 Add packages to JUnit output for Python tests This improves grouping/display of test case results. Author: Gábor Lipták Closes #9263 from gliptak/SPARK-11295. --- python/pyspark/ml/tests.py | 1 + python/pyspark/mllib/tests.py | 24 ++++++++++++++---------- python/pyspark/sql/tests.py | 1 + python/pyspark/streaming/tests.py | 1 + python/pyspark/tests.py | 1 + 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4eb17bfdcca90..9ea639dc4f960 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -394,6 +394,7 @@ def test_fit_maximize_metric(self): if __name__ == "__main__": + from pyspark.ml.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 32ed48e10388e..ea7d297cba2ae 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -77,21 +77,24 @@ pass ser = PickleSerializer() -sc = SparkContext('local[4]', "MLlib tests") class MLlibTestCase(unittest.TestCase): def setUp(self): - self.sc = sc + self.sc = SparkContext('local[4]', "MLlib tests") + + def tearDown(self): + self.sc.stop() class MLLibStreamingTestCase(unittest.TestCase): def setUp(self): - self.sc = sc + self.sc = SparkContext('local[4]', "MLlib tests") self.ssc = StreamingContext(self.sc, 1.0) def tearDown(self): self.ssc.stop(False) + self.sc.stop() @staticmethod def _eventually(condition, timeout=30.0, catch_assertions=False): @@ -1166,7 +1169,7 @@ def test_predictOn_model(self): clusterWeights=[1.0, 1.0, 1.0, 1.0]) predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] predict_stream = self.ssc.queueStream(predict_data) predict_val = stkm.predictOn(predict_stream) @@ -1197,7 +1200,7 @@ def test_trainOn_predictOn(self): # classification based in the initial model would have been 0 # proving that the model is updated. batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [sc.parallelize(batch) for batch in batches] + batches = [self.sc.parallelize(batch) for batch in batches] input_stream = self.ssc.queueStream(batches) predict_results = [] @@ -1230,7 +1233,7 @@ def test_dim(self): self.assertEqual(len(point.features), 3) linear_data = LinearDataGenerator.generateLinearRDD( - sc=sc, nexamples=6, nfeatures=2, eps=0.1, + sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, nParts=2, intercept=0.0).collect() self.assertEqual(len(linear_data), 6) for point in linear_data: @@ -1406,7 +1409,7 @@ def test_parameter_accuracy(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) input_stream = self.ssc.queueStream(batches) slr.trainOn(input_stream) @@ -1430,7 +1433,7 @@ def test_parameter_convergence(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) model_weights = [] input_stream = self.ssc.queueStream(batches) @@ -1463,7 +1466,7 @@ def test_prediction(self): 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], 100, 42 + i, 0.1) batches.append( - sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) input_stream = self.ssc.queueStream(batches) output_stream = slr.predictOnValues(input_stream) @@ -1494,7 +1497,7 @@ def test_train_prediction(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) predict_batches = [ b.map(lambda lp: (lp.label, lp.features)) for b in batches] @@ -1580,6 +1583,7 @@ def test_als_ratings_id_long_error(self): if __name__ == "__main__": + from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c03cb9338ae68..ae8620274dd20 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1259,6 +1259,7 @@ def test_collect_functions(self): if __name__ == "__main__": + from pyspark.sql.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 86b05d9fd2424..24b812615cbb4 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1635,6 +1635,7 @@ def search_kinesis_asl_assembly_jar(): are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' if __name__ == "__main__": + from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() mqtt_assembly_jar = search_mqtt_assembly_jar() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5bd94476597ab..23720502a82c8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2008,6 +2008,7 @@ def test_statcounter_array(self): if __name__ == "__main__": + from pyspark.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: From efd7eed3222799d66d4fcb68785142dc570c8150 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 19 Jan 2016 14:28:00 -0800 Subject: [PATCH 1626/2089] [BUILD] Runner for spark packages This is a convenience method added to the SBT build for developers, though if people think its useful we could consider adding a official script that runs using the assembly instead of compiling on demand. It simply compiles spark (without requiring an assembly), and invokes Spark Submit to download / run the package. Example Usage: ``` $ build/sbt > sparkPackage com.databricks:spark-sql-perf_2.10:0.2.4 com.databricks.spark.sql.perf.RunBenchmark --help ``` Author: Michael Armbrust Closes #10834 from marmbrus/sparkPackageRunner. --- project/SparkBuild.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4c34c888cfd5e..06e561ae0d89b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -274,6 +274,11 @@ object SparkBuild extends PomBuild { * Usage: `build/sbt sparkShell` */ val sparkShell = taskKey[Unit]("start a spark-shell.") + val sparkPackage = inputKey[Unit]( + s""" + |Download and run a spark package. + |Usage `builds/sbt "sparkPackage [args] + """.stripMargin) val sparkSql = taskKey[Unit]("starts the spark sql CLI.") enable(Seq( @@ -287,6 +292,16 @@ object SparkBuild extends PomBuild { (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value }, + sparkPackage := { + import complete.DefaultParsers._ + val packages :: className :: otherArgs = spaceDelimited(" [args]").parsed.toList + val scalaRun = (runner in run).value + val classpath = (fullClasspath in Runtime).value + val args = Seq("--packages", packages, "--class", className, (Keys.`package` in Compile in "core").value.getCanonicalPath) ++ otherArgs + println(args) + scalaRun.run("org.apache.spark.deploy.SparkSubmit", classpath.map(_.data), args, streams.value.log) + }, + javaOptions in Compile += "-Dspark.master=local", sparkSql := { From 43f1d59e17d89d19b322d639c5069a3fc0c8e2ed Mon Sep 17 00:00:00 2001 From: scwf Date: Tue, 19 Jan 2016 14:49:55 -0800 Subject: [PATCH 1627/2089] [SPARK-2750][WEB UI] Add https support to the Web UI Author: scwf Author: Marcelo Vanzin Author: WangTaoTheTonic Author: w00228970 Closes #10238 from vanzin/SPARK-2750. --- .../scala/org/apache/spark/SSLOptions.scala | 50 ++++++- .../org/apache/spark/SecurityManager.scala | 16 +-- .../apache/spark/deploy/DeployMessage.scala | 3 +- .../spark/deploy/history/HistoryServer.scala | 5 +- .../apache/spark/deploy/master/Master.scala | 4 +- .../spark/deploy/master/WorkerInfo.scala | 7 +- .../spark/deploy/master/ui/MasterWebUI.scala | 4 +- .../deploy/mesos/ui/MesosClusterUI.scala | 2 +- .../apache/spark/deploy/worker/Worker.scala | 6 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 3 +- .../org/apache/spark/ui/JettyUtils.scala | 79 +++++++++- .../scala/org/apache/spark/ui/SparkUI.scala | 3 +- .../scala/org/apache/spark/ui/WebUI.scala | 5 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- core/src/test/resources/spark.keystore | Bin 0 -> 1383 bytes .../apache/spark/SecurityManagerSuite.scala | 22 +-- .../apache/spark/deploy/DeployTestUtils.scala | 2 +- .../spark/deploy/master/MasterSuite.scala | 5 +- .../master/PersistenceEngineSuite.scala | 7 +- .../scala/org/apache/spark/ui/UISuite.scala | 135 ++++++++++++++---- docs/configuration.md | 22 +++ docs/security.md | 49 ++++++- 22 files changed, 338 insertions(+), 93 deletions(-) create mode 100644 core/src/test/resources/spark.keystore diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 3b9c885bf97a7..261265f0b4c55 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -39,8 +39,11 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param keyStore a path to the key-store file * @param keyStorePassword a password to access the key-store file * @param keyPassword a password to access the private key in the key-store + * @param keyStoreType the type of the key-store + * @param needClientAuth set true if SSL needs client authentication * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file + * @param trustStoreType the type of the trust-store * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java * @param enabledAlgorithms a set of encryption algorithms that may be used */ @@ -49,8 +52,11 @@ private[spark] case class SSLOptions( keyStore: Option[File] = None, keyStorePassword: Option[String] = None, keyPassword: Option[String] = None, + keyStoreType: Option[String] = None, + needClientAuth: Boolean = false, trustStore: Option[File] = None, trustStorePassword: Option[String] = None, + trustStoreType: Option[String] = None, protocol: Option[String] = None, enabledAlgorithms: Set[String] = Set.empty) extends Logging { @@ -63,12 +69,18 @@ private[spark] case class SSLOptions( val sslContextFactory = new SslContextFactory() keyStore.foreach(file => sslContextFactory.setKeyStorePath(file.getAbsolutePath)) - trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath)) keyStorePassword.foreach(sslContextFactory.setKeyStorePassword) - trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) + keyStoreType.foreach(sslContextFactory.setKeyStoreType) + if (needClientAuth) { + trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath)) + trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) + trustStoreType.foreach(sslContextFactory.setTrustStoreType) + } protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) + if (supportedAlgorithms.nonEmpty) { + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) + } Some(sslContextFactory) } else { @@ -82,6 +94,13 @@ private[spark] case class SSLOptions( */ def createAkkaConfig: Option[Config] = { if (enabled) { + if (keyStoreType.isDefined) { + logWarning("Akka configuration does not support key store type."); + } + if (trustStoreType.isDefined) { + logWarning("Akka configuration does not support trust store type."); + } + Some(ConfigFactory.empty() .withValue("akka.remote.netty.tcp.security.key-store", ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse(""))) @@ -110,7 +129,9 @@ private[spark] case class SSLOptions( * The supportedAlgorithms set is a subset of the enabledAlgorithms that * are supported by the current Java security provider for this protocol. */ - private val supportedAlgorithms: Set[String] = { + private val supportedAlgorithms: Set[String] = if (enabledAlgorithms.isEmpty) { + Set() + } else { var context: SSLContext = null try { context = SSLContext.getInstance(protocol.orNull) @@ -133,7 +154,11 @@ private[spark] case class SSLOptions( logDebug(s"Discarding unsupported cipher $cipher") } - enabledAlgorithms & providerAlgorithms + val supported = enabledAlgorithms & providerAlgorithms + require(supported.nonEmpty || sys.env.contains("SPARK_TESTING"), + "SSLContext does not support any of the enabled algorithms: " + + enabledAlgorithms.mkString(",")) + supported } /** Returns a string representation of this SSLOptions with all the passwords masked. */ @@ -153,9 +178,12 @@ private[spark] object SSLOptions extends Logging { * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory * $ - `[ns].keyStorePassword` - a password to the key-store file * $ - `[ns].keyPassword` - a password to the private key + * $ - `[ns].keyStoreType` - the type of the key-store + * $ - `[ns].needClientAuth` - whether SSL needs client authentication * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current * directory * $ - `[ns].trustStorePassword` - a password to the trust-store file + * $ - `[ns].trustStoreType` - the type of trust-store * $ - `[ns].protocol` - a protocol name supported by a particular Java version * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers * @@ -183,12 +211,21 @@ private[spark] object SSLOptions extends Logging { val keyPassword = conf.getOption(s"$ns.keyPassword") .orElse(defaults.flatMap(_.keyPassword)) + val keyStoreType = conf.getOption(s"$ns.keyStoreType") + .orElse(defaults.flatMap(_.keyStoreType)) + + val needClientAuth = + conf.getBoolean(s"$ns.needClientAuth", defaultValue = defaults.exists(_.needClientAuth)) + val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") .orElse(defaults.flatMap(_.trustStorePassword)) + val trustStoreType = conf.getOption(s"$ns.trustStoreType") + .orElse(defaults.flatMap(_.trustStoreType)) + val protocol = conf.getOption(s"$ns.protocol") .orElse(defaults.flatMap(_.protocol)) @@ -202,8 +239,11 @@ private[spark] object SSLOptions extends Logging { keyStore, keyStorePassword, keyPassword, + keyStoreType, + needClientAuth, trustStore, trustStorePassword, + trustStoreType, protocol, enabledAlgorithms) } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 64e483e384772..c5aec05c03fce 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -244,14 +244,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) - // SSL configuration for different communication layers - they can override the default - // configuration at a specified namespace. The namespace *must* start with spark.ssl. - val fileServerSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.fs", Some(defaultSSLOptions)) - val akkaSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.akka", Some(defaultSSLOptions)) - - logDebug(s"SSLConfiguration for file server: $fileServerSSLOptions") - logDebug(s"SSLConfiguration for Akka: $akkaSSLOptions") - + // SSL configuration for the file server. This is used by Utils.setupSecureURLConnection(). + val fileServerSSLOptions = getSSLOptions("fs") val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { val trustStoreManagers = for (trustStore <- fileServerSSLOptions.trustStore) yield { @@ -292,6 +286,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) (None, None) } + def getSSLOptions(module: String): SSLOptions = { + val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + logDebug(s"Created SSL options for $module: $opts") + opts + } + /** * Split a comma separated String, filter out any empty items, and return a Set of strings */ diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 3feb7cea593e0..3e78c7ae240f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -41,8 +41,7 @@ private[deploy] object DeployMessages { worker: RpcEndpointRef, cores: Int, memory: Int, - webUiPort: Int, - publicAddress: String) + workerWebUiUrl: String) extends DeployMessage { Utils.checkHost(host, "Required hostname") assert (port > 0) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 96007a06e3c54..1f13d7db348ec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -49,7 +49,8 @@ class HistoryServer( provider: ApplicationHistoryProvider, securityManager: SecurityManager, port: Int) - extends WebUI(securityManager, port, conf) with Logging with UIRoot { + extends WebUI(securityManager, securityManager.getSSLOptions("historyServer"), port, conf) + with Logging with UIRoot { // How many applications to retain private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) @@ -233,7 +234,7 @@ object HistoryServer extends Logging { val UI_PATH_PREFIX = "/history" - def main(argStrings: Array[String]) { + def main(argStrings: Array[String]): Unit = { Utils.initDaemon(log) new HistoryServerArguments(conf, argStrings) initSecurity() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 0deab8ddd5270..202a1b787c21b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -383,7 +383,7 @@ private[deploy] class Master( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -392,7 +392,7 @@ private[deploy] class Master( context.reply(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - workerRef, workerUiPort, publicAddress) + workerRef, workerWebUiUrl) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) context.reply(RegisteredWorker(self, masterWebUiUrl)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index f751966605206..4e20c10fd1427 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -29,8 +29,7 @@ private[spark] class WorkerInfo( val cores: Int, val memory: Int, val endpoint: RpcEndpointRef, - val webUiPort: Int, - val publicAddress: String) + val webUiAddress: String) extends Serializable { Utils.checkHost(host, "Expected hostname") @@ -98,10 +97,6 @@ private[spark] class WorkerInfo( coresUsed -= driver.desc.cores } - def webUiAddress : String = { - "http://" + this.publicAddress + ":" + this.webUiPort - } - def setState(state: WorkerState.Value): Unit = { this.state = state } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 750ef0a962550..d7543926f3850 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -32,8 +32,8 @@ class MasterWebUI( val master: Master, requestedPort: Int, customMasterPage: Option[MasterPage] = None) - extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging - with UIRoot { + extends WebUI(master.securityMgr, master.securityMgr.getSSLOptions("standalone"), + requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index da9740bb41f59..baad098a0cd1f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -31,7 +31,7 @@ private[spark] class MesosClusterUI( conf: SparkConf, dispatcherPublicAddress: String, val scheduler: MesosClusterScheduler) - extends WebUI(securityManager, port, conf) { + extends WebUI(securityManager, securityManager.getSSLOptions("mesos"), port, conf) { initialize() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 98e17da489741..179d3b9f20b1f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -100,6 +100,7 @@ private[deploy] class Worker( private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" + private var workerWebUiUrl: String = "" private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString private var registered = false private var connected = false @@ -184,6 +185,9 @@ private[deploy] class Worker( shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() + + val scheme = if (webUi.sslOptions.enabled) "https" else "http" + workerWebUiUrl = s"$scheme://$publicAddress:${webUi.boundPort}" registerWithMaster() metricsSystem.registerSource(workerSource) @@ -336,7 +340,7 @@ private[deploy] class Worker( private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = { masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + workerId, host, port, self, cores, memory, workerWebUiUrl)) .onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 1a0598e50dcf1..b45b6824949e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -34,7 +34,8 @@ class WorkerWebUI( val worker: Worker, val workDir: File, requestedPort: Int) - extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") + extends WebUI(worker.securityMgr, worker.securityMgr.getSSLOptions("standalone"), + requestedPort, worker.conf, name = "WorkerUI") with Logging { private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index b796a44fe01ac..bc143b7de399c 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -17,21 +17,24 @@ package org.apache.spark.ui -import java.net.{InetSocketAddress, URL} +import java.net.{URI, URL} import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} +import scala.collection.mutable.{ArrayBuffer, StringBuilder} import scala.language.implicitConversions import scala.xml.Node -import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.{Connector, Request, Server} import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.nio.SelectChannelConnector +import org.eclipse.jetty.server.ssl.SslSelectChannelConnector import org.eclipse.jetty.servlet._ import org.eclipse.jetty.util.thread.QueuedThreadPool import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SSLOptions} import org.apache.spark.util.Utils /** @@ -224,23 +227,51 @@ private[spark] object JettyUtils extends Logging { def startJettyServer( hostName: String, port: Int, + sslOptions: SSLOptions, handlers: Seq[ServletContextHandler], conf: SparkConf, serverName: String = ""): ServerInfo = { + val collection = new ContextHandlerCollection addFilters(handlers, conf) - val collection = new ContextHandlerCollection val gzipHandlers = handlers.map { h => val gzipHandler = new GzipHandler gzipHandler.setHandler(h) gzipHandler } - collection.setHandlers(gzipHandlers.toArray) // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(hostName, currentPort)) + val server = new Server + val connectors = new ArrayBuffer[Connector] + // Create a connector on port currentPort to listen for HTTP requests + val httpConnector = new SelectChannelConnector() + httpConnector.setPort(currentPort) + connectors += httpConnector + + sslOptions.createJettySslContextFactory().foreach { factory => + // If the new port wraps around, do not try a privileged port. + val securePort = + if (currentPort != 0) { + (currentPort + 400 - 1024) % (65536 - 1024) + 1024 + } else { + 0 + } + val scheme = "https" + // Create a connector on port securePort to listen for HTTPS requests + val connector = new SslSelectChannelConnector(factory) + connector.setPort(securePort) + connectors += connector + + // redirect the HTTP requests to HTTPS port + collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + } + + gzipHandlers.foreach(collection.addHandler) + connectors.foreach(_.setHost(hostName)) + server.setConnectors(connectors.toArray) + val pool = new QueuedThreadPool pool.setDaemon(true) server.setThreadPool(pool) @@ -262,6 +293,42 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } + + private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { + val redirectHandler: ContextHandler = new ContextHandler + redirectHandler.setContextPath("/") + redirectHandler.setHandler(new AbstractHandler { + override def handle( + target: String, + baseRequest: Request, + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + if (baseRequest.isSecure) { + return + } + val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, + baseRequest.getRequestURI, baseRequest.getQueryString) + response.setContentLength(0) + response.encodeRedirectURL(httpsURI) + response.sendRedirect(httpsURI) + baseRequest.setHandled(true) + } + }) + redirectHandler + } + + // Create a new URI from the arguments, handling IPv6 host encoding and default ports. + private def createRedirectURI( + scheme: String, server: String, port: Int, path: String, query: String) = { + val redirectServer = if (server.contains(":") && !server.startsWith("[")) { + s"[${server}]" + } else { + server + } + val authority = s"$redirectServer:$port" + new URI(scheme, authority, path, query, null).toString + } + } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index e319937702f23..eb53aa8e23ae7 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -50,7 +50,8 @@ private[spark] class SparkUI private ( var appName: String, val basePath: String, val startTime: Long) - extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") + extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), + conf, basePath, "SparkUI") with Logging with UIRoot { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 3925235984723..fe4949b9f6fee 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -26,7 +26,7 @@ import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SSLOptions} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils @@ -38,6 +38,7 @@ import org.apache.spark.util.Utils */ private[spark] abstract class WebUI( val securityManager: SecurityManager, + val sslOptions: SSLOptions, port: Int, conf: SparkConf, basePath: String = "", @@ -133,7 +134,7 @@ private[spark] abstract class WebUI( def bind() { assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) try { - serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf, name)) + serverInfo = Some(startJettyServer("0.0.0.0", port, sslOptions, handlers, conf, name)) logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index f2d93edd4fd2e..3f4ac9b2f18cd 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -86,7 +86,7 @@ private[spark] object AkkaUtils extends Logging { val secureCookie = if (isAuthOn) secretKey else "" logDebug(s"In createActorSystem, requireCookie is: $requireCookie") - val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig + val akkaSslConfig = securityManager.getSSLOptions("akka").createAkkaConfig .getOrElse(ConfigFactory.empty()) val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) diff --git a/core/src/test/resources/spark.keystore b/core/src/test/resources/spark.keystore new file mode 100644 index 0000000000000000000000000000000000000000..f30716b57b30298e78e85011d7595bb42d6a85d7 GIT binary patch literal 1383 zcmezO_TO6u1_mY|W&~sI#Dc`+jMU;q4hU zrO=-|ukKbAC9>v6Xj~6^dGz-;rEIy-hpu9VYZQKRFxVT-dox2Ha~PrWerl|?>*gSJ~<=YJoCs7 z*}f&=)m!s=l$QKIcjfZVg^^3XXEFT`vO68ixrQZv8(+};usDIh+i!Q@T=VR`^89pT z6UnxZD&GWt=NYjq{y2LTm;dYUELm3^{5SfoauvOH{%7jM zPd9ov=e?iU^JzzTZ}X+bIj3AJ&nEJhmq@bQG`idUM*r_}6^5YVRe!GJFIIo??*6>C z7vIbpE||~zQ!sDui>F(=K72@uKDoT)BuD=)FFucpI|Dxnt`Wa(wwCK?M(X@4_g0qv ztpD_f>7yp2OTnRV-yp{hjyJClXHNV+>q73A6Efx(7W|w4hRtoVLjNK5_p&_tOri@G z{Lh`aywmi#B-gd1RwMiDXR`Sds`^gv$v>tTp*Z=;ypD7=&V}2nyWfYn94Z5+Yt{%o zQv*w2N=^W#ngfr2=%p^1Top_ze^fw_@olmx$#AyB~F$QUX>2RAg% zNA@=`KQlM>G8i;=GBq|b>|Y()b*zay^1{UnwF`D`zQs^EAxidc%t7#mf$4xWh?;QP`aNzKxWvBigJX3kB%Cn;TgK>#Ki-+Ei zc;e5llU;DN!SO-MqGZpx%c~L(NKLurv%1*WaJBx7*DqEbh)WH4UC(G58aeeo6Eh`HsJjn8~h_Q&2Yt_uD=KU+=u`I5WC28qf3udlRWFG)yjv46m)qD=0ESebC ziYo0X-*0sLL%{Uzrn!0-|1q(j+ZbcPrTDqLXr@NmG%jVavp<7ao@I!_uIxq zi}0s+x-40>E*VS9-92;o$X37h9T#V&HuGLNv9?}$<0-YDvbCT06;8F*j&j~ua5_~f w#jTO2x_9%}r!KmNKLjS12(B>byIgicKkS7jrwSLx5-rOqYs=I>YaZYR0Q-$KZvX%Q literal 0 HcmV?d00001 diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e0226803bb1cf..7603cef773691 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -181,9 +181,10 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) + val akkaSSLOptions = securityManager.getSSLOptions("akka") assert(securityManager.fileServerSSLOptions.enabled === true) - assert(securityManager.akkaSSLOptions.enabled === true) + assert(akkaSSLOptions.enabled === true) assert(securityManager.sslSocketFactory.isDefined === true) assert(securityManager.hostnameVerifier.isDefined === true) @@ -198,15 +199,15 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) - assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.akkaSSLOptions.keyStore.isDefined === true) - assert(securityManager.akkaSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) + assert(akkaSSLOptions.trustStore.isDefined === true) + assert(akkaSSLOptions.trustStore.get.getName === "truststore") + assert(akkaSSLOptions.keyStore.isDefined === true) + assert(akkaSSLOptions.keyStore.get.getName === "keystore") + assert(akkaSSLOptions.trustStorePassword === Some("password")) + assert(akkaSSLOptions.keyStorePassword === Some("password")) + assert(akkaSSLOptions.keyPassword === Some("password")) + assert(akkaSSLOptions.protocol === Some("TLSv1.2")) + assert(akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { @@ -218,7 +219,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { val securityManager = new SecurityManager(conf) assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.akkaSSLOptions.enabled === false) assert(securityManager.sslSocketFactory.isDefined === false) assert(securityManager.hostnameVerifier.isDefined === false) } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 86455a13d0fe7..190e4dd7285b3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -50,7 +50,7 @@ private[deploy] object DeployTestUtils { createDriverDesc(), new Date()) def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://publicAddress:80") workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis workerInfo } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 10e33a32ba4c3..ce00807ea46b9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -90,8 +90,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva cores = 0, memory = 0, endpoint = null, - webUiPort = 0, - publicAddress = "" + webUiAddress = "http://localhost:80" ) val (rpcEnv, _, _) = @@ -376,7 +375,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { val workerId = System.currentTimeMillis.toString - new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, 101, "address") + new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, "http://localhost:80") } private def scheduleExecutorsOnWorkers( diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index b4deed7f877e8..62fe0eaedfd27 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -88,9 +88,7 @@ class PersistenceEngineSuite extends SparkFunSuite { cores = 0, memory = 0, endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) + webUiAddress = "http://localhost:80") persistenceEngine.addWorker(workerToPersist) @@ -109,8 +107,7 @@ class PersistenceEngineSuite extends SparkFunSuite { assert(workerToPersist.cores === recoveryWorkerInfo.cores) assert(workerToPersist.memory === recoveryWorkerInfo.memory) assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + assert(workerToPersist.webUiAddress === recoveryWorkerInfo.webUiAddress) } finally { testRpcEnv.shutdown() testRpcEnv.awaitTermination() diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 2d28b67ef23fa..69c46058f1c1a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ui -import java.net.ServerSocket +import java.net.{BindException, ServerSocket} import scala.io.Source -import scala.util.{Failure, Success, Try} +import org.eclipse.jetty.server.Server import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ class UISuite extends SparkFunSuite { @@ -45,6 +45,20 @@ class UISuite extends SparkFunSuite { sc } + private def sslDisabledConf(): (SparkConf, SSLOptions) = { + val conf = new SparkConf + (conf, new SecurityManager(conf).getSSLOptions("ui")) + } + + private def sslEnabledConf(): (SparkConf, SSLOptions) = { + val conf = new SparkConf() + .set("spark.ssl.ui.enabled", "true") + .set("spark.ssl.ui.keyStore", "./src/test/resources/spark.keystore") + .set("spark.ssl.ui.keyStorePassword", "123456") + .set("spark.ssl.ui.keyPassword", "123456") + (conf, new SecurityManager(conf).getSSLOptions("ui")) + } + ignore("basic ui visibility") { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible @@ -70,33 +84,92 @@ class UISuite extends SparkFunSuite { } test("jetty selects different port under contention") { - val server = new ServerSocket(0) - val startPort = server.getLocalPort - val serverInfo1 = JettyUtils.startJettyServer( - "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) - val serverInfo2 = JettyUtils.startJettyServer( - "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) - // Allow some wiggle room in case ports on the machine are under contention - val boundPort1 = serverInfo1.boundPort - val boundPort2 = serverInfo2.boundPort - assert(boundPort1 != startPort) - assert(boundPort2 != startPort) - assert(boundPort1 != boundPort2) - serverInfo1.server.stop() - serverInfo2.server.stop() - server.close() + var server: ServerSocket = null + var serverInfo1: ServerInfo = null + var serverInfo2: ServerInfo = null + val (conf, sslOptions) = sslDisabledConf() + try { + server = new ServerSocket(0) + val startPort = server.getLocalPort + serverInfo1 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf) + serverInfo2 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf) + // Allow some wiggle room in case ports on the machine are under contention + val boundPort1 = serverInfo1.boundPort + val boundPort2 = serverInfo2.boundPort + assert(boundPort1 != startPort) + assert(boundPort2 != startPort) + assert(boundPort1 != boundPort2) + } finally { + stopServer(serverInfo1) + stopServer(serverInfo2) + closeSocket(server) + } + } + + test("jetty with https selects different port under contention") { + var server: ServerSocket = null + var serverInfo1: ServerInfo = null + var serverInfo2: ServerInfo = null + try { + server = new ServerSocket(0) + val startPort = server.getLocalPort + val (conf, sslOptions) = sslEnabledConf() + serverInfo1 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf, "server1") + serverInfo2 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf, "server2") + // Allow some wiggle room in case ports on the machine are under contention + val boundPort1 = serverInfo1.boundPort + val boundPort2 = serverInfo2.boundPort + assert(boundPort1 != startPort) + assert(boundPort2 != startPort) + assert(boundPort1 != boundPort2) + } finally { + stopServer(serverInfo1) + stopServer(serverInfo2) + closeSocket(server) + } } test("jetty binds to port 0 correctly") { - val serverInfo = JettyUtils.startJettyServer( - "0.0.0.0", 0, Seq[ServletContextHandler](), new SparkConf) - val server = serverInfo.server - val boundPort = serverInfo.boundPort - assert(server.getState === "STARTED") - assert(boundPort != 0) - Try { new ServerSocket(boundPort) } match { - case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) - case Failure(e) => + var socket: ServerSocket = null + var serverInfo: ServerInfo = null + val (conf, sslOptions) = sslDisabledConf() + try { + serverInfo = JettyUtils.startJettyServer( + "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](), conf) + val server = serverInfo.server + val boundPort = serverInfo.boundPort + assert(server.getState === "STARTED") + assert(boundPort != 0) + intercept[BindException] { + socket = new ServerSocket(boundPort) + } + } finally { + stopServer(serverInfo) + closeSocket(socket) + } + } + + test("jetty with https binds to port 0 correctly") { + var socket: ServerSocket = null + var serverInfo: ServerInfo = null + try { + val (conf, sslOptions) = sslEnabledConf() + serverInfo = JettyUtils.startJettyServer( + "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](), conf) + val server = serverInfo.server + val boundPort = serverInfo.boundPort + assert(server.getState === "STARTED") + assert(boundPort != 0) + intercept[BindException] { + socket = new ServerSocket(boundPort) + } + } finally { + stopServer(serverInfo) + closeSocket(socket) } } @@ -117,4 +190,12 @@ class UISuite extends SparkFunSuite { assert(splitUIAddress(2).toInt == boundPort) } } + + def stopServer(info: ServerInfo): Unit = { + if (info != null && info.server != null) info.server.stop + } + + def closeSocket(socket: ServerSocket): Unit = { + if (socket != null) socket.close + } } diff --git a/docs/configuration.md b/docs/configuration.md index 08392c39187b9..12ac60129633d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1430,6 +1430,7 @@ Apart from these, the following properties are also available, and may be useful The reference list of protocols one can find on this page. + Note: If not set, it will use the default cipher suites of JVM. @@ -1454,6 +1455,13 @@ Apart from these, the following properties are also available, and may be useful A password to the key-store. + + spark.ssl.keyStoreType + JKS + + The type of the key-store. + + spark.ssl.protocol None @@ -1463,6 +1471,13 @@ Apart from these, the following properties are also available, and may be useful page. + + spark.ssl.needClientAuth + false + + Set true if SSL needs client authentication. + + spark.ssl.trustStore None @@ -1478,6 +1493,13 @@ Apart from these, the following properties are also available, and may be useful A password to the trust-store. + + spark.ssl.trustStoreType + JKS + + The type of the trust-store. + + diff --git a/docs/security.md b/docs/security.md index 1b7741d4dd93c..a4cc0f42b2482 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,15 +6,19 @@ title: Security Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. +* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. ## Web UI -The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting +and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via the `spark.ui.https.enabled` setting. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. +### Authentication +A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. + +Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. ## Event Logging @@ -23,8 +27,8 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for file server) protocols. SASL encryption is -supported for the block transfer service. Encryption is not yet supported for the WebUI. +Spark supports SSL for Akka and HTTP protocols. SASL encryption is supported for the block transfer +service. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle files, cached data, and other application files. If encrypting this data is desired, a workaround is @@ -32,8 +36,41 @@ to configure your cluster manager to store application data on encrypted disks. ### SSL Configuration -Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings +which will be used for all the supported communication protocols unless they are overwritten by +protocol-specific settings. This way the user can easily provide the common settings for all the +protocols without disabling the ability to configure each one individually. The common SSL settings +are at `spark.ssl` namespace in Spark configuration. The following table describes the +component-specific configuration namespaces used to override the default settings: + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Config NamespaceComponent
    spark.ssl.akkaAkka communication channels
    spark.ssl.fsHTTP file server and broadcast server
    spark.ssl.uiSpark application Web UI
    spark.ssl.standaloneStandalone Master / Worker Web UI
    spark.ssl.historyServerHistory Server Web UI
    +The full breakdown of available SSL options can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. ### YARN mode From f6f7ca9d2ef65da15f42085993e58e618637fad5 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Tue, 19 Jan 2016 14:59:20 -0800 Subject: [PATCH 1628/2089] [SPARK-9716][ML] BinaryClassificationEvaluator should accept Double prediction column This PR aims to allow the prediction column of `BinaryClassificationEvaluator` to be of double type. Author: BenFradet Closes #10472 from BenFradet/SPARK-9716. --- .../BinaryClassificationEvaluator.scala | 9 ++++-- .../apache/spark/ml/util/SchemaUtils.scala | 17 ++++++++++ .../BinaryClassificationEvaluatorSuite.scala | 32 +++++++++++++++++++ python/pyspark/ml/evaluation.py | 5 +-- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index f71726f110e84..a1d36c4becfa2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: * Evaluator for binary classification, which expects two input columns: rawPrediction and label. + * The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) + * or of type vector (length-2 vector of raw predictions, scores, or label probabilities). */ @Since("1.2.0") @Experimental @@ -78,13 +80,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.2.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) - .map { case Row(rawPrediction: Vector, label: Double) => - (rawPrediction(1), label) + .map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f651488aef9..e71dd9eee03e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -43,6 +43,23 @@ private[spark] object SchemaUtils { s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } + /** + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ + def checkColumnTypes( + schema: StructType, + colName: String, + dataTypes: Seq[DataType], + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(dataTypes.exists(actualDataType.equals), + s"Column $colName must be of type equal to one of the following types: " + + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + } + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index a535c1218ecfa..27349950dc119 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext class BinaryClassificationEvaluatorSuite @@ -36,4 +37,35 @@ class BinaryClassificationEvaluatorSuite .setMetricName("areaUnderPR") testDefaultReadWrite(evaluator) } + + test("should accept both vector and double raw prediction col") { + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") + + val vectorDF = sqlContext.createDataFrame(Seq( + (0d, Vectors.dense(12, 2.5)), + (1d, Vectors.dense(1, 3)), + (0d, Vectors.dense(10, 2)) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(vectorDF) === 1.0) + + val doubleDF = sqlContext.createDataFrame(Seq( + (0d, 0d), + (1d, 1d), + (0d, 0d) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(doubleDF) === 1.0) + + val stringDF = sqlContext.createDataFrame(Seq( + (0d, "0d"), + (1d, "1d"), + (0d, "0d") + )).toDF("label", "rawPrediction") + val thrown = intercept[IllegalArgumentException] { + evaluator.evaluate(stringDF) + } + assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + + "equal to one of the following types: [DoubleType, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + } } diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index dcc1738ec518b..6ff68abd8f18f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -106,8 +106,9 @@ def isLargerBetter(self): @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): """ - Evaluator for binary classification, which expects two input - columns: rawPrediction and label. + Evaluator for binary classification, which expects two input columns: rawPrediction and label. + The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label + 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities). >>> from pyspark.mllib.linalg import Vectors >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]), From 3e84ef0a54c53c45d7802cd2fecfa1c223580aee Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 19 Jan 2016 16:14:41 -0800 Subject: [PATCH 1629/2089] [SPARK-12770][SQL] Implement rules for branch elimination for CaseWhen The three optimization cases are: 1. If the first branch's condition is a true literal, remove the CaseWhen and use the value from that branch. 2. If a branch's condition is a false or null literal, remove that branch. 3. If only the else branch is left, remove the CaseWhen and use the value from the else branch. Author: Reynold Xin Closes #10827 from rxin/SPARK-12770. --- .../sql/catalyst/optimizer/Optimizer.scala | 18 +++++++++ .../optimizer/SimplifyConditionalSuite.scala | 37 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index cc3371c08fac4..b7caa49f5046b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -635,6 +635,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue + + case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) => + // If there are branches that are always false, remove them. + // If there are no more branches left, just use the else value. + // Note that these two are handled together here in a single case statement because + // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null). + val newBranches = branches.filter(_._1 != FalseLiteral) + if (newBranches.isEmpty) { + elseValue.getOrElse(Literal.create(null, e.dataType)) + } else { + e.copy(branches = newBranches) + } + + case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + // If the first branch is a true literal, remove the entire CaseWhen and use the value + // from that. Note that CaseWhen.branches should never be empty, and as a result the + // headOption (rather than head) added above is just a extra (and unnecessary) safeguard. + branches.head._2 } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 8e5d7ef3c9d49..d436b627f6bd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.IntegerType class SimplifyConditionalSuite extends PlanTest with PredicateHelper { @@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { comparePlans(actual, correctAnswer) } + private val trueBranch = (TrueLiteral, Literal(5)) + private val normalBranch = (NonFoldableLiteral(true), Literal(10)) + private val unreachableBranch = (FalseLiteral, Literal(20)) + test("simplify if") { assertEquivalent( If(TrueLiteral, Literal(10), Literal(20)), @@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) } + test("remove unreachable branches") { + // i.e. removing branches whose conditions are always false + assertEquivalent( + CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None), + CaseWhen(normalBranch :: Nil, None)) + } + + test("remove entire CaseWhen if only the else branch is reachable") { + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))), + Literal(30)) + + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None), + Literal.create(null, IntegerType)) + } + + test("remove entire CaseWhen if the first branch is always true") { + assertEquivalent( + CaseWhen(trueBranch :: normalBranch :: Nil, None), + Literal(5)) + + // Test branch elimination and simplification in combination + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None), + Literal(5)) + + // Make sure this doesn't trigger if there is a non-foldable branch before the true branch + assertEquivalent( + CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None), + CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None)) + } } From 37fefa66cbd61bc592aba42b0ed3aefc0cf3abb0 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 19 Jan 2016 16:33:48 -0800 Subject: [PATCH 1630/2089] [SPARK-12168][SPARKR] Add automated tests for conflicted function in R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently this is reported when loading the SparkR package in R (probably would add is.nan) ``` Loading required package: methods Attaching package: ‘SparkR’ The following objects are masked from ‘package:stats’: cov, filter, lag, na.omit, predict, sd, var The following objects are masked from ‘package:base’: colnames, colnames<-, intersect, rank, rbind, sample, subset, summary, table, transform ``` Adding this test adds an automated way to track changes to masked method. Also, the second part of this test check for those functions that would not be accessible without namespace/package prefix. Incidentally, this might point to how we would fix those inaccessible functions in base or stats. Looking for feedback for adding this test. Author: felixcheung Closes #10171 from felixcheung/rmaskedtest. --- R/pkg/NAMESPACE | 2 +- R/pkg/inst/tests/testthat/test_context.R | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 34d14373b9027..27d2f9822f294 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -271,10 +271,10 @@ export("as.DataFrame", "createExternalTable", "dropTempTable", "jsonFile", - "read.json", "loadDF", "parquetFile", "read.df", + "read.json", "read.parquet", "read.text", "sql", diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 1707e314beff5..92dbd575c20ab 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -17,6 +17,29 @@ context("test functions in sparkR.R") +test_that("Check masked functions", { + # Check that we are not masking any new function from base, stats, testthat unexpectedly + masked <- conflicts(detail = TRUE)$`package:SparkR` + expect_true("describe" %in% masked) # only when with testthat.. + func <- lapply(masked, function(x) { capture.output(showMethods(x))[[1]] }) + funcSparkROrEmpty <- grepl("\\(package SparkR\\)$|^$", func) + maskedBySparkR <- masked[funcSparkROrEmpty] + expect_equal(length(maskedBySparkR), 18) + expect_equal(sort(maskedBySparkR), sort(c("describe", "cov", "filter", "lag", "na.omit", + "predict", "sd", "var", "colnames", "colnames<-", + "intersect", "rank", "rbind", "sample", "subset", + "summary", "table", "transform"))) + # above are those reported as masked when `library(SparkR)` + # note that many of these methods are still callable without base:: or stats:: prefix + # there should be a test for each of these, except followings, which are currently "broken" + funcHasAny <- unlist(lapply(masked, function(x) { + any(grepl("=\"ANY\"", capture.output(showMethods(x)[-1]))) + })) + maskedCompletely <- masked[!funcHasAny] + expect_equal(length(maskedCompletely), 4) + expect_equal(sort(maskedCompletely), sort(c("cov", "filter", "sample", "table"))) +}) + test_that("repeatedly starting and stopping SparkR", { for (i in 1:4) { sc <- sparkR.init() From 3ac648289c543b56937d67b5df5c3e228ef47cbd Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 19 Jan 2016 16:37:18 -0800 Subject: [PATCH 1631/2089] [SPARK-12337][SPARKR] Implement dropDuplicates() method of DataFrame in SparkR. Author: Sun Rui Closes #10309 from sun-rui/SPARK-12337. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 30 ++++++++++++++++++ R/pkg/R/generics.R | 7 +++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 38 ++++++++++++++++++++++- 4 files changed, 75 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 27d2f9822f294..7739e9ea8689f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -39,6 +39,7 @@ exportMethods("arrange", "describe", "dim", "distinct", + "dropDuplicates", "dropna", "dtypes", "except", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 35695b9df1974..629c1ce2eddc1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1645,6 +1645,36 @@ setMethod("where", filter(x, condition) }) +#' dropDuplicates +#' +#' Returns a new DataFrame with duplicate rows removed, considering only +#' the subset of columns. +#' +#' @param x A DataFrame. +#' @param colnames A character vector of column names. +#' @return A DataFrame with duplicate rows removed. +#' @family DataFrame functions +#' @rdname dropduplicates +#' @name dropDuplicates +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' dropDuplicates(df) +#' dropDuplicates(df, c("col1", "col2")) +#' } +setMethod("dropDuplicates", + signature(x = "DataFrame"), + function(x, colNames = columns(x)) { + stopifnot(class(colNames) == "character") + + sdf <- callJMethod(x@sdf, "dropDuplicates", as.list(colNames)) + dataFrame(sdf) + }) + #' Join #' #' Join two DataFrames based on the given join expression. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 860329988f97c..d616266ead41b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -428,6 +428,13 @@ setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname dropduplicates +#' @export +setGeneric("dropDuplicates", + function(x, colNames = columns(x)) { + standardGeneric("dropDuplicates") + }) + #' @rdname nafunctions #' @export setGeneric("dropna", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 67ecdbc522d23..6610734cf4fae 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -734,7 +734,7 @@ test_that("head() and first() return the correct data", { expect_equal(ncol(testFirst), 2) }) -test_that("distinct() and unique on DataFrames", { +test_that("distinct(), unique() and dropDuplicates() on DataFrames", { lines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}", @@ -750,6 +750,42 @@ test_that("distinct() and unique on DataFrames", { uniques2 <- unique(df) expect_is(uniques2, "DataFrame") expect_equal(count(uniques2), 3) + + # Test dropDuplicates() + df <- createDataFrame( + sqlContext, + list( + list(2, 1, 2), list(1, 1, 1), + list(1, 2, 1), list(2, 1, 2), + list(2, 2, 2), list(2, 2, 1), + list(2, 1, 1), list(1, 1, 2), + list(1, 2, 2), list(1, 2, 1)), + schema = c("key", "value1", "value2")) + result <- collect(dropDuplicates(df)) + expected <- rbind.data.frame( + c(1, 1, 1), c(1, 1, 2), c(1, 2, 1), + c(1, 2, 2), c(2, 1, 1), c(2, 1, 2), + c(2, 2, 1), c(2, 2, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2),], + expected) + + result <- collect(dropDuplicates(df, c("key", "value1"))) + expected <- rbind.data.frame( + c(1, 1, 1), c(1, 2, 1), c(2, 1, 2), c(2, 2, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2),], + expected) + + result <- collect(dropDuplicates(df, "key")) + expected <- rbind.data.frame( + c(1, 1, 1), c(2, 1, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2),], + expected) }) test_that("sample on a DataFrame", { From beda9014220be77dd735e6af1903e7d93dceb110 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Jan 2016 16:51:17 -0800 Subject: [PATCH 1632/2089] Revert "[SPARK-11295] Add packages to JUnit output for Python tests" This reverts commit c6f971b4aeca7265ab374fa46c5c452461d9b6a7. --- python/pyspark/ml/tests.py | 1 - python/pyspark/mllib/tests.py | 24 ++++++++++-------------- python/pyspark/sql/tests.py | 1 - python/pyspark/streaming/tests.py | 1 - python/pyspark/tests.py | 1 - 5 files changed, 10 insertions(+), 18 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9ea639dc4f960..4eb17bfdcca90 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -394,7 +394,6 @@ def test_fit_maximize_metric(self): if __name__ == "__main__": - from pyspark.ml.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index ea7d297cba2ae..32ed48e10388e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -77,24 +77,21 @@ pass ser = PickleSerializer() +sc = SparkContext('local[4]', "MLlib tests") class MLlibTestCase(unittest.TestCase): def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - - def tearDown(self): - self.sc.stop() + self.sc = sc class MLLibStreamingTestCase(unittest.TestCase): def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") + self.sc = sc self.ssc = StreamingContext(self.sc, 1.0) def tearDown(self): self.ssc.stop(False) - self.sc.stop() @staticmethod def _eventually(condition, timeout=30.0, catch_assertions=False): @@ -1169,7 +1166,7 @@ def test_predictOn_model(self): clusterWeights=[1.0, 1.0, 1.0, 1.0]) predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] + predict_data = [sc.parallelize(batch, 1) for batch in predict_data] predict_stream = self.ssc.queueStream(predict_data) predict_val = stkm.predictOn(predict_stream) @@ -1200,7 +1197,7 @@ def test_trainOn_predictOn(self): # classification based in the initial model would have been 0 # proving that the model is updated. batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [self.sc.parallelize(batch) for batch in batches] + batches = [sc.parallelize(batch) for batch in batches] input_stream = self.ssc.queueStream(batches) predict_results = [] @@ -1233,7 +1230,7 @@ def test_dim(self): self.assertEqual(len(point.features), 3) linear_data = LinearDataGenerator.generateLinearRDD( - sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, + sc=sc, nexamples=6, nfeatures=2, eps=0.1, nParts=2, intercept=0.0).collect() self.assertEqual(len(linear_data), 6) for point in linear_data: @@ -1409,7 +1406,7 @@ def test_parameter_accuracy(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) + batches.append(sc.parallelize(batch)) input_stream = self.ssc.queueStream(batches) slr.trainOn(input_stream) @@ -1433,7 +1430,7 @@ def test_parameter_convergence(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) + batches.append(sc.parallelize(batch)) model_weights = [] input_stream = self.ssc.queueStream(batches) @@ -1466,7 +1463,7 @@ def test_prediction(self): 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], 100, 42 + i, 0.1) batches.append( - self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) input_stream = self.ssc.queueStream(batches) output_stream = slr.predictOnValues(input_stream) @@ -1497,7 +1494,7 @@ def test_train_prediction(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) + batches.append(sc.parallelize(batch)) predict_batches = [ b.map(lambda lp: (lp.label, lp.features)) for b in batches] @@ -1583,7 +1580,6 @@ def test_als_ratings_id_long_error(self): if __name__ == "__main__": - from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ae8620274dd20..c03cb9338ae68 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1259,7 +1259,6 @@ def test_collect_functions(self): if __name__ == "__main__": - from pyspark.sql.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 24b812615cbb4..86b05d9fd2424 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1635,7 +1635,6 @@ def search_kinesis_asl_assembly_jar(): are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' if __name__ == "__main__": - from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() mqtt_assembly_jar = search_mqtt_assembly_jar() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 23720502a82c8..5bd94476597ab 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2008,7 +2008,6 @@ def test_statcounter_array(self): if __name__ == "__main__": - from pyspark.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: From 488bbb216c82306e82b8963a331d48d484e8eadd Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 19 Jan 2016 18:31:03 -0800 Subject: [PATCH 1633/2089] [SPARK-12232][SPARKR] New R API for read.table to avoid name conflict shivaram sorry it took longer to fix some conflicts, this is the change to add an alias for `table` Author: felixcheung Closes #10406 from felixcheung/readtable. --- R/pkg/NAMESPACE | 2 +- R/pkg/R/SQLContext.R | 7 ++++--- R/pkg/inst/tests/testthat/test_context.R | 15 ++++++++------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 13 ++++--------- docs/sparkr.md | 11 ++++------- 5 files changed, 21 insertions(+), 27 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7739e9ea8689f..00634c1a70c26 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -280,7 +280,7 @@ export("as.DataFrame", "read.text", "sql", "str", - "table", + "tableToDF", "tableNames", "tables", "uncacheTable") diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 99679b4a774d3..16a2578678cd3 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -352,6 +352,8 @@ sql <- function(sqlContext, sqlQuery) { #' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a DataFrame. #' @return DataFrame +#' @rdname tableToDF +#' @name tableToDF #' @export #' @examples #'\dontrun{ @@ -360,15 +362,14 @@ sql <- function(sqlContext, sqlQuery) { #' path <- "path/to/file.json" #' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- table(sqlContext, "table") +#' new_df <- tableToDF(sqlContext, "table") #' } -table <- function(sqlContext, tableName) { +tableToDF <- function(sqlContext, tableName) { sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) } - #' Tables #' #' Returns a DataFrame containing names of tables in the given database. diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 92dbd575c20ab..3b14a497b487a 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -24,11 +24,11 @@ test_that("Check masked functions", { func <- lapply(masked, function(x) { capture.output(showMethods(x))[[1]] }) funcSparkROrEmpty <- grepl("\\(package SparkR\\)$|^$", func) maskedBySparkR <- masked[funcSparkROrEmpty] - expect_equal(length(maskedBySparkR), 18) - expect_equal(sort(maskedBySparkR), sort(c("describe", "cov", "filter", "lag", "na.omit", - "predict", "sd", "var", "colnames", "colnames<-", - "intersect", "rank", "rbind", "sample", "subset", - "summary", "table", "transform"))) + namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", + "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", + "summary", "transform") + expect_equal(length(maskedBySparkR), length(namesOfMasked)) + expect_equal(sort(maskedBySparkR), sort(namesOfMasked)) # above are those reported as masked when `library(SparkR)` # note that many of these methods are still callable without base:: or stats:: prefix # there should be a test for each of these, except followings, which are currently "broken" @@ -36,8 +36,9 @@ test_that("Check masked functions", { any(grepl("=\"ANY\"", capture.output(showMethods(x)[-1]))) })) maskedCompletely <- masked[!funcHasAny] - expect_equal(length(maskedCompletely), 4) - expect_equal(sort(maskedCompletely), sort(c("cov", "filter", "sample", "table"))) + namesOfMaskedCompletely <- c("cov", "filter", "sample") + expect_equal(length(maskedCompletely), length(namesOfMaskedCompletely)) + expect_equal(sort(maskedCompletely), sort(namesOfMaskedCompletely)) }) test_that("repeatedly starting and stopping SparkR", { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6610734cf4fae..14d40d5066e78 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -335,7 +335,6 @@ writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { # ArrayType df <- read.json(sqlContext, complexTypeJsonPath) - ldf <- collect(df) expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) @@ -490,19 +489,15 @@ test_that("insertInto() on a registered table", { unlink(parquetPath2) }) -test_that("table() returns a new DataFrame", { +test_that("tableToDF() returns a new DataFrame", { df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") - tabledf <- table(sqlContext, "table1") + tabledf <- tableToDF(sqlContext, "table1") expect_is(tabledf, "DataFrame") expect_equal(count(tabledf), 3) + tabledf2 <- tableToDF(sqlContext, "table1") + expect_equal(count(tabledf2), 3) dropTempTable(sqlContext, "table1") - - # nolint start - # Test base::table is working - #a <- letters[1:3] - #expect_equal(class(table(a, sample(a))), "table") - # nolint end }) test_that("toRDD() returns an RRDD", { diff --git a/docs/sparkr.md b/docs/sparkr.md index ea81532c611e2..73e38b8c70f01 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -375,13 +375,6 @@ The following functions are masked by the SparkR package: sample in package:base base::sample(x, size, replace = FALSE, prob = NULL) - - table in package:base -
    base::table(...,
    -            exclude = if (useNA == "no") c(NA, NaN),
    -            useNA = c("no", "ifany", "always"),
    -            dnn = list.names(...), deparse.level = 1)
    - Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. @@ -394,3 +387,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading From SparkR 1.5.x to 1.6 - Before Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API. + +## Upgrading From SparkR 1.6.x to 2.0 + + - The method `table` has been removed and replaced by `tableToDF`. From 6844d36aea91e9a7114f477a1cf3cdb9a882926a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 19 Jan 2016 20:45:52 -0800 Subject: [PATCH 1634/2089] [SPARK-12871][SQL] Support to specify the option for compression codec. https://issues.apache.org/jira/browse/SPARK-12871 This PR added an option to support to specify compression codec. This adds the option `codec` as an alias `compression` as filed in [SPARK-12668 ](https://issues.apache.org/jira/browse/SPARK-12668). Note that I did not add configurations for Hadoop 1.x as this `CsvRelation` is using Hadoop 2.x API and I guess it is going to drop Hadoop 1.x support. Author: hyukjinkwon Closes #10805 from HyukjinKwon/SPARK-12420. --- .../datasources/csv/CSVParameters.scala | 36 +++++++++++++++++-- .../datasources/csv/CSVRelation.scala | 10 ++++++ .../execution/datasources/csv/CSVSuite.scala | 26 ++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala index 127c9728da2d1..676a3d3bca9f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.Charset +import org.apache.hadoop.io.compress._ + import org.apache.spark.Logging +import org.apache.spark.util.Utils private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { @@ -35,7 +38,7 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] private def getBool(paramName: String, default: Boolean = false): Boolean = { val param = parameters.getOrElse(paramName, default.toString) - if (param.toLowerCase() == "true") { + if (param.toLowerCase == "true") { true } else if (param.toLowerCase == "false") { false @@ -73,6 +76,11 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] val nullValue = parameters.getOrElse("nullValue", "") + val compressionCodec: Option[String] = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CSVCompressionCodecs.getCodecClassName) + } + val maxColumns = 20480 val maxCharsPerColumn = 100000 @@ -85,7 +93,6 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] } private[csv] object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" val DROP_MALFORMED_MODE = "DROPMALFORMED" val FAIL_FAST_MODE = "FAILFAST" @@ -107,3 +114,28 @@ private[csv] object ParseModes { true // We default to permissive is the mode string is not valid } } + +private[csv] object CSVCompressionCodecs { + private val shortCompressionCodecNames = Map( + "bzip2" -> classOf[BZip2Codec].getName, + "gzip" -> classOf[GzipCodec].getName, + "lz4" -> classOf[Lz4Codec].getName, + "snappy" -> classOf[SnappyCodec].getName) + + /** + * Return the full version of the given codec class. + * If it is already a class name, just return it. + */ + def getCodecClassName(name: String): String = { + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + try { + // Validate the codec name + Utils.classForName(codecName) + codecName + } catch { + case e: ClassNotFoundException => + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 53818853ffb3b..1502501c3b89e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -24,6 +24,7 @@ import scala.util.control.NonFatal import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.RecordWriter @@ -99,6 +100,15 @@ private[csv] class CSVRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = job.getConfiguration + params.compressionCodec.foreach { codec => + conf.set("mapreduce.output.fileoutputformat.compress", "true") + conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) + conf.set("mapreduce.map.output.compress", "true") + conf.set("mapreduce.map.output.compress.codec", codec) + } + new CSVOutputWriterFactory(params) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 071b5ef56d58b..a79566b1f3658 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -349,4 +349,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } + + test("save csv with compression codec option") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("compression", "gZiP") + .save(csvDir) + + val compressedFiles = new File(csvDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".gz"))) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } } From 753b1945115245800898959e3ab249a94a1935e9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 20 Jan 2016 00:00:28 -0800 Subject: [PATCH 1635/2089] [SPARK-12912][SQL] Add a test suite for EliminateSubQueries Also updated documentation to explain why ComputeCurrentTime and EliminateSubQueries are in the optimizer rather than analyzer. Author: Reynold Xin Closes #10837 from rxin/optimizer-analyzer-comment. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 15 ++-- .../spark/sql/catalyst/trees/TreeNode.scala | 41 ++++++----- .../optimizer/EliminateSubQueriesSuite.scala | 69 +++++++++++++++++++ 4 files changed, 103 insertions(+), 26 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9257fba60e36c..d4b4bc88b3f2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -297,7 +297,7 @@ class Analyzer( * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation): LogicalPlan = { + private def getTable(u: UnresolvedRelation): LogicalPlan = { try { catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { @@ -1165,7 +1165,7 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case Subquery(_, child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b7caa49f5046b..04643f0274bd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -35,11 +35,16 @@ import org.apache.spark.sql.types._ */ abstract class Optimizer extends RuleExecutor[LogicalPlan] { def batches: Seq[Batch] = { - // SubQueries are only needed for analysis and can be removed before execution. - Batch("Remove SubQueries", FixedPoint(100), - EliminateSubQueries) :: - Batch("Compute Current Time", Once, + // Technically some of the rules in Finish Analysis are not optimizer rules and belong more + // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). + // However, because we also use the analyzer to canonicalized queries (for view definition), + // we do not eliminate subqueries or compute current time in the analyzer. + Batch("Finish Analysis", Once, + EliminateSubQueries, ComputeCurrentTime) :: + ////////////////////////////////////////////////////////////////////////////////////////// + // Optimizer rules start here + ////////////////////////////////////////////////////////////////////////////////////////// Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -57,7 +62,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ProjectCollapsing, CombineFilters, CombineLimits, - // Constant folding + // Constant folding and strength reduction NullPropagation, OptimizeIn, ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d74f3ef2ffba6..57e1a3c9eb226 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -244,6 +244,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * When `rule` does not apply to a given node it is left unchanged. * Users should not expect a specific directionality. If a specific directionality is needed, * transformDown or transformUp should be used. + * * @param rule the function use to transform this nodes children */ def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = { @@ -253,6 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a copy of this node where `rule` has been recursively applied to it and all of its * children (pre-order). When `rule` does not apply to a given node it is left unchanged. + * * @param rule the function used to transform this nodes children */ def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { @@ -268,6 +270,26 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. + * + * @param rule the function use to transform this nodes children + */ + def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } + } + } + /** * Returns a copy of this node where `rule` has been recursively applied to all the children of * this node. When `rule` does not apply to a given node it is left unchanged. @@ -332,25 +354,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (changed) makeCopy(newArgs) else this } - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order). When `rule` does not apply to a given node, it is left - * unchanged. - * @param rule the function use to transform this nodes children - */ - def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[BaseType]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) - } - } - } - /** * Args to the constructor that should be copied, but not transformed. * These are appended to the transformed args automatically by makeCopy diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala new file mode 100644 index 0000000000000..e0d430052fb55 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class EliminateSubQueriesSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("EliminateSubQueries", Once, EliminateSubQueries) :: Nil + } + + private def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + private def afterOptimization(plan: LogicalPlan): LogicalPlan = { + Optimize.execute(analysis.SimpleAnalyzer.execute(plan)) + } + + test("eliminate top level subquery") { + val input = LocalRelation('a.int, 'b.int) + val query = Subquery("a", input) + comparePlans(afterOptimization(query), input) + } + + test("eliminate mid-tree subquery") { + val input = LocalRelation('a.int, 'b.int) + val query = Filter(TrueLiteral, Subquery("a", input)) + comparePlans( + afterOptimization(query), + Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) + } + + test("eliminate multiple subqueries") { + val input = LocalRelation('a.int, 'b.int) + val query = Filter(TrueLiteral, Subquery("c", Subquery("b", Subquery("a", input)))) + comparePlans( + afterOptimization(query), + Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) + } + +} From 8e4f894e986ccd943df9ddf55fc853eb0558886f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 10:02:40 -0800 Subject: [PATCH 1636/2089] [SPARK-12881] [SQL] subexpress elimination in mutable projection Author: Davies Liu Closes #10814 from davies/mutable_subexpr. --- .../expressions/EquivalentExpressions.scala | 5 ++- .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 8 ++-- .../codegen/GenerateMutableProjection.scala | 43 ++++++++++++++----- .../codegen/GenerateUnsafeProjection.scala | 6 +-- .../SubexpressionEliminationSuite.scala | 13 ++++++ .../spark/sql/execution/SparkPlan.scala | 6 ++- .../apache/spark/sql/execution/Window.scala | 8 +++- .../aggregate/SortBasedAggregate.scala | 3 +- .../aggregate/TungstenAggregate.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++ 11 files changed, 80 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f7162e420d19a..affd1bdb327c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback + /** * This class is used to compute equality of (sub)expression trees. Expressions can be added * to this class and they subsequently query for expression equality. Expression trees are @@ -67,7 +69,8 @@ class EquivalentExpressions { */ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf - if (!skip && !addExpr(root)) { + // the children of CodegenFallback will not be used to generate code (call eval() instead) + if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { root.children.foreach(addExprTree(_, ignoreLeaf)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 25cf210c4b527..db17ba7c84ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -100,8 +100,8 @@ abstract class Expression extends TreeNode[Expression] { ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = ExprCode("", isNull, primitive) + val value = ctx.freshName("value") + val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) // Add `this` in the comment. ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 683029ff144d8..2747c315ad374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -125,7 +125,7 @@ class CodegenContext { val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // The collection of sub-exression result resetting methods that need to be called on each row. - val subExprResetVariables = mutable.ArrayBuffer.empty[String] + val subexprFunctions = mutable.ArrayBuffer.empty[String] def declareAddedFunctions(): String = { addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") @@ -424,9 +424,9 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { val expr = e.head - val isNull = freshName("isNull") - val value = freshName("value") val fnName = freshName("evalExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val code = expr.gen(this) @@ -461,7 +461,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subExprResetVariables += s"$fnName($INPUT_ROW);" + subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 59ef0f5836a3c..d9fe76133c6ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -38,12 +38,29 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean): (() => MutableProjection) = { + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + } + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + create(expressions, false) + } + + private def create( + expressions: Seq[Expression], + useSubexprElimination: Boolean): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCodes = expressions.zipWithIndex.map { - case (NoOp, _) => "" - case (e, i) => - val evaluationCode = e.gen(ctx) + val (validExpr, index) = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.unzip + val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + val projectionCodes = exprVals.zip(index).map { + case (ev, i) => + val e = expressions(i) if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" @@ -51,22 +68,25 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$isNull = ${ev.isNull}; + this.$value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$value = ${ev.value}; """ } } - val updates = expressions.zipWithIndex.map { - case (NoOp, _) => "" + + // Evaluate all the the subexpressions. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") + + val updates = validExpr.zip(index).map { case (e, i) => if (e.nullable) { if (e.dataType.isInstanceOf[DecimalType]) { @@ -128,6 +148,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $evalSubexpr $allProjections // copy all the results into MutableRow $allUpdates diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 61e7469ee4be2..72bf39a0398b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -294,13 +294,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - // Reset the subexpression values for each row. - val subexprReset = ctx.subExprResetVariables.mkString("\n") + // Evaluate all the subexpression. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") val code = s""" $bufferHolder.reset(); - $subexprReset + $evalSubexpr ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index a61297b2c0395..43a3eb9dec97c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -154,4 +154,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite { equivalence.addExpr(sum) assert(equivalence.getAllEquivalentExprs.isEmpty) } + + test("Children of CodegenFallback") { + val one = Literal(1) + val two = Add(one, one) + val explode = Explode(two) + val add = Add(two, explode) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + // the `two` inside `explode` should not be added + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 0) + assert(equivalence.getAllEquivalentExprs.filter(_.size == 1).size == 3) // add, two, explode + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 75101ea0fc6d2..b19b772409d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[this] def isTesting: Boolean = sys.props.contains("spark.testing") protected def newMutableProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean = false): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") try { - GenerateMutableProjection.generate(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } catch { case e: Exception => if (isTesting) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 168b5ab0316d1..26a7340f1ae10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -194,7 +194,11 @@ case class Window( val functions = functionSeq.toArray // Construct an aggregate processor if we need one. - def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => newMutableProjection(expressions, schema)) // Create the factory val factory = key match { @@ -206,7 +210,7 @@ case class Window( ordinal, functions, child.output, - newMutableProjection, + (expressions, schema) => newMutableProjection(expressions, schema), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 1d56592c40b96..06a3991459f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -87,7 +87,8 @@ case class SortBasedAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a9cf04388d2e8..8dcbab4c8cfbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -94,7 +94,8 @@ case class TungstenAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, testFallbackStartsAt, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d7f182352b4c9..b159346bed9f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ @@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + // Would be nice if semantic equals for `+` understood commutative verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) From 9376ae723e4ec0515120c488541617a0538f8879 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 20 Jan 2016 10:48:10 -0800 Subject: [PATCH 1637/2089] [SPARK-6519][ML] Add spark.ml API for bisecting k-means Author: Yu ISHIKAWA Closes #9604 from yu-iskw/SPARK-6519. --- .../spark/ml/clustering/BisectingKMeans.scala | 196 ++++++++++++++++++ .../ml/clustering/BisectingKMeansSuite.scala | 85 ++++++++ 2 files changed, 281 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala new file mode 100644 index 0000000000000..0b47cbbac8fe5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.clustering. + {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} + + +/** + * Common params for BisectingKMeans and BisectingKMeansModel + */ +private[clustering] trait BisectingKMeansParams extends Params + with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.0.0") + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** @group expertParam */ + @Since("2.0.0") + final val minDivisibleClusterSize = new Param[Double]( + this, + "minDivisibleClusterSize", + "the minimum number of points (if >= 1.0) or the minimum proportion", + (value: Double) => value > 0) + + /** @group expertGetParam */ + @Since("2.0.0") + def getMinDivisibleClusterSize: Double = $(minDivisibleClusterSize) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Model fitted by BisectingKMeans. + * + * @param parentModel a model trained by spark.mllib.clustering.BisectingKMeans. + */ +@Since("2.0.0") +@Experimental +class BisectingKMeansModel private[ml] ( + @Since("2.0.0") override val uid: String, + private val parentModel: MLlibBisectingKMeansModel + ) extends Model[BisectingKMeansModel] with BisectingKMeansParams { + + @Since("2.0.0") + override def copy(extra: ParamMap): BisectingKMeansModel = { + val copied = new BisectingKMeansModel(uid, parentModel) + copyValues(copied, extra) + } + + @Since("2.0.0") + override def transform(dataset: DataFrame): DataFrame = { + val predictUDF = udf((vector: Vector) => predict(vector)) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + @Since("2.0.0") + def clusterCenters: Array[Vector] = parentModel.clusterCenters + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("2.0.0") + def computeCost(dataset: DataFrame): Double = { + SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + parentModel.computeCost(data) + } +} + +/** + * :: Experimental :: + * + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("2.0.0") +@Experimental +class BisectingKMeans @Since("2.0.0") ( + @Since("2.0.0") override val uid: String) + extends Estimator[BisectingKMeansModel] with BisectingKMeansParams { + + setDefault( + k -> 4, + maxIter -> 20, + minDivisibleClusterSize -> 1.0) + + @Since("2.0.0") + override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra) + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("bisecting k-means")) + + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group expertSetParam */ + @Since("2.0.0") + def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + + @Since("2.0.0") + override def fit(dataset: DataFrame): BisectingKMeansModel = { + val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + + val bkm = new MLlibBisectingKMeans() + .setK($(k)) + .setMaxIterations($(maxIter)) + .setMinDivisibleClusterSize($(minDivisibleClusterSize)) + .setSeed($(seed)) + val parentModel = bkm.run(rdd) + val model = new BisectingKMeansModel(uid, parentModel) + copyValues(model) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala new file mode 100644 index 0000000000000..b26571eb9f35a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame + +class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + final val k = 5 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val bkm = new BisectingKMeans() + + assert(bkm.getK === 4) + assert(bkm.getFeaturesCol === "features") + assert(bkm.getPredictionCol === "prediction") + assert(bkm.getMaxIter === 20) + assert(bkm.getMinDivisibleClusterSize === 1.0) + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + .setK(9) + .setMinDivisibleClusterSize(2.0) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setSeed(123) + + assert(bkm.getK === 9) + assert(bkm.getFeaturesCol === "test_feature") + assert(bkm.getPredictionCol === "test_prediction") + assert(bkm.getMaxIter === 33) + assert(bkm.getMinDivisibleClusterSize === 2.0) + assert(bkm.getSeed === 123) + + intercept[IllegalArgumentException] { + new BisectingKMeans().setK(1) + } + + intercept[IllegalArgumentException] { + new BisectingKMeans().setMinDivisibleClusterSize(0) + } + } + + test("fit & transform") { + val predictionColName = "bisecting_kmeans_prediction" + val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = bkm.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) + } +} From 9bb35c5b59e58dbebbdc6856d611bff73dd35a91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Lipt=C3=A1k?= Date: Wed, 20 Jan 2016 11:11:10 -0800 Subject: [PATCH 1638/2089] [SPARK-11295][PYSPARK] Add packages to JUnit output for Python tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is #9263 from gliptak (improving grouping/display of test case results) with a small fix of bisecting k-means unit test. Author: Gábor Lipták Author: Xiangrui Meng Closes #10850 from mengxr/SPARK-11295. --- python/pyspark/ml/tests.py | 1 + python/pyspark/mllib/tests.py | 26 +++++++++++++++----------- python/pyspark/sql/tests.py | 1 + python/pyspark/streaming/tests.py | 1 + python/pyspark/tests.py | 1 + 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4eb17bfdcca90..9ea639dc4f960 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -394,6 +394,7 @@ def test_fit_maximize_metric(self): if __name__ == "__main__": + from pyspark.ml.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 32ed48e10388e..79ce4959c9266 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -77,21 +77,24 @@ pass ser = PickleSerializer() -sc = SparkContext('local[4]', "MLlib tests") class MLlibTestCase(unittest.TestCase): def setUp(self): - self.sc = sc + self.sc = SparkContext('local[4]', "MLlib tests") + + def tearDown(self): + self.sc.stop() class MLLibStreamingTestCase(unittest.TestCase): def setUp(self): - self.sc = sc + self.sc = SparkContext('local[4]', "MLlib tests") self.ssc = StreamingContext(self.sc, 1.0) def tearDown(self): self.ssc.stop(False) + self.sc.stop() @staticmethod def _eventually(condition, timeout=30.0, catch_assertions=False): @@ -423,7 +426,7 @@ def test_bisecting_kmeans(self): from pyspark.mllib.clustering import BisectingKMeans data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) bskm = BisectingKMeans() - model = bskm.train(sc.parallelize(data, 2), k=4) + model = bskm.train(self.sc.parallelize(data, 2), k=4) p = array([0.0, 0.0]) rdd_p = self.sc.parallelize([p]) self.assertEqual(model.predict(p), model.predict(rdd_p).first()) @@ -1166,7 +1169,7 @@ def test_predictOn_model(self): clusterWeights=[1.0, 1.0, 1.0, 1.0]) predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] predict_stream = self.ssc.queueStream(predict_data) predict_val = stkm.predictOn(predict_stream) @@ -1197,7 +1200,7 @@ def test_trainOn_predictOn(self): # classification based in the initial model would have been 0 # proving that the model is updated. batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [sc.parallelize(batch) for batch in batches] + batches = [self.sc.parallelize(batch) for batch in batches] input_stream = self.ssc.queueStream(batches) predict_results = [] @@ -1230,7 +1233,7 @@ def test_dim(self): self.assertEqual(len(point.features), 3) linear_data = LinearDataGenerator.generateLinearRDD( - sc=sc, nexamples=6, nfeatures=2, eps=0.1, + sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, nParts=2, intercept=0.0).collect() self.assertEqual(len(linear_data), 6) for point in linear_data: @@ -1406,7 +1409,7 @@ def test_parameter_accuracy(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) input_stream = self.ssc.queueStream(batches) slr.trainOn(input_stream) @@ -1430,7 +1433,7 @@ def test_parameter_convergence(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) model_weights = [] input_stream = self.ssc.queueStream(batches) @@ -1463,7 +1466,7 @@ def test_prediction(self): 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], 100, 42 + i, 0.1) batches.append( - sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) input_stream = self.ssc.queueStream(batches) output_stream = slr.predictOnValues(input_stream) @@ -1494,7 +1497,7 @@ def test_train_prediction(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) predict_batches = [ b.map(lambda lp: (lp.label, lp.features)) for b in batches] @@ -1580,6 +1583,7 @@ def test_als_ratings_id_long_error(self): if __name__ == "__main__": + from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c03cb9338ae68..ae8620274dd20 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1259,6 +1259,7 @@ def test_collect_functions(self): if __name__ == "__main__": + from pyspark.sql.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 86b05d9fd2424..24b812615cbb4 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1635,6 +1635,7 @@ def search_kinesis_asl_assembly_jar(): are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' if __name__ == "__main__": + from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() mqtt_assembly_jar = search_mqtt_assembly_jar() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5bd94476597ab..23720502a82c8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2008,6 +2008,7 @@ def test_statcounter_array(self): if __name__ == "__main__": + from pyspark.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: From 9753835cf3acc135e61bf668223046e29306c80d Mon Sep 17 00:00:00 2001 From: Imran Younus Date: Wed, 20 Jan 2016 11:16:59 -0800 Subject: [PATCH 1639/2089] [SPARK-12230][ML] WeightedLeastSquares.fit() should handle division by zero properly if standard deviation of target variable is zero. This fixes the behavior of WeightedLeastSquars.fit() when the standard deviation of the target variable is zero. If the fitIntercept is true, there is no need to train. Author: Imran Younus Closes #10274 from iyounus/SPARK-12230_bug_fix_in_weighted_least_squares. --- .../spark/ml/optim/WeightedLeastSquares.scala | 21 +++++- .../ml/optim/WeightedLeastSquaresSuite.scala | 69 +++++++++++++++++-- 2 files changed, 83 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 8617722ae542f..797870eb8ce8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares( val aaBar = summary.aaBar val aaValues = aaBar.values + if (bStd == 0) { + if (fitIntercept) { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + val coefficients = new DenseVector(Array.ofDim(k-1)) + val intercept = bBar + val diagInvAtWA = new DenseVector(Array(0D)) + return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + } else { + require(!(regParam > 0.0 && standardizeLabel), + "The standard deviation of the label is zero. " + + "Model cannot be regularized with standardization=true") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } + } + // add regularization to diagonals var i = 0 var j = 2 @@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares( if (standardizeFeatures) { lambda *= aVar(j - 2) } - if (standardizeLabel) { - // TODO: handle the case when bStd = 0 + if (standardizeLabel && bStd != 0) { lambda /= bStd } aaValues(i) += lambda diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index b542ba3dc54d2..0b58a9821f57b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { private var instances: RDD[Instance] = _ + private var instancesConstLabel: RDD[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + */ + instancesConstLabel = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) } test("WLS against lm") { @@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) - .fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization).fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } + idx += 1 + } + } + + test("WLS against lm when label is constant and no regularization") { + /* + R code: + + df.const.label <- as.data.frame(cbind(A, b.const)) + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization).fit(instancesConstLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } idx += 1 } } + test("WLS with regularization when label is constant") { + // if regParam is non-zero and standardization is true, the problem is ill-defined and + // an exception is thrown. + val wls = new WeightedLeastSquares( + fitIntercept = false, regParam = 0.1, standardizeFeatures = true, + standardizeLabel = true) + intercept[IllegalArgumentException]{ + wls.fit(instancesConstLabel) + } + } + test("WLS against glmnet") { /* R code: From e75e340a406b765608258b49f7e2f1107d4605fb Mon Sep 17 00:00:00 2001 From: Rajesh Balamohan Date: Wed, 20 Jan 2016 11:20:26 -0800 Subject: [PATCH 1640/2089] =?UTF-8?q?[SPARK-12925][SQL]=20Improve=20HiveIn?= =?UTF-8?q?spectors.unwrap=20for=20StringObjectIns=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Text is in UTF-8 and converting it via "UTF8String.fromString" incurs decoding and encoding, which turns out to be expensive and redundant. Profiler snapshot details is attached in the JIRA (ref:https://issues.apache.org/jira/secure/attachment/12783331/SPARK-12925_profiler_cpu_samples.png) Author: Rajesh Balamohan Closes #10848 from rajeshbalamohan/SPARK-12925. --- .../main/scala/org/apache/spark/sql/hive/HiveInspectors.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7a260e72eb459..5d84feb483eac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -320,7 +320,9 @@ private[hive] trait HiveInspectors { case hvoi: HiveCharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) + // Text is in UTF-8 already. No need to convert again via fromString + val wObj = x.getPrimitiveWritableObject(data) + UTF8String.fromBytes(wObj.getBytes, 0, wObj.getLength) case x: StringObjectInspector => UTF8String.fromString(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) From ab4a6bfd11b870428eb2a96aa213f7d34c0aa622 Mon Sep 17 00:00:00 2001 From: Rajesh Balamohan Date: Wed, 20 Jan 2016 11:30:03 -0800 Subject: [PATCH 1641/2089] [SPARK-12898] Consider having dummyCallSite for HiveTableScan Currently, HiveTableScan runs with getCallSite which is really expensive and shows up when scanning through large table with partitions (e.g TPC-DS) which slows down the overall runtime of the job. It would be good to consider having dummyCallSite in HiveTableScan. Author: Rajesh Balamohan Closes #10825 from rajeshbalamohan/SPARK-12898. --- .../spark/sql/hive/execution/HiveTableScan.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 1588728bdbaa4..eff8833e9232e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.util.Utils /** * The Hive table scan operator. Column and partition pruning are both handled. @@ -133,11 +134,17 @@ case class HiveTableScan( } protected override def doExecute(): RDD[InternalRow] = { + // Using dummyCallSite, as getCallSite can turn out to be expensive with + // with multiple partitions. val rdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) + Utils.withDummyCallSite(sqlContext.sparkContext) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + Utils.withDummyCallSite(sqlContext.sparkContext) { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + } } rdd.mapPartitionsInternal { iter => val proj = UnsafeProjection.create(schema) From e3727c409fe7d1fb6e27a14faddd0602f963745e Mon Sep 17 00:00:00 2001 From: Takahashi Hiroshi Date: Wed, 20 Jan 2016 11:44:04 -0800 Subject: [PATCH 1642/2089] [SPARK-10263][ML] Add @Since annotation to ml.param and ml.* Add Since annotations to ml.param and ml.* Author: Takahashi Hiroshi Author: Hiroshi Takahashi Closes #8935 from taishi-oss/issue10263. --- .../scala/org/apache/spark/ml/Pipeline.scala | 21 ++++++++++++--- .../org/apache/spark/ml/param/params.scala | 26 +++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 32570a16e6707..cbac7bbf49fc4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -85,25 +85,32 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ +@Since("1.2.0") @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { +class Pipeline @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("pipeline")) /** * param for pipeline stages * @group param */ + @Since("1.2.0") val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") /** @group setParam */ + @Since("1.2.0") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. /** @group getParam */ + @Since("1.2.0") def getStages: Array[PipelineStage] = $(stages).clone() + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() $(stages).foreach(_.validateParams()) @@ -121,6 +128,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M * @param dataset input dataset * @return fitted pipeline */ + @Since("1.2.0") override def fit(dataset: DataFrame): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) @@ -158,12 +166,14 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M new PipelineModel(uid, transformers.toArray).setParent(this) } + @Since("1.4.0") override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) new Pipeline().setStages(newStages) } + @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { validateParams() val theStages = $(stages) @@ -275,10 +285,11 @@ object Pipeline extends MLReadable[Pipeline] { * :: Experimental :: * Represents a fitted pipeline. */ +@Since("1.2.0") @Experimental class PipelineModel private[ml] ( - override val uid: String, - val stages: Array[Transformer]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") val stages: Array[Transformer]) extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ @@ -286,21 +297,25 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) } + @Since("1.2.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) } + @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } + @Since("1.4.0") override def copy(extra: ParamMap): PipelineModel = { new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c0546695e487b..f48923d69974b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -504,8 +504,11 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In * :: Experimental :: * A param and its value. */ +@Since("1.2.0") @Experimental -case class ParamPair[T](param: Param[T], value: T) { +case class ParamPair[T] @Since("1.2.0") ( + @Since("1.2.0") param: Param[T], + @Since("1.2.0") value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. param.validate(value) @@ -786,6 +789,7 @@ abstract class JavaParams extends Params * :: Experimental :: * A param to value map. */ +@Since("1.2.0") @Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -799,17 +803,20 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Creates an empty param map. */ + @Since("1.2.0") def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). */ + @Since("1.2.0") def put[T](param: Param[T], value: T): this.type = put(param -> value) /** * Puts a list of param pairs (overwrites if the input params exists). */ @varargs + @Since("1.2.0") def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => map(p.param.asInstanceOf[Param[Any]]) = p.value @@ -820,6 +827,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Optionally returns the value associated with a param. */ + @Since("1.2.0") def get[T](param: Param[T]): Option[T] = { map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } @@ -827,6 +835,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Returns the value associated with a param or a default value. */ + @Since("1.4.0") def getOrElse[T](param: Param[T], default: T): T = { get(param).getOrElse(default) } @@ -835,6 +844,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Gets the value of the input param or its default value if it does not exist. * Raises a NoSuchElementException if there is no value associated with the input param. */ + @Since("1.2.0") def apply[T](param: Param[T]): T = { get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") @@ -844,6 +854,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Checks whether a parameter is explicitly specified. */ + @Since("1.2.0") def contains(param: Param[_]): Boolean = { map.contains(param.asInstanceOf[Param[Any]]) } @@ -851,6 +862,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Removes a key from this map and returns its value associated previously as an option. */ + @Since("1.4.0") def remove[T](param: Param[T]): Option[T] = { map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } @@ -858,6 +870,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Filters this param map for the given parent. */ + @Since("1.2.0") def filter(parent: Params): ParamMap = { // Don't use filterKeys because mutable.Map#filterKeys // returns the instance of collections.Map, not mutable.Map. @@ -870,8 +883,10 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Creates a copy of this param map. */ + @Since("1.2.0") def copy: ParamMap = new ParamMap(map.clone()) + @Since("1.2.0") override def toString: String = { map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent}-${param.name}: $value" @@ -882,6 +897,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Returns a new param map that contains parameters in this map and the given map, * where the latter overwrites this if there exist conflicts. */ + @Since("1.2.0") def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. new ParamMap(this.map ++ other.map) @@ -890,6 +906,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Adds all parameters from the input param map into this param map. */ + @Since("1.2.0") def ++=(other: ParamMap): this.type = { // TODO: Provide a better method name for Java users. this.map ++= other.map @@ -899,6 +916,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Converts this param map to a sequence of param pairs. */ + @Since("1.2.0") def toSeq: Seq[ParamPair[_]] = { map.toSeq.map { case (param, value) => ParamPair(param, value) @@ -908,21 +926,25 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Number of param pairs in this map. */ + @Since("1.3.0") def size: Int = map.size } +@Since("1.2.0") @Experimental object ParamMap { /** * Returns an empty param map. */ + @Since("1.2.0") def empty: ParamMap = new ParamMap() /** * Constructs a param map by specifying its entries. */ @varargs + @Since("1.2.0") def apply(paramPairs: ParamPair[_]*): ParamMap = { new ParamMap().put(paramPairs: _*) } From 944fdadf77523570f6b33544ad0b388031498952 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 20 Jan 2016 11:57:53 -0800 Subject: [PATCH 1643/2089] [SPARK-12847][CORE][STREAMING] Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events Including the following changes: 1. Add StreamingListenerForwardingBus to WrappedStreamingListenerEvent process events in `onOtherEvent` to StreamingListener 2. Remove StreamingListenerBus 3. Merge AsynchronousListenerBus and LiveListenerBus to the same class LiveListenerBus 4. Add `logEvent` method to SparkListenerEvent so that EventLoggingListener can use it to ignore WrappedStreamingListenerEvents Author: Shixiong Zhu Closes #10779 from zsxwing/streaming-listener. --- .../scala/org/apache/spark/SparkContext.scala | 6 +- .../scheduler/EventLoggingListener.scala | 4 +- .../spark/scheduler/LiveListenerBus.scala | 167 ++++++++++++++- .../spark/scheduler/SparkListener.scala | 5 +- .../spark/scheduler/SparkListenerBus.scala | 2 +- .../spark/util/AsynchronousListenerBus.scala | 190 ------------------ .../org/apache/spark/util/ListenerBus.scala | 14 +- project/MimaExcludes.scala | 4 + .../spark/streaming/StreamingContext.scala | 9 +- .../streaming/scheduler/JobScheduler.scala | 4 +- .../scheduler/StreamingListenerBus.scala | 69 +++++-- .../spark/streaming/InputStreamsSuite.scala | 2 +- .../streaming/StreamingListenerSuite.scala | 22 ++ .../scheduler/ReceiverTrackerSuite.scala | 2 - 14 files changed, 269 insertions(+), 231 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 77acb7052ddf5..d7c605a583c4f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1644,9 +1644,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Shut down the SparkContext. def stop() { - if (AsynchronousListenerBus.withinListenerThread.value) { - throw new SparkException("Cannot stop SparkContext within listener thread of" + - " AsynchronousListenerBus") + if (LiveListenerBus.withinListenerThread.value) { + throw new SparkException( + s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index aa607c5a2df93..36f2b74f948f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -200,7 +200,9 @@ private[spark] class EventLoggingListener( override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } override def onOtherEvent(event: SparkListenerEvent): Unit = { - logEvent(event, flushLogger = true) + if (event.logEvent) { + logEvent(event, flushLogger = true) + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index be23056e7d423..1c21313d1cb17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,24 +17,169 @@ package org.apache.spark.scheduler +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.util.AsynchronousListenerBus +import scala.util.DynamicVariable + +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. * - * Until start() is called, all posted events are only buffered. Only after this listener bus + * Until `start()` is called, all posted events are only buffered. Only after this listener bus * has started will events be actually propagated to all attached listeners. This listener bus - * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). + * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus - extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus") - with SparkListenerBus { +private[spark] class LiveListenerBus extends SparkListenerBus { + + self => + + import LiveListenerBus._ + + private var sparkContext: SparkContext = null + + // Cap the capacity of the event queue so we get an explicit error (rather than + // an OOM exception) if it's perpetually being added to more quickly than it's being drained. + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + + // Indicate if `start()` is called + private val started = new AtomicBoolean(false) + // Indicate if `stop()` is called + private val stopped = new AtomicBoolean(false) + + // Indicate if we are processing some event + // Guarded by `self` + private var processingEvent = false private val logDroppedEvent = new AtomicBoolean(false) - override def onDropEvent(event: SparkListenerEvent): Unit = { + // A counter that represents the number of events produced and consumed in the queue + private val eventLock = new Semaphore(0) + + private val listenerThread = new Thread(name) { + setDaemon(true) + override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { + LiveListenerBus.withinListenerThread.withValue(true) { + while (true) { + eventLock.acquire() + self.synchronized { + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } + } + } + } + } + } + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + * + * @param sc Used to stop the SparkContext in case the listener thread dies. + */ + def start(sc: SparkContext): Unit = { + if (started.compareAndSet(false, true)) { + sparkContext = sc + listenerThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + def post(event: SparkListenerEvent): Unit = { + if (stopped.get) { + // Drop further events to make `listenerThread` exit ASAP + logError(s"$name has already stopped! Dropping event $event") + return + } + val eventAdded = eventQueue.offer(event) + if (eventAdded) { + eventLock.release() + } else { + onDropEvent(event) + } + } + + /** + * For testing only. Wait until there are no more events in the queue, or until the specified + * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue + * emptied. + * Exposed for testing. + */ + @throws(classOf[TimeoutException]) + def waitUntilEmpty(timeoutMillis: Long): Unit = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!queueIsEmpty) { + if (System.currentTimeMillis > finishTime) { + throw new TimeoutException( + s"The event queue is not empty after $timeoutMillis milliseconds") + } + /* Sleep rather than using wait/notify, because this is used only for testing and + * wait/notify add overhead in the general case. */ + Thread.sleep(10) + } + } + + /** + * For testing only. Return whether the listener daemon thread is still alive. + * Exposed for testing. + */ + def listenerThreadIsAlive: Boolean = listenerThread.isAlive + + /** + * Return whether the event queue is empty. + * + * The use of synchronized here guarantees that all events that once belonged to this queue + * have already been processed by all attached listeners, if this returns true. + */ + private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but drop the + * new events after stopping. + */ + def stop(): Unit = { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know + // `stop` is called. + eventLock.release() + listenerThread.join() + } else { + // Keep quiet + } + } + + /** + * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be + * notified with the dropped events. + * + * Note: `onDropEvent` can be called in any thread. + */ + def onDropEvent(event: SparkListenerEvent): Unit = { if (logDroppedEvent.compareAndSet(false, true)) { // Only log the following message once to avoid duplicated annoying logs. logError("Dropping SparkListenerEvent because no remaining room in event queue. " + @@ -42,5 +187,13 @@ private[spark] class LiveListenerBus "the rate at which tasks are being started by the scheduler.") } } +} + +private[spark] object LiveListenerBus { + // Allows for Context to check whether stop() call is made within listener thread + val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) + /** The thread name of Spark listener bus */ + val name = "SparkListenerBus" } + diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index f5267f58c2e42..6c6883d703bea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -34,7 +34,10 @@ import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") -trait SparkListenerEvent +trait SparkListenerEvent { + /* Whether output this event to the event log */ + protected[spark] def logEvent: Boolean = true +} @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 95722a07144ec..94f0574f0e165 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.ListenerBus */ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] { - override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { + protected override def doPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => listener.onStageSubmitted(stageSubmitted) diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala deleted file mode 100644 index f6b7ea2f37869..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean - -import scala.util.DynamicVariable - -import org.apache.spark.SparkContext - -/** - * Asynchronously passes events to registered listeners. - * - * Until `start()` is called, all posted events are only buffered. Only after this listener bus - * has started will events be actually propagated to all attached listeners. This listener bus - * is stopped when `stop()` is called, and it will drop further events after stopping. - * - * @param name name of the listener bus, will be the name of the listener thread. - * @tparam L type of listener - * @tparam E type of event - */ -private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: String) - extends ListenerBus[L, E] { - - self => - - private var sparkContext: SparkContext = null - - /* Cap the capacity of the event queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[E](EVENT_QUEUE_CAPACITY) - - // Indicate if `start()` is called - private val started = new AtomicBoolean(false) - // Indicate if `stop()` is called - private val stopped = new AtomicBoolean(false) - - // Indicate if we are processing some event - // Guarded by `self` - private var processingEvent = false - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread(name) { - setDaemon(true) - override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - AsynchronousListenerBus.withinListenerThread.withValue(true) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { - self.synchronized { - processingEvent = false - } - } - } - } - } - } - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - * - * @param sc Used to stop the SparkContext in case the listener thread dies. - */ - def start(sc: SparkContext) { - if (started.compareAndSet(false, true)) { - sparkContext = sc - listenerThread.start() - } else { - throw new IllegalStateException(s"$name already started!") - } - } - - def post(event: E) { - if (stopped.get) { - // Drop further events to make `listenerThread` exit ASAP - logError(s"$name has already stopped! Dropping event $event") - return - } - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - onDropEvent(event) - } - } - - /** - * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue - * emptied. - * Exposed for testing. - */ - @throws(classOf[TimeoutException]) - def waitUntilEmpty(timeoutMillis: Long): Unit = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - throw new TimeoutException( - s"The event queue is not empty after $timeoutMillis milliseconds") - } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) - } - } - - /** - * For testing only. Return whether the listener daemon thread is still alive. - * Exposed for testing. - */ - def listenerThreadIsAlive: Boolean = listenerThread.isAlive - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } - - /** - * Stop the listener bus. It will wait until the queued events have been processed, but drop the - * new events after stopping. - */ - def stop() { - if (!started.get()) { - throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") - } - if (stopped.compareAndSet(false, true)) { - // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know - // `stop` is called. - eventLock.release() - listenerThread.join() - } else { - // Keep quiet - } - } - - /** - * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be - * notified with the dropped events. - * - * Note: `onDropEvent` can be called in any thread. - */ - def onDropEvent(event: E): Unit -} - -private[spark] object AsynchronousListenerBus { - /* Allows for Context to check whether stop() call is made within listener thread - */ - val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) -} - diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 13cb516b583e9..5e1fab009c7d7 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -36,10 +36,18 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. */ - final def addListener(listener: L) { + final def addListener(listener: L): Unit = { listeners.add(listener) } + /** + * Remove a listener and it won't receive any events. This method is thread-safe and can be called + * in any thread. + */ + final def removeListener(listener: L): Unit = { + listeners.remove(listener) + } + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -52,7 +60,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { while (iter.hasNext) { val listener = iter.next() try { - onPostEvent(listener, event) + doPostEvent(listener, event) } catch { case NonFatal(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) @@ -64,7 +72,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same * thread. */ - def onPostEvent(listener: L, event: E): Unit + protected def doPostEvent(listener: L, event: E): Unit private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4430bfd3b0380..6469201446f0c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -153,6 +153,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") ) case v if v.startsWith("1.6") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 157ee92fd71b3..b7070dda99dc6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -37,6 +37,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.{RDD, RDDOperationScope} +import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ @@ -44,7 +45,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiverSupervisor, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -694,9 +695,9 @@ class StreamingContext private[streaming] ( */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null - if (AsynchronousListenerBus.withinListenerThread.value) { - throw new SparkException("Cannot stop StreamingContext within listener thread of" + - " AsynchronousListenerBus") + if (LiveListenerBus.withinListenerThread.value) { + throw new SparkException( + s"Cannot stop StreamingContext within listener thread of ${LiveListenerBus.name}") } synchronized { // The state should always be Stopped after calling `stop()`, even if we haven't started yet diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1ed6fb0aa9d52..9535c8e5b768a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -49,7 +49,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock - val listenerBus = new StreamingListenerBus() + val listenerBus = new StreamingListenerBus(ssc.sparkContext.listenerBus) // These two are created only when scheduler starts. // eventLoop not being null means the scheduler has been started and not stopped @@ -76,7 +76,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { rateController <- inputDStream.rateController } ssc.addStreamingListener(rateController) - listenerBus.start(ssc.sparkContext) + listenerBus.start() receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) receiverTracker.start() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index ca111bb636ed5..39f6e711a67ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -17,19 +17,37 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} +import org.apache.spark.util.ListenerBus -import org.apache.spark.Logging -import org.apache.spark.util.AsynchronousListenerBus +/** + * A Streaming listener bus to forward events to StreamingListeners. This one will wrap received + * Streaming events as WrappedStreamingListenerEvent and send them to Spark listener bus. It also + * registers itself with Spark listener bus, so that it can receive WrappedStreamingListenerEvents, + * unwrap them as StreamingListenerEvent and dispatch them to StreamingListeners. + */ +private[streaming] class StreamingListenerBus(sparkListenerBus: LiveListenerBus) + extends SparkListener with ListenerBus[StreamingListener, StreamingListenerEvent] { -/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ -private[spark] class StreamingListenerBus - extends AsynchronousListenerBus[StreamingListener, StreamingListenerEvent]("StreamingListenerBus") - with Logging { + /** + * Post a StreamingListenerEvent to the Spark listener bus asynchronously. This event will be + * dispatched to all StreamingListeners in the thread of the Spark listener bus. + */ + def post(event: StreamingListenerEvent) { + sparkListenerBus.post(new WrappedStreamingListenerEvent(event)) + } - private val logDroppedEvent = new AtomicBoolean(false) + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case WrappedStreamingListenerEvent(e) => + postToAll(e) + case _ => + } + } - override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = { + protected override def doPostEvent( + listener: StreamingListener, + event: StreamingListenerEvent): Unit = { event match { case receiverStarted: StreamingListenerReceiverStarted => listener.onReceiverStarted(receiverStarted) @@ -51,12 +69,31 @@ private[spark] class StreamingListenerBus } } - override def onDropEvent(event: StreamingListenerEvent): Unit = { - if (logDroppedEvent.compareAndSet(false, true)) { - // Only log the following message once to avoid duplicated annoying logs. - logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + - "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + - "rate at which events are being started by the scheduler.") - } + /** + * Register this one with the Spark listener bus so that it can receive Streaming events and + * forward them to StreamingListeners. + */ + def start(): Unit = { + sparkListenerBus.addListener(this) // for getting callbacks on spark events + } + + /** + * Unregister this one with the Spark listener bus and all StreamingListeners won't receive any + * events after that. + */ + def stop(): Unit = { + sparkListenerBus.removeListener(this) + } + + /** + * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark + * listener bus. + */ + private case class WrappedStreamingListenerEvent(streamingListenerEvent: StreamingListenerEvent) + extends SparkListenerEvent { + + // Do not log streaming events in event log as history server does not support streaming + // events (SPARK-12140). TODO Once SPARK-12140 is resolved we should set it to true. + protected[spark] override def logEvent: Boolean = false } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 2e231601c3953..75591f04ca00d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -77,7 +77,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // Ensure progress listener has been notified of all events - ssc.scheduler.listenerBus.waitUntilEmpty(500) + ssc.sparkContext.listenerBus.waitUntilEmpty(500) // Verify all "InputInfo"s have been reported assert(ssc.progressListener.numTotalReceivedRecords === input.size) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 628a5082074db..1ed68c74db9fd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future +import org.mockito.Mockito.{mock, reset, verifyNoMoreInteractions} import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -216,6 +217,27 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { assert(failureReasons(1).contains("This is another failed job")) } + test("StreamingListener receives no events after stopping StreamingListenerBus") { + val streamingListener = mock(classOf[StreamingListener]) + + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc.addStreamingListener(streamingListener) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + ssc.start() + ssc.stop() + + // Because "streamingListener" has already received some events, let's clear that. + reset(streamingListener) + + // Post a Streaming event after stopping StreamingContext + val receiverInfoStopped = ReceiverInfo(0, "test", false, "localhost", "0") + ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfoStopped)) + ssc.sparkContext.listenerBus.waitUntilEmpty(1000) + // The StreamingListener should not receive any event + verifyNoMoreInteractions(streamingListener) + } + private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = { val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc) _ssc.addStreamingListener(contextStoppingCollector) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index b67189fbd7f03..cfd7f86f84411 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -34,8 +34,6 @@ class ReceiverTrackerSuite extends TestSuiteBase { test("send rate update to receivers") { withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => - ssc.scheduler.listenerBus.start(ssc.sc) - val newRateLimit = 100L val inputDStream = new RateTestInputDStream(ssc) val tracker = new ReceiverTracker(ssc) From b7d74a602f622d8e105b349bd6d17ba42e7668dc Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 20 Jan 2016 13:55:41 -0800 Subject: [PATCH 1644/2089] [SPARK-7799][SPARK-12786][STREAMING] Add "streaming-akka" project Include the following changes: 1. Add "streaming-akka" project and org.apache.spark.streaming.akka.AkkaUtils for creating an actorStream 2. Remove "StreamingContext.actorStream" and "JavaStreamingContext.actorStream" 3. Update the ActorWordCount example and add the JavaActorWordCount example 4. Make "streaming-zeromq" depend on "streaming-akka" and update the codes accordingly Author: Shixiong Zhu Closes #10744 from zsxwing/streaming-akka-2. --- dev/sparktestsupport/modules.py | 12 ++ docs/streaming-custom-receivers.md | 49 ++++-- docs/streaming-programming-guide.md | 4 +- examples/pom.xml | 5 + .../streaming/JavaActorWordCount.java | 14 +- .../examples/streaming/ActorWordCount.scala | 37 +++-- .../examples/streaming/ZeroMQWordCount.scala | 13 +- external/akka/pom.xml | 73 +++++++++ .../spark/streaming/akka}/ActorReceiver.scala | 64 +++++--- .../spark/streaming/akka/AkkaUtils.scala | 147 ++++++++++++++++++ .../streaming/akka/JavaAkkaUtilsSuite.java | 66 ++++++++ .../spark/streaming/akka/AkkaUtilsSuite.scala | 64 ++++++++ external/zeromq/pom.xml | 5 + .../streaming/zeromq/ZeroMQReceiver.scala | 2 +- .../spark/streaming/zeromq/ZeroMQUtils.scala | 76 ++++++--- .../zeromq/JavaZeroMQStreamSuite.java | 31 ++-- .../streaming/zeromq/ZeroMQStreamSuite.scala | 16 +- pom.xml | 1 + project/MimaExcludes.scala | 10 ++ project/SparkBuild.scala | 9 +- .../spark/streaming/StreamingContext.scala | 24 +-- .../api/java/JavaStreamingContext.scala | 64 -------- 22 files changed, 601 insertions(+), 185 deletions(-) create mode 100644 external/akka/pom.xml rename {streaming/src/main/scala/org/apache/spark/streaming/receiver => external/akka/src/main/scala/org/apache/spark/streaming/akka}/ActorReceiver.scala (74%) create mode 100644 external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala create mode 100644 external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java create mode 100644 external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 93a8c15e3ec30..efe58ea2e0e78 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -222,6 +222,18 @@ def contains_file(self, filename): ) +streaming_akka = Module( + name="streaming-akka", + dependencies=[streaming], + source_file_regexes=[ + "external/akka", + ], + sbt_test_goals=[ + "streaming-akka/test", + ] +) + + streaming_flume = Module( name="streaming-flume", dependencies=[streaming], diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 97db865daa371..95b99862ec062 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -257,25 +257,54 @@ The following table summarizes the characteristics of both types of receivers ## Implementing and Using a Custom Actor-based Receiver +
    +
    + Custom [Akka Actors](http://doc.akka.io/docs/akka/2.3.11/scala/actors.html) can also be used to -receive data. The [`ActorHelper`](api/scala/index.html#org.apache.spark.streaming.receiver.ActorHelper) -trait can be mixed in to any Akka actor, which allows received data to be stored in Spark using - `store(...)` methods. The supervisor strategy of this actor can be configured to handle failures, etc. +receive data. Extending [`ActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.ActorReceiver) +allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of +this actor can be configured to handle failures, etc. {% highlight scala %} -class CustomActor extends Actor with ActorHelper { + +class CustomActor extends ActorReceiver { def receive = { case data: String => store(data) } } + +// A new input stream can be created with this custom actor as +val ssc: StreamingContext = ... +val lines = AkkaUtils.createStream[String](ssc, Props[CustomActor](), "CustomReceiver") + {% endhighlight %} -And a new input stream can be created with this custom actor as +See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. +
    +
    + +Custom [Akka UntypedActors](http://doc.akka.io/docs/akka/2.3.11/java/untyped-actors.html) can also be used to +receive data. Extending [`JavaActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.JavaActorReceiver) +allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of +this actor can be configured to handle failures, etc. + +{% highlight java %} + +class CustomActor extends JavaActorReceiver { + @Override + public void onReceive(Object msg) throws Exception { + store((String) msg); + } +} + +// A new input stream can be created with this custom actor as +JavaStreamingContext jssc = ...; +JavaDStream lines = AkkaUtils.createStream(jssc, Props.create(CustomActor.class), "CustomReceiver"); -{% highlight scala %} -val ssc: StreamingContext = ... -val lines = ssc.actorStream[String](Props[CustomActor], "CustomReceiver") {% endhighlight %} -See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) -for an end-to-end example. +See [JavaActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaActorWordCount.scala) for an end-to-end example. +
    +
    + +Python API Since actors are available only in the Java and Scala libraries, AkkaUtils is not available in the Python API. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8fd075d02b78e..93c34efb6662d 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -659,11 +659,11 @@ methods for creating DStreams from files and Akka actors as input sources. Python API `fileStream` is not available in the Python API, only `textFileStream` is available. - **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka - actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver + actors by using `AkkaUtils.createStream(ssc, actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html) for more details. Python API Since actors are available only in the Java and Scala - libraries, `actorStream` is not available in the Python API. + libraries, `AkkaUtils.createStream` is not available in the Python API. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. diff --git a/examples/pom.xml b/examples/pom.xml index 1a0d5e5854642..9437cee2abfdf 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -75,6 +75,11 @@ spark-streaming-flume_${scala.binary.version} ${project.version} + + org.apache.spark + spark-streaming-akka_${scala.binary.version} + ${project.version} + org.apache.spark spark-streaming-mqtt_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java index 2377207779fec..62e563380a9e7 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java @@ -31,7 +31,8 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.apache.spark.streaming.receiver.JavaActorReceiver; +import org.apache.spark.streaming.akka.AkkaUtils; +import org.apache.spark.streaming.akka.JavaActorReceiver; /** * A sample actor as receiver, is also simplest. This receiver actor @@ -56,6 +57,7 @@ public void preStart() { remotePublisher.tell(new SubscribeReceiver(getSelf()), getSelf()); } + @Override public void onReceive(Object msg) throws Exception { store((T) msg); } @@ -100,18 +102,20 @@ public static void main(String[] args) { String feederActorURI = "akka.tcp://test@" + host + ":" + port + "/user/FeederActor"; /* - * Following is the use of actorStream to plug in custom actor as receiver + * Following is the use of AkkaUtils.createStream to plug in custom actor as receiver * * An important point to note: * Since Actor may exist outside the spark framework, It is thus user's responsibility * to ensure the type safety, i.e type of data received and InputDstream * should be same. * - * For example: Both actorStream and JavaSampleActorReceiver are parameterized + * For example: Both AkkaUtils.createStream and JavaSampleActorReceiver are parameterized * to same type to ensure type safety. */ - JavaDStream lines = jssc.actorStream( - Props.create(JavaSampleActorReceiver.class, feederActorURI), "SampleReceiver"); + JavaDStream lines = AkkaUtils.createStream( + jssc, + Props.create(JavaSampleActorReceiver.class, feederActorURI), + "SampleReceiver"); // compute wordcount lines.flatMap(new FlatMapFunction() { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 88cdc6bc144e5..8e88987439ffc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -22,12 +22,12 @@ import scala.collection.mutable.LinkedList import scala.reflect.ClassTag import scala.util.Random -import akka.actor.{actorRef2Scala, Actor, ActorRef, Props} +import akka.actor._ +import com.typesafe.config.ConfigFactory -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.receiver.ActorReceiver -import org.apache.spark.util.AkkaUtils +import org.apache.spark.streaming.akka.{ActorReceiver, AkkaUtils} case class SubscribeReceiver(receiverActor: ActorRef) case class UnsubscribeReceiver(receiverActor: ActorRef) @@ -78,8 +78,7 @@ class FeederActor extends Actor { * * @see [[org.apache.spark.examples.streaming.FeederActor]] */ -class SampleActorReceiver[T: ClassTag](urlOfPublisher: String) -extends ActorReceiver { +class SampleActorReceiver[T](urlOfPublisher: String) extends ActorReceiver { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) @@ -108,9 +107,13 @@ object FeederActor { } val Seq(host, port) = args.toSeq - val conf = new SparkConf - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = conf, - securityManager = new SecurityManager(conf))._1 + val akkaConf = ConfigFactory.parseString( + s"""akka.actor.provider = "akka.remote.RemoteActorRefProvider" + |akka.remote.enabled-transports = ["akka.remote.netty.tcp"] + |akka.remote.netty.tcp.hostname = "$host" + |akka.remote.netty.tcp.port = $port + |""".stripMargin) + val actorSystem = ActorSystem("test", akkaConf) val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") println("Feeder started as:" + feeder) @@ -121,6 +124,7 @@ object FeederActor { /** * A sample word count program demonstrating the use of plugging in + * * Actor as Receiver * Usage: ActorWordCount * and describe the AkkaSystem that Spark Sample feeder is running on. @@ -146,20 +150,21 @@ object ActorWordCount { val ssc = new StreamingContext(sparkConf, Seconds(2)) /* - * Following is the use of actorStream to plug in custom actor as receiver + * Following is the use of AkkaUtils.createStream to plug in custom actor as receiver * * An important point to note: * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e type of data received and InputDstream + * to ensure the type safety, i.e type of data received and InputDStream * should be same. * - * For example: Both actorStream and SampleActorReceiver are parameterized + * For example: Both AkkaUtils.createStream and SampleActorReceiver are parameterized * to same type to ensure type safety. */ - - val lines = ssc.actorStream[String]( - Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format( - host, port.toInt))), "SampleReceiver") + val lines = AkkaUtils.createStream[String]( + ssc, + Props(classOf[SampleActorReceiver[String]], + "akka.tcp://test@%s:%s/user/FeederActor".format(host, port.toInt)), + "SampleReceiver") // compute wordcount lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 96448905760fb..f612e508eb78e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -25,8 +25,9 @@ import akka.actor.actorRef2Scala import akka.util.ByteString import akka.zeromq._ import akka.zeromq.Subscribe +import com.typesafe.config.ConfigFactory -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.zeromq._ @@ -69,10 +70,10 @@ object SimpleZeroMQPublisher { * * To run this example locally, you may run publisher as * `$ bin/run-example \ - * org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` + * org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.0.1:1234 foo` * and run the example as * `$ bin/run-example \ - * org.apache.spark.examples.streaming.ZeroMQWordCount tcp://127.0.1.1:1234 foo` + * org.apache.spark.examples.streaming.ZeroMQWordCount tcp://127.0.0.1:1234 foo` */ // scalastyle:on object ZeroMQWordCount { @@ -90,7 +91,11 @@ object ZeroMQWordCount { def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. - val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) + val lines = ZeroMQUtils.createStream( + ssc, + url, + Subscribe(topic), + bytesToStringIterator _) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/external/akka/pom.xml b/external/akka/pom.xml new file mode 100644 index 0000000000000..34de9bae00e49 --- /dev/null +++ b/external/akka/pom.xml @@ -0,0 +1,73 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-akka_2.10 + + streaming-akka + + jar + Spark Project External Akka + http://spark.apache.org/ + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + ${akka.group} + akka-actor_${scala.binary.version} + ${akka.version} + + + ${akka.group} + akka-remote_${scala.binary.version} + ${akka.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala similarity index 74% rename from streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala rename to external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala index 0eabf3d260b26..c75dc92445b64 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.streaming.receiver +package org.apache.spark.streaming.akka import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger @@ -26,23 +26,44 @@ import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} +import com.typesafe.config.ConfigFactory -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver.Receiver /** * :: DeveloperApi :: * A helper with set of defaults for supervisor strategy */ @DeveloperApi -object ActorSupervisorStrategy { +object ActorReceiver { - val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = + /** + * A OneForOneStrategy supervisor strategy with `maxNrOfRetries = 10` and + * `withinTimeRange = 15 millis`. For RuntimeException, it will restart the ActorReceiver; for + * others, it just escalates the failure to the supervisor of the supervisor. + */ + val defaultSupervisorStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = 15 millis) { case _: RuntimeException => Restart case _: Exception => Escalate } + + /** + * A default ActorSystem creator. It will use a unique system name + * (streaming-actor-system-) to start an ActorSystem that supports remote + * communication. + */ + val defaultActorSystemCreator: () => ActorSystem = () => { + val uniqueSystemName = s"streaming-actor-system-${TaskContext.get().taskAttemptId()}" + val akkaConf = ConfigFactory.parseString( + s"""akka.actor.provider = "akka.remote.RemoteActorRefProvider" + |akka.remote.enabled-transports = ["akka.remote.netty.tcp"] + |""".stripMargin) + ActorSystem(uniqueSystemName, akkaConf) + } } /** @@ -58,13 +79,12 @@ object ActorSupervisorStrategy { * } * } * - * // Can be used with an actorStream as follows - * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") + * AkkaUtils.createStream[String](ssc, Props[MyActor](),"MyActorReceiver") * * }}} * * @note Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of push block and InputDStream + * to ensure the type safety, i.e. parametrized type of push block and InputDStream * should be same. */ @DeveloperApi @@ -103,18 +123,18 @@ abstract class ActorReceiver extends Actor { * * @example {{{ * class MyActor extends JavaActorReceiver { - * def receive { - * case anything: String => store(anything) + * @Override + * public void onReceive(Object msg) throws Exception { + * store((String) msg); * } * } * - * // Can be used with an actorStream as follows - * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") + * AkkaUtils.createStream(jssc, Props.create(MyActor.class), "MyActorReceiver"); * * }}} * * @note Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of push block and InputDStream + * to ensure the type safety, i.e. parametrized type of push block and InputDStream * should be same. */ @DeveloperApi @@ -147,8 +167,8 @@ abstract class JavaActorReceiver extends UntypedActor { /** * :: DeveloperApi :: * Statistics for querying the supervisor about state of workers. Used in - * conjunction with `StreamingContext.actorStream` and - * [[org.apache.spark.streaming.receiver.ActorReceiver]]. + * conjunction with `AkkaUtils.createStream` and + * [[org.apache.spark.streaming.akka.ActorReceiverSupervisor]]. */ @DeveloperApi case class Statistics(numberOfMsgs: Int, @@ -157,10 +177,10 @@ case class Statistics(numberOfMsgs: Int, otherInfo: String) /** Case class to receive data sent by child actors */ -private[streaming] sealed trait ActorReceiverData -private[streaming] case class SingleItemData[T](item: T) extends ActorReceiverData -private[streaming] case class IteratorData[T](iterator: Iterator[T]) extends ActorReceiverData -private[streaming] case class ByteBufferData(bytes: ByteBuffer) extends ActorReceiverData +private[akka] sealed trait ActorReceiverData +private[akka] case class SingleItemData[T](item: T) extends ActorReceiverData +private[akka] case class IteratorData[T](iterator: Iterator[T]) extends ActorReceiverData +private[akka] case class ByteBufferData(bytes: ByteBuffer) extends ActorReceiverData /** * Provides Actors as receivers for receiving stream. @@ -181,14 +201,16 @@ private[streaming] case class ByteBufferData(bytes: ByteBuffer) extends ActorRec * context.parent ! Props(new Worker, "Worker") * }}} */ -private[streaming] class ActorReceiverSupervisor[T: ClassTag]( +private[akka] class ActorReceiverSupervisor[T: ClassTag]( + actorSystemCreator: () => ActorSystem, props: Props, name: String, storageLevel: StorageLevel, receiverSupervisorStrategy: SupervisorStrategy ) extends Receiver[T](storageLevel) with Logging { - protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + private lazy val actorSystem = actorSystemCreator() + protected lazy val actorSupervisor = actorSystem.actorOf(Props(new Supervisor), "Supervisor" + streamId) class Supervisor extends Actor { @@ -241,5 +263,7 @@ private[streaming] class ActorReceiverSupervisor[T: ClassTag]( def onStop(): Unit = { actorSupervisor ! PoisonPill + actorSystem.shutdown() + actorSystem.awaitTermination() } } diff --git a/external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala b/external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala new file mode 100644 index 0000000000000..38c35c5ae7a18 --- /dev/null +++ b/external/akka/src/main/scala/org/apache/spark/streaming/akka/AkkaUtils.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.akka + +import scala.reflect.ClassTag + +import akka.actor.{ActorSystem, Props, SupervisorStrategy} + +import org.apache.spark.api.java.function.{Function0 => JFunction0} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream + +object AkkaUtils { + + /** + * Create an input stream with a user-defined actor. See [[ActorReceiver]] for more details. + * + * @param ssc The StreamingContext instance + * @param propsForActor Props object defining creation of the actor + * @param actorName Name of the actor + * @param storageLevel RDD storage level (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will + * be shut down when the receiver is stopping (default: + * ActorReceiver.defaultActorSystemCreator) + * @param supervisorStrategy the supervisor strategy (default: ActorReceiver.defaultStrategy) + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e. parametrized type of data received and createStream + * should be same. + */ + def createStream[T: ClassTag]( + ssc: StreamingContext, + propsForActor: Props, + actorName: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, + actorSystemCreator: () => ActorSystem = ActorReceiver.defaultActorSystemCreator, + supervisorStrategy: SupervisorStrategy = ActorReceiver.defaultSupervisorStrategy + ): ReceiverInputDStream[T] = ssc.withNamedScope("actor stream") { + val cleanF = ssc.sc.clean(actorSystemCreator) + ssc.receiverStream(new ActorReceiverSupervisor[T]( + cleanF, + propsForActor, + actorName, + storageLevel, + supervisorStrategy)) + } + + /** + * Create an input stream with a user-defined actor. See [[JavaActorReceiver]] for more details. + * + * @param jssc The StreamingContext instance + * @param propsForActor Props object defining creation of the actor + * @param actorName Name of the actor + * @param storageLevel Storage level to use for storing the received objects + * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will + * be shut down when the receiver is stopping. + * @param supervisorStrategy the supervisor strategy + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e. parametrized type of data received and createStream + * should be same. + */ + def createStream[T]( + jssc: JavaStreamingContext, + propsForActor: Props, + actorName: String, + storageLevel: StorageLevel, + actorSystemCreator: JFunction0[ActorSystem], + supervisorStrategy: SupervisorStrategy + ): JavaReceiverInputDStream[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + createStream[T]( + jssc.ssc, + propsForActor, + actorName, + storageLevel, + () => actorSystemCreator.call(), + supervisorStrategy) + } + + /** + * Create an input stream with a user-defined actor. See [[JavaActorReceiver]] for more details. + * + * @param jssc The StreamingContext instance + * @param propsForActor Props object defining creation of the actor + * @param actorName Name of the actor + * @param storageLevel Storage level to use for storing the received objects + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e. parametrized type of data received and createStream + * should be same. + */ + def createStream[T]( + jssc: JavaStreamingContext, + propsForActor: Props, + actorName: String, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + createStream[T](jssc.ssc, propsForActor, actorName, storageLevel) + } + + /** + * Create an input stream with a user-defined actor. Storage level of the data will be the default + * StorageLevel.MEMORY_AND_DISK_SER_2. See [[JavaActorReceiver]] for more details. + * + * @param jssc The StreamingContext instance + * @param propsForActor Props object defining creation of the actor + * @param actorName Name of the actor + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e. parametrized type of data received and createStream + * should be same. + */ + def createStream[T]( + jssc: JavaStreamingContext, + propsForActor: Props, + actorName: String + ): JavaReceiverInputDStream[T] = { + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] + createStream[T](jssc.ssc, propsForActor, actorName) + } +} diff --git a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java new file mode 100644 index 0000000000000..b732506767154 --- /dev/null +++ b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.akka; + +import akka.actor.ActorSystem; +import akka.actor.Props; +import akka.actor.SupervisorStrategy; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.Test; + +import org.apache.spark.api.java.function.Function0; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; + +public class JavaAkkaUtilsSuite { + + @Test // tests the API, does not actually test data receiving + public void testAkkaUtils() { + JavaStreamingContext jsc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + try { + JavaReceiverInputDStream test1 = AkkaUtils.createStream( + jsc, Props.create(JavaTestActor.class), "test"); + JavaReceiverInputDStream test2 = AkkaUtils.createStream( + jsc, Props.create(JavaTestActor.class), "test", StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = AkkaUtils.createStream( + jsc, + Props.create(JavaTestActor.class), + "test", StorageLevel.MEMORY_AND_DISK_SER_2(), + new ActorSystemCreatorForTest(), + SupervisorStrategy.defaultStrategy()); + } finally { + jsc.stop(); + } + } +} + +class ActorSystemCreatorForTest implements Function0 { + @Override + public ActorSystem call() { + return null; + } +} + + +class JavaTestActor extends JavaActorReceiver { + @Override + public void onReceive(Object message) throws Exception { + store((String) message); + } +} diff --git a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala new file mode 100644 index 0000000000000..f437585a98e4f --- /dev/null +++ b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.akka + +import akka.actor.{Props, SupervisorStrategy} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream + +class AkkaUtilsSuite extends SparkFunSuite { + + test("createStream") { + val ssc: StreamingContext = new StreamingContext("local[2]", "test", Seconds(1000)) + try { + // tests the API, does not actually test data receiving + val test1: ReceiverInputDStream[String] = AkkaUtils.createStream( + ssc, Props[TestActor](), "test") + val test2: ReceiverInputDStream[String] = AkkaUtils.createStream( + ssc, Props[TestActor](), "test", StorageLevel.MEMORY_AND_DISK_SER_2) + val test3: ReceiverInputDStream[String] = AkkaUtils.createStream( + ssc, + Props[TestActor](), + "test", + StorageLevel.MEMORY_AND_DISK_SER_2, + supervisorStrategy = SupervisorStrategy.defaultStrategy) + val test4: ReceiverInputDStream[String] = AkkaUtils.createStream( + ssc, Props[TestActor](), "test", StorageLevel.MEMORY_AND_DISK_SER_2, () => null) + val test5: ReceiverInputDStream[String] = AkkaUtils.createStream( + ssc, Props[TestActor](), "test", StorageLevel.MEMORY_AND_DISK_SER_2, () => null) + val test6: ReceiverInputDStream[String] = AkkaUtils.createStream( + ssc, + Props[TestActor](), + "test", + StorageLevel.MEMORY_AND_DISK_SER_2, + () => null, + SupervisorStrategy.defaultStrategy) + } finally { + ssc.stop() + } + } +} + +class TestActor extends ActorReceiver { + override def receive: Receive = { + case m: String => store(m) + } +} diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a725988449075..7781aaeed9e0c 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -41,6 +41,11 @@ ${project.version} provided + + org.apache.spark + spark-streaming-akka_${scala.binary.version} + ${project.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 506ba8782d3d5..dd367cd43b807 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -23,7 +23,7 @@ import akka.util.ByteString import akka.zeromq._ import org.apache.spark.Logging -import org.apache.spark.streaming.receiver.ActorReceiver +import org.apache.spark.streaming.akka.ActorReceiver /** * A receiver to subscribe to ZeroMQ stream. diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 63cd8a2721f0c..1784d6e8623ad 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -20,29 +20,33 @@ package org.apache.spark.streaming.zeromq import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import akka.actor.{Props, SupervisorStrategy} +import akka.actor.{ActorSystem, Props, SupervisorStrategy} import akka.util.ByteString import akka.zeromq.Subscribe -import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.java.function.{Function => JFunction, Function0 => JFunction0} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.akka.{ActorReceiver, AkkaUtils} import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.receiver.ActorSupervisorStrategy object ZeroMQUtils { /** * Create an input stream that receives messages pushed by a zeromq publisher. - * @param ssc StreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to + * @param ssc StreamingContext object + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe Topic to subscribe to * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic * and each frame has sequence of byte thus it needs the converter * (which might be deserializer of bytes) to translate from sequence * of sequence of bytes, where sequence refer to a frame * and sub sequence refer to its payload. * @param storageLevel RDD storage level. Defaults to StorageLevel.MEMORY_AND_DISK_SER_2. + * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will + * be shut down when the receiver is stopping (default: + * ActorReceiver.defaultActorSystemCreator) + * @param supervisorStrategy the supervisor strategy (default: ActorReceiver.defaultStrategy) */ def createStream[T: ClassTag]( ssc: StreamingContext, @@ -50,22 +54,31 @@ object ZeroMQUtils { subscribe: Subscribe, bytesToObjects: Seq[ByteString] => Iterator[T], storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, - supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy + actorSystemCreator: () => ActorSystem = ActorReceiver.defaultActorSystemCreator, + supervisorStrategy: SupervisorStrategy = ActorReceiver.defaultSupervisorStrategy ): ReceiverInputDStream[T] = { - ssc.actorStream(Props(new ZeroMQReceiver(publisherUrl, subscribe, bytesToObjects)), - "ZeroMQReceiver", storageLevel, supervisorStrategy) + AkkaUtils.createStream( + ssc, + Props(new ZeroMQReceiver(publisherUrl, subscribe, bytesToObjects)), + "ZeroMQReceiver", + storageLevel, + actorSystemCreator, + supervisorStrategy) } /** * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote ZeroMQ publisher - * @param subscribe Topic to subscribe to + * @param jssc JavaStreamingContext object + * @param publisherUrl Url of remote ZeroMQ publisher + * @param subscribe Topic to subscribe to * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each * frame has sequence of byte thus it needs the converter(which might be * deserializer of bytes) to translate from sequence of sequence of bytes, * where sequence refer to a frame and sub sequence refer to its payload. - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects + * @param actorSystemCreator A function to create ActorSystem in executors. `ActorSystem` will + * be shut down when the receiver is stopping. + * @param supervisorStrategy the supervisor strategy (default: ActorReceiver.defaultStrategy) */ def createStream[T]( jssc: JavaStreamingContext, @@ -73,25 +86,33 @@ object ZeroMQUtils { subscribe: Subscribe, bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], storageLevel: StorageLevel, + actorSystemCreator: JFunction0[ActorSystem], supervisorStrategy: SupervisorStrategy ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) + createStream[T]( + jssc.ssc, + publisherUrl, + subscribe, + fn, + storageLevel, + () => actorSystemCreator.call(), + supervisorStrategy) } /** * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to + * @param jssc JavaStreamingContext object + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe Topic to subscribe to * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each * frame has sequence of byte thus it needs the converter(which might be * deserializer of bytes) to translate from sequence of sequence of bytes, * where sequence refer to a frame and sub sequence refer to its payload. - * @param storageLevel RDD storage level. + * @param storageLevel RDD storage level. */ def createStream[T]( jssc: JavaStreamingContext, @@ -104,14 +125,19 @@ object ZeroMQUtils { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) + createStream[T]( + jssc.ssc, + publisherUrl, + subscribe, + fn, + storageLevel) } /** * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to + * @param jssc JavaStreamingContext object + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe Topic to subscribe to * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each * frame has sequence of byte thus it needs the converter(which might * be deserializer of bytes) to translate from sequence of sequence of @@ -128,6 +154,10 @@ object ZeroMQUtils { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T](jssc.ssc, publisherUrl, subscribe, fn) + createStream[T]( + jssc.ssc, + publisherUrl, + subscribe, + fn) } } diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java b/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java index 417b91eecb0ee..9ff4b41f97d50 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java @@ -17,14 +17,17 @@ package org.apache.spark.streaming.zeromq; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; +import akka.actor.ActorSystem; import akka.actor.SupervisorStrategy; import akka.util.ByteString; import akka.zeromq.Subscribe; +import org.junit.Test; + import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function0; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; public class JavaZeroMQStreamSuite extends LocalJavaStreamingContext { @@ -32,19 +35,29 @@ public class JavaZeroMQStreamSuite extends LocalJavaStreamingContext { public void testZeroMQStream() { String publishUrl = "abc"; Subscribe subscribe = new Subscribe((ByteString)null); - Function> bytesToObjects = new Function>() { - @Override - public Iterable call(byte[][] bytes) throws Exception { - return null; - } - }; + Function> bytesToObjects = new BytesToObjects(); + Function0 actorSystemCreator = new ActorSystemCreatorForTest(); JavaReceiverInputDStream test1 = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects); JavaReceiverInputDStream test2 = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2()); JavaReceiverInputDStream test3 = ZeroMQUtils.createStream( - ssc,publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2(), + ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2(), actorSystemCreator, SupervisorStrategy.defaultStrategy()); } } + +class BytesToObjects implements Function> { + @Override + public Iterable call(byte[][] bytes) throws Exception { + return null; + } +} + +class ActorSystemCreatorForTest implements Function0 { + @Override + public ActorSystem call() { + return null; + } +} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index 35d2e62c68480..bac2679cabae5 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -42,14 +42,22 @@ class ZeroMQStreamSuite extends SparkFunSuite { // tests the API, does not actually test data receiving val test1: ReceiverInputDStream[String] = - ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects) + ZeroMQUtils.createStream( + ssc, publishUrl, subscribe, bytesToObjects, actorSystemCreator = () => null) val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2) + ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2, () => null) val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects, - StorageLevel.MEMORY_AND_DISK_SER_2, SupervisorStrategy.defaultStrategy) + StorageLevel.MEMORY_AND_DISK_SER_2, () => null, SupervisorStrategy.defaultStrategy) + val test4: ReceiverInputDStream[String] = + ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects) + val test5: ReceiverInputDStream[String] = ZeroMQUtils.createStream( + ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2) + val test6: ReceiverInputDStream[String] = ZeroMQUtils.createStream( + ssc, publishUrl, subscribe, bytesToObjects, + StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy = SupervisorStrategy.defaultStrategy) - // TODO: Actually test data receiving + // TODO: Actually test data receiving. A real test needs the native ZeroMQ library ssc.stop() } } diff --git a/pom.xml b/pom.xml index fca626991324b..43f08efaae86d 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/flume external/flume-sink external/flume-assembly + external/akka external/mqtt external/mqtt-assembly external/zeromq diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6469201446f0c..905fb4cd90377 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -153,6 +153,16 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-7799 Add "streaming-akka" project + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream") ) ++ Seq( // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 06e561ae0d89b..3927b88fb0bf6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -35,11 +35,11 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", - "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", + "streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl, @@ -232,8 +232,9 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) + // TODO: remove streamingAkka from this list after 2.0.0 allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, unsafe, testTags).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -649,7 +650,7 @@ object Unidoc { "-public", "-group", "Core Java API", packageList("api.java", "api.java.function"), "-group", "Spark Streaming", packageList( - "streaming.api.java", "streaming.flume", "streaming.kafka", + "streaming.api.java", "streaming.flume", "streaming.akka", "streaming.kafka", "streaming.mqtt", "streaming.twitter", "streaming.zeromq", "streaming.kinesis" ), "-group", "MLlib", packageList( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b7070dda99dc6..ec57c05e3b5bb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -25,7 +25,6 @@ import scala.collection.mutable.Queue import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} @@ -42,7 +41,7 @@ import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorReceiverSupervisor, ActorSupervisorStrategy, Receiver} +import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} @@ -295,27 +294,6 @@ class StreamingContext private[streaming] ( } } - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel RDD storage level (default: StorageLevel.MEMORY_AND_DISK_SER_2) - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T: ClassTag]( - props: Props, - name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, - supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy - ): ReceiverInputDStream[T] = withNamedScope("actor stream") { - receiverStream(new ActorReceiverSupervisor[T](props, name, storageLevel, supervisorStrategy)) - } - /** * Create a input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 00f9d8a9e8817..7a25ce54b6ff0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -24,7 +24,6 @@ import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -356,69 +355,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) } - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String, - storageLevel: StorageLevel, - supervisorStrategy: SupervisorStrategy - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - ssc.actorStream[T](props, name, storageLevel, supervisorStrategy) - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - ssc.actorStream[T](props, name, storageLevel) - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - ssc.actorStream[T](props, name) - } - /** * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. From 8f90c151878571e20625e2a53561441ec0035dfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 20 Jan 2016 14:59:30 -0800 Subject: [PATCH 1645/2089] [SPARK-12616][SQL] Making Logical Operator `Union` Support Arbitrary Number of Children The existing `Union` logical operator only supports two children. Thus, adding a new logical operator `Unions` which can have arbitrary number of children to replace the existing one. `Union` logical plan is a binary node. However, a typical use case for union is to union a very large number of input sources (DataFrames, RDDs, or files). It is not uncommon to union hundreds of thousands of files. In this case, our optimizer can become very slow due to the large number of logical unions. We should change the Union logical plan to support an arbitrary number of children, and add a single rule in the optimizer to collapse all adjacent `Unions` into a single `Unions`. Note that this problem doesn't exist in physical plan, because the physical `Unions` already supports arbitrary number of children. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10577 from gatorsmile/unionAllMultiChildren. --- .../spark/sql/catalyst/CatalystQl.scala | 4 +- .../sql/catalyst/analysis/Analyzer.scala | 12 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++ .../catalyst/analysis/HiveTypeCoercion.scala | 71 +++++++++++----- .../sql/catalyst/optimizer/Optimizer.scala | 71 ++++++++++------ .../sql/catalyst/planning/patterns.scala | 23 +++++- .../plans/logical/basicOperators.scala | 47 ++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 ++ .../analysis/DecimalPrecisionSuite.scala | 2 +- .../analysis/HiveTypeCoercionSuite.scala | 82 +++++++++++++++---- ...ownSuite.scala => SetOperationSuite.scala} | 43 ++++++++-- .../org/apache/spark/sql/DataFrame.scala | 5 +- .../scala/org/apache/spark/sql/Dataset.scala | 9 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 11 +-- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 16 +++- .../spark/sql/execution/PlannerSuite.scala | 12 --- .../apache/spark/sql/hive/SQLBuilder.scala | 12 +-- .../sql/hive/LogicalPlanToSQLSuite.scala | 4 + 20 files changed, 322 insertions(+), 122 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{SetOperationPushDownSuite.scala => SetOperationSuite.scala} (70%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 5fb41f7e4bf8c..35273c7e24ae4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -402,8 +402,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C overwrite) } - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + // If there are multiple INSERTS just UNION them together into one query. + val query = if (queries.length == 1) queries.head else Union(queries) // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d4b4bc88b3f2f..33d76eeb21287 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -66,7 +66,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, CTESubstitution, - WindowsSubstitution), + WindowsSubstitution, + EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -1170,6 +1171,15 @@ object EliminateSubQueries extends Rule[LogicalPlan] { } } +/** + * Removes [[Union]] operators from the plan if it just has one child. + */ +object EliminateUnions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Union(children) if children.size == 1 => children.head + } +} + /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2a2e0d27d9435..f2e78d97442e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -189,6 +189,14 @@ trait CheckAnalysis { s"but the left table has ${left.output.length} columns and the right has " + s"${right.output.length}") + case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => + val firstError = s.children.find(_.output.length != s.children.head.output.length).get + failAnalysis( + s""" + |Unions can only be performed on tables with the same number of columns, + | but one table has '${firstError.output.length}' columns and another table has + | '${s.children.head.output.length}' columns""".stripMargin) + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 7df3787e6d2d3..c557c3231997a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +30,7 @@ import org.apache.spark.sql.types._ /** - * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in + * A collection of [[Rule]] that can be used to coerce differing types that participate in * operations into compatible ones. * * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on @@ -219,31 +222,59 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p - case s @ SetOperation(left, right) if s.childrenResolved - && left.output.length == right.output.length && !s.resolved => + case s @ SetOperation(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + s.makeCopy(Array(newChildren.head, newChildren.last)) - // Tracks the list of data types to widen. - // Some(dataType) means the right-hand side and the left-hand side have different types, - // and there is a target type to widen both sides to. - val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - findWiderTypeForTwo(lhs.dataType, rhs.dataType) - case other => None - } + case s: Union if s.childrenResolved && + s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + s.makeCopy(Array(newChildren)) + } - if (targetTypes.exists(_.isDefined)) { - // There is at least one column to widen. - s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes))) - } else { - // If we cannot find any column to widen, then just return the original set. - s - } + /** Build new children with the widest types for each attribute among all the children */ + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + require(children.forall(_.output.length == children.head.output.length)) + + // Get a sequence of data types, each of which is the widest type of this specific attribute + // in all the children + val targetTypes: Seq[DataType] = + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + + if (targetTypes.nonEmpty) { + // Add an extra Project if the targetTypes are different from the original types. + children.map(widenTypes(_, targetTypes)) + } else { + // Unable to find a target type to widen, then just return the original set. + children + } + } + + /** Get the widest type for each attribute in all the children */ + @tailrec private def getWidestTypes( + children: Seq[LogicalPlan], + attrIndex: Int, + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + // Return the result after the widen data types have been found for all the children + if (attrIndex >= children.head.output.length) return castedTypes.toSeq + + // For the attrIndex-th attribute, find the widest type + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + // If unable to find an appropriate widen type for this column, return an empty Seq + case None => Seq.empty[DataType] + // Otherwise, record the result in the queue and find the type for the next column + case Some(widenType) => + castedTypes.enqueue(widenType) + getWidestTypes(children, attrIndex + 1, castedTypes) + } } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = { + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { val casted = plan.output.zip(targetTypes).map { - case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } Project(casted, plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 04643f0274bd4..44455b482074b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -45,6 +45,13 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// + // - Do the first call of CombineUnions before starting the major Optimizer rules, + // since it can reduce the number of iteration and the other rules could add/move + // extra operators between two adjacent Union operators. + // - Call CombineUnions again in Batch("Operator Optimizations"), + // since the other rules might make two separate Unions operators adjacent. + Batch("Union", Once, + CombineUnions) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -62,6 +69,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ProjectCollapsing, CombineFilters, CombineLimits, + CombineUnions, // Constant folding and strength reduction NullPropagation, OptimizeIn, @@ -138,11 +146,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { - assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) - assert(bn.left.output.size == bn.right.output.size) - - AttributeMap(bn.left.output.zip(bn.right.output)) + private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = { + assert(left.output.size == right.output.size) + AttributeMap(left.output.zip(right.output)) } /** @@ -176,32 +182,38 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down filter into union - case Filter(condition, u @ Union(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(u) - Filter(nondeterministic, - Union( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) // Push down deterministic projection through UNION ALL - case p @ Project(projectList, u @ Union(left, right)) => + case p @ Project(projectList, Union(children)) => + assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + val newFirstChild = Project(projectList, children.head) + val newOtherChildren = children.tail.map ( child => { + val rewrites = buildRewrites(children.head, child) + Project(projectList.map(pushToRight(_, rewrites)), child) + } ) + Union(newFirstChild +: newOtherChildren) } else { p } + // Push down filter into union + case Filter(condition, Union(children)) => + assert(children.nonEmpty) + val (deterministic, nondeterministic) = partitionByDeterministic(condition) + val newFirstChild = Filter(deterministic, children.head) + val newOtherChildren = children.tail.map { + child => { + val rewrites = buildRewrites(children.head, child) + Filter(pushToRight(deterministic, rewrites), child) + } + } + Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) + // Push down filter through INTERSECT - case Filter(condition, i @ Intersect(left, right)) => + case Filter(condition, Intersect(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(i) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Intersect( Filter(deterministic, left), @@ -210,9 +222,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { ) // Push down filter through EXCEPT - case Filter(condition, e @ Except(left, right)) => + case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(e) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Except( Filter(deterministic, left), @@ -662,6 +674,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Combines all adjacent [[Union]] operators into a single [[Union]]. + */ +object CombineUnions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Unions(children) => Union(children) + } +} + /** * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd3f15cbe107b..f0ee124e88a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.planning +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -170,17 +173,29 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } } + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ object Unions { def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(u)) + case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) case _ => None } - private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { - case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) - case other => other :: Nil + // Doing a depth-first tree traversal to combine all the union children. + @tailrec + private def collectUnionChildren( + plans: mutable.Stack[LogicalPlan], + children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + if (plans.isEmpty) children + else { + plans.pop match { + case Union(grandchildren) => + grandchildren.reverseMap(plans.push(_)) + collectUnionChildren(plans, children) + case other => collectUnionChildren(plans, children :+ other) + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f4a3d85d2a8a4..e9c970cd08088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -101,19 +101,6 @@ private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - - override def output: Seq[Attribute] = - left.output.zip(right.output).map { case (leftAttr, rightAttr) => - leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) - } - - override def statistics: Statistics = { - val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes - Statistics(sizeInBytes = sizeInBytes) - } -} - case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def output: Seq[Attribute] = @@ -127,6 +114,40 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } +/** Factory for constructing new `Union` nodes. */ +object Union { + def apply(left: LogicalPlan, right: LogicalPlan): Union = { + Union (left :: right :: Nil) + } +} + +case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { + + // updating nullability to make all the children consistent + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + + override lazy val resolved: Boolean = { + // allChildrenCompatible needs to be evaluated after childrenResolved + def allChildrenCompatible: Boolean = + children.tail.forall( child => + // compare the attribute number with the first child + child.output.length == children.head.output.length && + // compare the data types with the first child + child.output.zip(children.head.output).forall { + case (l, r) => l.dataType == r.dataType } + ) + + children.length > 1 && childrenResolved && allChildrenCompatible + } + + override def statistics: Statistics = { + val sizeInBytes = children.map(_.statistics.sizeInBytes).sum + Statistics(sizeInBytes = sizeInBytes) + } +} + case class Join( left: LogicalPlan, right: LogicalPlan, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 975cd87d090e4..ab680282208c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -237,6 +237,12 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("Eliminate the unnecessary union") { + val plan = Union(testRelation :: Nil) + val expected = testRelation + checkAnalysis(plan, expected) + } + test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 39c8f56c1bca6..24c608eaa5b39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Union(left, right) => (left.output.head, right.output.head) + case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b326aa9c55992..c30434a0063b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -387,19 +387,19 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenSetOperationTypes for union, except, and intersect") { - def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { - logical.output.zip(expectTypes).foreach { case (attr, dt) => - assert(attr.dataType === dt) - } + private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) } + } - val left = LocalRelation( + test("WidenSetOperationTypes for except and intersect") { + val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) - val right = LocalRelation( + val secondTable = LocalRelation( AttributeReference("s", StringType)(), AttributeReference("d", DecimalType(2, 1))(), AttributeReference("f", FloatType)(), @@ -408,15 +408,65 @@ class HiveTypeCoercionSuite extends PlanTest { val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(left, right)).asInstanceOf[Union] - val r2 = wt(Except(left, right)).asInstanceOf[Except] - val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) checkOutput(r2.right, expectedTypes) - checkOutput(r3.left, expectedTypes) - checkOutput(r3.right, expectedTypes) + + // Check if a Project is added + assert(r1.left.isInstanceOf[Project]) + assert(r1.right.isInstanceOf[Project]) + assert(r2.left.isInstanceOf[Project]) + assert(r2.right.isInstanceOf[Project]) + + val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except] + checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + + // Check if no Project is added + assert(r3.left.isInstanceOf[LocalRelation]) + assert(r3.right.isInstanceOf[LocalRelation]) + } + + test("WidenSetOperationTypes for union") { + val firstTable = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val secondTable = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + val thirdTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", FloatType)(), + AttributeReference("q", DoubleType)()) + val forthTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", ByteType)(), + AttributeReference("q", DoubleType)()) + + val wt = HiveTypeCoercion.WidenSetOperationTypes + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) + + val unionRelation = wt( + Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] + assert(unionRelation.children.length == 4) + checkOutput(unionRelation.children.head, expectedTypes) + checkOutput(unionRelation.children(1), expectedTypes) + checkOutput(unionRelation.children(2), expectedTypes) + checkOutput(unionRelation.children(3), expectedTypes) + + assert(unionRelation.children.head.isInstanceOf[Project]) + assert(unionRelation.children(1).isInstanceOf[Project]) + assert(unionRelation.children(2).isInstanceOf[Project]) + assert(unionRelation.children(3).isInstanceOf[Project]) } test("Transform Decimal precision/scale for union except and intersect") { @@ -438,8 +488,8 @@ class HiveTypeCoercionSuite extends PlanTest { val r2 = dp(Except(left1, right1)).asInstanceOf[Except] val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] - checkOutput(r1.left, expectedType1) - checkOutput(r1.right, expectedType1) + checkOutput(r1.children.head, expectedType1) + checkOutput(r1.children.last, expectedType1) checkOutput(r2.left, expectedType1) checkOutput(r2.right, expectedType1) checkOutput(r3.left, expectedType1) @@ -459,7 +509,7 @@ class HiveTypeCoercionSuite extends PlanTest { val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] - checkOutput(r1.right, Seq(expectedType)) + checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) @@ -467,7 +517,7 @@ class HiveTypeCoercionSuite extends PlanTest { val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] - checkOutput(r4.left, Seq(expectedType)) + checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) checkOutput(r6.left, Seq(expectedType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala similarity index 70% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index a498b463a69e9..2283f7c008ba2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -24,48 +24,73 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class SetOperationPushDownSuite extends PlanTest { +class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, + CombineUnions, SetOperationPushDown, SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) + val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) + val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) - test("union/intersect/except: filter to each side") { - val unionQuery = testUnion.where('a === 1) + test("union: combine unions into one unions") { + val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) + val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) + val unionOptimized1 = Optimize.execute(unionQuery1.analyze) + val unionOptimized2 = Optimize.execute(unionQuery2.analyze) + + comparePlans(unionOptimized1, unionOptimized2) + + val combinedUnions = Union(unionOptimized1 :: unionOptimized2 :: Nil) + val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) + val unionQuery3 = Union(unionQuery1, unionQuery2) + val unionOptimized3 = Optimize.execute(unionQuery3.analyze) + comparePlans(combinedUnionsOptimized, unionOptimized3) + } + + test("intersect/except: filter to each side") { val intersectQuery = testIntersect.where('b < 10) val exceptQuery = testExcept.where('c >= 5) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze val intersectCorrectAnswer = Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(unionOptimized, unionCorrectAnswer) comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } + test("union: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.where('a === 1) :: + testRelation2.where('d === 1) :: + testRelation3.where('g === 1) :: Nil).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + } + test("union: project to each side") { val unionQuery = testUnion.select('a) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze + Union(testRelation.select('a) :: + testRelation2.select('d) :: + testRelation3.select('g) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 95e5fbb11900e..518f9dcf94a70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} @@ -1002,7 +1003,9 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { - Union(logicalPlan, other.logicalPlan) + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9a9f7d111cf4b..bd99c399571c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ +import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} @@ -603,7 +604,11 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(left, right)) + } /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4ddb6d76b2c6..60fbb595e5758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,7 +336,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil - case Unions(unionChildren) => + case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 9e2e0357c65f0..6deb72adad5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -281,13 +281,10 @@ case class Range( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - override def output: Seq[Attribute] = { - children.tail.foldLeft(children.head.output) { case (currentOutput, child) => - currentOutput.zip(child.output).map { case (a1, a2) => - a1.withNullability(a1.nullable || a2.nullable) - } - } - } + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 1a3df1b117b68..3c0f25a5dc535 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -298,9 +298,9 @@ public void testSetOperation() { Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); - Dataset unioned = ds.union(ds2); + Dataset unioned = ds.union(ds2).union(ds); Assert.assertEquals( - Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"), + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"), unioned.collectAsList()); Dataset subtracted = ds.subtract(ds2); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bd11a387a1d5d..09bbe57a43ceb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ @@ -98,6 +98,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } + test("union all") { + val unionDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 49feeaf17d68f..8fca5e2167d04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -51,18 +51,6 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } - test("unions are collapsed") { - val planner = sqlContext.planner - import planner._ - val query = testData.unionAll(testData).unionAll(testData).logicalPlan - val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u } - val physicalUnions = planned collect { case u: execution.Union => u } - - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) - } - test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed testPartialAggregationPlan(query) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e83b4bffff857..1654594538366 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -129,11 +129,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi conditionSQL = condition.sql } yield s"$childSQL $whereOrHaving $conditionSQL" - case Union(left, right) => - for { - leftSQL <- toSQL(left) - rightSQL <- toSQL(right) - } yield s"$leftSQL UNION ALL $rightSQL" + case Union(children) if children.length > 1 => + val childrenSql = children.map(toSQL(_)) + if (childrenSql.exists(_.isEmpty)) { + None + } else { + Some(childrenSql.map(_.get).mkString(" UNION ALL ")) + } // Persisted data source relation case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 0604d9f47c5da..261a4746f4287 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -105,6 +105,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") } + test("three-child union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0") + } + test("case") { checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") } From f3934a8d656f1668bec065751b2a11411229b6f5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Jan 2016 15:08:27 -0800 Subject: [PATCH 1646/2089] [SPARK-12888][SQL] benchmark the new hash expression Benchmark it on 4 different schemas, the result: ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz Hash For simple: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- interpreted version 31.47 266.54 1.00 X codegen version 64.52 130.01 0.49 X ``` ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz Hash For normal: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- interpreted version 4068.11 0.26 1.00 X codegen version 1175.92 0.89 3.46 X ``` ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz Hash For array: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- interpreted version 9276.70 0.06 1.00 X codegen version 14762.23 0.04 0.63 X ``` ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz Hash For map: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- interpreted version 58869.79 0.01 1.00 X codegen version 9285.36 0.06 6.34 X ``` Author: Wenchen Fan Closes #10816 from cloud-fan/hash-benchmark. --- .../org/apache/spark/sql/HashBenchmark.scala | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala new file mode 100644 index 0000000000000..184f845b4dce2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark for the previous interpreted hash function(InternalRow.hashCode) vs the new codegen + * hash expression(Murmur3Hash). + */ +object HashBenchmark { + + def test(name: String, schema: StructType, iters: Int): Unit = { + val numRows = 1024 * 8 + + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + val attrs = schema.toAttributes + val safeProjection = GenerateSafeProjection.generate(attrs, attrs) + + val rows = (1 to numRows).map(_ => + // The output of encoder is UnsafeRow, use safeProjection to turn in into safe format. + safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() + ).toArray + + val benchmark = new Benchmark("Hash For " + name, iters * numRows) + benchmark.addCase("interpreted version") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numRows) { + sum += rows(i).hashCode() + i += 1 + } + } + } + + val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) + benchmark.addCase("codegen version") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numRows) { + sum += getHashCode(rows(i)).getInt(0) + i += 1 + } + } + } + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val simple = new StructType().add("i", IntegerType) + test("simple", simple, 1024) + + val normal = new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) + .add("smallDecimal", DecimalType.USER_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + test("normal", normal, 128) + + val arrayOfInt = ArrayType(IntegerType) + val array = new StructType() + .add("array", arrayOfInt) + .add("arrayOfArray", ArrayType(arrayOfInt)) + test("array", array, 64) + + val mapOfInt = MapType(IntegerType, IntegerType) + val map = new StructType() + .add("map", mapOfInt) + .add("mapOfMap", MapType(IntegerType, mapOfInt)) + test("map", map, 64) + } +} From 10173279305a0e8a62bfbfe7a9d5d1fd558dd8e1 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 20 Jan 2016 15:13:01 -0800 Subject: [PATCH 1647/2089] [SPARK-12848][SQL] Change parsed decimal literal datatype from Double to Decimal The current parser turns a decimal literal, for example ```12.1```, into a Double. The problem with this approach is that we convert an exact literal into a non-exact ```Double```. The PR changes this behavior, a Decimal literal is now converted into an extact ```BigDecimal```. The behavior for scientific decimals, for example ```12.1e01```, is unchanged. This will be converted into a Double. This PR replaces the ```BigDecimal``` literal by a ```Double``` literal, because the ```BigDecimal``` is the default now. You can use the double literal by appending a 'D' to the value, for instance: ```3.141527D``` cc davies rxin Author: Herman van Hovell Closes #10796 from hvanhovell/SPARK-12848. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- .../sql/catalyst/parser/ExpressionParser.g | 2 +- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 +- .../spark/sql/catalyst/CatalystQl.scala | 8 ++- .../catalyst/analysis/HiveTypeCoercion.scala | 3 +- .../spark/sql/MathExpressionsSuite.scala | 12 ++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 50 +++++++++---------- .../datasources/json/JsonSuite.scala | 4 +- .../execution/HiveCompatibilitySuite.scala | 13 +++-- .../HiveWindowFunctionQuerySuite.scala | 2 +- ... + 1.0-0-404b0ea20c125c9648b7919a8f41add3} | 0 ...1 + 1.0-0-77ca48f121bd2ef41efb9ee3bc28418} | 0 ... + '1'-0-6beb1ef5178117a9fd641008ed5ebb80} | 0 ....0 + 1-0-bec2842d2b009973b4d4b8f10b5554f8} | 0 ... + 1.0-0-eafdfdbb14980ee517c388dc117d91a8} | 0 ...0 + 1L-0-ef273f05968cd0e91af8c76949c73798} | 0 ...0 + 1S-0-9f93538c38920d52b322bfc40cc2f31a} | 0 ...0 + 1Y-0-9e354e022b1b423f366bf79ed7522f2a} | 0 ... + 1.0-0-9b0510d0bb3e9ee6a7698369b008a280} | 0 ... + 1.0-0-c3d54e5b6034b7796ed16896a434d1ba} | 0 ... + 1.0-0-7b54e1d367c2ed1f5c181298ee5470d0} | 0 ...l end -0-cf71a4c4cce08635cc80a64a1ae6bc83} | 0 ....0 end -0-dfc876530eeaa7c42978d1bc0b1fd58} | 0 ...tDISTs-0-d9065e533430691d70b3370174fbbd50} | 0 .../execution/HiveTypeCoercionSuite.scala | 20 +++++--- .../sql/hive/execution/HiveUDFSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 8 +-- .../sql/sources/BucketedWriteSuite.scala | 10 ++-- 28 files changed, 83 insertions(+), 59 deletions(-) rename sql/hive/src/test/resources/golden/{'1' + 1.0-0-5db3b55120a19863d96460d399c2d0e => '1' + 1.0-0-404b0ea20c125c9648b7919a8f41add3} (100%) rename sql/hive/src/test/resources/golden/{1 + 1.0-0-4f5da98a11db8e7192423c27db767ca6 => 1 + 1.0-0-77ca48f121bd2ef41efb9ee3bc28418} (100%) rename sql/hive/src/test/resources/golden/{1.0 + '1'-0-a6ec78b3b93d52034aab829d43210e73 => 1.0 + '1'-0-6beb1ef5178117a9fd641008ed5ebb80} (100%) rename sql/hive/src/test/resources/golden/{1.0 + 1-0-30a4b1c8227906931cd0532367bebc43 => 1.0 + 1-0-bec2842d2b009973b4d4b8f10b5554f8} (100%) rename sql/hive/src/test/resources/golden/{1.0 + 1.0-0-87321b2e30ee2986b00b631d0e4f4d8d => 1.0 + 1.0-0-eafdfdbb14980ee517c388dc117d91a8} (100%) rename sql/hive/src/test/resources/golden/{1.0 + 1L-0-44bb88a1c9280952e8119a3ab1bb4205 => 1.0 + 1L-0-ef273f05968cd0e91af8c76949c73798} (100%) rename sql/hive/src/test/resources/golden/{1.0 + 1S-0-31fbe14d01fb532176c1689680398368 => 1.0 + 1S-0-9f93538c38920d52b322bfc40cc2f31a} (100%) rename sql/hive/src/test/resources/golden/{1.0 + 1Y-0-12bcf6e49e83abd2aa36ea612b418d43 => 1.0 + 1Y-0-9e354e022b1b423f366bf79ed7522f2a} (100%) rename sql/hive/src/test/resources/golden/{1L + 1.0-0-95a30c4b746f520f1251981a66cef5c8 => 1L + 1.0-0-9b0510d0bb3e9ee6a7698369b008a280} (100%) rename sql/hive/src/test/resources/golden/{1S + 1.0-0-8dfa46ec33c1be5ffba2e40cbfe5349e => 1S + 1.0-0-c3d54e5b6034b7796ed16896a434d1ba} (100%) rename sql/hive/src/test/resources/golden/{1Y + 1.0-0-3ad5e3db0d0300312d33231e7c2a6c8d => 1Y + 1.0-0-7b54e1d367c2ed1f5c181298ee5470d0} (100%) rename sql/hive/src/test/resources/golden/{case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 => case when then 1.0 else null end -0-cf71a4c4cce08635cc80a64a1ae6bc83} (100%) rename sql/hive/src/test/resources/golden/{case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 => case when then null else 1.0 end -0-dfc876530eeaa7c42978d1bc0b1fd58} (100%) rename sql/hive/src/test/resources/golden/{windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad => windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50} (100%) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 14d40d5066e78..a389dd71a28ee 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1812,7 +1812,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(x), "map") df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") - expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) coltypes(df) <- c("character", "integer") diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 047a7e56cb577..957bb234e4901 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -122,7 +122,7 @@ constant | BigintLiteral | SmallintLiteral | TinyintLiteral - | DecimalLiteral + | DoubleLiteral | booleanValue ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index ee2882e51c450..e4ffc634e8bf4 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -418,9 +418,9 @@ TinyintLiteral (Digit)+ 'Y' ; -DecimalLiteral +DoubleLiteral : - Number 'B' 'D' + Number 'D' ; ByteLengthLiteral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 35273c7e24ae4..f531d59a75cf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -623,6 +623,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val CASE = "(?i)CASE".r val INTEGRAL = "[+-]?\\d+".r + val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r protected def nodeToExpr(node: ASTNode): Expression = node match { /* Attribute References */ @@ -785,8 +786,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast if ast.tokenType == SparkSqlParser.BigintLiteral => Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - case ast if ast.tokenType == SparkSqlParser.DecimalLiteral => - Literal(Decimal(ast.text.substring(0, ast.text.length() - 2))) + case ast if ast.tokenType == SparkSqlParser.DoubleLiteral => + Literal(ast.text.toDouble) case ast if ast.tokenType == SparkSqlParser.Number => val text = ast.text @@ -799,7 +800,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Literal(v.longValue()) case v => Literal(v.underlying()) } + case DECIMAL(_*) => + Literal(BigDecimal(text).underlying()) case _ => + // Convert a scientifically notated decimal into a double. Literal(text.toDouble) } case ast if ast.tokenType == SparkSqlParser.StringLiteral => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index c557c3231997a..6e43bdfd92d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -692,12 +692,11 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. - // Convert If(null literal, _, _) into boolean type. // In the optimizer, we should short-circuit this directly into false value. case If(pred, left, right) if pred.dataType == NullType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index aec450e0a6084..013a90875e2b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -212,7 +212,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) - val pi = "3.1415BD" + val pi = "3.1415" checkAnswer( sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), @@ -367,6 +367,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { checkAnswer( input.toDF("key", "value").selectExpr("abs(key) a").sort("a"), input.map(pair => Row(pair._2))) + + checkAnswer( + sql("select abs(0), abs(-1), abs(123), abs(-9223372036854775807), abs(9223372036854775807)"), + Row(0, 1, 123, 9223372036854775807L, 9223372036854775807L) + ) + + checkAnswer( + sql("select abs(0.0), abs(-3.14159265), abs(3.14159265)"), + Row(BigDecimal("0.0"), BigDecimal("3.14159265"), BigDecimal("3.14159265")) + ) } test("log2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b159346bed9f7..47308966e92cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1174,19 +1174,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3)) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8)) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -1200,11 +1200,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT 9223372036854775808BD"), Row(new java.math.BigDecimal("9223372036854775808")) + sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809BD"), Row(new java.math.BigDecimal("-9223372036854775809")) + sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } @@ -1219,11 +1219,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT -5.2BD"), Row(BigDecimal(-5.2)) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8d) + sql("SELECT +6.8e0"), Row(6.8d) ) checkAnswer( @@ -1598,20 +1598,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3BD * 3.0BD"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000BD * 3.0BD"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000BD * 30.0BD"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000BD * 3.000000000000000000BD"), + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000BD * 3.0000000000000000000BD"), + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), Row(null)) - checkAnswer(sql("select 10.3BD / 3.0BD"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000BD / 3.0BD"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000BD / 30.0BD"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000BD / 3.00000000000000000BD"), + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000BD / 3.00000000000000000BD"), + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) } @@ -1637,13 +1637,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("precision smaller than scale") { - checkAnswer(sql("select 10.00BD"), Row(BigDecimal("10.00"))) - checkAnswer(sql("select 1.00BD"), Row(BigDecimal("1.00"))) - checkAnswer(sql("select 0.10BD"), Row(BigDecimal("0.10"))) - checkAnswer(sql("select 0.01BD"), Row(BigDecimal("0.01"))) - checkAnswer(sql("select 0.001BD"), Row(BigDecimal("0.001"))) - checkAnswer(sql("select -0.01BD"), Row(BigDecimal("-0.01"))) - checkAnswer(sql("select -0.001BD"), Row(BigDecimal("-0.001"))) + checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) } test("external sorting updates peak execution memory") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8de8ba355e7d4..a3c6a1d7b20ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -442,13 +442,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2BD from jsonTable where num_str > 14"), + sql("select num_str + 1.2 from jsonTable where num_str > 14"), Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2BD from jsonTable where num_str >= 92233720368547758060BD"), + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), Row(new java.math.BigDecimal("92233720368547758071.2")) ) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 828ec9710550c..554d47d651aef 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -323,7 +323,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Feature removed in HIVE-11145 "alter_partition_protect_mode", "drop_partitions_ignore_protection", - "protectmode" + "protectmode", + + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. + "udf_abs", + "udf_format_number", + "udf_round", + "udf_round_3", + "view_cast" ) /** @@ -884,7 +891,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_10_trims", "udf_E", "udf_PI", - "udf_abs", "udf_acos", "udf_add", "udf_array", @@ -928,7 +934,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_find_in_set", "udf_float", "udf_floor", - "udf_format_number", "udf_from_unixtime", "udf_greaterthan", "udf_greaterthanorequal", @@ -976,8 +981,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", - "udf_round_3", "udf_rpad", "udf_rtrim", "udf_sign", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index bad3ca6da231f..d0b4cbe401eb3 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -559,7 +559,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte """ |select p_mfgr,p_name, p_size, |histogram_numeric(p_retailprice, 5) over w1 as hist, - |percentile(p_partkey, 0.5) over w1 as per, + |percentile(p_partkey, cast(0.5 as double)) over w1 as per, |row_number() over(distribute by p_mfgr sort by p_name) as rn |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name diff --git a/sql/hive/src/test/resources/golden/'1' + 1.0-0-5db3b55120a19863d96460d399c2d0e b/sql/hive/src/test/resources/golden/'1' + 1.0-0-404b0ea20c125c9648b7919a8f41add3 similarity index 100% rename from sql/hive/src/test/resources/golden/'1' + 1.0-0-5db3b55120a19863d96460d399c2d0e rename to sql/hive/src/test/resources/golden/'1' + 1.0-0-404b0ea20c125c9648b7919a8f41add3 diff --git a/sql/hive/src/test/resources/golden/1 + 1.0-0-4f5da98a11db8e7192423c27db767ca6 b/sql/hive/src/test/resources/golden/1 + 1.0-0-77ca48f121bd2ef41efb9ee3bc28418 similarity index 100% rename from sql/hive/src/test/resources/golden/1 + 1.0-0-4f5da98a11db8e7192423c27db767ca6 rename to sql/hive/src/test/resources/golden/1 + 1.0-0-77ca48f121bd2ef41efb9ee3bc28418 diff --git a/sql/hive/src/test/resources/golden/1.0 + '1'-0-a6ec78b3b93d52034aab829d43210e73 b/sql/hive/src/test/resources/golden/1.0 + '1'-0-6beb1ef5178117a9fd641008ed5ebb80 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + '1'-0-a6ec78b3b93d52034aab829d43210e73 rename to sql/hive/src/test/resources/golden/1.0 + '1'-0-6beb1ef5178117a9fd641008ed5ebb80 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1-0-30a4b1c8227906931cd0532367bebc43 b/sql/hive/src/test/resources/golden/1.0 + 1-0-bec2842d2b009973b4d4b8f10b5554f8 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1-0-30a4b1c8227906931cd0532367bebc43 rename to sql/hive/src/test/resources/golden/1.0 + 1-0-bec2842d2b009973b4d4b8f10b5554f8 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1.0-0-87321b2e30ee2986b00b631d0e4f4d8d b/sql/hive/src/test/resources/golden/1.0 + 1.0-0-eafdfdbb14980ee517c388dc117d91a8 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1.0-0-87321b2e30ee2986b00b631d0e4f4d8d rename to sql/hive/src/test/resources/golden/1.0 + 1.0-0-eafdfdbb14980ee517c388dc117d91a8 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1L-0-44bb88a1c9280952e8119a3ab1bb4205 b/sql/hive/src/test/resources/golden/1.0 + 1L-0-ef273f05968cd0e91af8c76949c73798 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1L-0-44bb88a1c9280952e8119a3ab1bb4205 rename to sql/hive/src/test/resources/golden/1.0 + 1L-0-ef273f05968cd0e91af8c76949c73798 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1S-0-31fbe14d01fb532176c1689680398368 b/sql/hive/src/test/resources/golden/1.0 + 1S-0-9f93538c38920d52b322bfc40cc2f31a similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1S-0-31fbe14d01fb532176c1689680398368 rename to sql/hive/src/test/resources/golden/1.0 + 1S-0-9f93538c38920d52b322bfc40cc2f31a diff --git a/sql/hive/src/test/resources/golden/1.0 + 1Y-0-12bcf6e49e83abd2aa36ea612b418d43 b/sql/hive/src/test/resources/golden/1.0 + 1Y-0-9e354e022b1b423f366bf79ed7522f2a similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1Y-0-12bcf6e49e83abd2aa36ea612b418d43 rename to sql/hive/src/test/resources/golden/1.0 + 1Y-0-9e354e022b1b423f366bf79ed7522f2a diff --git a/sql/hive/src/test/resources/golden/1L + 1.0-0-95a30c4b746f520f1251981a66cef5c8 b/sql/hive/src/test/resources/golden/1L + 1.0-0-9b0510d0bb3e9ee6a7698369b008a280 similarity index 100% rename from sql/hive/src/test/resources/golden/1L + 1.0-0-95a30c4b746f520f1251981a66cef5c8 rename to sql/hive/src/test/resources/golden/1L + 1.0-0-9b0510d0bb3e9ee6a7698369b008a280 diff --git a/sql/hive/src/test/resources/golden/1S + 1.0-0-8dfa46ec33c1be5ffba2e40cbfe5349e b/sql/hive/src/test/resources/golden/1S + 1.0-0-c3d54e5b6034b7796ed16896a434d1ba similarity index 100% rename from sql/hive/src/test/resources/golden/1S + 1.0-0-8dfa46ec33c1be5ffba2e40cbfe5349e rename to sql/hive/src/test/resources/golden/1S + 1.0-0-c3d54e5b6034b7796ed16896a434d1ba diff --git a/sql/hive/src/test/resources/golden/1Y + 1.0-0-3ad5e3db0d0300312d33231e7c2a6c8d b/sql/hive/src/test/resources/golden/1Y + 1.0-0-7b54e1d367c2ed1f5c181298ee5470d0 similarity index 100% rename from sql/hive/src/test/resources/golden/1Y + 1.0-0-3ad5e3db0d0300312d33231e7c2a6c8d rename to sql/hive/src/test/resources/golden/1Y + 1.0-0-7b54e1d367c2ed1f5c181298ee5470d0 diff --git a/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 b/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-cf71a4c4cce08635cc80a64a1ae6bc83 similarity index 100% rename from sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 rename to sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-cf71a4c4cce08635cc80a64a1ae6bc83 diff --git a/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 b/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-dfc876530eeaa7c42978d1bc0b1fd58 similarity index 100% rename from sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 rename to sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-dfc876530eeaa7c42978d1bc0b1fd58 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 similarity index 100% rename from sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad rename to sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index d2f91861ff73b..6b424d73430e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -25,20 +25,26 @@ import org.apache.spark.sql.hive.test.TestHive * A set of tests that validate type promotion and coercion rules. */ class HiveTypeCoercionSuite extends HiveComparisonTest { - val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") + val baseTypes = Seq( + ("1", "1"), + ("1.0", "CAST(1.0 AS DOUBLE)"), + ("1L", "1L"), + ("1S", "1S"), + ("1Y", "1Y"), + ("'1'", "'1'")) - baseTypes.foreach { i => - baseTypes.foreach { j => - createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") + baseTypes.foreach { case (ni, si) => + baseTypes.foreach { case (nj, sj) => + createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" - baseTypes.init.foreach { i => + baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", - s"SELECT case when true then $i else $nullVal end FROM src limit 1") + s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", - s"SELECT case when true then $nullVal else $i end FROM src limit 1") + s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index af76ff91a267c..703cfffee1587 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -143,10 +143,10 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Generic UDAF aggregates") { - checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + checkAnswer(sql("SELECT percentile_approx(100.0D, array(0.9D, 0.9D)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f7d8d395ba648..683008960aa28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1012,9 +1012,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | java_method("java.lang.String", "isEmpty"), | java_method("java.lang.Math", "max", 2, 3), | java_method("java.lang.Math", "min", 2, 3), - | java_method("java.lang.Math", "round", 2.5), - | java_method("java.lang.Math", "exp", 1.0), - | java_method("java.lang.Math", "floor", 1.9) + | java_method("java.lang.Math", "round", 2.5D), + | java_method("java.lang.Math", "exp", 1.0D), + | java_method("java.lang.Math", "floor", 1.9D) |FROM src tablesample (1 rows) """.stripMargin), Row( @@ -1461,6 +1461,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """ |SELECT json_tuple(json, 'f1', 'f2'), 3.14, str |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test - """.stripMargin), Row("value1", "12", 3.14, "hello")) + """.stripMargin), Row("value1", "12", BigDecimal("3.14"), "hello")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index dad1fc1273810..8cac7fe48f171 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.net.URI import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection @@ -65,6 +66,11 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + def tableDir: File = { + val identifier = hiveContext.sqlParser.parseTableIdentifier("bucketed_table") + new File(URI.create(hiveContext.catalog.hiveDefaultTableFilePath(identifier))) + } + /** * A helper method to check the bucket write functionality in low level, i.e. check the written * bucket files to see if the data are correct. User should pass in a data dir that these bucket @@ -127,7 +133,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .bucketBy(8, "j", "k") .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) } @@ -145,7 +150,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .sortBy("k") .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") for (i <- 0 until 5) { testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k")) } @@ -161,7 +165,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .bucketBy(8, "i", "j") .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") testBucketing(tableDir, source, 8, Seq("i", "j")) } } @@ -176,7 +179,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle .sortBy("k") .saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k")) } } From b362239df566bc949283f2ac195ee89af105605a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 15:24:01 -0800 Subject: [PATCH 1648/2089] [SPARK-12797] [SQL] Generated TungstenAggregate (without grouping keys) As discussed in #10786, the generated TungstenAggregate does not support imperative functions. For a query ``` sqlContext.range(10).filter("id > 1").groupBy().count() ``` The generated code will looks like: ``` /* 032 */ if (!initAgg0) { /* 033 */ initAgg0 = true; /* 034 */ /* 035 */ // initialize aggregation buffer /* 037 */ long bufValue2 = 0L; /* 038 */ /* 039 */ /* 040 */ // initialize Range /* 041 */ if (!range_initRange5) { /* 042 */ range_initRange5 = true; ... /* 071 */ } /* 072 */ /* 073 */ while (!range_overflow8 && range_number7 < range_partitionEnd6) { /* 074 */ long range_value9 = range_number7; /* 075 */ range_number7 += 1L; /* 076 */ if (range_number7 < range_value9 ^ 1L < 0) { /* 077 */ range_overflow8 = true; /* 078 */ } /* 079 */ /* 085 */ boolean primitive11 = false; /* 086 */ primitive11 = range_value9 > 1L; /* 087 */ if (!false && primitive11) { /* 092 */ // do aggregate and update aggregation buffer /* 099 */ long primitive17 = -1L; /* 100 */ primitive17 = bufValue2 + 1L; /* 101 */ bufValue2 = primitive17; /* 105 */ } /* 107 */ } /* 109 */ /* 110 */ // output the result /* 112 */ bufferHolder25.reset(); /* 114 */ rowWriter26.initialize(bufferHolder25, 1); /* 118 */ rowWriter26.write(0, bufValue2); /* 120 */ result24.pointTo(bufferHolder25.buffer, bufferHolder25.totalSize()); /* 121 */ currentRow = result24; /* 122 */ return; /* 124 */ } /* 125 */ ``` cc nongli Author: Davies Liu Closes #10840 from davies/gen_agg. --- .../sql/execution/WholeStageCodegen.scala | 12 ++- .../aggregate/TungstenAggregate.scala | 87 ++++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 8 +- .../execution/WholeStageCodegenSuite.scala | 12 +++ .../execution/metric/SQLMetricsSuite.scala | 4 +- 5 files changed, 111 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c15fabab805a7..57f4945de9804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) */ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { + private def supportCodegen(e: Expression): Boolean = e match { + case e: LeafExpression => true + // CodegenFallback requires the input to be an InternalRow + case e: CodegenFallback => false + case _ => true + } + private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - // Non-leaf with CodegenFallback does not work with whole stage codegen - val willFallback = plan.expressions.exists( - _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined - ) + val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val haveManyColumns = plan.output.length > 200 !willFallback && !haveManyColumns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 8dcbab4c8cfbc..23e54f344d252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType @@ -35,7 +36,7 @@ case class TungstenAggregate( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -113,6 +114,86 @@ case class TungstenAggregate( } } + override def supportCodegen: Boolean = { + groupingExpressions.isEmpty && + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private val modes = aggregateExpressions.map(_.mode).distinct + + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // generate variables for aggregation buffer + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + // The initial expression should not access any column + val ev = e.gen(ctx) + val initVars = s""" + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = + s""" + | if (!$initAgg) { + | $initAgg = true; + | + | // initialize aggregation buffer + | ${bufVars.map(_.code).mkString("\n")} + | + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} + | } + """.stripMargin + + (rdd, source) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) + } + + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) + ctx.currentVars = bufVars ++ input + // TODO: support subexpression elimination + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) + s""" + | ${ev.code} + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + + s""" + | // do aggregate and update aggregation buffer + | ${codes.mkString("")} + """.stripMargin + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 788b04fcf8c2e..c4aad398bfa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Without whole stage codegen 6725.52 31.18 1.00 X - With whole stage codegen 2233.05 93.91 3.01 X + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Without whole stage codegen 7775.53 26.97 1.00 X + With whole stage codegen 342.15 612.94 22.73 X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c54fc6ba2de3d..300788c88ab2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.functions.{avg, col, max} import org.apache.spark.sql.test.SharedSQLContext class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = false ) } + + test("Aggregate should be included in WholeStageCodegen") { + val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(9, 4.5))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4339f7260dcb9..51285431a47ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - df.collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + df.collect() + } sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) From 015c8efb3774c57be6f3fee5a454622879cab1ec Mon Sep 17 00:00:00 2001 From: wangfei Date: Wed, 20 Jan 2016 17:11:52 -0800 Subject: [PATCH 1649/2089] [SPARK-8968][SQL] external sort by the partition clomns when dynamic partitioning to optimize the memory overhead Now the hash based writer dynamic partitioning show the bad performance for big data and cause many small files and high GC. This patch we do external sort first so that each time we only need open one writer. before this patch: ![gc](https://cloud.githubusercontent.com/assets/7018048/9149788/edc48c6e-3dec-11e5-828c-9995b56e4d65.PNG) after this patch: ![gc-optimize-externalsort](https://cloud.githubusercontent.com/assets/7018048/9149794/60f80c9c-3ded-11e5-8a56-7ae18ddc7a2f.png) Author: wangfei Author: scwf Closes #7336 from scwf/dynamic-optimize-basedon-apachespark. --- .../hive/execution/InsertIntoHiveTable.scala | 69 ++---- .../spark/sql/hive/hiveWriterContainers.scala | 196 +++++++++++++----- 2 files changed, 166 insertions(+), 99 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index b02ace786c66c..feb133d44898a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -24,20 +24,16 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.{Context, ErrorMsg} -import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.types.DataType +import org.apache.spark.SparkException import org.apache.spark.util.SerializableJobConf private[hive] @@ -46,19 +42,12 @@ case class InsertIntoHiveTable( partition: Map[String, Option[String]], child: SparkPlan, overwrite: Boolean, - ifNotExists: Boolean) extends UnaryNode with HiveInspectors { + ifNotExists: Boolean) extends UnaryNode { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @transient private lazy val catalog = sc.catalog - private def newSerializer(tableDesc: TableDesc): Serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] - serializer.initialize(null, tableDesc.getProperties) - serializer - } - def output: Seq[Attribute] = Seq.empty private def saveAsHiveFile( @@ -78,44 +67,10 @@ case class InsertIntoHiveTable( conf.value, SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writeToFile _) + sc.sparkContext.runJob(rdd, writerContainer.writeToFile _) writerContainer.commitJob() - // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] - - val fieldOIs = standardOI.getAllStructFieldRefs.asScala - .map(_.getFieldObjectInspector).toArray - val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray - val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} - val outputData = new Array[Any](fieldOIs.length) - - writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - - val proj = FromUnsafeProjection(child.schema) - iterator.foreach { row => - var i = 0 - val safeRow = proj(row) - while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) - i += 1 - } - - writerContainer - .getLocalFileWriter(safeRow, table.schema) - .write(serializer.serialize(outputData, standardOI)) - } - - writerContainer.close() - } } /** @@ -194,11 +149,21 @@ case class InsertIntoHiveTable( val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) - new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + new SparkHiveDynamicPartitionWriterContainer( + jobConf, + fileSinkConf, + dynamicPartColNames, + child.output, + table) } else { - new SparkHiveWriterContainer(jobConf, fileSinkConf) + new SparkHiveWriterContainer( + jobConf, + fileSinkConf, + child.output, + table) } + @transient val outputClass = writerContainer.newSerializer(table.tableDesc).getSerializedClass saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) val outputPath = FileOutputFormat.getOutputPath(jobConf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 22182ba00986f..e9e08dbf8386a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import java.text.NumberFormat import java.util.Date -import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.FileUtils @@ -28,14 +28,18 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Serializer +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableJobConf @@ -45,9 +49,13 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging with Serializable { + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + inputSchema: Seq[Attribute], + table: MetastoreRelation) + extends Logging + with HiveInspectors + with Serializable { private val now = new Date() private val tableDesc: TableDesc = fileSinkConf.getTableInfo @@ -93,14 +101,12 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { - writer - } - def close() { // Seems the boolean value passed into close does not matter. - writer.close(false) - commit() + if (writer != null) { + writer.close(false) + commit() + } } def commitJob() { @@ -123,6 +129,13 @@ private[hive] class SparkHiveWriterContainer( SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } + def abortTask(): Unit = { + if (committer != null) { + committer.abortTask(taskContext) + } + logError(s"Task attempt $taskContext aborted.") + } + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { jobID = jobId splitID = splitId @@ -140,6 +153,44 @@ private[hive] class SparkHiveWriterContainer( conf.value.setBoolean("mapred.task.is.map", true) conf.value.setInt("mapred.task.partition", splitID) } + + def newSerializer(tableDesc: TableDesc): Serializer = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + + protected def prepareForWrite() = { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + val dataTypes = inputSchema.map(_.dataType) + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) } + val outputData = new Array[Any](fieldOIs.length) + (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) + } + + // this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + i += 1 + } + writer.write(serializer.serialize(outputData, standardOI)) + } + + close() + } } private[hive] object SparkHiveWriterContainer { @@ -163,25 +214,22 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc, - dynamicPartColNames: Array[String]) - extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + dynamicPartColNames: Array[String], + inputSchema: Seq[Attribute], + table: MetastoreRelation) + extends SparkHiveWriterContainer(jobConf, fileSinkConf, inputSchema, table) { import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) - @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ - override protected def initWriters(): Unit = { - // NOTE: This method is executed at the executor side. - // Actual writers are created for each dynamic partition on the fly. - writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + // do nothing } override def close(): Unit = { - writers.values.foreach(_.close(false)) - commit() + // do nothing } override def commitJob(): Unit = { @@ -198,33 +246,89 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: InternalRow, schema: StructType) - : FileSinkOperator.RecordWriter = { - def convertToHiveRawString(col: String, value: Any): String = { - val raw = String.valueOf(value) - schema(col).dataType match { - case DateType => DateTimeUtils.dateToString(raw.toInt) - case _: DecimalType => BigDecimal(raw).toString() - case _ => raw - } + // this function is executed on executor side + override def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + val partitionOutput = inputSchema.takeRight(dynamicPartColNames.length) + val dataOutput = inputSchema.take(fieldOIs.length) + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionOutput, inputSchema) + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataOutput, inputSchema) + + val fun: AnyRef = (pathString: String) => FileUtils.escapePathName(pathString, defaultPartName) + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionOutput.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF(fun, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartName), escaped) + val partitionName = Literal(dynamicPartColNames(i) + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName } - val nonDynamicPartLen = row.numFields - dynamicPartColNames.length - val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => - val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) - val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$colName=$colString" - }.mkString + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionOutput) + + // If anything below fails, we should abort the task. + try { + val sorter: UnsafeKVExternalSorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionOutput), + StructType.fromAttributes(dataOutput), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } - def newWriter(): FileSinkOperator.RecordWriter = { + logInfo(s"Sorting complete. Writing out partition files one at a time.") + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: FileSinkOperator.RecordWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close(false) + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey) + } + + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (sortedIterator.getValue.isNullAt(i)) { + null + } else { + wrappers(i)(sortedIterator.getValue.get(i, dataTypes(i))) + } + i += 1 + } + currentWriter.write(serializer.serialize(outputData, standardOI)) + } + } finally { + if (currentWriter != null) { + currentWriter.close(false) + } + } + commit() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): FileSinkOperator.RecordWriter = { + val partitionPath = getPartitionString(key).getString(0) val newFileSinkDesc = new FileSinkDesc( - fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getDirName + partitionPath, fileSinkConf.getTableInfo, fileSinkConf.getCompressed) newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) @@ -234,7 +338,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // to avoid write to the same file when `spark.speculation=true` val path = FileOutputFormat.getTaskOutputPath( conf.value, - dynamicPartPath.stripPrefix("/") + "/" + getOutputName) + partitionPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, @@ -244,7 +348,5 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( path, Reporter.NULL) } - - writers.getOrElseUpdate(dynamicPartPath, newWriter()) } } From d60f8d74ace5670b1b451a0ea0b93d3b9775bb52 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 20 Jan 2016 17:48:18 -0800 Subject: [PATCH 1650/2089] [SPARK-8968] [SQL] [HOT-FIX] Fix scala 2.11 build. --- .../scala/org/apache/spark/sql/hive/hiveWriterContainers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index e9e08dbf8386a..fc0dfd2d2f818 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -49,7 +49,7 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - @transient jobConf: JobConf, + @transient private val jobConf: JobConf, fileSinkConf: FileSinkDesc, inputSchema: Seq[Attribute], table: MetastoreRelation) From d7415991a1c65f44ba385bc697b458125366523f Mon Sep 17 00:00:00 2001 From: Shubhanshu Mishra Date: Wed, 20 Jan 2016 18:06:06 -0800 Subject: [PATCH 1651/2089] [SPARK-12910] Fixes : R version for installing sparkR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Testing code: ``` $ ./install-dev.sh USING R_HOME = /usr/bin ERROR: this R is version 2.15.1, package 'SparkR' requires R >= 3.0 ``` Using the new argument: ``` $ ./install-dev.sh /content/username/SOFTWARE/R-3.2.3 USING R_HOME = /content/username/SOFTWARE/R-3.2.3/bin * installing *source* package ‘SparkR’ ... ** R ** inst ** preparing package for lazy loading Creating a new generic function for ‘colnames’ in package ‘SparkR’ Creating a new generic function for ‘colnames<-’ in package ‘SparkR’ Creating a new generic function for ‘cov’ in package ‘SparkR’ Creating a new generic function for ‘na.omit’ in package ‘SparkR’ Creating a new generic function for ‘filter’ in package ‘SparkR’ Creating a new generic function for ‘intersect’ in package ‘SparkR’ Creating a new generic function for ‘sample’ in package ‘SparkR’ Creating a new generic function for ‘transform’ in package ‘SparkR’ Creating a new generic function for ‘subset’ in package ‘SparkR’ Creating a new generic function for ‘summary’ in package ‘SparkR’ Creating a new generic function for ‘lag’ in package ‘SparkR’ Creating a new generic function for ‘rank’ in package ‘SparkR’ Creating a new generic function for ‘sd’ in package ‘SparkR’ Creating a new generic function for ‘var’ in package ‘SparkR’ Creating a new generic function for ‘predict’ in package ‘SparkR’ Creating a new generic function for ‘rbind’ in package ‘SparkR’ Creating a generic function for ‘lapply’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘Filter’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘alias’ from package ‘stats’ in package ‘SparkR’ Creating a generic function for ‘substr’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘%in%’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘mean’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘unique’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘nrow’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘ncol’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘head’ from package ‘utils’ in package ‘SparkR’ Creating a generic function for ‘factorial’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘atan2’ from package ‘base’ in package ‘SparkR’ Creating a generic function for ‘ifelse’ from package ‘base’ in package ‘SparkR’ ** help No man pages found in package ‘SparkR’ *** installing help indices ** building package indices ** testing if installed package can be loaded * DONE (SparkR) ``` Author: Shubhanshu Mishra Closes #10836 from napsternxg/master. --- R/README.md | 10 ++++++++++ R/install-dev.sh | 11 +++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/R/README.md b/R/README.md index 005f56da1670c..bb3464ba9955d 100644 --- a/R/README.md +++ b/R/README.md @@ -1,6 +1,16 @@ # R on Spark SparkR is an R package that provides a light-weight frontend to use Spark from R. +### Installing sparkR + +Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. +By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. +Example: +``` +# where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript +export R_HOME=/home/username/R +./install-dev.sh +``` ### SparkR development diff --git a/R/install-dev.sh b/R/install-dev.sh index 4972bb9217072..befd413c4cd26 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -35,12 +35,19 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR pushd $FWDIR > /dev/null +if [ ! -z "$R_HOME" ] + then + R_SCRIPT_PATH="$R_HOME/bin" + else + R_SCRIPT_PATH="$(dirname $(which R))" +fi +echo "USING R_HOME = $R_HOME" # Generate Rd files if devtools is installed -Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' +"$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' # Install SparkR to $LIB_DIR -R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +"$R_SCRIPT_PATH/"R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ # Zip the SparkR package so that it can be distributed to worker nodes on YARN cd $LIB_DIR From 1b2a918e59addcdccdf8e011bce075cc9dd07b93 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 20 Jan 2016 21:08:15 -0800 Subject: [PATCH 1652/2089] [SPARK-12204][SPARKR] Implement drop method for DataFrame in SparkR. Author: Sun Rui Closes #10201 from sun-rui/SPARK-12204. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 77 ++++++++++++++++------- R/pkg/R/generics.R | 4 ++ R/pkg/inst/tests/testthat/test_context.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 31 +++++++-- docs/sql-programming-guide.md | 13 ++++ 6 files changed, 101 insertions(+), 27 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 00634c1a70c26..2cc1544bef080 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -39,6 +39,7 @@ exportMethods("arrange", "describe", "dim", "distinct", + "drop", "dropDuplicates", "dropna", "dtypes", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 629c1ce2eddc1..4653a73e11be3 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1192,23 +1192,10 @@ setMethod("$", signature(x = "DataFrame"), setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) - cols <- columns(x) - if (name %in% cols) { - if (is.null(value)) { - cols <- Filter(function(c) { c != name }, cols) - } - cols <- lapply(cols, function(c) { - if (c == name) { - alias(value, name) - } else { - col(c) - } - }) - nx <- select(x, cols) + + if (is.null(value)) { + nx <- drop(x, name) } else { - if (is.null(value)) { - return(x) - } nx <- withColumn(x, name, value) } x@sdf <- nx@sdf @@ -1386,12 +1373,13 @@ setMethod("selectExpr", #' WithColumn #' -#' Return a new DataFrame with the specified column added. +#' Return a new DataFrame by adding a column or replacing the existing column +#' that has the same name. #' #' @param x A DataFrame -#' @param colName A string containing the name of the new column. +#' @param colName A column name. #' @param col A Column expression. -#' @return A DataFrame with the new column added. +#' @return A DataFrame with the new column added or the existing column replaced. #' @family DataFrame functions #' @rdname withColumn #' @name withColumn @@ -1404,12 +1392,16 @@ setMethod("selectExpr", #' path <- "path/to/file.json" #' df <- read.json(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' # Replace an existing column +#' newDF2 <- withColumn(newDF, "newCol", newDF$col1) #' } setMethod("withColumn", signature(x = "DataFrame", colName = "character", col = "Column"), function(x, colName, col) { - select(x, x$"*", alias(col, colName)) + sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc) + dataFrame(sdf) }) + #' Mutate #' #' Return a new DataFrame with the specified columns added. @@ -2401,4 +2393,47 @@ setMethod("str", cat(paste0("\nDisplaying first ", ncol(localDF), " columns only.")) } } - }) \ No newline at end of file + }) + +#' drop +#' +#' Returns a new DataFrame with columns dropped. +#' This is a no-op if schema doesn't contain column name(s). +#' +#' @param x A SparkSQL DataFrame. +#' @param cols A character vector of column names or a Column. +#' @return A DataFrame +#' +#' @family DataFrame functions +#' @rdname drop +#' @name drop +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlCtx, path) +#' drop(df, "col1") +#' drop(df, c("col1", "col2")) +#' drop(df, df$col1) +#' } +setMethod("drop", + signature(x = "DataFrame"), + function(x, col) { + stopifnot(class(col) == "character" || class(col) == "Column") + + if (class(col) == "Column") { + sdf <- callJMethod(x@sdf, "drop", col@jc) + } else { + sdf <- callJMethod(x@sdf, "drop", as.list(col)) + } + dataFrame(sdf) + }) + +# Expose base::drop +setMethod("drop", + signature(x = "ANY"), + function(x) { + base::drop(x) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index d616266ead41b..9a8ab97bb8f9a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -428,6 +428,10 @@ setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname drop +#' @export +setGeneric("drop", function(x, ...) { standardGeneric("drop") }) + #' @rdname dropduplicates #' @export setGeneric("dropDuplicates", diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 3b14a497b487a..ad3f9722a4802 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -26,7 +26,7 @@ test_that("Check masked functions", { maskedBySparkR <- masked[funcSparkROrEmpty] namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform") + "summary", "transform", "drop") expect_equal(length(maskedBySparkR), length(namesOfMasked)) expect_equal(sort(maskedBySparkR), sort(namesOfMasked)) # above are those reported as masked when `library(SparkR)` diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index a389dd71a28ee..e59841ab9f787 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -824,11 +824,6 @@ test_that("select operators", { df$age2 <- df$age * 2 expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) - - df$age2 <- NULL - expect_equal(columns(df), c("name", "age")) - df$age3 <- NULL - expect_equal(columns(df), c("name", "age")) }) test_that("select with column", { @@ -854,6 +849,27 @@ test_that("select with column", { "To select multiple columns, use a character vector or list for col") }) +test_that("drop column", { + df <- select(read.json(sqlContext, jsonPath), "name", "age") + df1 <- drop(df, "name") + expect_equal(columns(df1), c("age")) + + df$age2 <- df$age + df1 <- drop(df, c("name", "age")) + expect_equal(columns(df1), c("age2")) + + df1 <- drop(df, df$age) + expect_equal(columns(df1), c("name", "age2")) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) + + # Test to make sure base::drop is not masked + expect_equal(drop(1:3 %*% 2:4), 20) +}) + test_that("subsetting", { # read.json returns columns in random order df <- select(read.json(sqlContext, jsonPath), "name", "age") @@ -1462,6 +1478,11 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(columns(newDF)[3], "newAge") expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + # Replace existing column + newDF <- withColumn(df, "age", df$age + 2) + expect_equal(length(columns(newDF)), 2) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + newDF2 <- withColumnRenamed(df, "age", "newerAge") expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index bc89c781562bd..fddc51379406b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2150,6 +2150,8 @@ options. --conf spark.sql.hive.thriftServer.singleSession=true \ ... {% endhighlight %} + - Since 1.6.1, withColumn method in sparkR supports adding a new column to or replacing existing columns + of the same name of a DataFrame. - From Spark 1.6, LongType casts to TimestampType expect seconds instead of microseconds. This change was made to match the behavior of Hive 1.2 for more consistent type casting to TimestampType @@ -2183,6 +2185,7 @@ options. users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate the DataFrame and the new DataFrame will include new files. + - DataFrame.withColumn method in pySpark supports adding a new column or replacing existing columns of the same name. ## Upgrading from Spark SQL 1.3 to 1.4 @@ -2262,6 +2265,16 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false")
    +#### Behavior change on DataFrame.withColumn + +Prior to 1.4, DataFrame.withColumn() supports adding a column only. The column will always be added +as a new column with its specified name in the result DataFrame even if there may be any existing +columns of the same name. Since 1.4, DataFrame.withColumn() supports adding a column of a different +name from names of all existing columns or replacing existing columns of the same name. + +Note that this change is only for Scala API, not for PySpark and SparkR. + + ## Upgrading from Spark SQL 1.0-1.2 to 1.3 In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the From 85200c09adc6eb98fadb8505f55cb44e3d8b3390 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 21 Jan 2016 16:30:20 +0100 Subject: [PATCH 1653/2089] [SPARK-12534][DOC] update documentation to list command line equivalent to properties Several Spark properties equivalent to Spark submit command line options are missing. Author: felixcheung Closes #10491 from felixcheung/sparksubmitdoc. --- docs/configuration.md | 10 +++++----- docs/job-scheduling.md | 5 ++++- docs/running-on-yarn.md | 27 +++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 12ac60129633d..acaeb830081e2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -173,7 +173,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overriden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. @@ -687,10 +687,10 @@ Apart from these, the following properties are also available, and may be useful spark.rdd.compress false - Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER in Java - and Scala or StorageLevel.MEMORY_ONLY in Python). - Can save substantial space at the cost of some extra CPU time. + Whether to compress serialized RDD partitions (e.g. for + StorageLevel.MEMORY_ONLY_SER in Java + and Scala or StorageLevel.MEMORY_ONLY in Python). + Can save substantial space at the cost of some extra CPU time. diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 6c587b3f0d8db..95d47794ea337 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -39,7 +39,10 @@ Resource allocation can be configured as follows, based on the cluster type: and optionally set `spark.cores.max` to limit each application's resource share as in the standalone mode. You should also set `spark.executor.memory` to control the executor memory. * **YARN:** The `--num-executors` option to the Spark YARN client controls how many executors it will allocate - on the cluster, while `--executor-memory` and `--executor-cores` control the resources per executor. + on the cluster (`spark.executor.instances` as configuration property), while `--executor-memory` + (`spark.executor.memory` configuration property) and `--executor-cores` (`spark.executor.cores` configuration + property) control the resources per executor. For more information, see the + [YARN Spark Properties](running-on-yarn.html). A second option available on Mesos is _dynamic sharing_ of CPU cores. In this mode, each Spark application still has a fixed and independent memory allocation (set by `spark.executor.memory`), but when the diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index a148c867eb18f..ad66b9f64a34b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -113,6 +113,19 @@ If you need a reference to the proper location to put log files in the YARN so t Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. + + spark.driver.memory + 1g + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 1g, 2g). + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-memory command line option + or in your default properties file. + + spark.driver.cores 1 @@ -202,6 +215,13 @@ If you need a reference to the proper location to put log files in the YARN so t Comma-separated list of files to be placed in the working directory of each executor. + + spark.executor.cores + 1 in YARN mode, all the available cores on the worker in standalone mode. + + The number of cores to use on each executor. For YARN and standalone mode only. + + spark.executor.instances 2 @@ -209,6 +229,13 @@ If you need a reference to the proper location to put log files in the YARN so t The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + + spark.executor.memory + 1g + + Amount of memory to use per executor process (e.g. 2g, 8g). + + spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 From b4574e387d0124667bdbb35f8c7c3e2065b14ba9 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 21 Jan 2016 17:24:48 -0800 Subject: [PATCH 1654/2089] [SPARK-12908][ML] Add warning message for LogisticRegression for potential converge issue When all labels are the same, it's a dangerous ground for LogisticRegression without intercept to converge. GLMNET doesn't support this case, and will just exit. GLM can train, but will have a warning message saying the algorithm doesn't converge. Author: DB Tsai Closes #10862 from dbtsai/add-tests. --- .../spark/ml/classification/LogisticRegression.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index dad8dfc84ec15..c98a78a515dc3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -300,6 +300,14 @@ class LogisticRegression @Since("1.2.0") ( s"training is not needed.") (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double]) } else { + if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { + logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } else if (!$(fitIntercept) && numClasses == 1) { + logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } + val featuresMean = summarizer.mean.toArray val featuresStd = summarizer.variance.toArray.map(math.sqrt) From 55c7dd031b8a58976922e469626469aa4aff1391 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 21 Jan 2016 18:55:28 -0800 Subject: [PATCH 1655/2089] [SPARK-12747][SQL] Use correct type name for Postgres JDBC's real array https://issues.apache.org/jira/browse/SPARK-12747 Postgres JDBC driver uses "FLOAT4" or "FLOAT8" not "real". Author: Liang-Chi Hsieh Closes #10695 from viirya/fix-postgres-jdbc. --- .../apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 8 +++++--- .../scala/org/apache/spark/sql/jdbc/PostgresDialect.scala | 2 ++ .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 559dc1fed163a..7d011be37067b 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -41,10 +41,10 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { conn.setCatalog("foo") conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[])").executeUpdate() + + "c10 integer[], c11 text[], c12 real[])").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate() } test("Type mapping for various types") { @@ -52,7 +52,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 12) + assert(types.length == 13) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -65,6 +65,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[String].isAssignableFrom(types(9))) assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) + assert(classOf[Seq[Double]].isAssignableFrom(types(12))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) @@ -80,6 +81,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getString(9) == "192.168.0.0/16") assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) + assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) } test("Basic write test") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index ad9e31690b2d8..8d43966480a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -60,6 +60,8 @@ private object PostgresDialect extends JdbcDialect { case StringType => Some(JdbcType("TEXT", Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) + case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 1fa22e2933318..518607543b482 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -514,6 +514,8 @@ class JDBCSuite extends SparkFunSuite val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") + assert(Postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") val errMsg = intercept[IllegalArgumentException] { Postgres.getJDBCType(ByteType) } From 006906db591666a7111066afd226325452be2e3e Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Thu, 21 Jan 2016 21:17:36 -0800 Subject: [PATCH 1656/2089] [SPARK-12960] [PYTHON] Some examples are missing support for python2 Without importing the print_function, the lines later on like ```print("Usage: direct_kafka_wordcount.py ", file=sys.stderr)``` fail when using python2.*. Import fixes that problem and doesn't break anything on python3 either. Author: Mark Grover Closes #10872 from markgrover/python2_compat. --- examples/src/main/python/streaming/direct_kafka_wordcount.py | 1 + .../src/main/python/examples/streaming/kinesis_wordcount_asl.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index ea20678b9acad..7097f7f4502bd 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -28,6 +28,7 @@ examples/src/main/python/streaming/direct_kafka_wordcount.py \ localhost:9092 test` """ +from __future__ import print_function import sys diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index f428f64da3c42..51f8c5ca6656b 100644 --- a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -54,6 +54,8 @@ See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on the Kinesis Spark Streaming integration. """ +from __future__ import print_function + import sys from pyspark import SparkContext From e13c147e74a52d74e259f04e49e368fab64cdc1f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 22 Jan 2016 01:03:41 -0800 Subject: [PATCH 1657/2089] [SPARK-12959][SQL] Writing Bucketed Data with Disabled Bucketing in SQLConf When users turn off bucketing in SQLConf, we should issue some messages to tell users these operations will be converted to normal way. Also added a test case for this scenario and fixed the helper function. Do you think this PR is helpful when using bucket tables? cloud-fan Thank you! Author: gatorsmile Closes #10870 from gatorsmile/bucketTableWritingTestcases. --- .../InsertIntoHadoopFsRelation.scala | 2 +- .../datasources/WriterContainer.scala | 2 +- .../sql/sources/BucketedWriteSuite.scala | 28 ++++++++++++++++--- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 314c957d57bb8..2d3e1714d2b7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty && relation.getBucketSpec.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.maybeBucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 563fd9eefcce3..6340229dbb78b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -311,7 +311,7 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - private val bucketSpec = relation.getBucketSpec + private val bucketSpec = relation.maybeBucketSpec private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 8cac7fe48f171..59b74d2b4c5ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import java.io.File import java.net.URI -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, SQLConf} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils @@ -88,10 +88,11 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ) for (bucketFile <- allBucketFiles) { - val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get - assert(bucketId >= 0 && bucketId < numBuckets) + val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse { + fail(s"Unable to find the related bucket files.") + } - // We may loss the type information after write(e.g. json format doesn't keep schema + // We may lose the type information after write(e.g. json format doesn't keep schema // information), here we get the types from the original dataframe. val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) val columns = (bucketCols ++ sortCols).zip(types).map { @@ -183,4 +184,23 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } } + + test("write bucketed data with bucketing disabled") { + // The configuration BUCKETING_ENABLED does not affect the writing path + withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + } + } + } + } + } } From 8a88e121283472c26e70563a4e04c109e9b183b3 Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Fri, 22 Jan 2016 10:35:02 -0800 Subject: [PATCH 1658/2089] [SPARK-12629][SPARKR] Fixes for DataFrame saveAsTable method I've tried to solve some of the issues mentioned in: https://issues.apache.org/jira/browse/SPARK-12629 Please, let me know what do you think. Thanks! Author: Narine Kokhlikyan Closes #10580 from NarineK/sparkrSavaAsRable. --- R/pkg/R/DataFrame.R | 23 +++++++++++++++++------ R/pkg/R/generics.R | 12 ++++++++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 15 ++++++++++++++- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 4653a73e11be3..3b7b8250b94f7 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1997,7 +1997,13 @@ setMethod("write.df", signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) + } else { + stop("sparkRHive or sparkRSQL context has to be specified") + } source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } @@ -2055,13 +2061,18 @@ setMethod("saveDF", #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", - signature(df = "DataFrame", tableName = "character", source = "character", - mode = "character"), + signature(df = "DataFrame", tableName = "character"), function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) + } else { + stop("sparkRHive or sparkRSQL context has to be specified") + } + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9a8ab97bb8f9a..04784d51566cb 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,7 +539,7 @@ setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("samp #' @rdname saveAsTable #' @export -setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { +setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) @@ -552,7 +552,15 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) +setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) { + standardGeneric("write.df") +}) + +#' @rdname write.df +#' @export +setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { + standardGeneric("saveDF") +}) #' @rdname write.json #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index e59841ab9f787..b52a11fb1a348 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -953,8 +953,21 @@ test_that("test HiveContext", { df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") expect_equal(count(df3), 3) - unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) + df4 <- sql(hiveCtx, "select * from hivetestbl") + expect_is(df4, "DataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + invisible(saveAsTable(df, "parquetest", "parquet", mode="overwrite", path=parquetDataPath)) + df5 <- sql(hiveCtx, "select * from parquetest") + expect_is(df5, "DataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) }) test_that("column operators", { From d8fefab4d8149f0638282570c75271ef35c65cff Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 22 Jan 2016 12:33:18 -0800 Subject: [PATCH 1659/2089] [HOTFIX][BUILD][TEST-MAVEN] Remove duplicate dependency Author: Shixiong Zhu Closes #10868 from zsxwing/hotfix-akka-pom. --- external/akka/pom.xml | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/external/akka/pom.xml b/external/akka/pom.xml index 34de9bae00e49..06c8e8aaabd8c 100644 --- a/external/akka/pom.xml +++ b/external/akka/pom.xml @@ -48,6 +48,10 @@ test-jar test + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${akka.group} akka-actor_${scala.binary.version} @@ -58,13 +62,6 @@ akka-remote_${scala.binary.version} ${akka.version} - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - target/scala-${scala.binary.version}/classes From bc1babd63da4ee56e6d371eb24805a5d714e8295 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 22 Jan 2016 21:20:04 -0800 Subject: [PATCH 1660/2089] [SPARK-7997][CORE] Remove Akka from Spark Core and Streaming - Remove Akka dependency from core. Note: the streaming-akka project still uses Akka. - Remove HttpFileServer - Remove Akka configs from SparkConf and SSLOptions - Rename `spark.akka.frameSize` to `spark.rpc.message.maxSize`. I think it's still worth to keep this config because using `DirectTaskResult` or `IndirectTaskResult` depends on it. - Update comments and docs Author: Shixiong Zhu Closes #10854 from zsxwing/remove-akka. --- core/pom.xml | 17 +- .../org/apache/spark/ContextCleaner.scala | 6 +- .../org/apache/spark/HttpFileServer.scala | 91 ------ .../org/apache/spark/MapOutputTracker.scala | 7 +- .../scala/org/apache/spark/SSLOptions.scala | 43 +-- .../org/apache/spark/SecurityManager.scala | 19 +- .../scala/org/apache/spark/SparkConf.scala | 32 +-- .../scala/org/apache/spark/SparkEnv.scala | 49 +--- .../org/apache/spark/deploy/Client.scala | 4 - .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 12 +- .../org/apache/spark/executor/Executor.scala | 8 +- .../cluster/CoarseGrainedClusterMessage.scala | 1 - .../CoarseGrainedSchedulerBackend.scala | 15 +- .../cluster/SimrSchedulerBackend.scala | 2 +- .../org/apache/spark/util/AkkaUtils.scala | 139 --------- .../org/apache/spark/util/RpcUtils.scala | 12 + .../org/apache/spark/FileServerSuite.scala | 264 ------------------ .../apache/spark/HeartbeatReceiverSuite.scala | 4 +- .../org/apache/spark/LocalSparkContext.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 12 +- .../apache/spark/SecurityManagerSuite.scala | 12 - .../apache/spark/deploy/DeployTestUtils.scala | 4 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../deploy/worker/DriverRunnerTest.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +- .../CoarseGrainedSchedulerBackendSuite.scala | 8 +- .../spark/scheduler/SparkListenerSuite.scala | 17 +- .../scheduler/TaskResultGetterSuite.scala | 26 +- dev/deps/spark-deps-hadoop-2.2 | 5 - dev/deps/spark-deps-hadoop-2.3 | 5 - dev/deps/spark-deps-hadoop-2.4 | 5 - dev/deps/spark-deps-hadoop-2.6 | 5 - dev/deps/spark-deps-hadoop-2.7 | 5 - docs/cluster-overview.md | 2 +- docs/configuration.md | 65 +---- docs/security.md | 30 +- .../mllib/feature/VectorTransformer.scala | 2 +- .../mllib/util/LocalClusterSparkContext.scala | 2 +- pom.xml | 5 + project/MimaExcludes.scala | 4 +- .../streaming/ReceivedBlockTrackerSuite.scala | 1 - .../spark/deploy/yarn/ExecutorRunnable.scala | 2 +- 43 files changed, 123 insertions(+), 831 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/HttpFileServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/AkkaUtils.scala delete mode 100644 core/src/test/scala/org/apache/spark/FileServerSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 2071a58de92e4..0ab170e028ab4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -185,19 +185,6 @@ commons-net commons-net - - ${akka.group} - akka-remote_${scala.binary.version} - - - ${akka.group} - akka-slf4j_${scala.binary.version} - - - ${akka.group} - akka-testkit_${scala.binary.version} - test - org.scala-lang scala-library @@ -224,6 +211,10 @@ io.netty netty-all + + io.netty + netty + com.clearspring.analytics stream diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 4628093b91cb8..5a42299a0bf83 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -86,8 +86,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). * * Due to SPARK-3015, this is set to true by default. This is intended to be only a temporary - * workaround for the issue, which is ultimately caused by the way the BlockManager actors - * issue inter-dependent blocking Akka messages to each other at high frequencies. This happens, + * workaround for the issue, which is ultimately caused by the way the BlockManager endpoints + * issue inter-dependent blocking RPC messages to each other at high frequencies. This happens, * for instance, when the driver performs a GC and cleans up all broadcast blocks that are no * longer in scope. */ @@ -101,7 +101,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { * exceptions on cleanup of shuffle blocks, as reported in SPARK-3139. To avoid that, this * parameter by default disables blocking on shuffle cleanups. Note that this does not affect * the cleanup of RDDs and broadcasts. This is intended to be a temporary workaround, - * until the real Akka issue (referred to in the comment above `blockOnCleanupTasks`) is + * until the real RPC issue (referred to in the comment above `blockOnCleanupTasks`) is * resolved. */ private val blockOnShuffleCleanupTasks = sc.conf.getBoolean( diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala deleted file mode 100644 index 46f9f9e9af7da..0000000000000 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -import com.google.common.io.Files - -import org.apache.spark.util.Utils - -private[spark] class HttpFileServer( - conf: SparkConf, - securityManager: SecurityManager, - requestedPort: Int = 0) - extends Logging { - - var baseDir : File = null - var fileDir : File = null - var jarDir : File = null - var httpServer : HttpServer = null - var serverUri : String = null - - def initialize() { - baseDir = Utils.createTempDir(Utils.getLocalDir(conf), "httpd") - fileDir = new File(baseDir, "files") - jarDir = new File(baseDir, "jars") - fileDir.mkdir() - jarDir.mkdir() - logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server") - httpServer.start() - serverUri = httpServer.uri - logDebug("HTTP file server started at: " + serverUri) - } - - def stop() { - httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs - try { - Utils.deleteRecursively(baseDir) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: ${baseDir.getAbsolutePath}", e) - } - } - - def addFile(file: File) : String = { - addFileToDir(file, fileDir) - serverUri + "/files/" + Utils.encodeFileNameToURIRawPath(file.getName) - } - - def addJar(file: File) : String = { - addFileToDir(file, jarDir) - serverUri + "/jars/" + Utils.encodeFileNameToURIRawPath(file.getName) - } - - def addDirectory(path: String, resourceBase: String): String = { - httpServer.addDirectory(path, resourceBase) - serverUri + path - } - - def addFileToDir(file: File, dir: File) : String = { - // Check whether the file is a directory. If it is, throw a more meaningful exception. - // If we don't catch this, Guava throws a very confusing error message: - // java.io.FileNotFoundException: [file] (No such file or directory) - // even though the directory ([file]) exists. - if (file.isDirectory) { - throw new IllegalArgumentException(s"$file cannot be a directory.") - } - Files.copy(file, new File(dir, file.getName)) - dir + "/" + Utils.encodeFileNameToURIRawPath(file.getName) - } - -} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1b59beb8d6efd..eb2fdecc83011 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -40,7 +40,7 @@ private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] class MapOutputTrackerMasterEndpoint( override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) extends RpcEndpoint with Logging { - val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => @@ -48,9 +48,10 @@ private[spark] class MapOutputTrackerMasterEndpoint( logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.length - if (serializedSize > maxAkkaFrameSize) { + if (serializedSize > maxRpcMessageSize) { + val msg = s"Map output statuses were $serializedSize bytes which " + - s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." + s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)." /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 261265f0b4c55..d755f07965e6c 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,9 +21,6 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext -import scala.collection.JavaConverters._ - -import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory /** @@ -31,8 +28,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * generate specific objects to configure SSL for different communication protocols. * * SSLOptions is intended to provide the maximum common set of SSL settings, which are supported - * by the protocol, which it can generate the configuration for. Since Akka doesn't support client - * authentication with SSL, SSLOptions cannot support it either. + * by the protocol, which it can generate the configuration for. * * @param enabled enables or disables SSL; if it is set to false, the rest of the * settings are disregarded @@ -88,43 +84,6 @@ private[spark] case class SSLOptions( } } - /** - * Creates an Akka configuration object which contains all the SSL settings represented by this - * object. It can be used then to compose the ultimate Akka configuration. - */ - def createAkkaConfig: Option[Config] = { - if (enabled) { - if (keyStoreType.isDefined) { - logWarning("Akka configuration does not support key store type."); - } - if (trustStoreType.isDefined) { - logWarning("Akka configuration does not support trust store type."); - } - - Some(ConfigFactory.empty() - .withValue("akka.remote.netty.tcp.security.key-store", - ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse(""))) - .withValue("akka.remote.netty.tcp.security.key-store-password", - ConfigValueFactory.fromAnyRef(keyStorePassword.getOrElse(""))) - .withValue("akka.remote.netty.tcp.security.trust-store", - ConfigValueFactory.fromAnyRef(trustStore.map(_.getAbsolutePath).getOrElse(""))) - .withValue("akka.remote.netty.tcp.security.trust-store-password", - ConfigValueFactory.fromAnyRef(trustStorePassword.getOrElse(""))) - .withValue("akka.remote.netty.tcp.security.key-password", - ConfigValueFactory.fromAnyRef(keyPassword.getOrElse(""))) - .withValue("akka.remote.netty.tcp.security.random-number-generator", - ConfigValueFactory.fromAnyRef("")) - .withValue("akka.remote.netty.tcp.security.protocol", - ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) - .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(supportedAlgorithms.asJava)) - .withValue("akka.remote.netty.tcp.enable-ssl", - ConfigValueFactory.fromAnyRef(true))) - } else { - None - } - } - /* * The supportedAlgorithms set is a subset of the enabledAlgorithms that * are supported by the current Java security provider for this protocol. diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index c5aec05c03fce..0675957e16680 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -67,17 +67,6 @@ import org.apache.spark.util.Utils * At this point spark has multiple communication protocols that need to be secured and * different underlying mechanisms are used depending on the protocol: * - * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. - * Akka remoting allows you to specify a secure cookie that will be exchanged - * and ensured to be identical in the connection handshake between the client - * and the server. If they are not identical then the client will be refused - * to connect to the server. There is no control of the underlying - * authentication mechanism so its not clear if the password is passed in - * plaintext or uses DIGEST-MD5 or some other mechanism. - * - * Akka also has an option to turn on SSL, this option is currently supported (see - * the details below). - * * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty * for the HttpServer. Jetty supports multiple authentication mechanisms - * Basic, Digest, Form, Spengo, etc. It also supports multiple different login @@ -168,16 +157,16 @@ import org.apache.spark.util.Utils * denote the global configuration for all the supported protocols. In order to override the global * configuration for the particular protocol, the properties must be overwritten in the * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global - * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be either `akka` for - * Akka based connections or `fs` for broadcast and file server. + * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be only`fs` for + * broadcast and file server. * * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of * options that can be specified. * * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions * object parses Spark configuration at a given namespace and builds the common representation - * of SSL settings. SSLOptions is then used to provide protocol-specific configuration like - * TypeSafe configuration for Akka or SSLContextFactory for Jetty. + * of SSL settings. SSLOptions is then used to provide protocol-specific SSLContextFactory for + * Jetty. * * SSL must be configured on each node and configured for each component involved in * communication using the particular protocol. In YARN clusters, the key-store can be prepared on diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 340e1f7824d1e..36e240e618490 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -344,17 +344,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { .map{case (k, v) => (k.substring(prefix.length), v)} } - /** Get all akka conf variables set on this SparkConf */ - def getAkkaConf: Seq[(String, String)] = - /* This is currently undocumented. If we want to make this public we should consider - * nesting options under the spark namespace to avoid conflicts with user akka options. - * Otherwise users configuring their own akka code via system properties could mess up - * spark's akka options. - * - * E.g. spark.akka.option.x.y.x = "value" - */ - getAll.filter { case (k, _) => isAkkaConf(k) } - /** * Returns the Spark application id, valid in the Driver after TaskScheduler registration and * from the start in the Executor. @@ -600,7 +589,9 @@ private[spark] object SparkConf extends Logging { "spark.yarn.max.executor.failures" -> Seq( AlternateConfig("spark.yarn.max.worker.failures", "1.5")), "spark.memory.offHeap.enabled" -> Seq( - AlternateConfig("spark.unsafe.offHeap", "1.6")) + AlternateConfig("spark.unsafe.offHeap", "1.6")), + "spark.rpc.message.maxSize" -> Seq( + AlternateConfig("spark.akka.frameSize", "1.6")) ) /** @@ -615,21 +606,13 @@ private[spark] object SparkConf extends Logging { }.toMap } - /** - * Return whether the given config is an akka config (e.g. akka.actor.provider). - * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). - */ - def isAkkaConf(name: String): Boolean = name.startsWith("akka.") - /** * Return whether the given config should be passed to an executor on start-up. * - * Certain akka and authentication configs are required from the executor when it connects to + * Certain authentication configs are required from the executor when it connects to * the scheduler, while the rest of the spark configs can be inherited from the driver later. */ def isExecutorStartupConf(name: String): Boolean = { - isAkkaConf(name) || - name.startsWith("spark.akka") || (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || name.startsWith("spark.rpc") || @@ -664,12 +647,19 @@ private[spark] object SparkConf extends Logging { logWarning( s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + s"may be removed in the future. ${cfg.deprecationMessage}") + return } allAlternatives.get(key).foreach { case (newKey, cfg) => logWarning( s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + s"and may be removed in the future. Please use the new key '$newKey' instead.") + return + } + if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) { + logWarning( + s"The configuration key $key is not supported any more " + + s"because Spark doesn't use Akka since 2.0") } } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ec43be0e2f3a4..9461afdc54124 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -23,7 +23,6 @@ import java.net.Socket import scala.collection.mutable import scala.util.Properties -import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -39,12 +38,12 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), - * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently + * including the serializer, RpcEnv, block manager, map output tracker, etc. Currently * Spark code finds the SparkEnv through a global variable, so all the threads can access the same * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). * @@ -55,7 +54,6 @@ import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} class SparkEnv ( val executorId: String, private[spark] val rpcEnv: RpcEnv, - _actorSystem: ActorSystem, // TODO Remove actorSystem val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -71,10 +69,6 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { - // TODO Remove actorSystem - @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") - val actorSystem: ActorSystem = _actorSystem - private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -96,13 +90,8 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - actorSystem.shutdown() rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() + rpcEnv.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. @@ -152,8 +141,8 @@ class SparkEnv ( object SparkEnv extends Logging { @volatile private var env: SparkEnv = _ - private[spark] val driverActorSystemName = "sparkDriver" - private[spark] val executorActorSystemName = "sparkExecutor" + private[spark] val driverSystemName = "sparkDriver" + private[spark] val executorSystemName = "sparkExecutor" def set(e: SparkEnv) { env = e @@ -202,7 +191,7 @@ object SparkEnv extends Logging { /** * Create a SparkEnv for an executor. - * In coarse-grained mode, the executor provides an actor system that is already instantiated. + * In coarse-grained mode, the executor provides an RpcEnv that is already instantiated. */ private[spark] def createExecutorEnv( conf: SparkConf, @@ -245,28 +234,11 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - // Create the ActorSystem for Akka and get the port it binds to. - val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, + val systemName = if (isDriver) driverSystemName else executorSystemName + val rpcEnv = RpcEnv.create(systemName, hostname, port, conf, securityManager, clientMode = !isDriver) - val actorSystem: ActorSystem = { - val actorSystemPort = - if (port == 0 || rpcEnv.address == null) { - port - } else { - rpcEnv.address.port + 1 - } - // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem( - actorSystemName + "ActorSystem", - hostname, - actorSystemPort, - conf, - securityManager - )._1 - } - // Figure out which port Akka actually bound to in case the original port is 0 or occupied. + // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. // In the non-driver case, the RPC env's address may be null since it may not be listening // for incoming connections. if (isDriver) { @@ -325,7 +297,7 @@ object SparkEnv extends Logging { new MapOutputTrackerWorker(conf) } - // Have to assign trackerActor after initialization as MapOutputTrackerActor + // Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint // requires the MapOutputTracker itself mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint( @@ -398,7 +370,6 @@ object SparkEnv extends Logging { val envInstance = new SparkEnv( executorId, rpcEnv, - actorSystem, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 63a20ab41a0f7..dcef03ef3e3ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -219,11 +219,7 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { - conf.set("spark.akka.logLifecycleEvents", "true") - } conf.set("spark.rpc.askTimeout", "10") - conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a1e8da1ec8f5d..a6749f7e38802 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -183,7 +183,7 @@ object SparkSubmit { } // In standalone cluster mode, there are two submission gateways: - // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper + // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper // (2) The new REST-based gateway introduced in Spark 1.3 // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over // to use the legacy gateway if the master endpoint turns out to be not a REST server. diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 58bd9ca3d12c3..e3a6c4c07a75e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -37,7 +37,6 @@ private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, - hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) @@ -55,8 +54,7 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[RegisterExecutorResponse]( - RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + ref.ask[RegisterExecutorResponse](RegisterExecutor(executorId, self, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { @@ -184,14 +182,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) - // SparkEnv will set spark.executor.port if the rpc env is listening for incoming - // connections (e.g., if it's using akka). Otherwise, the executor is running in - // client mode only, and does not accept incoming connections. - val sparkHostPort = env.conf.getOption("spark.executor.port").map { port => - hostname + ":" + port - }.orNull env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( - env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) + env.rpcEnv, driverUrl, executorId, cores, userClassPath, env)) workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 75d7e34d60eb2..030ae41db4a62 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -40,7 +40,7 @@ import org.apache.spark.util._ * Spark executor, backed by a threadpool to run tasks. * * This can be used with Mesos, YARN, and the standalone scheduler. - * An internal RPC interface (at the moment Akka) is used for communication with the driver, + * An internal RPC interface is used for communication with the driver, * except in the case of Mesos fine-grained mode. */ private[spark] class Executor( @@ -97,9 +97,9 @@ private[spark] class Executor( // Set the classloader for serializer env.serializer.setDefaultClassLoader(replClassLoader) - // Akka's message frame size. If task result is bigger than this, we use the block manager + // Max RPC message size. If task result is bigger than this, we use the block manager // to send the result back. - private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) // Limit of bytes for total size of results (default is 1GB) private val maxResultSize = Utils.getMaxResultSize(conf) @@ -263,7 +263,7 @@ private[spark] class Executor( s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) - } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + } else if (resultSize >= maxRpcMessageSize) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index f3d0d85476772..29e469c3f5983 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -48,7 +48,6 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutor( executorId: String, executorRef: RpcEndpointRef, - hostPort: String, cores: Int, logUrls: Map[String, String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index b808993aa6cd3..f69a3d371e5dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -27,7 +27,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME -import org.apache.spark.util.{AkkaUtils, SerializableBuffer, ThreadUtils, Utils} +import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} /** * A scheduler backend that waits for coarse-grained executors to connect. @@ -46,7 +46,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. var minRegisteredRatio = @@ -134,7 +134,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + case RegisterExecutor(executorId, executorRef, cores, logUrls) => if (executorDataMap.contains(executorId)) { context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { @@ -224,14 +224,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) - if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + if (serializedTask.limit >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + - "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + - "spark.akka.frameSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, - AkkaUtils.reservedSizeBytes) + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + + "spark.rpc.message.maxSize or using broadcast variables for large values." + msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 0a6f2c01c18df..a298cf5ef9a6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -49,7 +49,7 @@ private[spark] class SimrSchedulerBackend( val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") logInfo("Writing to HDFS file: " + driverFilePath) - logInfo("Writing Akka address: " + driverUrl) + logInfo("Writing Driver address: " + driverUrl) logInfo("Writing Spark UI Address: " + appUIAddress) // Create temporary file to prevent race condition where executors get empty driverUrl file diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala deleted file mode 100644 index 3f4ac9b2f18cd..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.JavaConverters._ - -import akka.actor.{ActorSystem, ExtendedActorSystem} -import com.typesafe.config.ConfigFactory -import org.apache.log4j.{Level, Logger} - -import org.apache.spark.{Logging, SecurityManager, SparkConf} - -/** - * Various utility classes for working with Akka. - */ -private[spark] object AkkaUtils extends Logging { - - /** - * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the - * ActorSystem itself and its port (which is hard to get from Akka). - * - * Note: the `name` parameter is important, as even if a client sends a message to right - * host + port, if the system name is incorrect, Akka will drop the message. - * - * If indestructible is set to true, the Actor System will continue running in the event - * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. - */ - def createActorSystem( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): (ActorSystem, Int) = { - val startService: Int => (ActorSystem, Int) = { actualPort => - doCreateActorSystem(name, host, actualPort, conf, securityManager) - } - Utils.startServiceOnPort(port, startService, conf, name) - } - - private def doCreateActorSystem( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): (ActorSystem, Int) = { - - val akkaThreads = conf.getInt("spark.akka.threads", 4) - val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", - conf.get("spark.network.timeout", "120s")) - val akkaFrameSize = maxFrameSizeBytes(conf) - val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) - val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" - if (!akkaLogLifecycleEvents) { - // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent. - // See: https://www.assembla.com/spaces/akka/tickets/3787#/ - Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL)) - } - - val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - - val akkaHeartBeatPausesS = conf.getTimeAsSeconds("spark.akka.heartbeat.pauses", "6000s") - val akkaHeartBeatIntervalS = conf.getTimeAsSeconds("spark.akka.heartbeat.interval", "1000s") - - val secretKey = securityManager.getSecretKey() - val isAuthOn = securityManager.isAuthenticationEnabled() - if (isAuthOn && secretKey == null) { - throw new Exception("Secret key is null with authentication on") - } - val requireCookie = if (isAuthOn) "on" else "off" - val secureCookie = if (isAuthOn) secretKey else "" - logDebug(s"In createActorSystem, requireCookie is: $requireCookie") - - val akkaSslConfig = securityManager.getSSLOptions("akka").createAkkaConfig - .getOrElse(ConfigFactory.empty()) - - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) - .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( - s""" - |akka.daemonic = on - |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] - |akka.stdout-loglevel = "ERROR" - |akka.jvm-exit-on-fatal-error = off - |akka.remote.require-cookie = "$requireCookie" - |akka.remote.secure-cookie = "$secureCookie" - |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatIntervalS s - |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPausesS s - |akka.actor.provider = "akka.remote.RemoteActorRefProvider" - |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" - |akka.remote.netty.tcp.hostname = "$host" - |akka.remote.netty.tcp.port = $port - |akka.remote.netty.tcp.tcp-nodelay = on - |akka.remote.netty.tcp.connection-timeout = $akkaTimeoutS s - |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B - |akka.remote.netty.tcp.execution-pool-size = $akkaThreads - |akka.actor.default-dispatcher.throughput = $akkaBatchSize - |akka.log-config-on-start = $logAkkaConfig - |akka.remote.log-remote-lifecycle-events = $lifecycleEvents - |akka.log-dead-letters = $lifecycleEvents - |akka.log-dead-letters-during-shutdown = $lifecycleEvents - """.stripMargin)) - - val actorSystem = ActorSystem(name, akkaConf) - val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider - val boundPort = provider.getDefaultAddress.port.get - (actorSystem, boundPort) - } - - private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 - - /** Returns the configured max frame size for Akka messages in bytes. */ - def maxFrameSizeBytes(conf: SparkConf): Int = { - val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128) - if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { - throw new IllegalArgumentException( - s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") - } - frameSizeInMB * 1024 * 1024 - } - - /** Space reserved for extra data in an Akka message besides serialized task or task result. */ - val reservedSizeBytes = 200 * 1024 - -} diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index b68936f5c9f0a..2bb8de568e803 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -53,4 +53,16 @@ private[spark] object RpcUtils { def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") } + + private val MAX_MESSAGE_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 + + /** Returns the configured max message size for messages in bytes. */ + def maxMessageSizeBytes(conf: SparkConf): Int = { + val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128) + if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) { + throw new IllegalArgumentException( + s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB") + } + maxSizeInMB * 1024 * 1024 + } } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala deleted file mode 100644 index bc7059b77fec5..0000000000000 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io._ -import java.net.URI -import java.util.jar.{JarEntry, JarOutputStream} -import javax.net.ssl.SSLException - -import com.google.common.io.{ByteStreams, Files} -import org.apache.commons.lang3.RandomUtils - -import org.apache.spark.util.Utils - -class FileServerSuite extends SparkFunSuite with LocalSparkContext { - - import SSLSampleConfigs._ - - @transient var tmpDir: File = _ - @transient var tmpFile: File = _ - @transient var tmpJarUrl: String = _ - - def newConf: SparkConf = new SparkConf(loadDefaults = false).set("spark.authenticate", "false") - - override def beforeEach() { - super.beforeEach() - resetSparkContext() - } - - override def beforeAll() { - super.beforeAll() - - tmpDir = Utils.createTempDir() - val testTempDir = new File(tmpDir, "test") - testTempDir.mkdir() - - val textFile = new File(testTempDir, "FileServerSuite.txt") - val pw = new PrintWriter(textFile) - // scalastyle:off println - pw.println("100") - // scalastyle:on println - pw.close() - - val jarFile = new File(testTempDir, "test.jar") - val jarStream = new FileOutputStream(jarFile) - val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) - - val jarEntry = new JarEntry(textFile.getName) - jar.putNextEntry(jarEntry) - - val in = new FileInputStream(textFile) - ByteStreams.copy(in, jar) - - in.close() - jar.close() - jarStream.close() - - tmpFile = textFile - tmpJarUrl = jarFile.toURI.toURL.toString - } - - override def afterAll() { - try { - Utils.deleteRecursively(tmpDir) - } finally { - super.afterAll() - } - } - - test("Distributing files locally") { - sc = new SparkContext("local[4]", "test", newConf) - sc.addFile(tmpFile.toString) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test("Distributing files locally security On") { - val sparkConf = new SparkConf(false) - sparkConf.set("spark.authenticate", "true") - sparkConf.set("spark.authenticate.secret", "good") - sc = new SparkContext("local[4]", "test", sparkConf) - - sc.addFile(tmpFile.toString) - assert(sc.env.securityManager.isAuthenticationEnabled() === true) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test("Distributing files locally using URL as input") { - // addFile("file:///....") - sc = new SparkContext("local[4]", "test", newConf) - sc.addFile(new File(tmpFile.toString).toURI.toString) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test ("Dynamically adding JARS locally") { - sc = new SparkContext("local[4]", "test", newConf) - sc.addJar(tmpJarUrl) - val testData = Array((1, 1)) - sc.parallelize(testData).foreach { x => - if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { - throw new SparkException("jar not added") - } - } - } - - test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) - sc.addFile(tmpFile.toString) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) - sc.addJar(tmpJarUrl) - val testData = Array((1, 1)) - sc.parallelize(testData).foreach { x => - if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { - throw new SparkException("jar not added") - } - } - } - - test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) - sc.addJar(tmpJarUrl.replace("file", "local")) - val testData = Array((1, 1)) - sc.parallelize(testData).foreach { x => - if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { - throw new SparkException("jar not added") - } - } - } - - test ("HttpFileServer should work with SSL") { - val sparkConf = sparkSSLConfig() - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - fileTransferTest(server, sm) - } finally { - server.stop() - } - } - - test ("HttpFileServer should work with SSL and good credentials") { - val sparkConf = sparkSSLConfig() - sparkConf.set("spark.authenticate", "true") - sparkConf.set("spark.authenticate.secret", "good") - - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - fileTransferTest(server, sm) - } finally { - server.stop() - } - } - - test ("HttpFileServer should not work with valid SSL and bad credentials") { - val sparkConf = sparkSSLConfig() - sparkConf.set("spark.authenticate", "true") - sparkConf.set("spark.authenticate.secret", "bad") - - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - intercept[IOException] { - fileTransferTest(server) - } - } finally { - server.stop() - } - } - - test ("HttpFileServer should not work with SSL when the server is untrusted") { - val sparkConf = sparkSSLConfigUntrusted() - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - intercept[SSLException] { - fileTransferTest(server) - } - } finally { - server.stop() - } - } - - def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { - val randomContent = RandomUtils.nextBytes(100) - val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) - Files.write(randomContent, file) - server.addFile(file) - - val uri = new URI(server.serverUri + "/files/" + file.getName) - - val connection = if (sm != null && sm.isAuthenticationEnabled()) { - Utils.constructURIForAuthentication(uri, sm).toURL.openConnection() - } else { - uri.toURL.openConnection() - } - - if (sm != null) { - Utils.setupSecureURLConnection(connection, sm) - } - - val buf = ByteStreams.toByteArray(connection.getInputStream) - assert(buf === randomContent) - } - -} diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 18e53508406dc..c7f629a14ba24 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -175,9 +175,9 @@ class HeartbeatReceiverSuite val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( - RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) + RegisterExecutor(executorId1, dummyExecutorEndpointRef1, 0, Map.empty)) fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( - RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) + RegisterExecutor(executorId2, dummyExecutorEndpointRef2, 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index e1a0bf7c933b9..24ec99c7e5e60 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -52,7 +52,7 @@ object LocalSparkContext { if (sc != null) { sc.stop() } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 3819c0a8f31dc..6546def5966e1 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -154,9 +154,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.shutdown() } - test("remote fetch below akka frame size") { + test("remote fetch below max RPC message size") { val newConf = new SparkConf - newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~123B, and no exception should be thrown + // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) @@ -179,9 +179,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("remote fetch exceeds akka frame size") { + test("remote fetch exceeds max RPC message size") { val newConf = new SparkConf - newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) @@ -189,7 +189,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. + // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 7603cef773691..8bdb237c28f66 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -181,10 +181,8 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) - val akkaSSLOptions = securityManager.getSSLOptions("akka") assert(securityManager.fileServerSSLOptions.enabled === true) - assert(akkaSSLOptions.enabled === true) assert(securityManager.sslSocketFactory.isDefined === true) assert(securityManager.hostnameVerifier.isDefined === true) @@ -198,16 +196,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - - assert(akkaSSLOptions.trustStore.isDefined === true) - assert(akkaSSLOptions.trustStore.get.getName === "truststore") - assert(akkaSSLOptions.keyStore.isDefined === true) - assert(akkaSSLOptions.keyStore.get.getName === "keystore") - assert(akkaSSLOptions.trustStorePassword === Some("password")) - assert(akkaSSLOptions.keyStorePassword === Some("password")) - assert(akkaSSLOptions.keyPassword === Some("password")) - assert(akkaSSLOptions.protocol === Some("TLSv1.2")) - assert(akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 190e4dd7285b3..9c13c15281a42 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -69,7 +69,7 @@ private[deploy] object DeployTestUtils { "publicAddress", new File("sparkHome"), new File("workDir"), - "akka://worker", + "spark://worker", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) @@ -84,7 +84,7 @@ private[deploy] object DeployTestUtils { new File("sparkHome"), createDriverDesc(), null, - "akka://worker", + "spark://worker", new SecurityManager(conf)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index ab3d4cafebefa..fdada0777f9a9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -544,7 +544,7 @@ class StandaloneDynamicAllocationSuite val endpointRef = mock(classOf[RpcEndpointRef]) val mockAddress = mock(classOf[RpcAddress]) when(endpointRef.address).thenReturn(mockAddress) - val message = RegisterExecutor(id, endpointRef, s"localhost:$port", 10, Map.empty) + val message = RegisterExecutor(id, endpointRef, 10, Map.empty) val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] backend.driverEndpoint.askWithRetry[CoarseGrainedClusterMessage](message) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index bd8b0655f4bbf..2a1696be3660a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -34,7 +34,7 @@ class DriverRunnerTest extends SparkFunSuite { val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) val conf = new SparkConf() new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - driverDescription, null, "akka://1.2.3.4/worker/", new SecurityManager(conf)) + driverDescription, null, "spark://1.2.3.4/worker/", new SecurityManager(conf)) } private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 64e486d791cde..6f4eda8b47dde 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -626,9 +626,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val e = intercept[Exception] { Await.result(f, 1 seconds) } - assert(e.isInstanceOf[TimeoutException] || // For Akka - e.isInstanceOf[NotSerializableException] // For Netty - ) + assert(e.isInstanceOf[NotSerializableException]) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 70f40fb26c2f6..04cccc67e328e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.util.{AkkaUtils, SerializableBuffer} +import org.apache.spark.util.{RpcUtils, SerializableBuffer} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { - test("serialized task larger than akka frame size") { + test("serialized task larger than max RPC message size") { val conf = new SparkConf - conf.set("spark.akka.frameSize", "1") + conf.set("spark.rpc.message.maxSize", "1") conf.set("spark.default.parallelism", "1") sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) - val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) + val frameSize = RpcUtils.maxMessageSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) val thrown = intercept[SparkException] { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index c87158d89f3fb..58d217ffef566 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { @@ -284,19 +284,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - val conf = new SparkConf().set("spark.akka.frameSize", "1") + val conf = new SparkConf().set("spark.rpc.message.maxSize", "1") sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) - // Make a task whose result is larger than the akka frame size - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt - assert(akkaFrameSize === 1024 * 1024) + // Make a task whose result is larger than the RPC message size + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + assert(maxRpcMessageSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) - .map { x => 1.to(akkaFrameSize).toArray } + .map { x => 1.to(maxRpcMessageSize).toArray } .reduce { case (x, y) => x } - assert(result === 1.to(akkaFrameSize).toArray) + assert(result === 1.to(maxRpcMessageSize).toArray) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val TASK_INDEX = 0 @@ -310,7 +309,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveTaskEvents sc.addSparkListener(listener) - // Make a task whose result is larger than the akka frame size + // Make a task whose result is larger than the RPC message size val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index bc72c3685e8c1..cc2557c2f1df2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.TestUtils.JavaSourceFromString -import org.apache.spark.util.{MutableURLClassLoader, Utils} +import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -77,22 +77,22 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule */ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { - // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small + // Set the RPC message size to be as small as possible (it must be an integer, so 1 is as small // as we can make it) so the tests don't take too long. - def conf: SparkConf = new SparkConf().set("spark.akka.frameSize", "1") + def conf: SparkConf = new SparkConf().set("spark.rpc.message.maxSize", "1") - test("handling results smaller than Akka frame size") { + test("handling results smaller than max RPC message size") { sc = new SparkContext("local", "test", conf) val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } - test("handling results larger than Akka frame size") { + test("handling results larger than max RPC message size") { sc = new SparkContext("local", "test", conf) - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt - val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) - assert(result === 1.to(akkaFrameSize).toArray) + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + val result = + sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x) + assert(result === 1.to(maxRpcMessageSize).toArray) val RESULT_BLOCK_ID = TaskResultBlockId(0) assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, @@ -114,11 +114,11 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local } val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) scheduler.taskResultGetter = resultGetter - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt - val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + val result = + sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x) assert(resultGetter.removeBlockSuccessfully) - assert(result === 1.to(akkaFrameSize).toArray) + assert(result === 1.to(maxRpcMessageSize).toArray) // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 0760529b578ee..4d9937c5cbc34 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -2,9 +2,6 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.jar -akka-actor_2.10-2.3.11.jar -akka-remote_2.10-2.3.11.jar -akka-slf4j_2.10-2.3.11.jar antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -44,7 +41,6 @@ commons-math3-3.4.1.jar commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar -config-1.2.1.jar core-1.1.2.jar curator-client-2.4.0.jar curator-framework-2.4.0.jar @@ -179,7 +175,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 191f2a0e4e86f..fd659ee20df1a 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -2,9 +2,6 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -akka-actor_2.10-2.3.11.jar -akka-remote_2.10-2.3.11.jar -akka-slf4j_2.10-2.3.11.jar antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -45,7 +42,6 @@ commons-math3-3.4.1.jar commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar -config-1.2.1.jar core-1.1.2.jar curator-client-2.4.0.jar curator-framework-2.4.0.jar @@ -170,7 +166,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 9134e997c8457..afae3deb9ada2 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -2,9 +2,6 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -akka-actor_2.10-2.3.11.jar -akka-remote_2.10-2.3.11.jar -akka-slf4j_2.10-2.3.11.jar antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -45,7 +42,6 @@ commons-math3-3.4.1.jar commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar -config-1.2.1.jar core-1.1.2.jar curator-client-2.4.0.jar curator-framework-2.4.0.jar @@ -171,7 +167,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8c45832873848..5a6460136a3a0 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,9 +2,6 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -akka-actor_2.10-2.3.11.jar -akka-remote_2.10-2.3.11.jar -akka-slf4j_2.10-2.3.11.jar antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -49,7 +46,6 @@ commons-math3-3.4.1.jar commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar -config-1.2.1.jar core-1.1.2.jar curator-client-2.6.0.jar curator-framework-2.6.0.jar @@ -177,7 +173,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1d34854819a69..70083e7f3d16a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,9 +2,6 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -akka-actor_2.10-2.3.11.jar -akka-remote_2.10-2.3.11.jar -akka-slf4j_2.10-2.3.11.jar antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -49,7 +46,6 @@ commons-math3-3.4.1.jar commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar -config-1.2.1.jar core-1.1.2.jar curator-client-2.6.0.jar curator-framework-2.6.0.jar @@ -178,7 +174,6 @@ stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar super-csv-2.2.0.jar -uncommons-maths-1.2.2a.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar xbean-asm5-shaded-4.4.jar diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 2810112f5294e..814e4406cf435 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -35,7 +35,7 @@ There are several useful things to note about this architecture: processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). 3. The driver program must listen for and accept incoming connections from its executors throughout - its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + its lifetime (e.g., see [spark.driver.port in the network config section](configuration.html#networking)). As such, the driver program must be network addressable from the worker nodes. 4. Because the driver schedules tasks on the cluster, it should be run close to the worker diff --git a/docs/configuration.md b/docs/configuration.md index acaeb830081e2..d2a2f1052405d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -944,52 +944,12 @@ Apart from these, the following properties are also available, and may be useful - + - - - - - - - - - - - - - - - - - - - - @@ -1015,28 +975,12 @@ Apart from these, the following properties are also available, and may be useful This is used for communicating with the executors and the standalone Master. - - - - - - - - - - diff --git a/docs/security.md b/docs/security.md index a4cc0f42b2482..32c33d285747a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -27,8 +27,7 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP protocols. SASL encryption is supported for the block transfer -service. +Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle files, cached data, and other application files. If encrypting this data is desired, a workaround is @@ -48,10 +47,6 @@ component-specific configuration namespaces used to override the default setting - - - - @@ -137,7 +132,7 @@ configure those ports. - + @@ -145,7 +140,7 @@ configure those ports. - +
    Property NameDefaultMeaning
    spark.akka.frameSizespark.rpc.message.maxSize 128 Maximum message size (in MB) to allow in "control plane" communication; generally only applies to map output size information sent between executors and the driver. Increase this if you are running - jobs with many thousands of map and reduce tasks and see messages about the frame size. -
    spark.akka.heartbeat.interval1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network - with those. -
    spark.akka.heartbeat.pauses6000s - This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart - beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with spark.akka.heartbeat.interval if you need to. -
    spark.akka.threads4 - Number of actor threads to use for communication. Can be useful to increase on large clusters - when the driver has a lot of CPU cores. -
    spark.akka.timeout100s - Communication timeout between Spark nodes. + jobs with many thousands of map and reduce tasks and see messages about the RPC message size.
    spark.executor.port(random) - Port for the executor to listen on. This is used for communicating with the driver. - This is only relevant when using the Akka RPC backend. -
    spark.fileserver.port(random) - Port for the driver's HTTP file server to listen on. - This is only relevant when using the Akka RPC backend. -
    spark.network.timeout 120s Default timeout for all network interactions. This config will be used in place of - spark.core.connection.ack.wait.timeout, spark.akka.timeout, + spark.core.connection.ack.wait.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or spark.rpc.lookupTimeout if they are not configured. @@ -1418,8 +1362,7 @@ Apart from these, the following properties are also available, and may be useful

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for particular protocol denoted by YYY. Currently YYY can be - either akka for Akka based connections or fs for file - server.

    + only fs for file server.

    Config Namespace Component
    spark.ssl.akkaAkka communication channels
    spark.ssl.fs HTTP file server and broadcast server7077 Submit job to cluster /
    Join cluster
    SPARK_MASTER_PORTAkka-based. Set to "0" to choose a port randomly. Standalone mode only.Set to "0" to choose a port randomly. Standalone mode only.
    Standalone Master(random) Schedule executors SPARK_WORKER_PORTAkka-based. Set to "0" to choose a port randomly. Standalone mode only.Set to "0" to choose a port randomly. Standalone mode only.
    @@ -178,24 +173,7 @@ configure those ports. (random) Connect to application /
    Notify executor state changes spark.driver.port - Akka-based. Set to "0" to choose a port randomly. - - - Driver - Executor - (random) - Schedule tasks - spark.executor.port - Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is - configured. - - - Executor - Driver - (random) - File server for files and jars - spark.fileserver.port - Jetty-based. Only used if Akka RPC backend is configured. + Set to "0" to choose a port randomly. Executor / Driver diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 5778fd1d09254..ca7385128d79a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -47,7 +47,7 @@ trait VectorTransformer extends Serializable { */ @Since("1.1.0") def transform(data: RDD[Vector]): RDD[Vector] = { - // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. + // Later in #1498 , all RDD objects are sent via broadcasting instead of RPC. // So it should be no longer necessary to explicitly broadcast `this` object. data.map(x => this.transform(x)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 9b2d023bbf738..95d874b8432eb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -29,7 +29,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => val conf = new SparkConf() .setMaster("local-cluster[2, 1, 1024]") .setAppName("test-cluster") - .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data + .set("spark.rpc.message.maxSize", "1") // set to 1MB to detect direct serialization of data sc = new SparkContext(conf) } diff --git a/pom.xml b/pom.xml index 43f08efaae86d..f08642f606788 100644 --- a/pom.xml +++ b/pom.xml @@ -569,6 +569,11 @@ netty-all 4.0.29.Final
    + + io.netty + netty + 3.8.0.Final + org.apache.derby derby diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 905fb4cd90377..c65fae482c5ca 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -162,7 +162,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") ) ++ Seq( // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 081f5a1c93e6e..898db85190b67 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -41,7 +41,6 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() - val akkaTimeout = 10 seconds val streamId = 1 var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 31fa53e24b507..21ac04dc76c32 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -166,7 +166,7 @@ class ExecutorRunnable( // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend - // uses Akka to connect to the scheduler, the akka settings are needed as well as the + // uses RPC to connect to the scheduler, the RPC settings are needed as well as the // authentication settings. sparkConf.getAll .filter { case (k, v) => SparkConf.isExecutorStartupConf(k) } From ea5c38fe75e4dfa60e61c5b4f20b742b67cb49b2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 22 Jan 2016 22:14:47 -0800 Subject: [PATCH 1661/2089] [HOTFIX]Remove rpcEnv.awaitTermination to avoid dead-lock in some test Looks rpcEnv.awaitTermination may block some tests forever. Just remove it and investigate the tests. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 9461afdc54124..12c7b2048a8c8 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -91,7 +91,6 @@ class SparkEnv ( metricsSystem.stop() outputCommitCoordinator.stop() rpcEnv.shutdown() - rpcEnv.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. From 5af5a02160b42115579003b749c4d1831bf9d48e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 22 Jan 2016 23:53:12 -0800 Subject: [PATCH 1662/2089] [SPARK-12872][SQL] Support to specify the option for compression codec for JSON datasource https://issues.apache.org/jira/browse/SPARK-12872 This PR makes the JSON datasource can compress output by option instead of manually setting Hadoop configurations. For reflecting codec by names, it is similar with https://github.com/apache/spark/pull/10805. As `CSVCompressionCodecs` can be shared with other datasources, it became a separate class to share as `CompressionCodecs`. Author: hyukjinkwon Closes #10858 from HyukjinKwon/SPARK-12872. --- .../datasources/CompressionCodecs.scala | 47 +++++++++++++++++++ .../datasources/csv/CSVParameters.scala | 28 +---------- .../datasources/json/JSONOptions.scala | 12 +++-- .../datasources/json/JSONRelation.scala | 10 ++++ .../datasources/json/JsonSuite.scala | 28 +++++++++++ 5 files changed, 96 insertions(+), 29 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala new file mode 100644 index 0000000000000..e683a95ed2aef --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec} + +import org.apache.spark.util.Utils + +private[datasources] object CompressionCodecs { + private val shortCompressionCodecNames = Map( + "bzip2" -> classOf[BZip2Codec].getName, + "gzip" -> classOf[GzipCodec].getName, + "lz4" -> classOf[Lz4Codec].getName, + "snappy" -> classOf[SnappyCodec].getName) + + /** + * Return the full version of the given codec class. + * If it is already a class name, just return it. + */ + def getCodecClassName(name: String): String = { + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + try { + // Validate the codec name + Utils.classForName(codecName) + codecName + } catch { + case e: ClassNotFoundException => + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala index 676a3d3bca9f7..0278675aa61b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -22,6 +22,7 @@ import java.nio.charset.Charset import org.apache.hadoop.io.compress._ import org.apache.spark.Logging +import org.apache.spark.sql.execution.datasources.CompressionCodecs import org.apache.spark.util.Utils private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { @@ -78,7 +79,7 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] val compressionCodec: Option[String] = { val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CSVCompressionCodecs.getCodecClassName) + name.map(CompressionCodecs.getCodecClassName) } val maxColumns = 20480 @@ -114,28 +115,3 @@ private[csv] object ParseModes { true // We default to permissive is the mode string is not valid } } - -private[csv] object CSVCompressionCodecs { - private val shortCompressionCodecNames = Map( - "bzip2" -> classOf[BZip2Codec].getName, - "gzip" -> classOf[GzipCodec].getName, - "lz4" -> classOf[Lz4Codec].getName, - "snappy" -> classOf[SnappyCodec].getName) - - /** - * Return the full version of the given codec class. - * If it is already a class name, just return it. - */ - def getCodecClassName(name: String): String = { - val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) - try { - // Validate the codec name - Utils.classForName(codecName) - codecName - } catch { - case e: ClassNotFoundException => - throw new IllegalArgumentException(s"Codec [$codecName] " + - s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index aee9cf2bdbcaa..e74a76c532367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import org.apache.spark.sql.execution.datasources.CompressionCodecs + /** * Options for the JSON data source. * @@ -32,7 +34,8 @@ case class JSONOptions( allowSingleQuotes: Boolean = true, allowNumericLeadingZeros: Boolean = false, allowNonNumericNumbers: Boolean = false, - allowBackslashEscapingAnyCharacter: Boolean = false) { + allowBackslashEscapingAnyCharacter: Boolean = false, + compressionCodec: Option[String] = None) { /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -46,7 +49,6 @@ case class JSONOptions( } } - object JSONOptions { def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( samplingRatio = @@ -64,6 +66,10 @@ object JSONOptions { allowNonNumericNumbers = parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false), + compressionCodec = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 31c5620c9a80e..93727abcc7de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.core.JsonFactory import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.{JobConf, TextInputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -162,6 +163,15 @@ private[sql] class JSONRelation( } override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + val conf = job.getConfiguration + options.compressionCodec.foreach { codec => + conf.set("mapreduce.output.fileoutputformat.compress", "true") + conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) + conf.set("mapreduce.map.output.compress", "true") + conf.set("mapreduce.map.output.compress.codec", codec) + } + new BucketedOutputWriterFactory { override def newInstance( path: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a3c6a1d7b20ed..d22fa7905aec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1467,6 +1467,34 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-12872 Support to specify the option for compression codec") { + withTempDir { dir => + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val jsonDF = sqlContext.read.json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write + .format("json") + .option("compression", "gZiP") + .save(jsonDir) + + val compressedFiles = new File(jsonDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".gz"))) + + val jsonCopy = sqlContext.read + .format("json") + .load(jsonDir) + + assert(jsonCopy.count == jsonDF.count) + val jsonCopySome = jsonCopy.selectExpr("string", "long", "boolean") + val jsonDFSome = jsonDF.selectExpr("string", "long", "boolean") + checkAnswer(jsonCopySome, jsonDFSome) + } + } + test("Casting long as timestamp") { withTempTable("jsonTable") { val schema = (new StructType).add("ts", TimestampType) From 1c690ddafa8376c55cbc5b7a7a750200abfbe2a6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 23 Jan 2016 00:34:55 -0800 Subject: [PATCH 1663/2089] [SPARK-12933][SQL] Initial implementation of Count-Min sketch This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1]. As required by the [design doc][2], spark-sketch should have no external dependency. Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation. The following features will be added in future follow-up PRs: - Serialization support - DataFrame API integration [1]: https://github.com/addthis/stream-lib/blob/aac6b4d23a8686b000f80baa447e0922ecac3bcb/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java [2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf Author: Cheng Lian Closes #10851 from liancheng/count-min-sketch. --- common/sketch/pom.xml | 42 +++ .../spark/util/sketch/CountMinSketch.java | 132 +++++++++ .../spark/util/sketch/CountMinSketchImpl.java | 268 ++++++++++++++++++ .../spark/util/sketch/Murmur3_x86_32.java | 126 ++++++++ .../apache/spark/util/sketch/Platform.java | 172 +++++++++++ .../util/sketch/CountMinSketchSuite.scala | 112 ++++++++ dev/sparktestsupport/modules.py | 12 + pom.xml | 1 + project/SparkBuild.scala | 39 ++- 9 files changed, 892 insertions(+), 12 deletions(-) create mode 100644 common/sketch/pom.xml create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java create mode 100644 common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml new file mode 100644 index 0000000000000..67723fa421ab1 --- /dev/null +++ b/common/sketch/pom.xml @@ -0,0 +1,42 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sketch_2.10 + jar + Spark Project Sketch + http://spark.apache.org/ + + sketch + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java new file mode 100644 index 0000000000000..21b161bc74ae0 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in + * sub-linear space. Currently, supported data types include: + *
      + *
    • {@link Byte}
    • + *
    • {@link Short}
    • + *
    • {@link Integer}
    • + *
    • {@link Long}
    • + *
    • {@link String}
    • + *
    + * Each {@link CountMinSketch} is initialized with a random seed, and a pair + * of parameters: + *
      + *
    1. relative error (or {@code eps}), and + *
    2. confidence (or {@code delta}) + *
    + * Suppose you want to estimate the number of times an element {@code x} has appeared in a data + * stream so far. With probability {@code delta}, the estimate of this frequency is within the + * range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the + * total count of items have appeared the the data stream so far. + * + * Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array + * with depth {@code d} and width {@code w}, where + *
      + *
    • {@code d = ceil(2 / eps)}
    • + *
    • {@code w = ceil(-log(1 - confidence) / log(2))}
    • + *
    + * + * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details, + * including proofs of the estimates and error bounds used in this implementation. + * + * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. + */ +abstract public class CountMinSketch { + /** + * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. + */ + public abstract double relativeError(); + + /** + * Returns the confidence (or {@code delta}) of this {@link CountMinSketch}. + */ + public abstract double confidence(); + + /** + * Depth of this {@link CountMinSketch}. + */ + public abstract int depth(); + + /** + * Width of this {@link CountMinSketch}. + */ + public abstract int width(); + + /** + * Total count of items added to this {@link CountMinSketch} so far. + */ + public abstract long totalCount(); + + /** + * Adds 1 to {@code item}. + */ + public abstract void add(Object item); + + /** + * Adds {@code count} to {@code item}. + */ + public abstract void add(Object item, long count); + + /** + * Returns the estimated frequency of {@code item}. + */ + public abstract long estimateCount(Object item); + + /** + * Merges another {@link CountMinSketch} with this one in place. + * + * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed + * can be merged. + */ + public abstract CountMinSketch mergeInPlace(CountMinSketch other); + + /** + * Writes out this {@link CountMinSketch} to an output stream in binary format. + */ + public abstract void writeTo(OutputStream out); + + /** + * Reads in a {@link CountMinSketch} from an input stream. + */ + public static CountMinSketch readFrom(InputStream in) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** + * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random + * {@code seed}. + */ + public static CountMinSketch create(int depth, int width, int seed) { + return new CountMinSketchImpl(depth, width, seed); + } + + /** + * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence}, + * and random {@code seed}. + */ + public static CountMinSketch create(double eps, double confidence, int seed) { + return new CountMinSketchImpl(eps, confidence, seed); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java new file mode 100644 index 0000000000000..e9fdbe3a86862 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.util.Arrays; +import java.util.Random; + +class CountMinSketchImpl extends CountMinSketch { + public static final long PRIME_MODULUS = (1L << 31) - 1; + + private int depth; + private int width; + private long[][] table; + private long[] hashA; + private long totalCount; + private double eps; + private double confidence; + + public CountMinSketchImpl(int depth, int width, int seed) { + this.depth = depth; + this.width = width; + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + initTablesWith(depth, width, seed); + } + + public CountMinSketchImpl(double eps, double confidence, int seed) { + // 2/w = eps ; w = 2/eps + // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) + this.eps = eps; + this.confidence = confidence; + this.width = (int) Math.ceil(2 / eps); + this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2)); + initTablesWith(depth, width, seed); + } + + private void initTablesWith(int depth, int width, int seed) { + this.table = new long[depth][width]; + this.hashA = new long[depth]; + Random r = new Random(seed); + // We're using a linear hash functions + // of the form (a*x+b) mod p. + // a,b are chosen independently for each hash function. + // However we can set b = 0 as all it does is shift the results + // without compromising their uniformity or independence with + // the other hashes. + for (int i = 0; i < depth; ++i) { + hashA[i] = r.nextInt(Integer.MAX_VALUE); + } + } + + @Override + public double relativeError() { + return eps; + } + + @Override + public double confidence() { + return confidence; + } + + @Override + public int depth() { + return depth; + } + + @Override + public int width() { + return width; + } + + @Override + public long totalCount() { + return totalCount; + } + + @Override + public void add(Object item) { + add(item, 1); + } + + @Override + public void add(Object item, long count) { + if (item instanceof String) { + addString((String) item, count); + } else { + long longValue; + + if (item instanceof Long) { + longValue = (Long) item; + } else if (item instanceof Integer) { + longValue = ((Integer) item).longValue(); + } else if (item instanceof Short) { + longValue = ((Short) item).longValue(); + } else if (item instanceof Byte) { + longValue = ((Byte) item).longValue(); + } else { + throw new IllegalArgumentException( + "Support for " + item.getClass().getName() + " not implemented" + ); + } + + addLong(longValue, count); + } + } + + private void addString(String item, long count) { + if (count < 0) { + throw new IllegalArgumentException("Negative increments not implemented"); + } + + int[] buckets = getHashBuckets(item, depth, width); + + for (int i = 0; i < depth; ++i) { + table[i][buckets[i]] += count; + } + + totalCount += count; + } + + private void addLong(long item, long count) { + if (count < 0) { + throw new IllegalArgumentException("Negative increments not implemented"); + } + + for (int i = 0; i < depth; ++i) { + table[i][hash(item, i)] += count; + } + + totalCount += count; + } + + private int hash(long item, int count) { + long hash = hashA[count] * item; + // A super fast way of computing x mod 2^p-1 + // See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf + // page 149, right after Proposition 7. + hash += hash >> 32; + hash &= PRIME_MODULUS; + // Doing "%" after (int) conversion is ~2x faster than %'ing longs. + return ((int) hash) % width; + } + + private static int[] getHashBuckets(String key, int hashCount, int max) { + byte[] b; + try { + b = key.getBytes("UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + return getHashBuckets(b, hashCount, max); + } + + private static int[] getHashBuckets(byte[] b, int hashCount, int max) { + int[] result = new int[hashCount]; + int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); + int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1); + for (int i = 0; i < hashCount; i++) { + result[i] = Math.abs((hash1 + i * hash2) % max); + } + return result; + } + + @Override + public long estimateCount(Object item) { + if (item instanceof String) { + return estimateCountForStringItem((String) item); + } else { + long longValue; + + if (item instanceof Long) { + longValue = (Long) item; + } else if (item instanceof Integer) { + longValue = ((Integer) item).longValue(); + } else if (item instanceof Short) { + longValue = ((Short) item).longValue(); + } else if (item instanceof Byte) { + longValue = ((Byte) item).longValue(); + } else { + throw new IllegalArgumentException( + "Support for " + item.getClass().getName() + " not implemented" + ); + } + + return estimateCountForLongItem(longValue); + } + } + + private long estimateCountForLongItem(long item) { + long res = Long.MAX_VALUE; + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][hash(item, i)]); + } + return res; + } + + private long estimateCountForStringItem(String item) { + long res = Long.MAX_VALUE; + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][buckets[i]]); + } + return res; + } + + @Override + public CountMinSketch mergeInPlace(CountMinSketch other) { + if (other == null) { + throw new CMSMergeException("Cannot merge null estimator"); + } + + if (!(other instanceof CountMinSketchImpl)) { + throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName()); + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + if (this.depth != that.depth) { + throw new CMSMergeException("Cannot merge estimators of different depth"); + } + + if (this.width != that.width) { + throw new CMSMergeException("Cannot merge estimators of different width"); + } + + if (!Arrays.equals(this.hashA, that.hashA)) { + throw new CMSMergeException("Cannot merge estimators of different seed"); + } + + for (int i = 0; i < this.table.length; ++i) { + for (int j = 0; j < this.table[i].length; ++j) { + this.table[i][j] = this.table[i][j] + that.table[i][j]; + } + } + + this.totalCount += that.totalCount; + + return this; + } + + @Override + public void writeTo(OutputStream out) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + protected static class CMSMergeException extends RuntimeException { + public CMSMergeException(String message) { + super(message); + } + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java new file mode 100644 index 0000000000000..3d1f28bcb911e --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure +// spark-sketch has no external dependencies. +final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + return hashInt(input, seed); + } + + public static int hashInt(int input, int seed) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = Platform.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); + int h1 = seed; + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = Platform.getInt(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return h1; + } + + public int hashLong(long input) { + return hashLong(input, seed); + } + + public static int hashLong(long input, int seed) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java new file mode 100644 index 0000000000000..75d6a6beec408 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no +// external dependencies. +final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala new file mode 100644 index 0000000000000..ec5b4eddeca0d --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite + private val epsOfTotalCount = 0.0001 + + private val confidence = 0.99 + + private val seed = 42 + + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"accuracy - $typeName") { + val r = new Random() + + val numAllItems = 1000000 + val allItems = Array.fill(numAllItems)(itemGenerator(r)) + + val numSamples = numAllItems / 10 + val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + + val exactFreq = { + val sampledItems = sampledItemIndices.map(allItems) + sampledItems.groupBy(identity).mapValues(_.length.toLong) + } + + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + + val probCorrect = { + val numErrors = allItems.map { item => + val count = exactFreq.getOrElse(item, 0L) + val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems + if (ratio > epsOfTotalCount) 1 else 0 + }.sum + + 1D - numErrors.toDouble / numAllItems + } + + assert( + probCorrect > confidence, + s"Confidence not reached: required $confidence, reached $probCorrect" + ) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + val r = new Random() + val numToMerge = 5 + val numItemsPerSketch = 100000 + val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { + itemGenerator(r) + } + + val sketches = perSketchItems.map { items => + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + items.foreach(sketch.add) + sketch + } + + val mergedSketch = sketches.reduce(_ mergeInPlace _) + + val expectedSketch = { + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + perSketchItems.foreach(_.foreach(sketch.add)) + sketch + } + + perSketchItems.foreach { + _.foreach { item => + assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item)) + } + } + } + } + + def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + testAccuracy[T](typeName)(itemGenerator) + testMergeInPlace[T](typeName)(itemGenerator) + } + + testItemType[Byte]("Byte") { _.nextInt().toByte } + + testItemType[Short]("Short") { _.nextInt().toShort } + + testItemType[Int]("Int") { _.nextInt() } + + testItemType[Long]("Long") { _.nextLong() } + + testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } +} diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index efe58ea2e0e78..032c0616edb1e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -113,6 +113,18 @@ def contains_file(self, filename): ) +sketch = Module( + name="sketch", + dependencies=[], + source_file_regexes=[ + "common/sketch/", + ], + sbt_test_goals=[ + "sketch/test" + ] +) + + graphx = Module( name="graphx", dependencies=[], diff --git a/pom.xml b/pom.xml index f08642f606788..fb7750602c425 100644 --- a/pom.xml +++ b/pom.xml @@ -86,6 +86,7 @@ + common/sketch tags core graphx diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3927b88fb0bf6..4224a65a822b8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,13 +34,24 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = - Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", - "streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver" + ).map(ProjectRef(buildLocation, _)) + + val streamingProjects@Seq( + streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt, + streamingTwitter, streamingZeromq + ) = Seq( + "streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka", + "streaming-mqtt", "streaming-twitter", "streaming-zeromq" + ).map(ProjectRef(buildLocation, _)) + + val allProjects@Seq( + core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + ) = Seq( + "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "test-tags", "sketch" + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl, streamingKinesisAsl, dockerIntegrationTests) = @@ -232,11 +243,15 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove streamingAkka from this list after 2.0.0 - allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach { - x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) - } + // TODO: remove streamingAkka and sketch from this list after 2.0.0 + allProjects.filterNot { x => + Seq( + spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, + unsafe, streamingAkka, testTags, sketch + ).contains(x) + }.foreach { x => + enable(MimaBuild.mimaSettings(sparkHome, x))(x) + } /* Unsafe settings */ enable(Unsafe.settings)(unsafe) From 358a33bbff549826b2336c317afc7274bdd30fdb Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Sat, 23 Jan 2016 20:19:58 +0900 Subject: [PATCH 1664/2089] [SPARK-12859][STREAMING][WEB UI] Names of input streams with receivers don't fit in Streaming page Added CSS style to force names of input streams with receivers to wrap Author: Alex Bozarth Closes #10873 from ajbozarth/spark12859. --- .../scala/org/apache/spark/streaming/ui/StreamingPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index b3692c3ea302b..c5d9f26cb241e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -466,7 +466,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
    -
    {receiverName}
    +
    {receiverName}
    Avg: {receivedRecords.formattedAvg} events/sec
    From 56f57f894eafeda48ce118eec16ecb88dbd1b9dc Mon Sep 17 00:00:00 2001 From: Mortada Mehyar Date: Sat, 23 Jan 2016 11:36:33 +0000 Subject: [PATCH 1665/2089] =?UTF-8?q?[SPARK-12760][DOCS]=20invalid=20lambd?= =?UTF-8?q?a=20expression=20in=20python=20example=20for=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …local vs cluster srowen thanks for the PR at https://github.com/apache/spark/pull/10866! sorry it took me a while. This is related to https://github.com/apache/spark/pull/10866, basically the assignment in the lambda expression in the python example is actually invalid ``` In [1]: data = [1, 2, 3, 4, 5] In [2]: counter = 0 In [3]: rdd = sc.parallelize(data) In [4]: rdd.foreach(lambda x: counter += x) File "", line 1 rdd.foreach(lambda x: counter += x) ^ SyntaxError: invalid syntax ``` Author: Mortada Mehyar Closes #10867 from mortada/doc_python_fix. --- docs/programming-guide.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index bad25e63e89e6..4d21d4320c50d 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -789,9 +789,12 @@ counter = 0 rdd = sc.parallelize(data) # Wrong: Don't do this!! -rdd.foreach(lambda x: counter += x) +def increment_counter(x): + global counter + counter += x +rdd.foreach(increment_counter) -print("Counter value: " + counter) +print("Counter value: ", counter) {% endhighlight %}
    From aca2a0165405b9eba27ac5e4739e36a618b96676 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 23 Jan 2016 11:45:12 +0000 Subject: [PATCH 1666/2089] [SPARK-12760][DOCS] inaccurate description for difference between local vs cluster mode in closure handling Clarify that modifying a driver local variable won't have the desired effect in cluster modes, and may or may not work as intended in local mode Author: Sean Owen Closes #10866 from srowen/SPARK-12760. --- docs/programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4d21d4320c50d..e45081464af1c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -755,7 +755,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which may behave differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
    @@ -803,11 +803,11 @@ print("Counter value: ", counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The behavior of the above code is undefined, and may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks, each of which is executed by an executor. Prior to execution, Spark computes the task's **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. -However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on separate worker nodes each have their own copy of the closure. +The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. +In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. From 5f56980127704d3c2877d0d0b5047791c00fdac9 Mon Sep 17 00:00:00 2001 From: jayadevanmurali Date: Sat, 23 Jan 2016 11:48:48 +0000 Subject: [PATCH 1667/2089] [SPARK-11137][STREAMING] Make StreamingContext.stop() exception-safe Make StreamingContext.stop() exception-safe Author: jayadevanmurali Closes #10807 from jayadevanmurali/branch-0.1-SPARK-11137. --- .../spark/streaming/StreamingContext.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ec57c05e3b5bb..32bea88ec6cc0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -693,12 +693,20 @@ class StreamingContext private[streaming] ( // interrupted. See SPARK-12001 for more details. Because the body of this case can be // executed twice in the case of a partial stop, all methods called here need to be // idempotent. - scheduler.stop(stopGracefully) + Utils.tryLogNonFatalError { + scheduler.stop(stopGracefully) + } // Removing the streamingSource to de-register the metrics on stop() - env.metricsSystem.removeSource(streamingSource) - uiTab.foreach(_.detach()) + Utils.tryLogNonFatalError { + env.metricsSystem.removeSource(streamingSource) + } + Utils.tryLogNonFatalError { + uiTab.foreach(_.detach()) + } StreamingContext.setActiveContext(null) - waiter.notifyStop() + Utils.tryLogNonFatalError { + waiter.notifyStop() + } if (shutdownHookRef != null) { shutdownHookRefToRemove = shutdownHookRef shutdownHookRef = null From 423783a08bb8730852973aca19603e444d15040d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 23 Jan 2016 12:13:05 -0800 Subject: [PATCH 1668/2089] [SPARK-12904][SQL] Strength reduction for integral and decimal literal comparisons This pull request implements strength reduction for comparing integral expressions and decimal literals, which is more common now because we switch to parsing fractional literals as decimal types (rather than doubles). I added the rules to the existing DecimalPrecision rule with some refactoring to simplify the control flow. I also moved DecimalPrecision rule into its own file due to the growing size. Author: Reynold Xin Closes #10882 from rxin/SPARK-12904-1. --- .../catalyst/analysis/DecimalPrecision.scala | 259 ++++++++++++++++++ .../catalyst/analysis/HiveTypeCoercion.scala | 138 +--------- .../sql/catalyst/expressions/literals.scala | 18 ++ .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 97 ++++++- .../org/apache/spark/sql/SQLContext.scala | 1 + 6 files changed, 376 insertions(+), 139 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala new file mode 100644 index 0000000000000..ad56c9864979b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + + +// scalastyle:off +/** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) + * sum(e1) p1 + 10 s1 + * avg(e1) p1 + 4 s1 + 4 + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + */ +// scalastyle:on +object DecimalPrecision extends Rule[LogicalPlan] { + import scala.math.{max, min} + + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } + + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // fix decimal precision for expressions + case q => q.transformExpressions( + decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) + } + + /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ + private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + val resultType = DecimalType.bounded(intDig + decDig, decDig) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = widerDecimalType(p1, s1, p2, s2) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + + // TODO: MaxOf, MinOf, etc might want other rules + // SUM and AVERAGE are handled by the implementations of those expressions + } + + /** + * Strength reduction for comparing integral expressions with decimal literals. + * 1. int_col > decimal_literal => int_col > floor(decimal_literal) + * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal) + * 3. int_col < decimal_literal => int_col < ceil(decimal_literal) + * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal) + * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col + * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col + * 7. decimal_literal < int_col => floor(decimal_literal) < int_col + * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col + * + * Note that technically this is an "optimization" and should go into the optimizer. However, + * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern + * match because there are multuple (at least 2) levels of casts involved. + * + * There are a lot more possible rules we can implement, but we don't do them + * because we are not sure how common they are. + */ + private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = { + + case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThan(i, Literal(value.floor.toLong)) + } + + case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThanOrEqual(i, Literal(value.ceil.toLong)) + } + + case LessThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThan(i, Literal(value.ceil.toLong)) + } + + case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThanOrEqual(i, Literal(value.floor.toLong)) + } + + case GreaterThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThan(Literal(value.ceil.toLong), i) + } + + case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThanOrEqual(Literal(value.floor.toLong), i) + } + + case LessThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThan(Literal(value.floor.toLong), i) + } + + case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThanOrEqual(Literal(value.ceil.toLong), i) + } + } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other + * side is a decimal. + */ + private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + (left.dataType, right.dataType) match { + case (t: IntegralType, DecimalType.Fixed(p, s)) => + b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) + case (DecimalType.Fixed(p, s), t: IntegralType) => + b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(left, Cast(right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(left, DoubleType), right)) + case _ => + b + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6e43bdfd92d0e..957ac89fa530d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -81,7 +82,7 @@ object HiveTypeCoercion { * Find the tightest common type of two types that might be used in a binary expression. * This handles all numeric types except fixed-precision decimals interacting with each other or * with primitive types, because in that case the precision and scale of the result depends on - * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. + * the operation. Those rules are implemented in [[DecimalPrecision]]. */ val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) @@ -381,141 +382,6 @@ object HiveTypeCoercion { } } - // scalastyle:off - /** - * Calculates and propagates precision for fixed-precision decimals. Hive has a number of - * rules for this based on the SQL standard and MS SQL: - * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf - * https://msdn.microsoft.com/en-us/library/ms190476.aspx - * - * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 - * respectively, then the following operations have the following precision / scale: - * - * Operation Result Precision Result Scale - * ------------------------------------------------------------------------ - * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) - * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) - * e1 * e2 p1 + p2 + 1 s1 + s2 - * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) - * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) - * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 - * - * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. - * - * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited - * precision, do the math on unlimited-precision numbers, then introduce casts back to the - * required fixed precision. This allows us to do all rounding and overflow handling in the - * cast-to-fixed-precision operator. - * - * In addition, when mixing non-decimal types with decimals, we use the following rules: - * - BYTE gets turned into DECIMAL(3, 0) - * - SHORT gets turned into DECIMAL(5, 0) - * - INT gets turned into DECIMAL(10, 0) - * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE - */ - // scalastyle:on - object DecimalPrecision extends Rule[LogicalPlan] { - import scala.math.{max, min} - - private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - - // Returns the wider decimal type that's wider than both of them - def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { - widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) - } - // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) - def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { - val scale = max(s1, s2) - val range = max(p1 - s1, p2 - s2) - DecimalType.bounded(range + scale, scale) - } - - private def promotePrecision(e: Expression, dataType: DataType): Expression = { - PromotePrecision(Cast(e, dataType)) - } - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - - // fix decimal precision for expressions - case q => q.transformExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e - - // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e - - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig - } - val resultType = DecimalType.bounded(intDig + decDig, decDig) - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - // resultType may have lower precision, so we cast them into wider type first. - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - // resultType may have lower precision, so we cast them into wider type first. - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - val resultType = widerDecimalType(p1, s1, p2, s2) - b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b - } - - // TODO: MaxOf, MinOf, etc might want other rules - - // SUM and AVERAGE are handled by the implementations of those expressions - } - } - } - /** * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index db30845fdab6c..ca0892eb42158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -139,6 +139,24 @@ object IntegerLiteral { } } +/** + * Extractor for and other utility methods for decimal literals. + */ +object DecimalLiteral { + def apply(v: Long): Literal = Literal(Decimal(v)) + + def apply(v: Double): Literal = Literal(Decimal(v)) + + def unapply(e: Expression): Option[Decimal] = e match { + case Literal(v, _: DecimalType) => Some(v.asInstanceOf[Decimal]) + case _ => None + } + + def largerThanLargestLong(v: Decimal): Boolean = v > Decimal(Long.MaxValue) + + def smallerThanSmallestLong(v: Decimal): Boolean = v < Decimal(Long.MinValue) +} + /** * In order to do type checking, use Literal.create() instead of constructor */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 44455b482074b..6addc2080648b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1006,7 +1006,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. * * This uses the same rules for increasing the precision and scale of the output as - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]]. + * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 24c608eaa5b39..b2613e4909288 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union} import org.apache.spark.sql.types._ -class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { + +class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { val conf = new SimpleCatalystConf(true) val catalog = new SimpleCatalog(conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) @@ -181,4 +184,94 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { assert(d4.isWiderThan(FloatType) === false) assert(d4.isWiderThan(DoubleType) === false) } + + test("strength reduction for integer/decimal comparisons - basic test") { + Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt => + val int = AttributeReference("a", dt)() + + ruleTest(int > Literal(Decimal(4)), int > Literal(4L)) + ruleTest(int > Literal(Decimal(4.7)), int > Literal(4L)) + + ruleTest(int >= Literal(Decimal(4)), int >= Literal(4L)) + ruleTest(int >= Literal(Decimal(4.7)), int >= Literal(5L)) + + ruleTest(int < Literal(Decimal(4)), int < Literal(4L)) + ruleTest(int < Literal(Decimal(4.7)), int < Literal(5L)) + + ruleTest(int <= Literal(Decimal(4)), int <= Literal(4L)) + ruleTest(int <= Literal(Decimal(4.7)), int <= Literal(4L)) + + ruleTest(Literal(Decimal(4)) > int, Literal(4L) > int) + ruleTest(Literal(Decimal(4.7)) > int, Literal(5L) > int) + + ruleTest(Literal(Decimal(4)) >= int, Literal(4L) >= int) + ruleTest(Literal(Decimal(4.7)) >= int, Literal(4L) >= int) + + ruleTest(Literal(Decimal(4)) < int, Literal(4L) < int) + ruleTest(Literal(Decimal(4.7)) < int, Literal(4L) < int) + + ruleTest(Literal(Decimal(4)) <= int, Literal(4L) <= int) + ruleTest(Literal(Decimal(4.7)) <= int, Literal(5L) <= int) + + } + } + + test("strength reduction for integer/decimal comparisons - overflow test") { + val maxValue = Literal(Decimal(Long.MaxValue)) + val overflow = Literal(Decimal(Long.MaxValue) + Decimal(0.1)) + val minValue = Literal(Decimal(Long.MinValue)) + val underflow = Literal(Decimal(Long.MinValue) - Decimal(0.1)) + + Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt => + val int = AttributeReference("a", dt)() + + ruleTest(int > maxValue, int > Literal(Long.MaxValue)) + ruleTest(int > overflow, FalseLiteral) + ruleTest(int > minValue, int > Literal(Long.MinValue)) + ruleTest(int > underflow, TrueLiteral) + + ruleTest(int >= maxValue, int >= Literal(Long.MaxValue)) + ruleTest(int >= overflow, FalseLiteral) + ruleTest(int >= minValue, int >= Literal(Long.MinValue)) + ruleTest(int >= underflow, TrueLiteral) + + ruleTest(int < maxValue, int < Literal(Long.MaxValue)) + ruleTest(int < overflow, TrueLiteral) + ruleTest(int < minValue, int < Literal(Long.MinValue)) + ruleTest(int < underflow, FalseLiteral) + + ruleTest(int <= maxValue, int <= Literal(Long.MaxValue)) + ruleTest(int <= overflow, TrueLiteral) + ruleTest(int <= minValue, int <= Literal(Long.MinValue)) + ruleTest(int <= underflow, FalseLiteral) + + ruleTest(maxValue > int, Literal(Long.MaxValue) > int) + ruleTest(overflow > int, TrueLiteral) + ruleTest(minValue > int, Literal(Long.MinValue) > int) + ruleTest(underflow > int, FalseLiteral) + + ruleTest(maxValue >= int, Literal(Long.MaxValue) >= int) + ruleTest(overflow >= int, TrueLiteral) + ruleTest(minValue >= int, Literal(Long.MinValue) >= int) + ruleTest(underflow >= int, FalseLiteral) + + ruleTest(maxValue < int, Literal(Long.MaxValue) < int) + ruleTest(overflow < int, FalseLiteral) + ruleTest(minValue < int, Literal(Long.MinValue) < int) + ruleTest(underflow < int, TrueLiteral) + + ruleTest(maxValue <= int, Literal(Long.MaxValue) <= int) + ruleTest(overflow <= int, FalseLiteral) + ruleTest(minValue <= int, Literal(Long.MinValue) <= int) + ruleTest(underflow <= int, TrueLiteral) + } + } + + /** strength reduction for integer/decimal comparisons */ + def ruleTest(initial: Expression, transformed: Expression): Unit = { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + DecimalPrecision(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 147e3557b632b..b774da33aebe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -74,6 +74,7 @@ class SQLContext private[sql]( def this(sparkContext: SparkContext) = { this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true) } + def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user From cfdcef70ddd25484f1cb1791e529210d602c2283 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sat, 23 Jan 2016 12:14:16 -0800 Subject: [PATCH 1669/2089] [STREAMING][MINOR] Scaladoc + logs Found while doing code review Author: Jacek Laskowski Closes #10878 from jaceklaskowski/streaming-scaladoc-logs-tiny-fixes. --- .../main/scala/org/apache/spark/streaming/StateSpec.scala | 5 ++--- .../apache/spark/streaming/dstream/MapWithStateDStream.scala | 2 +- .../spark/streaming/scheduler/ReceivedBlockTracker.scala | 2 +- .../spark/streaming/ui/StreamingJobProgressListener.scala | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index f1114c1e5ac6a..66f646d7dc136 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -30,9 +30,8 @@ import org.apache.spark.util.ClosureCleaner * `mapWithState` operation of a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). - * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or - * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of - * this class. + * Use [[org.apache.spark.streaming.StateSpec.function() StateSpec.function]] factory methods + * to create instances of this class. * * Example in Scala: * {{{ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index 36ff9c7e6182f..ed08191f41cc8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -90,7 +90,7 @@ private[streaming] class MapWithStateDStreamImpl[ } /** - * A DStream that allows per-key state to be maintains, and arbitrary records to be generated + * A DStream that allows per-key state to be maintained, and arbitrary records to be generated * based on updates to the state. This is the main DStream that implements the `mapWithState` * operation on DStreams. * diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 60b5c838e9734..5f1c671c3c568 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -166,7 +166,7 @@ private[streaming] class ReceivedBlockTracker( def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq - logInfo("Deleting batches " + timesToCleanup) + logInfo(s"Deleting batches: ${timesToCleanup.mkString(" ")}") if (writeToLog(BatchCleanupEvent(timesToCleanup))) { timeToAllocatedBlocks --= timesToCleanup writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 4908be0536353..cacd430cf339c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -91,7 +91,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { val batchUIData = BatchUIData(batchStarted.batchInfo) - runningBatchUIData(batchStarted.batchInfo.batchTime) = BatchUIData(batchStarted.batchInfo) + runningBatchUIData(batchStarted.batchInfo.batchTime) = batchUIData waitingBatchUIData.remove(batchStarted.batchInfo.batchTime) totalReceivedRecords += batchUIData.numRecords From f4004601b00044e1d5c57e72c45b815f36feac3e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 24 Jan 2016 11:29:27 -0800 Subject: [PATCH 1670/2089] [SPARK-12971] Fix Hive tests which fail in Hadoop-2.3 SBT build ErrorPositionSuite and one of the HiveComparisonTest tests have been consistently failing on the Hadoop 2.3 SBT build (but on no other builds). I believe that this is due to test isolation issues (e.g. tests sharing state via the sets of temporary tables that are registered to TestHive). This patch attempts to improve the isolation of these tests in order to address this issue. Author: Josh Rosen Closes #10884 from JoshRosen/fix-failing-hadoop-2.3-hive-tests. --- .../spark/sql/hive/ErrorPositionSuite.scala | 20 ++++++++++++++++--- .../hive/execution/HiveComparisonTest.scala | 6 +++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 14a466cfe9486..4b6da7cd33692 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -19,20 +19,34 @@ package org.apache.spark.sql.hive import scala.util.Try -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton -class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { import hiveContext.implicits._ - before { + override protected def beforeEach(): Unit = { + super.beforeEach() + if (sqlContext.tableNames().contains("src")) { + sqlContext.dropTempTable("src") + } + Seq((1, "")).toDF("key", "value").registerTempTable("src") Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") } + override protected def afterEach(): Unit = { + try { + sqlContext.dropTempTable("src") + sqlContext.dropTempTable("dupAttributes") + } finally { + super.afterEach() + } + } + positionTest("ambiguous attribute reference 1", "SELECT a from dupAttributes", "a") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 2e0a8698e6ffe..207bb814f0a27 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -150,7 +150,11 @@ abstract class HiveComparisonTest """.stripMargin }) - super.afterAll() + try { + TestHive.reset() + } finally { + super.afterAll() + } } protected def prepareAnswer( From a83400135d7506fd2e1c8971f372012fcb3cef23 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 24 Jan 2016 11:48:28 -0800 Subject: [PATCH 1671/2089] [SPARK-10498][TOOLS][BUILD] Add requirements.txt file for dev python tools Minor since so few people use them, but it would probably be good to have a requirements file for our python release tools for easier setup (also version pinning). cc JoshRosen who looked at the original JIRA. Author: Holden Karau Closes #10871 from holdenk/SPARK-10498-add-requirements-file-for-dev-python-tools. --- dev/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 dev/requirements.txt diff --git a/dev/requirements.txt b/dev/requirements.txt new file mode 100644 index 0000000000000..bf042d22a8b47 --- /dev/null +++ b/dev/requirements.txt @@ -0,0 +1,3 @@ +jira==1.0.3 +PyGithub==1.26.0 +Unidecode==0.04.19 From e789b1d2c1eab6187f54424ed92697ca200c3101 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Sun, 24 Jan 2016 12:29:26 -0800 Subject: [PATCH 1672/2089] =?UTF-8?q?[SPARK-12120][PYSPARK]=20Improve=20ex?= =?UTF-8?q?ception=20message=20when=20failing=20to=20init=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ialize HiveContext in PySpark davies Mind to review ? This is the error message after this PR ``` 15/12/03 16:59:53 WARN ObjectStore: Failed to get database default, returning NoSuchObjectException /Users/jzhang/github/spark/python/pyspark/sql/context.py:689: UserWarning: You must build Spark with Hive. Export 'SPARK_HIVE=true' and run build/sbt assembly warnings.warn("You must build Spark with Hive. " Traceback (most recent call last): File "", line 1, in File "/Users/jzhang/github/spark/python/pyspark/sql/context.py", line 663, in read return DataFrameReader(self) File "/Users/jzhang/github/spark/python/pyspark/sql/readwriter.py", line 56, in __init__ self._jreader = sqlContext._ssql_ctx.read() File "/Users/jzhang/github/spark/python/pyspark/sql/context.py", line 692, in _ssql_ctx raise e py4j.protocol.Py4JJavaError: An error occurred while calling None.org.apache.spark.sql.hive.HiveContext. : java.lang.RuntimeException: java.net.ConnectException: Call From jzhangMBPr.local/127.0.0.1 to 0.0.0.0:9000 failed on connection exception: java.net.ConnectException: Connection refused; For more details see: http://wiki.apache.org/hadoop/ConnectionRefused at org.apache.hadoop.hive.ql.session.SessionState.start(SessionState.java:522) at org.apache.spark.sql.hive.client.ClientWrapper.(ClientWrapper.scala:194) at org.apache.spark.sql.hive.client.IsolatedClientLoader.createClient(IsolatedClientLoader.scala:238) at org.apache.spark.sql.hive.HiveContext.executionHive$lzycompute(HiveContext.scala:218) at org.apache.spark.sql.hive.HiveContext.executionHive(HiveContext.scala:208) at org.apache.spark.sql.hive.HiveContext.functionRegistry$lzycompute(HiveContext.scala:462) at org.apache.spark.sql.hive.HiveContext.functionRegistry(HiveContext.scala:461) at org.apache.spark.sql.UDFRegistration.(UDFRegistration.scala:40) at org.apache.spark.sql.SQLContext.(SQLContext.scala:330) at org.apache.spark.sql.hive.HiveContext.(HiveContext.scala:90) at org.apache.spark.sql.hive.HiveContext.(HiveContext.scala:101) at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:57) at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45) at java.lang.reflect.Constructor.newInstance(Constructor.java:526) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:234) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:381) at py4j.Gateway.invoke(Gateway.java:214) at py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:79) at py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:68) at py4j.GatewayConnection.run(GatewayConnection.java:209) at java.lang.Thread.run(Thread.java:745) ``` Author: Jeff Zhang Closes #10126 from zjffdu/SPARK-12120. --- python/pyspark/sql/context.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 91e27cf16e439..7f6fb410abe81 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -15,6 +15,7 @@ # limitations under the License. # +from __future__ import print_function import sys import warnings import json @@ -571,9 +572,10 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) + print("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", file=sys.stderr) + raise def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) From 3327fd28170b549516fee1972dc6f4c32541591b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 24 Jan 2016 19:40:34 -0800 Subject: [PATCH 1673/2089] [SPARK-12624][PYSPARK] Checks row length when converting Java arrays to Python rows When actual row length doesn't conform to specified schema field length, we should give a better error message instead of throwing an unintuitive `ArrayOutOfBoundsException`. Author: Cheng Lian Closes #10886 from liancheng/spark-12624. --- python/pyspark/sql/tests.py | 9 +++++++++ .../scala/org/apache/spark/sql/execution/python.scala | 9 ++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ae8620274dd20..7593b991a780b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -364,6 +364,15 @@ def test_infer_schema_to_local(self): df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_create_dataframe_schema_mismatch(self): + input = [Row(a=1)] + rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) + schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) + df = self.sqlCtx.createDataFrame(rdd, schema) + message = ".*Input row doesn't have expected number of values required by the schema.*" + with self.assertRaisesRegexp(Exception, message): + df.show() + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index 41e35fd724cde..e3a016e18db87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -220,7 +220,14 @@ object EvaluatePython { ArrayBasedMapData(keys, values) case (c, StructType(fields)) if c.getClass.isArray => - new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + new GenericInternalRow(array.zip(fields).map { case (e, f) => fromJava(e, f.dataType) }) From 3adebfc9a37fdee5b7a4e891c4ee597b85f824c3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jan 2016 00:57:56 -0800 Subject: [PATCH 1674/2089] [SPARK-12901][SQL] Refactor options for JSON and CSV datasource (not case class and same format). https://issues.apache.org/jira/browse/SPARK-12901 This PR refactors the options in JSON and CSV datasources. In more details, 1. `JSONOptions` uses the same format as `CSVOptions`. 2. Not case classes. 3. `CSVRelation` that does not have to be serializable (it was `with Serializable` but I removed) Author: hyukjinkwon Closes #10895 from HyukjinKwon/SPARK-12901. --- .../{CSVParameters.scala => CSVOptions.scala} | 7 +-- .../execution/datasources/csv/CSVParser.scala | 8 +-- .../datasources/csv/CSVRelation.scala | 12 ++-- .../datasources/json/JSONOptions.scala | 59 ++++++++----------- .../datasources/json/JSONRelation.scala | 2 +- .../datasources/json/JsonSuite.scala | 4 +- 6 files changed, 40 insertions(+), 52 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/{CSVParameters.scala => CSVOptions.scala} (95%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 0278675aa61b0..5d0e99d7601dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.Charset -import org.apache.hadoop.io.compress._ - import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs -import org.apache.spark.util.Utils -private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { +private[sql] class CSVOptions( + @transient parameters: Map[String, String]) + extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { val paramValue = parameters.get(paramName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index ba1cc42f3e446..8f1421844c648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -29,7 +29,7 @@ import org.apache.spark.Logging * @param params Parameters object * @param headers headers for the columns */ -private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String]) { +private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { protected lazy val parser: CsvParser = { val settings = new CsvParserSettings() @@ -58,7 +58,7 @@ private[sql] abstract class CsvReader(params: CSVParameters, headers: Seq[String * @param params Parameters object for configuration * @param headers headers for columns */ -private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) extends Logging { +private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat @@ -93,7 +93,7 @@ private[sql] class LineCsvWriter(params: CSVParameters, headers: Seq[String]) ex * * @param params Parameters object */ -private[sql] class LineCsvReader(params: CSVParameters) +private[sql] class LineCsvReader(params: CSVOptions) extends CsvReader(params, null) { /** * parse a line @@ -118,7 +118,7 @@ private[sql] class LineCsvReader(params: CSVParameters) */ private[sql] class BulkCsvReader( iter: Iterator[String], - params: CSVParameters, + params: CSVOptions, headers: Seq[String]) extends CsvReader(params, headers) with Iterator[Array[String]] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 1502501c3b89e..5959f7cc5051b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -43,14 +43,14 @@ private[csv] class CSVRelation( private val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], private val parameters: Map[String, String]) - (@transient val sqlContext: SQLContext) extends HadoopFsRelation with Serializable { + (@transient val sqlContext: SQLContext) extends HadoopFsRelation { override lazy val dataSchema: StructType = maybeDataSchema match { case Some(structType) => structType case None => inferSchema(paths) } - private val params = new CSVParameters(parameters) + private val params = new CSVOptions(parameters) @transient private var cachedRDD: Option[RDD[String]] = None @@ -170,7 +170,7 @@ object CSVRelation extends Logging { file: RDD[String], header: Seq[String], firstLine: String, - params: CSVParameters): RDD[Array[String]] = { + params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. file.mapPartitionsWithIndex({ case (split, iter) => new BulkCsvReader( @@ -186,7 +186,7 @@ object CSVRelation extends Logging { requiredColumns: Array[String], inputs: Array[FileStatus], sqlContext: SQLContext, - params: CSVParameters): RDD[Row] = { + params: CSVOptions): RDD[Row] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields @@ -249,7 +249,7 @@ object CSVRelation extends Logging { } } -private[sql] class CSVOutputWriterFactory(params: CSVParameters) extends OutputWriterFactory { +private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, @@ -262,7 +262,7 @@ private[sql] class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, - params: CSVParameters) extends OutputWriter with Logging { + params: CSVOptions) extends OutputWriter with Logging { // create the Generator without separator inserted between 2 records private[this] val text = new Text() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index e74a76c532367..0a083b5e3598e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -26,16 +26,30 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs * * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ -case class JSONOptions( - samplingRatio: Double = 1.0, - primitivesAsString: Boolean = false, - allowComments: Boolean = false, - allowUnquotedFieldNames: Boolean = false, - allowSingleQuotes: Boolean = true, - allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false, - allowBackslashEscapingAnyCharacter: Boolean = false, - compressionCodec: Option[String] = None) { +private[sql] class JSONOptions( + @transient parameters: Map[String, String]) + extends Serializable { + + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + val primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + val allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + val allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + val allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + val allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + val allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + val compressionCodec = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -48,28 +62,3 @@ case class JSONOptions( allowBackslashEscapingAnyCharacter) } } - -object JSONOptions { - def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( - samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0), - primitivesAsString = - parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false), - allowComments = - parameters.get("allowComments").map(_.toBoolean).getOrElse(false), - allowUnquotedFieldNames = - parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false), - allowSingleQuotes = - parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true), - allowNumericLeadingZeros = - parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), - allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), - allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false), - compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } - ) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 93727abcc7de9..c893558136549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -75,7 +75,7 @@ private[sql] class JSONRelation( (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) + val options: JSONOptions = new JSONOptions(parameters) /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index d22fa7905aec1..00eaeb0d34e87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1240,7 +1240,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema.infer(empty, "", JSONOptions()) + val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } @@ -1264,7 +1264,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions()) + val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } From d8e480521e362bc6bc5d8ebcea9b2d50f72a71b9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 25 Jan 2016 09:22:10 +0000 Subject: [PATCH 1675/2089] [SPARK-12932][JAVA API] improved error message for java type inference failure Author: Andy Grove Closes #10865 from andygrove/SPARK-12932. --- .../org/apache/spark/sql/catalyst/JavaTypeInference.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index b5de60cdb7b76..3c3717d5043aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -406,7 +406,8 @@ object JavaTypeInference { expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil }) } else { - throw new UnsupportedOperationException(s"no encoder found for ${other.getName}") + throw new UnsupportedOperationException( + s"Cannot infer type for class ${other.getName} because it is not bean-compliant") } } } From 4ee8191e57cb823a23ceca17908af86e70354554 Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Mon, 25 Jan 2016 09:51:41 +0000 Subject: [PATCH 1676/2089] [SPARK-12755][CORE] Stop the event logger before the DAG scheduler [SPARK-12755][CORE] Stop the event logger before the DAG scheduler to avoid a race condition where the standalone master attempts to build the app's history UI before the event log is stopped. This contribution is my original work, and I license this work to the Spark project under the project's open source license. Author: Michael Allman Closes #10700 from mallman/stop_event_logger_first. --- .../main/scala/org/apache/spark/SparkContext.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d7c605a583c4f..fa8c0f524bf3a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1675,12 +1675,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Utils.tryLogNonFatalError { _executorAllocationManager.foreach(_.stop()) } - if (_dagScheduler != null) { - Utils.tryLogNonFatalError { - _dagScheduler.stop() - } - _dagScheduler = null - } if (_listenerBusStarted) { Utils.tryLogNonFatalError { listenerBus.stop() @@ -1690,6 +1684,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) From dd2325d9a7de7bef9a6bc2f0d5f26e605545b52d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 25 Jan 2016 11:52:26 -0800 Subject: [PATCH 1677/2089] [SPARK-11965][ML][DOC] Update user guide for RFormula feature interactions Update user guide for RFormula feature interactions. Meanwhile we also update other new features such as supporting string label in Spark 1.6. Author: Yanbo Liang Closes #10222 from yanboliang/spark-11965. --- docs/ml-features.md | 20 +++++++++++++++++- .../apache/spark/ml/feature/RFormula.scala | 21 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 677e4bfb916e8..5809f65d637e4 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1121,7 +1121,25 @@ for more details on the API. ## RFormula -`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). +Currently we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. +The basic operators are: + +* `~` separate target and terms +* `+` concat terms, "+ 0" means removing intercept +* `-` remove a term, "- 1" means removing intercept +* `:` interaction (multiplication for numeric values, or binarized categorical values) +* `.` all columns except target + +Suppose `a` and `b` are double columns, we use the following simple examples to illustrate the effect of `RFormula`: + +* `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2` are coefficients. +* `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` are coefficients. + +`RFormula` produces a vector column of features and a double or string column of label. +Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. +If the label column is of type string, it will be first transformed to double with `StringIndexer`. +If the label column does not exist in the DataFrame, the output label column will be created from the specified response variable in the formula. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 6cc9d025445c0..c21da218b36d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -45,6 +45,27 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { * Implements the transforms required for fitting a dataset against an R model formula. Currently * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * + * The basic operators are: + * - `~` separate target and terms + * - `+` concat terms, "+ 0" means removing intercept + * - `-` remove a term, "- 1" means removing intercept + * - `:` interaction (multiplication for numeric values, or binarized categorical values) + * - `.` all columns except target + * + * Suppose `a` and `b` are double columns, we use the following simple examples + * to illustrate the effect of `RFormula`: + * - `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2` + * are coefficients. + * - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` + * are coefficients. + * + * RFormula produces a vector column of features and a double or string column of label. + * Like when formulas are used in R for linear regression, string input columns will be one-hot + * encoded, and numeric columns will be cast to doubles. + * If the label column is of type string, it will be first transformed to double with + * `StringIndexer`. If the label column does not exist in the DataFrame, the output label column + * will be created from the specified response variable in the formula. */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { From ef8fb3612c7be1ac9058750be39ee28d88a148b4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Jan 2016 12:36:53 -0800 Subject: [PATCH 1678/2089] Closes #10879 Closes #9046 Closes #8532 Closes #10756 Closes #8960 Closes #10485 Closes #10467 From c037d25482ea63430fb42bfd86124c268be5a4a4 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Mon, 25 Jan 2016 14:42:44 -0600 Subject: [PATCH 1679/2089] [SPARK-12149][WEB UI] Executor UI improvement suggestions - Color UI Added color coding to the Executors page for Active Tasks, Failed Tasks, Completed Tasks and Task Time. Active Tasks is shaded blue with it's range based on percentage of total cores used. Failed Tasks is shaded red ranging over the first 10% of total tasks failed Completed Tasks is shaded green ranging over 10% of total tasks including failed and active tasks, but only when there are active or failed tasks on that executor. Task Time is shaded red when GC Time goes over 10% of total time with it's range directly corresponding to the percent of total time. Author: Alex Bozarth Closes #10154 from ajbozarth/spark12149. --- .../org/apache/spark/status/api/v1/api.scala | 2 + .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/ToolTips.scala | 3 + .../apache/spark/ui/exec/ExecutorsPage.scala | 98 +++++++++++++++---- .../apache/spark/ui/exec/ExecutorsTab.scala | 10 +- .../executor_list_json_expectation.json | 2 + project/MimaExcludes.scala | 6 ++ 7 files changed, 103 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index fe372116f1b69..3adf5b1109af4 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -55,11 +55,13 @@ class ExecutorSummary private[spark]( val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, + val maxTasks: Int, val activeTasks: Int, val failedTasks: Int, val completedTasks: Int, val totalTasks: Int, val totalDuration: Long, + val totalGCTime: Long, val totalInputBytes: Long, val totalShuffleRead: Long, val totalShuffleWrite: Long, diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index eb53aa8e23ae7..cf45414c4f786 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -195,7 +195,7 @@ private[spark] object SparkUI { val environmentListener = new EnvironmentListener val storageStatusListener = new StorageStatusListener - val executorsListener = new ExecutorsListener(storageStatusListener) + val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index cb122eaed83d1..2d2d80be4aabe 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -87,4 +87,7 @@ private[spark] object ToolTips { multiple operations (e.g. two map() functions) if they can be pipelined. Some operations also create multiple RDDs internally. Cached RDDs are shown in green. """ + + val TASK_TIME = + "Shaded red when garbage collection (GC) time is over 10% of task time" } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 440dfa2679563..e36b96b3e6978 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -50,6 +50,8 @@ private[ui] class ExecutorsPage( threadDumpEnabled: Boolean) extends WebUIPage("") { private val listener = parent.listener + // When GCTimePercent is edited change ToolTips.TASK_TIME to match + private val GCTimePercent = 0.1 def render(request: HttpServletRequest): Seq[Node] = { val (storageStatusList, execInfo) = listener.synchronized { @@ -77,7 +79,7 @@ private[ui] class ExecutorsPage( Failed Tasks Complete Tasks Total Tasks - Task Time + Task Time (GC Time) Input Shuffle Read @@ -129,13 +131,8 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} - {info.activeTasks} - {info.failedTasks} - {info.completedTasks} - {info.totalTasks} - - {Utils.msDurationToString(info.totalDuration)} - + {taskData(info.maxTasks, info.activeTasks, info.failedTasks, info.completedTasks, + info.totalTasks, info.totalDuration, info.totalGCTime)} {Utils.bytesToString(info.totalInputBytes)} @@ -177,7 +174,6 @@ private[ui] class ExecutorsPage( val maximumMemory = execInfo.map(_.maxMemory).sum val memoryUsed = execInfo.map(_.memoryUsed).sum val diskUsed = execInfo.map(_.diskUsed).sum - val totalDuration = execInfo.map(_.totalDuration).sum val totalInputBytes = execInfo.map(_.totalInputBytes).sum val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum @@ -192,13 +188,13 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} - {execInfo.map(_.activeTasks).sum} - {execInfo.map(_.failedTasks).sum} - {execInfo.map(_.completedTasks).sum} - {execInfo.map(_.totalTasks).sum} - - {Utils.msDurationToString(totalDuration)} - + {taskData(execInfo.map(_.maxTasks).sum, + execInfo.map(_.activeTasks).sum, + execInfo.map(_.failedTasks).sum, + execInfo.map(_.completedTasks).sum, + execInfo.map(_.totalTasks).sum, + execInfo.map(_.totalDuration).sum, + execInfo.map(_.totalGCTime).sum)} {Utils.bytesToString(totalInputBytes)} @@ -219,7 +215,7 @@ private[ui] class ExecutorsPage( Failed Tasks Complete Tasks Total Tasks - Task Time + Task Time (GC Time) Input Shuffle Read @@ -233,6 +229,70 @@ private[ui] class ExecutorsPage( } + + private def taskData( + maxTasks: Int, + activeTasks: Int, + failedTasks: Int, + completedTasks: Int, + totalTasks: Int, + totalDuration: Long, + totalGCTime: Long): + Seq[Node] = { + // Determine Color Opacity from 0.5-1 + // activeTasks range from 0 to maxTasks + val activeTasksAlpha = + if (maxTasks > 0) { + (activeTasks.toDouble / maxTasks) * 0.5 + 0.5 + } else { + 1 + } + // failedTasks range max at 10% failure, alpha max = 1 + val failedTasksAlpha = + if (totalTasks > 0) { + math.min(10 * failedTasks.toDouble / totalTasks, 1) * 0.5 + 0.5 + } else { + 1 + } + // totalDuration range from 0 to 50% GC time, alpha max = 1 + val totalDurationAlpha = + if (totalDuration > 0) { + math.min(totalGCTime.toDouble / totalDuration + 0.5, 1) + } else { + 1 + } + + val tableData = + 0) { + "background:hsla(240, 100%, 50%, " + activeTasksAlpha + ");color:white" + } else { + "" + } + }>{activeTasks} + 0) { + "background:hsla(0, 100%, 50%, " + failedTasksAlpha + ");color:white" + } else { + "" + } + }>{failedTasks} + {completedTasks} + {totalTasks} + GCTimePercent * totalDuration) { + "background:hsla(0, 100%, 50%, " + totalDurationAlpha + ");color:white" + } else { + "" + } + }> + {Utils.msDurationToString(totalDuration)} + ({Utils.msDurationToString(totalGCTime)}) + ; + + tableData + } } private[spark] object ExecutorsPage { @@ -245,11 +305,13 @@ private[spark] object ExecutorsPage { val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed + val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0) val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0) val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks val totalDuration = listener.executorToDuration.getOrElse(execId, 0L) + val totalGCTime = listener.executorToJvmGCTime.getOrElse(execId, 0L) val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L) val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L) val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L) @@ -261,11 +323,13 @@ private[spark] object ExecutorsPage { rddBlocks, memUsed, diskUsed, + maxTasks, activeTasks, failedTasks, completedTasks, totalTasks, totalDuration, + totalGCTime, totalInputBytes, totalShuffleRead, totalShuffleWrite, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 160d7a4dff2dc..a9e926b158780 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{ExceptionFailure, Resubmitted, SparkContext} +import org.apache.spark.{ExceptionFailure, Resubmitted, SparkConf, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} @@ -43,11 +43,14 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi -class ExecutorsListener(storageStatusListener: StorageStatusListener) extends SparkListener { +class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) + extends SparkListener { + val executorToTasksMax = HashMap[String, Int]() val executorToTasksActive = HashMap[String, Int]() val executorToTasksComplete = HashMap[String, Int]() val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() + val executorToJvmGCTime = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() val executorToInputRecords = HashMap[String, Long]() val executorToOutputBytes = HashMap[String, Long]() @@ -62,6 +65,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap + executorToTasksMax(eid) = + executorAdded.executorInfo.totalCores / conf.getInt("spark.task.cpus", 1) executorIdToData(eid) = ExecutorUIData(executorAdded.time) } @@ -131,6 +136,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp executorToShuffleWrite(eid) = executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.bytesWritten } + executorToJvmGCTime(eid) = executorToJvmGCTime.getOrElse(eid, 0L) + metrics.jvmGCTime } } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index cb622e147249e..94f8aeac55b5d 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -4,11 +4,13 @@ "rddBlocks" : 8, "memoryUsed" : 28000128, "diskUsed" : 0, + "maxTasks" : 0, "activeTasks" : 0, "failedTasks" : 1, "completedTasks" : 31, "totalTasks" : 32, "totalDuration" : 8820, + "totalGCTime" : 352, "totalInputBytes" : 28000288, "totalShuffleRead" : 0, "totalShuffleWrite" : 13180, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c65fae482c5ca..501456b043170 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -127,6 +127,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") + ) ++ Seq( + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") ) ++ // SPARK-12665 Remove deprecated and unused classes Seq( @@ -301,6 +304,9 @@ object MimaExcludes { // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") + ) ++ Seq( + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") ) ++ // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a // private class. From 7d877c3439d872ec2a9e07d245e9c96174c0cf00 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 25 Jan 2016 12:44:20 -0800 Subject: [PATCH 1680/2089] [SPARK-12902] [SQL] visualization for generated operators This PR brings back visualization for generated operators, they looks like: ![sql](https://cloud.githubusercontent.com/assets/40902/12460920/0dc7956a-bf6b-11e5-9c3f-8389f452526e.png) ![stage](https://cloud.githubusercontent.com/assets/40902/12460923/11806ac4-bf6b-11e5-9c72-e84a62c5ea93.png) Note: SQL metrics are not supported right now, because they are very slow, will be supported once we have batch mode. Author: Davies Liu Closes #10828 from davies/viz_codegen. --- .../apache/spark/ui/static/spark-dag-viz.js | 2 +- .../spark/ui/scope/RDDOperationGraph.scala | 6 +- .../sql/execution/ui/static/spark-sql-viz.css | 6 ++ .../spark/sql/execution/SparkPlanInfo.scala | 9 +- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 2 +- .../sql/execution/ui/SparkPlanGraph.scala | 95 ++++++++++++++----- .../execution/metric/SQLMetricsSuite.scala | 10 +- .../sql/execution/ui/SQLListenerSuite.scala | 2 +- 9 files changed, 104 insertions(+), 32 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 83dbea40b63f3..4337c42087e79 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -284,7 +284,7 @@ function renderDot(dot, container, forJob) { renderer(container, g); // Find the stage cluster and mark it for styling and post-processing - container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); + container.selectAll("g.cluster[name^=\"Stage \"]").classed("stage", true); } /* -------------------- * diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 06da74f1b6b5f..003c218aada9c 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -130,7 +130,11 @@ private[ui] object RDDOperationGraph extends Logging { } } // Attach the outermost cluster to the root cluster, and the RDD to the innermost cluster - rddClusters.headOption.foreach { cluster => rootCluster.attachChildCluster(cluster) } + rddClusters.headOption.foreach { cluster => + if (!rootCluster.childClusters.contains(cluster)) { + rootCluster.attachChildCluster(cluster) + } + } rddClusters.lastOption.foreach { cluster => cluster.attachChildNode(node) } } } diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index ddd3a91dd8ef8..303f8ebb8814c 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -20,6 +20,12 @@ text-shadow: none; } +#plan-viz-graph svg g.cluster rect { + fill: #A0DFFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + #plan-viz-graph svg g.node rect { fill: #C3EBFF; stroke: #3EC0FF; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 4f750ad13ab84..4dd9928244197 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -36,12 +36,17 @@ class SparkPlanInfo( private[sql] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val children = plan match { + case WholeStageCodegen(child, _) => child :: Nil + case InputAdapter(child) => child :: Nil + case plan => plan.children + } val metrics = plan.metrics.toSeq.map { case (key, metric) => new SQLMetricInfo(metric.name.getOrElse(key), metric.id, Utils.getFormattedClassName(metric.param)) } - val children = plan.children.map(fromSparkPlan) - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, plan.metadata, metrics) + new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + plan.metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index c74ad40406992..49915adf6cd29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -99,7 +99,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") } private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { - val metadata = graph.nodes.flatMap { node => + val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
    {node.desc}
    } @@ -110,7 +110,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution")
    {graph.makeDotFile(metrics)}
    -
    {graph.nodes.size.toString}
    +
    {graph.allNodes.size.toString}
    {metadata}
    {planVisualizationResources} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index cd56136927088..83c64f755f90f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -219,7 +219,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case SparkListenerSQLExecutionStart(executionId, description, details, physicalPlanDescription, sparkPlanInfo, time) => val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + val sqlPlanMetrics = physicalPlanGraph.allNodes.flatMap { node => node.metrics.map(metric => metric.accumulatorId -> metric) } val executionUIData = new SQLExecutionUIData( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 3a6eff9399825..4eb248569b281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -41,6 +41,16 @@ private[ui] case class SparkPlanGraph( dotFile.append("}") dotFile.toString() } + + /** + * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. + */ + val allNodes: Seq[SparkPlanGraphNode] = { + nodes.flatMap { + case cluster: SparkPlanGraphCluster => cluster.nodes :+ cluster + case node => Seq(node) + } + } } private[sql] object SparkPlanGraph { @@ -52,7 +62,7 @@ private[sql] object SparkPlanGraph { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null) new SparkPlanGraph(nodes, edges) } @@ -60,22 +70,40 @@ private[sql] object SparkPlanGraph { planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], - edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + edges: mutable.ArrayBuffer[SparkPlanGraphEdge], + parent: SparkPlanGraphNode, + subgraph: SparkPlanGraphCluster): Unit = { + if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) { + val cluster = new SparkPlanGraphCluster( + nodeIdGenerator.getAndIncrement(), + planInfo.nodeName, + planInfo.simpleString, + mutable.ArrayBuffer[SparkPlanGraphNode]()) + nodes += cluster + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) + } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) { + buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) + } else { + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val node = new SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + if (subgraph == null) { + nodes += node + } else { + subgraph.nodes += node + } + + if (parent != null) { + edges += SparkPlanGraphEdge(node.id, parent.id) + } + planInfo.children.foreach( + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) } - val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, - planInfo.simpleString, planInfo.metadata, metrics) - - nodes += node - val childrenNodes = planInfo.children.map( - child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) - for (child <- childrenNodes) { - edges += SparkPlanGraphEdge(child.id, node.id) - } - node } } @@ -86,12 +114,12 @@ private[sql] object SparkPlanGraph { * @param name the name of this SparkPlan node * @param metrics metrics that this SparkPlan node will track */ -private[ui] case class SparkPlanGraphNode( - id: Long, - name: String, - desc: String, - metadata: Map[String, String], - metrics: Seq[SQLPlanMetric]) { +private[ui] class SparkPlanGraphNode( + val id: Long, + val name: String, + val desc: String, + val metadata: Map[String, String], + val metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { val builder = new mutable.StringBuilder(name) @@ -117,6 +145,27 @@ private[ui] case class SparkPlanGraphNode( } } +/** + * Represent a tree of SparkPlan for WholeStageCodegen. + */ +private[ui] class SparkPlanGraphCluster( + id: Long, + name: String, + desc: String, + val nodes: mutable.ArrayBuffer[SparkPlanGraphNode]) + extends SparkPlanGraphNode(id, name, desc, Map.empty, Nil) { + + override def makeDotNode(metricsValue: Map[Long, String]): String = { + s""" + | subgraph cluster${id} { + | label=${name}; + | ${nodes.map(_.makeDotNode(metricsValue)).mkString(" \n")} + | } + """.stripMargin + } +} + + /** * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child * node id. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 51285431a47ed..cbae19ebd269d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -86,7 +86,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( - df.queryExecution.executedPlan)).nodes.filter { node => + df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => @@ -134,6 +134,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("WholeStageCodegen metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) + // TODO: update metrics in generated operators + val df = sqlContext.range(10).filter('id < 5) + testSparkPlanMetrics(df, 1, Map.empty) + } + test("TungstenAggregate metrics") { // Assume the execution plan is // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index eef3c1f3e34d9..81a159d542c67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -83,7 +83,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val df = createTestDataFrame val accumulatorIds = SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) - .nodes.flatMap(_.metrics.map(_.accumulatorId)) + .allNodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => From 00026fa9912ecee5637f1e7dd222f977f31f6766 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 25 Jan 2016 12:59:11 -0800 Subject: [PATCH 1681/2089] [SPARK-12901][SQL][HOT-FIX] Fix scala 2.11 compilation. --- .../apache/spark/sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../spark/sql/execution/datasources/json/JSONOptions.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 5d0e99d7601dc..709daccbbef58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs private[sql] class CSVOptions( - @transient parameters: Map[String, String]) + @transient private val parameters: Map[String, String]) extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 0a083b5e3598e..fe5b20697e40e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient parameters: Map[String, String]) + @transient private val parameters: Map[String, String]) extends Serializable { val samplingRatio = From 9348431da212ec3ab7be2b8e89a952a48b4e2a31 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 25 Jan 2016 13:38:09 -0800 Subject: [PATCH 1682/2089] [SPARK-12975][SQL] Throwing Exception when Bucketing Columns are part of Partitioning Columns When users are using `partitionBy` and `bucketBy` at the same time, some bucketing columns might be part of partitioning columns. For example, ``` df.write .format(source) .partitionBy("i") .bucketBy(8, "i", "k") .saveAsTable("bucketed_table") ``` However, in the above case, adding column `i` into `bucketBy` is useless. It is just wasting extra CPU when reading or writing bucket tables. Thus, like Hive, we can issue an exception and let users do the change. Also added a test case for checking if the information of `sortBy` and `bucketBy` columns are correctly saved in the metastore table. Could you check if my understanding is correct? cloud-fan rxin marmbrus Thanks! Author: gatorsmile Closes #10891 from gatorsmile/commonKeysInPartitionByBucketBy. --- .../apache/spark/sql/DataFrameWriter.scala | 9 +++ .../sql/hive/MetastoreDataSourcesSuite.scala | 55 ++++++++++++++++++- .../sql/sources/BucketedWriteSuite.scala | 22 +++++++- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ab63fe4aa88b7..12eb2393634a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -240,6 +240,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { n <- numBuckets } yield { require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + + // partitionBy columns cannot be used in bucketBy + if (normalizedParCols.nonEmpty && + normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) { + throw new AnalysisException( + s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " + + s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'") + } + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 253f13c598520..211932fea00ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -745,7 +745,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - test("Saving partition columns information") { + test("Saving partitionBy columns information") { val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") val tableName = s"partitionInfo_${System.currentTimeMillis()}" @@ -776,6 +776,59 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("Saving information for sortBy and bucketBy columns") { + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") + val tableName = s"bucketingInfo_${System.currentTimeMillis()}" + + withTable(tableName) { + df.write + .format("parquet") + .bucketBy(8, "d", "b") + .sortBy("c") + .saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val expectedSortByColumns = StructType(df.schema("c") :: Nil) + + val numBuckets = metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt + assert(numBuckets == 8) + + val numBucketCols = metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt + assert(numBucketCols == 2) + + val numSortCols = metastoreTable.properties("spark.sql.sources.schema.numSortCols").toInt + assert(numSortCols == 1) + + val actualBucketByColumns = + StructType( + (0 until numBucketCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.bucketCol.$index")) + }) + // Make sure bucketBy columns are correctly stored in metastore. + assert( + expectedBucketByColumns.sameType(actualBucketByColumns), + s"Partitions columns stored in metastore $actualBucketByColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedBucketByColumns.") + + val actualSortByColumns = + StructType( + (0 until numSortCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.sortCol.$index")) + }) + // Make sure sortBy columns are correctly stored in metastore. + assert( + expectedSortByColumns.sameType(actualSortByColumns), + s"Partitions columns stored in metastore $actualSortByColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedSortByColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } + } + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = { (from to to).map(i => i -> s"str$i").toDF("c1", "c2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 59b74d2b4c5ea..a32f8fb4c5a1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -92,10 +92,13 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle fail(s"Unable to find the related bucket files.") } + // Remove the duplicate columns in bucketCols and sortCols; + // Otherwise, we got analysis errors due to duplicate names + val selectedColumns = (bucketCols ++ sortCols).distinct // We may lose the type information after write(e.g. json format doesn't keep schema // information), here we get the types from the original dataframe. - val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) - val columns = (bucketCols ++ sortCols).zip(types).map { + val types = df.select(selectedColumns.map(col): _*).schema.map(_.dataType) + val columns = selectedColumns.zip(types).map { case (colName, dt) => col(colName).cast(dt) } @@ -158,6 +161,21 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } + test("write bucketed data with the overlapping bucketBy and partitionBy columns") { + intercept[AnalysisException](df.write + .partitionBy("i", "j") + .bucketBy(8, "j", "k") + .sortBy("k") + .saveAsTable("bucketed_table")) + } + + test("write bucketed data with the identical bucketBy and partitionBy columns") { + intercept[AnalysisException](df.write + .partitionBy("i") + .bucketBy(8, "i") + .saveAsTable("bucketed_table")) + } + test("write bucketed data without partitionBy") { for (source <- Seq("parquet", "json", "orc")) { withTable("bucketed_table") { From dcae355c64d7f6fdf61df2feefe464eb96c4cf5e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 25 Jan 2016 13:54:21 -0800 Subject: [PATCH 1683/2089] [SPARK-12905][ML][PYSPARK] PCAModel return eigenvalues for PySpark ```PCAModel``` can output ```explainedVariance``` at Python side. cc mengxr srowen Author: Yanbo Liang Closes #10830 from yanboliang/spark-12905. --- .../main/scala/org/apache/spark/ml/feature/PCA.scala | 2 ++ python/pyspark/ml/feature.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 7020397f3b064..0e07dfabfeaab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -102,6 +102,8 @@ object PCA extends DefaultParamsReadable[PCA] { * Model fitted by [[PCA]]. * * @param pc A principal components Matrix. Each column is one principal component. + * @param explainedVariance A vector of proportions of variance explained by + * each principal component. */ @Experimental class PCAModel private[ml] ( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 141ec3492aa94..1fa0eab384e7b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1987,6 +1987,8 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): >>> model = pca.fit(df) >>> model.transform(df).collect()[0].pca_features DenseVector([1.648..., -4.013...]) + >>> model.explainedVariance + DenseVector([0.794..., 0.205...]) .. versionadded:: 1.5.0 """ @@ -2052,6 +2054,15 @@ def pc(self): """ return self._call_java("pc") + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns a vector of proportions of variance + explained by each principal component. + """ + return self._call_java("explainedVariance") + @inherit_doc class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): From 6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 25 Jan 2016 15:05:05 -0800 Subject: [PATCH 1684/2089] [SPARK-12934][SQL] Count-min sketch serialization This PR adds serialization support for `CountMinSketch`. A version number is added to version the serialized binary format. Author: Cheng Lian Closes #10893 from liancheng/cms-serialization. --- .../spark/util/sketch/CountMinSketch.java | 32 ++++- .../spark/util/sketch/CountMinSketchImpl.java | 129 ++++++++++++++++-- .../sketch/IncompatibleMergeException.java | 24 ++++ .../util/sketch/CountMinSketchSuite.scala | 47 ++++++- 4 files changed, 213 insertions(+), 19 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 21b161bc74ae0..67938644d9f6c 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -17,6 +17,7 @@ package org.apache.spark.util.sketch; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -54,6 +55,25 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { + /** + * Version number of the serialized binary format. + */ + public enum Version { + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + public int getVersionNumber() { + return versionNumber; + } + } + + public abstract Version version(); + /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -99,19 +119,23 @@ abstract public class CountMinSketch { * * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed * can be merged. + * + * @exception IncompatibleMergeException if the {@code other} {@link CountMinSketch} has + * incompatible depth, width, relative-error, confidence, or random seed. */ - public abstract CountMinSketch mergeInPlace(CountMinSketch other); + public abstract CountMinSketch mergeInPlace(CountMinSketch other) + throws IncompatibleMergeException; /** * Writes out this {@link CountMinSketch} to an output stream in binary format. */ - public abstract void writeTo(OutputStream out); + public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. */ - public static CountMinSketch readFrom(InputStream in) { - throw new UnsupportedOperationException("Not implemented yet"); + public static CountMinSketch readFrom(InputStream in) throws IOException { + return CountMinSketchImpl.readFrom(in); } /** diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e9fdbe3a86862..0209446ea3b1d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -17,11 +17,30 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; +/* + * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian + * order): + * + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; @@ -33,7 +52,7 @@ class CountMinSketchImpl extends CountMinSketch { private double eps; private double confidence; - public CountMinSketchImpl(int depth, int width, int seed) { + CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; this.width = width; this.eps = 2.0 / width; @@ -41,7 +60,7 @@ public CountMinSketchImpl(int depth, int width, int seed) { initTablesWith(depth, width, seed); } - public CountMinSketchImpl(double eps, double confidence, int seed) { + CountMinSketchImpl(double eps, double confidence, int seed) { // 2/w = eps ; w = 2/eps // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) this.eps = eps; @@ -51,6 +70,53 @@ public CountMinSketchImpl(double eps, double confidence, int seed) { initTablesWith(depth, width, seed); } + CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { + this.depth = depth; + this.width = width; + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + this.hashA = hashA; + this.table = table; + this.totalCount = totalCount; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof CountMinSketchImpl)) { + return false; + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + return + this.depth == that.depth && + this.width == that.width && + this.totalCount == that.totalCount && + Arrays.equals(this.hashA, that.hashA) && + Arrays.deepEquals(this.table, that.table); + } + + @Override + public int hashCode() { + int hash = depth; + + hash = hash * 31 + width; + hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32)); + hash = hash * 31 + Arrays.hashCode(hashA); + hash = hash * 31 + Arrays.deepHashCode(table); + + return hash; + } + + @Override + public Version version() { + return Version.V1; + } + private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -221,27 +287,29 @@ private long estimateCountForStringItem(String item) { } @Override - public CountMinSketch mergeInPlace(CountMinSketch other) { + public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException { if (other == null) { - throw new CMSMergeException("Cannot merge null estimator"); + throw new IncompatibleMergeException("Cannot merge null estimator"); } if (!(other instanceof CountMinSketchImpl)) { - throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName()); + throw new IncompatibleMergeException( + "Cannot merge estimator of class " + other.getClass().getName() + ); } CountMinSketchImpl that = (CountMinSketchImpl) other; if (this.depth != that.depth) { - throw new CMSMergeException("Cannot merge estimators of different depth"); + throw new IncompatibleMergeException("Cannot merge estimators of different depth"); } if (this.width != that.width) { - throw new CMSMergeException("Cannot merge estimators of different width"); + throw new IncompatibleMergeException("Cannot merge estimators of different width"); } if (!Arrays.equals(this.hashA, that.hashA)) { - throw new CMSMergeException("Cannot merge estimators of different seed"); + throw new IncompatibleMergeException("Cannot merge estimators of different seed"); } for (int i = 0; i < this.table.length; ++i) { @@ -256,13 +324,48 @@ public CountMinSketch mergeInPlace(CountMinSketch other) { } @Override - public void writeTo(OutputStream out) { - throw new UnsupportedOperationException("Not implemented yet"); + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(version().getVersionNumber()); + + dos.writeLong(this.totalCount); + dos.writeInt(this.depth); + dos.writeInt(this.width); + + for (int i = 0; i < this.depth; ++i) { + dos.writeLong(this.hashA[i]); + } + + for (int i = 0; i < this.depth; ++i) { + for (int j = 0; j < this.width; ++j) { + dos.writeLong(table[i][j]); + } + } } - protected static class CMSMergeException extends RuntimeException { - public CMSMergeException(String message) { - super(message); + public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + // Ignores version number + dis.readInt(); + + long totalCount = dis.readLong(); + int depth = dis.readInt(); + int width = dis.readInt(); + + long hashA[] = new long[depth]; + for (int i = 0; i < depth; ++i) { + hashA[i] = dis.readLong(); + } + + long table[][] = new long[depth][width]; + for (int i = 0; i < depth; ++i) { + for (int j = 0; j < width; ++j) { + table[i][j] = dis.readLong(); + } } + + return new CountMinSketchImpl(depth, width, totalCount, hashA, table); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java new file mode 100644 index 0000000000000..64b567caa57c1 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +public class IncompatibleMergeException extends Exception { + public IncompatibleMergeException(String message) { + super(message); + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index ec5b4eddeca0d..b9c7f5c23a8fe 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite private val seed = 42 + // Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(sketch: CountMinSketch): Unit = { + val out = new ByteArrayOutputStream() + sketch.writeTo(out) + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = CountMinSketch.readFrom(in) + + assert(sketch === deserialized) + } + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { test(s"accuracy - $typeName") { - val r = new Random() + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) val numAllItems = 1000000 val allItems = Array.fill(numAllItems)(itemGenerator(r)) @@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + checkSerDe(sketch) val probCorrect = { val numErrors = allItems.map { item => @@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { test(s"mergeInPlace - $typeName") { - val r = new Random() + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + val numToMerge = 5 val numItemsPerSketch = 100000 val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { @@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val sketches = perSketchItems.map { items => val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + items.foreach(sketch.add) + checkSerDe(sketch) + sketch } val mergedSketch = sketches.reduce(_ mergeInPlace _) + checkSerDe(mergedSketch) val expectedSketch = { val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[Long]("Long") { _.nextLong() } testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + CountMinSketch.create(10, 10, 1).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 20, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 10, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + } } From be375fcbd200fb0e210b8edcfceb5a1bcdbba94b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 25 Jan 2016 16:23:59 -0800 Subject: [PATCH 1685/2089] [SPARK-12879] [SQL] improve the unsafe row writing framework As we begin to use unsafe row writing framework(`BufferHolder` and `UnsafeRowWriter`) in more and more places(`UnsafeProjection`, `UnsafeRowParquetRecordReader`, `GenerateColumnAccessor`, etc.), we should add more doc to it and make it easier to use. This PR abstract the technique used in `UnsafeRowParquetRecordReader`: avoid unnecessary operatition as more as possible. For example, do not always point the row to the buffer at the end, we only need to update the size of row. If all fields are of primitive type, we can even save the row size updating. Then we can apply this technique to more places easily. a local benchmark shows `UnsafeProjection` is up to 1.7x faster after this PR: **old version** ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- single long 2616.04 102.61 1.00 X single nullable long 3032.54 88.52 0.86 X primitive types 9121.05 29.43 0.29 X nullable primitive types 12410.60 21.63 0.21 X ``` **new version** ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- single long 1533.34 175.07 1.00 X single nullable long 2306.73 116.37 0.66 X primitive types 8403.93 31.94 0.18 X nullable primitive types 12448.39 21.56 0.12 X ``` For single non-nullable long(the best case), we can have about 1.7x speed up. Even it's nullable, we can still have 1.3x speed up. For other cases, it's not such a boost as the saved operations only take a little proportion of the whole process. The benchmark code is included in this PR. Author: Wenchen Fan Closes #10809 from cloud-fan/unsafe-projection. --- .../expressions/codegen/BufferHolder.java | 44 +++--- .../expressions/codegen/UnsafeRowWriter.java | 58 +++++--- .../codegen/GenerateUnsafeProjection.scala | 66 ++++++--- .../spark/sql/UnsafeProjectionBenchmark.scala | 136 ++++++++++++++++++ .../parquet/UnsafeRowParquetRecordReader.java | 17 +-- .../columnar/GenerateColumnAccessor.scala | 8 +- .../datasources/text/DefaultSource.scala | 7 +- 7 files changed, 258 insertions(+), 78 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index d26b1b187c27b..af61e2011f400 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -21,24 +21,40 @@ import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer when construct unsafe rows. + * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and + * automatically re-point the unsafe row to it. + * + * This class can be used to build a one-pass unsafe row writing program, i.e. data will be written + * to the data buffer directly and no extra copy is needed. There should be only one instance of + * this class per writing program, so that the memory segment/data buffer can be reused. Note that + * for each incoming record, we should call `reset` of BufferHolder instance before write the record + * and reuse the data buffer. + * + * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update + * the size of the result row, after writing a record to the buffer. However, we can skip this step + * if the fields of row are all fixed-length, as the size of result row is also fixed. */ public class BufferHolder { public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; + private final UnsafeRow row; + private final int fixedSize; - public BufferHolder() { - this(64); + public BufferHolder(UnsafeRow row) { + this(row, 64); } - public BufferHolder(int size) { - buffer = new byte[size]; + public BufferHolder(UnsafeRow row, int initialSize) { + this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields(); + this.buffer = new byte[fixedSize + initialSize]; + this.row = row; + this.row.pointTo(buffer, buffer.length); } /** - * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize, UnsafeRow row) { + public void grow(int neededSize) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -50,22 +66,12 @@ public void grow(int neededSize, UnsafeRow row) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; - if (row != null) { - row.pointTo(buffer, length * 2); - } + row.pointTo(buffer, buffer.length); } } - public void grow(int neededSize) { - grow(neededSize, null); - } - public void reset() { - cursor = Platform.BYTE_ARRAY_OFFSET; - } - public void resetTo(int offset) { - assert(offset <= buffer.length); - cursor = Platform.BYTE_ARRAY_OFFSET + offset; + cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } public int totalSize() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index e227c0dec9748..4776617043878 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -26,38 +26,56 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A helper class to write data into global row buffer using `UnsafeRow` format, - * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + * A helper class to write data into global row buffer using `UnsafeRow` format. + * + * It will remember the offset of row buffer which it starts to write, and move the cursor of row + * buffer while writing. If new data(can be the input record if this is the outermost writer, or + * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be + * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * `startingOffset` and clear out null bits. + * + * Note that if this is the outermost writer, which means we will always write from the very + * beginning of the global row buffer, we don't need to update `startingOffset` and can just call + * `zeroOutNullBytes` before writing new data. */ public class UnsafeRowWriter { - private BufferHolder holder; + private final BufferHolder holder; // The offset of the global buffer where we start to write this row. private int startingOffset; - private int nullBitsSize; - private UnsafeRow row; + private final int nullBitsSize; + private final int fixedSize; - public void initialize(BufferHolder holder, int numFields) { + public UnsafeRowWriter(BufferHolder holder, int numFields) { this.holder = holder; - this.startingOffset = holder.cursor; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + this.fixedSize = nullBitsSize + 8 * numFields; + this.startingOffset = holder.cursor; + } + + /** + * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null + * bits. This should be called before we write a new nested struct to the row buffer. + */ + public void reset() { + this.startingOffset = holder.cursor; // grow the global buffer to make sure it has enough space to write fixed-length data. - final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize, row); + holder.grow(fixedSize); holder.cursor += fixedSize; - // zero-out the null bits region + zeroOutNullBytes(); + } + + /** + * Clears out null bits. This should be called before we write a new row to row buffer. + */ + public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { Platform.putLong(holder.buffer, startingOffset + i, 0L); } } - public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { - initialize(holder, numFields); - this.row = row; - } - private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); @@ -98,7 +116,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes, row); + holder.grow(paddingBytes); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -161,7 +179,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16, row); + holder.grow(16); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -193,7 +211,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize, row); + holder.grow(roundedSize); zeroOutPaddingBytes(numBytes); @@ -214,7 +232,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize, row); + holder.grow(roundedSize); zeroOutPaddingBytes(numBytes); @@ -230,7 +248,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16, row); + holder.grow(16); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 72bf39a0398b1..6aa9cbf08bdb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - private val rowWriterClass = classOf[UnsafeRowWriter].getName - private val arrayWriterClass = classOf[UnsafeArrayWriter].getName - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, @@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String): String = { + bufferHolder: String, + isTopLevel: Boolean = false): String = { + val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") + ctx.addMutableState(rowWriterClass, rowWriter, + s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + + val resetWriter = if (isTopLevel) { + // For top level row writer, it always writes to the beginning of the global buffer holder, + // which means its fixed-size region always in the same position, so we don't need to call + // `reset` to set up its fixed-size region every time. + if (inputs.map(_.isNull).forall(_ == "false")) { + // If all fields are not nullable, which means the null bits never changes, then we don't + // need to clear it out every time. + "" + } else { + s"$rowWriter.zeroOutNullBytes();" + } + } else { + s"$rowWriter.reset();" + } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => @@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ - case _ if ctx.isPrimitiveType(dt) => - s""" - $rowWriter.write($index, ${input.value}); - """ - case t: DecimalType => s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" @@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } s""" - $rowWriter.initialize($bufferHolder, ${inputs.length}); + $resetWriter ${ctx.splitExpressions(row, writeFields)} """.trim } @@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, elementType: DataType, bufferHolder: String): String = { + val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"this.$arrayWriter = new $arrayWriterClass();") @@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) + val numVarLenFields = exprTypes.count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + val result = ctx.freshName("result") ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") - val bufferHolder = ctx.freshName("bufferHolder") + + val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + ctx.addMutableState(holderClass, holder, + s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + + val resetBufferHolder = if (numVarLenFields == 0) { + "" + } else { + s"$holder.reset();" + } + val updateRowSize = if (numVarLenFields == 0) { + "" + } else { + s"$result.setTotalSize($holder.totalSize());" + } // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val writeExpressions = + writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val code = s""" - $bufferHolder.reset(); + $resetBufferHolder $evalSubexpr - ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - - $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); + $writeExpressions + $updateRowSize """ ExprCode(code, "false", result) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala new file mode 100644 index 0000000000000..a6d90409382e5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + */ +object UnsafeProjectionBenchmark { + + def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray + } + + def main(args: Array[String]) { + val iters = 1024 * 16 + val numRows = 1024 * 16 + + val benchmark = new Benchmark("unsafe projection", iters * numRows) + + + val schema1 = new StructType().add("l", LongType, false) + val attrs1 = schema1.toAttributes + val rows1 = generateRows(schema1, numRows) + val projection1 = UnsafeProjection.create(attrs1, attrs1) + + benchmark.addCase("single long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection1(rows1(i)).getLong(0) + i += 1 + } + } + } + + val schema2 = new StructType().add("l", LongType, true) + val attrs2 = schema2.toAttributes + val rows2 = generateRows(schema2, numRows) + val projection2 = UnsafeProjection.create(attrs2, attrs2) + + benchmark.addCase("single nullable long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection2(rows2(i)).getLong(0) + i += 1 + } + } + } + + + val schema3 = new StructType() + .add("boolean", BooleanType, false) + .add("byte", ByteType, false) + .add("short", ShortType, false) + .add("int", IntegerType, false) + .add("long", LongType, false) + .add("float", FloatType, false) + .add("double", DoubleType, false) + val attrs3 = schema3.toAttributes + val rows3 = generateRows(schema3, numRows) + val projection3 = UnsafeProjection.create(attrs3, attrs3) + + benchmark.addCase("7 primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection3(rows3(i)).getLong(0) + i += 1 + } + } + } + + + val schema4 = new StructType() + .add("boolean", BooleanType, true) + .add("byte", ByteType, true) + .add("short", ShortType, true) + .add("int", IntegerType, true) + .add("long", LongType, true) + .add("float", FloatType, true) + .add("double", DoubleType, true) + val attrs4 = schema4.toAttributes + val rows4 = generateRows(schema4, numRows) + val projection4 = UnsafeProjection.create(attrs4, attrs4) + + benchmark.addCase("7 nullable primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection4(rows4(i)).getLong(0) + i += 1 + } + } + } + + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + single long 1533.34 175.07 1.00 X + single nullable long 2306.73 116.37 0.66 X + primitive types 8403.93 31.94 0.18 X + nullable primitive types 12448.39 21.56 0.12 X + */ + benchmark.run() + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 80805f15a8f06..17adfec32192f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -73,11 +73,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private boolean containsVarLenFields; - /** - * The number of bytes in the fixed length portion of the row. - */ - private int fixedSizeBytes; - /** * For each request column, the reader to read this column. * columnsReaders[i] populated the UnsafeRow's attribute at i. @@ -266,19 +261,13 @@ private void initializeInternal() throws IOException { /** * Initialize rows and rowWriters. These objects are reused across all rows in the relation. */ - int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); - rowByteSize += 8 * requestedSchema.getFieldCount(); - fixedSizeBytes = rowByteSize; - rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; containsVarLenFields = numVarLenFields > 0; rowWriters = new UnsafeRowWriter[rows.length]; for (int i = 0; i < rows.length; ++i) { rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); - rowWriters[i] = new UnsafeRowWriter(); - BufferHolder holder = new BufferHolder(rowByteSize); - rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); - rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length); + BufferHolder holder = new BufferHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE); + rowWriters[i] = new UnsafeRowWriter(holder, requestedSchema.getFieldCount()); } } @@ -295,7 +284,7 @@ private boolean loadBatch() throws IOException { if (containsVarLenFields) { for (int i = 0; i < rowWriters.length; ++i) { - rowWriters[i].holder().resetTo(fixedSizeBytes); + rowWriters[i].holder().reset(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 72eb1f6cf0518..738b9a35d1c9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -132,8 +132,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private BufferHolder bufferHolder = new BufferHolder(unsafeRow); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -181,9 +181,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; bufferHolder.reset(); - rowWriter.initialize(bufferHolder, $numFields); + rowWriter.zeroOutNullBytes(); ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); + unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index bd2d17c0189ec..430257f60d9fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -98,16 +98,15 @@ private[sql] class TextRelation( sqlContext.sparkContext.hadoopRDD( conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) .mapPartitions { iter => - val bufferHolder = new BufferHolder - val unsafeRowWriter = new UnsafeRowWriter val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) iter.map { case (_, line) => // Writes to an UnsafeRow directly bufferHolder.reset() - unsafeRowWriter.initialize(bufferHolder, 1) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) + unsafeRow.setTotalSize(bufferHolder.totalSize()) unsafeRow } } From 109061f7ad27225669cbe609ec38756b31d4e1b9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 25 Jan 2016 17:58:11 -0800 Subject: [PATCH 1686/2089] [SPARK-12936][SQL] Initial bloom filter implementation This PR adds an initial implementation of bloom filter in the newly added sketch module. The implementation is based on the [`BloomFilter` class in guava](https://code.google.com/p/guava-libraries/source/browse/guava/src/com/google/common/hash/BloomFilter.java). Some difference from the design doc: * expose `bitSize` instead of `sizeInBytes` to user. * always need the `expectedInsertions` parameter when create bloom filter. Author: Wenchen Fan Closes #10883 from cloud-fan/bloom-filter. --- .../apache/spark/util/sketch/BitArray.java | 94 ++++++++++ .../apache/spark/util/sketch/BloomFilter.java | 153 ++++++++++++++++ .../spark/util/sketch/BloomFilterImpl.java | 164 ++++++++++++++++++ .../spark/util/sketch/BitArraySuite.scala | 77 ++++++++ .../spark/util/sketch/BloomFilterSuite.scala | 114 ++++++++++++ 5 files changed, 602 insertions(+) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java create mode 100644 common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala create mode 100644 common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java new file mode 100644 index 0000000000000..1bc665ad54b72 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.util.Arrays; + +public final class BitArray { + private final long[] data; + private long bitCount; + + static int numWords(long numBits) { + long numWords = (long) Math.ceil(numBits / 64.0); + if (numWords > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); + } + return (int) numWords; + } + + BitArray(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive"); + } + this.data = new long[numWords(numBits)]; + long bitCount = 0; + for (long value : data) { + bitCount += Long.bitCount(value); + } + this.bitCount = bitCount; + } + + /** Returns true if the bit changed value. */ + boolean set(long index) { + if (!get(index)) { + data[(int) (index >>> 6)] |= (1L << index); + bitCount++; + return true; + } + return false; + } + + boolean get(long index) { + return (data[(int) (index >>> 6)] & (1L << index)) != 0; + } + + /** Number of bits */ + long bitSize() { + return (long) data.length * Long.SIZE; + } + + /** Number of set bits (1s) */ + long cardinality() { + return bitCount; + } + + /** Combines the two BitArrays using bitwise OR. */ + void putAll(BitArray array) { + assert data.length == array.data.length : "BitArrays must be of equal length when merging"; + long bitCount = 0; + for (int i = 0; i < data.length; i++) { + data[i] |= array.data[i]; + bitCount += Long.bitCount(data[i]); + } + this.bitCount = bitCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || !(o instanceof BitArray)) return false; + + BitArray bitArray = (BitArray) o; + return Arrays.equals(data, bitArray.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java new file mode 100644 index 0000000000000..38949c6311df8 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether + * an element is a member of a set. It returns false when the element is definitely not in the + * set, returns true when the element is probably in the set. + * + * Internally a Bloom filter is initialized with 2 information: how many space to use(number of + * bits) and how many hash values to calculate for each record. To get as lower false positive + * probability as possible, user should call {@link BloomFilter#create} to automatically pick a + * best combination of these 2 parameters. + * + * Currently the following data types are supported: + *
      + *
    • {@link Byte}
    • + *
    • {@link Short}
    • + *
    • {@link Integer}
    • + *
    • {@link Long}
    • + *
    • {@link String}
    • + *
    + * + * The implementation is largely based on the {@code BloomFilter} class from guava. + */ +public abstract class BloomFilter { + /** + * Returns the false positive probability, i.e. the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that + * has not actually been put in the {@code BloomFilter}. + * + *

    Ideally, this number should be close to the {@code fpp} parameter + * passed in to create this bloom filter, or smaller. If it is + * significantly higher, it is usually the case that too many elements (more than + * expected) have been put in the {@code BloomFilter}, degenerating it. + */ + public abstract double expectedFpp(); + + /** + * Returns the number of bits in the underlying bit array. + */ + public abstract long bitSize(); + + /** + * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@link #mightContain(Object)} with the same element will always return {@code true}. + * + * @return true if the bloom filter's bits changed as a result of this operation. If the bits + * changed, this is definitely the first time {@code object} has been added to the + * filter. If the bits haven't changed, this might be the first time {@code object} + * has been added to the filter. Note that {@code put(t)} always returns the + * opposite result to what {@code mightContain(t)} would have returned at the time + * it is called. + */ + public abstract boolean put(Object item); + + /** + * Determines whether a given bloom filter is compatible with this bloom filter. For two + * bloom filters to be compatible, they must have the same bit size. + * + * @param other The bloom filter to check for compatibility. + */ + public abstract boolean isCompatible(BloomFilter other); + + /** + * Combines this bloom filter with another bloom filter by performing a bitwise OR of the + * underlying data. The mutations happen to this instance. Callers must ensure the + * bloom filters are appropriately sized to avoid saturating them. + * + * @param other The bloom filter to combine this bloom filter with. It is not mutated. + * @throws IllegalArgumentException if {@code isCompatible(that) == false} + */ + public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; + + /** + * Returns {@code true} if the element might have been put in this Bloom filter, + * {@code false} if this is definitely not the case. + */ + public abstract boolean mightContain(Object item); + + /** + * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the + * expected insertions and total number of bits in the Bloom filter. + * + * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. + * + * @param n expected insertions (must be positive) + * @param m total number of bits in Bloom filter (must be positive) + */ + private static int optimalNumOfHashFunctions(long n, long m) { + // (m / n) * log(2), but avoid truncation due to division! + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + /** + * Computes m (total bits of Bloom filter) which is expected to achieve, for the specified + * expected insertions, the required false positive probability. + * + * See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula. + * + * @param n expected insertions (must be positive) + * @param p false positive rate (must be 0 < p < 1) + */ + private static long optimalNumOfBits(long n, double p) { + return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + static final double DEFAULT_FPP = 0.03; + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}. + */ + public static BloomFilter create(long expectedNumItems) { + return create(expectedNumItems, DEFAULT_FPP); + } + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick + * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter. + */ + public static BloomFilter create(long expectedNumItems, double fpp) { + assert fpp > 0.0 : "False positive probability must be > 0.0"; + assert fpp < 1.0 : "False positive probability must be < 1.0"; + long numBits = optimalNumOfBits(expectedNumItems, fpp); + return create(expectedNumItems, numBits); + } + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code numBits}, it will + * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. + */ + public static BloomFilter create(long expectedNumItems, long numBits) { + assert expectedNumItems > 0 : "Expected insertions must be > 0"; + assert numBits > 0 : "number of bits must be > 0"; + int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); + return new BloomFilterImpl(numHashFunctions, numBits); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java new file mode 100644 index 0000000000000..bbd6cf719dc0e --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.UnsupportedEncodingException; + +public class BloomFilterImpl extends BloomFilter { + + private final int numHashFunctions; + private final BitArray bits; + + BloomFilterImpl(int numHashFunctions, long numBits) { + this.numHashFunctions = numHashFunctions; + this.bits = new BitArray(numBits); + } + + @Override + public double expectedFpp() { + return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions); + } + + @Override + public long bitSize() { + return bits.bitSize(); + } + + private static long hashObjectToLong(Object item) { + if (item instanceof String) { + try { + byte[] bytes = ((String) item).getBytes("utf-8"); + return hashBytesToLong(bytes); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException("Only support utf-8 string", e); + } + } else { + long longValue; + + if (item instanceof Long) { + longValue = (Long) item; + } else if (item instanceof Integer) { + longValue = ((Integer) item).longValue(); + } else if (item instanceof Short) { + longValue = ((Short) item).longValue(); + } else if (item instanceof Byte) { + longValue = ((Byte) item).longValue(); + } else { + throw new IllegalArgumentException( + "Support for " + item.getClass().getName() + " not implemented" + ); + } + + int h1 = Murmur3_x86_32.hashLong(longValue, 0); + int h2 = Murmur3_x86_32.hashLong(longValue, h1); + return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + } + } + + private static long hashBytesToLong(byte[] bytes) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + } + + @Override + public boolean put(Object item) { + long bitSize = bits.bitSize(); + + // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash + // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy for long type, it hash the input long + // element with every i to produce n hash values. + long hash64 = hashObjectToLong(item); + int h1 = (int) (hash64 >> 32); + int h2 = (int) hash64; + + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContain(Object item) { + long bitSize = bits.bitSize(); + long hash64 = hashObjectToLong(item); + int h1 = (int) (hash64 >> 32); + int h2 = (int) hash64; + + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean isCompatible(BloomFilter other) { + if (other == null) { + return false; + } + + if (!(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions; + } + + @Override + public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException { + // Duplicates the logic of `isCompatible` here to provide better error message. + if (other == null) { + throw new IncompatibleMergeException("Cannot merge null bloom filter"); + } + + if (!(other instanceof BloomFilter)) { + throw new IncompatibleMergeException( + "Cannot merge bloom filter of class " + other.getClass().getName() + ); + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + if (this.bitSize() != that.bitSize()) { + throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size"); + } + + if (this.numHashFunctions != that.numHashFunctions) { + throw new IncompatibleMergeException( + "Cannot merge bloom filters with different number of hash functions"); + } + + this.bits.putAll(that.bits); + return this; + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala new file mode 100644 index 0000000000000..ff728f0ebcb85 --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite + + test("error case when create BitArray") { + intercept[IllegalArgumentException](new BitArray(0)) + intercept[IllegalArgumentException](new BitArray(64L * Integer.MAX_VALUE + 1)) + } + + test("bitSize") { + assert(new BitArray(64).bitSize() == 64) + // BitArray is word-aligned, so 65~128 bits need 2 long to store, which is 128 bits. + assert(new BitArray(65).bitSize() == 128) + assert(new BitArray(127).bitSize() == 128) + assert(new BitArray(128).bitSize() == 128) + } + + test("set") { + val bitArray = new BitArray(64) + assert(bitArray.set(1)) + // Only returns true if the bit changed. + assert(!bitArray.set(1)) + assert(bitArray.set(2)) + } + + test("normal operation") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val bitArray = new BitArray(320) + val indexes = (1 to 100).map(_ => r.nextInt(320).toLong).distinct + + indexes.foreach(bitArray.set) + indexes.foreach(i => assert(bitArray.get(i))) + assert(bitArray.cardinality() == indexes.length) + } + + test("merge") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val bitArray1 = new BitArray(64 * 6) + val bitArray2 = new BitArray(64 * 6) + + val indexes1 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct + val indexes2 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct + + indexes1.foreach(bitArray1.set) + indexes2.foreach(bitArray2.set) + + bitArray1.putAll(bitArray2) + indexes1.foreach(i => assert(bitArray1.get(i))) + indexes2.foreach(i => assert(bitArray1.get(i))) + assert(bitArray1.cardinality() == (indexes1 ++ indexes2).distinct.length) + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala new file mode 100644 index 0000000000000..d2de509f19517 --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite + private final val EPSILON = 0.01 + + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + test(s"accuracy - $typeName") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + val fpp = 0.05 + val numInsertion = numItems / 10 + + val allItems = Array.fill(numItems)(itemGen(r)) + + val filter = BloomFilter.create(numInsertion, fpp) + + // insert first `numInsertion` items. + allItems.take(numInsertion).foreach(filter.put) + + // false negative is not allowed. + assert(allItems.take(numInsertion).forall(filter.mightContain)) + + // The number of inserted items doesn't exceed `expectedNumItems`, so the `expectedFpp` + // should not be significantly higher than the one we passed in to create this bloom filter. + assert(filter.expectedFpp() - fpp < EPSILON) + + val errorCount = allItems.drop(numInsertion).count(filter.mightContain) + + // Also check the actual fpp is not significantly higher than we expected. + val actualFpp = errorCount.toDouble / (numItems - numInsertion) + assert(actualFpp - fpp < EPSILON) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val items1 = Array.fill(numItems / 2)(itemGen(r)) + val items2 = Array.fill(numItems / 2)(itemGen(r)) + + val filter1 = BloomFilter.create(numItems) + items1.foreach(filter1.put) + + val filter2 = BloomFilter.create(numItems) + items2.foreach(filter2.put) + + filter1.mergeInPlace(filter2) + + // After merge, `filter1` has `numItems` items which doesn't exceed `expectedNumItems`, so the + // `expectedFpp` should not be significantly higher than the default one. + assert(filter1.expectedFpp() - BloomFilter.DEFAULT_FPP < EPSILON) + + items1.foreach(i => assert(filter1.mightContain(i))) + items2.foreach(i => assert(filter1.mightContain(i))) + } + } + + def testItemType[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + testAccuracy[T](typeName, numItems)(itemGen) + testMergeInPlace[T](typeName, numItems)(itemGen) + } + + testItemType[Byte]("Byte", 160) { _.nextInt().toByte } + + testItemType[Short]("Short", 1000) { _.nextInt().toShort } + + testItemType[Int]("Int", 100000) { _.nextInt() } + + testItemType[Long]("Long", 100000) { _.nextLong() } + + testItemType[String]("String", 100000) { r => r.nextString(r.nextInt(512)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + BloomFilter.create(1000).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(1000, 6400) + val filter2 = BloomFilter.create(1000, 3200) + filter1.mergeInPlace(filter2) + } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(1000, 6400) + val filter2 = BloomFilter.create(2000, 6400) + filter1.mergeInPlace(filter2) + } + } +} From fdcc3512f7b45e5b067fc26cb05146f79c4a5177 Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 25 Jan 2016 18:23:47 -0800 Subject: [PATCH 1687/2089] [SPARK-12934] use try-with-resources for streams liancheng please take a look Author: tedyu Closes #10906 from tedyu/master. --- .../main/java/org/apache/spark/util/sketch/CountMinSketch.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 67938644d9f6c..9f4ff42403c34 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -128,11 +128,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) /** * Writes out this {@link CountMinSketch} to an output stream in binary format. + * It is the caller's responsibility to close the stream */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. + * It is the caller's responsibility to close the stream */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); From b66afdeb5253913d916dcf159aaed4ffdc15fd4b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 25 Jan 2016 22:38:31 -0800 Subject: [PATCH 1688/2089] [SPARK-11922][PYSPARK][ML] Python api for ml.feature.quantile discretizer Add Python API for ml.feature.QuantileDiscretizer. One open question: Do we want to do this stuff to re-use the java model, create a new model, or use a different wrapper around the java model. cc brkyvz & mengxr Author: Holden Karau Closes #10085 from holdenk/SPARK-11937-SPARK-11922-Python-API-for-ml.feature.QuantileDiscretizer. --- python/pyspark/ml/feature.py | 89 ++++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1fa0eab384e7b..f139d81bc490d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -30,10 +30,10 @@ __all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', - 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel'] + 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', + 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -991,6 +991,87 @@ def getDegree(self): return self.getOrDefault(self.degree) +@inherit_doc +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned + categorical features. The bin ranges are chosen by taking a sample of the data and dividing it + into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, + covering all real values. This attempts to find numBuckets partitions based on a sample of data, + but it may find fewer depending on the data sample values. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> qds = QuantileDiscretizer(numBuckets=2, + ... inputCol="values", outputCol="buckets") + >>> bucketizer = qds.fit(df) + >>> splits = bucketizer.getSplits() + >>> splits[0] + -inf + >>> print("%2.1f" % round(splits[1], 1)) + 0.4 + >>> bucketed = bucketizer.transform(df).head() + >>> bucketed.buckets + 0.0 + + .. versionadded:: 2.0.0 + """ + + # a placeholder to make it appear in the generated doc + numBuckets = Param(Params._dummy(), "numBuckets", + "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2. Default 2.") + + @keyword_only + def __init__(self, numBuckets=2, inputCol=None, outputCol=None): + """ + __init__(self, numBuckets=2, inputCol=None, outputCol=None) + """ + super(QuantileDiscretizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", + self.uid) + self.numBuckets = Param(self, "numBuckets", + "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2.") + self._setDefault(numBuckets=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numBuckets=2, inputCol=None, outputCol=None): + """ + setParams(self, numBuckets=2, inputCol=None, outputCol=None) + Set the params for the QuantileDiscretizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumBuckets(self, value): + """ + Sets the value of :py:attr:`numBuckets`. + """ + self._paramMap[self.numBuckets] = value + return self + + @since("2.0.0") + def getNumBuckets(self): + """ + Gets the value of numBuckets or its default value. + """ + return self.getOrDefault(self.numBuckets) + + def _create_model(self, java_model): + """ + Private method to convert the java_model to a Python model. + """ + return Bucketizer(splits=list(java_model.getSplits()), + inputCol=self.getInputCol(), + outputCol=self.getOutputCol()) + + @inherit_doc @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): From ae47ba718a280fc12720a71b981c38dbe647f35b Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 25 Jan 2016 22:41:52 -0800 Subject: [PATCH 1689/2089] [SPARK-12834] Change ser/de of JavaArray and JavaList https://issues.apache.org/jira/browse/SPARK-12834 We use `SerDe.dumps()` to serialize `JavaArray` and `JavaList` in `PythonMLLibAPI`, then deserialize them with `PickleSerializer` in Python side. However, there is no need to transform them in such an inefficient way. Instead of it, we can use type conversion to convert them, e.g. `list(JavaArray)` or `list(JavaList)`. What's more, there is an issue to Ser/De Scala Array as I said in https://issues.apache.org/jira/browse/SPARK-12780 Author: Xusen Yin Closes #10772 from yinxusen/SPARK-12834. --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 05f9a76d32671..088ec6a0c0465 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1490,7 +1490,11 @@ private[spark] object SerDe extends Serializable { initialize() def dumps(obj: AnyRef): Array[Byte] = { - new Pickler().dumps(obj) + obj match { + // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834. + case array: Array[_] => new Pickler().dumps(array.toSeq.asJava) + case _ => new Pickler().dumps(obj) + } } def loads(bytes: Array[Byte]): AnyRef = { From 27c910f7f29087d1ac216d4933d641d6515fd6ad Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Jan 2016 22:53:34 -0800 Subject: [PATCH 1690/2089] [SPARK-10086][MLLIB][STREAMING][PYSPARK] ignore StreamingKMeans test in PySpark for now I saw several failures from recent PR builds, e.g., https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/50015/consoleFull. This PR marks the test as ignored and we will fix the flakyness in SPARK-10086. gliptak Do you know why the test failure didn't show up in the Jenkins "Test Result"? cc: jkbradley Author: Xiangrui Meng Closes #10909 from mengxr/SPARK-10086. --- python/pyspark/mllib/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 79ce4959c9266..25a7c29982b3b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -1189,6 +1189,7 @@ def condition(): self._eventually(condition, catch_assertions=True) + @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") def test_trainOn_predictOn(self): """Test that prediction happens on the updated model.""" stkm = StreamingKMeans(decayFactor=0.0, k=2) From d54cfed5a6953a9ce2b9de2f31ee2d673cb5cc62 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 26 Jan 2016 00:51:08 -0800 Subject: [PATCH 1691/2089] [SQL][MINOR] A few minor tweaks to CSV reader. This pull request simply fixes a few minor coding style issues in csv, as I was reviewing the change post-hoc. Author: Reynold Xin Closes #10919 from rxin/csv-minor. --- .../datasources/csv/CSVInferSchema.scala | 21 +++++++------------ .../datasources/csv/CSVRelation.scala | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 0aa4539e60516..ace8cd7ad864e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -30,16 +30,15 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ -private[sql] object CSVInferSchema { +private[csv] object CSVInferSchema { /** * Similar to the JSON schema inference * 1. Infer type of each row * 2. Merge row types to find common type * 3. Replace any null types with string type - * TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670] */ - def apply( + def infer( tokenRdd: RDD[Array[String]], header: Array[String], nullValue: String = ""): StructType = { @@ -65,10 +64,7 @@ private[sql] object CSVInferSchema { rowSoFar } - private[csv] def mergeRowTypes( - first: Array[DataType], - second: Array[DataType]): Array[DataType] = { - + def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case ((a, b)) => val tpe = findTightestCommonType(a, b).getOrElse(StringType) tpe match { @@ -82,8 +78,7 @@ private[sql] object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - private[csv] def inferField( - typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { + def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { if (field == null || field.isEmpty || field == nullValue) { typeSoFar } else { @@ -155,7 +150,8 @@ private[sql] object CSVInferSchema { } } -object CSVTypeCast { + +private[csv] object CSVTypeCast { /** * Casts given string datum to specified type. @@ -167,7 +163,7 @@ object CSVTypeCast { * @param datum string value * @param castType SparkSQL type */ - private[csv] def castTo( + def castTo( datum: String, castType: DataType, nullable: Boolean = true, @@ -201,10 +197,9 @@ object CSVTypeCast { * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one * character. - * */ @throws[IllegalArgumentException] - private[csv] def toChar(str: String): Char = { + def toChar(str: String): Char = { if (str.charAt(0) == '\\') { str.charAt(1) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 5959f7cc5051b..dc449fea956f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -139,7 +139,7 @@ private[csv] class CSVRelation( val parsedRdd = tokenRdd(header, paths) if (params.inferSchemaFlag) { - CSVInferSchema(parsedRdd, header, params.nullValue) + CSVInferSchema.infer(parsedRdd, header, params.nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => From 6743de3a98e3f0d0e6064ca1872fa88c3aeaa143 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jan 2016 00:53:05 -0800 Subject: [PATCH 1692/2089] [SPARK-12937][SQL] bloom filter serialization This PR adds serialization support for BloomFilter. A version number is added to version the serialized binary format. Author: Wenchen Fan Closes #10920 from cloud-fan/bloom-filter. --- .../apache/spark/util/sketch/BitArray.java | 46 +++++++++++++----- .../apache/spark/util/sketch/BloomFilter.java | 42 +++++++++++++++- .../spark/util/sketch/BloomFilterImpl.java | 48 ++++++++++++++++++- .../spark/util/sketch/CountMinSketch.java | 25 ++++++---- .../spark/util/sketch/CountMinSketchImpl.java | 22 +-------- .../spark/util/sketch/BloomFilterSuite.scala | 20 ++++++++ 6 files changed, 159 insertions(+), 44 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 1bc665ad54b72..2a0484e324b13 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -17,6 +17,9 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.Arrays; public final class BitArray { @@ -24,6 +27,9 @@ public final class BitArray { private long bitCount; static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); + } long numWords = (long) Math.ceil(numBits / 64.0); if (numWords > Integer.MAX_VALUE) { throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); @@ -32,13 +38,14 @@ static int numWords(long numBits) { } BitArray(long numBits) { - if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive"); - } - this.data = new long[numWords(numBits)]; + this(new long[numWords(numBits)]); + } + + private BitArray(long[] data) { + this.data = data; long bitCount = 0; - for (long value : data) { - bitCount += Long.bitCount(value); + for (long word : data) { + bitCount += Long.bitCount(word); } this.bitCount = bitCount; } @@ -78,13 +85,28 @@ void putAll(BitArray array) { this.bitCount = bitCount; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof BitArray)) return false; + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } + + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } - BitArray bitArray = (BitArray) o; - return Arrays.equals(data, bitArray.data); + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); } @Override diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 38949c6311df8..00378d58518f6 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -17,6 +17,10 @@ package org.apache.spark.util.sketch; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + /** * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether * an element is a member of a set. It returns false when the element is definitely not in the @@ -39,6 +43,28 @@ * The implementation is largely based on the {@code BloomFilter} class from guava. */ public abstract class BloomFilter { + + public enum Version { + /** + * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total number of words of the underlying bit array (32 bit) + * - The words/longs (numWords * 64 bit) + * - Number of hash functions (32 bit) + */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + /** * Returns the false positive probability, i.e. the probability that * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that @@ -83,7 +109,7 @@ public abstract class BloomFilter { * bloom filters are appropriately sized to avoid saturating them. * * @param other The bloom filter to combine this bloom filter with. It is not mutated. - * @throws IllegalArgumentException if {@code isCompatible(that) == false} + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} */ public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; @@ -93,6 +119,20 @@ public abstract class BloomFilter { */ public abstract boolean mightContain(Object item); + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. + * It is the caller's responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link BloomFilter} from an input stream. + * It is the caller's responsibility to close the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + /** * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index bbd6cf719dc0e..1c08d07afaeaa 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -17,7 +17,7 @@ package org.apache.spark.util.sketch; -import java.io.UnsupportedEncodingException; +import java.io.*; public class BloomFilterImpl extends BloomFilter { @@ -25,8 +25,32 @@ public class BloomFilterImpl extends BloomFilter { private final BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; this.numHashFunctions = numHashFunctions; - this.bits = new BitArray(numBits); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; } @Override @@ -161,4 +185,24 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep this.bits.putAll(that.bits); return this; } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + bits.writeTo(dos); + dos.writeInt(numHashFunctions); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 9f4ff42403c34..00c0b1b9e2db8 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -55,10 +55,21 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { - /** - * Version number of the serialized binary format. - */ + public enum Version { + /** + * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ V1(1); private final int versionNumber; @@ -67,13 +78,11 @@ public enum Version { this.versionNumber = versionNumber; } - public int getVersionNumber() { + int getVersionNumber() { return versionNumber; } } - public abstract Version version(); - /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -128,13 +137,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) /** * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 0209446ea3b1d..d08809605a932 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -26,21 +26,6 @@ import java.util.Arrays; import java.util.Random; -/* - * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian - * order): - * - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) - */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; @@ -112,11 +97,6 @@ public int hashCode() { return hash; } - @Override - public Version version() { - return Version.V1; - } - private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -327,7 +307,7 @@ public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMerg public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); - dos.writeInt(version().getVersionNumber()); + dos.writeInt(Version.V1.getVersionNumber()); dos.writeLong(this.totalCount); dos.writeInt(this.depth); diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index d2de509f19517..a0408d2da4dff 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite private final val EPSILON = 0.01 + // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(filter: BloomFilter): Unit = { + val out = new ByteArrayOutputStream() + filter.writeTo(out) + out.close() + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = BloomFilter.readFrom(in) + in.close() + + assert(filter == deserialized) + } + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { test(s"accuracy - $typeName") { // use a fixed seed to make the test predictable. @@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite // Also check the actual fpp is not significantly higher than we expected. val actualFpp = errorCount.toDouble / (numItems - numInsertion) assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) } } @@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite items1.foreach(i => assert(filter1.mightContain(i))) items2.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) } } From 5936bf9fa85ccf7f0216145356140161c2801682 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Jan 2016 11:36:00 +0000 Subject: [PATCH 1693/2089] [SPARK-12961][CORE] Prevent snappy-java memory leak JIRA: https://issues.apache.org/jira/browse/SPARK-12961 To prevent memory leak in snappy-java, just call the method once and cache the result. After the library releases new version, we can remove this object. JoshRosen Author: Liang-Chi Hsieh Closes #10875 from viirya/prevent-snappy-memory-leak. --- .../apache/spark/io/CompressionCodec.scala | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 717804626f852..ae014becef755 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -149,12 +149,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { - - try { - Snappy.getNativeLibraryVersion - } catch { - case e: Error => throw new IllegalArgumentException(e) - } + val version = SnappyCompressionCodec.version override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt @@ -164,6 +159,19 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } +/** + * Object guards against memory leak bug in snappy-java library: + * (https://github.com/xerial/snappy-java/issues/131). + * Before a new version of the library, we only call the method once and cache the result. + */ +private final object SnappyCompressionCodec { + private lazy val version: String = try { + Snappy.getNativeLibraryVersion + } catch { + case e: Error => throw new IllegalArgumentException(e) + } +} + /** * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version From 649e9d0f5b2d5fc13f2dd5be675331510525927f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 26 Jan 2016 11:55:28 +0000 Subject: [PATCH 1694/2089] [SPARK-3369][CORE][STREAMING] Java mapPartitions Iterator->Iterable is inconsistent with Scala's Iterator->Iterator Fix Java function API methods for flatMap and mapPartitions to require producing only an Iterator, not Iterable. Also fix DStream.flatMap to require a function producing TraversableOnce only, not Traversable. CC rxin pwendell for API change; tdas since it also touches streaming. Author: Sean Owen Closes #10413 from srowen/SPARK-3369. --- .../api/java/function/CoGroupFunction.java | 2 +- .../java/function/DoubleFlatMapFunction.java | 3 +- .../api/java/function/FlatMapFunction.java | 3 +- .../api/java/function/FlatMapFunction2.java | 3 +- .../java/function/FlatMapGroupsFunction.java | 2 +- .../java/function/MapPartitionsFunction.java | 2 +- .../java/function/PairFlatMapFunction.java | 3 +- .../apache/spark/api/java/JavaRDDLike.scala | 20 ++++++------ .../java/org/apache/spark/JavaAPISuite.java | 24 +++++++------- docs/streaming-programming-guide.md | 4 +-- .../apache/spark/examples/JavaPageRank.java | 18 +++++------ .../apache/spark/examples/JavaWordCount.java | 5 +-- .../streaming/JavaActorWordCount.java | 5 +-- .../streaming/JavaCustomReceiver.java | 7 +++-- .../streaming/JavaDirectKafkaWordCount.java | 6 ++-- .../streaming/JavaKafkaWordCount.java | 9 +++--- .../streaming/JavaNetworkWordCount.java | 11 ++++--- .../JavaRecoverableNetworkWordCount.java | 6 ++-- .../streaming/JavaSqlNetworkWordCount.java | 8 ++--- .../JavaStatefulNetworkWordCount.java | 5 +-- .../JavaTwitterHashTagJoinSentiments.java | 5 +-- .../java/org/apache/spark/Java8APISuite.java | 2 +- .../streaming/JavaKinesisWordCountASL.java | 9 ++++-- project/MimaExcludes.scala | 31 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 +-- .../apache/spark/sql/JavaDatasetSuite.java | 25 +++++++-------- .../streaming/api/java/JavaDStreamLike.scala | 10 +++--- .../spark/streaming/dstream/DStream.scala | 2 +- .../streaming/dstream/FlatMappedDStream.scala | 2 +- .../apache/spark/streaming/JavaAPISuite.java | 20 ++++++------ 30 files changed, 146 insertions(+), 110 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java index 279639af5d430..07aebb75e8f4e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -25,5 +25,5 @@ * Datasets. */ public interface CoGroupFunction extends Serializable { - Iterable call(K key, Iterator left, Iterator right) throws Exception; + Iterator call(K key, Iterator left, Iterator right) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 57fd0a7a80494..576087b6f428e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that returns zero or more records of type Double from each input record. */ public interface DoubleFlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index ef0d1824121ec..2d8ea6d1a5a7e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - Iterable call(T t) throws Exception; + Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index 14a98a38ef5ab..fc97b63f825d0 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - Iterable call(T1 t1, T2 t2) throws Exception; + Iterator call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index d7a80e7b129b0..bae574ab5755d 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -24,5 +24,5 @@ * A function that returns zero or more output records from each grouping key and its values. */ public interface FlatMapGroupsFunction extends Serializable { - Iterable call(K key, Iterator values) throws Exception; + Iterator call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java index 6cb569ce0cb6b..cf9945a215aff 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -24,5 +24,5 @@ * Base interface for function used in Dataset's mapPartitions. */ public interface MapPartitionsFunction extends Serializable { - Iterable call(Iterator input) throws Exception; + Iterator call(Iterator input) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java index 691ef2eceb1f6..51eed2e67b9fa 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -18,6 +18,7 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; import scala.Tuple2; @@ -26,5 +27,5 @@ * key-value pairs are represented as scala.Tuple2 objects. */ public interface PairFlatMapFunction extends Serializable { - public Iterable> call(T t) throws Exception; + Iterator> call(T t) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0f8d13cf5cc2f..7340defabfe58 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -121,7 +121,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { - def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -130,7 +130,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { - def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[jl.Double] = (x: T) => f.call(x).asScala new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } @@ -139,7 +139,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -149,7 +149,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -160,7 +160,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) @@ -171,7 +171,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } @@ -182,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -193,7 +193,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) @@ -205,7 +205,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) @@ -290,7 +290,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 44d5cac7c2de5..8117ad9e60641 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -880,8 +880,8 @@ public void flatMap() { "The quick brown fox jumps over the lazy dog.")); JavaRDD words = rdd.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(x.split(" ")); + public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); } }); Assert.assertEquals("Hello", words.first()); @@ -890,12 +890,12 @@ public Iterable call(String x) { JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String s) { + public Iterator> call(String s) { List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { pairs.add(new Tuple2<>(word, word)); } - return pairs; + return pairs.iterator(); } } ); @@ -904,12 +904,12 @@ public Iterable> call(String s) { JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override - public Iterable call(String s) { + public Iterator call(String s) { List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } - return lengths; + return lengths.iterator(); } }); Assert.assertEquals(5.0, doubles.first(), 0.01); @@ -930,8 +930,8 @@ public void mapsFromPairsToPairs() { JavaPairRDD swapped = pairRDD.flatMapToPair( new PairFlatMapFunction, String, Integer>() { @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); + public Iterator> call(Tuple2 item) { + return Collections.singletonList(item.swap()).iterator(); } }); swapped.collect(); @@ -951,12 +951,12 @@ public void mapPartitions() { JavaRDD partitionSums = rdd.mapPartitions( new FlatMapFunction, Integer>() { @Override - public Iterable call(Iterator iter) { + public Iterator call(Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); } - return Collections.singletonList(sum); + return Collections.singletonList(sum).iterator(); } }); Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); @@ -1367,8 +1367,8 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = new FlatMapFunction2, Iterator, Integer>() { @Override - public Iterable call(Iterator i, Iterator s) { - return Arrays.asList(Iterators.size(i), Iterators.size(s)); + public Iterator call(Iterator i, Iterator s) { + return Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); } }; diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 93c34efb6662d..7e681b67cf0c2 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -165,8 +165,8 @@ space into words. // Split each line into words JavaDStream words = lines.flatMap( new FlatMapFunction() { - @Override public Iterable call(String x) { - return Arrays.asList(x.split(" ")); + @Override public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); } }); {% endhighlight %} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index a5db8accdf138..635fb6a373c47 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -17,7 +17,10 @@ package org.apache.spark.examples; - +import java.util.ArrayList; +import java.util.List; +import java.util.Iterator; +import java.util.regex.Pattern; import scala.Tuple2; @@ -32,11 +35,6 @@ import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; -import java.util.ArrayList; -import java.util.List; -import java.util.Iterator; -import java.util.regex.Pattern; - /** * Computes the PageRank of URLs from an input file. Input file should * be in format of: @@ -108,13 +106,13 @@ public Double call(Iterable rs) { JavaPairRDD contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { @Override - public Iterable> call(Tuple2, Double> s) { + public Iterator> call(Tuple2, Double> s) { int urlCount = Iterables.size(s._1); - List> results = new ArrayList>(); + List> results = new ArrayList<>(); for (String n : s._1) { - results.add(new Tuple2(n, s._2() / urlCount)); + results.add(new Tuple2<>(n, s._2() / urlCount)); } - return results; + return results.iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 9a6a944f7edef..d746a3d2b6773 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -27,6 +27,7 @@ import org.apache.spark.api.java.function.PairFunction; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -46,8 +47,8 @@ public static void main(String[] args) throws Exception { JavaRDD words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) { - return Arrays.asList(SPACE.split(s)); + public Iterator call(String s) { + return Arrays.asList(SPACE.split(s)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java index 62e563380a9e7..cf774667f6c5f 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; +import java.util.Iterator; import scala.Tuple2; @@ -120,8 +121,8 @@ public static void main(String[] args) { // compute wordcount lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) { - return Arrays.asList(s.split("\\s+")); + public Iterator call(String s) { + return Arrays.asList(s.split("\\s+")).iterator(); } }).mapToPair(new PairFunction() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 4b50fbf59f80e..3d668adcf815f 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -17,7 +17,6 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; import com.google.common.io.Closeables; import org.apache.spark.SparkConf; @@ -37,6 +36,8 @@ import java.io.InputStreamReader; import java.net.ConnectException; import java.net.Socket; +import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; /** @@ -74,8 +75,8 @@ public static void main(String[] args) { new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index f9a5e7f69ffe1..5107500a127c5 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -20,11 +20,11 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.collect.Lists; import kafka.serializer.StringDecoder; import org.apache.spark.SparkConf; @@ -87,8 +87,8 @@ public String call(Tuple2 tuple2) { }); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 337f8ffb5bfb0..0df4cb40a9a76 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -17,20 +17,19 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; import java.util.Map; import java.util.HashMap; import java.util.regex.Pattern; - import scala.Tuple2; -import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -88,8 +87,8 @@ public String call(Tuple2 tuple2) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index 3e9f0f4b8f127..b82b319acb735 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -17,8 +17,11 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; +import java.util.regex.Pattern; + import scala.Tuple2; -import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -31,8 +34,6 @@ import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import java.util.regex.Pattern; - /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. * @@ -67,8 +68,8 @@ public static void main(String[] args) { args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index bc963a02be608..bc8cbcdef7272 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -21,11 +21,11 @@ import java.io.IOException; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.collect.Lists; import com.google.common.io.Files; import org.apache.spark.Accumulator; @@ -138,8 +138,8 @@ private static JavaStreamingContext createContext(String ip, JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 084f68a8be437..f0228f5e63450 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -17,10 +17,10 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; -import com.google.common.collect.Lists; - import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; @@ -72,8 +72,8 @@ public static void main(String[] args) { args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index f52cc7c20576b..6beab90f086d8 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -73,8 +74,8 @@ public static void main(String[] args) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java index d869768026ae3..f0ae9a99bae47 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java @@ -34,6 +34,7 @@ import twitter4j.Status; import java.util.Arrays; +import java.util.Iterator; import java.util.List; /** @@ -70,8 +71,8 @@ public static void main(String[] args) { JavaDStream words = stream.flatMap(new FlatMapFunction() { @Override - public Iterable call(Status s) { - return Arrays.asList(s.getText().split(" ")); + public Iterator call(Status s) { + return Arrays.asList(s.getText().split(" ")).iterator(); } }); diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 27d494ce355f7..c0b58e713f642 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -294,7 +294,7 @@ public void zipPartitions() { sizeS += 1; s.next(); } - return Arrays.asList(sizeI, sizeS); + return Arrays.asList(sizeI, sizeS).iterator(); }; JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 06e0ff28afd95..64e044aa8e4a4 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -16,7 +16,10 @@ */ package org.apache.spark.examples.streaming; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -38,7 +41,6 @@ import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.google.common.collect.Lists; /** * Consumes messages from a Amazon Kinesis streams and does wordcount. @@ -154,8 +156,9 @@ public static void main(String[] args) { // Convert each line of Array[Byte] to String, and split into words JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + public Iterator call(byte[] line) { + String s = new String(line, StandardCharsets.UTF_8); + return Arrays.asList(WORD_SEPARATOR.split(s)).iterator(); } }); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 501456b043170..643bee69694df 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -60,6 +60,37 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") ) ++ + Seq( + // SPARK-3369 Fix Iterable/Iterator in Java API + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapGroupsFunction.call") + ) ++ Seq( // SPARK-4819 replace Guava Optional ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd99c399571c8..f182270a08729 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -346,7 +346,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) } @@ -366,7 +366,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (T) => Iterable[U] = x => f.call(x).asScala + val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3c0f25a5dc535..a6fb62c17d59b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -111,24 +111,24 @@ public Integer call(String v) throws Exception { Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @Override - public Iterable call(Iterator it) throws Exception { - List ls = new LinkedList(); + public Iterator call(Iterator it) { + List ls = new LinkedList<>(); while (it.hasNext()) { - ls.add(it.next().toUpperCase()); + ls.add(it.next().toUpperCase(Locale.ENGLISH)); } - return ls; + return ls.iterator(); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset flatMapped = ds.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) throws Exception { - List ls = new LinkedList(); + public Iterator call(String s) { + List ls = new LinkedList<>(); for (char c : s.toCharArray()) { ls.add(String.valueOf(c)); } - return ls; + return ls.iterator(); } }, Encoders.STRING()); Assert.assertEquals( @@ -192,12 +192,12 @@ public String call(Integer key, Iterator values) throws Exception { Dataset flatMapped = grouped.flatMapGroups( new FlatMapGroupsFunction() { @Override - public Iterable call(Integer key, Iterator values) throws Exception { + public Iterator call(Integer key, Iterator values) { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } - return Collections.singletonList(sb.toString()); + return Collections.singletonList(sb.toString()).iterator(); } }, Encoders.STRING()); @@ -228,10 +228,7 @@ public Integer call(Integer v) throws Exception { grouped2, new CoGroupFunction() { @Override - public Iterable call( - Integer key, - Iterator left, - Iterator right) throws Exception { + public Iterator call(Integer key, Iterator left, Iterator right) { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); @@ -240,7 +237,7 @@ public Iterable call( while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()); + return Collections.singletonList(sb.toString()).iterator(); } }, Encoders.STRING()); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a791a474c673d..f10de485d0f75 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -166,8 +166,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * and then flattening the results */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { - import scala.collection.JavaConverters._ - def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -176,8 +175,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * and then flattening the results */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { - import scala.collection.JavaConverters._ - def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -189,7 +187,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -202,7 +200,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1dfb4e7abc0ed..db79eeab9c0cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -550,7 +550,7 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope { + def flatMap[U: ClassTag](flatMapFunc: T => TraversableOnce[U]): DStream[U] = ssc.withScope { new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 96a444a7baa5e..d60a6179782e0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -25,7 +25,7 @@ import org.apache.spark.streaming.{Duration, Time} private[streaming] class FlatMappedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], - flatMapFunc: T => Traversable[U] + flatMapFunc: T => TraversableOnce[U] ) extends DStream[U](parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 4dbcef293487c..806cea24caddb 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -271,12 +271,12 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override - public Iterable call(Iterator in) { + public Iterator call(Iterator in) { StringBuilder out = new StringBuilder(); while (in.hasNext()) { out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Arrays.asList(out.toString()); + return Arrays.asList(out.toString()).iterator(); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -759,8 +759,8 @@ public void testFlatMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(x.split("(?!^)")); + public Iterator call(String x) { + return Arrays.asList(x.split("(?!^)")).iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -846,12 +846,12 @@ public void testPairFlatMap() { JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) { + public Iterator> call(String in) { List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { out.add(new Tuple2<>(in.length(), letter)); } - return out; + return out.iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -1019,13 +1019,13 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) { + public Iterator> call(Iterator> in) { List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); } - return out; + return out.iterator(); } }); @@ -1089,12 +1089,12 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) { + public Iterator> call(Tuple2 in) { List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { out.add(new Tuple2<>(in._2(), s.toString())); } - return out; + return out.iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); From ae0309a8812a4fade3a0ea67d8986ca870aeb9eb Mon Sep 17 00:00:00 2001 From: zhuol Date: Tue, 26 Jan 2016 09:40:02 -0600 Subject: [PATCH 1695/2089] [SPARK-10911] Executors should System.exit on clean shutdown. Call system.exit explicitly to make sure non-daemon user threads terminate. Without this, user applications might live forever if the cluster manager does not appropriately kill them. E.g., YARN had this bug: HADOOP-12441. Author: zhuol Closes #9946 from zhuoliu/10911. --- .../org/apache/spark/executor/CoarseGrainedExecutorBackend.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index e3a6c4c07a75e..136cf4a84d387 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -241,6 +241,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) + System.exit(0) } private def printUsageAndExit() = { From 08c781ca672820be9ba32838bbe40d2643c4bde4 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 26 Jan 2016 07:50:37 -0800 Subject: [PATCH 1696/2089] [SPARK-12682][SQL] Add support for (optionally) not storing tables in hive metadata format This PR adds a new table option (`skip_hive_metadata`) that'd allow the user to skip storing the table metadata in hive metadata format. While this could be useful in general, the specific use-case for this change is that Hive doesn't handle wide schemas well (see https://issues.apache.org/jira/browse/SPARK-12682 and https://issues.apache.org/jira/browse/SPARK-6024) which in turn prevents such tables from being queried in SparkSQL. Author: Sameer Agarwal Closes #10826 from sameeragarwal/skip-hive-metadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 ++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0cfe03ba91ec7..80e45d5162801 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -327,7 +327,14 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // TODO: Support persisting partitioned data source relations in Hive compatible format val qualifiedTableName = tableIdent.quotedString + val skipHiveMetadata = options.getOrElse("skipHiveMetadata", "false").toBoolean val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.relation) match { + case _ if skipHiveMetadata => + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 211932fea00ed..d9e4b020fdfcc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -900,4 +900,36 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sqlContext.sql("""use default""") sqlContext.sql("""drop database if exists testdb8156 CASCADE""") } + + test("skip hive metadata on table creation") { + val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) + + catalog.createDataSourceTable( + tableIdent = TableIdentifier("not_skip_hive_metadata"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "parquet", + options = Map("path" -> "just a dummy path", "skipHiveMetadata" -> "false"), + isExternal = false) + + // As a proxy for verifying that the table was stored in Hive compatible format, we verify that + // each column of the table is of native type StringType. + assert(catalog.client.getTable("default", "not_skip_hive_metadata").schema + .forall(column => HiveMetastoreTypes.toDataType(column.hiveType) == StringType)) + + catalog.createDataSourceTable( + tableIdent = TableIdentifier("skip_hive_metadata"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "parquet", + options = Map("path" -> "just a dummy path", "skipHiveMetadata" -> "true"), + isExternal = false) + + // As a proxy for verifying that the table was stored in SparkSQL format, we verify that + // the table has a column type as array of StringType. + assert(catalog.client.getTable("default", "skip_hive_metadata").schema + .forall(column => HiveMetastoreTypes.toDataType(column.hiveType) == ArrayType(StringType))) + } } From cbd507d69cea24adfb335d8fe26ab5a13c053ffc Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 26 Jan 2016 11:31:54 -0800 Subject: [PATCH 1697/2089] [SPARK-7799][STREAMING][DOCUMENT] Add the linking and deploying instructions for streaming-akka project Since `actorStream` is an external project, we should add the linking and deploying instructions for it. A follow up PR of #10744 Author: Shixiong Zhu Closes #10856 from zsxwing/akka-link-instruction. --- docs/streaming-custom-receivers.md | 81 ++++++++++++++++-------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 95b99862ec062..84547748618d1 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -257,54 +257,61 @@ The following table summarizes the characteristics of both types of receivers ## Implementing and Using a Custom Actor-based Receiver -

    -
    - Custom [Akka Actors](http://doc.akka.io/docs/akka/2.3.11/scala/actors.html) can also be used to -receive data. Extending [`ActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.ActorReceiver) -allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of -this actor can be configured to handle failures, etc. +receive data. Here are the instructions. -{% highlight scala %} +1. **Linking:** You need to add the following dependency to your SBT or Maven project (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). -class CustomActor extends ActorReceiver { - def receive = { - case data: String => store(data) - } -} + groupId = org.apache.spark + artifactId = spark-streaming-akka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} -// A new input stream can be created with this custom actor as -val ssc: StreamingContext = ... -val lines = AkkaUtils.createStream[String](ssc, Props[CustomActor](), "CustomReceiver") +2. **Programming:** -{% endhighlight %} +
    +
    -See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. -
    -
    + You need to extend [`ActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.ActorReceiver) + so as to store received data into Spark using `store(...)` methods. The supervisor strategy of + this actor can be configured to handle failures, etc. -Custom [Akka UntypedActors](http://doc.akka.io/docs/akka/2.3.11/java/untyped-actors.html) can also be used to -receive data. Extending [`JavaActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.JavaActorReceiver) -allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of -this actor can be configured to handle failures, etc. + class CustomActor extends ActorReceiver { + def receive = { + case data: String => store(data) + } + } -{% highlight java %} + // A new input stream can be created with this custom actor as + val ssc: StreamingContext = ... + val lines = AkkaUtils.createStream[String](ssc, Props[CustomActor](), "CustomReceiver") -class CustomActor extends JavaActorReceiver { - @Override - public void onReceive(Object msg) throws Exception { - store((String) msg); - } -} + See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. +
    +
    -// A new input stream can be created with this custom actor as -JavaStreamingContext jssc = ...; -JavaDStream lines = AkkaUtils.createStream(jssc, Props.create(CustomActor.class), "CustomReceiver"); + You need to extend [`JavaActorReceiver`](api/scala/index.html#org.apache.spark.streaming.akka.JavaActorReceiver) + so as to store received data into Spark using `store(...)` methods. The supervisor strategy of + this actor can be configured to handle failures, etc. -{% endhighlight %} + class CustomActor extends JavaActorReceiver { + @Override + public void onReceive(Object msg) throws Exception { + store((String) msg); + } + } -See [JavaActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaActorWordCount.scala) for an end-to-end example. -
    -
    + // A new input stream can be created with this custom actor as + JavaStreamingContext jssc = ...; + JavaDStream lines = AkkaUtils.createStream(jssc, Props.create(CustomActor.class), "CustomReceiver"); + + See [JavaActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaActorWordCount.scala) for an end-to-end example. +
    +
    + +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. +You need to package `spark-streaming-akka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into +the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` +are marked as `provided` dependencies as those are already present in a Spark installation. Then +use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). Python API Since actors are available only in the Java and Scala libraries, AkkaUtils is not available in the Python API. From 8beab68152348c44cf2f89850f792f164b06470d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 11:56:46 -0800 Subject: [PATCH 1698/2089] [SPARK-11923][ML] Python API for ml.feature.ChiSqSelector https://issues.apache.org/jira/browse/SPARK-11923 Author: Xusen Yin Closes #10186 from yinxusen/SPARK-11923. --- python/pyspark/ml/feature.py | 98 +++++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f139d81bc490d..32f324685a9cf 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -33,7 +33,8 @@ 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', + 'ChiSqSelector', 'ChiSqSelectorModel'] @inherit_doc @@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel): """ +@inherit_doc +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol): + """ + .. note:: Experimental + + Chi-Squared feature selection, which selects categorical features to use for predicting a + categorical label. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], + ... ["features", "label"]) + >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.transform(df).head().selectedFeatures + DenseVector([1.0]) + >>> model.selectedFeatures + [3] + + .. versionadded:: 2.0.0 + """ + + # a placeholder to make it appear in the generated doc + numTopFeatures = \ + Param(Params._dummy(), "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will select " + + "all features.") + + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + """ + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + """ + super(ChiSqSelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) + self.numTopFeatures = \ + Param(self, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will " + + "select all features.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels"): + """ + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ + labelCol="labels") + Sets params for this ChiSqSelector. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumTopFeatures(self, value): + """ + Sets the value of :py:attr:`numTopFeatures`. + """ + self._paramMap[self.numTopFeatures] = value + return self + + @since("2.0.0") + def getNumTopFeatures(self): + """ + Gets the value of numTopFeatures or its default value. + """ + return self.getOrDefault(self.numTopFeatures) + + def _create_model(self, java_model): + return ChiSqSelectorModel(java_model) + + +class ChiSqSelectorModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by ChiSqSelector. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def selectedFeatures(self): + """ + List of indices to select (filter). Must be ordered asc. + """ + return self._call_java("selectedFeatures") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From fbf7623d49525e3aa6b08f482afd7ee8118d80cb Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 13:18:01 -0800 Subject: [PATCH 1699/2089] [SPARK-12952] EMLDAOptimizer initialize() should return EMLDAOptimizer other than its parent class https://issues.apache.org/jira/browse/SPARK-12952 Author: Xusen Yin Closes #10863 from yinxusen/SPARK-12952. --- .../org/apache/spark/mllib/clustering/LDAOptimizer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index c19595e6cd21b..7a41f74191536 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -95,7 +95,9 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * Compute bipartite term/doc graph. */ - override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { + override private[clustering] def initialize( + docs: RDD[(Long, Vector)], + lda: LDA): EMLDAOptimizer = { // EMLDAOptimizer currently only supports symmetric document-topic priors val docConcentration = lda.getDocConcentration From ee74498de372b16fe6350e3617e9e6ec87c6ae7b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 26 Jan 2016 14:20:11 -0800 Subject: [PATCH 1700/2089] [SPARK-8725][PROJECT-INFRA] Test modules in topologically-sorted order in dev/run-tests This patch improves our `dev/run-tests` script to test modules in a topologically-sorted order based on modules' dependencies. This will help to ensure that bugs in upstream projects are not misattributed to downstream projects because those projects' tests were the first ones to exhibit the failure Topological sorting is also useful for shortening the feedback loop when testing pull requests: if I make a change in SQL then the SQL tests should run before MLlib, not after. In addition, this patch also updates our test module definitions to split `sql` into `catalyst`, `sql`, and `hive` in order to allow more tests to be skipped when changing only `hive/` files. Author: Josh Rosen Closes #10885 from JoshRosen/SPARK-8725. --- NOTICE | 16 ++++++ dev/run-tests.py | 25 ++++++---- dev/sparktestsupport/modules.py | 54 +++++++++++++++++--- dev/sparktestsupport/toposort.py | 85 ++++++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 18 deletions(-) create mode 100644 dev/sparktestsupport/toposort.py diff --git a/NOTICE b/NOTICE index e416aadce9911..6a26155fb4952 100644 --- a/NOTICE +++ b/NOTICE @@ -650,3 +650,19 @@ For CSV functionality: */ +=============================================================================== +For dev/sparktestsupport/toposort.py: + +Copyright 2014 True Blade Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/dev/run-tests.py b/dev/run-tests.py index 8f47728f206c3..c78a66f6aa54e 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -29,6 +29,7 @@ from sparktestsupport import SPARK_HOME, USER_HOME, ERROR_CODES from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +from sparktestsupport.toposort import toposort_flatten, toposort import sparktestsupport.modules as modules @@ -43,7 +44,7 @@ def determine_modules_for_files(filenames): If a file is not associated with a more specific submodule, then this method will consider that file to belong to the 'root' module. - >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) + >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"])) ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] @@ -99,14 +100,16 @@ def determine_modules_to_test(changed_modules): Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. - >>> sorted(x.name for x in determine_modules_to_test([modules.root])) + Returns a topologically-sorted list of modules (ties are broken by sorting on module names). + + >>> [x.name for x in determine_modules_to_test([modules.root])] ['root'] - >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) - ['examples', 'graphx'] - >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> [x.name for x in determine_modules_to_test([modules.graphx])] + ['graphx', 'examples'] + >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] + ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', + 'pyspark-mllib', 'pyspark-ml'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be @@ -116,7 +119,9 @@ def determine_modules_to_test(changed_modules): modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) - return modules_to_test.union(set(changed_modules)) + modules_to_test = modules_to_test.union(set(changed_modules)) + return toposort_flatten( + {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True) def determine_tags_to_exclude(changed_modules): @@ -377,12 +382,12 @@ def run_scala_tests_maven(test_profiles): def run_scala_tests_sbt(test_modules, test_profiles): - sbt_test_goals = set(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) + sbt_test_goals = list(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) if not sbt_test_goals: return - profiles_and_goals = test_profiles + list(sbt_test_goals) + profiles_and_goals = test_profiles + sbt_test_goals print("[info] Running Spark tests using SBT with these arguments: ", " ".join(profiles_and_goals)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 032c0616edb1e..07c3078e45490 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -15,12 +15,14 @@ # limitations under the License. # +from functools import total_ordering import itertools import re all_modules = [] +@total_ordering class Module(object): """ A module is the basic abstraction in our test runner script. Each module consists of a set of @@ -75,20 +77,56 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= def contains_file(self, filename): return any(re.match(p, filename) for p in self.source_file_prefixes) + def __repr__(self): + return "Module<%s>" % self.name + + def __lt__(self, other): + return self.name < other.name + + def __eq__(self, other): + return self.name == other.name + + def __ne__(self, other): + return not (self.name == other.name) + + def __hash__(self): + return hash(self.name) + + +catalyst = Module( + name="catalyst", + dependencies=[], + source_file_regexes=[ + "sql/catalyst/", + ], + sbt_test_goals=[ + "catalyst/test", + ], +) + sql = Module( name="sql", - dependencies=[], + dependencies=[catalyst], source_file_regexes=[ - "sql/(?!hive-thriftserver)", + "sql/core/", + ], + sbt_test_goals=[ + "sql/test", + ], +) + +hive = Module( + name="hive", + dependencies=[sql], + source_file_regexes=[ + "sql/hive/", "bin/spark-sql", ], build_profile_flags=[ "-Phive", ], sbt_test_goals=[ - "catalyst/test", - "sql/test", "hive/test", ], test_tags=[ @@ -99,7 +137,7 @@ def contains_file(self, filename): hive_thriftserver = Module( name="hive-thriftserver", - dependencies=[sql], + dependencies=[hive], source_file_regexes=[ "sql/hive-thriftserver", "sbin/start-thriftserver.sh", @@ -282,7 +320,7 @@ def contains_file(self, filename): examples = Module( name="examples", - dependencies=[graphx, mllib, streaming, sql], + dependencies=[graphx, mllib, streaming, hive], source_file_regexes=[ "examples/", ], @@ -314,7 +352,7 @@ def contains_file(self, filename): pyspark_sql = Module( name="pyspark-sql", - dependencies=[pyspark_core, sql], + dependencies=[pyspark_core, hive], source_file_regexes=[ "python/pyspark/sql" ], @@ -404,7 +442,7 @@ def contains_file(self, filename): sparkr = Module( name="sparkr", - dependencies=[sql, mllib], + dependencies=[hive, mllib], source_file_regexes=[ "R/", ], diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py new file mode 100644 index 0000000000000..6c67b4504bc3b --- /dev/null +++ b/dev/sparktestsupport/toposort.py @@ -0,0 +1,85 @@ +####################################################################### +# Implements a topological sort algorithm. +# +# Copyright 2014 True Blade Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Notes: +# Based on http://code.activestate.com/recipes/578272-topological-sort +# with these major changes: +# Added unittests. +# Deleted doctests (maybe not the best idea in the world, but it cleans +# up the docstring). +# Moved functools import to the top of the file. +# Changed assert to a ValueError. +# Changed iter[items|keys] to [items|keys], for python 3 +# compatibility. I don't think it matters for python 2 these are +# now lists instead of iterables. +# Copy the input so as to leave it unmodified. +# Renamed function from toposort2 to toposort. +# Handle empty input. +# Switch tests to use set literals. +# +######################################################################## + +from functools import reduce as _reduce + + +__all__ = ['toposort', 'toposort_flatten'] + + +def toposort(data): + """Dependencies are expressed as a dictionary whose keys are items +and whose values are a set of dependent items. Output is a list of +sets in topological order. The first set consists of items with no +dependences, each subsequent set consists of items that depend upon +items in the preceeding sets. +""" + + # Special case empty input. + if len(data) == 0: + return + + # Copy the input so as to leave it unmodified. + data = data.copy() + + # Ignore self dependencies. + for k, v in data.items(): + v.discard(k) + # Find all items that don't depend on anything. + extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) + # Add empty dependences where needed. + data.update({item: set() for item in extra_items_in_deps}) + while True: + ordered = set(item for item, dep in data.items() if len(dep) == 0) + if not ordered: + break + yield ordered + data = {item: (dep - ordered) + for item, dep in data.items() + if item not in ordered} + if len(data) != 0: + raise ValueError('Cyclic dependencies exist among these items: {}'.format( + ', '.join(repr(x) for x in data.items()))) + + +def toposort_flatten(data, sort=True): + """Returns a single list of dependencies. For any set returned by +toposort(), those items are sorted and appended to the result (just to +make the results deterministic).""" + + result = [] + for d in toposort(data): + result.extend((sorted if sort else list)(d)) + return result From 83507fea9f45c336d73dd4795b8cb37bcd63e31d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 14:29:29 -0800 Subject: [PATCH 1701/2089] [SQL] Minor Scaladoc format fix Otherwise the `^` character is always marked as error in IntelliJ since it represents an unclosed superscript markup tag. Author: Cheng Lian Closes #10926 from liancheng/agg-doc-fix. --- .../sql/catalyst/expressions/aggregate/interfaces.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ddd99c51ab0c3..561fa3321d8fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -200,7 +200,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)` * will be 0 and the position of the first buffer value of `avg(y)` will be 2: - * + * {{{ * avg(x) mutableAggBufferOffset = 0 * | * v @@ -210,7 +210,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * ^ * | * avg(y) mutableAggBufferOffset = 2 - * + * }}} */ protected val mutableAggBufferOffset: Int @@ -233,7 +233,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` * will be 3 (position 0 is used for the value of `key`): - * + * {{{ * avg(x) inputAggBufferOffset = 1 * | * v @@ -243,7 +243,7 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * ^ * | * avg(y) inputAggBufferOffset = 3 - * + * }}} */ protected val inputAggBufferOffset: Int From 19fdb21afbf0eae4483cf6d4ef32daffd1994b89 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 26 Jan 2016 14:58:39 -0800 Subject: [PATCH 1702/2089] [SPARK-12993][PYSPARK] Remove usage of ADD_FILES in pyspark environment variable ADD_FILES is created for adding python files on spark context to be distributed to executors (SPARK-865), this is deprecated now. User are encouraged to use --py-files for adding python files. Author: Jeff Zhang Closes #10913 from zjffdu/SPARK-12993. --- python/pyspark/shell.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 26cafca8b8381..7c37f75193473 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -32,15 +32,10 @@ from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the deprecated equivalent of ADD_JARS -add_files = None -if os.environ.get("ADD_FILES") is not None: - add_files = os.environ.get("ADD_FILES").split(',') - if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext(pyFiles=add_files) +sc = SparkContext() atexit.register(lambda: sc.stop()) try: @@ -68,10 +63,6 @@ platform.python_build()[1])) print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) -if add_files is not None: - print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") - print("Adding files: [%s]" % ", ".join(add_files)) - # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') From eb917291ca1a2d68ca0639cb4b1464a546603eba Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 26 Jan 2016 15:53:48 -0800 Subject: [PATCH 1703/2089] [SPARK-10509][PYSPARK] Reduce excessive param boiler plate code The current python ml params require cut-and-pasting the param setup and description between the class & ```__init__``` methods. Remove this possible case of errors & simplify use of custom params by adding a ```_copy_new_parent``` method to param so as to avoid cut and pasting (and cut and pasting at different indentation levels urgh). Author: Holden Karau Closes #10216 from holdenk/SPARK-10509-excessive-param-boiler-plate-code. --- python/pyspark/ml/classification.py | 32 ------ python/pyspark/ml/clustering.py | 7 -- python/pyspark/ml/evaluation.py | 12 --- python/pyspark/ml/feature.py | 98 +------------------ python/pyspark/ml/param/__init__.py | 22 +++++ .../ml/param/_shared_params_code_gen.py | 17 +--- python/pyspark/ml/param/shared.py | 81 +-------------- python/pyspark/ml/pipeline.py | 4 +- python/pyspark/ml/recommendation.py | 11 --- python/pyspark/ml/regression.py | 46 --------- python/pyspark/ml/tests.py | 12 +++ python/pyspark/ml/tuning.py | 18 ---- 12 files changed, 43 insertions(+), 317 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 265c6a14f1ca4..3179fb30ab4d7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -72,7 +72,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti .. versionadded:: 1.3.0 """ - # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -92,10 +91,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for threshold in binary classification, in range [0, 1]. - self.threshold = Param(self, "threshold", - "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.") self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -232,7 +227,6 @@ class TreeClassifierParams(object): """ supportedImpurities = ["entropy", "gini"] - # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + @@ -240,10 +234,6 @@ class TreeClassifierParams(object): def __init__(self): super(TreeClassifierParams, self).__init__() - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = Param(self, "impurity", "Criterion used for information " + - "gain calculation (case-insensitive). Supported options: " + - ", ".join(self.supportedImpurities)) @since("1.6.0") def setImpurity(self, value): @@ -485,7 +475,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) @@ -504,10 +493,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) - #: param for Loss function which GBT tries to minimize (case-insensitive). - self.lossType = Param(self, "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1) @@ -597,7 +582,6 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + "default is 1.0") modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + @@ -615,13 +599,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid) - #: param for the smoothing parameter. - self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + - "default is 1.0") - #: param for the model type. - self.modelType = Param(self, "modelType", "The model type which is a string " + - "(case-sensitive). Supported options: multinomial (default) " + - "and bernoulli.") self._setDefault(smoothing=1.0, modelType="multinomial") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -734,7 +711,6 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + "neurons and output layer of 10 neurons, default is [1, 1].") @@ -753,14 +729,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + - "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + - "100 neurons and output layer of 10 neurons, default is [1, 1].") - self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + - "matrices. Data is stacked within partitions. If block size is " + - "more than remaining data in a partition then it is adjusted to " + - "the size of this data. Recommended size is between 10 and 1000, " + - "default is 128.") self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 9189c02220228..60d1c9aaec988 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -73,7 +73,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + @@ -90,12 +89,6 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self.k = Param(self, "k", "number of clusters to create") - self.initMode = Param(self, "initMode", - "the initialization algorithm. This can be either \"random\" to " + - "choose random points as initial cluster centers, or \"k-means||\" " + - "to use a parallel variant of k-means++") - self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 6ff68abd8f18f..c9b95b3bf45d9 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -124,7 +124,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -138,9 +137,6 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", super(BinaryClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) - #: param for metric name in evaluation (areaUnderROC|areaUnderPR) - self.metricName = Param(self, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") kwargs = self.__init__._input_kwargs @@ -210,9 +206,6 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(RegressionEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) - #: param for metric name in evaluation (mse|rmse|r2|mae) - self.metricName = Param(self, "metricName", - "metric name in evaluation (mse|rmse|r2|mae)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="rmse") kwargs = self.__init__._input_kwargs @@ -265,7 +258,6 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation " "(f1|precision|recall|weightedPrecision|weightedRecall)") @@ -280,10 +272,6 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) - # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) - self.metricName = Param(self, "metricName", - "metric name in evaluation" - " (f1|precision|recall|weightedPrecision|weightedRecall)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") kwargs = self.__init__._input_kwargs diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 32f324685a9cf..22081233b04d5 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -57,7 +57,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]") @@ -68,8 +67,6 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): """ super(Binarizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) - self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1]") self._setDefault(threshold=0.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -125,7 +122,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.3.0 """ - # a placeholder to make it appear in the generated doc splits = \ Param(Params._dummy(), "splits", "Split points for mapping continuous features into buckets. With n+1 splits, " + @@ -142,19 +138,6 @@ def __init__(self, splits=None, inputCol=None, outputCol=None): """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) - #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, - # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) - # except the last bucket, which also includes y. The splits should be strictly increasing. - # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, - # values outside the splits specified will be treated as errors. - self.splits = \ - Param(self, "splits", - "Split points for mapping continuous features into buckets. With n+1 splits, " + - "there are n buckets. A bucket defined by splits x,y holds values in the " + - "range [x,y) except the last bucket, which also includes y. The splits " + - "should be strictly increasing. Values at -inf, inf must be explicitly " + - "provided to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -210,7 +193,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc minTF = Param( Params._dummy(), "minTF", "Filter to ignore rare words in" + " a document. For each document, terms with frequency/count less than the given" + @@ -235,22 +217,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outpu super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self.minTF = Param( - self, "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then " + - "this specifies a fraction (out of the document's token count). Note that the " + - "parameter is only used in transform of CountVectorizerModel and does not affect" + - "fitting. Default 1.0") - self.minDF = Param( - self, "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of " + - "documents. Default 1.0") - self.vocabSize = Param( - self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -359,7 +325,6 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + "default False.") @@ -370,8 +335,6 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): """ super(DCT, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) - self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " + - "default False.") self._setDefault(inverse=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -423,7 +386,6 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + "it must be MLlib Vector type.") @@ -435,8 +397,6 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): super(ElementwiseProduct, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", self.uid) - self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " + - "it must be MLlib Vector type.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -531,7 +491,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc minDocFreq = Param(Params._dummy(), "minDocFreq", "minimum of documents in which a term should appear for filtering") @@ -542,8 +501,6 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): """ super(IDF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) - self.minDocFreq = Param(self, "minDocFreq", - "minimum of documents in which a term should appear for filtering") self._setDefault(minDocFreq=0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -623,7 +580,6 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc min = Param(Params._dummy(), "min", "Lower bound of the output feature range") max = Param(Params._dummy(), "max", "Upper bound of the output feature range") @@ -634,8 +590,6 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): """ super(MinMaxScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) - self.min = Param(self, "min", "Lower bound of the output feature range") - self.max = Param(self, "max", "Upper bound of the output feature range") self._setDefault(min=0.0, max=1.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -745,7 +699,6 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") @keyword_only @@ -755,7 +708,6 @@ def __init__(self, n=2, inputCol=None, outputCol=None): """ super(NGram, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) - self.n = Param(self, "n", "number of elements per n-gram (>=1)") self._setDefault(n=2) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -808,7 +760,6 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc p = Param(Params._dummy(), "p", "the p norm value.") @keyword_only @@ -818,7 +769,6 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): """ super(Normalizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) - self.p = Param(self, "p", "the p norm value.") self._setDefault(p=2.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -887,7 +837,6 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only @@ -897,7 +846,6 @@ def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self.dropLast = Param(self, "dropLast", "whether to drop the last category") self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -950,7 +898,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") @keyword_only @@ -961,7 +908,6 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): super(PolynomialExpansion, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) - self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") self._setDefault(degree=2) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1107,7 +1053,6 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") @@ -1123,11 +1068,6 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, """ super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) - self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") - self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") - self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") - self.toLowercase = Param(self, "toLowercase", "whether to convert all characters to " + - "lowercase before tokenizing") self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1223,7 +1163,6 @@ class SQLTransformer(JavaTransformer): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc statement = Param(Params._dummy(), "statement", "SQL statement") @keyword_only @@ -1233,7 +1172,6 @@ def __init__(self, statement=None): """ super(SQLTransformer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) - self.statement = Param(self, "statement", "SQL statement") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1285,7 +1223,6 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc withMean = Param(Params._dummy(), "withMean", "Center data with mean") withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") @@ -1296,8 +1233,6 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): """ super(StandardScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) - self.withMean = Param(self, "withMean", "Center data with mean") - self.withStd = Param(self, "withStd", "Scale to unit standard deviation") self._setDefault(withMean=False, withStd=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1453,7 +1388,6 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make the labels show up in generated doc labels = Param(Params._dummy(), "labels", "Optional array of labels specifying index-string mapping." + " If not provided or if empty, then metadata from inputCol is used instead.") @@ -1466,9 +1400,6 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): super(IndexToString, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) - self.labels = Param(self, "labels", - "Optional array of labels specifying index-string mapping. If not" + - " provided or if empty, then metadata from inputCol is used instead.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1507,7 +1438,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words") @@ -1522,9 +1453,6 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) - self.stopWords = Param(self, "stopWords", "The words to be filtered out") - self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + - "sensitive comparison over the stop words") stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords defaultStopWords = stopWordsObj.English() self._setDefault(stopWords=defaultStopWords) @@ -1727,7 +1655,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + "(>= 2). If a feature is found to have > maxCategories values, then " + @@ -1740,10 +1667,6 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self.maxCategories = Param(self, "maxCategories", - "Threshold for the number of values a categorical feature " + - "can take (>= 2). If a feature is found to have " + - "> maxCategories values, then it is declared continuous.") self._setDefault(maxCategories=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1832,7 +1755,6 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + "a vector column. There can be no overlap with names.") names = Param(Params._dummy(), "names", "An array of feature names to select features from " + @@ -1847,12 +1769,6 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): """ super(VectorSlicer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) - self.indices = Param(self, "indices", "An array of indices to select features from " + - "a vector column. There can be no overlap with names.") - self.names = Param(self, "names", "An array of feature names to select features from " + - "a vector column. These names must be specified by ML " + - "org.apache.spark.ml.attribute.Attribute. There can be no overlap " + - "with indices.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1932,7 +1848,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc vectorSize = Param(Params._dummy(), "vectorSize", "the dimension of codes after transforming from words") numPartitions = Param(Params._dummy(), "numPartitions", @@ -1950,13 +1865,6 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, """ super(Word2Vec, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) - self.vectorSize = Param(self, "vectorSize", - "the dimension of codes after transforming from words") - self.numPartitions = Param(self, "numPartitions", - "number of partitions for sentences of words") - self.minCount = Param(self, "minCount", - "the minimum number of times a token must appear to be included " + - "in the word2vec model's vocabulary") self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None) kwargs = self.__init__._input_kwargs @@ -2075,7 +1983,6 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "the number of principal components") @keyword_only @@ -2085,7 +1992,6 @@ def __init__(self, k=None, inputCol=None, outputCol=None): """ super(PCA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) - self.k = Param(self, "k", "the number of principal components") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -2185,7 +2091,6 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc formula = Param(Params._dummy(), "formula", "R model formula") @keyword_only @@ -2195,7 +2100,6 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label"): """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) - self.formula = Param(self, "formula", "R model formula") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 92ce96aa3c4df..3da36d32c5af0 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -40,6 +40,15 @@ def __init__(self, parent, name, doc, expectedType=None): self.doc = str(doc) self.expectedType = expectedType + def _copy_new_parent(self, parent): + """Copy the current param to a new parent, must be a dummy param.""" + if self.parent == "undefined": + param = copy.copy(self) + param.parent = parent.uid + return param + else: + raise ValueError("Cannot copy from non-dummy parent %s." % parent) + def __str__(self): return str(self.parent) + "__" + self.name @@ -77,6 +86,19 @@ def __init__(self): #: value returned by :py:func:`params` self._params = None + # Copy the params from the class to the object + self._copy_params() + + def _copy_params(self): + """ + Copy all params defined on the class to current object. + """ + cls = type(self) + src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)] + src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs)) + for name, param in src_params: + setattr(self, name, param._copy_new_parent(self)) + @property @since("1.3.0") def params(self): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 82855bc4c75ba..5e297b8214823 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -50,13 +50,11 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType): Mixin for param $name: $doc """ - # a placeholder to make it appear in the generated doc $name = Param(Params._dummy(), "$name", "$doc", $expectedType) def __init__(self): - super(Has$Name, self).__init__() - #: param for $doc - self.$name = Param(self, "$name", "$doc", $expectedType)''' + super(Has$Name, self).__init__()''' + if defaultValueStr is not None: template += ''' self._setDefault($name=$defaultValueStr)''' @@ -171,22 +169,17 @@ def get$Name(self): Mixin for Decision Tree parameters. """ - # a placeholder to make it appear in the generated doc $dummyPlaceHolders def __init__(self): - super(DecisionTreeParams, self).__init__() - $realParams''' + super(DecisionTreeParams, self).__init__()''' dtParamMethods = "" dummyPlaceholders = "" - realParams = "" paramTemplate = """$name = Param($owner, "$name", "$doc")""" for name, doc in decisionTreeParams: variable = paramTemplate.replace("$name", name).replace("$doc", doc) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " - realParams += "#: param for " + doc + "\n " - realParams += "self." + variable.replace("$owner", "self") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" - code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) - .replace("$realParams", realParams) + dtParamMethods) + code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" + + dtParamMethods) print("\n\n\n".join(code)) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 23f94314844f6..db4a8a54d4956 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -25,13 +25,10 @@ class HasMaxIter(Params): Mixin for param maxIter: max number of iterations (>= 0). """ - # a placeholder to make it appear in the generated doc maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", int) def __init__(self): super(HasMaxIter, self).__init__() - #: param for max number of iterations (>= 0). - self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).", int) def setMaxIter(self, value): """ @@ -52,13 +49,10 @@ class HasRegParam(Params): Mixin for param regParam: regularization parameter (>= 0). """ - # a placeholder to make it appear in the generated doc regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", float) def __init__(self): super(HasRegParam, self).__init__() - #: param for regularization parameter (>= 0). - self.regParam = Param(self, "regParam", "regularization parameter (>= 0).", float) def setRegParam(self, value): """ @@ -79,13 +73,10 @@ class HasFeaturesCol(Params): Mixin for param featuresCol: features column name. """ - # a placeholder to make it appear in the generated doc featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", str) def __init__(self): super(HasFeaturesCol, self).__init__() - #: param for features column name. - self.featuresCol = Param(self, "featuresCol", "features column name.", str) self._setDefault(featuresCol='features') def setFeaturesCol(self, value): @@ -107,13 +98,10 @@ class HasLabelCol(Params): Mixin for param labelCol: label column name. """ - # a placeholder to make it appear in the generated doc labelCol = Param(Params._dummy(), "labelCol", "label column name.", str) def __init__(self): super(HasLabelCol, self).__init__() - #: param for label column name. - self.labelCol = Param(self, "labelCol", "label column name.", str) self._setDefault(labelCol='label') def setLabelCol(self, value): @@ -135,13 +123,10 @@ class HasPredictionCol(Params): Mixin for param predictionCol: prediction column name. """ - # a placeholder to make it appear in the generated doc predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", str) def __init__(self): super(HasPredictionCol, self).__init__() - #: param for prediction column name. - self.predictionCol = Param(self, "predictionCol", "prediction column name.", str) self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): @@ -163,13 +148,10 @@ class HasProbabilityCol(Params): Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. """ - # a placeholder to make it appear in the generated doc probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) def __init__(self): super(HasProbabilityCol, self).__init__() - #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. - self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) self._setDefault(probabilityCol='probability') def setProbabilityCol(self, value): @@ -191,13 +173,10 @@ class HasRawPredictionCol(Params): Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. """ - # a placeholder to make it appear in the generated doc rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) def __init__(self): super(HasRawPredictionCol, self).__init__() - #: param for raw prediction (a.k.a. confidence) column name. - self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): @@ -219,13 +198,10 @@ class HasInputCol(Params): Mixin for param inputCol: input column name. """ - # a placeholder to make it appear in the generated doc inputCol = Param(Params._dummy(), "inputCol", "input column name.", str) def __init__(self): super(HasInputCol, self).__init__() - #: param for input column name. - self.inputCol = Param(self, "inputCol", "input column name.", str) def setInputCol(self, value): """ @@ -246,13 +222,10 @@ class HasInputCols(Params): Mixin for param inputCols: input column names. """ - # a placeholder to make it appear in the generated doc inputCols = Param(Params._dummy(), "inputCols", "input column names.", None) def __init__(self): super(HasInputCols, self).__init__() - #: param for input column names. - self.inputCols = Param(self, "inputCols", "input column names.", None) def setInputCols(self, value): """ @@ -273,13 +246,10 @@ class HasOutputCol(Params): Mixin for param outputCol: output column name. """ - # a placeholder to make it appear in the generated doc outputCol = Param(Params._dummy(), "outputCol", "output column name.", str) def __init__(self): super(HasOutputCol, self).__init__() - #: param for output column name. - self.outputCol = Param(self, "outputCol", "output column name.", str) self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): @@ -301,13 +271,10 @@ class HasNumFeatures(Params): Mixin for param numFeatures: number of features. """ - # a placeholder to make it appear in the generated doc numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", int) def __init__(self): super(HasNumFeatures, self).__init__() - #: param for number of features. - self.numFeatures = Param(self, "numFeatures", "number of features.", int) def setNumFeatures(self, value): """ @@ -328,13 +295,10 @@ class HasCheckpointInterval(Params): Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ - # a placeholder to make it appear in the generated doc checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. - self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def setCheckpointInterval(self, value): """ @@ -355,13 +319,10 @@ class HasSeed(Params): Mixin for param seed: random seed. """ - # a placeholder to make it appear in the generated doc seed = Param(Params._dummy(), "seed", "random seed.", int) def __init__(self): super(HasSeed, self).__init__() - #: param for random seed. - self.seed = Param(self, "seed", "random seed.", int) self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): @@ -383,13 +344,10 @@ class HasTol(Params): Mixin for param tol: the convergence tolerance for iterative algorithms. """ - # a placeholder to make it appear in the generated doc tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", float) def __init__(self): super(HasTol, self).__init__() - #: param for the convergence tolerance for iterative algorithms. - self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.", float) def setTol(self, value): """ @@ -410,13 +368,10 @@ class HasStepSize(Params): Mixin for param stepSize: Step size to be used for each iteration of optimization. """ - # a placeholder to make it appear in the generated doc stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", float) def __init__(self): super(HasStepSize, self).__init__() - #: param for Step size to be used for each iteration of optimization. - self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.", float) def setStepSize(self, value): """ @@ -437,13 +392,10 @@ class HasHandleInvalid(Params): Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. """ - # a placeholder to make it appear in the generated doc handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def __init__(self): super(HasHandleInvalid, self).__init__() - #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. - self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def setHandleInvalid(self, value): """ @@ -464,13 +416,10 @@ class HasElasticNetParam(Params): Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. """ - # a placeholder to make it appear in the generated doc elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) def __init__(self): super(HasElasticNetParam, self).__init__() - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) self._setDefault(elasticNetParam=0.0) def setElasticNetParam(self, value): @@ -492,13 +441,10 @@ class HasFitIntercept(Params): Mixin for param fitIntercept: whether to fit an intercept term. """ - # a placeholder to make it appear in the generated doc fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", bool) def __init__(self): super(HasFitIntercept, self).__init__() - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.", bool) self._setDefault(fitIntercept=True) def setFitIntercept(self, value): @@ -520,13 +466,10 @@ class HasStandardization(Params): Mixin for param standardization: whether to standardize the training features before fitting the model. """ - # a placeholder to make it appear in the generated doc standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", bool) def __init__(self): super(HasStandardization, self).__init__() - #: param for whether to standardize the training features before fitting the model. - self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.", bool) self._setDefault(standardization=True) def setStandardization(self, value): @@ -548,13 +491,10 @@ class HasThresholds(Params): Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. """ - # a placeholder to make it appear in the generated doc thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def __init__(self): super(HasThresholds, self).__init__() - #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. - self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def setThresholds(self, value): """ @@ -575,13 +515,10 @@ class HasWeightCol(Params): Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. """ - # a placeholder to make it appear in the generated doc weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def __init__(self): super(HasWeightCol, self).__init__() - #: param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. - self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def setWeightCol(self, value): """ @@ -602,13 +539,10 @@ class HasSolver(Params): Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. """ - # a placeholder to make it appear in the generated doc solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) def __init__(self): super(HasSolver, self).__init__() - #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. - self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) self._setDefault(solver='auto') def setSolver(self, value): @@ -630,7 +564,6 @@ class DecisionTreeParams(Params): Mixin for Decision Tree parameters. """ - # a placeholder to make it appear in the generated doc maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") @@ -641,19 +574,7 @@ class DecisionTreeParams(Params): def __init__(self): super(DecisionTreeParams, self).__init__() - #: param for Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - self.maxDepth = Param(self, "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") - #: param for Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature. - self.maxBins = Param(self, "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") - #: param for Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1. - self.minInstancesPerNode = Param(self, "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") - #: param for Minimum information gain for a split to be considered at a tree node. - self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") - #: param for Maximum memory in MB allocated to histogram aggregation. - self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9f5f6ac8fa4e2..661074ca96212 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -149,6 +149,8 @@ class Pipeline(Estimator): .. versionadded:: 1.3.0 """ + stages = Param(Params._dummy(), "stages", "pipeline stages") + @keyword_only def __init__(self, stages=None): """ @@ -157,8 +159,6 @@ def __init__(self, stages=None): if stages is None: stages = [] super(Pipeline, self).__init__() - #: Param for pipeline stages. - self.stages = Param(self, "stages", "pipeline stages") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b44c66f73cc49..08180a2f25eb9 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -85,7 +85,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc rank = Param(Params._dummy(), "rank", "rank of the factorization") numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks") @@ -108,16 +107,6 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) - self.rank = Param(self, "rank", "rank of the factorization") - self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") - self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") - self.implicitPrefs = Param(self, "implicitPrefs", "whether to use implicit preference") - self.alpha = Param(self, "alpha", "alpha for implicit preference") - self.userCol = Param(self, "userCol", "column name for user ids") - self.itemCol = Param(self, "itemCol", "column name for item ids") - self.ratingCol = Param(self, "ratingCol", "column name for ratings") - self.nonnegative = Param(self, "nonnegative", - "whether to use nonnegative constraint for least squares") self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 401bac0223ebd..74a2248ed07c8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -162,7 +162,6 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti DenseVector([0.0, 1.0]) """ - # a placeholder to make it appear in the generated doc isotonic = \ Param(Params._dummy(), "isotonic", "whether the output sequence should be isotonic/increasing (true) or" + @@ -181,14 +180,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(IsotonicRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid) - self.isotonic = \ - Param(self, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + - "antitonic/decreasing (false).") - self.featureIndex = \ - Param(self, "featureIndex", - "The index of the feature if featuresCol is a vector column, no effect " + - "otherwise.") self._setDefault(isotonic=True, featureIndex=0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -262,15 +253,11 @@ class TreeEnsembleParams(DecisionTreeParams): Mixin for Decision Tree-based ensemble algorithms parameters. """ - # a placeholder to make it appear in the generated doc subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " + "used for learning each decision tree, in range (0, 1].") def __init__(self): super(TreeEnsembleParams, self).__init__() - #: param for Fraction of the training data, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", "Fraction of the training data " + - "used for learning each decision tree, in range (0, 1].") @since("1.4.0") def setSubsamplingRate(self, value): @@ -294,7 +281,6 @@ class TreeRegressorParams(Params): """ supportedImpurities = ["variance"] - # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + @@ -302,10 +288,6 @@ class TreeRegressorParams(Params): def __init__(self): super(TreeRegressorParams, self).__init__() - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = Param(self, "impurity", "Criterion used for information " + - "gain calculation (case-insensitive). Supported options: " + - ", ".join(self.supportedImpurities)) @since("1.4.0") def setImpurity(self, value): @@ -329,7 +311,6 @@ class RandomForestParams(TreeEnsembleParams): """ supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] - # a placeholder to make it appear in the generated doc numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).") featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", @@ -338,13 +319,6 @@ class RandomForestParams(TreeEnsembleParams): def __init__(self): super(RandomForestParams, self).__init__() - #: param for Number of trees to train (>= 1). - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1).") - #: param for The number of features to consider for splits at each tree node. - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(self.supportedFeatureSubsetStrategies)) @since("1.4.0") def setNumTrees(self, value): @@ -609,7 +583,6 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) @@ -627,10 +600,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) - #: param for Loss function which GBT tries to minimize (case-insensitive). - self.lossType = Param(self, "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) @@ -713,7 +682,6 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc censorCol = Param(Params._dummy(), "censorCol", "censor column name. The value of this column could be 0 or 1. " + "If the value is 1, it means the event has occurred i.e. " + @@ -739,20 +707,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) - #: Param for censor column name - self.censorCol = Param(self, "censorCol", - "censor column name. The value of this column could be 0 or 1. " + - "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.") - #: Param for quantile probabilities array - self.quantileProbabilities = \ - Param(self, "quantileProbabilities", - "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.") - #: Param for quantiles column name - self.quantilesCol = Param(self, "quantilesCol", - "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.") self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) kwargs = self.__init__._input_kwargs diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9ea639dc4f960..c45a159c460f3 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -185,6 +185,18 @@ def setParams(self, seed=None): class ParamTests(PySparkTestCase): + def test_copy_new_parent(self): + testParams = TestParams() + # Copying an instantiated param should fail + with self.assertRaises(ValueError): + testParams.maxIter._copy_new_parent(testParams) + # Copying a dummy param should succeed + TestParams.maxIter._copy_new_parent(testParams) + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + def test_param(self): testParams = TestParams() maxIter = testParams.maxIter diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 08f8db57f4400..0cbe97f1d839f 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,18 +115,11 @@ class CrossValidator(Estimator, HasSeed): .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") - - # a placeholder to make it appear in the generated doc estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - - # a placeholder to make it appear in the generated doc evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") - - # a placeholder to make it appear in the generated doc numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only @@ -137,17 +130,6 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF seed=None) """ super(CrossValidator, self).__init__() - #: param for estimator to be cross-validated - self.estimator = Param(self, "estimator", "estimator to be cross-validated") - #: param for estimator param maps - self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") - #: param for the evaluator used to select hyper-parameters that - #: maximize the cross-validated metric - self.evaluator = Param( - self, "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - #: param for number of folds for cross validation - self.numFolds = Param(self, "numFolds", "number of folds for cross validation") self._setDefault(numFolds=3) kwargs = self.__init__._input_kwargs self._set(**kwargs) From 22662b241629b56205719ede2f801a476e10a3cd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 26 Jan 2016 17:24:40 -0800 Subject: [PATCH 1704/2089] [SPARK-12614][CORE] Don't throw non fatal exception from ask Right now RpcEndpointRef.ask may throw exception in some corner cases, such as calling ask after stopping RpcEnv. It's better to avoid throwing exception from RpcEndpointRef.ask. We can send the exception to the future for `ask`. Author: Shixiong Zhu Closes #10568 from zsxwing/send-ask-fail. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 54 ++++++++++--------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index ef876b1d8c15a..9ae74d9d7b898 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -211,33 +211,37 @@ private[netty] class NettyRpcEnv( } } - if (remoteAddr == address) { - val p = Promise[Any]() - p.future.onComplete { - case Success(response) => onSuccess(response) - case Failure(e) => onFailure(e) - }(ThreadUtils.sameThread) - dispatcher.postLocalMessage(message, p) - } else { - val rpcMessage = RpcOutboxMessage(serialize(message), - onFailure, - (client, response) => onSuccess(deserialize[Any](client, response))) - postToOutbox(message.receiver, rpcMessage) - promise.future.onFailure { - case _: TimeoutException => rpcMessage.onTimeout() - case _ => + try { + if (remoteAddr == address) { + val p = Promise[Any]() + p.future.onComplete { + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) + }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) + } else { + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) + } + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) }(ThreadUtils.sameThread) + } catch { + case NonFatal(e) => + onFailure(e) } - - val timeoutCancelable = timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure( - new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - promise.future.onComplete { v => - timeoutCancelable.cancel(true) - }(ThreadUtils.sameThread) promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } From 1dac964c1b996d38c65818414fc8401961a1de8a Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 26 Jan 2016 17:31:19 -0800 Subject: [PATCH 1705/2089] =?UTF-8?q?[SPARK-11622][MLLIB]=20Make=20LibSVMR?= =?UTF-8?q?elation=20extends=20HadoopFsRelation=20and=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … Add LibSVMOutputWriter The behavior of LibSVMRelation is not changed except adding LibSVMOutputWriter * Partition is still not supported * Multiple input paths is not supported Author: Jeff Zhang Closes #9595 from zjffdu/SPARK-11622. --- .../ml/source/libsvm/LibSVMRelation.scala | 102 +++++++++++++++--- .../source/libsvm/LibSVMRelationSuite.scala | 23 +++- project/MimaExcludes.scala | 4 + 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 1bed542c40316..b9c364b05dc11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -17,16 +17,21 @@ package org.apache.spark.ml.source.libsvm +import java.io.IOException + import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types._ /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -37,14 +42,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging with Serializable { - - override def schema: StructType = StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil - ) + extends HadoopFsRelation with Serializable { - override def buildScan(): RDD[Row] = { + override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]) + : RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) val sparse = vectorType == "sparse" @@ -66,8 +67,63 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val case _ => false } + + override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job): + _root_.org.apache.spark.sql.sources.OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new LibSVMOutputWriter(path, dataSchema, context) + } + } + } + + override def paths: Array[String] = Array(path) + + override def dataSchema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil) } + +private[libsvm] class LibSVMOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = { + val label = row.get(0) + val vector = row.get(1).asInstanceOf[Vector] + val sb = new StringBuilder(label.toString) + vector.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" + } + buffer.set(sb.mkString) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} /** * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and @@ -99,16 +155,32 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") -class DefaultSource extends RelationProvider with DataSourceRegister { +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - @Since("1.6.0") - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) - : BaseRelation = { - val path = parameters.getOrElse("path", - throw new IllegalArgumentException("'path' must be specified")) + private def verifySchema(dataSchema: StructType): Unit = { + if (dataSchema.size != 2 || + (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) + || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") + } + } + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + val path = if (paths.length == 1) paths(0) + else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data") + else throw new IOException("Multiple input paths are not supported for libsvm data") + if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) { + throw new IOException("Partition is not supported for libsvm data") + } + dataSchema.foreach(verifySchema(_)) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 5f4d5f11bdd68..528d9e21cb1fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{File, IOException} import com.google.common.base.Charsets import com.google.common.io.Files @@ -25,6 +25,7 @@ import com.google.common.io.Files import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -82,4 +83,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } + + test("write libsvm data and read it again") { + val df = sqlContext.read.format("libsvm").load(path) + val tempDir2 = Utils.createTempDir() + val writepath = tempDir2.toURI.toString + df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) + + val df2 = sqlContext.read.format("libsvm").load(writepath) + val row1 = df2.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("write libsvm data failed due to invalid schema") { + val df = sqlContext.read.format("text").load(path) + val e = intercept[IOException] { + df.write.format("libsvm").save(path + "_2") + } + assert(e.getMessage.contains("Illegal schema for libsvm data")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 643bee69694df..fc7dc2181de85 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -203,6 +203,10 @@ object MimaExcludes { // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") + ) ++ Seq( + // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") ) case v if v.startsWith("1.6") => Seq( From 555127387accdd7c1cf236912941822ba8af0a52 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 26 Jan 2016 17:34:01 -0800 Subject: [PATCH 1706/2089] [SPARK-12854][SQL] Implement complex types support in ColumnarBatch This patch adds support for complex types for ColumnarBatch. ColumnarBatch supports structs and arrays. There is a simple mapping between the richer catalyst types to these two. Strings are treated as an array of bytes. ColumnarBatch will contain a column for each node of the schema. Non-complex schemas consists of just leaf nodes. Structs represent an internal node with one child for each field. Arrays are internal nodes with one child. Structs just contain nullability. Arrays contain offsets and lengths into the child array. This structure is able to handle arbitrary nesting. It has the key property that we maintain columnar throughout and that primitive types are only stored in the leaf nodes and contiguous across rows. For example, if the schema is ``` array> ``` There are three columns in the schema. The internal nodes each have one children. The leaf node contains all the int data stored consecutively. As part of this, this patch adds append APIs in addition to the Put APIs (e.g. putLong(rowid, v) vs appendLong(v)). These APIs are necessary when the batch contains variable length elements. The vectors are not fixed length and will grow as necessary. This should make the usage a lot simpler for the writer. Author: Nong Li Closes #10820 from nongli/spark-12854. --- .../sql/catalyst/expressions/UnsafeRow.java | 7 +- .../expressions/SpecificMutableRow.scala | 3 +- .../spark/sql/RandomDataGenerator.scala | 94 ++- .../spark/sql/RandomDataGeneratorSuite.scala | 4 +- .../GenerateUnsafeRowJoinerSuite.scala | 5 +- .../execution/vectorized/ColumnVector.java | 630 +++++++++++++++++- .../vectorized/ColumnVectorUtils.java | 126 ++++ .../execution/vectorized/ColumnarBatch.java | 70 +- .../vectorized/OffHeapColumnVector.java | 157 ++++- .../vectorized/OnHeapColumnVector.java | 169 ++++- .../UnsafeKVExternalSorterSuite.scala | 4 +- .../vectorized/ColumnarBatchBenchmark.scala | 78 ++- .../vectorized/ColumnarBatchSuite.scala | 397 ++++++++++- .../execution/AggregationQuerySuite.scala | 3 +- .../sql/sources/hadoopFsRelationSuites.scala | 3 +- .../org/apache/spark/unsafe/Platform.java | 11 + 16 files changed, 1671 insertions(+), 90 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1a351933a366c..a88bcbfdb7ccb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -68,6 +68,10 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields + 63)/ 64) * 8; } + public static int calculateFixedPortionByteSize(int numFields) { + return 8 * numFields + calculateBitSetWidthInBytes(numFields); + } + /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ @@ -596,10 +600,9 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { + if (i != 0) build.append(','); build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); - build.append(','); } - build.deleteCharAt(build.length() - 1); build.append(']'); return build.toString(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 475cbe005a6ee..4615c55d676f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * A parent class for mutable container objects that are reused when the values are changed, @@ -212,6 +211,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) def this() = this(Seq.empty) + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + override def numFields: Int = values.length override def setNullAt(i: Int): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 7614f055e9c04..55efea80d1a4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -21,6 +21,7 @@ import java.lang.Double.longBitsToDouble import java.lang.Float.intBitsToFloat import java.math.MathContext +import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -74,13 +75,47 @@ object RandomDataGenerator { * @param numFields the number of fields in this schema * @param acceptedTypes types to draw from. */ - def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = { + def randomSchema(rand: Random, numFields: Int, acceptedTypes: Seq[DataType]): StructType = { StructType(Seq.tabulate(numFields) { i => - val dt = acceptedTypes(Random.nextInt(acceptedTypes.size)) - StructField("col_" + i, dt, nullable = true) + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + StructField("col_" + i, dt, nullable = rand.nextBoolean()) }) } + /** + * Returns a random nested schema. This will randomly generate structs and arrays drawn from + * acceptedTypes. + */ + def randomNestedSchema(rand: Random, totalFields: Int, acceptedTypes: Seq[DataType]): + StructType = { + val fields = mutable.ArrayBuffer.empty[StructField] + var i = 0 + var numFields = totalFields + while (numFields > 0) { + val v = rand.nextInt(3) + if (v == 0) { + // Simple type: + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + fields += new StructField("col_" + i, dt, rand.nextBoolean()) + numFields -= 1 + } else if (v == 1) { + // Array + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + fields += new StructField("col_" + i, ArrayType(dt), rand.nextBoolean()) + numFields -= 1 + } else { + // Struct + // TODO: do empty structs make sense? + val n = Math.max(rand.nextInt(numFields), 1) + val nested = randomNestedSchema(rand, n, acceptedTypes) + fields += new StructField("col_" + i, nested, rand.nextBoolean()) + numFields -= n + } + i += 1 + } + StructType(fields) + } + /** * Returns a function which generates random values for the given [[DataType]], or `None` if no * random data generator is defined for that data type. The generated values will use an external @@ -90,16 +125,13 @@ object RandomDataGenerator { * * @param dataType the type to generate values for * @param nullable whether null values should be generated - * @param seed an optional seed for the random number generator + * @param rand an optional random number generator * @return a function which can be called to generate random values. */ def forType( dataType: DataType, nullable: Boolean = true, - seed: Option[Long] = None): Option[() => Any] = { - val rand = new Random() - seed.foreach(rand.setSeed) - + rand: Random = new Random): Option[() => Any] = { val valueGenerator: Option[() => Any] = dataType match { case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) case BinaryType => Some(() => { @@ -165,15 +197,15 @@ object RandomDataGenerator { rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) case ArrayType(elementType, containsNull) => { - forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { + forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } } case MapType(keyType, valueType, valueContainsNull) => { for ( - keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); + keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- - forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) + forType(valueType, nullable = valueContainsNull, rand) ) yield { () => { Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap @@ -182,7 +214,7 @@ object RandomDataGenerator { } case StructType(fields) => { val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => - forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) + forType(field.dataType, nullable = field.nullable, rand) } if (maybeFieldGenerators.forall(_.isDefined)) { val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) @@ -192,7 +224,7 @@ object RandomDataGenerator { } } case udt: UserDefinedType[_] => { - val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed) + val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand) // Because random data generator at here returns scala value, we need to // convert it to catalyst value to call udt's deserialize. val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) @@ -229,4 +261,40 @@ object RandomDataGenerator { } } } + + // Generates a random row for `schema`. + def randomRow(rand: Random, schema: StructType): Row = { + val fields = mutable.ArrayBuffer.empty[Any] + schema.fields.foreach { f => + f.dataType match { + case ArrayType(childType, nullable) => { + val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + val arr = mutable.ArrayBuffer.empty[Any] + val n = 1// rand.nextInt(10) + var i = 0 + val generator = RandomDataGenerator.forType(childType, nullable, rand) + assert(generator.isDefined, "Unsupported type") + val gen = generator.get + while (i < n) { + arr += gen() + i += 1 + } + arr + } + fields += data + } + case StructType(children) => { + fields += randomRow(rand, StructType(children)) + } + case _ => + val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) + assert(generator.isDefined, "Unsupported type") + val gen = generator.get + fields += gen() + } + } + Row.fromSeq(fields) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index cccac7efa09e9..b8ccdf7516d82 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ @@ -32,7 +34,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { */ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) - val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { + val generator = RandomDataGenerator.forType(dataType, nullable, new Random(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index 59729e7646beb..9f19745cefd20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -74,8 +74,9 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) { info(s"schema size $numFields1, $numFields2") - val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes) - val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes) + val random = new Random() + val schema1 = RandomDataGenerator.randomSchema(random, numFields1, candidateTypes) + val schema2 = RandomDataGenerator.randomSchema(random, numFields2, candidateTypes) // Create the converters needed to convert from external row to internal row and to UnsafeRows. val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 85509751dbbee..c119758d68b36 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -17,22 +17,45 @@ package org.apache.spark.sql.execution.vectorized; import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +import org.apache.commons.lang.NotImplementedException; /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. * The batched versions are preferable whenever possible. * - * Most of the APIs take the rowId as a parameter. This is the local 0-based row id for values + * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these + * columns have child columns. All of the data is stored in the child columns and the parent column + * contains nullability, and in the case of Arrays, the lengths and offsets into the child column. + * Lengths and offsets are encoded identically to INTs. + * Maps are just a special case of a two field struct. + * Strings are handled as an Array of ByteType. + * + * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the + * responsibility of the caller to call reserve() to ensure there is enough room before adding + * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), + * the lengths are known up front. + * + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in the current RowBatch. * * A ColumnVector should be considered immutable once originally created. In other words, it is not * valid to call put APIs after reads until reset() is called. + * + * ColumnVectors are intended to be reused. */ public abstract class ColumnVector { /** - * Allocates a column with each element of size `width` either on or off heap. + * Allocates a column to store elements of `type` on or off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. */ public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { if (mode == MemoryMode.OFF_HEAP) { @@ -42,13 +65,265 @@ public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode } } + /** + * Holder object to return an array. This object is intended to be reused. Callers should + * copy the data out if it needs to be stored. + */ + public static final class Array extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + public final ColumnVector data; + public int length; + public int offset; + + // Populate if binary data is required for the Array. This is stored here as an optimization + // for string data. + public byte[] byteArray; + public int byteArrayOffset; + + // Reused staging buffer, used for loading from offheap. + protected byte[] tmpByteArray = new byte[1]; + + protected Array(ColumnVector data) { + this.data = data; + } + + @Override + public final int numElements() { return length; } + + @Override + public ArrayData copy() { + throw new NotImplementedException(); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + + if (dt instanceof ByteType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getByte(offset + i); + } + } + } else if (dt instanceof IntegerType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getInt(offset + i); + } + } + } else if (dt instanceof DoubleType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getDouble(offset + i); + } + } + } else if (dt instanceof LongType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getLong(offset + i); + } + } + } else if (dt instanceof StringType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + } + } + } else { + throw new NotImplementedException("Type " + dt); + } + return list; + } + + @Override + public final boolean isNullAt(int ordinal) { return data.getIsNull(offset + ordinal); } + + @Override + public final boolean getBoolean(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + throw new NotImplementedException(); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + Array child = data.getByteArray(offset + ordinal); + return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + } + + @Override + public byte[] getBinary(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + throw new NotImplementedException(); + } + + @Override + public ArrayData getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + } + + /** + * Holder object to return a struct. This object is intended to be reused. + */ + public static final class Struct extends InternalRow { + // The fields that make up this struct. For example, if the struct had 2 int fields, the access + // to it would be: + // int f1 = fields[0].getInt[rowId] + // int f2 = fields[1].getInt[rowId] + public final ColumnVector[] fields; + + @Override + public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); } + + @Override + public boolean getBoolean(int ordinal) { + throw new NotImplementedException(); + } + + public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); } + + @Override + public short getShort(int ordinal) { + throw new NotImplementedException(); + } + + public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); } + public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { + throw new NotImplementedException(); + } + + public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + throw new NotImplementedException(); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + Array a = getByteArray(ordinal); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } + + @Override + public byte[] getBinary(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return fields[ordinal].getStruct(rowId); + } + + public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + + public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); } + public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); } + + @Override + public final int numFields() { + return fields.length; + } + + @Override + public InternalRow copy() { + throw new NotImplementedException(); + } + + @Override + public boolean anyNull() { + throw new NotImplementedException(); + } + + protected int rowId; + + protected Struct(ColumnVector[] fields) { + this.fields = fields; + } + } + + /** + * Returns the data type of this column. + */ public final DataType dataType() { return type; } /** * Resets this column for writing. The currently stored values are no longer accessible. */ public void reset() { + if (childColumns != null) { + for (ColumnVector c: childColumns) { + c.reset(); + } + } numNulls = 0; + elementsAppended = 0; if (anyNullsSet) { putNotNulls(0, capacity); anyNullsSet = false; @@ -61,6 +336,12 @@ public void reset() { */ public abstract void close(); + /* + * Ensures that there is enough storage to store capcity elements. That is, the put() APIs + * must work for all rowIds < capcity. + */ + public abstract void reserve(int capacity); + /** * Returns the number of nulls in this column. */ @@ -96,6 +377,26 @@ public void reset() { */ public abstract boolean getIsNull(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putByte(int rowId, byte value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBytes(int rowId, int count, byte value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract byte getByte(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -118,10 +419,36 @@ public void reset() { public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Returns the integer for rowId. + * Returns the value for rowId. */ public abstract int getInt(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putLong(int rowId, long value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putLongs(int rowId, int count, long value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 8-byte little endian longs. + */ + public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract long getLong(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -145,14 +472,248 @@ public void reset() { public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); /** - * Returns the double for rowId. + * Returns the value for rowId. */ public abstract double getDouble(int rowId); + /** + * Puts a byte array that already exists in this column. + */ + public abstract void putArray(int rowId, int offset, int length); + + /** + * Returns the length of the array at rowid. + */ + public abstract int getArrayLength(int rowId); + + /** + * Returns the offset of the array at rowid. + */ + public abstract int getArrayOffset(int rowId); + + /** + * Returns a utility object to get structs. + */ + public Struct getStruct(int rowId) { + resultStruct.rowId = rowId; + return resultStruct; + } + + /** + * Returns the array at rowid. + */ + public final Array getArray(int rowId) { + resultArray.length = getArrayLength(rowId); + resultArray.offset = getArrayOffset(rowId); + return resultArray; + } + + /** + * Loads the data into array.byteArray. + */ + public abstract void loadBytes(Array array); + + /** + * Sets the value at rowId to `value`. + */ + public abstract int putByteArray(int rowId, byte[] value, int offset, int count); + public final int putByteArray(int rowId, byte[] value) { + return putByteArray(rowId, value, 0, value.length); + } + + /** + * Returns the value for rowId. + */ + public final Array getByteArray(int rowId) { + Array array = getArray(rowId); + array.data.loadBytes(array); + return array; + } + + /** + * Append APIs. These APIs all behave similarly and will append data to the current vector. It + * is not valid to mix the put and append APIs. The append APIs are slower and should only be + * used if the sizes are not known up front. + * In all these cases, the return value is the rowId for the first appended element. + */ + public final int appendNull() { + assert (!(dataType() instanceof StructType)); // Use appendStruct() + reserve(elementsAppended + 1); + putNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNotNull() { + reserve(elementsAppended + 1); + putNotNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendNotNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNotNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendByte(byte v) { + reserve(elementsAppended + 1); + putByte(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBytes(int count, byte v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBytes(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendBytes(int length, byte[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putBytes(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendInt(int v) { + reserve(elementsAppended + 1); + putInt(elementsAppended, v); + return elementsAppended++; + } + + public final int appendInts(int count, int v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putInts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendInts(int length, int[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putInts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendLong(long v) { + reserve(elementsAppended + 1); + putLong(elementsAppended, v); + return elementsAppended++; + } + + public final int appendLongs(int count, long v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putLongs(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendLongs(int length, long[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putLongs(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendDouble(double v) { + reserve(elementsAppended + 1); + putDouble(elementsAppended, v); + return elementsAppended++; + } + + public final int appendDoubles(int count, double v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putDoubles(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendDoubles(int length, double[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putDoubles(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendByteArray(byte[] value, int offset, int length) { + int copiedOffset = arrayData().appendBytes(length, value, offset); + reserve(elementsAppended + 1); + putArray(elementsAppended, copiedOffset, length); + return elementsAppended++; + } + + public final int appendArray(int length) { + reserve(elementsAppended + 1); + putArray(elementsAppended, arrayData().elementsAppended, length); + return elementsAppended++; + } + + /** + * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this + * recursively appends a NULL to its children. + * We don't have this logic as the general appendNull implementation to optimize the more + * common non-struct case. + */ + public final int appendStruct(boolean isNull) { + if (isNull) { + appendNull(); + for (ColumnVector c: childColumns) { + if (c.type instanceof StructType) { + c.appendStruct(true); + } else { + c.appendNull(); + } + } + } else { + appendNotNull(); + } + return elementsAppended; + } + + /** + * Returns the data for the underlying array. + */ + public final ColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + + /** + * Returns the elements appended. + */ + public int getElementsAppended() { return elementsAppended; } + /** * Maximum number of rows that can be stored in this column. */ - protected final int capacity; + protected int capacity; + + /** + * Data type for this column. + */ + protected final DataType type; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. @@ -166,12 +727,63 @@ public void reset() { protected boolean anyNullsSet; /** - * Data type for this column. + * Default size of each array length value. This grows as necessary. */ - protected final DataType type; + protected static final int DEFAULT_ARRAY_LENGTH = 4; + + /** + * Current write cursor (row index) when appending data. + */ + protected int elementsAppended; - protected ColumnVector(int capacity, DataType type) { + /** + * If this is a nested type (array or struct), the column for the child data. + */ + protected final ColumnVector[] childColumns; + + /** + * Reusable Array holder for getArray(). + */ + protected final Array resultArray; + + /** + * Reusable Struct holder for getStruct(). + */ + protected final Struct resultStruct; + + /** + * Sets up the common state and also handles creating the child columns if this is a nested + * type. + */ + protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.capacity = capacity; this.type = type; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) { + DataType childType; + int childCapacity = capacity; + if (type instanceof ArrayType) { + childType = ((ArrayType)type).elementType(); + } else { + childType = DataTypes.ByteType; + childCapacity *= DEFAULT_ARRAY_LENGTH; + } + this.childColumns = new ColumnVector[1]; + this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); + this.resultArray = new Array(this.childColumns[0]); + this.resultStruct = null; + } else if (type instanceof StructType) { + StructType st = (StructType)type; + this.childColumns = new ColumnVector[st.fields().length]; + for (int i = 0; i < childColumns.length; ++i) { + this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + } + this.resultArray = null; + this.resultStruct = new Struct(this.childColumns); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java new file mode 100644 index 0000000000000..6c651a759d250 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.*; + +import org.apache.commons.lang.NotImplementedException; + +/** + * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly + * for debugging or other non-performance critical paths. + * These utilities are mostly used to convert ColumnVectors into other formats. + */ +public class ColumnVectorUtils { + public static String toString(ColumnVector.Array a) { + return new String(a.byteArray, a.byteArrayOffset, a.length); + } + + /** + * Returns the array data as the java primitive array. + * For example, an array of IntegerType will return an int[]. + * Throws exceptions for unhandled schemas. + */ + public static Object toPrimitiveJavaArray(ColumnVector.Array array) { + DataType dt = array.data.dataType(); + if (dt instanceof IntegerType) { + int[] result = new int[array.length]; + ColumnVector data = array.data; + for (int i = 0; i < result.length; i++) { + if (data.getIsNull(array.offset + i)) { + throw new RuntimeException("Cannot handle NULL values."); + } + result[i] = data.getInt(array.offset + i); + } + return result; + } else { + throw new NotImplementedException(); + } + } + + private static void appendValue(ColumnVector dst, DataType t, Object o) { + if (o == null) { + dst.appendNull(); + } else { + if (t == DataTypes.ByteType) { + dst.appendByte(((Byte)o).byteValue()); + } else if (t == DataTypes.IntegerType) { + dst.appendInt(((Integer)o).intValue()); + } else if (t == DataTypes.LongType) { + dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.DoubleType) { + dst.appendDouble(((Double)o).doubleValue()); + } else if (t == DataTypes.StringType) { + byte[] b =((String)o).getBytes(); + dst.appendByteArray(b, 0, b.length); + } else { + throw new NotImplementedException("Type " + t); + } + } + } + + private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) { + if (t instanceof ArrayType) { + ArrayType at = (ArrayType)t; + if (src.isNullAt(fieldIdx)) { + dst.appendNull(); + } else { + List values = src.getList(fieldIdx); + dst.appendArray(values.size()); + for (Object o : values) { + appendValue(dst.arrayData(), at.elementType(), o); + } + } + } else if (t instanceof StructType) { + StructType st = (StructType)t; + if (src.isNullAt(fieldIdx)) { + dst.appendStruct(true); + } else { + dst.appendStruct(false); + Row c = src.getStruct(fieldIdx); + for (int i = 0; i < st.fields().length; i++) { + appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + } + } + } else { + appendValue(dst, t, src.get(fieldIdx)); + } + } + + /** + * Converts an iterator of rows into a single ColumnBatch. + */ + public static ColumnarBatch toBatch( + StructType schema, MemoryMode memMode, Iterator row) { + ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode); + int n = 0; + while (row.hasNext()) { + Row r = row.next(); + for (int i = 0; i < schema.fields().length; i++) { + appendValue(batch.column(i), schema.fields()[i].dataType(), r, i); + } + n++; + } + batch.setNumRows(n); + return batch; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 2c55f854c2419..d558dae50c227 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -21,12 +21,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -48,6 +46,7 @@ */ public final class ColumnarBatch { private static final int DEFAULT_BATCH_SIZE = 4 * 1024; + private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; private final StructType schema; private final int capacity; @@ -64,6 +63,10 @@ public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); } + public static ColumnarBatch allocate(StructType type) { + return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE); + } + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) { return new ColumnarBatch(schema, maxRows, memMode); } @@ -82,25 +85,53 @@ public void close() { * Adapter class to interop with existing components that expect internal row. A lot of * performance is lost with this translation. */ - public final class Row extends InternalRow { + public static final class Row extends InternalRow { private int rowId; + private final ColumnarBatch parent; + private final int fixedLenRowSize; + + private Row(ColumnarBatch parent) { + this.parent = parent; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + } /** * Marks this row as being filtered out. This means a subsequent iteration over the rows * in this batch will not include this row. */ public final void markFiltered() { - ColumnarBatch.this.markFiltered(rowId); + parent.markFiltered(rowId); } @Override public final int numFields() { - return ColumnarBatch.this.numCols(); + return parent.numCols(); } @Override + /** + * Revisit this. This is expensive. + */ public final InternalRow copy() { - throw new NotImplementedException(); + UnsafeRow row = new UnsafeRow(parent.numCols()); + row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize); + for (int i = 0; i < parent.numCols(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = parent.schema.fields()[i].dataType(); + if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else { + throw new RuntimeException("Not implemented."); + } + } + } + return row; } @Override @@ -110,7 +141,7 @@ public final boolean anyNull() { @Override public final boolean isNullAt(int ordinal) { - return ColumnarBatch.this.column(ordinal).getIsNull(rowId); + return parent.column(ordinal).getIsNull(rowId); } @Override @@ -119,9 +150,7 @@ public final boolean getBoolean(int ordinal) { } @Override - public final byte getByte(int ordinal) { - throw new NotImplementedException(); - } + public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); } @Override public final short getShort(int ordinal) { @@ -130,13 +159,11 @@ public final short getShort(int ordinal) { @Override public final int getInt(int ordinal) { - return ColumnarBatch.this.column(ordinal).getInt(rowId); + return parent.column(ordinal).getInt(rowId); } @Override - public final long getLong(int ordinal) { - throw new NotImplementedException(); - } + public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); } @Override public final float getFloat(int ordinal) { @@ -145,7 +172,7 @@ public final float getFloat(int ordinal) { @Override public final double getDouble(int ordinal) { - return ColumnarBatch.this.column(ordinal).getDouble(rowId); + return parent.column(ordinal).getDouble(rowId); } @Override @@ -155,7 +182,8 @@ public final Decimal getDecimal(int ordinal, int precision, int scale) { @Override public final UTF8String getUTF8String(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } @Override @@ -170,12 +198,12 @@ public final CalendarInterval getInterval(int ordinal) { @Override public final InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); + return parent.column(ordinal).getStruct(rowId); } @Override public final ArrayData getArray(int ordinal) { - throw new NotImplementedException(); + return parent.column(ordinal).getArray(rowId); } @Override @@ -194,7 +222,7 @@ public final Object get(int ordinal, DataType dataType) { */ public Iterator rowIterator() { final int maxRows = ColumnarBatch.this.numRows(); - final Row row = new Row(); + final Row row = new Row(this); return new Iterator() { int rowId = 0; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 6180dd308e5e3..335124fd5a603 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,12 +18,18 @@ import java.nio.ByteOrder; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DoubleType; import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.commons.lang.NotImplementedException; + import org.apache.commons.lang.NotImplementedException; /** @@ -35,21 +41,21 @@ public final class OffHeapColumnVector extends ColumnVector { private long nulls; private long data; + // Set iff the type is array. + private long lengthData; + private long offsetData; + protected OffHeapColumnVector(int capacity, DataType type) { - super(capacity, type); + super(capacity, type, MemoryMode.OFF_HEAP); if (!ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) { throw new NotImplementedException("Only little endian is supported."); } + nulls = 0; + data = 0; + lengthData = 0; + offsetData = 0; - this.nulls = Platform.allocateMemory(capacity); - if (type instanceof IntegerType) { - this.data = Platform.allocateMemory(capacity * 4); - } else if (type instanceof DoubleType) { - this.data = Platform.allocateMemory(capacity * 8); - } else { - throw new RuntimeException("Unhandled " + type); - } - anyNullsSet = true; + reserveInternal(capacity); reset(); } @@ -67,8 +73,12 @@ public long nullsNativeAddress() { public final void close() { Platform.freeMemory(nulls); Platform.freeMemory(data); + Platform.freeMemory(lengthData); + Platform.freeMemory(offsetData); nulls = 0; data = 0; + lengthData = 0; + offsetData = 0; } // @@ -111,6 +121,33 @@ public final boolean getIsNull(int rowId) { return Platform.getByte(null, nulls + rowId) == 1; } + // + // APIs dealing with Bytes + // + + @Override + public final void putByte(int rowId, byte value) { + Platform.putByte(null, data + rowId, value); + + } + + @Override + public final void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, value); + } + } + + @Override + public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count); + } + + @Override + public final byte getByte(int rowId) { + return Platform.getByte(null, data + rowId); + } + // // APIs dealing with ints // @@ -145,6 +182,40 @@ public final int getInt(int rowId) { return Platform.getInt(null, data + 4 * rowId); } + // + // APIs dealing with Longs + // + + @Override + public final void putLong(int rowId, long value) { + Platform.putLong(null, data + 8 * rowId, value); + } + + @Override + public final void putLongs(int rowId, int count, long value) { + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8) { + Platform.putLong(null, offset, value); + } + } + + @Override + public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, + null, data + 8 * rowId, count * 8); + } + + @Override + public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 8 * rowId, count * 8); + } + + @Override + public final long getLong(int rowId) { + return Platform.getLong(null, data + 8 * rowId); + } + // // APIs dealing with doubles // @@ -178,4 +249,70 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { public final double getDouble(int rowId) { return Platform.getDouble(null, data + rowId * 8); } + + // + // APIs dealing with Arrays. + // + @Override + public final void putArray(int rowId, int offset, int length) { + assert(offset >= 0 && offset + length <= childColumns[0].capacity); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, offset); + } + + @Override + public final int getArrayLength(int rowId) { + return Platform.getInt(null, lengthData + 4 * rowId); + } + + @Override + public final int getArrayOffset(int rowId) { + return Platform.getInt(null, offsetData + 4 * rowId); + } + + // APIs dealing with ByteArrays + @Override + public final int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, result); + return result; + } + + @Override + public final void loadBytes(Array array) { + if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; + Platform.copyMemory( + null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); + array.byteArray = array.tmpByteArray; + array.byteArrayOffset = 0; + } + + @Override + public final void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Split out the slow path. + private final void reserveInternal(int newCapacity) { + if (this.resultArray != null) { + this.lengthData = + Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + this.offsetData = + Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + } else if (type instanceof ByteType) { + this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + } else if (type instanceof IntegerType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + } else if (type instanceof LongType || type instanceof DoubleType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); + Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + capacity = newCapacity; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 76d9956c3842f..8197fa11cd4c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,10 @@ */ package org.apache.spark.sql.execution.vectorized; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; import java.util.Arrays; /** @@ -37,19 +34,18 @@ public final class OnHeapColumnVector extends ColumnVector { private byte[] nulls; // Array for each type. Only 1 is populated for any type. + private byte[] byteData; private int[] intData; + private long[] longData; private double[] doubleData; + // Only set if type is Array. + private int[] arrayLengths; + private int[] arrayOffsets; + protected OnHeapColumnVector(int capacity, DataType type) { - super(capacity, type); - if (type instanceof IntegerType) { - this.intData = new int[capacity]; - } else if (type instanceof DoubleType) { - this.doubleData = new double[capacity]; - } else { - throw new RuntimeException("Unhandled " + type); - } - this.nulls = new byte[capacity]; + super(capacity, type, MemoryMode.ON_HEAP); + reserveInternal(capacity); reset(); } @@ -108,6 +104,32 @@ public final boolean getIsNull(int rowId) { return nulls[rowId] == 1; } + // + // APIs dealing with Bytes + // + + @Override + public final void putByte(int rowId, byte value) { + byteData[rowId] = value; + } + + @Override + public final void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = value; + } + } + + @Override + public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + System.arraycopy(src, srcIndex, byteData, rowId, count); + } + + @Override + public final byte getByte(int rowId) { + return byteData[rowId]; + } + // // APIs dealing with Ints // @@ -144,6 +166,43 @@ public final int getInt(int rowId) { return intData[rowId]; } + // + // APIs dealing with Longs + // + + @Override + public final void putLong(int rowId, long value) { + longData[rowId] = value; + } + + @Override + public final void putLongs(int rowId, int count, long value) { + for (int i = 0; i < count; ++i) { + longData[i + rowId] = value; + } + } + + @Override + public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + System.arraycopy(src, srcIndex, longData, rowId, count); + } + + @Override + public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i) { + longData[i + rowId] = Platform.getLong(src, srcOffset); + srcIndex += 8; + srcOffset += 8; + } + } + + @Override + public final long getLong(int rowId) { + return longData[rowId]; + } + + // // APIs dealing with doubles // @@ -173,4 +232,86 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { public final double getDouble(int rowId) { return doubleData[rowId]; } + + // + // APIs dealing with Arrays + // + + @Override + public final int getArrayLength(int rowId) { + return arrayLengths[rowId]; + } + @Override + public final int getArrayOffset(int rowId) { + return arrayOffsets[rowId]; + } + + @Override + public final void putArray(int rowId, int offset, int length) { + arrayOffsets[rowId] = offset; + arrayLengths[rowId] = length; + } + + @Override + public final void loadBytes(Array array) { + array.byteArray = byteData; + array.byteArrayOffset = array.offset; + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public final int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + arrayOffsets[rowId] = result; + arrayLengths[rowId] = length; + return result; + } + + @Override + public final void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Spilt this function out since it is the slow path. + private final void reserveInternal(int newCapacity) { + if (this.resultArray != null) { + int[] newLengths = new int[newCapacity]; + int[] newOffsets = new int[newCapacity]; + if (this.arrayLengths != null) { + System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + } + arrayLengths = newLengths; + arrayOffsets = newOffsets; + } else if (type instanceof ByteType) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; + } else if (type instanceof IntegerType) { + int[] newData = new int[newCapacity]; + if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + intData = newData; + } else if (type instanceof LongType) { + long[] newData = new long[newCapacity]; + if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + longData = newData; + } else if (type instanceof DoubleType) { + double[] newData = new double[newCapacity]; + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + doubleData = newData; + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + + byte[] newNulls = new byte[newCapacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + nulls = newNulls; + + capacity = newCapacity; + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 95c9550aebb0a..8a95359d9de25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -40,8 +40,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val rand = new Random(42) for (i <- 0 until 6) { - val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes) - val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes) + val keySchema = RandomDataGenerator.randomSchema(rand, rand.nextInt(10) + 1, keyTypes) + val valueSchema = RandomDataGenerator.randomSchema(rand, rand.nextInt(10) + 1, valueTypes) testKVSorter(keySchema, valueSchema, spill = i > 3) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index bfe944d835bd3..8efdf8adb042a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer +import scala.util.Random + import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -239,6 +241,26 @@ object ColumnarBatchBenchmark { Platform.freeMemory(buffer) } + // Adding values by appending, instead of putting. + val onHeapAppend = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.appendInt(i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + col.reset() + } + col.close + } + /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Int Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate @@ -253,6 +275,7 @@ object ColumnarBatchBenchmark { Column(off heap direct) 237.6 1379.12 1.05 X UnsafeRow (on heap) 414.6 790.35 0.60 X UnsafeRow (off heap) 487.2 672.58 0.51 X + Column On Heap Append 530.1 618.14 0.59 X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -265,6 +288,7 @@ object ColumnarBatchBenchmark { benchmark.addCase("Column(off heap direct)")(columnOffheapDirect) benchmark.addCase("UnsafeRow (on heap)")(unsafeRowOnheap) benchmark.addCase("UnsafeRow (off heap)")(unsafeRowOffheap) + benchmark.addCase("Column On Heap Append")(onHeapAppend) benchmark.run() } @@ -314,8 +338,60 @@ object ColumnarBatchBenchmark { benchmark.run() } + def stringAccess(iters: Long): Unit = { + val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + val random = new Random(0) + + def randomString(min: Int, max: Int): String = { + val len = random.nextInt(max - min) + min + val sb = new StringBuilder(len) + var i = 0 + while (i < len) { + sb.append(chars.charAt(random.nextInt(chars.length()))); + i += 1 + } + return sb.toString + } + + val minString = 3 + val maxString = 32 + val count = 4 * 1000 + + val data = Seq.fill(count)(randomString(minString, maxString)).map(_.getBytes).toArray + + def column(memoryMode: MemoryMode) = { i: Int => + val column = ColumnVector.allocate(count, BinaryType, memoryMode) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + column.putByteArray(i, data(i)) + i += 1 + } + i = 0 + while (i < count) { + sum += column.getByteArray(i).length + i += 1 + } + column.reset() + } + } + + /* + String Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------------- + On Heap 457.0 35.85 1.00 X + Off Heap 1206.0 13.59 0.38 X + */ + val benchmark = new Benchmark("String Read/Write", count * iters) + benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) + benchmark.addCase("Off Heap")(column(MemoryMode.OFF_HEAP)) + benchmark.run + } + def main(args: Array[String]): Unit = { intAccess(1024 * 40) booleanAccess(1024 * 40) + stringAccess(1024 * 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index d5e517c7f56be..215ca9ab6b770 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution.vectorized +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform class ColumnarBatchSuite extends SparkFunSuite { @@ -74,6 +75,45 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Byte Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[Byte] + + val column = ColumnVector.allocate(1024, ByteType, memMode) + var idx = 0 + + val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + column.putBytes(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putBytes(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + column.putByte(idx, 9) + reference += 9 + idx += 1 + + column.putBytes(idx, 3, 4) + reference += 4 + reference += 4 + reference += 4 + idx += 3 + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getByte(v._2), "MemoryMode" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getByte(null, addr + v._2)) + } + } + }} + } + test("Int Apis") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -142,6 +182,76 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Long Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Long] + + val column = ColumnVector.allocate(1024, LongType, memMode) + var idx = 0 + + val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + column.putLongs(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putLongs(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + val littleEndian = new Array[Byte](16) + littleEndian(0) = 7 + littleEndian(1) = 1 + littleEndian(8) = 6 + littleEndian(10) = 1 + + column.putLongsLittleEndian(idx, 1, littleEndian, 8) + column.putLongsLittleEndian(idx + 1, 1, littleEndian, 0) + reference += 6 + (1 << 16) + reference += 7 + (1 << 8) + idx += 2 + + column.putLongsLittleEndian(idx, 2, littleEndian, 0) + reference += 7 + (1 << 8) + reference += 6 + (1 << 16) + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextLong() + column.putLong(idx, v) + reference += v + idx += 1 + } else { + + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + column.putLongs(idx, n, n + 1) + var i = 0 + while (i < n) { + reference += (n + 1) + i += 1 + } + idx += n + } + } + + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getLong(v._2), "idx=" + v._2 + + " Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) + } + } + }} + } + test("Double APIs") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -209,15 +319,150 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("String APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[String] + + val column = ColumnVector.allocate(6, BinaryType, memMode) + assert(column.arrayData().elementsAppended == 0) + var idx = 0 + + val values = ("Hello" :: "abc" :: Nil).toArray + column.putByteArray(idx, values(0).getBytes, 0, values(0).getBytes().length) + reference += values(0) + idx += 1 + assert(column.arrayData().elementsAppended == 5) + + column.putByteArray(idx, values(1).getBytes, 0, values(1).getBytes().length) + reference += values(1) + idx += 1 + assert(column.arrayData().elementsAppended == 8) + + // Just put llo + val offset = column.putByteArray(idx, values(0).getBytes, 2, values(0).getBytes().length - 2) + reference += "llo" + idx += 1 + assert(column.arrayData().elementsAppended == 11) + + // Put the same "ll" at offset. This should not allocate more memory in the column. + column.putArray(idx, offset, 2) + reference += "ll" + idx += 1 + assert(column.arrayData().elementsAppended == 11) + + // Put a long string + val s = "abcdefghijklmnopqrstuvwxyz" + column.putByteArray(idx, (s + s).getBytes) + reference += (s + s) + idx += 1 + assert(column.arrayData().elementsAppended == 11 + (s + s).length) + + reference.zipWithIndex.foreach { v => + assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) + assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + "MemoryMode" + memMode) + } + + column.reset() + assert(column.arrayData().elementsAppended == 0) + }} + } + + test("Int Array") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + + // Fill the underlying data with all the arrays back to back. + val data = column.arrayData(); + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 0) + column.putArray(3, 3, 3) + + val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + assert(a1 === Array(0)) + assert(a2 === Array(1, 2)) + assert(a3 === Array.empty[Int]) + assert(a4 === Array(3, 4, 5)) + + // Verify the ArrayData APIs + assert(column.getArray(0).length == 1) + assert(column.getArray(0).getInt(0) == 0) + + assert(column.getArray(1).length == 2) + assert(column.getArray(1).getInt(0) == 1) + assert(column.getArray(1).getInt(1) == 2) + + assert(column.getArray(2).length == 0) + + assert(column.getArray(3).length == 3) + assert(column.getArray(3).getInt(0) == 3) + assert(column.getArray(3).getInt(1) == 4) + assert(column.getArray(3).getInt(2) == 5) + + // Add a longer array which requires resizing + column.reset + val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + assert(data.capacity == 10) + data.reserve(array.length) + assert(data.capacity == array.length * 2) + data.putInts(0, array.length, array, 0) + column.putArray(0, 0, array.length) + assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + === array) + }} + } + + test("Struct Column") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val schema = new StructType().add("int", IntegerType).add("double", DoubleType) + val column = ColumnVector.allocate(1024, schema, memMode) + + val c1 = column.getChildColumn(0) + val c2 = column.getChildColumn(1) + assert(c1.dataType() == IntegerType) + assert(c2.dataType() == DoubleType) + + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val s = column.getStruct(0) + assert(s.fields(0).getInt(0) == 123) + assert(s.fields(0).getInt(1) == 456) + assert(s.fields(1).getDouble(0) == 3.45) + assert(s.fields(1).getDouble(1) == 5.67) + + assert(s.getInt(0) == 123) + assert(s.getDouble(1) == 3.45) + + val s2 = column.getStruct(1) + assert(s2.getInt(0) == 456) + assert(s2.getDouble(1) == 5.67) + }} + } + test("ColumnarBatch basic") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType() .add("intCol", IntegerType) .add("doubleCol", DoubleType) .add("intCol2", IntegerType) + .add("string", BinaryType) val batch = ColumnarBatch.allocate(schema, memMode) - assert(batch.numCols() == 3) + assert(batch.numCols() == 4) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) assert(batch.capacity() > 0) @@ -227,10 +472,11 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.column(0).putInt(0, 1) batch.column(1).putDouble(0, 1.1) batch.column(2).putNull(0) + batch.column(3).putByteArray(0, "Hello".getBytes) batch.setNumRows(1) // Verify the results of the row. - assert(batch.numCols() == 3) + assert(batch.numCols() == 4) assert(batch.numRows() == 1) assert(batch.numValidRows() == 1) assert(batch.rowIterator().hasNext == true) @@ -241,6 +487,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -251,6 +498,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) @@ -260,24 +508,27 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.numValidRows() == 0) assert(batch.rowIterator().hasNext == false) - // Reset and add 3 throws + // Reset and add 3 rows batch.reset() assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) assert(batch.rowIterator().hasNext == false) - // Add rows [NULL, 2.2, 2], [3, NULL, 3], [4, 4.4, 4] + // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] batch.column(0).putNull(0) batch.column(1).putDouble(0, 2.2) batch.column(2).putInt(0, 2) + batch.column(3).putByteArray(0, "abc".getBytes) batch.column(0).putInt(1, 3) batch.column(1).putNull(1) batch.column(2).putInt(1, 3) + batch.column(3).putByteArray(1, "".getBytes) batch.column(0).putInt(2, 4) batch.column(1).putDouble(2, 4.4) batch.column(2).putInt(2, 4) + batch.column(3).putByteArray(2, "world".getBytes) batch.setNumRows(3) def rowEquals(x: InternalRow, y: Row): Unit = { @@ -289,30 +540,152 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(x.isNullAt(2) == y.isNullAt(2)) if (!x.isNullAt(2)) assert(x.getInt(2) == y.getInt(2)) + + assert(x.isNullAt(3) == y.isNullAt(3)) + if (!x.isNullAt(3)) assert(x.getString(3) == y.getString(3)) } + // Verify assert(batch.numRows() == 3) assert(batch.numValidRows() == 3) val it2 = batch.rowIterator() - rowEquals(it2.next(), Row(null, 2.2, 2)) - rowEquals(it2.next(), Row(3, null, 3)) - rowEquals(it2.next(), Row(4, 4.4, 4)) + rowEquals(it2.next(), Row(null, 2.2, 2, "abc")) + rowEquals(it2.next(), Row(3, null, 3, "")) + rowEquals(it2.next(), Row(4, 4.4, 4, "world")) assert(!it.hasNext) // Filter out some rows and verify batch.markFiltered(1) assert(batch.numValidRows() == 2) val it3 = batch.rowIterator() - rowEquals(it3.next(), Row(null, 2.2, 2)) - rowEquals(it3.next(), Row(4, 4.4, 4)) + rowEquals(it3.next(), Row(null, 2.2, 2, "abc")) + rowEquals(it3.next(), Row(4, 4.4, 4, "world")) assert(!it.hasNext) batch.markFiltered(2) assert(batch.numValidRows() == 1) val it4 = batch.rowIterator() - rowEquals(it4.next(), Row(null, 2.2, 2)) + rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) batch.close }} } + + + private def doubleEquals(d1: Double, d2: Double): Boolean = { + if (d1.isNaN && d2.isNaN) { + true + } else { + d1 == d2 + } + } + + private def compareStruct(fields: Seq[StructField], r1: InternalRow, r2: Row, seed: Long) { + fields.zipWithIndex.foreach { v => { + assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) + if (!r1.isNullAt(v._2)) { + v._1.dataType match { + case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) + case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), + "Seed = " + seed) + case StringType => + assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case ArrayType(childType, n) => + val a1 = r1.getArray(v._2).array + val a2 = r2.getList(v._2).toArray + assert(a1.length == a2.length, "Seed = " + seed) + childType match { + case DoubleType => { + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]), + "Seed = " + seed) + i += 1 + } + } + case _ => assert(a1 === a2, "Seed = " + seed) + } + case StructType(childFields) => + compareStruct(childFields, r1.getStruct(v._2, fields.length), r2.getStruct(v._2), seed) + case _ => + throw new NotImplementedError("Not implemented " + v._1.dataType) + } + } + }} + } + + test("Convert rows") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val rows = Row(1, 2L, "a", 1.2, 'b'.toByte) :: Row(4, 5L, "cd", 2.3, 'a'.toByte) :: Nil + val schema = new StructType() + .add("i1", IntegerType) + .add("l2", LongType) + .add("string", StringType) + .add("d", DoubleType) + .add("b", ByteType) + + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + assert(batch.numRows() == 2) + assert(batch.numCols() == 5) + + val it = batch.rowIterator() + val referenceIt = rows.iterator + while (it.hasNext) { + compareStruct(schema, it.next(), referenceIt.next(), 0) + } + batch.close() + } + }} + + /** + * This test generates a random schema data, serializes it to column batches and verifies the + * results. + */ + def testRandomRows(flatSchema: Boolean, numFields: Int) { + // TODO: add remaining types. Figure out why StringType doesn't work on jenkins. + val types = Array(ByteType, IntegerType, LongType, DoubleType) + val seed = System.nanoTime() + val NUM_ROWS = 500 + val NUM_ITERS = 1000 + val random = new Random(seed) + var i = 0 + while (i < NUM_ITERS) { + val schema = if (flatSchema) { + RandomDataGenerator.randomSchema(random, numFields, types) + } else { + RandomDataGenerator.randomNestedSchema(random, numFields, types) + } + val rows = mutable.ArrayBuffer.empty[Row] + var j = 0 + while (j < NUM_ROWS) { + val row = RandomDataGenerator.randomRow(random, schema) + rows += row + j += 1 + } + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + assert(batch.numRows() == NUM_ROWS) + + val it = batch.rowIterator() + val referenceIt = rows.iterator + var k = 0 + while (it.hasNext) { + compareStruct(schema, it.next(), referenceIt.next(), seed) + k += 1 + } + batch.close() + }} + i += 1 + } + } + + test("Random flat schema") { + testRandomRows(true, 10) + } + + test("Random nested schema") { + testRandomRows(false, 30) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 76b36aa89182e..3e4cf3f79e57c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import scala.util.Random import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -879,7 +880,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te RandomDataGenerator.forType( dataType = schemaForGenerator, nullable = true, - seed = Some(System.nanoTime())) + new Random(System.nanoTime())) val dataGenerator = maybeDataGenerator .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 3f9ecf6965e1d..1a4b3ece72a66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.collection.JavaConverters._ +import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -122,7 +123,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val dataGenerator = RandomDataGenerator.forType( dataType = dataType, nullable = true, - seed = Some(System.nanoTime()) + new Random(System.nanoTime()) ).getOrElse { fail(s"Failed to create data generator for schema $dataType") } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 0d6b215fe5aaf..b29bf6a464b30 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -105,6 +105,17 @@ public static void freeMemory(long address) { _UNSAFE.freeMemory(address); } + public static long reallocateMemory(long address, long oldSize, long newSize) { + long newMemory = _UNSAFE.allocateMemory(newSize); + copyMemory(null, address, null, newMemory, oldSize); + freeMemory(address); + return newMemory; + } + + public static void setMemory(long address, byte value, long size) { + _UNSAFE.setMemory(address, size, value); + } + public static void copyMemory( Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy From b72611f20a03c790b6fd341b6ffdb3b5437609ee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 26 Jan 2016 17:59:05 -0800 Subject: [PATCH 1707/2089] [SPARK-7780][MLLIB] intercept in logisticregressionwith lbfgs should not be regularized The intercept in Logistic Regression represents a prior on categories which should not be regularized. In MLlib, the regularization is handled through Updater, and the Updater penalizes all the components without excluding the intercept which resulting poor training accuracy with regularization. The new implementation in ML framework handles this properly, and we should call the implementation in ML from MLlib since majority of users are still using MLlib api. Note that both of them are doing feature scalings to improve the convergence, and the only difference is ML version doesn't regularize the intercept. As a result, when lambda is zero, they will converge to the same solution. Previously partially reviewed at https://github.com/apache/spark/pull/6386#issuecomment-168781424 re-opening for dbtsai to review. Author: Holden Karau Author: Holden Karau Closes #10788 from holdenk/SPARK-7780-intercept-in-logisticregressionwithLBFGS-should-not-be-regularized. --- .../classification/LogisticRegression.scala | 36 ++++++-- .../classification/LogisticRegression.scala | 82 ++++++++++++++++++- .../spark/mllib/optimization/LBFGS.scala | 28 +++++++ .../GeneralizedLinearAlgorithm.scala | 34 ++++---- .../ml/classification/OneVsRestSuite.scala | 2 +- .../LogisticRegressionSuite.scala | 25 ++++-- 6 files changed, 179 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c98a78a515dc3..9b2340a1f16fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -247,15 +247,27 @@ class LogisticRegression @Since("1.2.0") ( @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds - override protected def train(dataset: DataFrame): LogisticRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. + private var optInitialModel: Option[LogisticRegressionModel] = None + + /** @group setParam */ + private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { + this.optInitialModel = Some(model) + this + } + + override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + train(dataset, handlePersistence) + } + + protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): + LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) val (summarizer, labelSummarizer) = { @@ -343,7 +355,21 @@ class LogisticRegression @Since("1.2.0") ( val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) - if ($(fitIntercept)) { + if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) { + val vec = optInitialModel.get.coefficients + logWarning( + s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}") + } + + if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) { + val initialCoefficientsWithInterceptArray = initialCoefficientsWithIntercept.toArray + optInitialModel.get.coefficients.foreachActive { case (index, value) => + initialCoefficientsWithInterceptArray(index) = value + } + if ($(fitIntercept)) { + initialCoefficientsWithInterceptArray(numFeatures) == optInitialModel.get.intercept + } + } else if ($(fitIntercept)) { /* For binary logistic regression, when we initialize the coefficients as zeros, it will converge faster if we initialize the intercept such that @@ -434,7 +460,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { */ @Since("1.4.0") @Experimental -class LogisticRegressionModel private[ml] ( +class LogisticRegressionModel private[spark] ( @Since("1.4.0") override val uid: String, @Since("1.6.0") val coefficients: Vector, @Since("1.3.0") val intercept: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2a7697b5a79cc..bf68e3edd7ed6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -19,15 +19,18 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.classification.impl.GLMClassificationModel -import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} +import org.apache.spark.mllib.util.MLUtils.appendBias import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel /** * Classification model trained using Multinomial/Binary Logistic Regression. @@ -332,6 +335,13 @@ object LogisticRegressionWithSGD { * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. + * + * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization + * penalty to all elements including the intercept. If this is called with one of + * standard updaters (L1Updater, or SquaredL2Updater) this is translated + * into a call to ml.LogisticRegression, otherwise this will use the existing mllib + * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the + * intercept. */ @Since("1.1.0") class LogisticRegressionWithLBFGS @@ -374,4 +384,72 @@ class LogisticRegressionWithLBFGS new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) } } + + /** + * Run Logistic Regression with the configured parameters on an input RDD + * of LabeledPoint entries. + * + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * If using ml implementation, uses ml code to generate initial weights. + */ + override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = { + run(input, generateInitialWeights(input), userSuppliedWeights = false) + } + + /** + * Run Logistic Regression with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + * + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * Uses user provided weights. + */ + override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { + run(input, initialWeights, userSuppliedWeights = true) + } + + private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean): + LogisticRegressionModel = { + // ml's Logisitic regression only supports binary classifcation currently. + if (numOfLinearPredictor == 1) { + def runWithMlLogisitcRegression(elasticNetParam: Double) = { + // Prepare the ml LogisticRegression based on our settings + val lr = new org.apache.spark.ml.classification.LogisticRegression() + lr.setRegParam(optimizer.getRegParam()) + lr.setElasticNetParam(elasticNetParam) + lr.setStandardization(useFeatureScaling) + if (userSuppliedWeights) { + val uid = Identifiable.randomUID("logreg-static") + lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel( + uid, initialWeights, 1.0)) + } + lr.setFitIntercept(addIntercept) + lr.setMaxIter(optimizer.getNumIterations()) + lr.setTol(optimizer.getConvergenceTol()) + // Convert our input into a DataFrame + val sqlContext = new SQLContext(input.context) + import sqlContext.implicits._ + val df = input.toDF() + // Determine if we should cache the DF + val handlePersistence = input.getStorageLevel == StorageLevel.NONE + // Train our model + val mlLogisticRegresionModel = lr.train(df, handlePersistence) + // convert the model + val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray) + createModel(weights, mlLogisticRegresionModel.intercept) + } + optimizer.getUpdater() match { + case x: SquaredL2Updater => runWithMlLogisitcRegression(1.0) + case x: L1Updater => runWithMlLogisitcRegression(0.0) + case _ => super.run(input, initialWeights) + } + } else { + super.run(input, initialWeights) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index efedc112d380e..a5bd77e6bee91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -69,6 +69,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /* + * Get the convergence tolerance of iterations. + */ + private[mllib] def getConvergenceTol(): Double = { + this.convergenceTol + } + /** * Set the maximal number of iterations for L-BFGS. Default 100. * @deprecated use [[LBFGS#setNumIterations]] instead @@ -86,6 +93,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Get the maximum number of iterations for L-BFGS. Defaults to 100. + */ + private[mllib] def getNumIterations(): Int = { + this.maxNumIterations + } + /** * Set the regularization parameter. Default 0.0. */ @@ -94,6 +108,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Get the regularization parameter. + */ + private[mllib] def getRegParam(): Double = { + this.regParam + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for L-BFGS. @@ -113,6 +134,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Returns the updater, limited to internal use. + */ + private[mllib] def getUpdater(): Updater = { + updater + } + override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { val (weights, _) = LBFGS.runLBFGS( data, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index e60edc675c83f..73da899a0edd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -140,7 +140,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * translated back to resulting model weights, so it's transparent to users. * Note: This technique is used in both libsvm and glmnet packages. Default false. */ - private var useFeatureScaling = false + private[mllib] var useFeatureScaling = false /** * The dimension of training features. @@ -196,12 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. - * + * Generate the initial weights when the user does not supply them */ - @Since("0.8.0") - def run(input: RDD[LabeledPoint]): M = { + protected def generateInitialWeights(input: RDD[LabeledPoint]): Vector = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() } @@ -217,16 +214,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always * have the intercept as part of weights to have consistent design. */ - val initialWeights = { - if (numOfLinearPredictor == 1) { - Vectors.zeros(numFeatures) - } else if (addIntercept) { - Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) - } else { - Vectors.zeros(numFeatures * numOfLinearPredictor) - } + if (numOfLinearPredictor == 1) { + Vectors.zeros(numFeatures) + } else if (addIntercept) { + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) + } else { + Vectors.zeros(numFeatures * numOfLinearPredictor) } - run(input, initialWeights) + } + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + * + */ + @Since("0.8.0") + def run(input: RDD[LabeledPoint]): M = { + run(input, generateInitialWeights(input)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index d7983f92a3483..445e50d867e15 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -168,7 +168,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 8d14bb6572155..8fef1316cd216 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -215,6 +216,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w // Test if we can correctly learn A, B where Y = logistic(A + B*X) test("logistic regression with LBFGS") { + val updaters: List[Updater] = List(new SquaredL2Updater(), new L1Updater()) + updaters.foreach(testLBFGS) + } + + private def testLBFGS(myUpdater: Updater): Unit = { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -223,7 +229,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + // Override the updater + class LogisticRegressionWithLBFGSCustomUpdater + extends LogisticRegressionWithLBFGS { + override val optimizer = + new LBFGS(new LogisticGradient, myUpdater) + } + + val lr = new LogisticRegressionWithLBFGSCustomUpdater().setIntercept(true) val model = lr.run(testRDD) @@ -396,10 +410,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) // Training data with different scales without feature standardization - // will not yield the same result in the scaled space due to poor - // convergence rate. - assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) - assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + // should still converge quickly since the model still uses standardization but + // simply modifies the regularization function. See regParamL1Fun and related + // inside of LogisticRegression + assert(modelB1.weights(0) ~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) ~== modelB3.weights(0) * 1.0E6 absTol 0.1) } test("multinomial logistic regression with LBFGS") { From e7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Jan 2016 19:29:47 -0800 Subject: [PATCH 1708/2089] [SPARK-12903][SPARKR] Add covar_samp and covar_pop for SparkR Add ```covar_samp``` and ```covar_pop``` for SparkR. Should we also provide ```cov``` alias for ```covar_samp```? There is ```cov``` implementation at stats.R which masks ```stats::cov``` already, but may bring to breaking API change. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10829 from yanboliang/spark-12903. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 58 +++++++++++++++++++++++ R/pkg/R/generics.R | 10 +++- R/pkg/R/stats.R | 3 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 + 5 files changed, 73 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2cc1544bef080..f194a46303e0d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -35,6 +35,8 @@ exportMethods("arrange", "count", "cov", "corr", + "covar_samp", + "covar_pop", "crosstab", "describe", "dim", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9bb7876b384ce..8f8651c295eec 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -275,6 +275,64 @@ setMethod("corr", signature(x = "Column"), column(jc) }) +#' cov +#' +#' Compute the sample covariance between two expressions. +#' +#' @rdname cov +#' @name cov +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' cov(df$c, df$d) +#' cov("c", "d") +#' covar_samp(df$c, df$d) +#' covar_samp("c", "d") +#' } +setMethod("cov", signature(x = "characterOrColumn"), + function(x, col2) { + stopifnot(is(class(col2), "characterOrColumn")) + covar_samp(x, col2) + }) + +#' @rdname cov +#' @name covar_samp +setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_samp", col1, col2) + column(jc) + }) + +#' covar_pop +#' +#' Compute the population covariance between two expressions. +#' +#' @rdname covar_pop +#' @name covar_pop +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' covar_pop(df$c, df$d) +#' covar_pop("c", "d") +#' } +setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_pop", col1, col2) + column(jc) + }) + #' cos #' #' Computes the cosine of the given value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 04784d51566cb..2dba71abec689 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -418,12 +418,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @rdname statfunctions #' @export -setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) +setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname statfunctions #' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) +#' @rdname statfunctions +#' @export +setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) + +#' @rdname statfunctions +#' @export +setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index d17cce9c756e2..2e8076843f08a 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -66,8 +66,9 @@ setMethod("crosstab", #' cov <- cov(df, "title", "gender") #' } setMethod("cov", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2) { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "cov", col1, col2) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index b52a11fb1a348..7b5713720df87 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -996,6 +996,8 @@ test_that("column functions", { c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() c16 <- is.nan(c) + isnan(c) + isNaN(c) + c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") + c18 <- covar_pop(c, c1) + covar_pop("c", "c1") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From ce38a35b764397fcf561ac81de6da96579f5c13e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 20:12:34 -0800 Subject: [PATCH 1709/2089] [SPARK-12935][SQL] DataFrame API for Count-Min Sketch This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs. Author: Cheng Lian Closes #10911 from liancheng/cms-df-api. --- .../apache/spark/util/sketch/BloomFilter.java | 10 ++- .../spark/util/sketch/CountMinSketch.java | 26 +++--- .../spark/util/sketch/CountMinSketchImpl.java | 56 ++++++++----- sql/core/pom.xml | 5 ++ .../spark/sql/DataFrameStatFunctions.scala | 81 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 28 ++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 36 +++++++++ 7 files changed, 205 insertions(+), 37 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 00378d58518f6..d392fb187ad65 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -47,10 +47,12 @@ public abstract class BloomFilter { public enum Version { /** * {@code BloomFilter} binary format version 1 (all values written in big-endian order): - * - Version number, always 1 (32 bit) - * - Total number of words of the underlying bit array (32 bit) - * - The words/longs (numWords * 64 bit) - * - Number of hash functions (32 bit) + *
      + *
    • Version number, always 1 (32 bit)
    • + *
    • Total number of words of the underlying bit array (32 bit)
    • + *
    • The words/longs (numWords * 64 bit)
    • + *
    • Number of hash functions (32 bit)
    • + *
    */ V1(1); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 00c0b1b9e2db8..5692e574d4c7e 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -59,16 +59,22 @@ abstract public class CountMinSketch { public enum Version { /** * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) + *
      + *
    • Version number, always 1 (32 bit)
    • + *
    • Total count of added items (64 bit)
    • + *
    • Depth (32 bit)
    • + *
    • Width (32 bit)
    • + *
    • Hash functions (depth * 64 bit)
    • + *
    • + * Count table + *
        + *
      • Row 0 (width * 64 bit)
      • + *
      • Row 1 (width * 64 bit)
      • + *
      • ...
      • + *
      • Row {@code depth - 1} (width * 64 bit)
      • + *
      + *
    • + *
    */ V1(1); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index d08809605a932..8cc29e4076307 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -21,13 +21,16 @@ import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.OutputStream; +import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; -class CountMinSketchImpl extends CountMinSketch { - public static final long PRIME_MODULUS = (1L << 31) - 1; +class CountMinSketchImpl extends CountMinSketch implements Serializable { + private static final long PRIME_MODULUS = (1L << 31) - 1; private int depth; private int width; @@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch { private double eps; private double confidence; + private CountMinSketchImpl() { + } + CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; this.width = width; @@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch { initTablesWith(depth, width, seed); } - CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { - this.depth = depth; - this.width = width; - this.eps = 2.0 / width; - this.confidence = 1 - 1 / Math.pow(2, depth); - this.hashA = hashA; - this.table = table; - this.totalCount = totalCount; - } - @Override public boolean equals(Object other) { if (other == this) { @@ -325,27 +321,43 @@ public void writeTo(OutputStream out) throws IOException { } public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + CountMinSketchImpl sketch = new CountMinSketchImpl(); + sketch.readFrom0(in); + return sketch; + } + + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); - // Ignores version number - dis.readInt(); + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")"); + } - long totalCount = dis.readLong(); - int depth = dis.readInt(); - int width = dis.readInt(); + this.totalCount = dis.readLong(); + this.depth = dis.readInt(); + this.width = dis.readInt(); + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); - long hashA[] = new long[depth]; + this.hashA = new long[depth]; for (int i = 0; i < depth; ++i) { - hashA[i] = dis.readLong(); + this.hashA[i] = dis.readLong(); } - long table[][] = new long[depth][width]; + this.table = new long[depth][width]; for (int i = 0; i < depth; ++i) { for (int j = 0; j < width; ++j) { - table[i][j] = dis.readLong(); + this.table[i][j] = dis.readLong(); } } + } + + private void writeObject(ObjectOutputStream out) throws IOException { + this.writeTo(out); + } - return new CountMinSketchImpl(depth, width, totalCount, hashA, table); + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.readFrom0(in); } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 31b364f351d56..4bb55f6b7f739 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -42,6 +42,11 @@ 1.5.6 jar + + org.apache.spark + spark-sketch_2.10 + ${project.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index e66aa5f947181..465b12bb59d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.CountMinSketch /** * :: Experimental :: @@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), depth, width, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch( + colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), eps, confidence, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(depth, width, seed)) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) + } + + private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require( + colType == StringType || colType.isInstanceOf[IntegralType], + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + + singleCol.rdd.aggregate(zero)( + (sketch: CountMinSketch, row: Row) => { + sketch.add(row.get(0)) + sketch + }, + + (sketch1: CountMinSketch, sketch2: CountMinSketch) => { + sketch1.mergeInPlace(sketch2) + } + ) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ac1607ba3521a..9cf94e72d34e2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -35,9 +35,10 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.CountMinSketch; +import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { @@ -321,4 +322,29 @@ public void testTextLoad() { Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); } + + @Test + public void testCountMinSketch() { + DataFrame df = context.range(1000); + + CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); + Assert.assertEquals(sketch1.totalCount(), 1000); + Assert.assertEquals(sketch1.depth(), 10); + Assert.assertEquals(sketch1.width(), 20); + + CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); + Assert.assertEquals(sketch2.totalCount(), 1000); + Assert.assertEquals(sketch2.depth(), 10); + Assert.assertEquals(sketch2.width(), 20); + + CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); + Assert.assertEquals(sketch3.totalCount(), 1000); + Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); + + CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); + Assert.assertEquals(sketch4.totalCount(), 1000); + Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 63ad6c439a870..8f3ea5a2860ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql import java.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DoubleType class DataFrameStatSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { sampled.groupBy("key").count().orderBy("key"), Seq(Row(0, 6), Row(1, 11))) } + + // This test case only verifies that `DataFrame.countMinSketch()` methods do return + // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in + // `CountMinSketchSuite` in project spark-sketch. + test("countMinSketch") { + val df = sqlContext.range(1000) + + val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) + assert(sketch1.totalCount() === 1000) + assert(sketch1.depth() === 10) + assert(sketch1.width() === 20) + + val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42) + assert(sketch2.totalCount() === 1000) + assert(sketch2.depth() === 10) + assert(sketch2.width() === 20) + + val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch3.totalCount() === 1000) + assert(sketch3.relativeError() === 0.001) + assert(sketch3.confidence() === 0.99 +- 5e-3) + + val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch4.totalCount() === 1000) + assert(sketch4.relativeError() === 0.001 +- 1e04) + assert(sketch4.confidence() === 0.99 +- 5e-3) + + intercept[IllegalArgumentException] { + df.select('id cast DoubleType as 'id) + .stat + .countMinSketch('id, depth = 10, width = 20, seed = 42) + } + } } From 58f5d8c1da6feeb598aa5f74ffe1593d4839d11d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 20:30:13 -0800 Subject: [PATCH 1710/2089] [SPARK-12728][SQL] Integrates SQL generation with native view This PR is a follow-up of PR #10541. It integrates the newly introduced SQL generation feature with native view to make native view canonical. In this PR, a new SQL option `spark.sql.nativeView.canonical` is added. When this option and `spark.sql.nativeView` are both `true`, Spark SQL tries to handle `CREATE VIEW` DDL statements using SQL query strings generated from view definition logical plans. If we failed to map the plan to SQL, we fallback to the original native view approach. One important issue this PR fixes is that, now we can use CTE when defining a view. Originally, when native view is turned on, we wrap the view definition text with an extra `SELECT`. However, HiveQL parser doesn't allow CTE appearing as a subquery. Namely, something like this is disallowed: ```sql SELECT n FROM ( WITH w AS (SELECT 1 AS n) SELECT * FROM w ) v ``` This PR fixes this issue because the extra `SELECT` is no longer needed (also, CTE expressions are inlined as subqueries during analysis phase, thus there won't be CTE expressions in the generated SQL query string). Author: Cheng Lian Author: Yin Huai Closes #10733 from liancheng/spark-12728.integrate-sql-gen-with-native-view. --- .../scala/org/apache/spark/sql/SQLConf.scala | 10 ++ .../apache/spark/sql/test/SQLTestUtils.scala | 13 ++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 33 ++-- .../hive/execution/CreateViewAsSelect.scala | 95 ++++++++---- .../sql/hive/LogicalPlanToSQLSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 142 ++++++++++++------ 6 files changed, 200 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2d664d3ee691b..c9ba6700998c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -367,6 +367,14 @@ private[spark] object SQLConf { "possible, or you may get wrong result.", isPublic = false) + val CANONICAL_NATIVE_VIEW = booleanConf("spark.sql.nativeView.canonical", + defaultValue = Some(true), + doc = "When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " + + "CREATE VIEW statement using SQL query string generated from view definition logical " + + "plan. If the logical plan doesn't have a SQL representation, we fallback to the " + + "original native view implementation.", + isPublic = false) + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), doc = "The name of internal column for storing raw/un-parsed JSON records that fail to parse.") @@ -550,6 +558,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + private[spark] def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) private[spark] def subexpressionEliminationEnabled: Boolean = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 5f73d71d4510a..d48143762cac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -154,9 +154,22 @@ private[sql] trait SQLTestUtils } } + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + sqlContext.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + /** * Creates a temporary database and switches current database to it before executing `f`. This * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. */ protected def withTempDatabase(f: String => Unit): Unit = { val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 80e45d5162801..a9c0e9ab7caef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -579,25 +579,24 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p - case CreateViewAsSelect(table, child, allowExisting, replace, sql) => - if (conf.nativeView) { - if (allowExisting && replace) { - throw new AnalysisException( - "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") - } + case CreateViewAsSelect(table, child, allowExisting, replace, sql) if conf.nativeView => + if (allowExisting && replace) { + throw new AnalysisException( + "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") + } - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) - execution.CreateViewAsSelect( - table.copy( - specifiedDatabase = Some(dbName), - name = tblName), - child.output, - allowExisting, - replace) - } else { - HiveNativeCommand(sql) - } + execution.CreateViewAsSelect( + table.copy( + specifiedDatabase = Some(dbName), + name = tblName), + child, + allowExisting, + replace) + + case CreateViewAsSelect(table, child, allowExisting, replace, sql) => + HiveNativeCommand(sql) case p @ CreateTableAsSelect(table, child, allowExisting) => val schema = if (table.schema.nonEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 6e288afbb4d2d..31bda56e8a163 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, SQLBuilder} import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** @@ -32,10 +33,12 @@ import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} // from Hive and may not work for some cases like create view on self join. private[hive] case class CreateViewAsSelect( tableDesc: HiveTable, - childSchema: Seq[Attribute], + child: LogicalPlan, allowExisting: Boolean, orReplace: Boolean) extends RunnableCommand { + private val childSchema = child.output + assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) assert(tableDesc.viewText.isDefined) @@ -44,55 +47,83 @@ private[hive] case class CreateViewAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableIdentifier)) { - if (allowExisting) { - // view already exists, will do nothing, to keep consistent with Hive - } else if (orReplace) { - hiveContext.catalog.client.alertView(prepareTable()) - } else { + hiveContext.catalog.tableExists(tableIdentifier) match { + case true if allowExisting => + // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view + // already exists. + + case true if orReplace => + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` + hiveContext.catalog.client.alertView(prepareTable(sqlContext)) + + case true => + // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already + // exists. throw new AnalysisException(s"View $tableIdentifier already exists. " + "If you want to update the view definition, please use ALTER VIEW AS or " + "CREATE OR REPLACE VIEW AS") - } - } else { - hiveContext.catalog.client.createView(prepareTable()) + + case false => + hiveContext.catalog.client.createView(prepareTable(sqlContext)) } Seq.empty[Row] } - private def prepareTable(): HiveTable = { - // setup column types according to the schema of child. - val schema = if (tableDesc.schema == Nil) { - childSchema.map { attr => - HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) - } + private def prepareTable(sqlContext: SQLContext): HiveTable = { + val expandedText = if (sqlContext.conf.canonicalView) { + rebuildViewQueryString(sqlContext).getOrElse(wrapViewTextWithSelect) } else { - childSchema.zip(tableDesc.schema).map { case (attr, col) => - HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + wrapViewTextWithSelect + } + + val viewSchema = { + if (tableDesc.schema.isEmpty) { + childSchema.map { attr => + HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + } + } else { + childSchema.zip(tableDesc.schema).map { case (attr, col) => + HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + } } } - val columnNames = childSchema.map(f => verbose(f.name)) + tableDesc.copy(schema = viewSchema, viewText = Some(expandedText)) + } + private def wrapViewTextWithSelect: String = { // When user specified column names for view, we should create a project to do the renaming. // When no column name specified, we still need to create a project to declare the columns // we need, to make us more robust to top level `*`s. - val projectList = if (tableDesc.schema == Nil) { - columnNames.mkString(", ") - } else { - columnNames.zip(tableDesc.schema.map(f => verbose(f.name))).map { - case (name, alias) => s"$name AS $alias" - }.mkString(", ") + val viewOutput = { + val columnNames = childSchema.map(f => quote(f.name)) + if (tableDesc.schema.isEmpty) { + columnNames.mkString(", ") + } else { + columnNames.zip(tableDesc.schema.map(f => quote(f.name))).map { + case (name, alias) => s"$name AS $alias" + }.mkString(", ") + } } - val viewName = verbose(tableDesc.name) - - val expandedText = s"SELECT $projectList FROM (${tableDesc.viewText.get}) $viewName" + val viewText = tableDesc.viewText.get + val viewName = quote(tableDesc.name) + s"SELECT $viewOutput FROM ($viewText) $viewName" + } - tableDesc.copy(schema = schema, viewText = Some(expandedText)) + private def rebuildViewQueryString(sqlContext: SQLContext): Option[String] = { + val logicalPlan = if (tableDesc.schema.isEmpty) { + child + } else { + val projectList = childSchema.zip(tableDesc.schema).map { + case (attr, col) => Alias(attr, col.name)() + } + sqlContext.executePlan(Project(projectList, child)).analyzed + } + new SQLBuilder(logicalPlan, sqlContext).toSQL } // escape backtick with double-backtick in column name and wrap it with backtick. - private def verbose(name: String) = s"`${name.replaceAll("`", "``")}`" + private def quote(name: String) = s"`${name.replaceAll("`", "``")}`" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 261a4746f4287..1f731db26f387 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -147,7 +147,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // TODO Enable this // Query plans transformed by DistinctAggregationRewriter are not recognized yet - ignore("distinct and non-distinct aggregation") { + ignore("multi-distinct columns") { checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 683008960aa28..9e53d8a81e753 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1319,67 +1319,119 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("correctly handle CREATE OR REPLACE VIEW") { - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt", "jt2") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") - checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + Seq(true, false).foreach { enabled => + val prefix = (if (enabled) "With" else "Without") + " canonical native view: " + test(s"$prefix correctly handle CREATE OR REPLACE VIEW") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + + sql("DROP VIEW testView") + + val e = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + } + assert(e.message.contains("not allowed to define a view")) + } + } + } - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") - // make sure the view has been changed. - checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + test(s"$prefix correctly handle ALTER VIEW") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt", "jt2") { + withView("testView") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("ALTER VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + } + } + } + } - sql("DROP VIEW testView") + test(s"$prefix create hive view for json table") { + // json table is not hive-compatible, make sure the new flag fix it. + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt") { + withView("testView") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + } + } + } + } - val e = intercept[AnalysisException] { - sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + test(s"$prefix create hive view for partitioned parquet table") { + // partitioned parquet table is not hive-compatible, make sure the new flag fix it. + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("parTable") { + withView("testView") { + val df = Seq(1 -> "a").toDF("i", "j") + df.write.format("parquet").partitionBy("i").saveAsTable("parTable") + sql("CREATE VIEW testView AS SELECT i, j FROM parTable") + checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) + } } - assert(e.message.contains("not allowed to define a view")) } } } - test("correctly handle ALTER VIEW") { - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt", "jt2") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") - - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("ALTER VIEW testView AS SELECT * FROM jt2") - // make sure the view has been changed. - checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) - - sql("DROP VIEW testView") + test("CTE within view") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withView("cte_view") { + sql("CREATE VIEW cte_view AS WITH w AS (SELECT 1 AS n) SELECT n FROM w") + checkAnswer(sql("SELECT * FROM cte_view"), Row(1)) } } } - test("create hive view for json table") { - // json table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") - checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) - sql("DROP VIEW testView") + test("Using view after switching current database") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withView("v") { + sql("CREATE VIEW v AS SELECT * FROM src") + withTempDatabase { db => + activateDatabase(db) { + // Should look up table `src` in database `default`. + checkAnswer(sql("SELECT * FROM default.v"), sql("SELECT * FROM default.src")) + + // The new `src` table shouldn't be scanned. + sql("CREATE TABLE src(key INT, value STRING)") + checkAnswer(sql("SELECT * FROM default.v"), sql("SELECT * FROM default.src")) + } + } } } } - test("create hive view for partitioned parquet table") { - // partitioned parquet table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("parTable") { - val df = Seq(1 -> "a").toDF("i", "j") - df.write.format("parquet").partitionBy("i").saveAsTable("parTable") - sql("CREATE VIEW testView AS SELECT i, j FROM parTable") - checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) - sql("DROP VIEW testView") + test("Using view after adding more columns") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withTable("add_col") { + sqlContext.range(10).write.saveAsTable("add_col") + withView("v") { + sql("CREATE VIEW v AS SELECT * FROM add_col") + sqlContext.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + checkAnswer(sql("SELECT * FROM v"), sqlContext.range(10)) + } } } } From bae3c9a4eb0c320999e5dbafd62692c12823e07d Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Tue, 26 Jan 2016 21:14:39 -0800 Subject: [PATCH 1711/2089] [SPARK-12967][NETTY] Avoid NettyRpc error message during sparkContext shutdown If there's an RPC issue while sparkContext is alive but stopped (which would happen only when executing SparkContext.stop), log a warning instead. This is a common occurrence. vanzin Author: Nishkam Ravi Author: nishkamravi2 Closes #10881 from nishkamravi2/master_netty. --- .../spark/rpc/RpcEnvStoppedException.scala | 20 +++++++++++++++++++ .../apache/spark/rpc/netty/Dispatcher.scala | 4 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 6 +++++- .../org/apache/spark/rpc/netty/Outbox.scala | 7 +++++-- 4 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala new file mode 100644 index 0000000000000..c296cc23f12b7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +private[rpc] class RpcEnvStoppedException() + extends IllegalStateException("RpcEnv already stopped.") diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 19259e0e800c3..6ceff2c073998 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -106,7 +106,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e)) + postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}")) } } @@ -156,7 +156,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block val error = if (stopped) { - new IllegalStateException("RpcEnv already stopped.") + new RpcEnvStoppedException() } else { new SparkException(s"Could not find $endpointName or it has been stopped.") } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 9ae74d9d7b898..89eda857e6225 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -182,7 +182,11 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { // Message to a local RPC endpoint. - dispatcher.postOneWayMessage(message) + try { + dispatcher.postOneWayMessage(message) + } catch { + case e: RpcEnvStoppedException => logWarning(e.getMessage) + } } else { // Message to a remote RPC endpoint. postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 2316ebe347bb7..9fd64e8535752 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEnvStoppedException} private[netty] sealed trait OutboxMessage { @@ -43,7 +43,10 @@ private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends Outbo } override def onFailure(e: Throwable): Unit = { - logWarning(s"Failed to send one-way RPC.", e) + e match { + case e1: RpcEnvStoppedException => logWarning(e1.getMessage) + case e1: Throwable => logWarning(s"Failed to send one-way RPC.", e1) + } } } From 4db255c7aa756daa224d61905db745b6bccc9173 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 26 Jan 2016 21:16:56 -0800 Subject: [PATCH 1712/2089] [SPARK-12780] Inconsistency returning value of ML python models' properties https://issues.apache.org/jira/browse/SPARK-12780 Author: Xusen Yin Closes #10724 from yinxusen/SPARK-12780. --- python/pyspark/ml/feature.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 22081233b04d5..d017a231886cb 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1323,7 +1323,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] - >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels) >>> itd = inverter.transform(td) >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) @@ -1365,13 +1365,14 @@ class StringIndexerModel(JavaModel): .. versionadded:: 1.4.0 """ + @property @since("1.5.0") def labels(self): """ Ordered list of labels, corresponding to indices to be assigned. """ - return self._java_obj.labels + return self._call_java("labels") @inherit_doc From 90b0e562406a8bac529e190472e7f5da4030bf5c Mon Sep 17 00:00:00 2001 From: BenFradet Date: Wed, 27 Jan 2016 09:27:11 +0000 Subject: [PATCH 1713/2089] [SPARK-12983][CORE][DOC] Correct metrics.properties.template There are some typos or plain unintelligible sentences in the metrics template. Author: BenFradet Closes #10902 from BenFradet/SPARK-12983. --- conf/metrics.properties.template | 71 +++++++++++++++++--------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index d6962e0da2f30..8a4f4e48335bd 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -57,39 +57,41 @@ # added to Java properties using -Dspark.metrics.conf=xxx if you want to # customize metrics system. You can also put the file in ${SPARK_HOME}/conf # and it will be loaded automatically. -# 5. MetricsServlet is added by default as a sink in master, worker and client -# driver, you can send http request "/metrics/json" to get a snapshot of all the -# registered metrics in json format. For master, requests "/metrics/master/json" and -# "/metrics/applications/json" can be sent seperately to get metrics snapshot of -# instance master and applications. MetricsServlet may not be configured by self. -# +# 5. The MetricsServlet sink is added by default as a sink in the master, +# worker and driver, and you can send HTTP requests to the "/metrics/json" +# endpoint to get a snapshot of all the registered metrics in JSON format. +# For master, requests to the "/metrics/master/json" and +# "/metrics/applications/json" endpoints can be sent separately to get +# metrics snapshots of the master instance and applications. This +# MetricsServlet does not have to be configured. ## List of available common sources and their properties. # org.apache.spark.metrics.source.JvmSource -# Note: Currently, JvmSource is the only available common source -# to add additionaly to an instance, to enable this, -# set the "class" option to its fully qulified class name (see examples below) +# Note: Currently, JvmSource is the only available common source. +# It can be added to an instance by setting the "class" option to its +# fully qualified class name (see examples below). ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink # Name: Default: Description: # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # org.apache.spark.metrics.sink.CSVSink # Name: Default: Description: # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # directory /tmp Where to store CSV files # org.apache.spark.metrics.sink.GangliaSink # Name: Default: Description: -# host NONE Hostname or multicast group of Ganglia server -# port NONE Port of Ganglia server(s) +# host NONE Hostname or multicast group of the Ganglia server, +# must be set +# port NONE Port of the Ganglia server(s), must be set # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # ttl 1 TTL of messages sent by Ganglia # mode multicast Ganglia network mode ('unicast' or 'multicast') @@ -98,19 +100,21 @@ # org.apache.spark.metrics.sink.MetricsServlet # Name: Default: Description: # path VARIES* Path prefix from the web server root -# sample false Whether to show entire set of samples for histograms ('false' or 'true') +# sample false Whether to show entire set of samples for histograms +# ('false' or 'true') # -# * Default path is /metrics/json for all instances except the master. The master has two paths: +# * Default path is /metrics/json for all instances except the master. The +# master has two paths: # /metrics/applications/json # App information # /metrics/master/json # Master information # org.apache.spark.metrics.sink.GraphiteSink # Name: Default: Description: -# host NONE Hostname of Graphite server -# port NONE Port of Graphite server +# host NONE Hostname of the Graphite server, must be set +# port NONE Port of the Graphite server, must be set # period 10 Poll period -# unit seconds Units of poll period -# prefix EMPTY STRING Prefix to prepend to metric name +# unit seconds Unit of the poll period +# prefix EMPTY STRING Prefix to prepend to every metric's name # protocol tcp Protocol ("tcp" or "udp") to use ## Examples @@ -120,42 +124,42 @@ # Enable ConsoleSink for all instances by class name #*.sink.console.class=org.apache.spark.metrics.sink.ConsoleSink -# Polling period for ConsoleSink +# Polling period for the ConsoleSink #*.sink.console.period=10 - +# Unit of the polling period for the ConsoleSink #*.sink.console.unit=seconds -# Master instance overlap polling period +# Polling period for the ConsoleSink specific for the master instance #master.sink.console.period=15 - +# Unit of the polling period for the ConsoleSink specific for the master +# instance #master.sink.console.unit=seconds -# Enable CsvSink for all instances +# Enable CsvSink for all instances by class name #*.sink.csv.class=org.apache.spark.metrics.sink.CsvSink -# Polling period for CsvSink +# Polling period for the CsvSink #*.sink.csv.period=1 - +# Unit of the polling period for the CsvSink #*.sink.csv.unit=minutes # Polling directory for CsvSink #*.sink.csv.directory=/tmp/ -# Worker instance overlap polling period +# Polling period for the CsvSink specific for the worker instance #worker.sink.csv.period=10 - +# Unit of the polling period for the CsvSink specific for the worker instance #worker.sink.csv.unit=minutes # Enable Slf4jSink for all instances by class name #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink -# Polling period for Slf4JSink +# Polling period for the Slf4JSink #*.sink.slf4j.period=1 - +# Unit of the polling period for the Slf4jSink #*.sink.slf4j.unit=minutes - -# Enable jvm source for instance master, worker, driver and executor +# Enable JvmSource for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource #worker.source.jvm.class=org.apache.spark.metrics.source.JvmSource @@ -163,4 +167,3 @@ #driver.source.jvm.class=org.apache.spark.metrics.source.JvmSource #executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource - From 093291cf9b8729c0bd057cf67aed840b11f8c94a Mon Sep 17 00:00:00 2001 From: Andrew Date: Wed, 27 Jan 2016 09:31:44 +0000 Subject: [PATCH 1714/2089] [SPARK-1680][DOCS] Explain environment variables for running on YARN in cluster mode JIRA 1680 added a property called spark.yarn.appMasterEnv. This PR draws users' attention to this special case by adding an explanation in configuration.html#environment-variables Author: Andrew Closes #10869 from weineran/branch-yarn-docs. --- docs/configuration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index d2a2f1052405d..74a8fb5d35a66 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1643,6 +1643,8 @@ to use on each machine and maximum memory. Since `spark-env.sh` is a shell script, some of these can be set programmatically -- for example, you might compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface. +Note: When running Spark on YARN in `cluster` mode, environment variables need to be set using the `spark.yarn.appMasterEnv.[EnvironmentVariableName]` property in your `conf/spark-defaults.conf` file. Environment variables that are set in `spark-env.sh` will not be reflected in the YARN Application Master process in `cluster` mode. See the [YARN-related Spark Properties](running-on-yarn.html#spark-properties) for more information. + # Configuring Logging Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a From 41f0c85f9be264103c066935e743f59caf0fe268 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 Jan 2016 08:32:13 -0800 Subject: [PATCH 1715/2089] [SPARK-13023][PROJECT INFRA] Fix handling of root module in modules_to_test() There's a minor bug in how we handle the `root` module in the `modules_to_test()` function in `dev/run-tests.py`: since `root` now depends on `build` (since every test needs to run on any build test), we now need to check for the presence of root in `modules_to_test` instead of `changed_modules`. Author: Josh Rosen Closes #10933 from JoshRosen/build-module-fix. --- dev/run-tests.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index c78a66f6aa54e..6febbf108900d 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -104,6 +104,8 @@ def determine_modules_to_test(changed_modules): >>> [x.name for x in determine_modules_to_test([modules.root])] ['root'] + >>> [x.name for x in determine_modules_to_test([modules.build])] + ['root'] >>> [x.name for x in determine_modules_to_test([modules.graphx])] ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] @@ -111,15 +113,13 @@ def determine_modules_to_test(changed_modules): ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ - # If we're going to have to run all of the tests, then we can just short-circuit - # and return 'root'. No module depends on root, so if it appears then it will be - # in changed_modules. - if modules.root in changed_modules: - return [modules.root] modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) modules_to_test = modules_to_test.union(set(changed_modules)) + # If we need to run all of the tests, then we should short-circuit and return 'root' + if modules.root in modules_to_test: + return [modules.root] return toposort_flatten( {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True) From edd473751b59b55fa3daede5ed7bc19ea8bd7170 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Wed, 27 Jan 2016 09:55:10 -0800 Subject: [PATCH 1716/2089] [SPARK-10847][SQL][PYSPARK] Pyspark - DataFrame - Optional Metadata with `None` triggers cryptic failure The error message is now changed from "Do not support type class scala.Tuple2." to "Do not support type class org.json4s.JsonAST$JNull$" to be more informative about what is not supported. Also, StructType metadata now handles JNull correctly, i.e., {'a': None}. test_metadata_null is added to tests.py to show the fix works. Author: Jason Lee Closes #8969 from jasoncl/SPARK-10847. --- python/pyspark/sql/tests.py | 7 +++++++ .../main/scala/org/apache/spark/sql/types/Metadata.scala | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7593b991a780b..410efbafe0792 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -747,6 +747,13 @@ def test_struct_type(self): except ValueError: self.assertEqual(1, 1) + def test_metadata_null(self): + from pyspark.sql.types import StructType, StringType, StructField + schema = StructType([StructField("f1", StringType(), True, None), + StructField("f2", StringType(), True, {'a': None})]) + rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) + self.sqlCtx.createDataFrame(rdd, schema) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 6ee24ee0c1913..9e0f9943bc638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -156,7 +156,9 @@ object Metadata { throw new RuntimeException(s"Do not support array of type ${other.getClass}.") } } - case other => + case (key, JNull) => + builder.putNull(key) + case (key, other) => throw new RuntimeException(s"Do not support type ${other.getClass}.") } builder.build() @@ -229,6 +231,9 @@ class MetadataBuilder { this } + /** Puts a null. */ + def putNull(key: String): this.type = put(key, null) + /** Puts a Long. */ def putLong(key: String, value: Long): this.type = put(key, value) From 87abcf7df921a5937fdb2bae8bfb30bfabc4970a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 27 Jan 2016 11:15:48 -0800 Subject: [PATCH 1717/2089] [SPARK-12895][SPARK-12896] Migrate TaskMetrics to accumulators The high level idea is that instead of having the executors send both accumulator updates and TaskMetrics, we should have them send only accumulator updates. This eliminates the need to maintain both code paths since one can be implemented in terms of the other. This effort is split into two parts: **SPARK-12895: Implement TaskMetrics using accumulators.** TaskMetrics is basically just a bunch of accumulable fields. This patch makes TaskMetrics a syntactic wrapper around a collection of accumulators so we don't need to send TaskMetrics from the executors to the driver. **SPARK-12896: Send only accumulator updates to the driver.** Now that TaskMetrics are expressed in terms of accumulators, we can capture all TaskMetrics values if we just send accumulator updates from the executors to the driver. This completes the parent issue SPARK-10620. While an effort has been made to preserve as much of the public API as possible, there were a few known breaking DeveloperApi changes that would be very awkward to maintain. I will gather the full list shortly and post it here. Note: This was once part of #10717. This patch is split out into its own patch from there to make it easier for others to review. Other smaller pieces of already been merged into master. Author: Andrew Or Closes #10835 from andrewor14/task-metrics-use-accums. --- .../shuffle/sort/UnsafeShuffleWriter.java | 8 +- .../scala/org/apache/spark/Accumulable.scala | 86 ++- .../scala/org/apache/spark/Accumulator.scala | 101 +++- .../scala/org/apache/spark/Aggregator.scala | 3 +- .../org/apache/spark/HeartbeatReceiver.scala | 6 +- .../apache/spark/InternalAccumulator.scala | 199 ++++++- .../scala/org/apache/spark/TaskContext.scala | 19 +- .../org/apache/spark/TaskContextImpl.scala | 30 +- .../org/apache/spark/TaskEndReason.scala | 29 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 8 + .../org/apache/spark/executor/Executor.scala | 45 +- .../apache/spark/executor/InputMetrics.scala | 81 ++- .../apache/spark/executor/OutputMetrics.scala | 71 ++- .../spark/executor/ShuffleReadMetrics.scala | 104 ++-- .../spark/executor/ShuffleWriteMetrics.scala | 62 +- .../apache/spark/executor/TaskMetrics.scala | 370 +++++++----- .../org/apache/spark/rdd/CoGroupedRDD.scala | 3 +- .../org/apache/spark/rdd/HadoopRDD.scala | 24 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 22 +- .../spark/scheduler/AccumulableInfo.scala | 55 +- .../apache/spark/scheduler/DAGScheduler.scala | 102 ++-- .../spark/scheduler/DAGSchedulerEvent.scala | 5 +- .../apache/spark/scheduler/ResultTask.scala | 8 +- .../spark/scheduler/ShuffleMapTask.scala | 8 +- .../spark/scheduler/SparkListener.scala | 7 +- .../org/apache/spark/scheduler/Stage.scala | 4 +- .../org/apache/spark/scheduler/Task.scala | 41 +- .../apache/spark/scheduler/TaskResult.scala | 28 +- .../spark/scheduler/TaskResultGetter.scala | 27 +- .../spark/scheduler/TaskScheduler.scala | 6 +- .../spark/scheduler/TaskSchedulerImpl.scala | 13 +- .../spark/scheduler/TaskSetManager.scala | 14 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../status/api/v1/AllStagesResource.scala | 3 +- .../spark/ui/jobs/JobProgressListener.scala | 18 +- .../org/apache/spark/ui/jobs/StagePage.scala | 36 +- .../org/apache/spark/util/JsonProtocol.scala | 124 +++- .../util/collection/ExternalSorter.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 2 - .../org/apache/spark/AccumulatorSuite.scala | 324 ++++++----- .../ExecutorAllocationManagerSuite.scala | 2 +- .../apache/spark/HeartbeatReceiverSuite.scala | 6 +- .../spark/InternalAccumulatorSuite.scala | 331 +++++++++++ .../org/apache/spark/SparkFunSuite.scala | 2 + .../spark/executor/TaskMetricsSuite.scala | 540 +++++++++++++++++- .../spark/memory/MemoryTestingUtils.scala | 3 +- .../spark/scheduler/DAGSchedulerSuite.scala | 281 ++++----- .../spark/scheduler/ReplayListenerSuite.scala | 8 +- .../spark/scheduler/TaskContextSuite.scala | 56 +- .../scheduler/TaskResultGetterSuite.scala | 67 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 36 +- .../org/apache/spark/ui/StagePageSuite.scala | 11 +- .../ui/jobs/JobProgressListenerSuite.scala | 26 +- .../apache/spark/util/JsonProtocolSuite.scala | 515 ++++++++++++----- project/MimaExcludes.scala | 9 + .../org/apache/spark/sql/execution/Sort.scala | 8 +- .../TungstenAggregationIterator.scala | 6 +- .../datasources/SqlNewHadoopRDD.scala | 22 +- .../execution/joins/BroadcastHashJoin.scala | 3 +- .../joins/BroadcastHashOuterJoin.scala | 3 +- .../joins/BroadcastLeftSemiJoinHash.scala | 3 +- .../sql/execution/metric/SQLMetrics.scala | 6 +- .../spark/sql/execution/ui/SQLListener.scala | 22 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/execution/ReferenceSort.scala | 3 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 3 +- .../UnsafeKVExternalSorterSuite.scala | 3 +- .../columnar/PartitionBatchPruningSuite.scala | 38 +- .../sql/execution/ui/SQLListenerSuite.scala | 34 +- .../sql/util/DataFrameCallbackSuite.scala | 2 +- 70 files changed, 3012 insertions(+), 1141 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index d3d79a27ea1c6..128a82579b800 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -444,13 +444,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { - // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) - Map> internalAccumulators = - taskContext.internalMetricsToAccumulators(); - if (internalAccumulators != null) { - internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) - .add(getPeakMemoryUsedBytes()); - } + taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); if (stopping) { return Option.apply(null); diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index a456d420b8d6a..bde136141f40d 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -35,40 +35,67 @@ import org.apache.spark.util.Utils * [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are * accumulating a set. You will add items to the set, and you will union two sets together. * + * All accumulators created on the driver to be used on the executors must be registered with + * [[Accumulators]]. This is already done automatically for accumulators created by the user. + * Internal accumulators must be explicitly registered by the caller. + * + * Operations are not thread-safe. + * + * @param id ID of this accumulator; for internal use only. * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` * @param name human-readable name for use in Spark's web UI * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be * thread safe so that they can be reported correctly. + * @param countFailedValues whether to accumulate values from failed tasks. This is set to true + * for system and time metrics like serialization time or bytes spilled, + * and false for things with absolute values like number of input rows. + * This should be used for internal metrics only. * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] private[spark] ( - initialValue: R, +class Accumulable[R, T] private ( + val id: Long, + @transient initialValue: R, param: AccumulableParam[R, T], val name: Option[String], - internal: Boolean) + internal: Boolean, + private[spark] val countFailedValues: Boolean) extends Serializable { private[spark] def this( - @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { - this(initialValue, param, None, internal) + initialValue: R, + param: AccumulableParam[R, T], + name: Option[String], + internal: Boolean, + countFailedValues: Boolean) = { + this(Accumulators.newId(), initialValue, param, name, internal, countFailedValues) } - def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = - this(initialValue, param, name, false) + private[spark] def this( + initialValue: R, + param: AccumulableParam[R, T], + name: Option[String], + internal: Boolean) = { + this(initialValue, param, name, internal, false /* countFailedValues */) + } - def this(@transient initialValue: R, param: AccumulableParam[R, T]) = - this(initialValue, param, None) + def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false /* internal */) - val id: Long = Accumulators.newId + def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) - @volatile @transient private var value_ : R = initialValue // Current value on master - val zero = param.zero(initialValue) // Zero value to be passed to workers + @volatile @transient private var value_ : R = initialValue // Current value on driver + val zero = param.zero(initialValue) // Zero value to be passed to executors private var deserialized = false - Accumulators.register(this) + // In many places we create internal accumulators without access to the active context cleaner, + // so if we register them here then we may never unregister these accumulators. To avoid memory + // leaks, we require the caller to explicitly register internal accumulators elsewhere. + if (!internal) { + Accumulators.register(this) + } /** * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver @@ -77,6 +104,17 @@ class Accumulable[R, T] private[spark] ( */ private[spark] def isInternal: Boolean = internal + /** + * Return a copy of this [[Accumulable]]. + * + * The copy will have the same ID as the original and will not be registered with + * [[Accumulators]] again. This method exists so that the caller can avoid passing the + * same mutable instance around. + */ + private[spark] def copy(): Accumulable[R, T] = { + new Accumulable[R, T](id, initialValue, param, name, internal, countFailedValues) + } + /** * Add more data to this accumulator / accumulable * @param term the data to add @@ -106,7 +144,7 @@ class Accumulable[R, T] private[spark] ( def merge(term: R) { value_ = param.addInPlace(value_, term)} /** - * Access the accumulator's current value; only allowed on master. + * Access the accumulator's current value; only allowed on driver. */ def value: R = { if (!deserialized) { @@ -128,7 +166,7 @@ class Accumulable[R, T] private[spark] ( def localValue: R = value_ /** - * Set the accumulator's value; only allowed on master. + * Set the accumulator's value; only allowed on driver. */ def value_= (newValue: R) { if (!deserialized) { @@ -139,22 +177,24 @@ class Accumulable[R, T] private[spark] ( } /** - * Set the accumulator's value; only allowed on master + * Set the accumulator's value. For internal use only. */ - def setValue(newValue: R) { - this.value = newValue - } + def setValue(newValue: R): Unit = { value_ = newValue } + + /** + * Set the accumulator's value. For internal use only. + */ + private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) } // Called by Java when deserializing an object private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() value_ = zero deserialized = true + // Automatically register the accumulator when it is deserialized with the task closure. - // - // Note internal accumulators sent with task are deserialized before the TaskContext is created - // and are registered in the TaskContext constructor. Other internal accumulators, such SQL - // metrics, still need to register here. + // This is for external accumulators and internal ones that do not represent task level + // metrics, e.g. internal SQL metrics, which are per-operator. val taskContext = TaskContext.get() if (taskContext != null) { taskContext.registerAccumulator(this) diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 007136e6ae349..558bd447e22c5 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -17,9 +17,14 @@ package org.apache.spark -import scala.collection.{mutable, Map} +import java.util.concurrent.atomic.AtomicLong +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable import scala.ref.WeakReference +import org.apache.spark.storage.{BlockId, BlockStatus} + /** * A simpler value of [[Accumulable]] where the result type being accumulated is the same @@ -49,14 +54,18 @@ import scala.ref.WeakReference * * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `T` + * @param name human-readable name associated with this accumulator + * @param internal whether this accumulator is used internally within Spark only + * @param countFailedValues whether to accumulate values from failed tasks * @tparam T result type */ class Accumulator[T] private[spark] ( @transient private[spark] val initialValue: T, param: AccumulatorParam[T], name: Option[String], - internal: Boolean) - extends Accumulable[T, T](initialValue, param, name, internal) { + internal: Boolean, + override val countFailedValues: Boolean = false) + extends Accumulable[T, T](initialValue, param, name, internal, countFailedValues) { def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { this(initialValue, param, name, false) @@ -75,43 +84,63 @@ private[spark] object Accumulators extends Logging { * This global map holds the original accumulator objects that are created on the driver. * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. + * TODO: Don't use a global map; these should be tied to a SparkContext at the very least. */ + @GuardedBy("Accumulators") val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() - private var lastId: Long = 0 + private val nextId = new AtomicLong(0L) - def newId(): Long = synchronized { - lastId += 1 - lastId - } + /** + * Return a globally unique ID for a new [[Accumulable]]. + * Note: Once you copy the [[Accumulable]] the ID is no longer unique. + */ + def newId(): Long = nextId.getAndIncrement + /** + * Register an [[Accumulable]] created on the driver such that it can be used on the executors. + * + * All accumulators registered here can later be used as a container for accumulating partial + * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does. + * Note: if an accumulator is registered here, it should also be registered with the active + * context cleaner for cleanup so as to avoid memory leaks. + * + * If an [[Accumulable]] with the same ID was already registered, this does nothing instead + * of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct + * [[org.apache.spark.executor.TaskMetrics]] from accumulator updates. + */ def register(a: Accumulable[_, _]): Unit = synchronized { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) + if (!originals.contains(a.id)) { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) + } } - def remove(accId: Long) { - synchronized { - originals.remove(accId) - } + /** + * Unregister the [[Accumulable]] with the given ID, if any. + */ + def remove(accId: Long): Unit = synchronized { + originals.remove(accId) } - // Add values to the original accumulators with some given IDs - def add(values: Map[Long, Any]): Unit = synchronized { - for ((id, value) <- values) { - if (originals.contains(id)) { - // Since we are now storing weak references, we must check whether the underlying data - // is valid. - originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value - case None => - throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") - } - } else { - logWarning(s"Ignoring accumulator update for unknown accumulator id $id") + /** + * Return the [[Accumulable]] registered with the given ID, if any. + */ + def get(id: Long): Option[Accumulable[_, _]] = synchronized { + originals.get(id).map { weakRef => + // Since we are storing weak references, we must check whether the underlying data is valid. + weakRef.get.getOrElse { + throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id") } } } + /** + * Clear all registered [[Accumulable]]s. For testing only. + */ + def clear(): Unit = synchronized { + originals.clear() + } + } @@ -156,5 +185,23 @@ object AccumulatorParam { def zero(initialValue: Float): Float = 0f } - // TODO: Add AccumulatorParams for other types, e.g. lists and strings + // Note: when merging values, this param just adopts the newer value. This is used only + // internally for things that shouldn't really be accumulated across tasks, like input + // read method, which should be the same across all tasks in the same stage. + private[spark] object StringAccumulatorParam extends AccumulatorParam[String] { + def addInPlace(t1: String, t2: String): String = t2 + def zero(initialValue: String): String = "" + } + + // Note: this is expensive as it makes a copy of the list every time the caller adds an item. + // A better way to use this is to first accumulate the values yourself then them all at once. + private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] { + def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2 + def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T] + } + + // For the internal metric that records what blocks are updated in a particular task + private[spark] object UpdatedBlockStatusesAccumulatorParam + extends ListAccumulatorParam[(BlockId, BlockStatus)] + } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 62629000cfc23..e493d9a3cf9cc 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -57,8 +57,7 @@ case class Aggregator[K, V, C] ( Option(context).foreach { c => c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - c.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + c.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) } } } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index e03977828b86d..45b20c0e8d60c 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} */ private[spark] case class Heartbeat( executorId: String, - taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates blockManagerId: BlockManagerId) /** @@ -119,14 +119,14 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) context.reply(true) // Messages received from executors - case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId) => if (scheduler != null) { if (executorLastSeen.contains(executorId)) { executorLastSeen(executorId) = clock.getTimeMillis() eventLoopThread.submit(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) + executorId, accumUpdates, blockManagerId) val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) context.reply(response) } diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 6ea997c079f33..c191122c0630a 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -17,23 +17,169 @@ package org.apache.spark +import org.apache.spark.storage.{BlockId, BlockStatus} -// This is moved to its own file because many more things will be added to it in SPARK-10620. + +/** + * A collection of fields and methods concerned with internal accumulators that represent + * task level metrics. + */ private[spark] object InternalAccumulator { - val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" - val TEST_ACCUMULATOR = "testAccumulator" - - // For testing only. - // This needs to be a def since we don't want to reuse the same accumulator across stages. - private def maybeTestAccumulator: Option[Accumulator[Long]] = { - if (sys.props.contains("spark.testing")) { - Some(new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) - } else { - None + + import AccumulatorParam._ + + // Prefixes used in names of internal task level metrics + val METRICS_PREFIX = "internal.metrics." + val SHUFFLE_READ_METRICS_PREFIX = METRICS_PREFIX + "shuffle.read." + val SHUFFLE_WRITE_METRICS_PREFIX = METRICS_PREFIX + "shuffle.write." + val OUTPUT_METRICS_PREFIX = METRICS_PREFIX + "output." + val INPUT_METRICS_PREFIX = METRICS_PREFIX + "input." + + // Names of internal task level metrics + val EXECUTOR_DESERIALIZE_TIME = METRICS_PREFIX + "executorDeserializeTime" + val EXECUTOR_RUN_TIME = METRICS_PREFIX + "executorRunTime" + val RESULT_SIZE = METRICS_PREFIX + "resultSize" + val JVM_GC_TIME = METRICS_PREFIX + "jvmGCTime" + val RESULT_SERIALIZATION_TIME = METRICS_PREFIX + "resultSerializationTime" + val MEMORY_BYTES_SPILLED = METRICS_PREFIX + "memoryBytesSpilled" + val DISK_BYTES_SPILLED = METRICS_PREFIX + "diskBytesSpilled" + val PEAK_EXECUTION_MEMORY = METRICS_PREFIX + "peakExecutionMemory" + val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses" + val TEST_ACCUM = METRICS_PREFIX + "testAccumulator" + + // scalastyle:off + + // Names of shuffle read metrics + object shuffleRead { + val REMOTE_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "remoteBlocksFetched" + val LOCAL_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "localBlocksFetched" + val REMOTE_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesRead" + val LOCAL_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "localBytesRead" + val FETCH_WAIT_TIME = SHUFFLE_READ_METRICS_PREFIX + "fetchWaitTime" + val RECORDS_READ = SHUFFLE_READ_METRICS_PREFIX + "recordsRead" + } + + // Names of shuffle write metrics + object shuffleWrite { + val BYTES_WRITTEN = SHUFFLE_WRITE_METRICS_PREFIX + "bytesWritten" + val RECORDS_WRITTEN = SHUFFLE_WRITE_METRICS_PREFIX + "recordsWritten" + val WRITE_TIME = SHUFFLE_WRITE_METRICS_PREFIX + "writeTime" + } + + // Names of output metrics + object output { + val WRITE_METHOD = OUTPUT_METRICS_PREFIX + "writeMethod" + val BYTES_WRITTEN = OUTPUT_METRICS_PREFIX + "bytesWritten" + val RECORDS_WRITTEN = OUTPUT_METRICS_PREFIX + "recordsWritten" + } + + // Names of input metrics + object input { + val READ_METHOD = INPUT_METRICS_PREFIX + "readMethod" + val BYTES_READ = INPUT_METRICS_PREFIX + "bytesRead" + val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead" + } + + // scalastyle:on + + /** + * Create an internal [[Accumulator]] by name, which must begin with [[METRICS_PREFIX]]. + */ + def create(name: String): Accumulator[_] = { + require(name.startsWith(METRICS_PREFIX), + s"internal accumulator name must start with '$METRICS_PREFIX': $name") + getParam(name) match { + case p @ LongAccumulatorParam => newMetric[Long](0L, name, p) + case p @ IntAccumulatorParam => newMetric[Int](0, name, p) + case p @ StringAccumulatorParam => newMetric[String]("", name, p) + case p @ UpdatedBlockStatusesAccumulatorParam => + newMetric[Seq[(BlockId, BlockStatus)]](Seq(), name, p) + case p => throw new IllegalArgumentException( + s"unsupported accumulator param '${p.getClass.getSimpleName}' for metric '$name'.") + } + } + + /** + * Get the [[AccumulatorParam]] associated with the internal metric name, + * which must begin with [[METRICS_PREFIX]]. + */ + def getParam(name: String): AccumulatorParam[_] = { + require(name.startsWith(METRICS_PREFIX), + s"internal accumulator name must start with '$METRICS_PREFIX': $name") + name match { + case UPDATED_BLOCK_STATUSES => UpdatedBlockStatusesAccumulatorParam + case shuffleRead.LOCAL_BLOCKS_FETCHED => IntAccumulatorParam + case shuffleRead.REMOTE_BLOCKS_FETCHED => IntAccumulatorParam + case input.READ_METHOD => StringAccumulatorParam + case output.WRITE_METHOD => StringAccumulatorParam + case _ => LongAccumulatorParam } } + /** + * Accumulators for tracking internal metrics. + */ + def create(): Seq[Accumulator[_]] = { + Seq[String]( + EXECUTOR_DESERIALIZE_TIME, + EXECUTOR_RUN_TIME, + RESULT_SIZE, + JVM_GC_TIME, + RESULT_SERIALIZATION_TIME, + MEMORY_BYTES_SPILLED, + DISK_BYTES_SPILLED, + PEAK_EXECUTION_MEMORY, + UPDATED_BLOCK_STATUSES).map(create) ++ + createShuffleReadAccums() ++ + createShuffleWriteAccums() ++ + createInputAccums() ++ + createOutputAccums() ++ + sys.props.get("spark.testing").map(_ => create(TEST_ACCUM)).toSeq + } + + /** + * Accumulators for tracking shuffle read metrics. + */ + def createShuffleReadAccums(): Seq[Accumulator[_]] = { + Seq[String]( + shuffleRead.REMOTE_BLOCKS_FETCHED, + shuffleRead.LOCAL_BLOCKS_FETCHED, + shuffleRead.REMOTE_BYTES_READ, + shuffleRead.LOCAL_BYTES_READ, + shuffleRead.FETCH_WAIT_TIME, + shuffleRead.RECORDS_READ).map(create) + } + + /** + * Accumulators for tracking shuffle write metrics. + */ + def createShuffleWriteAccums(): Seq[Accumulator[_]] = { + Seq[String]( + shuffleWrite.BYTES_WRITTEN, + shuffleWrite.RECORDS_WRITTEN, + shuffleWrite.WRITE_TIME).map(create) + } + + /** + * Accumulators for tracking input metrics. + */ + def createInputAccums(): Seq[Accumulator[_]] = { + Seq[String]( + input.READ_METHOD, + input.BYTES_READ, + input.RECORDS_READ).map(create) + } + + /** + * Accumulators for tracking output metrics. + */ + def createOutputAccums(): Seq[Accumulator[_]] = { + Seq[String]( + output.WRITE_METHOD, + output.BYTES_WRITTEN, + output.RECORDS_WRITTEN).map(create) + } + /** * Accumulators for tracking internal metrics. * @@ -41,18 +187,23 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(sc: SparkContext): Seq[Accumulator[Long]] = { - val internalAccumulators = Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq - internalAccumulators.foreach { accumulator => - sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + def create(sc: SparkContext): Seq[Accumulator[_]] = { + val accums = create() + accums.foreach { accum => + Accumulators.register(accum) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum)) } - internalAccumulators + accums + } + + /** + * Create a new accumulator representing an internal task metric. + */ + private def newMetric[T]( + initialValue: T, + name: String, + param: AccumulatorParam[T]): Accumulator[T] = { + new Accumulator[T](initialValue, param, Some(name), internal = true, countFailedValues = true) } + } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7704abc134096..9f49cf1c4c9bd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -64,7 +64,7 @@ object TaskContext { * An empty task context that does not represent an actual task. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) + new TaskContextImpl(0, 0, 0, 0, null, null) } } @@ -138,7 +138,6 @@ abstract class TaskContext extends Serializable { */ def taskAttemptId(): Long - /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics @@ -161,20 +160,4 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit - /** - * Return the local values of internal accumulators that belong to this task. The key of the Map - * is the accumulator id and the value of the Map is the latest accumulator local value. - */ - private[spark] def collectInternalAccumulators(): Map[Long, Any] - - /** - * Return the local values of accumulators that belong to this task. The key of the Map is the - * accumulator id and the value of the Map is the latest accumulator local value. - */ - private[spark] def collectAccumulators(): Map[Long, Any] - - /** - * Accumulators for tracking internal metrics indexed by the name. - */ - private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 94ff884b742b8..27ca46f73d8ca 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,7 +17,7 @@ package org.apache.spark -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager @@ -32,11 +32,15 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, - internalAccumulators: Seq[Accumulator[Long]], - val taskMetrics: TaskMetrics = TaskMetrics.empty) + initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.create()) extends TaskContext with Logging { + /** + * Metrics associated with this task. + */ + override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators) + // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -91,24 +95,8 @@ private[spark] class TaskContextImpl( override def getMetricsSources(sourceName: String): Seq[Source] = metricsSystem.getSourcesByName(sourceName) - @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] - - private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { - accumulators(a.id) = a - } - - private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { - accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = { + taskMetrics.registerAccumulator(a) } - private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { - accumulators.mapValues(_.localValue).toMap - } - - private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { - // Explicitly register internal accumulators here because these are - // not captured in the task closure and are already deserialized - internalAccumulators.foreach(registerAccumulator) - internalAccumulators.map { a => (a.name.get, a) }.toMap - } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 13241b77bf97b..68340cc704dae 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -19,8 +19,11 @@ package org.apache.spark import java.io.{ObjectInputStream, ObjectOutputStream} +import scala.util.Try + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -115,22 +118,34 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics], - private val exceptionWrapper: Option[ThrowableSerializationWrapper]) + exceptionWrapper: Option[ThrowableSerializationWrapper], + accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]) extends TaskFailedReason { + @deprecated("use accumUpdates instead", "2.0.0") + val metrics: Option[TaskMetrics] = { + if (accumUpdates.nonEmpty) { + Try(TaskMetrics.fromAccumulatorUpdates(accumUpdates)).toOption + } else { + None + } + } + /** * `preserveCause` is used to keep the exception itself so it is available to the * driver. This may be set to `false` in the event that the exception is not in fact * serializable. */ - private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, - if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + private[spark] def this( + e: Throwable, + accumUpdates: Seq[AccumulableInfo], + preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None, accumUpdates) } - private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e, metrics, preserveCause = true) + private[spark] def this(e: Throwable, accumUpdates: Seq[AccumulableInfo]) { + this(e, accumUpdates, preserveCause = true) } def exception: Option[Throwable] = exceptionWrapper.flatMap { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 8ba3f5e241899..06b5101b1f566 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -370,6 +370,14 @@ object SparkHadoopUtil { val SPARK_YARN_CREDS_COUNTER_DELIM = "-" + /** + * Number of records to update input metrics when reading from HadoopRDDs. + * + * Each update is potentially expensive because we need to use reflection to access the + * Hadoop FileSystem API of interest (only available in 2.5), so we should do this sparingly. + */ + private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + def get: SparkHadoopUtil = { // Check each time to support changing to/from YARN val yarnMode = java.lang.Boolean.valueOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 030ae41db4a62..51c000ea5c574 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -31,7 +31,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} +import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -210,7 +210,7 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() var threwException = true - val (value, accumUpdates) = try { + val value = try { val res = task.run( taskAttemptId = taskId, attemptNumber = attemptNumber, @@ -249,10 +249,11 @@ private[spark] class Executor( m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) - m.updateAccumulators() } - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) + // Note: accumulator updates must be collected after TaskMetrics is updated + val accumUpdates = task.collectAccumulatorUpdates() + val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -297,21 +298,25 @@ private[spark] class Executor( // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - val metrics: Option[TaskMetrics] = Option(task).flatMap { task => - task.metrics.map { m => + // Collect latest accumulator values to report back to the driver + val accumulatorUpdates: Seq[AccumulableInfo] = + if (task != null) { + task.metrics.foreach { m => m.setExecutorRunTime(System.currentTimeMillis() - taskStart) m.setJvmGCTime(computeTotalGcTime() - startGCTime) - m.updateAccumulators() - m } + task.collectAccumulatorUpdates(taskFailed = true) + } else { + Seq.empty[AccumulableInfo] } + val serializedTaskEndReason = { try { - ser.serialize(new ExceptionFailure(t, metrics)) + ser.serialize(new ExceptionFailure(t, accumulatorUpdates)) } catch { case _: NotSerializableException => // t is not serializable so just send the stacktrace - ser.serialize(new ExceptionFailure(t, metrics, false)) + ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false)) } } execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) @@ -418,33 +423,21 @@ private[spark] class Executor( /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { - // list of (task id, metrics) to send back to the driver - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + // list of (task id, accumUpdates) to send back to the driver + val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]() val curGCTime = computeTotalGcTime() for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => metrics.mergeShuffleReadMetrics() - metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - metrics.updateAccumulators() - - if (isLocal) { - // JobProgressListener will hold an reference of it during - // onExecutorMetricsUpdate(), then JobProgressListener can not see - // the changes of metrics any more, so make a deep copy of it - val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) - tasksMetrics += ((taskRunner.taskId, copiedMetrics)) - } else { - // It will be copied by serialization - tasksMetrics += ((taskRunner.taskId, metrics)) - } + accumUpdates += ((taskRunner.taskId, metrics.accumulatorUpdates())) } } } - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) try { val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 8f1d7f89a44b4..ed9e157ce758b 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -17,13 +17,15 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Method by which input data was read. Network means that the data was read over the network + * Method by which input data was read. Network means that the data was read over the network * from a remote block manager (which may have stored the data on-disk or in-memory). + * Operations are not thread-safe. */ @DeveloperApi object DataReadMethod extends Enumeration with Serializable { @@ -34,44 +36,75 @@ object DataReadMethod extends Enumeration with Serializable { /** * :: DeveloperApi :: - * Metrics about reading input data. + * A collection of accumulators that represents metrics about reading data from external systems. */ @DeveloperApi -case class InputMetrics(readMethod: DataReadMethod.Value) { +class InputMetrics private ( + _bytesRead: Accumulator[Long], + _recordsRead: Accumulator[Long], + _readMethod: Accumulator[String]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.RECORDS_READ), + TaskMetrics.getAccum[String](accumMap, InternalAccumulator.input.READ_METHOD)) + } /** - * This is volatile so that it is visible to the updater thread. + * Create a new [[InputMetrics]] that is not associated with any particular task. + * + * This mainly exists because of SPARK-5225, where we are forced to use a dummy [[InputMetrics]] + * because we want to ignore metrics from a second read method. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerInputMetrics]]. */ - @volatile @transient var bytesReadCallback: Option[() => Long] = None + private[executor] def this() { + this(InternalAccumulator.createInputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) + } /** - * Total bytes read. + * Total number of bytes read. */ - private var _bytesRead: Long = _ - def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long): Unit = _bytesRead += bytes + def bytesRead: Long = _bytesRead.localValue /** - * Total records read. + * Total number of records read. */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def recordsRead: Long = _recordsRead.localValue /** - * Invoke the bytesReadCallback and mutate bytesRead. + * The source from which this task reads its input. */ - def updateBytesRead() { - bytesReadCallback.foreach { c => - _bytesRead = c() - } + def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) + + @deprecated("incrementing input metrics is for internal use only", "2.0.0") + def incBytesRead(v: Long): Unit = _bytesRead.add(v) + @deprecated("incrementing input metrics is for internal use only", "2.0.0") + def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) + private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = + _readMethod.setValue(v.toString) + +} + +/** + * Deprecated methods to preserve case class matching behavior before Spark 2.0. + */ +object InputMetrics { + + @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") + def apply(readMethod: DataReadMethod.Value): InputMetrics = { + val im = new InputMetrics + im.setReadMethod(readMethod) + im } - /** - * Register a function that can be called to get up-to-date information on how many bytes the task - * has read from an input source. - */ - def setBytesReadCallback(f: Option[() => Long]) { - bytesReadCallback = f + @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") + def unapply(input: InputMetrics): Option[DataReadMethod.Value] = { + Some(input.readMethod) } } diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index ad132d004cde0..0b37d559c7462 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -17,12 +17,14 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * Method by which output data was written. + * Operations are not thread-safe. */ @DeveloperApi object DataWriteMethod extends Enumeration with Serializable { @@ -33,21 +35,70 @@ object DataWriteMethod extends Enumeration with Serializable { /** * :: DeveloperApi :: - * Metrics about writing output data. + * A collection of accumulators that represents metrics about writing data to external systems. */ @DeveloperApi -case class OutputMetrics(writeMethod: DataWriteMethod.Value) { +class OutputMetrics private ( + _bytesWritten: Accumulator[Long], + _recordsWritten: Accumulator[Long], + _writeMethod: Accumulator[String]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.BYTES_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.RECORDS_WRITTEN), + TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD)) + } + + /** + * Create a new [[OutputMetrics]] that is not associated with any particular task. + * + * This is only used for preserving matching behavior on [[OutputMetrics]], which used to be + * a case class before Spark 2.0. Once we remove support for matching on [[OutputMetrics]] + * we can remove this constructor as well. + */ + private[executor] def this() { + this(InternalAccumulator.createOutputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) + } + + /** + * Total number of bytes written. + */ + def bytesWritten: Long = _bytesWritten.localValue + /** - * Total bytes written + * Total number of records written. */ - private var _bytesWritten: Long = _ - def bytesWritten: Long = _bytesWritten - private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value + def recordsWritten: Long = _recordsWritten.localValue /** - * Total records written + * The source to which this task writes its output. */ - private var _recordsWritten: Long = 0L - def recordsWritten: Long = _recordsWritten - private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value + def writeMethod: DataWriteMethod.Value = DataWriteMethod.withName(_writeMethod.localValue) + + private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v) + private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v) + private[spark] def setWriteMethod(v: DataWriteMethod.Value): Unit = + _writeMethod.setValue(v.toString) + +} + +/** + * Deprecated methods to preserve case class matching behavior before Spark 2.0. + */ +object OutputMetrics { + + @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") + def apply(writeMethod: DataWriteMethod.Value): OutputMetrics = { + val om = new OutputMetrics + om.setWriteMethod(writeMethod) + om + } + + @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") + def unapply(output: OutputMetrics): Option[DataWriteMethod.Value] = { + Some(output.writeMethod) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index e985b35ace623..50bb645d974a3 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -17,71 +17,103 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Metrics pertaining to shuffle data read in a given task. + * A collection of accumulators that represent metrics about reading shuffle data. + * Operations are not thread-safe. */ @DeveloperApi -class ShuffleReadMetrics extends Serializable { +class ShuffleReadMetrics private ( + _remoteBlocksFetched: Accumulator[Int], + _localBlocksFetched: Accumulator[Int], + _remoteBytesRead: Accumulator[Long], + _localBytesRead: Accumulator[Long], + _fetchWaitTime: Accumulator[Long], + _recordsRead: Accumulator[Long]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.REMOTE_BLOCKS_FETCHED), + TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.LOCAL_BLOCKS_FETCHED), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.REMOTE_BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.LOCAL_BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.FETCH_WAIT_TIME), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.RECORDS_READ)) + } + /** - * Number of remote blocks fetched in this shuffle by this task + * Create a new [[ShuffleReadMetrics]] that is not associated with any particular task. + * + * This mainly exists for legacy reasons, because we use dummy [[ShuffleReadMetrics]] in + * many places only to merge their values together later. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerTempShuffleReadMetrics]] followed by + * [[TaskMetrics.mergeShuffleReadMetrics]]. */ - private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched: Int = _remoteBlocksFetched - private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + private[spark] def this() { + this(InternalAccumulator.createShuffleReadAccums().map { a => (a.name.get, a) }.toMap) + } /** - * Number of local blocks fetched in this shuffle by this task + * Number of remote blocks fetched in this shuffle by this task. */ - private var _localBlocksFetched: Int = _ - def localBlocksFetched: Int = _localBlocksFetched - private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value + def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue /** - * Time the task spent waiting for remote shuffle blocks. This only includes the time - * blocking on shuffle input data. For instance if block B is being fetched while the task is - * still not finished processing block A, it is not considered to be blocking on block B. + * Number of local blocks fetched in this shuffle by this task. */ - private var _fetchWaitTime: Long = _ - def fetchWaitTime: Long = _fetchWaitTime - private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value - private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value + def localBlocksFetched: Int = _localBlocksFetched.localValue /** - * Total number of remote bytes read from the shuffle by this task + * Total number of remote bytes read from the shuffle by this task. */ - private var _remoteBytesRead: Long = _ - def remoteBytesRead: Long = _remoteBytesRead - private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value - private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + def remoteBytesRead: Long = _remoteBytesRead.localValue /** * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ - private var _localBytesRead: Long = _ - def localBytesRead: Long = _localBytesRead - private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + def localBytesRead: Long = _localBytesRead.localValue /** - * Total bytes fetched in the shuffle by this task (both remote and local). + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. + */ + def fetchWaitTime: Long = _fetchWaitTime.localValue + + /** + * Total number of records read from the shuffle by this task. */ - def totalBytesRead: Long = _remoteBytesRead + _localBytesRead + def recordsRead: Long = _recordsRead.localValue /** - * Number of blocks fetched in this shuffle by this task (remote or local) + * Total bytes fetched in the shuffle by this task (both remote and local). */ - def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched + def totalBytesRead: Long = remoteBytesRead + localBytesRead /** - * Total number of records read from the shuffle by this task + * Number of blocks fetched in this shuffle by this task (remote or local). */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - private[spark] def incRecordsRead(value: Long) = _recordsRead += value - private[spark] def decRecordsRead(value: Long) = _recordsRead -= value + def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched + + private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v) + private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v) + private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v) + private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v) + private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v) + private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + + private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v) + private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v) + private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v) + private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v) + private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) + private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) + } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 24795f860087f..c7aaabb561bba 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -17,40 +17,66 @@ package org.apache.spark.executor +import org.apache.spark.{Accumulator, InternalAccumulator} import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Metrics pertaining to shuffle data written in a given task. + * A collection of accumulators that represent metrics about writing shuffle data. + * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics extends Serializable { +class ShuffleWriteMetrics private ( + _bytesWritten: Accumulator[Long], + _recordsWritten: Accumulator[Long], + _writeTime: Accumulator[Long]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.BYTES_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.RECORDS_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.WRITE_TIME)) + } /** - * Number of bytes written for the shuffle by this task + * Create a new [[ShuffleWriteMetrics]] that is not associated with any particular task. + * + * This mainly exists for legacy reasons, because we use dummy [[ShuffleWriteMetrics]] in + * many places only to merge their values together later. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerShuffleWriteMetrics]]. */ - @volatile private var _bytesWritten: Long = _ - def bytesWritten: Long = _bytesWritten - private[spark] def incBytesWritten(value: Long) = _bytesWritten += value - private[spark] def decBytesWritten(value: Long) = _bytesWritten -= value + private[spark] def this() { + this(InternalAccumulator.createShuffleWriteAccums().map { a => (a.name.get, a) }.toMap) + } /** - * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds + * Number of bytes written for the shuffle by this task. */ - @volatile private var _writeTime: Long = _ - def writeTime: Long = _writeTime - private[spark] def incWriteTime(value: Long) = _writeTime += value - private[spark] def decWriteTime(value: Long) = _writeTime -= value + def bytesWritten: Long = _bytesWritten.localValue /** - * Total number of records written to the shuffle by this task + * Total number of records written to the shuffle by this task. */ - @volatile private var _recordsWritten: Long = _ - def recordsWritten: Long = _recordsWritten - private[spark] def incRecordsWritten(value: Long) = _recordsWritten += value - private[spark] def decRecordsWritten(value: Long) = _recordsWritten -= value - private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value + def recordsWritten: Long = _recordsWritten.localValue + + /** + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds. + */ + def writeTime: Long = _writeTime.localValue + + private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] def decBytesWritten(v: Long): Unit = { + _bytesWritten.setValue(bytesWritten - v) + } + private[spark] def decRecordsWritten(v: Long): Unit = { + _recordsWritten.setValue(recordsWritten - v) + } // Legacy methods for backward compatibility. // TODO: remove these once we make this class private. diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 32ef5a9b5606f..8d10bf588ef1f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,90 +17,161 @@ package org.apache.spark.executor -import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentHashMap - +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.DataReadMethod.DataReadMethod +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.util.Utils /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. * - * This class is used to house metrics both for in-progress and completed tasks. In executors, - * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread - * reads it to send in-progress metrics, and the task thread reads it to send metrics along with - * the completed task. + * This class is wrapper around a collection of internal accumulators that represent metrics + * associated with a task. The local values of these accumulators are sent from the executor + * to the driver when the task completes. These values are then merged into the corresponding + * accumulator previously registered on the driver. + * + * The accumulator updates are also sent to the driver periodically (on executor heartbeat) + * and when the task failed with an exception. The [[TaskMetrics]] object itself should never + * be sent to the driver. * - * So, when adding new fields, take into consideration that the whole object can be serialized for - * shipping off at any time to consumers of the SparkListener interface. + * @param initialAccums the initial set of accumulators that this [[TaskMetrics]] depends on. + * Each accumulator in this initial set must be uniquely named and marked + * as internal. Additional accumulators registered later need not satisfy + * these requirements. */ @DeveloperApi -class TaskMetrics extends Serializable { +class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { + + import InternalAccumulator._ + + // Needed for Java tests + def this() { + this(InternalAccumulator.create()) + } + + /** + * All accumulators registered with this task. + */ + private val accums = new ArrayBuffer[Accumulable[_, _]] + accums ++= initialAccums + + /** + * A map for quickly accessing the initial set of accumulators by name. + */ + private val initialAccumsMap: Map[String, Accumulator[_]] = { + val map = new mutable.HashMap[String, Accumulator[_]] + initialAccums.foreach { a => + val name = a.name.getOrElse { + throw new IllegalArgumentException( + "initial accumulators passed to TaskMetrics must be named") + } + require(a.isInternal, + s"initial accumulator '$name' passed to TaskMetrics must be marked as internal") + require(!map.contains(name), + s"detected duplicate accumulator name '$name' when constructing TaskMetrics") + map(name) = a + } + map.toMap + } + + // Each metric is internally represented as an accumulator + private val _executorDeserializeTime = getAccum(EXECUTOR_DESERIALIZE_TIME) + private val _executorRunTime = getAccum(EXECUTOR_RUN_TIME) + private val _resultSize = getAccum(RESULT_SIZE) + private val _jvmGCTime = getAccum(JVM_GC_TIME) + private val _resultSerializationTime = getAccum(RESULT_SERIALIZATION_TIME) + private val _memoryBytesSpilled = getAccum(MEMORY_BYTES_SPILLED) + private val _diskBytesSpilled = getAccum(DISK_BYTES_SPILLED) + private val _peakExecutionMemory = getAccum(PEAK_EXECUTION_MEMORY) + private val _updatedBlockStatuses = + TaskMetrics.getAccum[Seq[(BlockId, BlockStatus)]](initialAccumsMap, UPDATED_BLOCK_STATUSES) + /** - * Host's name the task runs on + * Time taken on the executor to deserialize this task. */ - private var _hostname: String = _ - def hostname: String = _hostname - private[spark] def setHostname(value: String) = _hostname = value + def executorDeserializeTime: Long = _executorDeserializeTime.localValue /** - * Time taken on the executor to deserialize this task + * Time the executor spends actually running the task (including fetching shuffle data). */ - private var _executorDeserializeTime: Long = _ - def executorDeserializeTime: Long = _executorDeserializeTime - private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value + def executorRunTime: Long = _executorRunTime.localValue + /** + * The number of bytes this task transmitted back to the driver as the TaskResult. + */ + def resultSize: Long = _resultSize.localValue /** - * Time the executor spends actually running the task (including fetching shuffle data) + * Amount of time the JVM spent in garbage collection while executing this task. */ - private var _executorRunTime: Long = _ - def executorRunTime: Long = _executorRunTime - private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value + def jvmGCTime: Long = _jvmGCTime.localValue /** - * The number of bytes this task transmitted back to the driver as the TaskResult + * Amount of time spent serializing the task result. */ - private var _resultSize: Long = _ - def resultSize: Long = _resultSize - private[spark] def setResultSize(value: Long) = _resultSize = value + def resultSerializationTime: Long = _resultSerializationTime.localValue + /** + * The number of in-memory bytes spilled by this task. + */ + def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue /** - * Amount of time the JVM spent in garbage collection while executing this task + * The number of on-disk bytes spilled by this task. */ - private var _jvmGCTime: Long = _ - def jvmGCTime: Long = _jvmGCTime - private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value + def diskBytesSpilled: Long = _diskBytesSpilled.localValue /** - * Amount of time spent serializing the task result + * Peak memory used by internal data structures created during shuffles, aggregations and + * joins. The value of this accumulator should be approximately the sum of the peak sizes + * across all such data structures created in this task. For SQL jobs, this only tracks all + * unsafe operators and ExternalSort. */ - private var _resultSerializationTime: Long = _ - def resultSerializationTime: Long = _resultSerializationTime - private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value + def peakExecutionMemory: Long = _peakExecutionMemory.localValue /** - * The number of in-memory bytes spilled by this task + * Storage statuses of any blocks that have been updated as a result of this task. */ - private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled: Long = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue + + @deprecated("use updatedBlockStatuses instead", "2.0.0") + def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { + if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None + } + + // Setters and increment-ers + private[spark] def setExecutorDeserializeTime(v: Long): Unit = + _executorDeserializeTime.setValue(v) + private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v) + private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v) + private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v) + private[spark] def setResultSerializationTime(v: Long): Unit = + _resultSerializationTime.setValue(v) + private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v) + private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v) + private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) + private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.add(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v) /** - * The number of on-disk bytes spilled by this task + * Get a Long accumulator from the given map by name, assuming it exists. + * Note: this only searches the initial set of accumulators passed into the constructor. */ - private var _diskBytesSpilled: Long = _ - def diskBytesSpilled: Long = _diskBytesSpilled - private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + private[spark] def getAccum(name: String): Accumulator[Long] = { + TaskMetrics.getAccum[Long](initialAccumsMap, name) + } + + + /* ========================== * + | INPUT METRICS | + * ========================== */ private var _inputMetrics: Option[InputMetrics] = None @@ -116,7 +187,8 @@ class TaskMetrics extends Serializable { private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = { synchronized { val metrics = _inputMetrics.getOrElse { - val metrics = new InputMetrics(readMethod) + val metrics = new InputMetrics(initialAccumsMap) + metrics.setReadMethod(readMethod) _inputMetrics = Some(metrics) metrics } @@ -128,18 +200,17 @@ class TaskMetrics extends Serializable { if (metrics.readMethod == readMethod) { metrics } else { - new InputMetrics(readMethod) + val m = new InputMetrics + m.setReadMethod(readMethod) + m } } } - /** - * This should only be used when recreating TaskMetrics, not when updating input metrics in - * executors - */ - private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) { - _inputMetrics = inputMetrics - } + + /* ============================ * + | OUTPUT METRICS | + * ============================ */ private var _outputMetrics: Option[OutputMetrics] = None @@ -149,23 +220,24 @@ class TaskMetrics extends Serializable { */ def outputMetrics: Option[OutputMetrics] = _outputMetrics - @deprecated("setting OutputMetrics is for internal use only", "2.0.0") - def outputMetrics_=(om: Option[OutputMetrics]): Unit = { - _outputMetrics = om - } - /** * Get or create a new [[OutputMetrics]] associated with this task. */ private[spark] def registerOutputMetrics( writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized { _outputMetrics.getOrElse { - val metrics = new OutputMetrics(writeMethod) + val metrics = new OutputMetrics(initialAccumsMap) + metrics.setWriteMethod(writeMethod) _outputMetrics = Some(metrics) metrics } } + + /* ================================== * + | SHUFFLE READ METRICS | + * ================================== */ + private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None /** @@ -174,21 +246,13 @@ class TaskMetrics extends Serializable { */ def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics - /** - * This should only be used when recreating TaskMetrics, not when updating read metrics in - * executors. - */ - private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { - _shuffleReadMetrics = shuffleReadMetrics - } - /** * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency. * * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization * issues from readers in different threads, in-progress tasks use a [[ShuffleReadMetrics]] for * each dependency and merge these metrics before reporting them to the driver. - */ + */ @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[ShuffleReadMetrics] /** @@ -210,19 +274,21 @@ class TaskMetrics extends Serializable { */ private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { if (tempShuffleReadMetrics.nonEmpty) { - val merged = new ShuffleReadMetrics - for (depMetrics <- tempShuffleReadMetrics) { - merged.incFetchWaitTime(depMetrics.fetchWaitTime) - merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) - merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) - merged.incRemoteBytesRead(depMetrics.remoteBytesRead) - merged.incLocalBytesRead(depMetrics.localBytesRead) - merged.incRecordsRead(depMetrics.recordsRead) - } - _shuffleReadMetrics = Some(merged) + val metrics = new ShuffleReadMetrics(initialAccumsMap) + metrics.setRemoteBlocksFetched(tempShuffleReadMetrics.map(_.remoteBlocksFetched).sum) + metrics.setLocalBlocksFetched(tempShuffleReadMetrics.map(_.localBlocksFetched).sum) + metrics.setFetchWaitTime(tempShuffleReadMetrics.map(_.fetchWaitTime).sum) + metrics.setRemoteBytesRead(tempShuffleReadMetrics.map(_.remoteBytesRead).sum) + metrics.setLocalBytesRead(tempShuffleReadMetrics.map(_.localBytesRead).sum) + metrics.setRecordsRead(tempShuffleReadMetrics.map(_.recordsRead).sum) + _shuffleReadMetrics = Some(metrics) } } + /* =================================== * + | SHUFFLE WRITE METRICS | + * =================================== */ + private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None /** @@ -230,86 +296,120 @@ class TaskMetrics extends Serializable { */ def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics - @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0") - def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = { - _shuffleWriteMetrics = swm - } - /** * Get or create a new [[ShuffleWriteMetrics]] associated with this task. */ private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized { _shuffleWriteMetrics.getOrElse { - val metrics = new ShuffleWriteMetrics + val metrics = new ShuffleWriteMetrics(initialAccumsMap) _shuffleWriteMetrics = Some(metrics) metrics } } - private var _updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = - Seq.empty[(BlockId, BlockStatus)] - /** - * Storage statuses of any blocks that have been updated as a result of this task. - */ - def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses + /* ========================== * + | OTHER THINGS | + * ========================== */ - @deprecated("setting updated blocks is for internal use only", "2.0.0") - def updatedBlocks_=(ub: Option[Seq[(BlockId, BlockStatus)]]): Unit = { - _updatedBlockStatuses = ub.getOrElse(Seq.empty[(BlockId, BlockStatus)]) + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = { + accums += a } - private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = { - _updatedBlockStatuses ++= v + /** + * Return the latest updates of accumulators in this task. + * + * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]] + * field is always empty, since this represents the partial updates recorded in this task, + * not the aggregated value across multiple tasks. + */ + def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a => + new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues) } - private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = { - _updatedBlockStatuses = v + // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. + // If so, initialize all relevant metrics classes so listeners can access them downstream. + { + var (hasShuffleRead, hasShuffleWrite, hasInput, hasOutput) = (false, false, false, false) + initialAccums + .filter { a => a.localValue != a.zero } + .foreach { a => + a.name.get match { + case sr if sr.startsWith(SHUFFLE_READ_METRICS_PREFIX) => hasShuffleRead = true + case sw if sw.startsWith(SHUFFLE_WRITE_METRICS_PREFIX) => hasShuffleWrite = true + case in if in.startsWith(INPUT_METRICS_PREFIX) => hasInput = true + case out if out.startsWith(OUTPUT_METRICS_PREFIX) => hasOutput = true + case _ => + } + } + if (hasShuffleRead) { _shuffleReadMetrics = Some(new ShuffleReadMetrics(initialAccumsMap)) } + if (hasShuffleWrite) { _shuffleWriteMetrics = Some(new ShuffleWriteMetrics(initialAccumsMap)) } + if (hasInput) { _inputMetrics = Some(new InputMetrics(initialAccumsMap)) } + if (hasOutput) { _outputMetrics = Some(new OutputMetrics(initialAccumsMap)) } } - @deprecated("use updatedBlockStatuses instead", "2.0.0") - def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { - if (_updatedBlockStatuses.nonEmpty) Some(_updatedBlockStatuses) else None - } +} - private[spark] def updateInputMetrics(): Unit = synchronized { - inputMetrics.foreach(_.updateBytesRead()) - } +private[spark] object TaskMetrics extends Logging { - @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() - // Get the hostname from cached data, since hostname is the order of number of nodes in - // cluster, so using cached hostname will decrease the object number and alleviate the GC - // overhead. - _hostname = TaskMetrics.getCachedHostName(_hostname) - } - - private var _accumulatorUpdates: Map[Long, Any] = Map.empty - @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + def empty: TaskMetrics = new TaskMetrics - private[spark] def updateAccumulators(): Unit = synchronized { - _accumulatorUpdates = _accumulatorsUpdater() + /** + * Get an accumulator from the given map by name, assuming it exists. + */ + def getAccum[T](accumMap: Map[String, Accumulator[_]], name: String): Accumulator[T] = { + require(accumMap.contains(name), s"metric '$name' is missing") + val accum = accumMap(name) + try { + // Note: we can't do pattern matching here because types are erased by compile time + accum.asInstanceOf[Accumulator[T]] + } catch { + case e: ClassCastException => + throw new SparkException(s"accumulator $name was of unexpected type", e) + } } /** - * Return the latest updates of accumulators in this task. + * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only. + * + * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we + * need the latter to post task end events to listeners, so we need to reconstruct the metrics + * on the driver. + * + * This assumes the provided updates contain the initial set of accumulators representing + * internal task level metrics. */ - def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates - - private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { - _accumulatorsUpdater = accumulatorsUpdater + def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = { + // Initial accumulators are passed into the TaskMetrics constructor first because these + // are required to be uniquely named. The rest of the accumulators from this task are + // registered later because they need not satisfy this requirement. + val (initialAccumInfos, otherAccumInfos) = accumUpdates + .filter { info => info.update.isDefined } + .partition { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) } + val initialAccums = initialAccumInfos.map { info => + val accum = InternalAccumulator.create(info.name.get) + accum.setValueAny(info.update.get) + accum + } + // We don't know the types of the rest of the accumulators, so we try to find the same ones + // that were previously registered here on the driver and make copies of them. It is important + // that we copy the accumulators here since they are used across many tasks and we want to + // maintain a snapshot of their local task values when we post them to listeners downstream. + val otherAccums = otherAccumInfos.flatMap { info => + val id = info.id + val acc = Accumulators.get(id).map { a => + val newAcc = a.copy() + newAcc.setValueAny(info.update.get) + newAcc + } + if (acc.isEmpty) { + logWarning(s"encountered unregistered accumulator $id when reconstructing task metrics.") + } + acc + } + val metrics = new TaskMetrics(initialAccums) + otherAccums.foreach(metrics.registerAccumulator) + metrics } -} - -private[spark] object TaskMetrics { - private val hostNameCache = new ConcurrentHashMap[String, String]() - - def empty: TaskMetrics = new TaskMetrics - - def getCachedHostName(host: String): String = { - val canonicalHost = hostNameCache.putIfAbsent(host, host) - if (canonicalHost != null) canonicalHost else host - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 3587e7eb1afaf..d9b0824b38ecc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -153,8 +153,7 @@ class CoGroupedRDD[K: ClassTag]( } context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index a79ab86d49227..3204e6adceca2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -212,6 +212,8 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() + // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) // Sets the thread local variable for the file's name @@ -222,14 +224,17 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) @@ -252,6 +257,9 @@ class HadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } (key, value) } @@ -272,8 +280,8 @@ class HadoopRDD[K, V]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 5cc9c81cc6749..4d2816e335fe3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -133,14 +133,17 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) val format = inputFormatClass.newInstance format match { @@ -182,6 +185,9 @@ class NewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } (reader.getCurrentKey, reader.getCurrentValue) } @@ -201,8 +207,8 @@ class NewHadoopRDD[K, V]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 146cfb9ba8037..9d45fff9213c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -19,47 +19,58 @@ package org.apache.spark.scheduler import org.apache.spark.annotation.DeveloperApi + /** * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + * + * Note: once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. + * + * @param id accumulator ID + * @param name accumulator name + * @param update partial value from a task, may be None if used on driver to describe a stage + * @param value total accumulated value so far, maybe None if used on executors to describe a task + * @param internal whether this accumulator was internal + * @param countFailedValues whether to count this accumulator's partial value if the task failed */ @DeveloperApi -class AccumulableInfo private[spark] ( - val id: Long, - val name: String, - val update: Option[String], // represents a partial update within a task - val value: String, - val internal: Boolean) { - - override def equals(other: Any): Boolean = other match { - case acc: AccumulableInfo => - this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value && - this.internal == acc.internal - case _ => false - } +case class AccumulableInfo private[spark] ( + id: Long, + name: Option[String], + update: Option[Any], // represents a partial update within a task + value: Option[Any], + private[spark] val internal: Boolean, + private[spark] val countFailedValues: Boolean) - override def hashCode(): Int = { - val state = Seq(id, name, update, value, internal) - state.map(_.hashCode).reduceLeft(31 * _ + _) - } -} +/** + * A collection of deprecated constructors. This will be removed soon. + */ object AccumulableInfo { + + @deprecated("do not create AccumulableInfo", "2.0.0") def apply( id: Long, name: String, update: Option[String], value: String, internal: Boolean): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, internal) + new AccumulableInfo( + id, Option(name), update, Option(value), internal, countFailedValues = false) } + @deprecated("do not create AccumulableInfo", "2.0.0") def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, internal = false) + new AccumulableInfo( + id, Option(name), update, Option(value), internal = false, countFailedValues = false) } + @deprecated("do not create AccumulableInfo", "2.0.0") def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value, internal = false) + new AccumulableInfo( + id, Option(name), None, Option(value), internal = false, countFailedValues = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6b01a10fc136e..897479b50010d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -208,11 +208,10 @@ class DAGScheduler( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics): Unit = { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo): Unit = { eventProcessLoop.post( - CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + CompletionEvent(task, reason, result, accumUpdates, taskInfo)) } /** @@ -222,9 +221,10 @@ class DAGScheduler( */ def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } @@ -1074,39 +1074,43 @@ class DAGScheduler( } } - /** Merge updates from a task to our local accumulator values */ + /** + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ private def updateAccumulators(event: CompletionEvent): Unit = { val task = event.task val stage = stageIdToStage(task.stageId) - if (event.accumUpdates != null) { - try { - Accumulators.add(event.accumUpdates) - - event.accumUpdates.foreach { case (id, partialValue) => - // In this instance, although the reference in Accumulators.originals is a WeakRef, - // it's guaranteed to exist since the event.accumUpdates Map exists - - val acc = Accumulators.originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] - case None => throw new NullPointerException("Non-existent reference to Accumulator") - } - - // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name.get - val value = s"${acc.value}" - stage.latestInfo.accumulables(id) = - new AccumulableInfo(id, name, None, value, acc.isInternal) - event.taskInfo.accumulables += - new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) - } + try { + event.accumUpdates.foreach { ainfo => + assert(ainfo.update.isDefined, "accumulator from task should have a partial value") + val id = ainfo.id + val partialValue = ainfo.update.get + // Find the corresponding accumulator on the driver and update it + val acc: Accumulable[Any, Any] = Accumulators.get(id) match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => + throw new SparkException(s"attempted to access non-existent accumulator $id") + } + acc ++= partialValue + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name + stage.latestInfo.accumulables(id) = new AccumulableInfo( + id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues) + event.taskInfo.accumulables += new AccumulableInfo( + id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues) } - } catch { - // If we see an exception during accumulator update, just log the - // error and move on. - case e: Exception => - logError(s"Failed to update accumulators for $task", e) } + } catch { + case NonFatal(e) => + logError(s"Failed to update accumulators for task ${task.partitionId}", e) } } @@ -1116,6 +1120,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task + val taskId = event.taskInfo.id val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) @@ -1125,12 +1130,26 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // The success case is dealt with separately below, since we need to compute accumulator - // updates before posting. + // Reconstruct task metrics. Note: this may be null if the task has failed. + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulatorUpdates(event.accumUpdates) + } catch { + case NonFatal(e) => + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + // The success case is dealt with separately below. + // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here. if (event.reason != Success) { val attemptId = task.stageAttemptId - listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, - event.taskInfo, event.taskMetrics)) + listenerBus.post(SparkListenerTaskEnd( + stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics)) } if (!stageIdToStage.contains(task.stageId)) { @@ -1142,7 +1161,7 @@ class DAGScheduler( event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, - event.reason, event.taskInfo, event.taskMetrics)) + event.reason, event.taskInfo, taskMetrics)) stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => @@ -1291,7 +1310,8 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Do nothing here, left up to the TaskScheduler to decide how to handle user failures + // Tasks failed with exceptions might still have accumulator updates. + updateAccumulators(event) case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. @@ -1637,7 +1657,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) - case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => + case completion: CompletionEvent => dagScheduler.handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason, exception) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index dda3b6cc7f960..d5cd2da7a10d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,9 +73,8 @@ private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 6590cf6ffd24f..885f70e89fbf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each * partition of the given RDD. Once deserialized, the type should be * (RDD[T], (TaskContext, Iterator[T]) => U). @@ -37,6 +38,9 @@ import org.apache.spark.rdd.RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). + * @param _initialAccums initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] class ResultTask[T, U]( stageId: Int, @@ -45,8 +49,8 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, - internalAccumulators: Seq[Accumulator[Long]]) - extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.create()) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ea97ef0e746d8..89207dd175ae9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,10 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling + * @param _initialAccums initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -44,8 +48,8 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - internalAccumulators: Seq[Accumulator[Long]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + _initialAccums: Seq[Accumulator[_]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 6c6883d703bea..ed3adbd81c28e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties +import javax.annotation.Nullable import scala.collection.Map import scala.collection.mutable @@ -60,7 +61,7 @@ case class SparkListenerTaskEnd( taskType: String, reason: TaskEndReason, taskInfo: TaskInfo, - taskMetrics: TaskMetrics) + @Nullable taskMetrics: TaskMetrics) extends SparkListenerEvent @DeveloperApi @@ -111,12 +112,12 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends /** * Periodic updates from executors. * @param execId executor id - * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics) + * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates) */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - taskMetrics: Seq[(Long, Int, Int, TaskMetrics)]) + accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])]) extends SparkListenerEvent @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 7ea24a217bd39..c1c8b47128f22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -74,10 +74,10 @@ private[scheduler] abstract class Stage( val name: String = callSite.shortForm val details: String = callSite.longForm - private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty /** Internal accumulators shared across all tasks in this stage. */ - def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators /** * Re-initialize the internal accumulators associated with this stage. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index fca57928eca1b..a49f3716e2702 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import scala.collection.mutable.HashMap @@ -41,32 +41,29 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * and divides the task output to multiple buckets (based on the task's partitioner). * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD + * @param initialAccumulators initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { /** - * The key of the Map is the accumulator id and the value of the Map is the latest accumulator - * local value. - */ - type AccumulatorUpdates = Map[Long, Any] - - /** - * Called by [[Executor]] to run this task. + * Called by [[org.apache.spark.executor.Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task along with updates of Accumulators. */ final def run( - taskAttemptId: Long, - attemptNumber: Int, - metricsSystem: MetricsSystem) - : (T, AccumulatorUpdates) = { + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem): T = { context = new TaskContextImpl( stageId, partitionId, @@ -74,16 +71,14 @@ private[spark] abstract class Task[T]( attemptNumber, taskMemoryManager, metricsSystem, - internalAccumulators) + initialAccumulators) TaskContext.setTaskContext(context) - context.taskMetrics.setHostname(Utils.localHostName()) - context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - (runTask(context), context.collectAccumulators()) + runTask(context) } finally { context.markTaskCompleted() try { @@ -140,6 +135,18 @@ private[spark] abstract class Task[T]( */ def executorDeserializeTime: Long = _executorDeserializeTime + /** + * Collect the latest values of accumulators used in this task. If the task failed, + * filter out the accumulators whose values should not be included on failures. + */ + def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = { + if (context != null) { + context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues } + } else { + Seq.empty[AccumulableInfo] + } + } + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index b82c7f3fa54f8..03135e63d7551 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,11 +20,9 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.Map -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockId import org.apache.spark.util.Utils @@ -36,31 +34,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ -private[spark] -class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any], - var metrics: TaskMetrics) +private[spark] class DirectTaskResult[T]( + var valueBytes: ByteBuffer, + var accumUpdates: Seq[AccumulableInfo]) extends TaskResult[T] with Externalizable { private var valueObjectDeserialized = false private var valueObject: T = _ - def this() = this(null.asInstanceOf[ByteBuffer], null, null) + def this() = this(null.asInstanceOf[ByteBuffer], null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - - out.writeInt(valueBytes.remaining); + out.writeInt(valueBytes.remaining) Utils.writeByteBuffer(valueBytes, out) - out.writeInt(accumUpdates.size) - for ((key, value) <- accumUpdates) { - out.writeLong(key) - out.writeObject(value) - } - out.writeObject(metrics) + accumUpdates.foreach(out.writeObject) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val blen = in.readInt() val byteVal = new Array[Byte](blen) in.readFully(byteVal) @@ -70,13 +61,12 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - val _accumUpdates = mutable.Map[Long, Any]() + val _accumUpdates = new ArrayBuffer[AccumulableInfo] for (i <- 0 until numUpdates) { - _accumUpdates(in.readLong()) = in.readObject() + _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo] } accumUpdates = _accumUpdates } - metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index f4965994d8277..c94c4f55e9ced 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.concurrent.RejectedExecutionException +import java.util.concurrent.{ExecutorService, RejectedExecutionException} import scala.language.existentials import scala.util.control.NonFatal @@ -35,9 +35,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( - THREADS, "task-result-getter") + // Exposed for testing. + protected val getTaskResultExecutor: ExecutorService = + ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter") + + // Exposed for testing. protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { sparkEnv.closureSerializer.newInstance() @@ -45,7 +48,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } def enqueueSuccessfulTask( - taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, + tid: Long, + serializedData: ByteBuffer): Unit = { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { @@ -82,7 +87,19 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul (deserializedResult, size) } - result.metrics.setResultSize(size) + // Set the task result size in the accumulator updates received from the executors. + // We need to do this here on the driver because if we did this on the executors then + // we would have to serialize the result again after updating the size. + result.accumUpdates = result.accumUpdates.map { a => + if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { + assert(a.update == Some(0L), + "task result size should not have been set on the executors") + a.copy(update = Some(size.toLong)) + } else { + a + } + } + scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 7c0b007db708e..fccd6e0699341 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -65,8 +65,10 @@ private[spark] trait TaskScheduler { * alive. Return true if the driver knows about the given block manager. Otherwise, return false, * indicating that the block manager should re-register. */ - def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean + def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean /** * Get an application ID associated with the job. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6e3ef0e54f0fd..29341dfe3043c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,7 +30,6 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -380,17 +379,17 @@ private[spark] class TaskSchedulerImpl( */ override def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + accumUpdates: Array[(Long, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - - val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { - taskMetrics.flatMap { case (id, metrics) => + // (taskId, stageId, stageAttemptId, accumUpdates) + val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized { + accumUpdates.flatMap { case (id, updates) => taskIdToTaskSetManager.get(id).map { taskSetMgr => - (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates) } } } - dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId) } def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index aa39b59d8cce4..cf97877476d54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -621,8 +621,7 @@ private[spark] class TaskSetManager( // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. // Note: "result.value()" only deserializes the value when it's called at the first time, so // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) if (!successful(index)) { tasksSuccessful += 1 logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( @@ -653,8 +652,7 @@ private[spark] class TaskSetManager( info.markFailed() val index = info.index copiesRunning(index) -= 1 - var taskMetrics : TaskMetrics = null - + var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo] val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString val failureException: Option[Throwable] = reason match { @@ -669,7 +667,8 @@ private[spark] class TaskSetManager( None case ef: ExceptionFailure => - taskMetrics = ef.metrics.orNull + // ExceptionFailure's might have accumulator updates + accumUpdates = ef.accumUpdates if (ef.className == classOf[NotSerializableException].getName) { // If the task result wasn't serializable, there's no point in trying to re-execute it. logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying" @@ -721,7 +720,7 @@ private[spark] class TaskSetManager( // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) - sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) addPendingTask(index) if (!isZombie && state != TaskState.KILLED && reason.isInstanceOf[TaskFailedReason] @@ -793,7 +792,8 @@ private[spark] class TaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded( + tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info) } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index a57e5b0bfb865..acbe16001f5ba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -103,8 +103,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 078718ba11260..9c92a501503cd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -237,7 +237,8 @@ private[v1] object AllStagesResource { } def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = { - new AccumulableInfo(acc.id, acc.name, acc.update, acc.value) + new AccumulableInfo( + acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull) } def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 4a9f8b30525fe..b2aa8bfbe7009 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -325,12 +325,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { + val metrics = new TaskMetrics val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) + stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo, Some(metrics))) } for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); @@ -387,9 +388,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { (Some(e.toErrorString), None) } - if (!metrics.isEmpty) { + metrics.foreach { m => val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) - updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) @@ -489,19 +490,18 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) { val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { logWarning("Metrics update for task in unknown stage " + sid) new StageUIData }) val taskData = stageData.taskData.get(taskId) - taskData.map { t => + val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates) + taskData.foreach { t => if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, - t.taskMetrics) - + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics) // Overwrite task metrics - t.taskMetrics = Some(taskMetrics) + t.taskMetrics = Some(metrics) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 914f6183cc2a4..29c5ff0b5cf0b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -271,8 +271,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo): Elem = - {acc.name}{acc.value} + def accumulableRow(acc: AccumulableInfo): Seq[Node] = { + (acc.name, acc.value) match { + case (Some(name), Some(value)) => {name}{value} + case _ => Seq.empty[Node] + } + } val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, @@ -404,13 +408,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) - val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => - info.accumulables - .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.update.getOrElse("0").toLong } - .getOrElse(0L) - .toDouble - } + val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.peakExecutionMemory.toDouble + } val peakExecutionMemoryQuantiles = { - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") - } - val peakExecutionMemoryUsed = taskInternalAccumulables - .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.update.getOrElse("0").toLong } - .getOrElse(0L) + val externalAccumulableReadable = info.accumulables + .filterNot(_.internal) + .flatMap { a => + (a.name, a.update) match { + case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update")) + case _ => None + } + } + val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index efa22b99936af..dc8070cf8aad3 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -233,14 +233,14 @@ private[spark] object JsonProtocol { def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId - val taskMetrics = metricsUpdate.taskMetrics + val accumUpdates = metricsUpdate.accumUpdates ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ ("Executor ID" -> execId) ~ - ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ ("Stage ID" -> stageId) ~ ("Stage Attempt ID" -> stageAttemptId) ~ - ("Task Metrics" -> taskMetricsToJson(metrics)) + ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList)) }) } @@ -265,7 +265,7 @@ private[spark] object JsonProtocol { ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -284,11 +284,44 @@ private[spark] object JsonProtocol { } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { + val name = accumulableInfo.name ("ID" -> accumulableInfo.id) ~ - ("Name" -> accumulableInfo.name) ~ - ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~ - ("Value" -> accumulableInfo.value) ~ - ("Internal" -> accumulableInfo.internal) + ("Name" -> name) ~ + ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~ + ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~ + ("Internal" -> accumulableInfo.internal) ~ + ("Count Failed Values" -> accumulableInfo.countFailedValues) + } + + /** + * Serialize the value of an accumulator to JSON. + * + * For accumulators representing internal task metrics, this looks up the relevant + * [[AccumulatorParam]] to serialize the value accordingly. For all other accumulators, + * this will simply serialize the value as a string. + * + * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing. + */ + private[util] def accumValueToJson(name: Option[String], value: Any): JValue = { + import AccumulatorParam._ + if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { + (value, InternalAccumulator.getParam(name.get)) match { + case (v: Int, IntAccumulatorParam) => JInt(v) + case (v: Long, LongAccumulatorParam) => JInt(v) + case (v: String, StringAccumulatorParam) => JString(v) + case (v, UpdatedBlockStatusesAccumulatorParam) => + JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) + }) + case (v, p) => + throw new IllegalArgumentException(s"unexpected combination of accumulator value " + + s"type (${v.getClass.getName}) and param (${p.getClass.getName}) in '${name.get}'") + } + } else { + // For all external accumulators, just use strings + JString(value.toString) + } } def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { @@ -303,9 +336,9 @@ private[spark] object JsonProtocol { }.getOrElse(JNothing) val shuffleWriteMetrics: JValue = taskMetrics.shuffleWriteMetrics.map { wm => - ("Shuffle Bytes Written" -> wm.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> wm.shuffleWriteTime) ~ - ("Shuffle Records Written" -> wm.shuffleRecordsWritten) + ("Shuffle Bytes Written" -> wm.bytesWritten) ~ + ("Shuffle Write Time" -> wm.writeTime) ~ + ("Shuffle Records Written" -> wm.recordsWritten) }.getOrElse(JNothing) val inputMetrics: JValue = taskMetrics.inputMetrics.map { im => @@ -324,7 +357,6 @@ private[spark] object JsonProtocol { ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) - ("Host Name" -> taskMetrics.hostname) ~ ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ ("Result Size" -> taskMetrics.resultSize) ~ @@ -352,12 +384,12 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing) + val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ - ("Metrics" -> metrics) + ("Accumulator Updates" -> accumUpdates) case taskCommitDenied: TaskCommitDenied => ("Job ID" -> taskCommitDenied.jobID) ~ ("Partition ID" -> taskCommitDenied.partitionID) ~ @@ -619,14 +651,15 @@ private[spark] object JsonProtocol { def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { val execInfo = (json \ "Executor ID").extract[String] - val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val accumUpdates = (json \ "Metrics Updated").extract[List[JValue]].map { json => val taskId = (json \ "Task ID").extract[Long] val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] - val metrics = taskMetricsFromJson(json \ "Task Metrics") - (taskId, stageId, stageAttemptId, metrics) + val updates = + (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson) + (taskId, stageId, stageAttemptId, updates) } - SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } /** --------------------------------------------------------------------- * @@ -647,7 +680,7 @@ private[spark] object JsonProtocol { val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson(_)) + case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -675,7 +708,7 @@ private[spark] object JsonProtocol { val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson(_)) + case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -690,11 +723,43 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = (json \ "Name").extract[String] - val update = Utils.jsonOption(json \ "Update").map(_.extract[String]) - val value = (json \ "Value").extract[String] + val name = (json \ "Name").extractOpt[String] + val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) - AccumulableInfo(id, name, update, value, internal) + val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) + new AccumulableInfo(id, name, update, value, internal, countFailedValues) + } + + /** + * Deserialize the value of an accumulator from JSON. + * + * For accumulators representing internal task metrics, this looks up the relevant + * [[AccumulatorParam]] to deserialize the value accordingly. For all other + * accumulators, this will simply deserialize the value as a string. + * + * The behavior here must match that of [[accumValueToJson]]. Exposed for testing. + */ + private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = { + import AccumulatorParam._ + if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { + (value, InternalAccumulator.getParam(name.get)) match { + case (JInt(v), IntAccumulatorParam) => v.toInt + case (JInt(v), LongAccumulatorParam) => v.toLong + case (JString(v), StringAccumulatorParam) => v + case (JArray(v), UpdatedBlockStatusesAccumulatorParam) => + v.map { blockJson => + val id = BlockId((blockJson \ "Block ID").extract[String]) + val status = blockStatusFromJson(blockJson \ "Status") + (id, status) + } + case (v, p) => + throw new IllegalArgumentException(s"unexpected combination of accumulator " + + s"value in JSON ($v) and accumulator param (${p.getClass.getName}) in '${name.get}'") + } + } else { + value.extract[String] + } } def taskMetricsFromJson(json: JValue): TaskMetrics = { @@ -702,7 +767,6 @@ private[spark] object JsonProtocol { return TaskMetrics.empty } val metrics = new TaskMetrics - metrics.setHostname((json \ "Host Name").extract[String]) metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long]) metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long]) metrics.setResultSize((json \ "Result Size").extract[Long]) @@ -787,10 +851,12 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") - val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). - map(_.extract[String]).orNull - val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) + val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull + // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x + val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates()) + ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `taskCommitDenied` => diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index df9e0502e7361..5afd6d6e22c62 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -682,8 +682,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) lengths } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 625fdd57eb5d4..876c3a2283649 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -191,8 +191,6 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); - when(taskContext.internalMetricsToAccumulators()).thenReturn(null); - when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 5b84acf40be4e..11c97d7d9a447 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -17,18 +17,22 @@ package org.apache.spark +import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference +import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException import org.apache.spark.scheduler._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { - import InternalAccumulator._ + import AccumulatorParam._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -59,7 +63,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex longAcc.value should be (210L + maxInt * 20) } - test ("value not assignable from tasks") { + test("value not assignable from tasks") { sc = new SparkContext("local", "test") val acc : Accumulator[Int] = sc.accumulator(0) @@ -84,7 +88,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - test ("value not readable in tasks") { + test("value not readable in tasks") { val maxI = 1000 for (nThreads <- List(1, 10)) { // test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -159,193 +163,157 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } - test("internal accumulators in TaskContext") { + test("get accum") { sc = new SparkContext("local", "test") - val accums = InternalAccumulator.create(sc) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) - val internalMetricsToAccums = taskContext.internalMetricsToAccumulators - val collectedInternalAccums = taskContext.collectInternalAccumulators() - val collectedAccums = taskContext.collectAccumulators() - assert(internalMetricsToAccums.size > 0) - assert(internalMetricsToAccums.values.forall(_.isInternal)) - assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) - val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) - assert(collectedInternalAccums.size === internalMetricsToAccums.size) - assert(collectedInternalAccums.size === collectedAccums.size) - assert(collectedInternalAccums.contains(testAccum.id)) - assert(collectedAccums.contains(testAccum.id)) - } + // Don't register with SparkContext for cleanup + var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true) + val accId = acc.id + val ref = WeakReference(acc) + assert(ref.get.isDefined) + Accumulators.register(ref.get.get) - test("internal accumulators in a stage") { - val listener = new SaveInfoListener - val numPartitions = 10 - sc = new SparkContext("local", "test") - sc.addSparkListener(listener) - // Have each task add 1 to the internal accumulator - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions) - // The accumulator values should be merged in the stage - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - assert(stageAccum.value.toLong === numPartitions) - // The accumulator should be updated locally on each task - val taskAccumValues = taskInfos.map { taskInfo => - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - taskAccum.value.toLong - } - // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + // Getting a garbage collected accum should throw error + intercept[IllegalAccessError] { + Accumulators.get(accId) } - rdd.count() + + // Getting a normal accumulator. Note: this has to be separate because referencing an + // accumulator above in an `assert` would keep it from being garbage collected. + val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true, true) + Accumulators.register(acc2) + assert(Accumulators.get(acc2.id) === Some(acc2)) + + // Getting an accumulator that does not exist should return None + assert(Accumulators.get(100000).isEmpty) } - test("internal accumulators in multiple stages") { - val listener = new SaveInfoListener - val numPartitions = 10 - sc = new SparkContext("local", "test") - sc.addSparkListener(listener) - // Each stage creates its own set of internal accumulators so the - // values for the same metric should not be mixed up across stages - val rdd = sc.parallelize(1 to 100, numPartitions) - .map { i => (i, i) } - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - iter - } - .reduceByKey { case (x, y) => x + y } - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 - iter - } - .repartition(numPartitions * 2) - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - // We ran 3 stages, and the accumulator values should be distinct - val stageInfos = listener.getCompletedStageInfos - assert(stageInfos.size === 3) - val (firstStageAccum, secondStageAccum, thirdStageAccum) = - (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), - findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), - findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) - assert(firstStageAccum.value.toLong === numPartitions) - assert(secondStageAccum.value.toLong === numPartitions * 10) - assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) - } - rdd.count() + test("only external accums are automatically registered") { + val accEx = new Accumulator(0, IntAccumulatorParam, Some("external"), internal = false) + val accIn = new Accumulator(0, IntAccumulatorParam, Some("internal"), internal = true) + assert(!accEx.isInternal) + assert(accIn.isInternal) + assert(Accumulators.get(accEx.id).isDefined) + assert(Accumulators.get(accIn.id).isEmpty) } - test("internal accumulators in fully resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + test("copy") { + val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), true, false) + val acc2 = acc1.copy() + assert(acc1.id === acc2.id) + assert(acc1.value === acc2.value) + assert(acc1.name === acc2.name) + assert(acc1.isInternal === acc2.isInternal) + assert(acc1.countFailedValues === acc2.countFailedValues) + assert(acc1 !== acc2) + // Modifying one does not affect the other + acc1.add(44L) + assert(acc1.value === 500L) + assert(acc2.value === 456L) + acc2.add(144L) + assert(acc1.value === 500L) + assert(acc2.value === 600L) } - test("internal accumulators in partially resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + test("register multiple accums with same ID") { + // Make sure these are internal accums so we don't automatically register them already + val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true) + val acc2 = acc1.copy() + assert(acc1 !== acc2) + assert(acc1.id === acc2.id) + assert(Accumulators.originals.isEmpty) + assert(Accumulators.get(acc1.id).isEmpty) + Accumulators.register(acc1) + Accumulators.register(acc2) + // The second one does not override the first one + assert(Accumulators.originals.size === 1) + assert(Accumulators.get(acc1.id) === Some(acc1)) } - /** - * Return the accumulable info that matches the specified name. - */ - private def findAccumulableInfo( - accums: Iterable[AccumulableInfo], - name: String): AccumulableInfo = { - accums.find { a => a.name == name }.getOrElse { - throw new TestFailedException(s"internal accumulator '$name' not found", 0) - } + test("string accumulator param") { + val acc = new Accumulator("", StringAccumulatorParam, Some("darkness")) + assert(acc.value === "") + acc.setValue("feeds") + assert(acc.value === "feeds") + acc.add("your") + assert(acc.value === "your") // value is overwritten, not concatenated + acc += "soul" + assert(acc.value === "soul") + acc ++= "with" + assert(acc.value === "with") + acc.merge("kindness") + assert(acc.value === "kindness") } - /** - * Test whether internal accumulators are merged properly if some tasks fail. - */ - private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { - val listener = new SaveInfoListener - val numPartitions = 10 - val numFailedPartitions = (0 until numPartitions).count(failCondition) - // This says use 1 core and retry tasks up to 2 times - sc = new SparkContext("local[1, 2]", "test") - sc.addSparkListener(listener) - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => - val taskContext = TaskContext.get() - taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - // Fail the first attempts of a subset of the tasks - if (failCondition(i) && taskContext.attemptNumber() == 0) { - throw new Exception("Failing a task intentionally.") - } - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - // We should not double count values in the merged accumulator - assert(stageAccum.value.toLong === numPartitions) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - Some(taskAccum.value.toLong) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None - } - } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) - } - rdd.count() + test("list accumulator param") { + val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers")) + assert(acc.value === Seq.empty[Int]) + acc.add(Seq(1, 2)) + assert(acc.value === Seq(1, 2)) + acc += Seq(3, 4) + assert(acc.value === Seq(1, 2, 3, 4)) + acc ++= Seq(5, 6) + assert(acc.value === Seq(1, 2, 3, 4, 5, 6)) + acc.merge(Seq(7, 8)) + assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8)) + acc.setValue(Seq(9, 10)) + assert(acc.value === Seq(9, 10)) + } + + test("value is reset on the executors") { + val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false) + val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false) + val externalAccums = Seq(acc1, acc2) + val internalAccums = InternalAccumulator.create() + // Set some values; these should not be observed later on the "executors" + acc1.setValue(10) + acc2.setValue(20L) + internalAccums + .find(_.name == Some(InternalAccumulator.TEST_ACCUM)) + .get.asInstanceOf[Accumulator[Long]] + .setValue(30L) + // Simulate the task being serialized and sent to the executors. + val dummyTask = new DummyTask(internalAccums, externalAccums) + val serInstance = new JavaSerializer(new SparkConf).newInstance() + val taskSer = Task.serializeWithDependencies( + dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance) + // Now we're on the executors. + // Deserialize the task and assert that its accumulators are zero'ed out. + val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer) + val taskDeser = serInstance.deserialize[DummyTask]( + taskBytes, Thread.currentThread.getContextClassLoader) + // Assert that executors see only zeros + taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) } + taskDeser.internalAccums.foreach { a => assert(a.localValue == a.zero) } } } private[spark] object AccumulatorSuite { + import InternalAccumulator._ + /** - * Run one or more Spark jobs and verify that the peak execution memory accumulator - * is updated afterwards. + * Run one or more Spark jobs and verify that in at least one job the peak execution memory + * accumulator is updated afterwards. */ def verifyPeakExecutionMemorySet( sc: SparkContext, testName: String)(testBody: => Unit): Unit = { val listener = new SaveInfoListener sc.addSparkListener(listener) - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { jobId => - if (jobId == 0) { - // The first job is a dummy one to verify that the accumulator does not already exist - val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) - assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) - } else { - // In the subsequent jobs, verify that peak execution memory is updated - val accum = listener.getCompletedStageInfos - .flatMap(_.accumulables.values) - .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) - .getOrElse { - throw new TestFailedException( - s"peak execution memory accumulator not set in '$testName'", 0) - } - assert(accum.value.toLong > 0) - } - } - // Run the jobs - sc.parallelize(1 to 10).count() testBody + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + val isSet = accums.exists { a => + a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) + } + if (!isSet) { + throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0) + } } } @@ -357,6 +325,10 @@ private class SaveInfoListener extends SparkListener { private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated + @GuardedBy("this") + private var exception: Throwable = null + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq @@ -365,9 +337,20 @@ private class SaveInfoListener extends SparkListener { jobCompletionCallback = callback } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + /** Throw a stored exception, if any. */ + def maybeThrowException(): Unit = synchronized { + if (exception != null) { throw exception } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { if (jobCompletionCallback != null) { - jobCompletionCallback(jobEnd.jobId) + try { + jobCompletionCallback(jobEnd.jobId) + } catch { + // Store any exception thrown here so we can throw them later in the main thread. + // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. + case NonFatal(e) => exception = e + } } } @@ -379,3 +362,14 @@ private class SaveInfoListener extends SparkListener { completedTaskInfos += taskEnd.taskInfo } } + + +/** + * A dummy [[Task]] that contains internal and external [[Accumulator]]s. + */ +private[spark] class DummyTask( + val internalAccums: Seq[Accumulator[_]], + val externalAccums: Seq[Accumulator[_]]) + extends Task[Int](0, 0, 0, internalAccums) { + override def runTask(c: TaskContext): Int = 1 +} diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 4e678fbac6a39..80a1de6065b43 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -801,7 +801,7 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. - val taskEndReason = ExceptionFailure(null, null, null, null, null, None) + val taskEndReason = ExceptionFailure(null, null, null, null, None) sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index c7f629a14ba24..3777d77f8f5b3 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -215,14 +215,16 @@ class HeartbeatReceiverSuite val metrics = new TaskMetrics val blockManagerId = BlockManagerId(executorId, "localhost", 12345) val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId)) if (executorShouldReregister) { assert(response.reregisterBlockManager) } else { assert(!response.reregisterBlockManager) // Additionally verify that the scheduler callback is called with the correct parameters verify(scheduler).executorHeartbeatReceived( - Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + Matchers.eq(executorId), + Matchers.eq(Array(1L -> metrics.accumulatorUpdates())), + Matchers.eq(blockManagerId)) } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala new file mode 100644 index 0000000000000..630b46f828df7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.storage.{BlockId, BlockStatus} + + +class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { + import InternalAccumulator._ + import AccumulatorParam._ + + test("get param") { + assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam) + assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam) + assert(getParam(RESULT_SIZE) === LongAccumulatorParam) + assert(getParam(JVM_GC_TIME) === LongAccumulatorParam) + assert(getParam(RESULT_SERIALIZATION_TIME) === LongAccumulatorParam) + assert(getParam(MEMORY_BYTES_SPILLED) === LongAccumulatorParam) + assert(getParam(DISK_BYTES_SPILLED) === LongAccumulatorParam) + assert(getParam(PEAK_EXECUTION_MEMORY) === LongAccumulatorParam) + assert(getParam(UPDATED_BLOCK_STATUSES) === UpdatedBlockStatusesAccumulatorParam) + assert(getParam(TEST_ACCUM) === LongAccumulatorParam) + // shuffle read + assert(getParam(shuffleRead.REMOTE_BLOCKS_FETCHED) === IntAccumulatorParam) + assert(getParam(shuffleRead.LOCAL_BLOCKS_FETCHED) === IntAccumulatorParam) + assert(getParam(shuffleRead.REMOTE_BYTES_READ) === LongAccumulatorParam) + assert(getParam(shuffleRead.LOCAL_BYTES_READ) === LongAccumulatorParam) + assert(getParam(shuffleRead.FETCH_WAIT_TIME) === LongAccumulatorParam) + assert(getParam(shuffleRead.RECORDS_READ) === LongAccumulatorParam) + // shuffle write + assert(getParam(shuffleWrite.BYTES_WRITTEN) === LongAccumulatorParam) + assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam) + assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam) + // input + assert(getParam(input.READ_METHOD) === StringAccumulatorParam) + assert(getParam(input.RECORDS_READ) === LongAccumulatorParam) + assert(getParam(input.BYTES_READ) === LongAccumulatorParam) + // output + assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam) + assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam) + assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam) + // default to Long + assert(getParam(METRICS_PREFIX + "anything") === LongAccumulatorParam) + intercept[IllegalArgumentException] { + getParam("something that does not start with the right prefix") + } + } + + test("create by name") { + val executorRunTime = create(EXECUTOR_RUN_TIME) + val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES) + val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED) + val inputReadMethod = create(input.READ_METHOD) + assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME)) + assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES)) + assert(shuffleRemoteBlocksRead.name === Some(shuffleRead.REMOTE_BLOCKS_FETCHED)) + assert(inputReadMethod.name === Some(input.READ_METHOD)) + assert(executorRunTime.value.isInstanceOf[Long]) + assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]]) + // We cannot assert the type of the value directly since the type parameter is erased. + // Instead, try casting a `Seq` of expected type and see if it fails in run time. + updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)]) + assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int]) + assert(inputReadMethod.value.isInstanceOf[String]) + // default to Long + val anything = create(METRICS_PREFIX + "anything") + assert(anything.value.isInstanceOf[Long]) + } + + test("create") { + val accums = create() + val shuffleReadAccums = createShuffleReadAccums() + val shuffleWriteAccums = createShuffleWriteAccums() + val inputAccums = createInputAccums() + val outputAccums = createOutputAccums() + // assert they're all internal + assert(accums.forall(_.isInternal)) + assert(shuffleReadAccums.forall(_.isInternal)) + assert(shuffleWriteAccums.forall(_.isInternal)) + assert(inputAccums.forall(_.isInternal)) + assert(outputAccums.forall(_.isInternal)) + // assert they all count on failures + assert(accums.forall(_.countFailedValues)) + assert(shuffleReadAccums.forall(_.countFailedValues)) + assert(shuffleWriteAccums.forall(_.countFailedValues)) + assert(inputAccums.forall(_.countFailedValues)) + assert(outputAccums.forall(_.countFailedValues)) + // assert they all have names + assert(accums.forall(_.name.isDefined)) + assert(shuffleReadAccums.forall(_.name.isDefined)) + assert(shuffleWriteAccums.forall(_.name.isDefined)) + assert(inputAccums.forall(_.name.isDefined)) + assert(outputAccums.forall(_.name.isDefined)) + // assert `accums` is a strict superset of the others + val accumNames = accums.map(_.name.get).toSet + val shuffleReadAccumNames = shuffleReadAccums.map(_.name.get).toSet + val shuffleWriteAccumNames = shuffleWriteAccums.map(_.name.get).toSet + val inputAccumNames = inputAccums.map(_.name.get).toSet + val outputAccumNames = outputAccums.map(_.name.get).toSet + assert(shuffleReadAccumNames.subsetOf(accumNames)) + assert(shuffleWriteAccumNames.subsetOf(accumNames)) + assert(inputAccumNames.subsetOf(accumNames)) + assert(outputAccumNames.subsetOf(accumNames)) + } + + test("naming") { + val accums = create() + val shuffleReadAccums = createShuffleReadAccums() + val shuffleWriteAccums = createShuffleWriteAccums() + val inputAccums = createInputAccums() + val outputAccums = createOutputAccums() + // assert that prefixes are properly namespaced + assert(SHUFFLE_READ_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(SHUFFLE_WRITE_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(INPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(OUTPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(accums.forall(_.name.get.startsWith(METRICS_PREFIX))) + // assert they all start with the expected prefixes + assert(shuffleReadAccums.forall(_.name.get.startsWith(SHUFFLE_READ_METRICS_PREFIX))) + assert(shuffleWriteAccums.forall(_.name.get.startsWith(SHUFFLE_WRITE_METRICS_PREFIX))) + assert(inputAccums.forall(_.name.get.startsWith(INPUT_METRICS_PREFIX))) + assert(outputAccums.forall(_.name.get.startsWith(OUTPUT_METRICS_PREFIX))) + } + + test("internal accumulators in TaskContext") { + val taskContext = TaskContext.empty() + val accumUpdates = taskContext.taskMetrics.accumulatorUpdates() + assert(accumUpdates.size > 0) + assert(accumUpdates.forall(_.internal)) + val testAccum = taskContext.taskMetrics.getAccum(TEST_ACCUM) + assert(accumUpdates.exists(_.id == testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findTestAccum(stageInfos.head.accumulables.values) + assert(stageAccum.value.get.asInstanceOf[Long] === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findTestAccum(taskInfo.accumulables) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.asInstanceOf[Long] === 1L) + taskAccum.value.get.asInstanceOf[Long] + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + val rdd = sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 100 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findTestAccum(stageInfos(0).accumulables.values), + findTestAccum(stageInfos(1).accumulables.values), + findTestAccum(stageInfos(2).accumulables.values)) + assert(firstStageAccum.value.get.asInstanceOf[Long] === numPartitions) + assert(secondStageAccum.value.get.asInstanceOf[Long] === numPartitions * 10) + assert(thirdStageAccum.value.get.asInstanceOf[Long] === numPartitions * 2 * 100) + } + rdd.count() + } + + // TODO: these two tests are incorrect; they don't actually trigger stage retries. + ignore("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + ignore("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + test("internal accumulators are registered for cleanups") { + sc = new SparkContext("local", "test") { + private val myCleaner = new SaveAccumContextCleaner(this) + override def cleaner: Option[ContextCleaner] = Some(myCleaner) + } + assert(Accumulators.originals.isEmpty) + sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count() + val internalAccums = InternalAccumulator.create() + // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage + assert(Accumulators.originals.size === internalAccums.size * 2) + val accumsRegistered = sc.cleaner match { + case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup + case _ => Seq.empty[Long] + } + // Make sure the same set of accumulators is registered for cleanup + assert(accumsRegistered.size === internalAccums.size * 2) + assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet) + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findTestAccum(accums: Iterable[AccumulableInfo]): AccumulableInfo = { + accums.find { a => a.name == Some(TEST_ACCUM) }.getOrElse { + fail(s"unable to find internal accumulator called $TEST_ACCUM") + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + * TODO: make this actually retry the stage. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findTestAccum(stageInfos.head.accumulables.values) + // If all partitions failed, then we would resubmit the whole stage again and create a + // fresh set of internal accumulators. Otherwise, these internal accumulators do count + // failed values, so we must include the failed values. + val expectedAccumValue = + if (numPartitions == numFailedPartitions) { + numPartitions + } else { + numPartitions + numFailedPartitions + } + assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findTestAccum(taskInfo.accumulables) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.asInstanceOf[Long] === 1L) + assert(taskAccum.value.isDefined) + Some(taskAccum.value.get.asInstanceOf[Long]) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + listener.maybeThrowException() + } + + /** + * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. + */ + private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { + private val accumsRegistered = new ArrayBuffer[Long] + + override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + accumsRegistered += a.id + super.registerAccumulatorForCleanup(a) + } + + def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray + } + +} diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 9be9db01c7de9..d3359c7406e45 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -42,6 +42,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { test() } finally { logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") + // Avoid leaking map entries in tests that use accumulators without SparkContext + Accumulators.clear() } } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index e5ec2aa1be355..15be0b194ed8e 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -17,12 +17,542 @@ package org.apache.spark.executor -import org.apache.spark.SparkFunSuite +import org.scalatest.Assertions + +import org.apache.spark._ +import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId} + class TaskMetricsSuite extends SparkFunSuite { - test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { - val taskMetrics = new TaskMetrics() - taskMetrics.mergeShuffleReadMetrics() - assert(taskMetrics.shuffleReadMetrics.isEmpty) + import AccumulatorParam._ + import InternalAccumulator._ + import StorageLevel._ + import TaskMetricsSuite._ + + test("create") { + val internalAccums = InternalAccumulator.create() + val tm1 = new TaskMetrics + val tm2 = new TaskMetrics(internalAccums) + assert(tm1.accumulatorUpdates().size === internalAccums.size) + assert(tm1.shuffleReadMetrics.isEmpty) + assert(tm1.shuffleWriteMetrics.isEmpty) + assert(tm1.inputMetrics.isEmpty) + assert(tm1.outputMetrics.isEmpty) + assert(tm2.accumulatorUpdates().size === internalAccums.size) + assert(tm2.shuffleReadMetrics.isEmpty) + assert(tm2.shuffleWriteMetrics.isEmpty) + assert(tm2.inputMetrics.isEmpty) + assert(tm2.outputMetrics.isEmpty) + // TaskMetrics constructor expects minimal set of initial accumulators + intercept[IllegalArgumentException] { new TaskMetrics(Seq.empty[Accumulator[_]]) } + } + + test("create with unnamed accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.create() ++ Seq( + new Accumulator(0, IntAccumulatorParam, None, internal = true))) + } + } + + test("create with duplicate name accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.create() ++ Seq( + new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true))) + } + } + + test("create with external accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.create() ++ Seq( + new Accumulator(0, IntAccumulatorParam, Some("x")))) + } + } + + test("create shuffle read metrics") { + import shuffleRead._ + val accums = InternalAccumulator.createShuffleReadAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(REMOTE_BLOCKS_FETCHED).setValueAny(1) + accums(LOCAL_BLOCKS_FETCHED).setValueAny(2) + accums(REMOTE_BYTES_READ).setValueAny(3L) + accums(LOCAL_BYTES_READ).setValueAny(4L) + accums(FETCH_WAIT_TIME).setValueAny(5L) + accums(RECORDS_READ).setValueAny(6L) + val sr = new ShuffleReadMetrics(accums) + assert(sr.remoteBlocksFetched === 1) + assert(sr.localBlocksFetched === 2) + assert(sr.remoteBytesRead === 3L) + assert(sr.localBytesRead === 4L) + assert(sr.fetchWaitTime === 5L) + assert(sr.recordsRead === 6L) + } + + test("create shuffle write metrics") { + import shuffleWrite._ + val accums = InternalAccumulator.createShuffleWriteAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_WRITTEN).setValueAny(1L) + accums(RECORDS_WRITTEN).setValueAny(2L) + accums(WRITE_TIME).setValueAny(3L) + val sw = new ShuffleWriteMetrics(accums) + assert(sw.bytesWritten === 1L) + assert(sw.recordsWritten === 2L) + assert(sw.writeTime === 3L) + } + + test("create input metrics") { + import input._ + val accums = InternalAccumulator.createInputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_READ).setValueAny(1L) + accums(RECORDS_READ).setValueAny(2L) + accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString) + val im = new InputMetrics(accums) + assert(im.bytesRead === 1L) + assert(im.recordsRead === 2L) + assert(im.readMethod === DataReadMethod.Hadoop) + } + + test("create output metrics") { + import output._ + val accums = InternalAccumulator.createOutputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_WRITTEN).setValueAny(1L) + accums(RECORDS_WRITTEN).setValueAny(2L) + accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString) + val om = new OutputMetrics(accums) + assert(om.bytesWritten === 1L) + assert(om.recordsWritten === 2L) + assert(om.writeMethod === DataWriteMethod.Hadoop) + } + + test("mutating values") { + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + // initial values + assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L) + assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 0L) + assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 0L) + assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 0L) + assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 0L) + assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 0L) + assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 0L) + assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 0L) + assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, + Seq.empty[(BlockId, BlockStatus)]) + // set or increment values + tm.setExecutorDeserializeTime(100L) + tm.setExecutorDeserializeTime(1L) // overwrite + tm.setExecutorRunTime(200L) + tm.setExecutorRunTime(2L) + tm.setResultSize(300L) + tm.setResultSize(3L) + tm.setJvmGCTime(400L) + tm.setJvmGCTime(4L) + tm.setResultSerializationTime(500L) + tm.setResultSerializationTime(5L) + tm.incMemoryBytesSpilled(600L) + tm.incMemoryBytesSpilled(6L) // add + tm.incDiskBytesSpilled(700L) + tm.incDiskBytesSpilled(7L) + tm.incPeakExecutionMemory(800L) + tm.incPeakExecutionMemory(8L) + val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L)) + val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L)) + tm.incUpdatedBlockStatuses(Seq(block1)) + tm.incUpdatedBlockStatuses(Seq(block2)) + // assert new values exist + assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 1L) + assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 2L) + assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 3L) + assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 4L) + assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 5L) + assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 606L) + assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 707L) + assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 808L) + assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, + Seq(block1, block2)) + } + + test("mutating shuffle read metrics values") { + import shuffleRead._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = { + assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value) + } + // create shuffle read metrics + assert(tm.shuffleReadMetrics.isEmpty) + tm.registerTempShuffleReadMetrics() + tm.mergeShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isDefined) + val sr = tm.shuffleReadMetrics.get + // initial values + assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0) + assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0) + assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 0L) + assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 0L) + assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 0L) + assertValEquals(_.recordsRead, RECORDS_READ, 0L) + // set and increment values + sr.setRemoteBlocksFetched(100) + sr.setRemoteBlocksFetched(10) + sr.incRemoteBlocksFetched(1) // 10 + 1 + sr.incRemoteBlocksFetched(1) // 10 + 1 + 1 + sr.setLocalBlocksFetched(200) + sr.setLocalBlocksFetched(20) + sr.incLocalBlocksFetched(2) + sr.incLocalBlocksFetched(2) + sr.setRemoteBytesRead(300L) + sr.setRemoteBytesRead(30L) + sr.incRemoteBytesRead(3L) + sr.incRemoteBytesRead(3L) + sr.setLocalBytesRead(400L) + sr.setLocalBytesRead(40L) + sr.incLocalBytesRead(4L) + sr.incLocalBytesRead(4L) + sr.setFetchWaitTime(500L) + sr.setFetchWaitTime(50L) + sr.incFetchWaitTime(5L) + sr.incFetchWaitTime(5L) + sr.setRecordsRead(600L) + sr.setRecordsRead(60L) + sr.incRecordsRead(6L) + sr.incRecordsRead(6L) + // assert new values exist + assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 12) + assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 24) + assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 36L) + assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 48L) + assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 60L) + assertValEquals(_.recordsRead, RECORDS_READ, 72L) + } + + test("mutating shuffle write metrics values") { + import shuffleWrite._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = { + assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value) + } + // create shuffle write metrics + assert(tm.shuffleWriteMetrics.isEmpty) + tm.registerShuffleWriteMetrics() + assert(tm.shuffleWriteMetrics.isDefined) + val sw = tm.shuffleWriteMetrics.get + // initial values + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) + assertValEquals(_.writeTime, WRITE_TIME, 0L) + // increment and decrement values + sw.incBytesWritten(100L) + sw.incBytesWritten(10L) // 100 + 10 + sw.decBytesWritten(1L) // 100 + 10 - 1 + sw.decBytesWritten(1L) // 100 + 10 - 1 - 1 + sw.incRecordsWritten(200L) + sw.incRecordsWritten(20L) + sw.decRecordsWritten(2L) + sw.decRecordsWritten(2L) + sw.incWriteTime(300L) + sw.incWriteTime(30L) + // assert new values exist + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 108L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 216L) + assertValEquals(_.writeTime, WRITE_TIME, 330L) + } + + test("mutating input metrics values") { + import input._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = { + assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value, + (x: Any, y: Any) => assert(x.toString === y.toString)) + } + // create input metrics + assert(tm.inputMetrics.isEmpty) + tm.registerInputMetrics(DataReadMethod.Memory) + assert(tm.inputMetrics.isDefined) + val in = tm.inputMetrics.get + // initial values + assertValEquals(_.bytesRead, BYTES_READ, 0L) + assertValEquals(_.recordsRead, RECORDS_READ, 0L) + assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory) + // set and increment values + in.setBytesRead(1L) + in.setBytesRead(2L) + in.incRecordsRead(1L) + in.incRecordsRead(2L) + in.setReadMethod(DataReadMethod.Disk) + // assert new values exist + assertValEquals(_.bytesRead, BYTES_READ, 2L) + assertValEquals(_.recordsRead, RECORDS_READ, 3L) + assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk) + } + + test("mutating output metrics values") { + import output._ + val accums = InternalAccumulator.create() + val tm = new TaskMetrics(accums) + def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = { + assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value, + (x: Any, y: Any) => assert(x.toString === y.toString)) + } + // create input metrics + assert(tm.outputMetrics.isEmpty) + tm.registerOutputMetrics(DataWriteMethod.Hadoop) + assert(tm.outputMetrics.isDefined) + val out = tm.outputMetrics.get + // initial values + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) + assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + // set values + out.setBytesWritten(1L) + out.setBytesWritten(2L) + out.setRecordsWritten(3L) + out.setRecordsWritten(4L) + out.setWriteMethod(DataWriteMethod.Hadoop) + // assert new values exist + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L) + // Note: this doesn't actually test anything, but there's only one DataWriteMethod + // so we can't set it to anything else + assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + } + + test("merging multiple shuffle read metrics") { + val tm = new TaskMetrics + assert(tm.shuffleReadMetrics.isEmpty) + val sr1 = tm.registerTempShuffleReadMetrics() + val sr2 = tm.registerTempShuffleReadMetrics() + val sr3 = tm.registerTempShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isEmpty) + sr1.setRecordsRead(10L) + sr2.setRecordsRead(10L) + sr1.setFetchWaitTime(1L) + sr2.setFetchWaitTime(2L) + sr3.setFetchWaitTime(3L) + tm.mergeShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isDefined) + val sr = tm.shuffleReadMetrics.get + assert(sr.remoteBlocksFetched === 0L) + assert(sr.recordsRead === 20L) + assert(sr.fetchWaitTime === 6L) + + // SPARK-5701: calling merge without any shuffle deps does nothing + val tm2 = new TaskMetrics + tm2.mergeShuffleReadMetrics() + assert(tm2.shuffleReadMetrics.isEmpty) + } + + test("register multiple shuffle write metrics") { + val tm = new TaskMetrics + val sw1 = tm.registerShuffleWriteMetrics() + val sw2 = tm.registerShuffleWriteMetrics() + assert(sw1 === sw2) + assert(tm.shuffleWriteMetrics === Some(sw1)) + } + + test("register multiple input metrics") { + val tm = new TaskMetrics + val im1 = tm.registerInputMetrics(DataReadMethod.Memory) + val im2 = tm.registerInputMetrics(DataReadMethod.Memory) + // input metrics with a different read method than the one already registered are ignored + val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop) + assert(im1 === im2) + assert(im1 !== im3) + assert(tm.inputMetrics === Some(im1)) + im2.setBytesRead(50L) + im3.setBytesRead(100L) + assert(tm.inputMetrics.get.bytesRead === 50L) + } + + test("register multiple output metrics") { + val tm = new TaskMetrics + val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) + val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) + assert(om1 === om2) + assert(tm.outputMetrics === Some(om1)) + } + + test("additional accumulables") { + val internalAccums = InternalAccumulator.create() + val tm = new TaskMetrics(internalAccums) + assert(tm.accumulatorUpdates().size === internalAccums.size) + val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a")) + val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b")) + val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c")) + val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"), + internal = true, countFailedValues = true) + tm.registerAccumulator(acc1) + tm.registerAccumulator(acc2) + tm.registerAccumulator(acc3) + tm.registerAccumulator(acc4) + acc1 += 1 + acc2 += 2 + val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap + assert(newUpdates.contains(acc1.id)) + assert(newUpdates.contains(acc2.id)) + assert(newUpdates.contains(acc3.id)) + assert(newUpdates.contains(acc4.id)) + assert(newUpdates(acc1.id).name === Some("a")) + assert(newUpdates(acc2.id).name === Some("b")) + assert(newUpdates(acc3.id).name === Some("c")) + assert(newUpdates(acc4.id).name === Some("d")) + assert(newUpdates(acc1.id).update === Some(1)) + assert(newUpdates(acc2.id).update === Some(2)) + assert(newUpdates(acc3.id).update === Some(0)) + assert(newUpdates(acc4.id).update === Some(0)) + assert(!newUpdates(acc3.id).internal) + assert(!newUpdates(acc3.id).countFailedValues) + assert(newUpdates(acc4.id).internal) + assert(newUpdates(acc4.id).countFailedValues) + assert(newUpdates.values.map(_.update).forall(_.isDefined)) + assert(newUpdates.values.map(_.value).forall(_.isEmpty)) + assert(newUpdates.size === internalAccums.size + 4) + } + + test("existing values in shuffle read accums") { + // set shuffle read accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME)) + assert(srAccum.isDefined) + srAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isDefined) + assert(tm.shuffleWriteMetrics.isEmpty) + assert(tm.inputMetrics.isEmpty) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in shuffle write accums") { + // set shuffle write accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN)) + assert(swAccum.isDefined) + swAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isEmpty) + assert(tm.shuffleWriteMetrics.isDefined) + assert(tm.inputMetrics.isEmpty) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in input accums") { + // set input accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val inAccum = accums.find(_.name === Some(input.RECORDS_READ)) + assert(inAccum.isDefined) + inAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isEmpty) + assert(tm.shuffleWriteMetrics.isEmpty) + assert(tm.inputMetrics.isDefined) + assert(tm.outputMetrics.isEmpty) } + + test("existing values in output accums") { + // set output accum before passing it into TaskMetrics + val accums = InternalAccumulator.create() + val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN)) + assert(outAccum.isDefined) + outAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm4 = new TaskMetrics(accums) + assert(tm4.shuffleReadMetrics.isEmpty) + assert(tm4.shuffleWriteMetrics.isEmpty) + assert(tm4.inputMetrics.isEmpty) + assert(tm4.outputMetrics.isDefined) + } + + test("from accumulator updates") { + val accumUpdates1 = InternalAccumulator.create().map { a => + AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues) + } + val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1) + assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1) + // Test this with additional accumulators. Only the ones registered with `Accumulators` + // will show up in the reconstructed TaskMetrics. In practice, all accumulators created + // on the driver, internal or not, should be registered with `Accumulators` at some point. + // Here we show that reconstruction will succeed even if there are unregistered accumulators. + val param = IntAccumulatorParam + val registeredAccums = Seq( + new Accumulator(0, param, Some("a"), internal = true, countFailedValues = true), + new Accumulator(0, param, Some("b"), internal = true, countFailedValues = false), + new Accumulator(0, param, Some("c"), internal = false, countFailedValues = true), + new Accumulator(0, param, Some("d"), internal = false, countFailedValues = false)) + val unregisteredAccums = Seq( + new Accumulator(0, param, Some("e"), internal = true, countFailedValues = true), + new Accumulator(0, param, Some("f"), internal = true, countFailedValues = false)) + registeredAccums.foreach(Accumulators.register) + registeredAccums.foreach { a => assert(Accumulators.originals.contains(a.id)) } + unregisteredAccums.foreach { a => assert(!Accumulators.originals.contains(a.id)) } + // set some values in these accums + registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) } + unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) } + val registeredAccumInfos = registeredAccums.map(makeInfo) + val unregisteredAccumInfos = unregisteredAccums.map(makeInfo) + val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos + val metrics2 = TaskMetrics.fromAccumulatorUpdates(accumUpdates2) + // accumulators that were not registered with `Accumulators` will not show up + assertUpdatesEquals(metrics2.accumulatorUpdates(), accumUpdates1 ++ registeredAccumInfos) + } +} + + +private[spark] object TaskMetricsSuite extends Assertions { + + /** + * Assert that the following three things are equal to `value`: + * (1) TaskMetrics value + * (2) TaskMetrics accumulator update value + * (3) Original accumulator value + */ + def assertValueEquals( + tm: TaskMetrics, + tmValue: TaskMetrics => Any, + accums: Seq[Accumulator[_]], + metricName: String, + value: Any, + assertEquals: (Any, Any) => Unit = (x: Any, y: Any) => assert(x === y)): Unit = { + assertEquals(tmValue(tm), value) + val accum = accums.find(_.name == Some(metricName)) + assert(accum.isDefined) + assertEquals(accum.get.value, value) + val accumUpdate = tm.accumulatorUpdates().find(_.name == Some(metricName)) + assert(accumUpdate.isDefined) + assert(accumUpdate.get.value === None) + assertEquals(accumUpdate.get.update, Some(value)) + } + + /** + * Assert that two lists of accumulator updates are equal. + * Note: this does NOT check accumulator ID equality. + */ + def assertUpdatesEquals( + updates1: Seq[AccumulableInfo], + updates2: Seq[AccumulableInfo]): Unit = { + assert(updates1.size === updates2.size) + updates1.zip(updates2).foreach { case (info1, info2) => + // do not assert ID equals here + assert(info1.name === info2.name) + assert(info1.update === info2.update) + assert(info1.value === info2.value) + assert(info1.internal === info2.internal) + assert(info1.countFailedValues === info2.countFailedValues) + } + } + + /** + * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * info as an accumulator update. + */ + def makeInfo(a: Accumulable[_, _]): AccumulableInfo = { + new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues) + } + } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 0e60cc8e77873..2b5e4b80e96ab 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -31,7 +31,6 @@ object MemoryTestingUtils { taskAttemptId = 0, attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = env.metricsSystem, - internalAccumulators = Seq.empty) + metricsSystem = env.metricsSystem) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 370a284d2950f..d9c71ec2eae7b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -96,8 +95,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite - extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -111,8 +109,10 @@ class DAGSchedulerSuite override def schedulingMode: SchedulingMode = SchedulingMode.NONE override def start() = {} override def stop() = {} - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -189,7 +189,8 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception): Unit = { failure = exception } } - before { + override def beforeEach(): Unit = { + super.beforeEach() sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() @@ -202,17 +203,21 @@ class DAGSchedulerSuite results.clear() mapOutputTracker = new MapOutputTrackerMaster(conf) scheduler = new DAGScheduler( - sc, - taskScheduler, - sc.listenerBus, - mapOutputTracker, - blockManagerMaster, - sc.env) + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } - after { - scheduler.stop() + override def afterEach(): Unit = { + try { + scheduler.stop() + } finally { + super.afterEach() + } } override def afterAll() { @@ -242,26 +247,31 @@ class DAGSchedulerSuite * directly through CompletionEvents. */ private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => - it.next.asInstanceOf[Tuple2[_, _]]._1 + it.next.asInstanceOf[Tuple2[_, _]]._1 /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent( - taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2)) } } } - private def completeWithAccumulator(accumId: Long, taskSet: TaskSet, - results: Seq[(TaskEndReason, Any)]) { + private def completeWithAccumulator( + accumId: Long, + taskSet: TaskSet, + results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent( + taskSet.tasks(i), + result._1, + result._2, + Seq(new AccumulableInfo( + accumId, Some(""), Some(1), None, internal = false, countFailedValues = false)))) } } } @@ -338,9 +348,12 @@ class DAGSchedulerSuite } test("equals and hashCode AccumulableInfo") { - val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true) - val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) - val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + val accInfo1 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = true, countFailedValues = false) + val accInfo2 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false) + val accInfo3 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false) assert(accInfo1 !== accInfo2) assert(accInfo2 === accInfo3) assert(accInfo2.hashCode() === accInfo3.hashCode()) @@ -464,7 +477,7 @@ class DAGSchedulerSuite override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, TaskMetrics)], + accumUpdates: Array[(Long, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None @@ -499,8 +512,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) @@ -515,12 +528,12 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( - (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -829,23 +842,17 @@ class DAGSchedulerSuite HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -882,12 +889,9 @@ class DAGSchedulerSuite HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) @@ -900,12 +904,9 @@ class DAGSchedulerSuite assert(countSubmittedMapStageAttempts() === 2) // The second ResultTask fails, with a fetch failure for the output from the second mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(1), FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Another ResubmitFailedStages event should not result in another attempt for the map @@ -920,11 +921,11 @@ class DAGSchedulerSuite } /** - * This tests the case where a late FetchFailed comes in after the map stage has finished getting - * retried and a new reduce stage starts running. - */ + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ test("extremely late fetch failures don't cause multiple concurrent attempts for " + - "the same stage") { + "the same stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId @@ -952,12 +953,9 @@ class DAGSchedulerSuite assert(countSubmittedReduceStageAttempts() === 1) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Trigger resubmission of the failed map stage and finish the re-started map task. @@ -971,12 +969,9 @@ class DAGSchedulerSuite assert(countSubmittedReduceStageAttempts() === 2) // A late FetchFailed arrives from the second task in the original reduce stage. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(1), FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because @@ -1007,48 +1002,36 @@ class DAGSchedulerSuite assert(shuffleStage.numAvailableOutputs === 0) // should be ignored for being too old - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 0) // should work because it's a non-failed host (so the available map outputs will increase) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostB", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostB", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 1) // should be ignored for being too old - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 1) // should work because it's a new epoch, which will increase the number of available map // outputs, and also finish the stage taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(1), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1140,12 +1123,9 @@ class DAGSchedulerSuite // then one executor dies, and a task fails in stage 1 runEvent(ExecutorLost("exec-hostA")) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), - null, - null, - createFakeTaskInfo(), null)) // so we resubmit stage 0, which completes happily @@ -1155,13 +1135,10 @@ class DAGSchedulerSuite assert(stage0Resubmit.stageAttemptId === 1) val task = stage0Resubmit.tasks(0) assert(task.partitionId === 2) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( task, Success, - makeMapStatus("hostC", shuffleMapRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostC", shuffleMapRdd.partitions.length))) // now here is where things get tricky : we will now have a task set representing // the second attempt for stage 1, but we *also* have some tasks for the first attempt for @@ -1174,28 +1151,19 @@ class DAGSchedulerSuite // we'll have some tasks finish from the first attempt, and some finish from the second attempt, // so that we actually have all stage outputs, though no attempt has completed all its // tasks - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(3).tasks(0), Success, - makeMapStatus("hostC", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) - runEvent(CompletionEvent( + makeMapStatus("hostC", reduceRdd.partitions.length))) + runEvent(makeCompletionEvent( taskSets(3).tasks(1), Success, - makeMapStatus("hostC", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostC", reduceRdd.partitions.length))) // late task finish from the first attempt - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(2), Success, - makeMapStatus("hostB", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostB", reduceRdd.partitions.length))) // What should happen now is that we submit stage 2. However, we might not see an error // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But @@ -1242,21 +1210,21 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) // complete some of the tasks from the first stage, on one host - runEvent(CompletionEvent( - taskSets(0).tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) - runEvent(CompletionEvent( - taskSets(0).tasks(1), Success, - makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.length))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + Success, + makeMapStatus("hostA", reduceRdd.partitions.length))) // now that host goes down runEvent(ExecutorLost("exec-hostA")) // so we resubmit those tasks - runEvent(CompletionEvent( - taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) - runEvent(CompletionEvent( - taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null)) + runEvent(makeCompletionEvent(taskSets(0).tasks(1), Resubmitted, null)) // now complete everything on a different host complete(taskSets(0), Seq( @@ -1449,12 +1417,12 @@ class DAGSchedulerSuite // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -1469,15 +1437,15 @@ class DAGSchedulerSuite submit(finalRdd, Array(0)) // have the first stage complete normally complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) // have the second stage complete normally complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -1500,15 +1468,15 @@ class DAGSchedulerSuite cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) // complete stage 0 complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) // complete stage 1 complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. @@ -1606,6 +1574,28 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("accumulators are updated on exception failures") { + val acc1 = sc.accumulator(0L, "ingenieur") + val acc2 = sc.accumulator(0L, "boulanger") + val acc3 = sc.accumulator(0L, "agriculteur") + assert(Accumulators.get(acc1.id).isDefined) + assert(Accumulators.get(acc2.id).isDefined) + assert(Accumulators.get(acc3.id).isDefined) + val accInfo1 = new AccumulableInfo( + acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false) + val accInfo2 = new AccumulableInfo( + acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false) + val accInfo3 = new AccumulableInfo( + acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false) + val accumUpdates = Seq(accInfo1, accInfo2, accInfo3) + val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates) + submit(new MyRDD(sc, 1, Nil), Array(0)) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(Accumulators.get(acc1.id).get.value === 15L) + assert(Accumulators.get(acc2.id).get.value === 13L) + assert(Accumulators.get(acc3.id).get.value === 18L) + } + test("reduce tasks should be placed locally with map output") { // Create an shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) @@ -1614,9 +1604,9 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)))) + (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -1884,8 +1874,7 @@ class DAGSchedulerSuite submitMapStage(shuffleDep) val oldTaskSet = taskSets(0) - runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet // Pretend host A was lost @@ -1895,23 +1884,19 @@ class DAGSchedulerSuite assert(newEpoch > oldEpoch) // Suppose we also get a completed event from task 1 on the same host; this should be ignored - runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet // A completion from another task should work because it's a non-failed host - runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet // Now complete tasks in the second task set val newTaskSet = taskSets(1) assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA - runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet - runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) assert(results.size === 1) // Map stage job should now finally be complete assertDataStructuresEmpty() @@ -1962,5 +1947,21 @@ class DAGSchedulerSuite info } -} + private def makeCompletionEvent( + task: Task[_], + reason: TaskEndReason, + result: Any, + extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], + taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { + val accumUpdates = reason match { + case Success => + task.initialAccumulators.map { a => + new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues) + } + case ef: ExceptionFailure => ef.accumUpdates + case _ => Seq.empty[AccumulableInfo] + } + CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 761e82e6cf1ce..35215c15ea805 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. @@ -131,7 +131,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(sc.eventLogger.isDefined) val originalEvents = sc.eventLogger.get.loggedEvents val replayedEvents = eventMonster.loggedEvents - originalEvents.zip(replayedEvents).foreach { case (e1, e2) => assert(e1 === e2) } + originalEvents.zip(replayedEvents).foreach { case (e1, e2) => + // Don't compare the JSON here because accumulators in StageInfo may be out of order + JsonProtocolSuite.assertEquals( + JsonProtocol.sparkEventFromJson(e1), JsonProtocol.sparkEventFromJson(e2)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index e5ec44a9f3b6b..b3bb86db10a32 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -22,6 +22,8 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.executor.TaskMetricsSuite +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD @@ -57,8 +59,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) - val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) + val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) intercept[RuntimeException] { task.run(0, 0, null) } @@ -97,6 +98,57 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }.collect() assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + + test("accumulators are updated on exception failures") { + // This means use 1 core and 4 max task failures + sc = new SparkContext("local[1,4]", "test") + val param = AccumulatorParam.LongAccumulatorParam + // Create 2 accumulators, one that counts failed values and another that doesn't + val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) + val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) + // Fail first 3 attempts of every task. This means each task should be run 4 times. + sc.parallelize(1 to 10, 10).map { i => + acc1 += 1 + acc2 += 1 + if (TaskContext.get.attemptNumber() <= 2) { + throw new Exception("you did something wrong") + } else { + 0 + } + }.count() + // The one that counts failed values should be 4x the one that didn't, + // since we ran each task 4 times + assert(Accumulators.get(acc1.id).get.value === 40L) + assert(Accumulators.get(acc2.id).get.value === 10L) + } + + test("failed tasks collect only accumulators whose values count during failures") { + sc = new SparkContext("local", "test") + val param = AccumulatorParam.LongAccumulatorParam + val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) + val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) + val initialAccums = InternalAccumulator.create() + // Create a dummy task. We won't end up running this; we just want to collect + // accumulator updates from it. + val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) { + context = new TaskContextImpl(0, 0, 0L, 0, + new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), + SparkEnv.get.metricsSystem, + initialAccums) + context.taskMetrics.registerAccumulator(acc1) + context.taskMetrics.registerAccumulator(acc2) + override def runTask(tc: TaskContext): Int = 0 + } + // First, simulate task success. This should give us all the accumulators. + val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false) + val accumUpdates2 = (initialAccums ++ Seq(acc1, acc2)).map(TaskMetricsSuite.makeInfo) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2) + // Now, simulate task failures. This should give us only the accums that count failed values. + val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true) + val accumUpdates4 = (initialAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index cc2557c2f1df2..b5385c11a926e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -21,10 +21,15 @@ import java.io.File import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal +import com.google.common.util.concurrent.MoreExecutors +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, anyLong} +import org.mockito.Mockito.{spy, times, verify} import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ @@ -33,13 +38,14 @@ import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} + /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. * * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) +private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false @@ -72,6 +78,31 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule } } + +/** + * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors + * _before_ modifying the results in any way. + */ +private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl) + extends TaskResultGetter(env, scheduler) { + + // Use the current thread so we can access its results synchronously + protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor() + + // DirectTaskResults that we receive from the executors + private val _taskResults = new ArrayBuffer[DirectTaskResult[_]] + + def taskResults: Seq[DirectTaskResult[_]] = _taskResults + + override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = { + // work on a copy since the super class still needs to use the buffer + val newBuffer = data.duplicate() + _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer) + super.enqueueSuccessfulTask(tsm, tid, data) + } +} + + /** * Tests related to handling task results (both direct and indirect). */ @@ -182,5 +213,39 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local Thread.currentThread.setContextClassLoader(originalClassLoader) } } + + test("task result size is set on the driver, not the executors") { + import InternalAccumulator._ + + // Set up custom TaskResultGetter and TaskSchedulerImpl spy + sc = new SparkContext("local", "test", conf) + val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val spyScheduler = spy(scheduler) + val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler) + val newDAGScheduler = new DAGScheduler(sc, spyScheduler) + scheduler.taskResultGetter = resultGetter + sc.dagScheduler = newDAGScheduler + sc.taskScheduler = spyScheduler + sc.taskScheduler.setDAGScheduler(newDAGScheduler) + + // Just run 1 task and capture the corresponding DirectTaskResult + sc.parallelize(1 to 1, 1).count() + val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]]) + verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture()) + + // When a task finishes, the executor sends a serialized DirectTaskResult to the driver + // without setting the result size so as to avoid serializing the result again. Instead, + // the result size is set later in TaskResultGetter on the driver before passing the + // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult + // before and after the result size is set. + assert(resultGetter.taskResults.size === 1) + val resBefore = resultGetter.taskResults.head + val resAfter = captor.getValue + val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update) + val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update) + assert(resSizeBefore.exists(_ == 0L)) + assert(resSizeAfter.exists(_.toString.toLong > 0L)) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ecc18fc6e15b4..a2e74365641a6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) @@ -38,9 +37,8 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo) { taskScheduler.endedTasks(taskInfo.index) = reason } @@ -167,14 +165,17 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => + new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) + } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have - var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // Tell it the task has finished - manager.handleSuccessfulTask(0, createTaskResult(0)) + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) assert(sched.endedTasks(0) === Success) assert(sched.finishedManagers.contains(manager)) } @@ -184,10 +185,15 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => + task.initialAccumulators.map { a => + new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) + } + } // First three offers should all find tasks for (i <- 0 until 3) { - var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) val task = taskOption.get assert(task.executorId === "exec1") @@ -198,14 +204,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None) // Finish the first two tasks - manager.handleSuccessfulTask(0, createTaskResult(0)) - manager.handleSuccessfulTask(1, createTaskResult(1)) + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdatesByTask(0))) + manager.handleSuccessfulTask(1, createTaskResult(1, accumUpdatesByTask(1))) assert(sched.endedTasks(0) === Success) assert(sched.endedTasks(1) === Success) assert(!sched.finishedManagers.contains(manager)) // Finish the last task - manager.handleSuccessfulTask(2, createTaskResult(2)) + manager.handleSuccessfulTask(2, createTaskResult(2, accumUpdatesByTask(2))) assert(sched.endedTasks(2) === Success) assert(sched.finishedManagers.contains(manager)) } @@ -620,7 +626,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // multiple 1k result val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() - assert(10 === r.size ) + assert(10 === r.size) // single 10M result val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} @@ -761,7 +767,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Regression test for SPARK-2931 sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, - ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -786,8 +792,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) } - def createTaskResult(id: Int): DirectTaskResult[Int] = { + private def createTaskResult( + id: Int, + accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() - new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) + new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 86699e7f56953..b83ffa3282e4d 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { + private val peakExecutionMemory = 10 + test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") @@ -52,7 +54,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values - assert(html.contains("10.0 b" * 5)) + assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } /** @@ -79,14 +81,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) - val peakExecutionMemory = 10 - taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, - Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markSuccessful() + val taskMetrics = TaskMetrics.empty + taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 607617cbe91ca..18a16a25bfac5 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None, None), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0", true, Some("Induced failure")), @@ -269,20 +269,22 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = new TaskMetrics() - taskMetrics.setExecutorRunTime(base + 4) - taskMetrics.incDiskBytesSpilled(base + 5) - taskMetrics.incMemoryBytesSpilled(base + 6) + val accums = InternalAccumulator.create() + accums.foreach(Accumulators.register) + val taskMetrics = new TaskMetrics(accums) val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() + val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() + val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) + val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) shuffleReadMetrics.incRemoteBytesRead(base + 1) shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) taskMetrics.mergeShuffleReadMetrics() - val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() shuffleWriteMetrics.incBytesWritten(base + 3) - val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) - inputMetrics.incBytesRead(base + 7) - val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) + taskMetrics.setExecutorRunTime(base + 4) + taskMetrics.incDiskBytesSpilled(base + 5) + taskMetrics.incMemoryBytesSpilled(base + 6) + inputMetrics.setBytesRead(base + 7) outputMetrics.setBytesWritten(base + 8) taskMetrics } @@ -300,9 +302,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, 0, makeTaskMetrics(0)), - (1235L, 0, 0, makeTaskMetrics(100)), - (1236L, 1, 0, makeTaskMetrics(200))))) + (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()), + (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()), + (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates())))) var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e5ca2de4ad537..57021d1d3d528 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -22,6 +22,10 @@ import java.util.Properties import scala.collection.Map import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JArray, JInt, JString, JValue} +import org.json4s.JsonDSL._ +import org.scalatest.Assertions +import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ @@ -32,12 +36,7 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage._ class JsonProtocolSuite extends SparkFunSuite { - - val jobSubmissionTime = 1421191042750L - val jobCompletionTime = 1421191296660L - - val executorAddedTime = 1421458410000L - val executorRemovedTime = 1421458922000L + import JsonProtocolSuite._ test("SparkListenerEvent") { val stageSubmitted = @@ -82,9 +81,13 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") - val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( - (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, - hasHadoopInput = true, hasOutput = true)))) + val executorMetricsUpdate = { + // Use custom accum ID for determinism + val accumUpdates = + makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true) + .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) } + SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) + } testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -142,7 +145,7 @@ class JsonProtocolSuite extends SparkFunSuite { "Some exception") val fetchMetadataFailed = new MetadataFetchFailedException(17, 19, "metadata Fetch failed exception").toTaskEndReason - val exceptionFailure = new ExceptionFailure(exception, None) + val exceptionFailure = new ExceptionFailure(exception, Seq.empty[AccumulableInfo]) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -166,9 +169,8 @@ class JsonProtocolSuite extends SparkFunSuite { | Backward compatibility tests | * ============================== */ - test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, - None, None) + test("ExceptionFailure backward compatibility: full stack trace") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) @@ -273,14 +275,13 @@ class JsonProtocolSuite extends SparkFunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } - test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { - // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + test("ShuffleReadMetrics: Local bytes read backwards compatibility") { + // Metrics about local shuffle bytes read were added in 1.3.1. val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = false, hasRecords = false) assert(metrics.shuffleReadMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } - .removeField { case (field, _) => field == "Local Read Time" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) } @@ -371,22 +372,76 @@ class JsonProtocolSuite extends SparkFunSuite { } test("AccumulableInfo backward compatibility") { - // "Internal" property of AccumulableInfo were added after 1.5.1. - val accumulableInfo = makeAccumulableInfo(1) + // "Internal" property of AccumulableInfo was added in 1.5.1 + val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true) val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) .removeField({ _._1 == "Internal" }) val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) - assert(false === oldInfo.internal) + assert(!oldInfo.internal) + // "Count Failed Values" property of AccumulableInfo was added in 2.0.0 + val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo) + .removeField({ _._1 == "Count Failed Values" }) + val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2) + assert(!oldInfo2.countFailedValues) + } + + test("ExceptionFailure backward compatibility: accumulator updates") { + // "Task Metrics" was replaced with "Accumulator Updates" in 2.0.0. For older event logs, + // we should still be able to fallback to constructing the accumulator updates from the + // "Task Metrics" field, if it exists. + val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true) + val tmJson = JsonProtocol.taskMetricsToJson(tm) + val accumUpdates = tm.accumulatorUpdates() + val exception = new SparkException("sentimental") + val exceptionFailure = new ExceptionFailure(exception, accumUpdates) + val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure) + val tmFieldJson: JValue = "Task Metrics" -> tmJson + val oldExceptionFailureJson: JValue = + exceptionFailureJson.removeField { _._1 == "Accumulator Updates" }.merge(tmFieldJson) + val oldExceptionFailure = + JsonProtocol.taskEndReasonFromJson(oldExceptionFailureJson).asInstanceOf[ExceptionFailure] + assert(exceptionFailure.className === oldExceptionFailure.className) + assert(exceptionFailure.description === oldExceptionFailure.description) + assertSeqEquals[StackTraceElement]( + exceptionFailure.stackTrace, oldExceptionFailure.stackTrace, assertStackTraceElementEquals) + assert(exceptionFailure.fullStackTrace === oldExceptionFailure.fullStackTrace) + assertSeqEquals[AccumulableInfo]( + exceptionFailure.accumUpdates, oldExceptionFailure.accumUpdates, (x, y) => x == y) } - /** -------------------------- * - | Helper test running methods | - * --------------------------- */ + test("AccumulableInfo value de/serialization") { + import InternalAccumulator._ + val blocks = Seq[(BlockId, BlockStatus)]( + (TestBlockId("meebo"), BlockStatus(StorageLevel.MEMORY_ONLY, 1L, 2L)), + (TestBlockId("feebo"), BlockStatus(StorageLevel.DISK_ONLY, 3L, 4L))) + val blocksJson = JArray(blocks.toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> JsonProtocol.blockStatusToJson(status)) + }) + testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) + testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) + testAccumValue(Some(input.READ_METHOD), "aka", JString("aka")) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + // For anything else, we just cast the value to a string + testAccumValue(Some("anything"), blocks, JString(blocks.toString)) + testAccumValue(Some("anything"), 123, JString("123")) + } + +} + + +private[spark] object JsonProtocolSuite extends Assertions { + import InternalAccumulator._ + + private val jobSubmissionTime = 1421191042750L + private val jobCompletionTime = 1421191296660L + private val executorAddedTime = 1421458410000L + private val executorRemovedTime = 1421458922000L private def testEvent(event: SparkListenerEvent, jsonString: String) { val actualJsonString = compact(render(JsonProtocol.sparkEventToJson(event))) val newEvent = JsonProtocol.sparkEventFromJson(parse(actualJsonString)) - assertJsonStringEquals(jsonString, actualJsonString) + assertJsonStringEquals(jsonString, actualJsonString, event.getClass.getSimpleName) assertEquals(event, newEvent) } @@ -440,11 +495,19 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } + private def testAccumValue(name: Option[String], value: Any, expectedJson: JValue): Unit = { + val json = JsonProtocol.accumValueToJson(name, value) + assert(json === expectedJson) + val newValue = JsonProtocol.accumValueFromJson(name, json) + val expectedValue = if (name.exists(_.startsWith(METRICS_PREFIX))) value else value.toString + assert(newValue === expectedValue) + } + /** -------------------------------- * | Util methods for comparing events | - * --------------------------------- */ + * --------------------------------- */ - private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { + private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { (event1, event2) match { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => assert(e1.properties === e2.properties) @@ -478,14 +541,17 @@ class JsonProtocolSuite extends SparkFunSuite { assert(e1.executorId === e1.executorId) case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => assert(e1.execId === e2.execId) - assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { - val (taskId1, stageId1, stageAttemptId1, metrics1) = a - val (taskId2, stageId2, stageAttemptId2, metrics2) = b - assert(taskId1 === taskId2) - assert(stageId1 === stageId2) - assert(stageAttemptId1 === stageAttemptId2) - assertEquals(metrics1, metrics2) - }) + assertSeqEquals[(Long, Int, Int, Seq[AccumulableInfo])]( + e1.accumUpdates, + e2.accumUpdates, + (a, b) => { + val (taskId1, stageId1, stageAttemptId1, updates1) = a + val (taskId2, stageId2, stageAttemptId2, updates2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b)) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -544,7 +610,6 @@ class JsonProtocolSuite extends SparkFunSuite { } private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { - assert(metrics1.hostname === metrics2.hostname) assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) assert(metrics1.resultSize === metrics2.resultSize) assert(metrics1.jvmGCTime === metrics2.jvmGCTime) @@ -601,7 +666,7 @@ class JsonProtocolSuite extends SparkFunSuite { assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) assert(r1.fullStackTrace === r2.fullStackTrace) - assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) + assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), @@ -637,10 +702,16 @@ class JsonProtocolSuite extends SparkFunSuite { assertStackTraceElementEquals) } - private def assertJsonStringEquals(json1: String, json2: String) { + private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - assert(formatJsonString(json1) === formatJsonString(json2), - s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}") + if (formatJsonString(expected) != formatJsonString(actual)) { + // scalastyle:off + // This prints something useful if the JSON strings don't match + println("=== EXPECTED ===\n" + pretty(parse(expected)) + "\n") + println("=== ACTUAL ===\n" + pretty(parse(actual)) + "\n") + // scalastyle:on + throw new TestFailedException(s"$metadata JSON did not equal", 1) + } } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -699,7 +770,7 @@ class JsonProtocolSuite extends SparkFunSuite { /** ----------------------------------- * | Util methods for constructing events | - * ------------------------------------ */ + * ------------------------------------ */ private val properties = { val p = new Properties @@ -746,8 +817,12 @@ class JsonProtocolSuite extends SparkFunSuite { taskInfo } - private def makeAccumulableInfo(id: Int, internal: Boolean = false): AccumulableInfo = - AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id, internal) + private def makeAccumulableInfo( + id: Int, + internal: Boolean = false, + countFailedValues: Boolean = false): AccumulableInfo = + new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), + internal, countFailedValues) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is @@ -764,7 +839,6 @@ class JsonProtocolSuite extends SparkFunSuite { hasOutput: Boolean, hasRecords: Boolean = true) = { val t = new TaskMetrics - t.setHostname("localhost") t.setExecutorDeserializeTime(a) t.setExecutorRunTime(b) t.setResultSize(c) @@ -774,7 +848,7 @@ class JsonProtocolSuite extends SparkFunSuite { if (hasHadoopInput) { val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) - inputMetrics.incBytesRead(d + e + f) + inputMetrics.setBytesRead(d + e + f) inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) } else { val sr = t.registerTempShuffleReadMetrics() @@ -794,7 +868,7 @@ class JsonProtocolSuite extends SparkFunSuite { val sw = t.registerShuffleWriteMetrics() sw.incBytesWritten(a + b + c) sw.incWriteTime(b + c + d) - sw.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) + sw.incRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) } // Make at most 6 blocks t.setUpdatedBlockStatuses((1 to (e % 5 + 1)).map { i => @@ -826,14 +900,16 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -881,14 +957,16 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | } @@ -919,21 +997,24 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | } @@ -962,21 +1043,24 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | } @@ -1011,26 +1095,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1044,7 +1130,7 @@ class JsonProtocolSuite extends SparkFunSuite { | "Fetch Wait Time": 900, | "Remote Bytes Read": 1000, | "Local Bytes Read": 1100, - | "Total Records Read" : 10 + | "Total Records Read": 10 | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, @@ -1098,26 +1184,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1182,26 +1270,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1273,17 +1363,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1331,17 +1423,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1405,17 +1499,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1495,17 +1591,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | } @@ -1657,51 +1755,208 @@ class JsonProtocolSuite extends SparkFunSuite { """ private val executorMetricsUpdateJsonString = - s""" - |{ - | "Event": "SparkListenerExecutorMetricsUpdate", - | "Executor ID": "exec3", - | "Metrics Updated": [ - | { - | "Task ID": 1, - | "Stage ID": 2, - | "Stage Attempt ID": 3, - | "Task Metrics": { - | "Host Name": "localhost", - | "Executor Deserialize Time": 300, - | "Executor Run Time": 400, - | "Result Size": 500, - | "JVM GC Time": 600, - | "Result Serialization Time": 700, - | "Memory Bytes Spilled": 800, - | "Disk Bytes Spilled": 0, - | "Input Metrics": { - | "Data Read Method": "Hadoop", - | "Bytes Read": 2100, - | "Records Read": 21 - | }, - | "Output Metrics": { - | "Data Write Method": "Hadoop", - | "Bytes Written": 1200, - | "Records Written": 12 - | }, - | "Updated Blocks": [ - | { - | "Block ID": "rdd_0_0", - | "Status": { - | "Storage Level": { - | "Use Disk": true, - | "Use Memory": true, - | "Deserialized": false, - | "Replication": 2 - | }, - | "Memory Size": 0, - | "Disk Size": 0 - | } - | } - | ] - | } - | }] - |} - """.stripMargin + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Accumulator Updates": [ + | { + | "ID": 0, + | "Name": "$EXECUTOR_DESERIALIZE_TIME", + | "Update": 300, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 1, + | "Name": "$EXECUTOR_RUN_TIME", + | "Update": 400, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 2, + | "Name": "$RESULT_SIZE", + | "Update": 500, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 3, + | "Name": "$JVM_GC_TIME", + | "Update": 600, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 4, + | "Name": "$RESULT_SERIALIZATION_TIME", + | "Update": 700, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 5, + | "Name": "$MEMORY_BYTES_SPILLED", + | "Update": 800, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 6, + | "Name": "$DISK_BYTES_SPILLED", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 7, + | "Name": "$PEAK_EXECUTION_MEMORY", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 8, + | "Name": "$UPDATED_BLOCK_STATUSES", + | "Update": [ + | { + | "BlockID": "rdd_0_0", + | "Status": { + | "StorageLevel": { + | "UseDisk": true, + | "UseMemory": true, + | "Deserialized": false, + | "Replication": 2 + | }, + | "MemorySize": 0, + | "DiskSize": 0 + | } + | } + | ], + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 9, + | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 10, + | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 11, + | "Name": "${shuffleRead.REMOTE_BYTES_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 12, + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 13, + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 14, + | "Name": "${shuffleRead.RECORDS_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 15, + | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 16, + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 17, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 18, + | "Name": "${input.READ_METHOD}", + | "Update": "Hadoop", + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 19, + | "Name": "${input.BYTES_READ}", + | "Update": 2100, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 20, + | "Name": "${input.RECORDS_READ}", + | "Update": 21, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 21, + | "Name": "${output.WRITE_METHOD}", + | "Update": "Hadoop", + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 22, + | "Name": "${output.BYTES_WRITTEN}", + | "Update": 1200, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 23, + | "Name": "${output.RECORDS_WRITTEN}", + | "Update": 12, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 24, + | "Name": "$TEST_ACCUM", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | } + | ] + | } + | ] + |} + """.stripMargin } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fc7dc2181de85..968a2903f3010 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -175,6 +175,15 @@ object MimaExcludes { ) ++ Seq( // SPARK-12510 Refactor ActorReceiver to support Java ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") + ) ++ Seq( + // SPARK-12895 Implement TaskMetrics using accumulators + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") + ) ++ Seq( + // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulator.this") ) ++ Seq( // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 73dc8cb984471..75cb6d1137c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -79,17 +79,17 @@ case class Sort( sorter.setTestSpillFrequency(testSpillFrequency) } + val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. - val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled + val spillSizeBefore = metrics.memoryBytesSpilled val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) dataSize += sorter.getPeakMemoryUsage - spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 41799c596b6d3..001e9c306ac45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -418,10 +418,10 @@ class TungstenAggregationIterator( val mapMemory = hashMap.getPeakMemoryUsedBytes val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) val peakMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() dataSize += peakMemory - spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(peakMemory) } numOutputRows += 1 res diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 8222b84d33e3b..edd87c2d8ed07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -136,14 +136,17 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) val format = inputFormatClass.newInstance format match { @@ -208,6 +211,9 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } reader.getCurrentValue } @@ -228,8 +234,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index c9ea579b5e809..04640711d99d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -111,8 +111,7 @@ case class BroadcastHashJoin( val hashedRelation = broadcastRelation.value hashedRelation match { case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) case _ => } hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 6c7fa2eee5bfa..db8edd169dcfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -119,8 +119,7 @@ case class BroadcastHashOuterJoin( hashTable match { case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) case _ => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 004407b2e6925..8929dc3af1912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -66,8 +66,7 @@ case class BroadcastLeftSemiJoinHash( val hashedRelation = broadcastedRelation.value hashedRelation match { case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) case _ => } hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 52735c9d7f8c4..950dc7816241f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.metric -import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} +import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} import org.apache.spark.util.Utils /** @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( name: String, val param: SQLMetricParam[R, T]) - extends Accumulable[R, T](param.zero, param, Some(name), true) { + extends Accumulable[R, T](param.zero, param, Some(name), internal = true) { def reset(): Unit = { this.value = param.zero @@ -131,6 +131,8 @@ private[sql] object SQLMetrics { name: String, param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) + // This is an internal accumulator so we need to register it explicitly. + Accumulators.register(acc) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 83c64f755f90f..544606f1168b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -139,9 +139,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), - finishTask = false) + for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false) } } @@ -177,7 +176,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskId: Long, stageId: Int, stageAttemptID: Int, - accumulatorUpdates: Map[Long, Any], + accumulatorUpdates: Seq[AccumulableInfo], finishTask: Boolean): Unit = { _stageIdToStageMetrics.get(stageId) match { @@ -289,8 +288,10 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi for (stageId <- executionUIData.stages; stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; - accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { - accumulatorUpdate + accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { + assert(accumulatorUpdate.update.isDefined, s"accumulator update from " + + s"task did not have a partial value: ${accumulatorUpdate.name}") + (accumulatorUpdate.id, accumulatorUpdate.update.get) } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => @@ -328,9 +329,10 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { acc => - (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) - }.toMap, + taskEnd.taskInfo.accumulables.map { a => + val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L)) + a.copy(update = Some(newValue)) + }, finishTask = true) } @@ -406,4 +408,4 @@ private[ui] class SQLStageMetrics( private[ui] class SQLTaskMetrics( val attemptId: Long, // TODO not used yet var finished: Boolean, - var accumulatorUpdates: Map[Long, Any]) + var accumulatorUpdates: Seq[AccumulableInfo]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 47308966e92cb..10ccd4b8f60db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1648,7 +1648,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("external sorting updates peak execution memory") { AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala index 9575d26fd123f..273937fa8ce90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -49,8 +49,7 @@ case class ReferenceSort( val context = TaskContext.get() context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 9c258cb31f460..c7df8b51e2f13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -71,8 +71,7 @@ class UnsafeFixedWidthAggregationMapSuite taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null, - internalAccumulators = Seq.empty)) + metricsSystem = null)) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 8a95359d9de25..e03bd6a3e7d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -117,8 +117,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null, - internalAccumulators = Seq.empty)) + metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 647a7e9a4e196..86c2c25c2c7e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution.columnar +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + +class PartitionBatchPruningSuite + extends SparkFunSuite + with BeforeAndAfterEach + with SharedSQLContext { + import testImplicits._ private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize @@ -32,30 +39,41 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - - val pruningData = sparkContext.makeRDD((1 to 100).map { key => - val string = if (((key - 1) / 10) % 2 == 0) null else key.toString - TestData(key, string) - }, 5).toDF() - pruningData.registerTempTable("pruningData") - // Enable in-memory partition pruning sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { try { sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - sqlContext.uncacheTable("pruningData") } finally { super.afterAll() } } + override protected def beforeEach(): Unit = { + super.beforeEach() + // This creates accumulators, which get cleaned up after every single test, + // so we need to do this before every test. + val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val string = if (((key - 1) / 10) % 2 == 0) null else key.toString + TestData(key, string) + }, 5).toDF() + pruningData.registerTempTable("pruningData") + sqlContext.cacheTable("pruningData") + } + + override protected def afterEach(): Unit = { + try { + sqlContext.uncacheTable("pruningData") + } finally { + super.afterEach() + } + } + // Comparisons checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1)) checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 81a159d542c67..2c408c8878470 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import org.mockito.Mockito.{mock, when} + import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -67,9 +69,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { ) private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { - val metrics = new TaskMetrics - metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) - metrics.updateAccumulators() + val metrics = mock(classOf[TaskMetrics]) + when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) => + new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)), + value = None, internal = true, countFailedValues = true) + }.toSeq) metrics } @@ -114,17 +118,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(listener.getExecutionMetrics(0).isEmpty) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) @@ -133,9 +137,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) @@ -173,9 +177,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b46b0d2f6040a..9a24a2487a254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -140,7 +140,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) assert(peakMemoryAccumulator.size == 1) - peakMemoryAccumulator.head._2.value.toLong + peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long] } assert(sparkListener.getCompletedStageInfos.length == 2) From 32f741115bda5d7d7dbfcd9fe827ecbea7303ffa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 Jan 2016 13:27:32 -0800 Subject: [PATCH 1718/2089] [SPARK-13021][CORE] Fail fast when custom RDDs violate RDD.partition's API contract Spark's `Partition` and `RDD.partitions` APIs have a contract which requires custom implementations of `RDD.partitions` to ensure that for all `x`, `rdd.partitions(x).index == x`; in other words, the `index` reported by a repartition needs to match its position in the partitions array. If a custom RDD implementation violates this contract, then Spark has the potential to become stuck in an infinite recomputation loop when recomputing a subset of an RDD's partitions, since the tasks that are actually run will not correspond to the missing output partitions that triggered the recomputation. Here's a link to a notebook which demonstrates this problem: https://rawgit.com/JoshRosen/e520fb9a64c1c97ec985/raw/5e8a5aa8d2a18910a1607f0aa4190104adda3424/Violating%2520RDD.partitions%2520contract.html In order to guard against this infinite loop behavior, this patch modifies Spark so that it fails fast and refuses to compute RDDs' whose `partitions` violate the API contract. Author: Josh Rosen Closes #10932 from JoshRosen/SPARK-13021. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 7 +++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9dad7944144d8..be47172581b7f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag]( /** * Implemented by subclasses to return the set of partitions in this RDD. This method will only * be called once, so it is safe to implement a time-consuming computation in it. + * + * The partitions in this array must satisfy the following property: + * `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }` */ protected def getPartitions: Array[Partition] @@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag]( checkpointRDD.map(_.partitions).getOrElse { if (partitions_ == null) { partitions_ = getPartitions + partitions_.zipWithIndex.foreach { case (partition, index) => + require(partition.index == index, + s"partitions($index).partition == ${partition.index}, but it should equal $index") + } } partitions_ } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ef2ed445005d3..80347b800a7b4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") { + class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) { + + override def compute(part: Partition, context: TaskContext): Iterator[T] = { + prev.compute(part, context) + } + + override protected def getPartitions: Array[Partition] = { + prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i` + } + } + val rdd = new BadRDD(sc.parallelize(1 to 100, 100)) + val e = intercept[IllegalArgumentException] { + rdd.partitions + } + assert(e.getMessage.contains("partitions")) + } + test("nested RDDs are not supported (SPARK-5063)") { val rdd: RDD[Int] = sc.parallelize(1 to 100) val rdd2: RDD[Int] = sc.parallelize(1 to 100) From 680afabe78b77e4e63e793236453d69567d24290 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 27 Jan 2016 13:29:09 -0800 Subject: [PATCH 1719/2089] [SPARK-12938][SQL] DataFrame API for Bloom filter This PR integrates Bloom filter from spark-sketch into DataFrame. This version resorts to RDD.aggregate for building the filter. A more performant UDAF version can be built in future follow-up PRs. This PR also add 2 specify `put` version(`putBinary` and `putLong`) into `BloomFilter`, which makes it easier to build a Bloom filter over a `DataFrame`. Author: Wenchen Fan Closes #10937 from cloud-fan/bloom-filter. --- .../apache/spark/util/sketch/BloomFilter.java | 34 ++++- .../spark/util/sketch/BloomFilterImpl.java | 141 ++++++++++++------ .../spark/util/sketch/CountMinSketchImpl.java | 47 +----- .../org/apache/spark/util/sketch/Utils.java | 48 ++++++ .../spark/sql/DataFrameStatFunctions.scala | 76 +++++++++- .../apache/spark/sql/JavaDataFrameSuite.java | 31 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 22 +++ 7 files changed, 306 insertions(+), 93 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index d392fb187ad65..81772fcea0ec2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -49,9 +49,9 @@ public enum Version { * {@code BloomFilter} binary format version 1 (all values written in big-endian order): *
      *
    • Version number, always 1 (32 bit)
    • + *
    • Number of hash functions (32 bit)
    • *
    • Total number of words of the underlying bit array (32 bit)
    • *
    • The words/longs (numWords * 64 bit)
    • - *
    • Number of hash functions (32 bit)
    • *
    */ V1(1); @@ -97,6 +97,21 @@ int getVersionNumber() { */ public abstract boolean put(Object item); + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. + */ + public abstract boolean putString(String str); + + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put long. + */ + public abstract boolean putLong(long l); + + /** + * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + */ + public abstract boolean putBinary(byte[] bytes); + /** * Determines whether a given bloom filter is compatible with this bloom filter. For two * bloom filters to be compatible, they must have the same bit size. @@ -121,6 +136,23 @@ int getVersionNumber() { */ public abstract boolean mightContain(Object item); + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 + * string. + */ + public abstract boolean mightContainString(String str); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + */ + public abstract boolean mightContainLong(long l); + + /** + * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte + * array. + */ + public abstract boolean mightContainBinary(byte[] bytes); + /** * Writes out this {@link BloomFilter} to an output stream in binary format. * It is the caller's responsibility to close the stream. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 1c08d07afaeaa..35107e0b389d7 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,10 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter { +public class BloomFilterImpl extends BloomFilter implements Serializable { - private final int numHashFunctions; - private final BitArray bits; + private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { this(new BitArray(numBits), numHashFunctions); @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) { this.numHashFunctions = numHashFunctions; } + private BloomFilterImpl() {} + @Override public boolean equals(Object other) { if (other == this) { @@ -63,55 +65,75 @@ public long bitSize() { return bits.bitSize(); } - private static long hashObjectToLong(Object item) { + @Override + public boolean put(Object item) { if (item instanceof String) { - try { - byte[] bytes = ((String) item).getBytes("utf-8"); - return hashBytesToLong(bytes); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("Only support utf-8 string", e); - } + return putString((String) item); + } else if (item instanceof byte[]) { + return putBinary((byte[]) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - int h1 = Murmur3_x86_32.hashLong(longValue, 0); - int h2 = Murmur3_x86_32.hashLong(longValue, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + return putLong(Utils.integralToLong(item)); } } - private static long hashBytesToLong(byte[] bytes) { + @Override + public boolean putString(String str) { + return putBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean putBinary(byte[] bytes) { int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); - return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; } @Override - public boolean put(Object item) { + public boolean mightContainString(String str) { + return mightContainBinary(Utils.getBytesFromUTF8String(str)); + } + + @Override + public boolean mightContainBinary(byte[] bytes) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } - // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash - // values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy for long type, it hash the input long - // element with every i to produce n hash values. - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + @Override + public boolean putLong(long l) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); boolean bitsChanged = false; for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); @@ -125,12 +147,11 @@ public boolean put(Object item) { } @Override - public boolean mightContain(Object item) { - long bitSize = bits.bitSize(); - long hash64 = hashObjectToLong(item); - int h1 = (int) (hash64 >> 32); - int h2 = (int) hash64; + public boolean mightContainLong(long l) { + int h1 = Murmur3_x86_32.hashLong(l, 0); + int h2 = Murmur3_x86_32.hashLong(l, h1); + long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { int combinedHash = h1 + (i * h2); // Flip all the bits if it's negative (guaranteed positive number) @@ -144,6 +165,17 @@ public boolean mightContain(Object item) { return true; } + @Override + public boolean mightContain(Object item) { + if (item instanceof String) { + return mightContainString((String) item); + } else if (item instanceof byte[]) { + return mightContainBinary((byte[]) item); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + @Override public boolean isCompatible(BloomFilter other) { if (other == null) { @@ -191,11 +223,11 @@ public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); dos.writeInt(Version.V1.getVersionNumber()); - bits.writeTo(dos); dos.writeInt(numHashFunctions); + bits.writeTo(dos); } - public static BloomFilterImpl readFrom(InputStream in) throws IOException { + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); int version = dis.readInt(); @@ -203,6 +235,21 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { throw new IOException("Unexpected Bloom filter version number (" + version + ")"); } - return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + BloomFilterImpl filter = new BloomFilterImpl(); + filter.readFrom0(in); + return filter; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 8cc29e4076307..e49ae22906c4c 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private double eps; private double confidence; - private CountMinSketchImpl() { - } + private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; @@ -143,23 +142,7 @@ public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - addLong(longValue, count); + addLong(Utils.integralToLong(item), count); } } @@ -201,13 +184,7 @@ private int hash(long item, int count) { } private static int[] getHashBuckets(String key, int hashCount, int max) { - byte[] b; - try { - b = key.getBytes("UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException(e); - } - return getHashBuckets(b, hashCount, max); + return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); } private static int[] getHashBuckets(byte[] b, int hashCount, int max) { @@ -225,23 +202,7 @@ public long estimateCount(Object item) { if (item instanceof String) { return estimateCountForStringItem((String) item); } else { - long longValue; - - if (item instanceof Long) { - longValue = (Long) item; - } else if (item instanceof Integer) { - longValue = ((Integer) item).longValue(); - } else if (item instanceof Short) { - longValue = ((Short) item).longValue(); - } else if (item instanceof Byte) { - longValue = ((Byte) item).longValue(); - } else { - throw new IllegalArgumentException( - "Support for " + item.getClass().getName() + " not implemented" - ); - } - - return estimateCountForLongItem(longValue); + return estimateCountForLongItem(Utils.integralToLong(item)); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java new file mode 100644 index 0000000000000..a6b33313035b0 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.UnsupportedEncodingException; + +public class Utils { + public static byte[] getBytesFromUTF8String(String str) { + try { + return str.getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + throw new IllegalArgumentException("Only support utf-8 string", e); + } + } + + public static long integralToLong(Object i) { + long longValue; + + if (i instanceof Long) { + longValue = (Long) i; + } else if (i instanceof Integer) { + longValue = ((Integer) i).longValue(); + } else if (i instanceof Short) { + longValue = ((Short) i).longValue(); + } else if (i instanceof Byte) { + longValue = ((Byte) i).longValue(); + } else { + throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName()); + } + + return longValue; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 465b12bb59d1e..b0b6995a2214f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,9 +22,10 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * :: Experimental :: @@ -390,4 +391,75 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { } ) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) + } + + private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require(colType == StringType || colType.isInstanceOf[IntegralType], + s"Bloom filter only supports string type and integral types, but got $colType.") + + val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { + (filter, row) => + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + filter.putBinary(row.getUTF8String(0).getBytes) + filter + } else { + (filter, row) => + // TODO: specialize it. + filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) + filter + } + + singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9cf94e72d34e2..0d4c128cb36d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -40,6 +40,7 @@ import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -300,6 +301,7 @@ public void pivot() { Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + @Test public void testGenericLoad() { DataFrame df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); @@ -347,4 +349,33 @@ public void testCountMinSketch() { Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); } + + @Test + public void testBloomFilter() { + DataFrame df = context.range(1000); + + BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); + assert (filter1.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter1.mightContain(i)); + } + + BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); + assert (filter2.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter2.mightContain(i * 3)); + } + + BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); + assert (filter3.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter3.mightContain(i)); + } + + BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); + assert (filter4.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter4.mightContain(i * 3)); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f3ea5a2860ba..f01f126f7696d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -246,4 +246,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { .countMinSketch('id, depth = 10, width = 20, seed = 42) } } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = sqlContext.range(1000) + + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + + val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03) + assert(filter2.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(i => filter2.mightContain(i * 3))) + + val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter3.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter3.mightContain)) + + val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5) + assert(filter4.bitSize() == 64 * 5) + assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) + } } From ef96cd3c521c175878c38a1ed6eeeab0ed8346b5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 27 Jan 2016 13:45:00 -0800 Subject: [PATCH 1720/2089] [SPARK-12865][SPARK-12866][SQL] Migrate SparkSQLParser/ExtendedHiveQlParser commands to new Parser This PR moves all the functionality provided by the SparkSQLParser/ExtendedHiveQlParser to the new Parser hierarchy (SparkQl/HiveQl). This also improves the current SET command parsing: the current implementation swallows ```set role ...``` and ```set autocommit ...``` commands, this PR respects these commands (and passes them on to Hive). This PR and https://github.com/apache/spark/pull/10723 end the use of Parser-Combinator parsers for SQL parsing. As a result we can also remove the ```AbstractSQLParser``` in Catalyst. The PR is marked WIP as long as it doesn't pass all tests. cc rxin viirya winningsix (this touches https://github.com/apache/spark/pull/10144) Author: Herman van Hovell Closes #10905 from hvanhovell/SPARK-12866. --- .../sql/catalyst/parser/ExpressionParser.g | 12 +- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 5 + .../sql/catalyst/parser/SparkSqlParser.g | 62 +++++++-- .../spark/sql/catalyst/CatalystQl.scala | 22 ++++ .../spark/sql/catalyst/parser/ASTNode.scala | 14 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/SparkQl.scala | 31 +++++ .../spark/sql/execution/SparkSQLParser.scala | 124 ------------------ .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- .../spark/sql/hive/ExtendedHiveQlParser.scala | 70 ---------- .../apache/spark/sql/hive/HiveContext.scala | 8 +- .../org/apache/spark/sql/hive/HiveQl.scala | 18 ++- ...nctions-1-4a6f611305f58bdbafb2fd89ec62d797 | 4 + ...nctions-2-97cbada21ad9efda7ce9de5891deca7c | 1 + ...nctions-4-4deaa213aff83575bbaf859f79bfdd48 | 2 + .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 16 files changed, 161 insertions(+), 226 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 957bb234e4901..0555a6ba83cbb 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -167,8 +167,8 @@ intervalLiteral ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)? ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)? ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)? - (millisecond=intervalConstant KW_MILLISECOND)? - (microsecond=intervalConstant KW_MICROSECOND)? + ((intervalConstant KW_MILLISECOND)=> millisecond=intervalConstant KW_MILLISECOND)? + ((intervalConstant KW_MICROSECOND)=> microsecond=intervalConstant KW_MICROSECOND)? -> ^(TOK_INTERVAL ^(TOK_INTERVAL_YEAR_LITERAL $year?) ^(TOK_INTERVAL_MONTH_LITERAL $month?) @@ -505,10 +505,8 @@ identifier functionIdentifier @init { gParent.pushMsg("function identifier", state); } @after { gParent.popMsg(state); } - : db=identifier DOT fn=identifier - -> Identifier[$db.text + "." + $fn.text] - | - identifier + : + identifier (DOT identifier)? -> identifier+ ; principalIdentifier @@ -553,6 +551,8 @@ nonReserved | KW_SNAPSHOT | KW_AUTOCOMMIT | KW_ANTI + | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND + | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS ; //The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e4ffc634e8bf4..4374cd7ef7200 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -327,6 +327,11 @@ KW_AUTOCOMMIT: 'AUTOCOMMIT'; KW_WEEK: 'WEEK'|'WEEKS'; KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; +KW_CLEAR: 'CLEAR'; +KW_LAZY: 'LAZY'; +KW_CACHE: 'CACHE'; +KW_UNCACHE: 'UNCACHE'; +KW_DFS: 'DFS'; // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index c146ca5914884..35bef00351d72 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -371,6 +371,13 @@ TOK_TXN_READ_WRITE; TOK_COMMIT; TOK_ROLLBACK; TOK_SET_AUTOCOMMIT; +TOK_CACHETABLE; +TOK_UNCACHETABLE; +TOK_CLEARCACHE; +TOK_SETCONFIG; +TOK_DFS; +TOK_ADDFILE; +TOK_ADDJAR; } @@ -515,6 +522,11 @@ import java.util.HashMap; xlateMap.put("KW_WEEK", "WEEK"); xlateMap.put("KW_MILLISECOND", "MILLISECOND"); xlateMap.put("KW_MICROSECOND", "MICROSECOND"); + xlateMap.put("KW_CLEAR", "CLEAR"); + xlateMap.put("KW_LAZY", "LAZY"); + xlateMap.put("KW_CACHE", "CACHE"); + xlateMap.put("KW_UNCACHE", "UNCACHE"); + xlateMap.put("KW_DFS", "DFS"); // Operators xlateMap.put("DOT", "."); @@ -687,8 +699,12 @@ catch (RecognitionException e) { // starting rule statement - : explainStatement EOF - | execStatement EOF + : explainStatement EOF + | execStatement EOF + | KW_ADD KW_JAR -> ^(TOK_ADDJAR) + | KW_ADD KW_FILE -> ^(TOK_ADDFILE) + | KW_DFS -> ^(TOK_DFS) + | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG) ; explainStatement @@ -717,6 +733,7 @@ execStatement | deleteStatement | updateStatement | sqlTransactionStatement + | cacheStatement ; loadStatement @@ -1390,7 +1407,7 @@ showStatement @init { pushMsg("show statement", state); } @after { popMsg(state); } : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) - | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) + | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES ^(TOK_FROM $db_name)? showStmtIdentifier?) | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? -> ^(TOK_SHOWCOLUMNS tableName $db_name?) | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) @@ -2438,12 +2455,11 @@ BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly exc sqlTransactionStatement @init { pushMsg("transaction statement", state); } @after { popMsg(state); } - : - startTransactionStatement - | commitStatement - | rollbackStatement - | setAutoCommitStatement - ; + : startTransactionStatement + | commitStatement + | rollbackStatement + | setAutoCommitStatement + ; startTransactionStatement : @@ -2489,3 +2505,31 @@ setAutoCommitStatement /* END user defined transaction boundaries */ + +/* +Table Caching statements. + */ +cacheStatement +@init { pushMsg("cache statement", state); } +@after { popMsg(state); } + : + cacheTableStatement + | uncacheTableStatement + | clearCacheStatement + ; + +cacheTableStatement + : + KW_CACHE (lazy=KW_LAZY)? KW_TABLE identifier (KW_AS selectStatementWithCTE)? -> ^(TOK_CACHETABLE identifier $lazy? selectStatementWithCTE?) + ; + +uncacheTableStatement + : + KW_UNCACHE KW_TABLE identifier -> ^(TOK_UNCACHETABLE identifier) + ; + +clearCacheStatement + : + KW_CLEAR KW_CACHE -> ^(TOK_CLEARCACHE) + ; + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index f531d59a75cf8..536c292ab7f34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -210,6 +210,28 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { + case Token("TOK_SHOWFUNCTIONS", args) => + // Skip LIKE. + val pattern = args match { + case like :: nodes if like.text.toUpperCase == "LIKE" => nodes + case nodes => nodes + } + + // Extract Database and Function name + pattern match { + case Nil => + ShowFunctions(None, None) + case Token(name, Nil) :: Nil => + ShowFunctions(None, Some(unquoteString(name))) + case Token(db, Nil) :: Token(name, Nil) :: Nil => + ShowFunctions(Some(unquoteString(db)), Some(unquoteString(name))) + case _ => + noParseRule("SHOW FUNCTIONS", node) + } + + case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => + DescribeFunction(functionName, isExtended.nonEmpty) + case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = queryArgs match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala index ec5e71042d4be..ec9812414e19f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -27,10 +27,10 @@ case class ASTNode( children: List[ASTNode], stream: TokenRewriteStream) extends TreeNode[ASTNode] { /** Cache the number of children. */ - val numChildren = children.size + val numChildren: Int = children.size /** tuple used in pattern matching. */ - val pattern = Some((token.getText, children)) + val pattern: Some[(String, List[ASTNode])] = Some((token.getText, children)) /** Line in which the ASTNode starts. */ lazy val line: Int = { @@ -55,10 +55,16 @@ case class ASTNode( } /** Origin of the ASTNode. */ - override val origin = Origin(Some(line), Some(positionInLine)) + override val origin: Origin = Origin(Some(line), Some(positionInLine)) /** Source text. */ - lazy val source = stream.toString(startIndex, stopIndex) + lazy val source: String = stream.toString(startIndex, stopIndex) + + /** Get the source text that remains after this token. */ + lazy val remainder: String = { + stream.fill() + stream.toString(stopIndex + 1, stream.size() - 1).trim() + } def text: String = token.getText diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b774da33aebe4..be28df3a51557 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -204,7 +204,7 @@ class SQLContext private[sql]( protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient - protected[sql] val sqlParser: ParserInterface = new SparkSQLParser(new SparkQl(conf)) + protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) @transient protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index f3e89ef4a71f5..f6055306b6c97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ @@ -27,6 +28,18 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly protected override def nodeToPlan(node: ASTNode): LogicalPlan = { node match { + case Token("TOK_SETCONFIG", Nil) => + val keyValueSeparatorIndex = node.remainder.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = node.remainder.substring(0, keyValueSeparatorIndex).trim + val value = node.remainder.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (node.remainder.nonEmpty) { + SetCommand(Some(node.remainder -> None)) + } else { + SetCommand(None) + } + // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => ExplainCommand(OneRowRelation) @@ -75,6 +88,24 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly } } + case Token("TOK_CACHETABLE", Token(tableName, Nil) :: args) => + val Seq(lzy, selectAst) = getClauses(Seq("LAZY", "TOK_QUERY"), args) + CacheTableCommand(tableName, selectAst.map(nodeToPlan), lzy.isDefined) + + case Token("TOK_UNCACHETABLE", Token(tableName, Nil) :: Nil) => + UncacheTableCommand(tableName) + + case Token("TOK_CLEARCACHE", Nil) => + ClearCacheCommand + + case Token("TOK_SHOWTABLES", args) => + val databaseName = args match { + case Nil => None + case Token("TOK_FROM", Token(dbName, Nil) :: Nil) :: Nil => Option(dbName) + case _ => noParseRule("SHOW TABLES", node) + } + ShowTablesCommand(databaseName) + case _ => super.nodeToPlan(node) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala deleted file mode 100644 index d2d8271563726..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.parsing.combinator.RegexParsers - -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StringType - -/** - * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL - * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. - * - * @param fallback A function that returns the next parser in the chain. This is a call-by-name - * parameter because this allows us to return a different dialect if we - * have to. - */ -class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParser { - - override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = - fallback.parseTableIdentifier(sql) - - // A parser for the key-value part of the "SET [key = [value ]]" syntax - private object SetCommandParser extends RegexParsers { - private val key: Parser[String] = "(?m)[^=]+".r - - private val value: Parser[String] = "(?m).*$".r - - private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)()) - - private val pair: Parser[LogicalPlan] = - (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) - } - - def apply(input: String): LogicalPlan = parseAll(pair, input) match { - case Success(plan, _) => plan - case x => sys.error(x.toString) - } - } - - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val FUNCTION = Keyword("FUNCTION") - protected val FUNCTIONS = Keyword("FUNCTIONS") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") - protected val UNCACHE = Keyword("UNCACHE") - - override protected lazy val start: Parser[LogicalPlan] = - cache | uncache | set | show | desc | others - - private lazy val cache: Parser[LogicalPlan] = - CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan.map(fallback.parsePlan), isLazy.isDefined) - } - - private lazy val uncache: Parser[LogicalPlan] = - ( UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) - } - | CLEAR ~ CACHE ^^^ ClearCacheCommand - ) - - private lazy val set: Parser[LogicalPlan] = - SET ~> restInput ^^ { - case input => SetCommandParser(input) - } - - // It can be the following patterns: - // SHOW FUNCTIONS; - // SHOW FUNCTIONS mydb.func1; - // SHOW FUNCTIONS func1; - // SHOW FUNCTIONS `mydb.a`.`func1.aa`; - private lazy val show: Parser[LogicalPlan] = - ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ { - case _ ~ dbName => ShowTablesCommand(dbName) - } - | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) - case None => logical.ShowFunctions(None, None) - } - ) - - private lazy val desc: Parser[LogicalPlan] = - DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) - } - - private lazy val others: Parser[LogicalPlan] = - wholeInput ^^ { - case input => fallback.parsePlan(input) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 10ccd4b8f60db..989cb2942918e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -56,8 +56,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("show functions") { - checkAnswer(sql("SHOW functions"), - FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + def getFunctions(pattern: String): Seq[Row] = { + val regex = java.util.regex.Pattern.compile(pattern) + sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + } + checkAnswer(sql("SHOW functions"), getFunctions(".*")) + Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) + } } test("describe functions") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala deleted file mode 100644 index 313ba18f6aef0..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import scala.language.implicitConversions - -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand} - -/** - * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. - */ -private[hive] class ExtendedHiveQlParser(sqlContext: HiveContext) extends AbstractSparkSQLParser { - - val parser = new HiveQl(sqlContext.conf) - - override def parseExpression(sql: String): Expression = parser.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = - parser.parseTableIdentifier(sql) - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - - protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl - - protected lazy val hiveQl: Parser[LogicalPlan] = - restInput ^^ { - case statement => - sqlContext.executionHive.withHiveState { - parser.parsePlan(statement.trim) - } - } - - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => HiveNativeCommand(command.trim) - } - - private lazy val addFile: Parser[LogicalPlan] = - ADD ~ FILE ~> restInput ^^ { - case input => AddFile(input.trim) - } - - private lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case input => AddJar(input.trim) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index eaca3c9269bb7..1797ea54f2501 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -316,7 +316,9 @@ class HiveContext private[hive]( } protected[sql] override def parseSql(sql: String): LogicalPlan = { - super.parseSql(substitutor.substitute(hiveconf, sql)) + executionHive.withHiveState { + super.parseSql(substitutor.substitute(hiveconf, sql)) + } } override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = @@ -546,9 +548,7 @@ class HiveContext private[hive]( } @transient - protected[sql] override val sqlParser: ParserInterface = { - new SparkSQLParser(new ExtendedHiveQlParser(this)) - } + protected[sql] override val sqlParser: ParserInterface = new HiveQl(conf) @transient private val hivePlanner = new SparkPlanner(this) with HiveStrategies { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 46246f8191db1..22841ed2116d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -35,11 +35,12 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.ParseUtils._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.types._ import org.apache.spark.sql.AnalysisException @@ -113,7 +114,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_CREATEROLE", "TOK_DESCDATABASE", - "TOK_DESCFUNCTION", "TOK_DROPDATABASE", "TOK_DROPFUNCTION", @@ -151,7 +151,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_SHOW_TRANSACTIONS", "TOK_SHOWCOLUMNS", "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", "TOK_SHOWINDEXES", "TOK_SHOWLOCKS", "TOK_SHOWPARTITIONS", @@ -244,6 +243,15 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging protected override def nodeToPlan(node: ASTNode): LogicalPlan = { node match { + case Token("TOK_DFS", Nil) => + HiveNativeCommand(node.source + " " + node.remainder) + + case Token("TOK_ADDFILE", Nil) => + AddFile(node.remainder) + + case Token("TOK_ADDJAR", Nil) => + AddJar(node.remainder) + // Special drop table that also uncaches. case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") @@ -558,7 +566,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging protected override def nodeToTransformation( node: ASTNode, - child: LogicalPlan): Option[ScriptTransformation] = node match { + child: LogicalPlan): Option[logical.ScriptTransformation] = node match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -651,7 +659,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging schemaLess) Some( - ScriptTransformation( + logical.ScriptTransformation( inputExprs.map(nodeToExpr), unescapedScript, output, diff --git a/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 b/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 index 175795534fff5..f400819b67c26 100644 --- a/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 +++ b/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 @@ -1,4 +1,5 @@ case +cbrt ceil ceiling coalesce @@ -17,3 +18,6 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user diff --git a/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c b/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c index 3c25d656bda1c..19458fc86e439 100644 --- a/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c +++ b/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c @@ -2,6 +2,7 @@ assert_true case coalesce current_database +current_date decode e encode diff --git a/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 b/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 index cd2e58d04a4ef..1d05f843a7e0f 100644 --- a/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 +++ b/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 @@ -1,4 +1,6 @@ +current_date date_add +date_format date_sub datediff to_date diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9e53d8a81e753..0d62d799c8dce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.{ExtendedHiveQlParser, HiveContext, HiveQl, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveContext, HiveQl, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ From d702f0c170d5c39df501e173813f8a7718e3b3c6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 27 Jan 2016 14:01:55 -0800 Subject: [PATCH 1721/2089] [HOTFIX] Fix Scala 2.11 compilation by explicitly marking annotated parameters as vals (SI-8813). Caused by #10835. Author: Andrew Or Closes #10955 from andrewor14/fix-scala211. --- core/src/main/scala/org/apache/spark/Accumulable.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index bde136141f40d..52f572b63fa95 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -57,7 +57,8 @@ import org.apache.spark.util.Utils */ class Accumulable[R, T] private ( val id: Long, - @transient initialValue: R, + // SI-8813: This must explicitly be a private val, or else scala 2.11 doesn't compile + @transient private val initialValue: R, param: AccumulableParam[R, T], val name: Option[String], internal: Boolean, From 4a091232122b51f10521a68de8b1d9eb853b563d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 27 Jan 2016 15:35:31 -0800 Subject: [PATCH 1722/2089] [SPARK-13045] [SQL] Remove ColumnVector.Struct in favor of ColumnarBatch.Row These two classes became identical as the implementation progressed. Author: Nong Li Closes #10952 from nongli/spark-13045. --- .../execution/vectorized/ColumnVector.java | 104 +----------------- .../execution/vectorized/ColumnarBatch.java | 40 ++++--- .../vectorized/ColumnarBatchSuite.scala | 8 +- 3 files changed, 32 insertions(+), 120 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index c119758d68b36..a0bf8734b6545 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -210,104 +210,6 @@ public Object get(int ordinal, DataType dataType) { } } - /** - * Holder object to return a struct. This object is intended to be reused. - */ - public static final class Struct extends InternalRow { - // The fields that make up this struct. For example, if the struct had 2 int fields, the access - // to it would be: - // int f1 = fields[0].getInt[rowId] - // int f2 = fields[1].getInt[rowId] - public final ColumnVector[] fields; - - @Override - public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); } - - @Override - public boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } - - public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); } - - @Override - public short getShort(int ordinal) { - throw new NotImplementedException(); - } - - public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); } - public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); } - - @Override - public float getFloat(int ordinal) { - throw new NotImplementedException(); - } - - public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - Array a = getByteArray(ordinal); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); - } - - @Override - public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - return fields[ordinal].getStruct(rowId); - } - - public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); } - - @Override - public MapData getMap(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - throw new NotImplementedException(); - } - - public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); } - public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); } - - @Override - public final int numFields() { - return fields.length; - } - - @Override - public InternalRow copy() { - throw new NotImplementedException(); - } - - @Override - public boolean anyNull() { - throw new NotImplementedException(); - } - - protected int rowId; - - protected Struct(ColumnVector[] fields) { - this.fields = fields; - } - } - /** * Returns the data type of this column. */ @@ -494,7 +396,7 @@ public void reset() { /** * Returns a utility object to get structs. */ - public Struct getStruct(int rowId) { + public ColumnarBatch.Row getStruct(int rowId) { resultStruct.rowId = rowId; return resultStruct; } @@ -749,7 +651,7 @@ public final int appendStruct(boolean isNull) { /** * Reusable Struct holder for getStruct(). */ - protected final Struct resultStruct; + protected final ColumnarBatch.Row resultStruct; /** * Sets up the common state and also handles creating the child columns if this is a nested @@ -779,7 +681,7 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); } this.resultArray = null; - this.resultStruct = new Struct(this.childColumns); + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index d558dae50c227..5a575811fa896 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -86,13 +86,23 @@ public void close() { * performance is lost with this translation. */ public static final class Row extends InternalRow { - private int rowId; + protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; + private final ColumnVector[] columns; + // Ctor used if this is a top level row. private Row(ColumnarBatch parent) { this.parent = parent; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + } + + // Ctor used if this is a struct. + protected Row(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; } /** @@ -103,23 +113,23 @@ public final void markFiltered() { parent.markFiltered(rowId); } + public ColumnVector[] columns() { return columns; } + @Override - public final int numFields() { - return parent.numCols(); - } + public final int numFields() { return columns.length; } @Override /** * Revisit this. This is expensive. */ public final InternalRow copy() { - UnsafeRow row = new UnsafeRow(parent.numCols()); + UnsafeRow row = new UnsafeRow(numFields()); row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize); - for (int i = 0; i < parent.numCols(); i++) { + for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = parent.schema.fields()[i].dataType(); + DataType dt = columns[i].dataType(); if (dt instanceof IntegerType) { row.setInt(i, getInt(i)); } else if (dt instanceof LongType) { @@ -141,7 +151,7 @@ public final boolean anyNull() { @Override public final boolean isNullAt(int ordinal) { - return parent.column(ordinal).getIsNull(rowId); + return columns[ordinal].getIsNull(rowId); } @Override @@ -150,7 +160,7 @@ public final boolean getBoolean(int ordinal) { } @Override - public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); } + public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override public final short getShort(int ordinal) { @@ -159,11 +169,11 @@ public final short getShort(int ordinal) { @Override public final int getInt(int ordinal) { - return parent.column(ordinal).getInt(rowId); + return columns[ordinal].getInt(rowId); } @Override - public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); } + public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override public final float getFloat(int ordinal) { @@ -172,7 +182,7 @@ public final float getFloat(int ordinal) { @Override public final double getDouble(int ordinal) { - return parent.column(ordinal).getDouble(rowId); + return columns[ordinal].getDouble(rowId); } @Override @@ -182,7 +192,7 @@ public final Decimal getDecimal(int ordinal, int precision, int scale) { @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId); + ColumnVector.Array a = columns[ordinal].getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } @@ -198,12 +208,12 @@ public final CalendarInterval getInterval(int ordinal) { @Override public final InternalRow getStruct(int ordinal, int numFields) { - return parent.column(ordinal).getStruct(rowId); + return columns[ordinal].getStruct(rowId); } @Override public final ArrayData getArray(int ordinal) { - return parent.column(ordinal).getArray(rowId); + return columns[ordinal].getArray(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 215ca9ab6b770..67cc08b6fc8ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -439,10 +439,10 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(1, 5.67) val s = column.getStruct(0) - assert(s.fields(0).getInt(0) == 123) - assert(s.fields(0).getInt(1) == 456) - assert(s.fields(1).getDouble(0) == 3.45) - assert(s.fields(1).getDouble(1) == 5.67) + assert(s.columns()(0).getInt(0) == 123) + assert(s.columns()(0).getInt(1) == 456) + assert(s.columns()(1).getDouble(0) == 3.45) + assert(s.columns()(1).getDouble(1) == 5.67) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) From c2204436a15838f2dce44e3cfb0fe58236ef6196 Mon Sep 17 00:00:00 2001 From: James Lohse Date: Thu, 28 Jan 2016 10:50:50 +0000 Subject: [PATCH 1723/2089] Provide same info as in spark-submit --help this is stated for --packages and --repositories. Without stating it for --jars, people expect a standard java classpath to work, with expansion and using a different delimiter than a comma. Currently this is only state in the --help for spark-submit "Comma-separated list of local jars to include on the driver and executor classpaths." Author: James Lohse Closes #10890 from jimlohse/patch-1. --- docs/submitting-applications.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index acbb0f298fe47..413532f2f6cfa 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,8 +177,9 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. Spark uses the following URL scheme to allow -different strategies for disseminating jars: +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. + +Spark uses the following URL scheme to allow different strategies for disseminating jars: - **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor pulls the file from the driver HTTP server. From 415d0a859b7a76f3a866ec62ab472c4050f2a01b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 28 Jan 2016 12:26:03 -0800 Subject: [PATCH 1724/2089] [SPARK-12818][SQL] Specialized integral and string types for Count-min Sketch This PR is a follow-up of #10911. It adds specialized update methods for `CountMinSketch` so that we can avoid doing internal/external row format conversion in `DataFrame.countMinSketch()`. Author: Cheng Lian Closes #10968 from liancheng/cms-specialized. --- .../spark/util/sketch/CountMinSketch.java | 34 +++++++++- .../spark/util/sketch/CountMinSketchImpl.java | 35 ++++++++-- .../spark/sql/DataFrameStatFunctions.scala | 65 +++++++++++-------- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 5692e574d4c7e..f0aac5bb00dfb 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -115,15 +115,45 @@ int getVersionNumber() { public abstract long totalCount(); /** - * Adds 1 to {@code item}. + * Increments {@code item}'s count by one. */ public abstract void add(Object item); /** - * Adds {@code count} to {@code item}. + * Increments {@code item}'s count by {@code count}. */ public abstract void add(Object item, long count); + /** + * Increments {@code item}'s count by one. + */ + public abstract void addLong(long item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addLong(long item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addString(String item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addString(String item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addBinary(byte[] item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addBinary(byte[] item, long count); + /** * Returns the estimated frequency of {@code item}. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e49ae22906c4c..c0631c6778df4 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -25,7 +25,6 @@ import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; -import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; @@ -146,27 +145,49 @@ public void add(Object item, long count) { } } - private void addString(String item, long count) { + @Override + public void addString(String item) { + addString(item, 1); + } + + @Override + public void addString(String item, long count) { + addBinary(Utils.getBytesFromUTF8String(item), count); + } + + @Override + public void addLong(long item) { + addLong(item, 1); + } + + @Override + public void addLong(long item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } - int[] buckets = getHashBuckets(item, depth, width); - for (int i = 0; i < depth; ++i) { - table[i][buckets[i]] += count; + table[i][hash(item, i)] += count; } totalCount += count; } - private void addLong(long item, long count) { + @Override + public void addBinary(byte[] item) { + addBinary(item, 1); + } + + @Override + public void addBinary(byte[] item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { - table[i][hash(item, i)] += count; + table[i][buckets[i]] += count; } totalCount += count; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b0b6995a2214f..bb3cc02800d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** @@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * - * * @param col1 The name of the first column. Distinct items will make the first item of * each row. * @param col2 The name of the second column. Distinct items will make the column names @@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { val singleCol = df.select(col) val colType = singleCol.schema.head.dataType - require( - colType == StringType || colType.isInstanceOf[IntegralType], - s"Count-min Sketch only supports string type and integral types, " + - s"and does not support type $colType." - ) + val updater: (CountMinSketch, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` + // instead of `addString` to avoid unnecessary conversion. + case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes) + case ByteType => (sketch, row) => sketch.addLong(row.getByte(0)) + case ShortType => (sketch, row) => sketch.addLong(row.getShort(0)) + case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0)) + case LongType => (sketch, row) => sketch.addLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + } - singleCol.rdd.aggregate(zero)( - (sketch: CountMinSketch, row: Row) => { - sketch.add(row.get(0)) + singleCol.queryExecution.toRdd.aggregate(zero)( + (sketch: CountMinSketch, row: InternalRow) => { + updater(sketch, row) sketch }, - - (sketch1: CountMinSketch, sketch2: CountMinSketch) => { - sketch1.mergeInPlace(sketch2) - } + (sketch1, sketch2) => sketch1.mergeInPlace(sketch2) ) } @@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { require(colType == StringType || colType.isInstanceOf[IntegralType], s"Bloom filter only supports string type and integral types, but got $colType.") - val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { - (filter, row) => - // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` - // instead of `putString` to avoid unnecessary conversion. - filter.putBinary(row.getUTF8String(0).getBytes) - filter - } else { - (filter, row) => - // TODO: specialize it. - filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) - filter + val updater: (BloomFilter, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes) + case ByteType => (filter, row) => filter.putLong(row.getByte(0)) + case ShortType => (filter, row) => filter.putLong(row.getShort(0)) + case IntegerType => (filter, row) => filter.putLong(row.getInt(0)) + case LongType => (filter, row) => filter.putLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Bloom filter only supports string type and integral types, " + + s"and does not support type $colType." + ) } - singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + singleCol.queryExecution.toRdd.aggregate(zero)( + (filter: BloomFilter, row: InternalRow) => { + updater(filter, row) + filter + }, + (filter1, filter2) => filter1.mergeInPlace(filter2) + ) } } From 676803963fcc08aa988aa6f14be3751314e006ca Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 28 Jan 2016 13:45:28 -0800 Subject: [PATCH 1725/2089] [SPARK-12926][SQL] SQLContext to display warning message when non-sql configs are being set Users unknowingly try to set core Spark configs in SQLContext but later realise that it didn't work. eg. sqlContext.sql("SET spark.shuffle.memoryFraction=0.4"). This PR adds a warning message when such operations are done. Author: Tejas Patil Closes #10849 from tejasapatil/SPARK-12926. --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index c9ba6700998c3..eb9da0bd4fd4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.util.Utils @@ -519,7 +520,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf { +private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -628,7 +629,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon // Only verify configs in the SQLConf object entry.valueConverter(value) } - settings.put(key, value) + setConfWithCheck(key, value) } /** Set the given Spark SQL configuration property. */ @@ -636,7 +637,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon require(entry != null, "entry cannot be null") require(value != null, s"value cannot be null for key: ${entry.key}") require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - settings.put(entry.key, entry.stringConverter(value)) + setConfWithCheck(entry.key, entry.stringConverter(value)) } /** Return the value of Spark SQL configuration property for the given key. */ @@ -699,6 +700,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon }.toSeq } + private def setConfWithCheck(key: String, value: String): Unit = { + if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) { + logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value") + } + settings.put(key, value) + } + private[spark] def unsetConf(key: String): Unit = { settings.remove(key) } From cc18a7199240bf3b03410c1ba6704fe7ce6ae38e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 13:51:55 -0800 Subject: [PATCH 1726/2089] [SPARK-13031] [SQL] cleanup codegen and improve test coverage 1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. Author: Davies Liu Closes #10944 from davies/gen_refactor. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++++++++------ .../aggregate/TungstenAggregate.scala | 88 +++++--- .../spark/sql/execution/basicOperators.scala | 96 +++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 334 insertions(+), 202 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad374..e6704cf8bb1f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,14 +144,23 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = { + if (freshNamePrefix == "") { + s"$name${curId.getAndIncrement}" + } else { + s"${freshNamePrefix}_$name${curId.getAndIncrement}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6ef..ec31db19b94b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de9804..ef81ba60f049f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns the RDD of InternalRow which generates the input rows. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def upstream(): RDD[InternalRow] + + /** + * Returns Java source code to process the rows from upstream. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent + ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) } + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + ctx.freshNamePrefix = nodeName + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, evals)} + """.stripMargin + } else { + doConsume(ctx, input) + } + } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doPrepare(): Unit = { + child.prepare() + } - override def supportCodegen: Boolean = true + override def doExecute(): RDD[InternalRow] = { + child.execute() + } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def supportCodegen: Boolean = false + + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" - | while (input.hasNext()) { + s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() --------> produce() + * doExecute() ---------> upstream() -------> upstream() ------> execute() + * | + * -----------------> produce() * | * doProduce() -------> produce() * | - * doProduce() ---> execute() + * doProduce() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| * | - * doConsume() <----- consume() + * doConsume() + * | + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) + val code = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() { + protected void processNext() throws java.io.IOException { $code - } + } } - """ + """ + // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - rdd.mapPartitions { iter => + plan.upstream().mapPartitions { iter => + val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | currentRow = unsafeRow; + | currentRow = $row; | return; """.stripMargin + } else { + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } } } @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d252..cbd2634b8900f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,9 +117,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } // The variables used as aggregation buffer @@ -127,7 +125,11 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -137,50 +139,80 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = + // generate variables for output + val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) + } else { + // output the aggregate buffer directly + (bufVars, "") + } + + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, s""" - | if (!$initAgg) { - | $initAgg = true; - | + | private void $doAgg() { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } - """.stripMargin + """.stripMargin) - (rdd, source) + s""" + | if (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | $genResult + | + | ${consume(ctx, resultVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val updates = updateExpr.zipWithIndex.map { case (e, i) => + val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -190,7 +222,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | ${updates.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5ec..e7a73d5fbb4bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -153,17 +161,21 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { - val initTerm = ctx.freshName("range_initRange") + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("range_partitionEnd") + val partitionEnd = ctx.freshName("partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("range_number") + val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("range_overflow") + val overflow = ctx.freshName("overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("range_value") + val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -172,38 +184,42 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } + """.stripMargin) - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } + | initRange(((InternalRow) input.next()).getInt(0)); | } else { | return; | } @@ -218,12 +234,6 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 989cb2942918e..51a50c1fa30e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } - - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cbae19ebd269d..82f6811503c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48143762cac0..7d6bff8295d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9a24a2487a254..a3e5243b68aba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From df78a934a07a4ce5af43243be9ba5fe60b91eee6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 28 Jan 2016 14:29:47 -0800 Subject: [PATCH 1727/2089] [SPARK-9835][ML] Implement IterativelyReweightedLeastSquares solver Implement ```IterativelyReweightedLeastSquares``` solver for GLM. I consider it as a solver rather than estimator, it only used internal so I keep it ```private[ml]```. There are two limitations in the current implementation compared with R: * It can not support ```Tuple``` as response for ```Binomial``` family, such as the following code: ``` glm( cbind(using, notUsing) ~ age + education + wantsMore , family = binomial) ``` * It does not support ```offset```. Because I considered that ```RFormula``` did not support ```Tuple``` as label and ```offset``` keyword, so I simplified the implementation. But to add support for these two functions is not very hard, I can do it in follow-up PR if it is necessary. Meanwhile, we can also add R-like statistic summary for IRLS. The implementation refers R, [statsmodels](https://github.com/statsmodels/statsmodels) and [sparkGLM](https://github.com/AlteryxLabs/sparkGLM). Please focus on the main structure and overpass minor issues/docs that I will update later. Any comments and opinions will be appreciated. cc mengxr jkbradley Author: Yanbo Liang Closes #10639 from yanboliang/spark-9835. --- .../IterativelyReweightedLeastSquares.scala | 108 ++++++++++ .../spark/ml/optim/WeightedLeastSquares.scala | 7 +- ...erativelyReweightedLeastSquaresSuite.scala | 200 ++++++++++++++++++ 3 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala new file mode 100644 index 0000000000000..6aa44e6ba723e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.Logging +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg._ +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[IterativelyReweightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + */ +private[ml] class IterativelyReweightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double) extends Serializable + +/** + * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve + * certain optimization problems by an iterative method. In each step of the iterations, it + * involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]]. + * It can be used to find maximum likelihood estimates of a generalized linear model (GLM), + * find M-estimator in robust regression and other optimization problems. + * + * @param initialModel the initial guess model. + * @param reweightFunc the reweight function which is used to update offsets and weights + * at each iteration. + * @param fitIntercept whether to fit intercept. + * @param regParam L2 regularization parameter used by WLS. + * @param maxIter maximum number of iterations. + * @param tol the convergence tolerance. + * + * @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares + * for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives, + * Journal of the Royal Statistical Society. Series B, 1984.]] + */ +private[ml] class IterativelyReweightedLeastSquares( + val initialModel: WeightedLeastSquaresModel, + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val fitIntercept: Boolean, + val regParam: Double, + val maxIter: Int, + val tol: Double) extends Logging with Serializable { + + def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + + var converged = false + var iter = 0 + + var model: WeightedLeastSquaresModel = initialModel + var oldModel: WeightedLeastSquaresModel = null + + while (iter < maxIter && !converged) { + + oldModel = model + + // Update offsets and weights using reweightFunc + val newInstances = instances.map { instance => + val (newOffset, newWeight) = reweightFunc(instance, oldModel) + Instance(newOffset, newWeight, instance.features) + } + + // Estimate new model + model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false, + standardizeLabel = false).fit(newInstances) + + // Check convergence + val oldCoefficients = oldModel.coefficients + val coefficients = model.coefficients + BLAS.axpy(-1.0, coefficients, oldCoefficients) + val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => + math.max(math.abs(x), math.abs(y)) + } + val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) + + if (maxTol < tol) { + converged = true + logInfo(s"IRLS converged in $iter iterations.") + } + + logInfo(s"Iteration $iter : relative tolerance = $maxTol") + iter = iter + 1 + + if (iter == maxIter) { + logInfo(s"IRLS reached the max number of iterations: $maxIter.") + } + + } + + new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 797870eb8ce8a..61b3642131810 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, val intercept: Double, - val diagInvAtWA: DenseVector) extends Serializable + val diagInvAtWA: DenseVector) extends Serializable { + + def predict(features: Vector): Double = { + BLAS.dot(coefficients, features) + intercept + } +} /** * Weighted least squares solver via normal equation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala new file mode 100644 index 0000000000000..604021220a139 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances1: RDD[Instance] = _ + private var instances2: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) + b <- c(1, 0, 1, 0) + w <- c(1, 2, 3, 4) + */ + instances1 = sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ), 2) + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + */ + instances2 = sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + } + + test("IRLS against GLM with Binomial errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="binomial", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.30216651 -0.04452045 + [1] 3.5651651 -1.2334085 -0.7348971 + */ + val expected = Seq( + Vectors.dense(0.0, -0.30216651, -0.04452045), + Vectors.dense(3.5651651, -1.2334085, -0.7348971)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val newInstances = instances1.map { instance => + val mu = (instance.label + 0.5) / 2.0 + val eta = math.log(mu / (1.0 - mu)) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against GLM with Poisson errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="poisson", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.09607792 0.18375613 + [1] 6.299947 3.324107 -1.081766 + */ + val expected = Seq( + Vectors.dense(0.0, -0.09607792, 0.18375613), + Vectors.dense(6.299947, 3.324107, -1.081766)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val yMean = instances2.map(_.label).mean + val newInstances = instances2.map { instance => + val mu = (instance.label + yMean) / 2.0 + val eta = math.log(mu) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against L1Regression") { + /* + R code: + + library(quantreg) + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- rq(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] 1.266667 0.400000 + [1] 29.5 17.0 -5.5 + */ + val expected = Seq( + Vectors.dense(0.0, 1.266667, 0.400000), + Vectors.dense(29.5, 17.0, -5.5)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(instances2) + val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} + +object IterativelyReweightedLeastSquaresSuite { + + def BinomialReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) + val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val w = mu * (1 - mu) * instance.weight + (z, w) + } + + def PoissonReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = math.exp(eta) + val z = eta + (instance.label - mu) / mu + val w = mu * instance.weight + (z, w) + } + + def L1RegressionReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val e = math.max(math.abs(eta - instance.label), 1e-7) + val w = 1 / e + val y = instance.label + (y, w) + } +} From abae889f08eb412cb897e4e63614ec2c93885ffd Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 28 Jan 2016 15:20:16 -0800 Subject: [PATCH 1728/2089] [SPARK-12401][SQL] Add integration tests for postgres enum types We can handle posgresql-specific enum types as strings in jdbc. So, we should just add tests and close the corresponding JIRA ticket. Author: Takeshi YAMAMURO Closes #10596 from maropu/AddTestsInIntegration. --- .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 7d011be37067b..72bda8fe1ef10 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import java.util.Properties import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.tags.DockerTest @DockerTest @@ -39,12 +39,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.setCatalog("foo") + conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[], c12 real[])").executeUpdate() + + "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate() } test("Type mapping for various types") { @@ -52,7 +53,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 13) + assert(types.length == 14) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -66,22 +67,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(classOf[Seq[Double]].isAssignableFrom(types(12))) + assert(classOf[String].isAssignableFrom(types(13))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) assert(rows(0).getLong(3) == 123456789012345L) - assert(rows(0).getBoolean(4) == false) + assert(!rows(0).getBoolean(4)) // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) - assert(rows(0).getBoolean(7) == true) + assert(rows(0).getBoolean(7)) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) + assert(rows(0).getString(13) == "d1") } test("Basic write test") { From 3a40c0e575fd4215302ea60c9821d31a5a138b8a Mon Sep 17 00:00:00 2001 From: Brandon Bradley Date: Thu, 28 Jan 2016 15:25:57 -0800 Subject: [PATCH 1729/2089] [SPARK-12749][SQL] add json option to parse floating-point types as DecimalType I tried to add this via `USE_BIG_DECIMAL_FOR_FLOATS` option from Jackson with no success. Added test for non-complex types. Should I add a test for complex types? Author: Brandon Bradley Closes #10936 from blbradley/spark-12749. --- python/pyspark/sql/readwriter.py | 2 ++ .../apache/spark/sql/DataFrameReader.scala | 2 ++ .../datasources/json/InferSchema.scala | 8 ++++-- .../datasources/json/JSONOptions.scala | 2 ++ .../datasources/json/JsonSuite.scala | 28 +++++++++++++++++++ 5 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 0b20022b14b8d..b1453c637f79e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -152,6 +152,8 @@ def json(self, path, schema=None): You can set the following JSON-specific options to deal with non-standard JSON files: * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ type + * `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal \ + type * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 634c1bd4739b1..2e0c6c7df967e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -252,6 +252,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * You can set the following JSON-specific options to deal with non-standard JSON files: *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal + * type
  • *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 44d5e4ff7ec8b..8b773ddfcb656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -134,8 +134,12 @@ private[json] object InferSchema { val v = parser.getDecimalValue DecimalType(v.precision(), v.scale()) case FLOAT | DOUBLE => - // TODO(davies): Should we use decimal if possible? - DoubleType + if (configOptions.floatAsBigDecimal) { + val v = parser.getDecimalValue + DecimalType(v.precision(), v.scale()) + } else { + DoubleType + } } case VALUE_TRUE | VALUE_FALSE => BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index fe5b20697e40e..31a95ed461215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -34,6 +34,8 @@ private[sql] class JSONOptions( parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val floatAsBigDecimal = + parameters.get("floatAsBigDecimal").map(_.toBoolean).getOrElse(false) val allowComments = parameters.get("allowComments").map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 00eaeb0d34e87..dd83a0e36f6f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -771,6 +771,34 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Loading a JSON dataset floatAsBigDecimal returns schema with float types as BigDecimal") { + val jsonDF = sqlContext.read.option("floatAsBigDecimal", "true").json(primitiveFieldAndType) + + val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DecimalType(17, -292), true) :: + StructField("integer", LongType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(BigDecimal("92233720368547758070"), + true, + BigDecimal("1.7976931348623157E308"), + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() From 4637fc08a3733ec313218fb7e4d05064d9a6262d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jan 2016 16:25:21 -0800 Subject: [PATCH 1730/2089] [SPARK-11955][SQL] Mark optional fields in merging schema for safely pushdowning filters in Parquet JIRA: https://issues.apache.org/jira/browse/SPARK-11955 Currently we simply skip pushdowning filters in parquet if we enable schema merging. However, we can actually mark particular fields in merging schema for safely pushdowning filters in parquet. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #9940 from viirya/safe-pushdown-parquet-filters. --- .../org/apache/spark/sql/types/Metadata.scala | 5 +++ .../apache/spark/sql/types/StructType.scala | 34 ++++++++++++--- .../spark/sql/types/DataTypeSuite.scala | 14 ++++-- .../datasources/parquet/ParquetFilters.scala | 37 +++++++++++----- .../datasources/parquet/ParquetRelation.scala | 13 +++--- .../parquet/ParquetFilterSuite.scala | 43 ++++++++++++++++++- 6 files changed, 117 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 9e0f9943bc638..66f123682e117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -273,4 +273,9 @@ class MetadataBuilder { map.put(key, value) this } + + def remove(key: String): this.type = { + map.remove(key) + this + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3bd733fa2d26c..da0c92864e9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -334,6 +334,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { + private[sql] val metadataKeyForOptionalField = "_OPTIONAL_" + override private[sql] def defaultConcreteType: DataType = new StructType override private[sql] def acceptsType(other: DataType): Boolean = { @@ -359,6 +361,18 @@ object StructType extends AbstractDataType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + def removeMetadata(key: String, dt: DataType): DataType = + dt match { + case StructType(fields) => + val newFields = fields.map { f => + val mb = new MetadataBuilder() + f.copy(dataType = removeMetadata(key, f.dataType), + metadata = mb.withMetadata(f.metadata).remove(key).build()) + } + StructType(newFields) + case _ => dt + } + private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), @@ -376,24 +390,32 @@ object StructType extends AbstractDataType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + // This metadata will record the fields that only exist in one of two StructTypes + val optionalMeta = new MetadataBuilder() val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse { + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + Some(leftField.copy(metadata = optionalMeta.build())) + } .foreach(newFields += _) } val leftMapped = fieldsMap(leftFields) rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach(newFields += _) + .foreach { f => + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + newFields += f.copy(metadata = optionalMeta.build()) + } StructType(newFields) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 706ecd29d1355..c2bbca7c33f28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType(List()) val merged = left.merge(right) - assert(merged === left) + assert(DataType.equalsIgnoreCompatibleNullability(merged, left)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where left is empty") { @@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(right === merged) - + assert(DataType.equalsIgnoreCompatibleNullability(merged, right)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where both are non-empty") { @@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(merged === expected) + assert(DataType.equalsIgnoreCompatibleNullability(merged, expected)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where right contains type conflict") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index e9b734b0abf50..5a5cb5cf03d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -207,11 +207,26 @@ private[sql] object ParquetFilters { */ } + /** + * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField. + * These fields only exist in one side of merged schemas. Due to that, we can't push down filters + * using such fields, otherwise Parquet library will throw exception. Here we filter out such + * fields. + */ + private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { + case StructType(fields) => + fields.filter { f => + !f.metadata.contains(StructType.metadataKeyForOptionalField) || + !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) + }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) } + case _ => Array.empty[(String, DataType)] + } + /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + val dataTypeOf = getFieldMap(schema).toMap relaxParquetValidTypeMap @@ -231,29 +246,29 @@ private[sql] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) => + case sources.IsNull(name) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.IsNotNull(name) => + case sources.IsNotNull(name) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.EqualTo(name, value) => + case sources.EqualTo(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) => + case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) => + case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) => + case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThan(name, value) => + case sources.LessThan(name, value) if dataTypeOf.contains(name) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) => + case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThan(name, value) => + case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => makeGt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) => + case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.In(name, valueSet) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b460ec1d26047..f87590095d344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -258,7 +258,12 @@ private[sql] class ParquetRelation( job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) - CatalystWriteSupport.setSchema(dataSchema, conf) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushdowning filters. + val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + CatalystWriteSupport.setSchema(dataSchemaToWrite, conf) // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) // and `CatalystWriteSupport` (writing actual rows to Parquet files). @@ -304,10 +309,6 @@ private[sql] class ParquetRelation( val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - // When merging schemas is enabled and the column of the given filter does not exist, - // Parquet emits an exception which is an issue of Parquet (PARQUET-389). - val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown - // Parquet row group size. We will use this value as the value for // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value // of these flags are smaller than the parquet row group size. @@ -321,7 +322,7 @@ private[sql] class ParquetRelation( dataSchema, parquetBlockSize, useMetadataCache, - safeParquetFilterPushDown, + parquetFilterPushDown, assumeBinaryIsString, assumeInt96IsTimestamp) _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 97c5313f0feff..1796b3af0e37a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -379,9 +380,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( - sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), - (1 to 1).map(i => Row(i, i.toString, null))) + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = sqlContext.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = sqlContext.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) } } } From b9dfdcc63bb12bc24de96060e756889c2ceda519 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 17:01:12 -0800 Subject: [PATCH 1731/2089] Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage" This reverts commit cc18a7199240bf3b03410c1ba6704fe7ce6ae38e. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++------------ .../aggregate/TungstenAggregate.scala | 88 +++----- .../spark/sql/execution/basicOperators.scala | 96 ++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 202 insertions(+), 334 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f7..2747c315ad374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,23 +144,14 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * A prefix used to generate fresh name. - */ - var freshNamePrefix = "" - /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" - } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" - } + def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b8..d9fe76133c6ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f049f..57f4945de9804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns the RDD of InternalRow which generates the input rows. + * Returns an input RDD of InternalRow and Java source code to process them. */ - def upstream(): RDD[InternalRow] - - /** - * Returns Java source code to process the rows from upstream. - */ - def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { this.parent = parent - ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): String + protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) /** - * Consume the columns generated from current SparkPlan, call it's parent. + * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. */ - def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null) { - assert(input.length == output.length) - } - parent.consumeChild(ctx, this, input, row) + protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { + assert(columns.length == output.length) + parent.doConsume(ctx, this, columns) } - /** - * Consume the columns generated from it's child, call doConsume() or emit the rows. - */ - def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - ctx.freshNamePrefix = nodeName - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} - """.stripMargin - } else { - doConsume(ctx, input) - } - } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String } @@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def doPrepare(): Unit = { - child.prepare() - } - override def doExecute(): RDD[InternalRow] = { - child.execute() - } + override def supportCodegen: Boolean = true - override def supportCodegen: Boolean = false - - override def upstream(): RDD[InternalRow] = { - child.execute() - } - - override def doProduce(ctx: CodegenContext): String = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - s""" - | while (input.hasNext()) { + val code = s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin + (child.execute(), code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() ---------> upstream() -------> upstream() ------> execute() - * | - * -----------------> produce() + * doExecute() --------> produce() * | * doProduce() -------> produce() * | - * doProduce() + * doProduce() ---> execute() * | * consume() - * consumeChild() <-----------| + * doConsume() ------------| * | - * doConsume() - * | - * consumeChild() <----- consume() + * doConsume() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -206,48 +162,37 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { - override def supportCodegen: Boolean = false - override def output: Seq[Attribute] = plan.output - override def outputPartitioning: Partitioning = plan.outputPartitioning - override def outputOrdering: Seq[SortOrder] = plan.outputOrdering - - override def doPrepare(): Unit = { - plan.prepare() - } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val code = plan.produce(ctx, this) + val (rdd, code) = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} + private Object[] references; + ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() throws java.io.IOException { + protected void processNext() { $code - } + } } - """ - + """ // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - plan.upstream().mapPartitions { iter => - + rdd.mapPartitions { iter => val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -258,47 +203,29 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def upstream(): RDD[InternalRow] = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { throw new UnsupportedOperationException } - override def doProduce(ctx: CodegenContext): String = { - throw new UnsupportedOperationException - } - - override def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - - if (row != null) { - // There is an UnsafeRow already + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | currentRow = $row; + | ${code.code.trim} + | currentRow = ${code.value}; | return; - """.stripMargin + """.stripMargin } else { - assert(input != null) - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns - s""" - | currentRow = unsafeRow; - | return; - """.stripMargin - } + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin } } @@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => + plan.children.exists(supportCodegen) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - val input = apply(p) // collapse them recursively - inputs += input - InputAdapter(input) + inputs += p + InputAdapter(p) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index cbd2634b8900f..23e54f344d252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,7 +117,9 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) } // The variables used as aggregation buffer @@ -125,11 +127,7 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -139,80 +137,50 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { - // evaluate aggregate results - ctx.currentVars = bufVars - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttrs).gen(ctx) - } - // evaluate result expressions - ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).gen(ctx) - } - (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} - """.stripMargin) - } else { - // output the aggregate buffer directly - (bufVars, "") - } - - val doAgg = ctx.freshName("doAgg") - ctx.addNewFunction(doAgg, + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = s""" - | private void $doAgg() { + | if (!$initAgg) { + | $initAgg = true; + | | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} | } - """.stripMargin) + """.stripMargin - s""" - | if (!$initAgg) { - | $initAgg = true; - | $doAgg(); - | - | // output the result - | $genResult - | - | ${consume(ctx, resultVars)} - | } - """.stripMargin + (rdd, source) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => - e.mode match { - case Partial | Complete => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions - case PartialMerge | Final => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions - } + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) } + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val updates = updateExpr.zipWithIndex.map { case (e, i) => - val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -222,7 +190,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${updates.mkString("")} + | ${codes.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb4bf..6deb72adad5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -161,21 +153,17 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - override def upstream(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) - } - - protected override def doProduce(ctx: CodegenContext): String = { - val initTerm = ctx.freshName("initRange") + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initTerm = ctx.freshName("range_initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("partitionEnd") + val partitionEnd = ctx.freshName("range_partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("number") + val number = ctx.freshName("range_number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("overflow") + val overflow = ctx.freshName("range_overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("value") + val value = ctx.freshName("range_value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -184,42 +172,38 @@ case class Range( s"$number > $partitionEnd" } - ctx.addNewFunction("initRange", - s""" - | private void initRange(int idx) { - | $BigInt index = $BigInt.valueOf(idx); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } - | } - """.stripMargin) + val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) - s""" + val code = s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | initRange(((InternalRow) input.next()).getInt(0)); + | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } | } else { | return; | } @@ -234,6 +218,12 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin + + (rdd, code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa30e4..989cb2942918e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,61 +1939,58 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // TODO: support subexpression elimination in whole stage codegen - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) - - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) - } + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503c23..cbae19ebd269d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,24 +335,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) - } + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7d6bff8295d2b..d48143762cac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a3e5243b68aba..9a24a2487a254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() - } + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) assert(metrics(0) == 1) From 66449b8dcdbc3dca126c34b42c4d0419c7648696 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Jan 2016 22:20:52 -0800 Subject: [PATCH 1732/2089] [SPARK-12968][SQL] Implement command to set current database JIRA: https://issues.apache.org/jira/browse/SPARK-12968 Implement command to set current database. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10916 from viirya/ddl-use-database. --- .../spark/sql/catalyst/analysis/Catalog.scala | 4 ++++ .../org/apache/spark/sql/execution/SparkQl.scala | 3 +++ .../apache/spark/sql/execution/commands.scala | 10 ++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 -- .../spark/sql/hive/client/ClientInterface.scala | 3 +++ .../spark/sql/hive/client/ClientWrapper.scala | 9 +++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 16 ++++++++++++++++ 9 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index a8f89ce6de457..f2f9ec59417ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -46,6 +46,10 @@ trait Catalog { def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan + def setCurrentDatabase(databaseName: String): Unit = { + throw new UnsupportedOperationException + } + /** * Returns tuples of (tableName, isTemporary) for all tables in the given database. * isTemporary is a Boolean value indicates if a table is a temporary or not. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index f6055306b6c97..a5bd8ee42dec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -55,6 +55,9 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => + SetDatabaseCommand(cleanIdentifier(database)) + case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val Some(tableType) :: formatted :: extended :: pretty :: Nil = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 3cfa3dfd9c7ec..703e4643cbd25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -408,3 +408,13 @@ case class DescribeFunction( } } } + +case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.catalog.setCurrentDatabase(databaseName) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index ab31d45a79a2e..72da266da4d01 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" - -> "OK", + -> "", "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a9c0e9ab7caef..848aa4ec6fe56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -711,6 +711,10 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } override def unregisterAllTables(): Unit = {} + + override def setCurrentDatabase(databaseName: String): Unit = { + client.setCurrentDatabase(databaseName) + } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 22841ed2116d1..752c037a842a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -155,8 +155,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_SHOWLOCKS", "TOK_SHOWPARTITIONS", - "TOK_SWITCHDATABASE", - "TOK_UNLOCKTABLE" ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 9d9a55edd7314..4eec3fef7408b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -109,6 +109,9 @@ private[hive] trait ClientInterface { /** Returns the name of the active database. */ def currentDatabase: String + /** Sets the name of current database. */ + def setCurrentDatabase(databaseName: String): Unit + /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ def getDatabase(name: String): HiveDatabase = { getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index ce7a305d437a5..5307e924e7e55 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -229,6 +230,14 @@ private[hive] class ClientWrapper( state.getCurrentDatabase } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + if (getDatabaseOption(databaseName).isDefined) { + state.setCurrentDatabase(databaseName) + } else { + throw new NoSuchDatabaseException + } + } + override def createDatabase(database: HiveDatabase): Unit = withHiveState { client.createDatabase( new Database( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4659d745fe78b..9632d27a2ffce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin @@ -1262,6 +1263,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("use database") { + val currentDatabase = sql("select current_database()").first().getString(0) + + sql("CREATE DATABASE hive_test_db") + sql("USE hive_test_db") + assert("hive_test_db" == sql("select current_database()").first().getString(0)) + + intercept[NoSuchDatabaseException] { + sql("USE not_existing_db") + } + + sql(s"USE $currentDatabase") + assert(currentDatabase == sql("select current_database()").first().getString(0)) + } + test("lookup hive UDF in another thread") { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") From 721ced28b522cc00b45ca7fa32a99e80ad3de2f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Jan 2016 22:42:43 -0800 Subject: [PATCH 1733/2089] [SPARK-13067] [SQL] workaround for a weird scala reflection problem A simple workaround to avoid getting parameter types when convert a logical plan to json. Author: Wenchen Fan Closes #10970 from cloud-fan/reflection. --- .../spark/sql/catalyst/ScalaReflection.scala | 25 ++++++++++++++++--- .../spark/sql/catalyst/trees/TreeNode.scala | 4 +-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 643228d0eb27d..e5811efb436a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -601,6 +601,20 @@ object ScalaReflection extends ScalaReflection { getConstructorParameters(t) } + /** + * Returns the parameter names for the primary constructor of this class. + * + * Logically we should call `getConstructorParameters` and throw away the parameter types to get + * parameter names, however there are some weird scala reflection problems and this method is a + * workaround to avoid getting parameter types. + */ + def getConstructorParameterNames(cls: Class[_]): Seq[String] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + constructParams(t).map(_.name.toString) + } + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } @@ -745,6 +759,12 @@ trait ScalaReflection { def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe + constructParams(tpe).map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } + + protected def constructParams(tpe: Type): Seq[Symbol] = { val constructorSymbol = tpe.member(nme.CONSTRUCTOR) val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramss @@ -758,9 +778,6 @@ trait ScalaReflection { primaryConstructorSymbol.get.asMethod.paramss } } - - params.flatten.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - } + params.flatten } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 57e1a3c9eb226..2df0683f9fa16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -512,7 +512,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } protected def jsonFields: List[JField] = { - val fieldNames = getConstructorParameters(getClass).map(_._1) + val fieldNames = getConstructorParameterNames(getClass) val fieldValues = productIterator.toSeq ++ otherCopyArgs assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) @@ -560,7 +560,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper case p: Product => try { - val fieldNames = getConstructorParameters(p.getClass).map(_._1) + val fieldNames = getConstructorParameterNames(p.getClass) val fieldValues = p.productIterator.toSeq assert(fieldNames.length == fieldValues.length) ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { From 8d3cc3de7d116190911e7943ef3233fe3b7db1bf Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Thu, 28 Jan 2016 23:34:50 -0800 Subject: [PATCH 1734/2089] [SPARK-13050][BUILD] Scalatest tags fail build with the addition of the sketch module A dependency on the spark test tags was left out of the sketch module pom file causing builds to fail when test tags were used. This dependency is found in the pom file for every other module in spark. Author: Alex Bozarth Closes #10954 from ajbozarth/spark13050. --- common/sketch/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 67723fa421ab1..2cafe8c548f5f 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -35,6 +35,13 @@ sketch + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes From 55561e7693dd2a5bf3c7f8026c725421801fd0ec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 01:59:59 -0800 Subject: [PATCH 1735/2089] [SPARK-13031][SQL] cleanup codegen and improve test coverage 1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. This PR re-open #10944 and fix the bug. Author: Davies Liu Closes #10977 from davies/gen_refactor. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../sql/execution/WholeStageCodegen.scala | 188 ++++++++++++------ .../aggregate/AggregationIterator.scala | 2 +- .../aggregate/TungstenAggregate.scala | 98 ++++++--- .../spark/sql/execution/basicOperators.scala | 96 +++++---- .../spark/sql/DataFrameAggregateSuite.scala | 7 + .../org/apache/spark/sql/SQLQuerySuite.scala | 103 +++++----- .../execution/metric/SQLMetricsSuite.scala | 34 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/util/DataFrameCallbackSuite.scala | 10 +- 11 files changed, 350 insertions(+), 205 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2747c315ad374..e6704cf8bb1f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,14 +144,23 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = { + if (freshNamePrefix == "") { + s"$name${curId.getAndIncrement}" + } else { + s"${freshNamePrefix}_$name${curId.getAndIncrement}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d9fe76133c6ef..ec31db19b94b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 57f4945de9804..ef81ba60f049f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns an input RDD of InternalRow and Java source code to process them. + * Returns the RDD of InternalRow which generates the input rows. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { + def upstream(): RDD[InternalRow] + + /** + * Returns Java source code to process the rows from upstream. + */ + def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent + ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) + protected def doProduce(ctx: CodegenContext): String /** - * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. + * Consume the columns generated from current SparkPlan, call it's parent. */ - protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { - assert(columns.length == output.length) - parent.doConsume(ctx, this, columns) + def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (input != null) { + assert(input.length == output.length) + } + parent.consumeChild(ctx, this, input, row) } + /** + * Consume the columns generated from it's child, call doConsume() or emit the rows. + */ + def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + ctx.freshNamePrefix = nodeName + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + val evals = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + s""" + | ${evals.map(_.code).mkString("\n")} + | ${doConsume(ctx, evals)} + """.stripMargin + } else { + doConsume(ctx, input) + } + } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } } @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doPrepare(): Unit = { + child.prepare() + } - override def supportCodegen: Boolean = true + override def doExecute(): RDD[InternalRow] = { + child.execute() + } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def supportCodegen: Boolean = false + + override def upstream(): RDD[InternalRow] = { + child.execute() + } + + override def doProduce(ctx: CodegenContext): String = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - val code = s""" - | while (input.hasNext()) { + s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin - (child.execute(), code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() --------> produce() + * doExecute() ---------> upstream() -------> upstream() ------> execute() + * | + * -----------------> produce() * | * doProduce() -------> produce() * | - * doProduce() ---> execute() + * doProduce() * | * consume() - * doConsume() ------------| + * consumeChild() <-----------| * | - * doConsume() <----- consume() + * doConsume() + * | + * consumeChild() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { + override def supportCodegen: Boolean = false + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + + override def doPrepare(): Unit = { + plan.prepare() + } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val (rdd, code) = plan.produce(ctx, this) + val code = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() { + protected void processNext() throws java.io.IOException { $code - } + } } - """ + """ + // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - rdd.mapPartitions { iter => + plan.upstream().mapPartitions { iter => + val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { throw new UnsupportedOperationException } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def consumeChild( + ctx: CodegenContext, + child: SparkPlan, + input: Seq[ExprCode], + row: String = null): String = { + + if (row != null) { + // There is an UnsafeRow already s""" - | currentRow = unsafeRow; + | currentRow = $row; | return; """.stripMargin + } else { + assert(input != null) + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + s""" + | ${code.code.trim} + | currentRow = ${code.value}; + | return; + """.stripMargin + } else { + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin + } } } @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) + plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - plan.children.exists(supportCodegen) => + (Utils.isTesting || plan.children.exists(supportCodegen)) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - inputs += p - InputAdapter(p) + val input = apply(p) // collapse them recursively + inputs += input + InputAdapter(input) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 0c74df0aa5fdd..38da82c47ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -238,7 +238,7 @@ abstract class AggregationIterator( resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { - // Grouping-only: we only output values of grouping expressions. + // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { resultProjection(currentGroupingKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 23e54f344d252..ff2f38bfd9105 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,9 +117,7 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && - // final aggregation only have one row, do not need to codegen - !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } // The variables used as aggregation buffer @@ -127,7 +125,11 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -137,60 +139,96 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | boolean $isNull = ${ev.isNull}; - | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val source = + // generate variables for output + val bufferAttrs = functions.flatMap(_.aggBufferAttributes) + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttrs).gen(ctx) + } + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + | ${aggResults.map(_.code).mkString("\n")} + | ${resultVars.map(_.code).mkString("\n")} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.gen(ctx)) + (resultVars, resultVars.map(_.code).mkString("\n")) + } + + val doAgg = ctx.freshName("doAgg") + ctx.addNewFunction(doAgg, s""" - | if (!$initAgg) { - | $initAgg = true; - | + | private void $doAgg() { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | $childSource - | - | // output the result - | ${consume(ctx, bufVars)} + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } - """.stripMargin + """.stripMargin) - (rdd, source) + s""" + | if (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | $genResult + | + | ${consume(ctx, resultVars)} + | } + """.stripMargin } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - // the mode could be only Partial or PartialMerge - val updateExpr = if (modes.contains(Partial)) { - functions.flatMap(_.updateExpressions) - } else { - functions.flatMap(_.mergeExpressions) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } } - val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output - val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val codes = boundExpr.zipWithIndex.map { case (e, i) => - val ev = e.gen(ctx) + val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" - | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; | ${bufVars(i).value} = ${ev.value}; """.stripMargin } s""" - | // do aggregate and update aggregation buffer - | ${codes.mkString("")} + | // do aggregate + | ${aggVals.map(_.code).mkString("\n")} + | // update aggregation buffer + | ${updates.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6deb72adad5ec..e7a73d5fbb4bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -153,17 +161,21 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { - val initTerm = ctx.freshName("range_initRange") + override def upstream(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + + protected override def doProduce(ctx: CodegenContext): String = { + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("range_partitionEnd") + val partitionEnd = ctx.freshName("partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("range_number") + val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("range_overflow") + val overflow = ctx.freshName("overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("range_value") + val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -172,38 +184,42 @@ case class Range( s"$number > $partitionEnd" } - val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | } + """.stripMargin) - val code = s""" + s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } + | initRange(((InternalRow) input.next()).getInt(0)); | } else { | return; | } @@ -218,12 +234,6 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin - - (rdd, code) - } - - def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b1004bc5bc290..08fb7c9d84c0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -153,6 +153,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("agg without groups and functions") { + checkAnswer( + testData2.agg(lit(1)), + Row(1) + ) + } + test("average") { checkAnswer( testData2.agg(avg('a), mean('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 989cb2942918e..51a50c1fa30e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } - - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cbae19ebd269d..82f6811503c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48143762cac0..7d6bff8295d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9a24a2487a254..a3e5243b68aba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + } assert(metrics.length == 3) assert(metrics(0) == 1) From e51b6eaa9e9c007e194d858195291b2b9fb27322 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 29 Jan 2016 09:22:24 -0800 Subject: [PATCH 1736/2089] [SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example * Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang Author: Joseph K. Bradley Closes #10469 from yanboliang/spark-11939. --- python/pyspark/ml/param/__init__.py | 24 +++++ python/pyspark/ml/regression.py | 30 +++++- python/pyspark/ml/tests.py | 36 +++++-- python/pyspark/ml/util.py | 142 +++++++++++++++++++++++++++- python/pyspark/ml/wrapper.py | 33 ++++--- 5 files changed, 236 insertions(+), 29 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 3da36d32c5af0..ea86d6aeb8b31 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -314,3 +314,27 @@ def _copyValues(self, to, extra=None): if p in paramMap and to.hasParam(p.name): to._set(**{p.name: paramMap[p]}) return to + + def _resetUid(self, newUid): + """ + Changes the uid of this instance. This updates both + the stored uid and the parent uid of params and param maps. + This is used by persistence (loading). + :param newUid: new uid to use + :return: same instance, but with the uid and Param.parent values + updated, including within param maps + """ + self.uid = newUid + newDefaultParamMap = dict() + newParamMap = dict() + for param in self.params: + newParam = copy.copy(param) + newParam.parent = newUid + if param in self._defaultParamMap: + newDefaultParamMap[newParam] = self._defaultParamMap[param] + if param in self._paramMap: + newParamMap[newParam] = self._paramMap[param] + param.parent = newUid + self._defaultParamMap = newDefaultParamMap + self._paramMap = newParamMap + return self diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 74a2248ed07c8..20dc6c2db91f3 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -18,9 +18,9 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.mllib.common import inherit_doc @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol): + HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): """ Linear regression. @@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lr_path = path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LinearRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LinearRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -106,7 +125,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel): +class LinearRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by LinearRegression. @@ -821,9 +840,10 @@ def predict(self, features): if __name__ == "__main__": import doctest + import pyspark.ml.regression from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.regression tests") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c45a159c460f3..54806ee336666 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -34,18 +34,22 @@ else: import unittest -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext, Row -from pyspark.sql.functions import rand +from shutil import rmtree +import tempfile + +from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.feature import * from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.util import keyword_only -from pyspark.ml import Estimator, Model, Pipeline, Transformer -from pyspark.ml.feature import * +from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel +from pyspark.ml.util import keyword_only from pyspark.mllib.linalg import DenseVector +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase class MockDataset(DataFrame): @@ -405,6 +409,26 @@ def test_fit_maximize_metric(self): self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") +class PersistenceTest(PySparkTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index cee9d67b05325..d7a813f56cd57 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,8 +15,27 @@ # limitations under the License. # -from functools import wraps +import sys import uuid +from functools import wraps + +if sys.version > '3': + basestring = str + +from pyspark import SparkContext, since +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") def keyword_only(func): @@ -52,3 +71,124 @@ def _randomUID(cls): concatenates the class name, "_", and 12 random hex chars. """ return cls.__name__ + "_" + uuid.uuid4().hex[12:] + + +@inherit_doc +class JavaMLWriter(object): + """ + .. note:: Experimental + + Utility class that can save ML instances through their Scala implementation. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, instance): + instance._transfer_params_to_java() + self._jwrite = instance._java_obj.write() + + def save(self, path): + """Save the ML instance to the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._jwrite.save(path) + + def overwrite(self): + """Overwrites if the output path already exists.""" + self._jwrite.overwrite() + return self + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + self._jwrite.context(sqlContext._ssql_ctx) + return self + + +@inherit_doc +class MLWritable(object): + """ + .. note:: Experimental + + Mixin for ML instances that provide JavaMLWriter. + + .. versionadded:: 2.0.0 + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + +@inherit_doc +class JavaMLReader(object): + """ + .. note:: Experimental + + Utility class that can load ML instances through their Scala implementation. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, clazz): + self._clazz = clazz + self._jread = self._load_java_obj(clazz).read() + + def load(self, path): + """Load the ML instance from the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_obj = self._jread.load(path) + instance = self._clazz() + instance._java_obj = java_obj + instance._resetUid(java_obj.uid()) + instance._transfer_params_from_java() + return instance + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + self._jread.context(sqlContext._ssql_ctx) + return self + + @classmethod + def _java_loader_class(cls, clazz): + """ + Returns the full class name of the Java ML instance. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = clazz.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, clazz.__name__]) + + @classmethod + def _load_java_obj(cls, clazz): + """Load the peer Java object of the ML instance.""" + java_class = cls._java_loader_class(clazz) + java_obj = _jvm() + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj + + +@inherit_doc +class MLReadable(object): + """ + .. note:: Experimental + + Mixin for instances that provide JavaMLReader. + + .. versionadded:: 2.0.0 + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def load(cls, path): + """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" + return cls.read().load(path) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index dd1d4b076eddd..d4d48eb2150e3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,21 +21,10 @@ from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -def _jvm(): - """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. - """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") - - @inherit_doc class JavaWrapper(Params): """ @@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer): __metaclass__ = ABCMeta - def __init__(self, java_model): + def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. + + This instance can be instantiated without specifying java_model, + it will be assigned after that, but this scenario only used by + :py:class:`JavaMLReader` to load models. This is a bit of a + hack, but it is easiest since a proper fix would require + MLReader (in pyspark.ml.util) to depend on these wrappers, but + these wrappers depend on pyspark.ml.util (both directly and via + other ML classes). """ super(JavaModel, self).__init__() - self._java_obj = java_model - self.uid = java_model.uid() + if java_model is not None: + self._java_obj = java_model + self.uid = java_model.uid() def copy(self, extra=None): """ @@ -182,8 +180,9 @@ def copy(self, extra=None): if extra is None: extra = dict() that = super(JavaModel, self).copy(extra) - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that def _call_java(self, name, *args): From e4c1162b6b3dbc8fc95cfe75c6e0bc2915575fb2 Mon Sep 17 00:00:00 2001 From: zhuol Date: Fri, 29 Jan 2016 11:54:58 -0600 Subject: [PATCH 1737/2089] [SPARK-10873] Support column sort and search for History Server. [SPARK-10873] Support column sort and search for History Server using jQuery DataTable and REST API. Before this commit, the history server was generated hard-coded html and can not support search, also, the sorting was disabled if there is any application that has more than one attempt. Supporting search and sort (over all applications rather than the 20 entries in the current page) in any case will greatly improve user experience. 1. Create the historypage-template.html for displaying application information in datables. 2. historypage.js uses jQuery to access the data from /api/v1/applications REST API, and use DataTable to display each application's information. For application that has more than one attempt, the RowsGroup is used to merge such entries while at the same time supporting sort and search. 3. "duration" and "lastUpdated" rest API are added to application's "attempts". 4. External javascirpt and css files for datatables, RowsGroup and jquery plugins are added with licenses clarified. Snapshots for how it looks like now: History page view: ![historypage](https://cloud.githubusercontent.com/assets/11683054/12184383/89bad774-b55a-11e5-84e4-b0276172976f.png) Search: ![search](https://cloud.githubusercontent.com/assets/11683054/12184385/8d3b94b0-b55a-11e5-869a-cc0ef0a4242a.png) Sort by started time: ![sort-by-started-time](https://cloud.githubusercontent.com/assets/11683054/12184387/8f757c3c-b55a-11e5-98c8-577936366566.png) Author: zhuol Closes #10648 from zhuoliu/10873. --- .rat-excludes | 10 + LICENSE | 6 + .../spark/ui/static/dataTables.bootstrap.css | 319 ++++++++++ .../ui/static/dataTables.bootstrap.min.js | 8 + .../spark/ui/static/dataTables.rowsGroup.js | 224 +++++++ .../spark/ui/static/historypage-template.html | 81 +++ .../org/apache/spark/ui/static/historypage.js | 159 +++++ .../spark/ui/static/jquery.blockUI.min.js | 6 + .../ui/static/jquery.cookies.2.2.0.min.js | 18 + .../static/jquery.dataTables.1.10.4.min.css | 1 + .../ui/static/jquery.dataTables.1.10.4.min.js | 157 +++++ .../apache/spark/ui/static/jquery.mustache.js | 592 ++++++++++++++++++ .../spark/ui/static/jsonFormatter.min.css | 1 + .../spark/ui/static/jsonFormatter.min.js | 2 + .../spark/deploy/history/HistoryPage.scala | 193 +----- .../api/v1/ApplicationListResource.scala | 14 + .../org/apache/spark/status/api/v1/api.scala | 2 + .../scala/org/apache/spark/ui/SparkUI.scala | 2 + .../scala/org/apache/spark/ui/UIUtils.scala | 11 + .../application_list_json_expectation.json | 18 +- .../completed_app_list_json_expectation.json | 18 +- .../maxDate2_app_list_json_expectation.json | 4 +- .../maxDate_app_list_json_expectation.json | 6 +- .../minDate_app_list_json_expectation.json | 14 +- .../one_app_json_expectation.json | 4 +- ...ne_app_multi_attempt_json_expectation.json | 6 +- .../deploy/history/HistoryServerSuite.scala | 43 +- project/MimaExcludes.scala | 4 + 28 files changed, 1721 insertions(+), 202 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage-template.html create mode 100644 core/src/main/resources/org/apache/spark/ui/static/historypage.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js create mode 100755 core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.css create mode 100755 core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js diff --git a/.rat-excludes b/.rat-excludes index a4f316a4aaa04..874a6ee9f4043 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,16 @@ graphlib-dot.min.js sorttable.js vis.min.js vis.min.css +dataTables.bootstrap.css +dataTables.bootstrap.min.js +dataTables.rowsGroup.js +jquery.blockUI.min.js +jquery.cookies.2.2.0.min.js +jquery.dataTables.1.10.4.min.css +jquery.dataTables.1.10.4.min.js +jquery.mustache.js +jsonFormatter.min.css +jsonFormatter.min.js .*avsc .*txt .*json diff --git a/LICENSE b/LICENSE index 9c944ac610afe..9fc29db8d3f22 100644 --- a/LICENSE +++ b/LICENSE @@ -291,3 +291,9 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) + (MIT License) datatables (http://datatables.net/license) + (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) + (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) + (MIT License) blockUI (http://jquery.malsup.com/block/) + (MIT License) RowsGroup (http://datatables.net/license/mit) + (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css new file mode 100644 index 0000000000000..faee0e50dbfea --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.css @@ -0,0 +1,319 @@ +div.dataTables_length label { + font-weight: normal; + text-align: left; + white-space: nowrap; +} + +div.dataTables_length select { + width: 75px; + display: inline-block; +} + +div.dataTables_filter { + text-align: right; +} + +div.dataTables_filter label { + font-weight: normal; + white-space: nowrap; + text-align: left; +} + +div.dataTables_filter input { + margin-left: 0.5em; + display: inline-block; +} + +div.dataTables_info { + padding-top: 8px; + white-space: nowrap; +} + +div.dataTables_paginate { + margin: 0; + white-space: nowrap; + text-align: right; +} + +div.dataTables_paginate ul.pagination { + margin: 2px 0; + white-space: nowrap; +} + +@media screen and (max-width: 767px) { + div.dataTables_length, + div.dataTables_filter, + div.dataTables_info, + div.dataTables_paginate { + text-align: center; + } +} + + +table.dataTable td, +table.dataTable th { + -webkit-box-sizing: content-box; + -moz-box-sizing: content-box; + box-sizing: content-box; +} + + +table.dataTable { + clear: both; + margin-top: 6px !important; + margin-bottom: 6px !important; + max-width: none !important; +} + +table.dataTable thead .sorting, +table.dataTable thead .sorting_asc, +table.dataTable thead .sorting_desc, +table.dataTable thead .sorting_asc_disabled, +table.dataTable thead .sorting_desc_disabled { + cursor: pointer; +} + +table.dataTable thead .sorting { background: url('../images/sort_both.png') no-repeat center right; } +table.dataTable thead .sorting_asc { background: url('../images/sort_asc.png') no-repeat center right; } +table.dataTable thead .sorting_desc { background: url('../images/sort_desc.png') no-repeat center right; } + +table.dataTable thead .sorting_asc_disabled { background: url('../images/sort_asc_disabled.png') no-repeat center right; } +table.dataTable thead .sorting_desc_disabled { background: url('../images/sort_desc_disabled.png') no-repeat center right; } + +table.dataTable thead > tr > th { + padding-left: 18px; + padding-right: 18px; +} + +table.dataTable th:active { + outline: none; +} + +/* Scrolling */ +div.dataTables_scrollHead table { + margin-bottom: 0 !important; + border-bottom-left-radius: 0; + border-bottom-right-radius: 0; +} + +div.dataTables_scrollHead table thead tr:last-child th:first-child, +div.dataTables_scrollHead table thead tr:last-child td:first-child { + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.dataTables_scrollBody table { + border-top: none; + margin-top: 0 !important; + margin-bottom: 0 !important; +} + +div.dataTables_scrollBody tbody tr:first-child th, +div.dataTables_scrollBody tbody tr:first-child td { + border-top: none; +} + +div.dataTables_scrollFoot table { + margin-top: 0 !important; + border-top: none; +} + +/* Frustratingly the border-collapse:collapse used by Bootstrap makes the column + width calculations when using scrolling impossible to align columns. We have + to use separate + */ +table.table-bordered.dataTable { + border-collapse: separate !important; +} +table.table-bordered thead th, +table.table-bordered thead td { + border-left-width: 0; + border-top-width: 0; +} +table.table-bordered tbody th, +table.table-bordered tbody td { + border-left-width: 0; + border-bottom-width: 0; +} +table.table-bordered th:last-child, +table.table-bordered td:last-child { + border-right-width: 0; +} +div.dataTables_scrollHead table.table-bordered { + border-bottom-width: 0; +} + + + + +/* + * TableTools styles + */ +.table.dataTable tbody tr.active td, +.table.dataTable tbody tr.active th { + background-color: #08C; + color: white; +} + +.table.dataTable tbody tr.active:hover td, +.table.dataTable tbody tr.active:hover th { + background-color: #0075b0 !important; +} + +.table.dataTable tbody tr.active th > a, +.table.dataTable tbody tr.active td > a { + color: white; +} + +.table-striped.dataTable tbody tr.active:nth-child(odd) td, +.table-striped.dataTable tbody tr.active:nth-child(odd) th { + background-color: #017ebc; +} + +table.DTTT_selectable tbody tr { + cursor: pointer; +} + +div.DTTT .btn { + color: #333 !important; + font-size: 12px; +} + +div.DTTT .btn:hover { + text-decoration: none !important; +} + +ul.DTTT_dropdown.dropdown-menu { + z-index: 2003; +} + +ul.DTTT_dropdown.dropdown-menu a { + color: #333 !important; /* needed only when demo_page.css is included */ +} + +ul.DTTT_dropdown.dropdown-menu li { + position: relative; +} + +ul.DTTT_dropdown.dropdown-menu li:hover a { + background-color: #0088cc; + color: white !important; +} + +div.DTTT_collection_background { + z-index: 2002; +} + +/* TableTools information display */ +div.DTTT_print_info { + position: fixed; + top: 50%; + left: 50%; + width: 400px; + height: 150px; + margin-left: -200px; + margin-top: -75px; + text-align: center; + color: #333; + padding: 10px 30px; + opacity: 0.95; + + background-color: white; + border: 1px solid rgba(0, 0, 0, 0.2); + border-radius: 6px; + + -webkit-box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5); + box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5); +} + +div.DTTT_print_info h6 { + font-weight: normal; + font-size: 28px; + line-height: 28px; + margin: 1em; +} + +div.DTTT_print_info p { + font-size: 14px; + line-height: 20px; +} + +div.dataTables_processing { + position: absolute; + top: 50%; + left: 50%; + width: 100%; + height: 60px; + margin-left: -50%; + margin-top: -25px; + padding-top: 20px; + padding-bottom: 20px; + text-align: center; + font-size: 1.2em; + background-color: white; + background: -webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0))); + background: -webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: -o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); + background: linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%); +} + + + +/* + * FixedColumns styles + */ +div.DTFC_LeftHeadWrapper table, +div.DTFC_LeftFootWrapper table, +div.DTFC_RightHeadWrapper table, +div.DTFC_RightFootWrapper table, +table.DTFC_Cloned tr.even { + background-color: white; + margin-bottom: 0; +} + +div.DTFC_RightHeadWrapper table , +div.DTFC_LeftHeadWrapper table { + border-bottom: none !important; + margin-bottom: 0 !important; + border-top-right-radius: 0 !important; + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.DTFC_RightHeadWrapper table thead tr:last-child th:first-child, +div.DTFC_RightHeadWrapper table thead tr:last-child td:first-child, +div.DTFC_LeftHeadWrapper table thead tr:last-child th:first-child, +div.DTFC_LeftHeadWrapper table thead tr:last-child td:first-child { + border-bottom-left-radius: 0 !important; + border-bottom-right-radius: 0 !important; +} + +div.DTFC_RightBodyWrapper table, +div.DTFC_LeftBodyWrapper table { + border-top: none; + margin: 0 !important; +} + +div.DTFC_RightBodyWrapper tbody tr:first-child th, +div.DTFC_RightBodyWrapper tbody tr:first-child td, +div.DTFC_LeftBodyWrapper tbody tr:first-child th, +div.DTFC_LeftBodyWrapper tbody tr:first-child td { + border-top: none; +} + +div.DTFC_RightFootWrapper table, +div.DTFC_LeftFootWrapper table { + border-top: none; + margin-top: 0 !important; +} + + +/* + * FixedHeader styles + */ +div.FixedHeader_Cloned table { + margin: 0 !important +} + diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js new file mode 100644 index 0000000000000..f0d09b9d52668 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js @@ -0,0 +1,8 @@ +/*! + DataTables Bootstrap 3 integration + ©2011-2014 SpryMedia Ltd - datatables.net/license +*/ +(function(){var f=function(c,b){c.extend(!0,b.defaults,{dom:"<'row'<'col-sm-6'l><'col-sm-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-6'i><'col-sm-6'p>>",renderer:"bootstrap"});c.extend(b.ext.classes,{sWrapper:"dataTables_wrapper form-inline dt-bootstrap",sFilterInput:"form-control input-sm",sLengthSelect:"form-control input-sm"});b.ext.renderer.pageButton.bootstrap=function(g,f,p,k,h,l){var q=new b.Api(g),r=g.oClasses,i=g.oLanguage.oPaginate,d,e,o=function(b,f){var j,m,n,a,k=function(a){a.preventDefault(); +c(a.currentTarget).hasClass("disabled")||q.page(a.data.action).draw(!1)};j=0;for(m=f.length;j",{"class":r.sPageButton+" "+ +e,"aria-controls":g.sTableId,tabindex:g.iTabIndex,id:0===p&&"string"===typeof a?g.sTableId+"_"+a:null}).append(c("",{href:"#"}).html(d)).appendTo(b),g.oApi._fnBindAction(n,{action:a},k))}};o(c(f).empty().html('
      ').children("ul"),k)};b.TableTools&&(c.extend(!0,b.TableTools.classes,{container:"DTTT btn-group",buttons:{normal:"btn btn-default",disabled:"disabled"},collection:{container:"DTTT_dropdown dropdown-menu",buttons:{normal:"",disabled:"disabled"}},print:{info:"DTTT_print_info"}, +select:{row:"active"}}),c.extend(!0,b.TableTools.DEFAULTS.oTags,{collection:{container:"ul",button:"li",liner:"a"}}))};"function"===typeof define&&define.amd?define(["jquery","datatables"],f):"object"===typeof exports?f(require("jquery"),require("datatables")):jQuery&&f(jQuery,jQuery.fn.dataTable)})(window,document); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js new file mode 100644 index 0000000000000..983c3a564fb1b --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js @@ -0,0 +1,224 @@ +/*! RowsGroup for DataTables v1.0.0 + * 2015 Alexey Shildyakov ashl1future@gmail.com + */ + +/** + * @summary RowsGroup + * @description Group rows by specified columns + * @version 1.0.0 + * @file dataTables.rowsGroup.js + * @author Alexey Shildyakov (ashl1future@gmail.com) + * @contact ashl1future@gmail.com + * @copyright Alexey Shildyakov + * + * License MIT - http://datatables.net/license/mit + * + * This feature plug-in for DataTables automatically merges columns cells + * based on it's values equality. It supports multi-column row grouping + * in according to the requested order with dependency from each previous + * requested columns. Now it supports ordering and searching. + * Please see the example.html for details. + * + * Rows grouping in DataTables can be enabled by using any one of the following + * options: + * + * * Setting the `rowsGroup` parameter in the DataTables initialisation + * to array which contains columns selectors + * (https://datatables.net/reference/type/column-selector) used for grouping. i.e. + * rowsGroup = [1, 'columnName:name', ] + * * Setting the `rowsGroup` parameter in the DataTables defaults + * (thus causing all tables to have this feature) - i.e. + * `$.fn.dataTable.defaults.RowsGroup = [0]`. + * * Creating a new instance: `new $.fn.dataTable.RowsGroup( table, columnsForGrouping );` + * where `table` is a DataTable's API instance and `columnsForGrouping` is the array + * described above. + * + * For more detailed information please see: + * + */ + +(function($){ + +ShowedDataSelectorModifier = { + order: 'current', + page: 'current', + search: 'applied', +} + +GroupedColumnsOrderDir = 'desc'; // change + + +/* + * columnsForGrouping: array of DTAPI:cell-selector for columns for which rows grouping is applied + */ +var RowsGroup = function ( dt, columnsForGrouping ) +{ + this.table = dt.table(); + this.columnsForGrouping = columnsForGrouping; + // set to True when new reorder is applied by RowsGroup to prevent order() looping + this.orderOverrideNow = false; + this.order = [] + + self = this; + $(document).on('order.dt', function ( e, settings) { + if (!self.orderOverrideNow) { + self._updateOrderAndDraw() + } + self.orderOverrideNow = false; + }) + + $(document).on('draw.dt', function ( e, settings) { + self._mergeCells() + }) + + this._updateOrderAndDraw(); +}; + + +RowsGroup.prototype = { + _getOrderWithGroupColumns: function (order, groupedColumnsOrderDir) + { + if (groupedColumnsOrderDir === undefined) + groupedColumnsOrderDir = GroupedColumnsOrderDir + + var self = this; + var groupedColumnsIndexes = this.columnsForGrouping.map(function(columnSelector){ + return self.table.column(columnSelector).index() + }) + var groupedColumnsKnownOrder = order.filter(function(columnOrder){ + return groupedColumnsIndexes.indexOf(columnOrder[0]) >= 0 + }) + var nongroupedColumnsOrder = order.filter(function(columnOrder){ + return groupedColumnsIndexes.indexOf(columnOrder[0]) < 0 + }) + var groupedColumnsKnownOrderIndexes = groupedColumnsKnownOrder.map(function(columnOrder){ + return columnOrder[0] + }) + var groupedColumnsOrder = groupedColumnsIndexes.map(function(iColumn){ + var iInOrderIndexes = groupedColumnsKnownOrderIndexes.indexOf(iColumn) + if (iInOrderIndexes >= 0) + return [iColumn, groupedColumnsKnownOrder[iInOrderIndexes][1]] + else + return [iColumn, groupedColumnsOrderDir] + }) + + groupedColumnsOrder.push.apply(groupedColumnsOrder, nongroupedColumnsOrder) + return groupedColumnsOrder; + }, + + // Workaround: the DT reset ordering to 'desc' from multi-ordering if user order on one column (without shift) + // but because we always has multi-ordering due to grouped rows this happens every time + _getInjectedMonoSelectWorkaround: function(order) + { + if (order.length === 1) { + // got mono order - workaround here + var orderingColumn = order[0][0] + var previousOrder = this.order.map(function(val){ + return val[0] + }) + var iColumn = previousOrder.indexOf(orderingColumn); + if (iColumn >= 0) { + // assume change the direction, because we already has that in previous order + return [[orderingColumn, this._toogleDirection(this.order[iColumn][1])]] + } // else This is the new ordering column. Proceed as is. + } // else got multi order - work normal + return order; + }, + + _mergeCells: function() + { + var columnsIndexes = this.table.columns(this.columnsForGrouping, ShowedDataSelectorModifier).indexes().toArray() + var showedRowsCount = this.table.rows(ShowedDataSelectorModifier)[0].length + this._mergeColumn(0, showedRowsCount - 1, columnsIndexes) + }, + + // the index is relative to the showed data + // (selector-modifier = {order: 'current', page: 'current', search: 'applied'}) index + _mergeColumn: function(iStartRow, iFinishRow, columnsIndexes) + { + var columnsIndexesCopy = columnsIndexes.slice() + currentColumn = columnsIndexesCopy.shift() + currentColumn = this.table.column(currentColumn, ShowedDataSelectorModifier) + + var columnNodes = currentColumn.nodes() + var columnValues = currentColumn.data() + + var newSequenceRow = iStartRow, + iRow; + for (iRow = iStartRow + 1; iRow <= iFinishRow; ++iRow) { + + if (columnValues[iRow] === columnValues[newSequenceRow]) { + $(columnNodes[iRow]).hide() + } else { + $(columnNodes[newSequenceRow]).show() + $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1) - newSequenceRow + 1) + + if (columnsIndexesCopy.length > 0) + this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy) + + newSequenceRow = iRow; + } + + } + $(columnNodes[newSequenceRow]).show() + $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1)- newSequenceRow + 1) + if (columnsIndexesCopy.length > 0) + this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy) + }, + + _toogleDirection: function(dir) + { + return dir == 'asc'? 'desc': 'asc'; + }, + + _updateOrderAndDraw: function() + { + this.orderOverrideNow = true; + + var currentOrder = this.table.order(); + currentOrder = this._getInjectedMonoSelectWorkaround(currentOrder); + this.order = this._getOrderWithGroupColumns(currentOrder) + // this.table.order($.extend(true, Array(), this.order)) // disable this line in order to support sorting on non-grouped columns + this.table.draw(false) + }, +}; + + +$.fn.dataTable.RowsGroup = RowsGroup; +$.fn.DataTable.RowsGroup = RowsGroup; + +// Automatic initialisation listener +$(document).on( 'init.dt', function ( e, settings ) { + if ( e.namespace !== 'dt' ) { + return; + } + + var api = new $.fn.dataTable.Api( settings ); + + if ( settings.oInit.rowsGroup || + $.fn.dataTable.defaults.rowsGroup ) + { + options = settings.oInit.rowsGroup? + settings.oInit.rowsGroup: + $.fn.dataTable.defaults.rowsGroup; + new RowsGroup( api, options ); + } +} ); + +}(jQuery)); + +/* + +TODO: Provide function which determines the all s and s with "rowspan" html-attribute is parent (groupped) for the specified or . To use in selections, editing or hover styles. + +TODO: Feature +Use saved order direction for grouped columns + Split the columns into grouped and ungrouped. + + user = grouped+ungrouped + grouped = grouped + saved = grouped+ungrouped + + For grouped uses following order: user -> saved (because 'saved' include 'grouped' after first initialisation). This should be done with saving order like for 'groupedColumns' + For ungrouped: uses only 'user' input ordering +*/ \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html new file mode 100644 index 0000000000000..66d111e439096 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -0,0 +1,81 @@ + + + diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js new file mode 100644 index 0000000000000..785abe45bc56e --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// this function works exactly the same as UIUtils.formatDuration +function formatDuration(milliseconds) { + if (milliseconds < 100) { + return milliseconds + " ms"; + } + var seconds = milliseconds * 1.0 / 1000; + if (seconds < 1) { + return seconds.toFixed(1) + " s"; + } + if (seconds < 60) { + return seconds.toFixed(0) + " s"; + } + var minutes = seconds / 60; + if (minutes < 10) { + return minutes.toFixed(1) + " min"; + } else if (minutes < 60) { + return minutes.toFixed(0) + " min"; + } + var hours = minutes / 60; + return hours.toFixed(1) + " h"; +} + +function formatDate(date) { + return date.split(".")[0].replace("T", " "); +} + +function getParameterByName(name, searchString) { + var regex = new RegExp("[\\?&]" + name + "=([^&#]*)"), + results = regex.exec(searchString); + return results === null ? "" : decodeURIComponent(results[1].replace(/\+/g, " ")); +} + +jQuery.extend( jQuery.fn.dataTableExt.oSort, { + "title-numeric-pre": function ( a ) { + var x = a.match(/title="*(-?[0-9\.]+)/)[1]; + return parseFloat( x ); + }, + + "title-numeric-asc": function ( a, b ) { + return ((a < b) ? -1 : ((a > b) ? 1 : 0)); + }, + + "title-numeric-desc": function ( a, b ) { + return ((a < b) ? 1 : ((a > b) ? -1 : 0)); + } +} ); + +$(document).ajaxStop($.unblockUI); +$(document).ajaxStart(function(){ + $.blockUI({ message: '

      Loading history summary...

      '}); +}); + +$(document).ready(function() { + $.extend( $.fn.dataTable.defaults, { + stateSave: true, + lengthMenu: [[20,40,60,100,-1], [20, 40, 60, 100, "All"]], + pageLength: 20 + }); + + historySummary = $("#history-summary"); + searchString = historySummary["context"]["location"]["search"]; + requestedIncomplete = getParameterByName("showIncomplete", searchString); + requestedIncomplete = (requestedIncomplete == "true" ? true : false); + + $.getJSON("/api/v1/applications", function(response,status,jqXHR) { + var array = []; + var hasMultipleAttempts = false; + for (i in response) { + var app = response[i]; + if (app["attempts"][0]["completed"] == requestedIncomplete) { + continue; // if we want to show for Incomplete, we skip the completed apps; otherwise skip incomplete ones. + } + var id = app["id"]; + var name = app["name"]; + if (app["attempts"].length > 1) { + hasMultipleAttempts = true; + } + var num = app["attempts"].length; + for (j in app["attempts"]) { + var attempt = app["attempts"][j]; + attempt["startTime"] = formatDate(attempt["startTime"]); + attempt["endTime"] = formatDate(attempt["endTime"]); + attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; + array.push(app_clone); + } + } + + var data = {"applications": array} + $.get("/static/historypage-template.html", function(template) { + historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); + var selector = "#history-summary-table"; + var conf = { + "columns": [ + {name: 'first'}, + {name: 'second'}, + {name: 'third'}, + {name: 'fourth'}, + {name: 'fifth'}, + {name: 'sixth', type: "title-numeric"}, + {name: 'seventh'}, + {name: 'eighth'}, + ], + }; + + var rowGroupConf = { + "rowsGroup": [ + 'first:name', + 'second:name' + ], + }; + + if (hasMultipleAttempts) { + jQuery.extend(conf, rowGroupConf); + var rowGroupCells = document.getElementsByClassName("rowGroupColumn"); + for (i = 0; i < rowGroupCells.length; i++) { + rowGroupCells[i].style='background-color: #ffffff'; + } + } + + if (!hasMultipleAttempts) { + var attemptIDCells = document.getElementsByClassName("attemptIDSpan"); + for (i = 0; i < attemptIDCells.length; i++) { + attemptIDCells[i].style.display='none'; + } + } + + var durationCells = document.getElementsByClassName("durationClass"); + for (i = 0; i < durationCells.length; i++) { + var timeInMilliseconds = parseInt(durationCells[i].title); + durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + } + + if ($(selector.concat(" tr")).length < 20) { + $.extend(conf, {paging: false}); + } + + $(selector).DataTable(conf); + $('#hisotry-summary [data-toggle="tooltip"]').tooltip(); + }); + }); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js new file mode 100644 index 0000000000000..1e84b3ec21c4d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js @@ -0,0 +1,6 @@ +/* +* jQuery BlockUI; v20131009 +* http://jquery.malsup.com/block/ +* Copyright (c) 2013 M. Alsup; Dual licensed: MIT/GPL +*/ +(function(){"use strict";function e(e){function o(o,i){var s,h,k=o==window,v=i&&void 0!==i.message?i.message:void 0;if(i=e.extend({},e.blockUI.defaults,i||{}),!i.ignoreIfBlocked||!e(o).data("blockUI.isBlocked")){if(i.overlayCSS=e.extend({},e.blockUI.defaults.overlayCSS,i.overlayCSS||{}),s=e.extend({},e.blockUI.defaults.css,i.css||{}),i.onOverlayClick&&(i.overlayCSS.cursor="pointer"),h=e.extend({},e.blockUI.defaults.themedCSS,i.themedCSS||{}),v=void 0===v?i.message:v,k&&b&&t(window,{fadeOut:0}),v&&"string"!=typeof v&&(v.parentNode||v.jquery)){var y=v.jquery?v[0]:v,m={};e(o).data("blockUI.history",m),m.el=y,m.parent=y.parentNode,m.display=y.style.display,m.position=y.style.position,m.parent&&m.parent.removeChild(y)}e(o).data("blockUI.onUnblock",i.onUnblock);var g,I,w,U,x=i.baseZ;g=r||i.forceIframe?e(''):e(''),I=i.theme?e(''):e(''),i.theme&&k?(U='"):i.theme?(U='"):U=k?'':'',w=e(U),v&&(i.theme?(w.css(h),w.addClass("ui-widget-content")):w.css(s)),i.theme||I.css(i.overlayCSS),I.css("position",k?"fixed":"absolute"),(r||i.forceIframe)&&g.css("opacity",0);var C=[g,I,w],S=k?e("body"):e(o);e.each(C,function(){this.appendTo(S)}),i.theme&&i.draggable&&e.fn.draggable&&w.draggable({handle:".ui-dialog-titlebar",cancel:"li"});var O=f&&(!e.support.boxModel||e("object,embed",k?null:o).length>0);if(u||O){if(k&&i.allowBodyStretch&&e.support.boxModel&&e("html,body").css("height","100%"),(u||!e.support.boxModel)&&!k)var E=d(o,"borderTopWidth"),T=d(o,"borderLeftWidth"),M=E?"(0 - "+E+")":0,B=T?"(0 - "+T+")":0;e.each(C,function(e,o){var t=o[0].style;if(t.position="absolute",2>e)k?t.setExpression("height","Math.max(document.body.scrollHeight, document.body.offsetHeight) - (jQuery.support.boxModel?0:"+i.quirksmodeOffsetHack+') + "px"'):t.setExpression("height",'this.parentNode.offsetHeight + "px"'),k?t.setExpression("width",'jQuery.support.boxModel && document.documentElement.clientWidth || document.body.clientWidth + "px"'):t.setExpression("width",'this.parentNode.offsetWidth + "px"'),B&&t.setExpression("left",B),M&&t.setExpression("top",M);else if(i.centerY)k&&t.setExpression("top",'(document.documentElement.clientHeight || document.body.clientHeight) / 2 - (this.offsetHeight / 2) + (blah = document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "px"'),t.marginTop=0;else if(!i.centerY&&k){var n=i.css&&i.css.top?parseInt(i.css.top,10):0,s="((document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "+n+') + "px"';t.setExpression("top",s)}})}if(v&&(i.theme?w.find(".ui-widget-content").append(v):w.append(v),(v.jquery||v.nodeType)&&e(v).show()),(r||i.forceIframe)&&i.showOverlay&&g.show(),i.fadeIn){var j=i.onBlock?i.onBlock:c,H=i.showOverlay&&!v?j:c,z=v?j:c;i.showOverlay&&I._fadeIn(i.fadeIn,H),v&&w._fadeIn(i.fadeIn,z)}else i.showOverlay&&I.show(),v&&w.show(),i.onBlock&&i.onBlock();if(n(1,o,i),k?(b=w[0],p=e(i.focusableElements,b),i.focusInput&&setTimeout(l,20)):a(w[0],i.centerX,i.centerY),i.timeout){var W=setTimeout(function(){k?e.unblockUI(i):e(o).unblock(i)},i.timeout);e(o).data("blockUI.timeout",W)}}}function t(o,t){var s,l=o==window,a=e(o),d=a.data("blockUI.history"),c=a.data("blockUI.timeout");c&&(clearTimeout(c),a.removeData("blockUI.timeout")),t=e.extend({},e.blockUI.defaults,t||{}),n(0,o,t),null===t.onUnblock&&(t.onUnblock=a.data("blockUI.onUnblock"),a.removeData("blockUI.onUnblock"));var r;r=l?e("body").children().filter(".blockUI").add("body > .blockUI"):a.find(">.blockUI"),t.cursorReset&&(r.length>1&&(r[1].style.cursor=t.cursorReset),r.length>2&&(r[2].style.cursor=t.cursorReset)),l&&(b=p=null),t.fadeOut?(s=r.length,r.stop().fadeOut(t.fadeOut,function(){0===--s&&i(r,d,t,o)})):i(r,d,t,o)}function i(o,t,i,n){var s=e(n);if(!s.data("blockUI.isBlocked")){o.each(function(){this.parentNode&&this.parentNode.removeChild(this)}),t&&t.el&&(t.el.style.display=t.display,t.el.style.position=t.position,t.parent&&t.parent.appendChild(t.el),s.removeData("blockUI.history")),s.data("blockUI.static")&&s.css("position","static"),"function"==typeof i.onUnblock&&i.onUnblock(n,i);var l=e(document.body),a=l.width(),d=l[0].style.width;l.width(a-1).width(a),l[0].style.width=d}}function n(o,t,i){var n=t==window,l=e(t);if((o||(!n||b)&&(n||l.data("blockUI.isBlocked")))&&(l.data("blockUI.isBlocked",o),n&&i.bindEvents&&(!o||i.showOverlay))){var a="mousedown mouseup keydown keypress keyup touchstart touchend touchmove";o?e(document).bind(a,i,s):e(document).unbind(a,s)}}function s(o){if("keydown"===o.type&&o.keyCode&&9==o.keyCode&&b&&o.data.constrainTabKey){var t=p,i=!o.shiftKey&&o.target===t[t.length-1],n=o.shiftKey&&o.target===t[0];if(i||n)return setTimeout(function(){l(n)},10),!1}var s=o.data,a=e(o.target);return a.hasClass("blockOverlay")&&s.onOverlayClick&&s.onOverlayClick(o),a.parents("div."+s.blockMsgClass).length>0?!0:0===a.parents().children().filter("div.blockUI").length}function l(e){if(p){var o=p[e===!0?p.length-1:0];o&&o.focus()}}function a(e,o,t){var i=e.parentNode,n=e.style,s=(i.offsetWidth-e.offsetWidth)/2-d(i,"borderLeftWidth"),l=(i.offsetHeight-e.offsetHeight)/2-d(i,"borderTopWidth");o&&(n.left=s>0?s+"px":"0"),t&&(n.top=l>0?l+"px":"0")}function d(o,t){return parseInt(e.css(o,t),10)||0}e.fn._fadeIn=e.fn.fadeIn;var c=e.noop||function(){},r=/MSIE/.test(navigator.userAgent),u=/MSIE 6.0/.test(navigator.userAgent)&&!/MSIE 8.0/.test(navigator.userAgent);document.documentMode||0;var f=e.isFunction(document.createElement("div").style.setExpression);e.blockUI=function(e){o(window,e)},e.unblockUI=function(e){t(window,e)},e.growlUI=function(o,t,i,n){var s=e('
      ');o&&s.append("

      "+o+"

      "),t&&s.append("

      "+t+"

      "),void 0===i&&(i=3e3);var l=function(o){o=o||{},e.blockUI({message:s,fadeIn:o.fadeIn!==void 0?o.fadeIn:700,fadeOut:o.fadeOut!==void 0?o.fadeOut:1e3,timeout:o.timeout!==void 0?o.timeout:i,centerY:!1,showOverlay:!1,onUnblock:n,css:e.blockUI.defaults.growlCSS})};l(),s.css("opacity"),s.mouseover(function(){l({fadeIn:0,timeout:3e4});var o=e(".blockMsg");o.stop(),o.fadeTo(300,1)}).mouseout(function(){e(".blockMsg").fadeOut(1e3)})},e.fn.block=function(t){if(this[0]===window)return e.blockUI(t),this;var i=e.extend({},e.blockUI.defaults,t||{});return this.each(function(){var o=e(this);i.ignoreIfBlocked&&o.data("blockUI.isBlocked")||o.unblock({fadeOut:0})}),this.each(function(){"static"==e.css(this,"position")&&(this.style.position="relative",e(this).data("blockUI.static",!0)),this.style.zoom=1,o(this,t)})},e.fn.unblock=function(o){return this[0]===window?(e.unblockUI(o),this):this.each(function(){t(this,o)})},e.blockUI.version=2.66,e.blockUI.defaults={message:"

      Please wait...

      ",title:null,draggable:!0,theme:!1,css:{padding:0,margin:0,width:"30%",top:"40%",left:"35%",textAlign:"center",color:"#000",border:"3px solid #aaa",backgroundColor:"#fff",cursor:"wait"},themedCSS:{width:"30%",top:"40%",left:"35%"},overlayCSS:{backgroundColor:"#000",opacity:.6,cursor:"wait"},cursorReset:"default",growlCSS:{width:"350px",top:"10px",left:"",right:"10px",border:"none",padding:"5px",opacity:.6,cursor:"default",color:"#fff",backgroundColor:"#000","-webkit-border-radius":"10px","-moz-border-radius":"10px","border-radius":"10px"},iframeSrc:/^https/i.test(window.location.href||"")?"javascript:false":"about:blank",forceIframe:!1,baseZ:1e3,centerX:!0,centerY:!0,allowBodyStretch:!0,bindEvents:!0,constrainTabKey:!0,fadeIn:200,fadeOut:400,timeout:0,showOverlay:!0,focusInput:!0,focusableElements:":input:enabled:visible",onBlock:null,onUnblock:null,onOverlayClick:null,quirksmodeOffsetHack:4,blockMsgClass:"blockMsg",ignoreIfBlocked:!1};var b=null,p=[]}"function"==typeof define&&define.amd&&define.amd.jQuery?define(["jquery"],e):e(jQuery)})(); \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js new file mode 100644 index 0000000000000..bd2dacb4eeebd --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js @@ -0,0 +1,18 @@ +/** + * Copyright (c) 2005 - 2010, James Auldridge + * All rights reserved. + * + * Licensed under the BSD, MIT, and GPL (your choice!) Licenses: + * http://code.google.com/p/cookies/wiki/License + * + */ +var jaaulde=window.jaaulde||{};jaaulde.utils=jaaulde.utils||{};jaaulde.utils.cookies=(function(){var resolveOptions,assembleOptionsString,parseCookies,constructor,defaultOptions={expiresAt:null,path:'/',domain:null,secure:false};resolveOptions=function(options){var returnValue,expireDate;if(typeof options!=='object'||options===null){returnValue=defaultOptions;}else +{returnValue={expiresAt:defaultOptions.expiresAt,path:defaultOptions.path,domain:defaultOptions.domain,secure:defaultOptions.secure};if(typeof options.expiresAt==='object'&&options.expiresAt instanceof Date){returnValue.expiresAt=options.expiresAt;}else if(typeof options.hoursToLive==='number'&&options.hoursToLive!==0){expireDate=new Date();expireDate.setTime(expireDate.getTime()+(options.hoursToLive*60*60*1000));returnValue.expiresAt=expireDate;}if(typeof options.path==='string'&&options.path!==''){returnValue.path=options.path;}if(typeof options.domain==='string'&&options.domain!==''){returnValue.domain=options.domain;}if(options.secure===true){returnValue.secure=options.secure;}}return returnValue;};assembleOptionsString=function(options){options=resolveOptions(options);return((typeof options.expiresAt==='object'&&options.expiresAt instanceof Date?'; expires='+options.expiresAt.toGMTString():'')+'; path='+options.path+(typeof options.domain==='string'?'; domain='+options.domain:'')+(options.secure===true?'; secure':''));};parseCookies=function(){var cookies={},i,pair,name,value,separated=document.cookie.split(';'),unparsedValue;for(i=0;i.sorting_1,table.dataTable.order-column tbody tr>.sorting_2,table.dataTable.order-column tbody tr>.sorting_3,table.dataTable.display tbody tr>.sorting_1,table.dataTable.display tbody tr>.sorting_2,table.dataTable.display tbody tr>.sorting_3{background-color:#f9f9f9}table.dataTable.order-column tbody tr.selected>.sorting_1,table.dataTable.order-column tbody tr.selected>.sorting_2,table.dataTable.order-column tbody tr.selected>.sorting_3,table.dataTable.display tbody tr.selected>.sorting_1,table.dataTable.display tbody tr.selected>.sorting_2,table.dataTable.display tbody tr.selected>.sorting_3{background-color:#acbad4}table.dataTable.display tbody tr.odd>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd>.sorting_1{background-color:#f1f1f1}table.dataTable.display tbody tr.odd>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd>.sorting_2{background-color:#f3f3f3}table.dataTable.display tbody tr.odd>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd>.sorting_3{background-color:#f5f5f5}table.dataTable.display tbody tr.odd.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_1{background-color:#a6b3cd}table.dataTable.display tbody tr.odd.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_2{background-color:#a7b5ce}table.dataTable.display tbody tr.odd.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_3{background-color:#a9b6d0}table.dataTable.display tbody tr.even>.sorting_1,table.dataTable.order-column.stripe tbody tr.even>.sorting_1{background-color:#f9f9f9}table.dataTable.display tbody tr.even>.sorting_2,table.dataTable.order-column.stripe tbody tr.even>.sorting_2{background-color:#fbfbfb}table.dataTable.display tbody tr.even>.sorting_3,table.dataTable.order-column.stripe tbody tr.even>.sorting_3{background-color:#fdfdfd}table.dataTable.display tbody tr.even.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_1{background-color:#acbad4}table.dataTable.display tbody tr.even.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_2{background-color:#adbbd6}table.dataTable.display tbody tr.even.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_3{background-color:#afbdd8}table.dataTable.display tbody tr:hover>.sorting_1,table.dataTable.display tbody tr.odd:hover>.sorting_1,table.dataTable.display tbody tr.even:hover>.sorting_1,table.dataTable.order-column.hover tbody tr:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_1{background-color:#eaeaea}table.dataTable.display tbody tr:hover>.sorting_2,table.dataTable.display tbody tr.odd:hover>.sorting_2,table.dataTable.display tbody tr.even:hover>.sorting_2,table.dataTable.order-column.hover tbody tr:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_2{background-color:#ebebeb}table.dataTable.display tbody tr:hover>.sorting_3,table.dataTable.display tbody tr.odd:hover>.sorting_3,table.dataTable.display tbody tr.even:hover>.sorting_3,table.dataTable.order-column.hover tbody tr:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_3{background-color:#eee}table.dataTable.display tbody tr:hover.selected>.sorting_1,table.dataTable.display tbody tr.odd:hover.selected>.sorting_1,table.dataTable.display tbody tr.even:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_1{background-color:#a1aec7}table.dataTable.display tbody tr:hover.selected>.sorting_2,table.dataTable.display tbody tr.odd:hover.selected>.sorting_2,table.dataTable.display tbody tr.even:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_2{background-color:#a2afc8}table.dataTable.display tbody tr:hover.selected>.sorting_3,table.dataTable.display tbody tr.odd:hover.selected>.sorting_3,table.dataTable.display tbody tr.even:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_3{background-color:#a4b2cb}table.dataTable.no-footer{border-bottom:1px solid #111}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.compact thead th,table.dataTable.compact thead td{padding:5px 9px}table.dataTable.compact tfoot th,table.dataTable.compact tfoot td{padding:5px 9px 3px 9px}table.dataTable.compact tbody th,table.dataTable.compact tbody td{padding:4px 5px}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable,table.dataTable th,table.dataTable td{-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box}.dataTables_wrapper{position:relative;clear:both;*zoom:1;zoom:1}.dataTables_wrapper .dataTables_length{float:left}.dataTables_wrapper .dataTables_filter{float:right;text-align:right}.dataTables_wrapper .dataTables_filter input{margin-left:0.5em}.dataTables_wrapper .dataTables_info{clear:both;float:left;padding-top:0.755em}.dataTables_wrapper .dataTables_paginate{float:right;text-align:right;padding-top:0.25em}.dataTables_wrapper .dataTables_paginate .paginate_button{box-sizing:border-box;display:inline-block;min-width:1.5em;padding:0.5em 1em;margin-left:2px;text-align:center;text-decoration:none !important;cursor:pointer;*cursor:hand;color:#333 !important;border:1px solid transparent}.dataTables_wrapper .dataTables_paginate .paginate_button.current,.dataTables_wrapper .dataTables_paginate .paginate_button.current:hover{color:#333 !important;border:1px solid #cacaca;background-color:#fff;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #fff), color-stop(100%, #dcdcdc));background:-webkit-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-moz-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-ms-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-o-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:linear-gradient(to bottom, #fff 0%, #dcdcdc 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button.disabled,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:hover,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:active{cursor:default;color:#666 !important;border:1px solid transparent;background:transparent;box-shadow:none}.dataTables_wrapper .dataTables_paginate .paginate_button:hover{color:white !important;border:1px solid #111;background-color:#585858;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #585858), color-stop(100%, #111));background:-webkit-linear-gradient(top, #585858 0%, #111 100%);background:-moz-linear-gradient(top, #585858 0%, #111 100%);background:-ms-linear-gradient(top, #585858 0%, #111 100%);background:-o-linear-gradient(top, #585858 0%, #111 100%);background:linear-gradient(to bottom, #585858 0%, #111 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button:active{outline:none;background-color:#2b2b2b;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #2b2b2b), color-stop(100%, #0c0c0c));background:-webkit-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-moz-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-ms-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-o-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:linear-gradient(to bottom, #2b2b2b 0%, #0c0c0c 100%);box-shadow:inset 0 0 3px #111}.dataTables_wrapper .dataTables_processing{position:absolute;top:50%;left:50%;width:100%;height:40px;margin-left:-50%;margin-top:-25px;padding-top:20px;text-align:center;font-size:1.2em;background-color:white;background:-webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0)));background:-webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%)}.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter,.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_processing,.dataTables_wrapper .dataTables_paginate{color:#333}.dataTables_wrapper .dataTables_scroll{clear:both}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody{*margin-top:-1px;-webkit-overflow-scrolling:touch}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody th>div.dataTables_sizing,.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody td>div.dataTables_sizing{height:0;overflow:hidden;margin:0 !important;padding:0 !important}.dataTables_wrapper.no-footer .dataTables_scrollBody{border-bottom:1px solid #111}.dataTables_wrapper.no-footer div.dataTables_scrollHead table,.dataTables_wrapper.no-footer div.dataTables_scrollBody table{border-bottom:none}.dataTables_wrapper:after{visibility:hidden;display:block;content:"";clear:both;height:0}@media screen and (max-width: 767px){.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_paginate{float:none;text-align:center}.dataTables_wrapper .dataTables_paginate{margin-top:0.5em}}@media screen and (max-width: 640px){.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter{float:none;text-align:center}.dataTables_wrapper .dataTables_filter{margin-top:0.5em}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js new file mode 100644 index 0000000000000..8885017c35d0d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js @@ -0,0 +1,157 @@ +/*! DataTables 1.10.4 + * ©2008-2014 SpryMedia Ltd - datatables.net/license + */ +(function(Da,P,l){var O=function(g){function V(a){var b,c,e={};g.each(a,function(d){if((b=d.match(/^([^A-Z]+?)([A-Z])/))&&-1!=="a aa ai ao as b fn i m o s ".indexOf(b[1]+" "))c=d.replace(b[0],b[2].toLowerCase()),e[c]=d,"o"===b[1]&&V(a[d])});a._hungarianMap=e}function G(a,b,c){a._hungarianMap||V(a);var e;g.each(b,function(d){e=a._hungarianMap[d];if(e!==l&&(c||b[e]===l))"o"===e.charAt(0)?(b[e]||(b[e]={}),g.extend(!0,b[e],b[d]),G(a[e],b[e],c)):b[e]=b[d]})}function O(a){var b=p.defaults.oLanguage,c=a.sZeroRecords; +!a.sEmptyTable&&(c&&"No data available in table"===b.sEmptyTable)&&D(a,a,"sZeroRecords","sEmptyTable");!a.sLoadingRecords&&(c&&"Loading..."===b.sLoadingRecords)&&D(a,a,"sZeroRecords","sLoadingRecords");a.sInfoThousands&&(a.sThousands=a.sInfoThousands);(a=a.sDecimal)&&cb(a)}function db(a){z(a,"ordering","bSort");z(a,"orderMulti","bSortMulti");z(a,"orderClasses","bSortClasses");z(a,"orderCellsTop","bSortCellsTop");z(a,"order","aaSorting");z(a,"orderFixed","aaSortingFixed");z(a,"paging","bPaginate"); +z(a,"pagingType","sPaginationType");z(a,"pageLength","iDisplayLength");z(a,"searching","bFilter");if(a=a.aoSearchCols)for(var b=0,c=a.length;b").css({position:"absolute",top:0,left:0,height:1,width:1,overflow:"hidden"}).append(g("
      ").css({position:"absolute",top:1,left:1,width:100, +overflow:"scroll"}).append(g('
      ').css({width:"100%",height:10}))).appendTo("body"),c=b.find(".test");a.bScrollOversize=100===c[0].offsetWidth;a.bScrollbarLeft=1!==c.offset().left;b.remove()}function gb(a,b,c,e,d,f){var h,i=!1;c!==l&&(h=c,i=!0);for(;e!==d;)a.hasOwnProperty(e)&&(h=i?b(h,a[e],e,a):a[e],i=!0,e+=f);return h}function Ea(a,b){var c=p.defaults.column,e=a.aoColumns.length,c=g.extend({},p.models.oColumn,c,{nTh:b?b:P.createElement("th"),sTitle:c.sTitle?c.sTitle:b?b.innerHTML: +"",aDataSort:c.aDataSort?c.aDataSort:[e],mData:c.mData?c.mData:e,idx:e});a.aoColumns.push(c);c=a.aoPreSearchCols;c[e]=g.extend({},p.models.oSearch,c[e]);ja(a,e,null)}function ja(a,b,c){var b=a.aoColumns[b],e=a.oClasses,d=g(b.nTh);if(!b.sWidthOrig){b.sWidthOrig=d.attr("width")||null;var f=(d.attr("style")||"").match(/width:\s*(\d+[pxem%]+)/);f&&(b.sWidthOrig=f[1])}c!==l&&null!==c&&(eb(c),G(p.defaults.column,c),c.mDataProp!==l&&!c.mData&&(c.mData=c.mDataProp),c.sType&&(b._sManualType=c.sType),c.className&& +!c.sClass&&(c.sClass=c.className),g.extend(b,c),D(b,c,"sWidth","sWidthOrig"),"number"===typeof c.iDataSort&&(b.aDataSort=[c.iDataSort]),D(b,c,"aDataSort"));var h=b.mData,i=W(h),j=b.mRender?W(b.mRender):null,c=function(a){return"string"===typeof a&&-1!==a.indexOf("@")};b._bAttrSrc=g.isPlainObject(h)&&(c(h.sort)||c(h.type)||c(h.filter));b.fnGetData=function(a,b,c){var e=i(a,b,l,c);return j&&b?j(e,b,a,c):e};b.fnSetData=function(a,b,c){return Q(h)(a,b,c)};"number"!==typeof h&&(a._rowReadObject=!0);a.oFeatures.bSort|| +(b.bSortable=!1,d.addClass(e.sSortableNone));a=-1!==g.inArray("asc",b.asSorting);c=-1!==g.inArray("desc",b.asSorting);!b.bSortable||!a&&!c?(b.sSortingClass=e.sSortableNone,b.sSortingClassJUI=""):a&&!c?(b.sSortingClass=e.sSortableAsc,b.sSortingClassJUI=e.sSortJUIAscAllowed):!a&&c?(b.sSortingClass=e.sSortableDesc,b.sSortingClassJUI=e.sSortJUIDescAllowed):(b.sSortingClass=e.sSortable,b.sSortingClassJUI=e.sSortJUI)}function X(a){if(!1!==a.oFeatures.bAutoWidth){var b=a.aoColumns;Fa(a);for(var c=0,e=b.length;c< +e;c++)b[c].nTh.style.width=b[c].sWidth}b=a.oScroll;(""!==b.sY||""!==b.sX)&&Y(a);u(a,null,"column-sizing",[a])}function ka(a,b){var c=Z(a,"bVisible");return"number"===typeof c[b]?c[b]:null}function $(a,b){var c=Z(a,"bVisible"),c=g.inArray(b,c);return-1!==c?c:null}function aa(a){return Z(a,"bVisible").length}function Z(a,b){var c=[];g.map(a.aoColumns,function(a,d){a[b]&&c.push(d)});return c}function Ga(a){var b=a.aoColumns,c=a.aoData,e=p.ext.type.detect,d,f,h,i,j,g,m,o,k;d=0;for(f=b.length;do[f])e(m.length+o[f],n);else if("string"===typeof o[f]){i=0;for(j=m.length;ib&&a[d]--; -1!=e&&c===l&& +a.splice(e,1)}function ca(a,b,c,e){var d=a.aoData[b],f,h=function(c,f){for(;c.childNodes.length;)c.removeChild(c.firstChild);c.innerHTML=v(a,b,f,"display")};if("dom"===c||(!c||"auto"===c)&&"dom"===d.src)d._aData=ma(a,d,e,e===l?l:d._aData).data;else{var i=d.anCells;if(i)if(e!==l)h(i[e],e);else{c=0;for(f=i.length;c").appendTo(h));b=0;for(c= +m.length;btr").attr("role","row");g(h).find(">tr>th, >tr>td").addClass(n.sHeaderTH);g(i).find(">tr>th, >tr>td").addClass(n.sFooterTH);if(null!==i){a=a.aoFooter[0];b=0;for(c=a.length;b=a.fnRecordsDisplay()?0:h,a.iInitDisplayStart=-1);var h=a._iDisplayStart,n=a.fnDisplayEnd();if(a.bDeferLoading)a.bDeferLoading= +!1,a.iDraw++,B(a,!1);else if(i){if(!a.bDestroying&&!jb(a))return}else a.iDraw++;if(0!==j.length){f=i?a.aoData.length:n;for(i=i?0:h;i",{"class":d? +e[0]:""}).append(g("",{valign:"top",colSpan:aa(a),"class":a.oClasses.sRowEmpty}).html(c))[0];u(a,"aoHeaderCallback","header",[g(a.nTHead).children("tr")[0],Ka(a),h,n,j]);u(a,"aoFooterCallback","footer",[g(a.nTFoot).children("tr")[0],Ka(a),h,n,j]);e=g(a.nTBody);e.children().detach();e.append(g(b));u(a,"aoDrawCallback","draw",[a]);a.bSorted=!1;a.bFiltered=!1;a.bDrawing=!1}}function M(a,b){var c=a.oFeatures,e=c.bFilter;c.bSort&&kb(a);e?fa(a,a.oPreviousSearch):a.aiDisplay=a.aiDisplayMaster.slice(); +!0!==b&&(a._iDisplayStart=0);a._drawHold=b;L(a);a._drawHold=!1}function lb(a){var b=a.oClasses,c=g(a.nTable),c=g("
      ").insertBefore(c),e=a.oFeatures,d=g("
      ",{id:a.sTableId+"_wrapper","class":b.sWrapper+(a.nTFoot?"":" "+b.sNoFooter)});a.nHolding=c[0];a.nTableWrapper=d[0];a.nTableReinsertBefore=a.nTable.nextSibling;for(var f=a.sDom.split(""),h,i,j,n,m,o,k=0;k")[0];n=f[k+1];if("'"==n||'"'==n){m="";for(o=2;f[k+o]!=n;)m+=f[k+o],o++;"H"==m?m=b.sJUIHeader: +"F"==m&&(m=b.sJUIFooter);-1!=m.indexOf(".")?(n=m.split("."),j.id=n[0].substr(1,n[0].length-1),j.className=n[1]):"#"==m.charAt(0)?j.id=m.substr(1,m.length-1):j.className=m;k+=o}d.append(j);d=g(j)}else if(">"==i)d=d.parent();else if("l"==i&&e.bPaginate&&e.bLengthChange)h=mb(a);else if("f"==i&&e.bFilter)h=nb(a);else if("r"==i&&e.bProcessing)h=ob(a);else if("t"==i)h=pb(a);else if("i"==i&&e.bInfo)h=qb(a);else if("p"==i&&e.bPaginate)h=rb(a);else if(0!==p.ext.feature.length){j=p.ext.feature;o=0;for(n=j.length;o< +n;o++)if(i==j[o].cFeature){h=j[o].fnInit(a);break}}h&&(j=a.aanFeatures,j[i]||(j[i]=[]),j[i].push(h),d.append(h))}c.replaceWith(d)}function da(a,b){var c=g(b).children("tr"),e,d,f,h,i,j,n,m,o,k;a.splice(0,a.length);f=0;for(j=c.length;f',i=e.sSearch,i=i.match(/_INPUT_/)?i.replace("_INPUT_",h):i+h,b=g("
      ",{id:!f.f?c+"_filter":null,"class":b.sFilter}).append(g("
      ").addClass(b.sLength);a.aanFeatures.l||(j[0].id=c+"_length");j.children().append(a.oLanguage.sLengthMenu.replace("_MENU_",d[0].outerHTML));g("select",j).val(a._iDisplayLength).bind("change.DT", +function(){Qa(a,g(this).val());L(a)});g(a.nTable).bind("length.dt.DT",function(b,c,f){a===c&&g("select",j).val(f)});return j[0]}function rb(a){var b=a.sPaginationType,c=p.ext.pager[b],e="function"===typeof c,d=function(a){L(a)},b=g("
      ").addClass(a.oClasses.sPaging+b)[0],f=a.aanFeatures;e||c.fnInit(a,b,d);f.p||(b.id=a.sTableId+"_paginate",a.aoDrawCallback.push({fn:function(a){if(e){var b=a._iDisplayStart,g=a._iDisplayLength,n=a.fnRecordsDisplay(),m=-1===g,b=m?0:Math.ceil(b/g),g=m?1:Math.ceil(n/ +g),n=c(b,g),o,m=0;for(o=f.p.length;mf&&(e=0)):"first"==b?e=0:"previous"==b?(e=0<=d?e-d:0,0>e&&(e=0)):"next"==b?e+d",{id:!a.aanFeatures.r?a.sTableId+"_processing":null,"class":a.oClasses.sProcessing}).html(a.oLanguage.sProcessing).insertBefore(a.nTable)[0]}function B(a,b){a.oFeatures.bProcessing&&g(a.aanFeatures.r).css("display",b?"block":"none");u(a,null,"processing",[a,b])}function pb(a){var b=g(a.nTable);b.attr("role","grid");var c=a.oScroll;if(""===c.sX&&""===c.sY)return a.nTable;var e=c.sX,d=c.sY,f=a.oClasses,h=b.children("caption"),i=h.length?h[0]._captionSide:null, +j=g(b[0].cloneNode(!1)),n=g(b[0].cloneNode(!1)),m=b.children("tfoot");c.sX&&"100%"===b.attr("width")&&b.removeAttr("width");m.length||(m=null);c=g("
      ",{"class":f.sScrollWrapper}).append(g("
      ",{"class":f.sScrollHead}).css({overflow:"hidden",position:"relative",border:0,width:e?!e?null:s(e):"100%"}).append(g("
      ",{"class":f.sScrollHeadInner}).css({"box-sizing":"content-box",width:c.sXInner||"100%"}).append(j.removeAttr("id").css("margin-left",0).append("top"===i?h:null).append(b.children("thead"))))).append(g("
      ", +{"class":f.sScrollBody}).css({overflow:"auto",height:!d?null:s(d),width:!e?null:s(e)}).append(b));m&&c.append(g("
      ",{"class":f.sScrollFoot}).css({overflow:"hidden",border:0,width:e?!e?null:s(e):"100%"}).append(g("
      ",{"class":f.sScrollFootInner}).append(n.removeAttr("id").css("margin-left",0).append("bottom"===i?h:null).append(b.children("tfoot")))));var b=c.children(),o=b[0],f=b[1],k=m?b[2]:null;e&&g(f).scroll(function(){var a=this.scrollLeft;o.scrollLeft=a;m&&(k.scrollLeft=a)});a.nScrollHead= +o;a.nScrollBody=f;a.nScrollFoot=k;a.aoDrawCallback.push({fn:Y,sName:"scrolling"});return c[0]}function Y(a){var b=a.oScroll,c=b.sX,e=b.sXInner,d=b.sY,f=b.iBarWidth,h=g(a.nScrollHead),i=h[0].style,j=h.children("div"),n=j[0].style,m=j.children("table"),j=a.nScrollBody,o=g(j),k=j.style,l=g(a.nScrollFoot).children("div"),p=l.children("table"),r=g(a.nTHead),q=g(a.nTable),t=q[0],N=t.style,J=a.nTFoot?g(a.nTFoot):null,u=a.oBrowser,w=u.bScrollOversize,y,v,x,K,z,A=[],B=[],C=[],D,E=function(a){a=a.style;a.paddingTop= +"0";a.paddingBottom="0";a.borderTopWidth="0";a.borderBottomWidth="0";a.height=0};q.children("thead, tfoot").remove();z=r.clone().prependTo(q);y=r.find("tr");x=z.find("tr");z.find("th, td").removeAttr("tabindex");J&&(K=J.clone().prependTo(q),v=J.find("tr"),K=K.find("tr"));c||(k.width="100%",h[0].style.width="100%");g.each(pa(a,z),function(b,c){D=ka(a,b);c.style.width=a.aoColumns[D].sWidth});J&&F(function(a){a.style.width=""},K);b.bCollapse&&""!==d&&(k.height=o[0].offsetHeight+r[0].offsetHeight+"px"); +h=q.outerWidth();if(""===c){if(N.width="100%",w&&(q.find("tbody").height()>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(q.outerWidth()-f)}else""!==e?N.width=s(e):h==o.width()&&o.height()h-f&&(N.width=s(h))):N.width=s(h);h=q.outerWidth();F(E,x);F(function(a){C.push(a.innerHTML);A.push(s(g(a).css("width")))},x);F(function(a,b){a.style.width=A[b]},y);g(x).height(0);J&&(F(E,K),F(function(a){B.push(s(g(a).css("width")))},K),F(function(a,b){a.style.width= +B[b]},v),g(K).height(0));F(function(a,b){a.innerHTML='
      '+C[b]+"
      ";a.style.width=A[b]},x);J&&F(function(a,b){a.innerHTML="";a.style.width=B[b]},K);if(q.outerWidth()j.offsetHeight||"scroll"==o.css("overflow-y")?h+f:h;if(w&&(j.scrollHeight>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(v-f);(""===c||""!==e)&&R(a,1,"Possible column misalignment",6)}else v="100%";k.width=s(v);i.width=s(v);J&&(a.nScrollFoot.style.width= +s(v));!d&&w&&(k.height=s(t.offsetHeight+f));d&&b.bCollapse&&(k.height=s(d),b=c&&t.offsetWidth>j.offsetWidth?f:0,t.offsetHeightj.clientHeight||"scroll"==o.css("overflow-y");u="padding"+(u.bScrollbarLeft?"Left":"Right");n[u]=m?f+"px":"0px";J&&(p[0].style.width=s(b),l[0].style.width=s(b),l[0].style[u]=m?f+"px":"0px");o.scroll();if((a.bSorted||a.bFiltered)&&!a._drawHold)j.scrollTop=0}function F(a, +b,c){for(var e=0,d=0,f=b.length,h,g;d"));i.find("tfoot th, tfoot td").css("width","");var p=i.find("tbody tr"),j=pa(a,i.find("thead")[0]);for(k=0;k").css("width",s(a)).appendTo(b||P.body),e=c[0].offsetWidth;c.remove();return e}function Eb(a,b){var c=a.oScroll;if(c.sX||c.sY)c=!c.sX?c.iBarWidth:0,b.style.width=s(g(b).outerWidth()-c)}function Db(a,b){var c=Fb(a,b);if(0>c)return null; +var e=a.aoData[c];return!e.nTr?g("").html(v(a,c,b,"display"))[0]:e.anCells[b]}function Fb(a,b){for(var c,e=-1,d=-1,f=0,h=a.aoData.length;fe&&(e=c.length,d=f);return d}function s(a){return null===a?"0px":"number"==typeof a?0>a?"0px":a+"px":a.match(/\d$/)?a+"px":a}function Gb(){if(!p.__scrollbarWidth){var a=g("

      ").css({width:"100%",height:200,padding:0})[0],b=g("

      ").css({position:"absolute",top:0,left:0,width:200,height:150,padding:0, +overflow:"hidden",visibility:"hidden"}).append(a).appendTo("body"),c=a.offsetWidth;b.css("overflow","scroll");a=a.offsetWidth;c===a&&(a=b[0].clientWidth);b.remove();p.__scrollbarWidth=c-a}return p.__scrollbarWidth}function T(a){var b,c,e=[],d=a.aoColumns,f,h,i,j;b=a.aaSortingFixed;c=g.isPlainObject(b);var n=[];f=function(a){a.length&&!g.isArray(a[0])?n.push(a):n.push.apply(n,a)};g.isArray(b)&&f(b);c&&b.pre&&f(b.pre);f(a.aaSorting);c&&b.post&&f(b.post);for(a=0;ad?1:0,0!==c)return"asc"===g.dir?c:-c;c=e[a];d=e[b];return cd?1:0}):j.sort(function(a,b){var c,h,g,i,j=n.length,l=f[a]._aSortData,p=f[b]._aSortData;for(g=0;gh?1:0})}a.bSorted=!0}function Ib(a){for(var b,c,e=a.aoColumns,d=T(a),a=a.oLanguage.oAria,f=0,h=e.length;f/g,"");var j=c.nTh;j.removeAttribute("aria-sort");c.bSortable&&(0d?d+1:3));d=0;for(f=e.length;dd?d+1:3))}a.aLastSort=e}function Hb(a,b){var c=a.aoColumns[b],e=p.ext.order[c.sSortDataType],d;e&&(d=e.call(a.oInstance,a,b,$(a,b)));for(var f,h=p.ext.type.order[c.sType+"-pre"],g=0,j=a.aoData.length;g< +j;g++)if(c=a.aoData[g],c._aSortData||(c._aSortData=[]),!c._aSortData[b]||e)f=e?d[g]:v(a,g,b,"sort"),c._aSortData[b]=h?h(f):f}function xa(a){if(a.oFeatures.bStateSave&&!a.bDestroying){var b={time:+new Date,start:a._iDisplayStart,length:a._iDisplayLength,order:g.extend(!0,[],a.aaSorting),search:yb(a.oPreviousSearch),columns:g.map(a.aoColumns,function(b,e){return{visible:b.bVisible,search:yb(a.aoPreSearchCols[e])}})};u(a,"aoStateSaveParams","stateSaveParams",[a,b]);a.oSavedState=b;a.fnStateSaveCallback.call(a.oInstance, +a,b)}}function Jb(a){var b,c,e=a.aoColumns;if(a.oFeatures.bStateSave){var d=a.fnStateLoadCallback.call(a.oInstance,a);if(d&&d.time&&(b=u(a,"aoStateLoadParams","stateLoadParams",[a,d]),-1===g.inArray(!1,b)&&(b=a.iStateDuration,!(0=e.length?[0,c[1]]:c)});g.extend(a.oPreviousSearch, +zb(d.search));b=0;for(c=d.columns.length;b=c&&(b=c-e);b-=b%e;if(-1===e||0>b)b=0;a._iDisplayStart=b}function Oa(a,b){var c=a.renderer,e=p.ext.renderer[b];return g.isPlainObject(c)&& +c[b]?e[c[b]]||e._:"string"===typeof c?e[c]||e._:e._}function A(a){return a.oFeatures.bServerSide?"ssp":a.ajax||a.sAjaxSource?"ajax":"dom"}function Va(a,b){var c=[],c=Lb.numbers_length,e=Math.floor(c/2);b<=c?c=U(0,b):a<=e?(c=U(0,c-2),c.push("ellipsis"),c.push(b-1)):(a>=b-1-e?c=U(b-(c-2),b):(c=U(a-1,a+2),c.push("ellipsis"),c.push(b-1)),c.splice(0,0,"ellipsis"),c.splice(0,0,0));c.DT_el="span";return c}function cb(a){g.each({num:function(b){return za(b,a)},"num-fmt":function(b){return za(b,a,Wa)},"html-num":function(b){return za(b, +a,Aa)},"html-num-fmt":function(b){return za(b,a,Aa,Wa)}},function(b,c){w.type.order[b+a+"-pre"]=c;b.match(/^html\-/)&&(w.type.search[b+a]=w.type.search.html)})}function Mb(a){return function(){var b=[ya(this[p.ext.iApiIndex])].concat(Array.prototype.slice.call(arguments));return p.ext.internal[a].apply(this,b)}}var p,w,q,r,t,Xa={},Nb=/[\r\n]/g,Aa=/<.*?>/g,$b=/^[\w\+\-]/,ac=/[\w\+\-]$/,Xb=RegExp("(\\/|\\.|\\*|\\+|\\?|\\||\\(|\\)|\\[|\\]|\\{|\\}|\\\\|\\$|\\^|\\-)","g"),Wa=/[',$\u00a3\u20ac\u00a5%\u2009\u202F]/g, +H=function(a){return!a||!0===a||"-"===a?!0:!1},Ob=function(a){var b=parseInt(a,10);return!isNaN(b)&&isFinite(a)?b:null},Pb=function(a,b){Xa[b]||(Xa[b]=RegExp(ua(b),"g"));return"string"===typeof a&&"."!==b?a.replace(/\./g,"").replace(Xa[b],"."):a},Ya=function(a,b,c){var e="string"===typeof a;b&&e&&(a=Pb(a,b));c&&e&&(a=a.replace(Wa,""));return H(a)||!isNaN(parseFloat(a))&&isFinite(a)},Qb=function(a,b,c){return H(a)?!0:!(H(a)||"string"===typeof a)?null:Ya(a.replace(Aa,""),b,c)?!0:null},C=function(a, +b,c){var e=[],d=0,f=a.length;if(c!==l)for(;d")[0],Yb=va.textContent!==l,Zb=/<.*?>/g;p=function(a){this.$=function(a,b){return this.api(!0).$(a,b)};this._=function(a,b){return this.api(!0).rows(a,b).data()};this.api=function(a){return a?new q(ya(this[w.iApiIndex])):new q(this)};this.fnAddData=function(a,b){var c=this.api(!0),e=g.isArray(a)&&(g.isArray(a[0])||g.isPlainObject(a[0]))? +c.rows.add(a):c.row.add(a);(b===l||b)&&c.draw();return e.flatten().toArray()};this.fnAdjustColumnSizing=function(a){var b=this.api(!0).columns.adjust(),c=b.settings()[0],e=c.oScroll;a===l||a?b.draw(!1):(""!==e.sX||""!==e.sY)&&Y(c)};this.fnClearTable=function(a){var b=this.api(!0).clear();(a===l||a)&&b.draw()};this.fnClose=function(a){this.api(!0).row(a).child.hide()};this.fnDeleteRow=function(a,b,c){var e=this.api(!0),a=e.rows(a),d=a.settings()[0],g=d.aoData[a[0][0]];a.remove();b&&b.call(this,d,g); +(c===l||c)&&e.draw();return g};this.fnDestroy=function(a){this.api(!0).destroy(a)};this.fnDraw=function(a){this.api(!0).draw(!a)};this.fnFilter=function(a,b,c,e,d,g){d=this.api(!0);null===b||b===l?d.search(a,c,e,g):d.column(b).search(a,c,e,g);d.draw()};this.fnGetData=function(a,b){var c=this.api(!0);if(a!==l){var e=a.nodeName?a.nodeName.toLowerCase():"";return b!==l||"td"==e||"th"==e?c.cell(a,b).data():c.row(a).data()||null}return c.data().toArray()};this.fnGetNodes=function(a){var b=this.api(!0); +return a!==l?b.row(a).node():b.rows().nodes().flatten().toArray()};this.fnGetPosition=function(a){var b=this.api(!0),c=a.nodeName.toUpperCase();return"TR"==c?b.row(a).index():"TD"==c||"TH"==c?(a=b.cell(a).index(),[a.row,a.columnVisible,a.column]):null};this.fnIsOpen=function(a){return this.api(!0).row(a).child.isShown()};this.fnOpen=function(a,b,c){return this.api(!0).row(a).child(b,c).show().child()[0]};this.fnPageChange=function(a,b){var c=this.api(!0).page(a);(b===l||b)&&c.draw(!1)};this.fnSetColumnVis= +function(a,b,c){a=this.api(!0).column(a).visible(b);(c===l||c)&&a.columns.adjust().draw()};this.fnSettings=function(){return ya(this[w.iApiIndex])};this.fnSort=function(a){this.api(!0).order(a).draw()};this.fnSortListener=function(a,b,c){this.api(!0).order.listener(a,b,c)};this.fnUpdate=function(a,b,c,e,d){var g=this.api(!0);c===l||null===c?g.row(b).data(a):g.cell(b,c).data(a);(d===l||d)&&g.columns.adjust();(e===l||e)&&g.draw();return 0};this.fnVersionCheck=w.fnVersionCheck;var b=this,c=a===l,e=this.length; +c&&(a={});this.oApi=this.internal=w.internal;for(var d in p.ext.internal)d&&(this[d]=Mb(d));this.each(function(){var d={},d=1t<"F"ip>'),k.renderer)? +g.isPlainObject(k.renderer)&&!k.renderer.header&&(k.renderer.header="jqueryui"):k.renderer="jqueryui":g.extend(j,p.ext.classes,d.oClasses);g(this).addClass(j.sTable);if(""!==k.oScroll.sX||""!==k.oScroll.sY)k.oScroll.iBarWidth=Gb();!0===k.oScroll.sX&&(k.oScroll.sX="100%");k.iInitDisplayStart===l&&(k.iInitDisplayStart=d.iDisplayStart,k._iDisplayStart=d.iDisplayStart);null!==d.iDeferLoading&&(k.bDeferLoading=!0,h=g.isArray(d.iDeferLoading),k._iRecordsDisplay=h?d.iDeferLoading[0]:d.iDeferLoading,k._iRecordsTotal= +h?d.iDeferLoading[1]:d.iDeferLoading);var r=k.oLanguage;g.extend(!0,r,d.oLanguage);""!==r.sUrl&&(g.ajax({dataType:"json",url:r.sUrl,success:function(a){O(a);G(m.oLanguage,a);g.extend(true,r,a);ga(k)},error:function(){ga(k)}}),n=!0);null===d.asStripeClasses&&(k.asStripeClasses=[j.sStripeOdd,j.sStripeEven]);var h=k.asStripeClasses,q=g("tbody tr:eq(0)",this);-1!==g.inArray(!0,g.map(h,function(a){return q.hasClass(a)}))&&(g("tbody tr",this).removeClass(h.join(" ")),k.asDestroyStripes=h.slice());var o= +[],s,h=this.getElementsByTagName("thead");0!==h.length&&(da(k.aoHeader,h[0]),o=pa(k));if(null===d.aoColumns){s=[];h=0;for(i=o.length;h").appendTo(this));k.nTHead=i[0];i=g(this).children("tbody");0===i.length&&(i=g("").appendTo(this));k.nTBody=i[0];i=g(this).children("tfoot");if(0===i.length&&0").appendTo(this);0===i.length||0===i.children().length?g(this).addClass(j.sNoFooter): +0a?new q(b[a],this[a]):null},filter:function(a){var b=[];if(y.filter)b=y.filter.call(this,a,this);else for(var c=0,e=this.length;c").addClass(b);g("td",c).addClass(b).html(a)[0].colSpan=aa(e);d.push(c[0])}};if(g.isArray(a)||a instanceof g)for(var h=0,i=a.length;h=0?b:h.length+b];if(typeof a==="function"){var d=Ba(c,f);return g.map(h,function(b,f){return a(f,Vb(c,f,0,0,d),j[f])?f:null})}var k=typeof a==="string"?a.match(cc):"";if(k)switch(k[2]){case "visIdx":case "visible":b= +parseInt(k[1],10);if(b<0){var l=g.map(h,function(a,b){return a.bVisible?b:null});return[l[l.length+b]]}return[ka(c,b)];case "name":return g.map(i,function(a,b){return a===k[1]?b:null})}else return g(j).filter(a).map(function(){return g.inArray(this,j)}).toArray()})},1);c.selector.cols=a;c.selector.opts=b;return c});t("columns().header()","column().header()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].nTh},1)});t("columns().footer()","column().footer()",function(){return this.iterator("column", +function(a,b){return a.aoColumns[b].nTf},1)});t("columns().data()","column().data()",function(){return this.iterator("column-rows",Vb,1)});t("columns().dataSrc()","column().dataSrc()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].mData},1)});t("columns().cache()","column().cache()",function(a){return this.iterator("column-rows",function(b,c,e,d,f){return ha(b.aoData,f,"search"===a?"_aFilterData":"_aSortData",c)},1)});t("columns().nodes()","column().nodes()",function(){return this.iterator("column-rows", +function(a,b,c,e,d){return ha(a.aoData,d,"anCells",b)},1)});t("columns().visible()","column().visible()",function(a,b){return this.iterator("column",function(c,e){if(a===l)return c.aoColumns[e].bVisible;var d=c.aoColumns,f=d[e],h=c.aoData,i,j,n;if(a!==l&&f.bVisible!==a){if(a){var m=g.inArray(!0,C(d,"bVisible"),e+1);i=0;for(j=h.length;ie;return!0};p.isDataTable=p.fnIsDataTable=function(a){var b=g(a).get(0),c=!1;g.each(p.settings,function(a,d){if(d.nTable===b||d.nScrollHead===b||d.nScrollFoot===b)c=!0});return c};p.tables=p.fnTables=function(a){return g.map(p.settings,function(b){if(!a||a&&g(b.nTable).is(":visible"))return b.nTable})};p.util={throttle:ta,escapeRegex:ua}; +p.camelToHungarian=G;r("$()",function(a,b){var c=this.rows(b).nodes(),c=g(c);return g([].concat(c.filter(a).toArray(),c.find(a).toArray()))});g.each(["on","one","off"],function(a,b){r(b+"()",function(){var a=Array.prototype.slice.call(arguments);a[0].match(/\.dt\b/)||(a[0]+=".dt");var e=g(this.tables().nodes());e[b].apply(e,a);return this})});r("clear()",function(){return this.iterator("table",function(a){na(a)})});r("settings()",function(){return new q(this.context,this.context)});r("data()",function(){return this.iterator("table", +function(a){return C(a.aoData,"_aData")}).flatten()});r("destroy()",function(a){a=a||!1;return this.iterator("table",function(b){var c=b.nTableWrapper.parentNode,e=b.oClasses,d=b.nTable,f=b.nTBody,h=b.nTHead,i=b.nTFoot,j=g(d),f=g(f),l=g(b.nTableWrapper),m=g.map(b.aoData,function(a){return a.nTr}),o;b.bDestroying=!0;u(b,"aoDestroyCallback","destroy",[b]);a||(new q(b)).columns().visible(!0);l.unbind(".DT").find(":not(tbody *)").unbind(".DT");g(Da).unbind(".DT-"+b.sInstance);d!=h.parentNode&&(j.children("thead").detach(), +j.append(h));i&&d!=i.parentNode&&(j.children("tfoot").detach(),j.append(i));j.detach();l.detach();b.aaSorting=[];b.aaSortingFixed=[];wa(b);g(m).removeClass(b.asStripeClasses.join(" "));g("th, td",h).removeClass(e.sSortable+" "+e.sSortableAsc+" "+e.sSortableDesc+" "+e.sSortableNone);b.bJUI&&(g("th span."+e.sSortIcon+", td span."+e.sSortIcon,h).detach(),g("th, td",h).each(function(){var a=g("div."+e.sSortJUIWrapper,this);g(this).append(a.contents());a.detach()}));!a&&c&&c.insertBefore(d,b.nTableReinsertBefore); +f.children().detach();f.append(m);j.css("width",b.sDestroyWidth).removeClass(e.sTable);(o=b.asDestroyStripes.length)&&f.children().each(function(a){g(this).addClass(b.asDestroyStripes[a%o])});c=g.inArray(b,p.settings);-1!==c&&p.settings.splice(c,1)})});p.version="1.10.4";p.settings=[];p.models={};p.models.oSearch={bCaseInsensitive:!0,sSearch:"",bRegex:!1,bSmart:!0};p.models.oRow={nTr:null,anCells:null,_aData:[],_aSortData:null,_aFilterData:null,_sFilterRow:null,_sRowStripe:"",src:null};p.models.oColumn= +{idx:null,aDataSort:null,asSorting:null,bSearchable:null,bSortable:null,bVisible:null,_sManualType:null,_bAttrSrc:!1,fnCreatedCell:null,fnGetData:null,fnSetData:null,mData:null,mRender:null,nTh:null,nTf:null,sClass:null,sContentPadding:null,sDefaultContent:null,sName:null,sSortDataType:"std",sSortingClass:null,sSortingClassJUI:null,sTitle:null,sType:null,sWidth:null,sWidthOrig:null};p.defaults={aaData:null,aaSorting:[[0,"asc"]],aaSortingFixed:[],ajax:null,aLengthMenu:[10,25,50,100],aoColumns:null, +aoColumnDefs:null,aoSearchCols:[],asStripeClasses:null,bAutoWidth:!0,bDeferRender:!1,bDestroy:!1,bFilter:!0,bInfo:!0,bJQueryUI:!1,bLengthChange:!0,bPaginate:!0,bProcessing:!1,bRetrieve:!1,bScrollCollapse:!1,bServerSide:!1,bSort:!0,bSortMulti:!0,bSortCellsTop:!1,bSortClasses:!0,bStateSave:!1,fnCreatedRow:null,fnDrawCallback:null,fnFooterCallback:null,fnFormatNumber:function(a){return a.toString().replace(/\B(?=(\d{3})+(?!\d))/g,this.oLanguage.sThousands)},fnHeaderCallback:null,fnInfoCallback:null, +fnInitComplete:null,fnPreDrawCallback:null,fnRowCallback:null,fnServerData:null,fnServerParams:null,fnStateLoadCallback:function(a){try{return JSON.parse((-1===a.iStateDuration?sessionStorage:localStorage).getItem("DataTables_"+a.sInstance+"_"+location.pathname))}catch(b){}},fnStateLoadParams:null,fnStateLoaded:null,fnStateSaveCallback:function(a,b){try{(-1===a.iStateDuration?sessionStorage:localStorage).setItem("DataTables_"+a.sInstance+"_"+location.pathname,JSON.stringify(b))}catch(c){}},fnStateSaveParams:null, +iStateDuration:7200,iDeferLoading:null,iDisplayLength:10,iDisplayStart:0,iTabIndex:0,oClasses:{},oLanguage:{oAria:{sSortAscending:": activate to sort column ascending",sSortDescending:": activate to sort column descending"},oPaginate:{sFirst:"First",sLast:"Last",sNext:"Next",sPrevious:"Previous"},sEmptyTable:"No data available in table",sInfo:"Showing _START_ to _END_ of _TOTAL_ entries",sInfoEmpty:"Showing 0 to 0 of 0 entries",sInfoFiltered:"(filtered from _MAX_ total entries)",sInfoPostFix:"",sDecimal:"", +sThousands:",",sLengthMenu:"Show _MENU_ entries",sLoadingRecords:"Loading...",sProcessing:"Processing...",sSearch:"Search:",sSearchPlaceholder:"",sUrl:"",sZeroRecords:"No matching records found"},oSearch:g.extend({},p.models.oSearch),sAjaxDataProp:"data",sAjaxSource:null,sDom:"lfrtip",searchDelay:null,sPaginationType:"simple_numbers",sScrollX:"",sScrollXInner:"",sScrollY:"",sServerMethod:"GET",renderer:null};V(p.defaults);p.defaults.column={aDataSort:null,iDataSort:-1,asSorting:["asc","desc"],bSearchable:!0, +bSortable:!0,bVisible:!0,fnCreatedCell:null,mData:null,mRender:null,sCellType:"td",sClass:"",sContentPadding:"",sDefaultContent:null,sName:"",sSortDataType:"std",sTitle:null,sType:null,sWidth:null};V(p.defaults.column);p.models.oSettings={oFeatures:{bAutoWidth:null,bDeferRender:null,bFilter:null,bInfo:null,bLengthChange:null,bPaginate:null,bProcessing:null,bServerSide:null,bSort:null,bSortMulti:null,bSortClasses:null,bStateSave:null},oScroll:{bCollapse:null,iBarWidth:0,sX:null,sXInner:null,sY:null}, +oLanguage:{fnInfoCallback:null},oBrowser:{bScrollOversize:!1,bScrollbarLeft:!1},ajax:null,aanFeatures:[],aoData:[],aiDisplay:[],aiDisplayMaster:[],aoColumns:[],aoHeader:[],aoFooter:[],oPreviousSearch:{},aoPreSearchCols:[],aaSorting:null,aaSortingFixed:[],asStripeClasses:null,asDestroyStripes:[],sDestroyWidth:0,aoRowCallback:[],aoHeaderCallback:[],aoFooterCallback:[],aoDrawCallback:[],aoRowCreatedCallback:[],aoPreDrawCallback:[],aoInitComplete:[],aoStateSaveParams:[],aoStateLoadParams:[],aoStateLoaded:[], +sTableId:"",nTable:null,nTHead:null,nTFoot:null,nTBody:null,nTableWrapper:null,bDeferLoading:!1,bInitialised:!1,aoOpenRows:[],sDom:null,searchDelay:null,sPaginationType:"two_button",iStateDuration:0,aoStateSave:[],aoStateLoad:[],oSavedState:null,oLoadedState:null,sAjaxSource:null,sAjaxDataProp:null,bAjaxDataGet:!0,jqXHR:null,json:l,oAjaxData:l,fnServerData:null,aoServerParams:[],sServerMethod:null,fnFormatNumber:null,aLengthMenu:null,iDraw:0,bDrawing:!1,iDrawError:-1,_iDisplayLength:10,_iDisplayStart:0, +_iRecordsTotal:0,_iRecordsDisplay:0,bJUI:null,oClasses:{},bFiltered:!1,bSorted:!1,bSortCellsTop:null,oInit:null,aoDestroyCallback:[],fnRecordsTotal:function(){return"ssp"==A(this)?1*this._iRecordsTotal:this.aiDisplayMaster.length},fnRecordsDisplay:function(){return"ssp"==A(this)?1*this._iRecordsDisplay:this.aiDisplay.length},fnDisplayEnd:function(){var a=this._iDisplayLength,b=this._iDisplayStart,c=b+a,e=this.aiDisplay.length,d=this.oFeatures,f=d.bPaginate;return d.bServerSide?!1===f||-1===a?b+e: +Math.min(b+a,this._iRecordsDisplay):!f||c>e||-1===a?e:c},oInstance:null,sInstance:null,iTabIndex:0,nScrollHead:null,nScrollFoot:null,aLastSort:[],oPlugins:{}};p.ext=w={classes:{},errMode:"alert",feature:[],search:[],internal:{},legacy:{ajax:null},pager:{},renderer:{pageButton:{},header:{}},order:{},type:{detect:[],search:{},order:{}},_unique:0,fnVersionCheck:p.fnVersionCheck,iApiIndex:0,oJUIClasses:{},sVersion:p.version};g.extend(w,{afnFiltering:w.search,aTypes:w.type.detect,ofnSearch:w.type.search, +oSort:w.type.order,afnSortData:w.order,aoFeatures:w.feature,oApi:w.internal,oStdClasses:w.classes,oPagination:w.pager});g.extend(p.ext.classes,{sTable:"dataTable",sNoFooter:"no-footer",sPageButton:"paginate_button",sPageButtonActive:"current",sPageButtonDisabled:"disabled",sStripeOdd:"odd",sStripeEven:"even",sRowEmpty:"dataTables_empty",sWrapper:"dataTables_wrapper",sFilter:"dataTables_filter",sInfo:"dataTables_info",sPaging:"dataTables_paginate paging_",sLength:"dataTables_length",sProcessing:"dataTables_processing", +sSortAsc:"sorting_asc",sSortDesc:"sorting_desc",sSortable:"sorting",sSortableAsc:"sorting_asc_disabled",sSortableDesc:"sorting_desc_disabled",sSortableNone:"sorting_disabled",sSortColumn:"sorting_",sFilterInput:"",sLengthSelect:"",sScrollWrapper:"dataTables_scroll",sScrollHead:"dataTables_scrollHead",sScrollHeadInner:"dataTables_scrollHeadInner",sScrollBody:"dataTables_scrollBody",sScrollFoot:"dataTables_scrollFoot",sScrollFootInner:"dataTables_scrollFootInner",sHeaderTH:"",sFooterTH:"",sSortJUIAsc:"", +sSortJUIDesc:"",sSortJUI:"",sSortJUIAscAllowed:"",sSortJUIDescAllowed:"",sSortJUIWrapper:"",sSortIcon:"",sJUIHeader:"",sJUIFooter:""});var Ca="",Ca="",E=Ca+"ui-state-default",ia=Ca+"css_right ui-icon ui-icon-",Wb=Ca+"fg-toolbar ui-toolbar ui-widget-header ui-helper-clearfix";g.extend(p.ext.oJUIClasses,p.ext.classes,{sPageButton:"fg-button ui-button "+E,sPageButtonActive:"ui-state-disabled",sPageButtonDisabled:"ui-state-disabled",sPaging:"dataTables_paginate fg-buttonset ui-buttonset fg-buttonset-multi ui-buttonset-multi paging_", +sSortAsc:E+" sorting_asc",sSortDesc:E+" sorting_desc",sSortable:E+" sorting",sSortableAsc:E+" sorting_asc_disabled",sSortableDesc:E+" sorting_desc_disabled",sSortableNone:E+" sorting_disabled",sSortJUIAsc:ia+"triangle-1-n",sSortJUIDesc:ia+"triangle-1-s",sSortJUI:ia+"carat-2-n-s",sSortJUIAscAllowed:ia+"carat-1-n",sSortJUIDescAllowed:ia+"carat-1-s",sSortJUIWrapper:"DataTables_sort_wrapper",sSortIcon:"DataTables_sort_icon",sScrollHead:"dataTables_scrollHead "+E,sScrollFoot:"dataTables_scrollFoot "+E, +sHeaderTH:E,sFooterTH:E,sJUIHeader:Wb+" ui-corner-tl ui-corner-tr",sJUIFooter:Wb+" ui-corner-bl ui-corner-br"});var Lb=p.ext.pager;g.extend(Lb,{simple:function(){return["previous","next"]},full:function(){return["first","previous","next","last"]},simple_numbers:function(a,b){return["previous",Va(a,b),"next"]},full_numbers:function(a,b){return["first","previous",Va(a,b),"next","last"]},_numbers:Va,numbers_length:7});g.extend(!0,p.ext.renderer,{pageButton:{_:function(a,b,c,e,d,f){var h=a.oClasses,i= +a.oLanguage.oPaginate,j,l,m=0,o=function(b,e){var k,p,r,q,s=function(b){Sa(a,b.data.action,true)};k=0;for(p=e.length;k").appendTo(b);o(r,q)}else{l=j="";switch(q){case "ellipsis":b.append("");break;case "first":j=i.sFirst;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "previous":j=i.sPrevious;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "next":j=i.sNext;l=q+(d",{"class":h.sPageButton+" "+l,"aria-controls":a.sTableId,"data-dt-idx":m,tabindex:a.iTabIndex,id:c===0&&typeof q==="string"?a.sTableId+"_"+q:null}).html(j).appendTo(b);Ua(r,{action:q},s);m++}}}};try{var k=g(P.activeElement).data("dt-idx");o(g(b).empty(),e);k!==null&&g(b).find("[data-dt-idx="+k+"]").focus()}catch(p){}}}});g.extend(p.ext.type.detect,[function(a,b){var c=b.oLanguage.sDecimal; +return Ya(a,c)?"num"+c:null},function(a){if(a&&!(a instanceof Date)&&(!$b.test(a)||!ac.test(a)))return null;var b=Date.parse(a);return null!==b&&!isNaN(b)||H(a)?"date":null},function(a,b){var c=b.oLanguage.sDecimal;return Ya(a,c,!0)?"num-fmt"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c)?"html-num"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c,!0)?"html-num-fmt"+c:null},function(a){return H(a)||"string"===typeof a&&-1!==a.indexOf("<")?"html":null}]);g.extend(p.ext.type.search, +{html:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," ").replace(Aa,""):""},string:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," "):a}});var za=function(a,b,c,e){if(0!==a&&(!a||"-"===a))return-Infinity;b&&(a=Pb(a,b));a.replace&&(c&&(a=a.replace(c,"")),e&&(a=a.replace(e,"")));return 1*a};g.extend(w.type.order,{"date-pre":function(a){return Date.parse(a)||0},"html-pre":function(a){return H(a)?"":a.replace?a.replace(/<.*?>/g,"").toLowerCase():a+""},"string-pre":function(a){return H(a)? +"":"string"===typeof a?a.toLowerCase():!a.toString?"":a.toString()},"string-asc":function(a,b){return ab?1:0},"string-desc":function(a,b){return ab?-1:0}});cb("");g.extend(!0,p.ext.renderer,{header:{_:function(a,b,c,e){g(a.nTable).on("order.dt.DT",function(d,f,h,g){if(a===f){d=c.idx;b.removeClass(c.sSortingClass+" "+e.sSortAsc+" "+e.sSortDesc).addClass(g[d]=="asc"?e.sSortAsc:g[d]=="desc"?e.sSortDesc:c.sSortingClass)}})},jqueryui:function(a,b,c,e){g("
      ").addClass(e.sSortJUIWrapper).append(b.contents()).append(g("").addClass(e.sSortIcon+ +" "+c.sSortingClassJUI)).appendTo(b);g(a.nTable).on("order.dt.DT",function(d,f,g,i){if(a===f){d=c.idx;b.removeClass(e.sSortAsc+" "+e.sSortDesc).addClass(i[d]=="asc"?e.sSortAsc:i[d]=="desc"?e.sSortDesc:c.sSortingClass);b.find("span."+e.sSortIcon).removeClass(e.sSortJUIAsc+" "+e.sSortJUIDesc+" "+e.sSortJUI+" "+e.sSortJUIAscAllowed+" "+e.sSortJUIDescAllowed).addClass(i[d]=="asc"?e.sSortJUIAsc:i[d]=="desc"?e.sSortJUIDesc:c.sSortingClassJUI)}})}}});p.render={number:function(a,b,c,e){return{display:function(d){var f= +0>d?"-":"",d=Math.abs(parseFloat(d)),g=parseInt(d,10),d=c?b+(d-g).toFixed(c).substring(2):"";return f+(e||"")+g.toString().replace(/\B(?=(\d{3})+(?!\d))/g,a)+d}}}};g.extend(p.ext.internal,{_fnExternApiFunc:Mb,_fnBuildAjax:qa,_fnAjaxUpdate:jb,_fnAjaxParameters:sb,_fnAjaxUpdateDraw:tb,_fnAjaxDataSrc:ra,_fnAddColumn:Ea,_fnColumnOptions:ja,_fnAdjustColumnSizing:X,_fnVisibleToColumnIndex:ka,_fnColumnIndexToVisible:$,_fnVisbleColumns:aa,_fnGetColumns:Z,_fnColumnTypes:Ga,_fnApplyColumnDefs:hb,_fnHungarianMap:V, +_fnCamelToHungarian:G,_fnLanguageCompat:O,_fnBrowserDetect:fb,_fnAddData:I,_fnAddTr:la,_fnNodeToDataIndex:function(a,b){return b._DT_RowIndex!==l?b._DT_RowIndex:null},_fnNodeToColumnIndex:function(a,b,c){return g.inArray(c,a.aoData[b].anCells)},_fnGetCellData:v,_fnSetCellData:Ha,_fnSplitObjNotation:Ja,_fnGetObjectDataFn:W,_fnSetObjectDataFn:Q,_fnGetDataMaster:Ka,_fnClearTable:na,_fnDeleteIndex:oa,_fnInvalidate:ca,_fnGetRowElements:ma,_fnCreateTr:Ia,_fnBuildHead:ib,_fnDrawHead:ea,_fnDraw:L,_fnReDraw:M, +_fnAddOptionsHtml:lb,_fnDetectHeader:da,_fnGetUniqueThs:pa,_fnFeatureHtmlFilter:nb,_fnFilterComplete:fa,_fnFilterCustom:wb,_fnFilterColumn:vb,_fnFilter:ub,_fnFilterCreateSearch:Pa,_fnEscapeRegex:ua,_fnFilterData:xb,_fnFeatureHtmlInfo:qb,_fnUpdateInfo:Ab,_fnInfoMacros:Bb,_fnInitialise:ga,_fnInitComplete:sa,_fnLengthChange:Qa,_fnFeatureHtmlLength:mb,_fnFeatureHtmlPaginate:rb,_fnPageChange:Sa,_fnFeatureHtmlProcessing:ob,_fnProcessingDisplay:B,_fnFeatureHtmlTable:pb,_fnScrollDraw:Y,_fnApplyToChildren:F, +_fnCalculateColumnWidths:Fa,_fnThrottle:ta,_fnConvertToWidth:Cb,_fnScrollingWidthAdjust:Eb,_fnGetWidestNode:Db,_fnGetMaxLenString:Fb,_fnStringToCss:s,_fnScrollBarWidth:Gb,_fnSortFlatten:T,_fnSort:kb,_fnSortAria:Ib,_fnSortListener:Ta,_fnSortAttachListener:Na,_fnSortingClasses:wa,_fnSortData:Hb,_fnSaveState:xa,_fnLoadState:Jb,_fnSettingsFromNode:ya,_fnLog:R,_fnMap:D,_fnBindAction:Ua,_fnCallbackReg:x,_fnCallbackFire:u,_fnLengthOverflow:Ra,_fnRenderer:Oa,_fnDataSource:A,_fnRowAttributes:La,_fnCalculateEnd:function(){}}); +g.fn.dataTable=p;g.fn.dataTableSettings=p.settings;g.fn.dataTableExt=p.ext;g.fn.DataTable=function(a){return g(this).dataTable(a).api()};g.each(p,function(a,b){g.fn.DataTable[a]=b});return g.fn.dataTable};"function"===typeof define&&define.amd?define("datatables",["jquery"],O):"object"===typeof exports?O(require("jquery")):jQuery&&!jQuery.fn.dataTable&&O(jQuery)})(window,document); diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js new file mode 100644 index 0000000000000..14925bf93d0f1 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js @@ -0,0 +1,592 @@ +/* +Shameless port of a shameless port +@defunkt => @janl => @aq + +See http://github.com/defunkt/mustache for more info. +*/ + +;(function($) { + +/*! + * mustache.js - Logic-less {{mustache}} templates with JavaScript + * http://github.com/janl/mustache.js + */ + +/*global define: false*/ + +(function (root, factory) { + if (typeof exports === "object" && exports) { + factory(exports); // CommonJS + } else { + var mustache = {}; + factory(mustache); + if (typeof define === "function" && define.amd) { + define(mustache); // AMD + } else { + root.Mustache = mustache; // ++ + ++ + } else if (requestedIncomplete) {

      No incomplete applications found!

      } else {

      No completed applications found!

      ++ -

      Did you specify the correct logging directory? - Please verify your setting of - spark.history.fs.logDirectory and whether you have the permissions to - access it.
      It is also possible that your application did not run to - completion or did not stop the SparkContext. -

      +

      Did you specify the correct logging directory? + Please verify your setting of + spark.history.fs.logDirectory and whether you have the permissions to + access it.
      It is also possible that your application did not run to + completion or did not stop the SparkContext. +

      } - } - - { + } + + + { if (requestedIncomplete) { "Back to completed applications" } else { "Show incomplete applications" } - } - -
      + } +
      +
      UIUtils.basicSparkPage(content, "History Server") } - private val appHeader = Seq( - "App ID", - "App Name", - "Started", - "Completed", - "Duration", - "Spark User", - "Last Updated") - - private val appWithAttemptHeader = Seq( - "App ID", - "App Name", - "Attempt ID", - "Started", - "Completed", - "Duration", - "Spark User", - "Last Updated") - - private def rangeIndices( - range: Seq[Int], - condition: Int => Boolean, - showIncomplete: Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => - {nextPage} ) - } - - private def attemptRow( - renderAttemptIdColumn: Boolean, - info: ApplicationHistoryInfo, - attempt: ApplicationAttemptInfo, - isFirst: Boolean): Seq[Node] = { - val uiAddress = UIUtils.prependBaseUri(HistoryServer.getAttemptURI(info.id, attempt.attemptId)) - val startTime = UIUtils.formatDate(attempt.startTime) - val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-" - val duration = - if (attempt.endTime > 0) { - UIUtils.formatDuration(attempt.endTime - attempt.startTime) - } else { - "-" - } - val lastUpdated = UIUtils.formatDate(attempt.lastUpdated) - - { - if (isFirst) { - if (info.attempts.size > 1 || renderAttemptIdColumn) { - - {info.id} - - {info.name} - } else { - {info.id} - {info.name} - } - } else { - Nil - } - } - { - if (renderAttemptIdColumn) { - if (info.attempts.size > 1 && attempt.attemptId.isDefined) { - {attempt.attemptId.get} - } else { -   - } - } else { - Nil - } - } - {startTime} - {endTime} - - {duration} - {attempt.sparkUser} - {lastUpdated} - - } - - private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - attemptRow(false, info, info.attempts.head, true) - } - - private def appWithAttemptRow(info: ApplicationHistoryInfo): Seq[Node] = { - attemptRow(true, info, info.attempts.head, true) ++ - info.attempts.drop(1).flatMap(attemptRow(true, info, _, false)) - } - - private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + Array( - "page=" + linkPage, - "showIncomplete=" + showIncomplete - ).mkString("&")) + private def makePageLink(showIncomplete: Boolean): String = { + UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 0fc0fb59d861f..0f30183682469 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -71,6 +71,13 @@ private[spark] object ApplicationsListResource { attemptId = internalAttemptInfo.attemptId, startTime = new Date(internalAttemptInfo.startTime), endTime = new Date(internalAttemptInfo.endTime), + duration = + if (internalAttemptInfo.endTime > 0) { + internalAttemptInfo.endTime - internalAttemptInfo.startTime + } else { + 0 + }, + lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, completed = internalAttemptInfo.completed ) @@ -93,6 +100,13 @@ private[spark] object ApplicationsListResource { attemptId = None, startTime = new Date(internal.startTime), endTime = new Date(internal.endTime), + duration = + if (internal.endTime > 0) { + internal.endTime - internal.startTime + } else { + 0 + }, + lastUpdated = new Date(internal.endTime), sparkUser = internal.desc.user, completed = completed )) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 3adf5b1109af4..2b0079f5fd62e 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -35,6 +35,8 @@ class ApplicationAttemptInfo private[spark]( val attemptId: Option[String], val startTime: Date, val endTime: Date, + val lastUpdated: Date, + val duration: Long, val sparkUser: String, val completed: Boolean = false) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index cf45414c4f786..6cc30eeaf5d82 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -114,6 +114,8 @@ private[spark] class SparkUI private ( attemptId = None, startTime = new Date(startTime), endTime = new Date(-1), + duration = 0, + lastUpdated = new Date(startTime), sparkUser = "", completed = false )) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 1949c4b3cbf42..4ebee9093d41c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -157,11 +157,22 @@ private[spark] object UIUtils extends Logging { def commonHeaderNodes: Seq[Node] = { + + + + + + + + + diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index d575bf2f284b9..5bbb4ceb97228 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -45,6 +55,8 @@ "attempts" : [ { "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -54,6 +66,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -63,7 +77,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index d575bf2f284b9..5bbb4ceb97228 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -45,6 +55,8 @@ "attempts" : [ { "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -54,6 +66,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -63,7 +77,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index 483632a3956ed..3f80a529a08b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -4,7 +4,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index 4b85690fd9199..508bdc17efe9f 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -13,7 +15,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 15c2de8ef99ea..5dca7d73de0cc 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -4,6 +4,8 @@ "attempts" : [ { "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -14,12 +16,16 @@ "attemptId" : "2", "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -30,12 +36,16 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -46,8 +56,10 @@ { "startTime": "2015-02-28T00:02:38.277GMT", "endTime": "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser": "irashid", "completed": true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index 07489ad96414a..cca32c791074f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -4,7 +4,9 @@ "attempts" : [ { "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index 8f3d7160c723f..1ea1779e8369d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -5,13 +5,17 @@ "attemptId" : "2", "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 18659fc0c18de..be55b2e0fe1b7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -139,7 +139,24 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers code should be (HttpServletResponse.SC_OK) jsonOpt should be ('defined) errOpt should be (None) - val json = jsonOpt.get + val jsonOrg = jsonOpt.get + + // SPARK-10873 added the lastUpdated field for each application's attempt, + // the REST API returns the last modified time of EVENT LOG file for this field. + // It is not applicable to hard-code this dynamic field in a static expected file, + // so here we skip checking the lastUpdated field's value (setting it as ""). + val json = if (jsonOrg.indexOf("lastUpdated") >= 0) { + val subStrings = jsonOrg.split(",") + for (i <- subStrings.indices) { + if (subStrings(i).indexOf("lastUpdated") >= 0) { + subStrings(i) = "\"lastUpdated\":\"\"" + } + } + subStrings.mkString(",") + } else { + jsonOrg + } + val exp = IOUtils.toString(new FileInputStream( new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json"))) // compare the ASTs so formatting differences don't cause failures @@ -241,30 +258,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("generate history page with relative links") { - val historyServer = mock[HistoryServer] - val request = mock[HttpServletRequest] - val ui = mock[SparkUI] - val link = "/history/app1" - val info = new ApplicationHistoryInfo("app1", "app1", - List(ApplicationAttemptInfo(None, 0, 2, 1, "xxx", true))) - when(historyServer.getApplicationList()).thenReturn(Seq(info)) - when(ui.basePath).thenReturn(link) - when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) - val page = new HistoryPage(historyServer) - - // when - val response = page.render(request) - - // then - val links = response \\ "a" - val justHrefs = for { - l <- links - attrs <- l.attribute("href") - } yield (attrs.toString) - justHrefs should contain (UIUtils.prependBaseUri(resource = link)) - } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 968a2903f3010..a3ae4d2b730ff 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,6 +45,10 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), // SPARK-12600 Remove SQL deprecated methods ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), From c5f745ede01831b59c57effa7de88c648b82c13d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 10:24:23 -0800 Subject: [PATCH 1738/2089] [SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible. generated code comparison for `hash(int, double, string, array)`: **before:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); } /* input[1, double] */ double value5 = i.getDouble(1); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); } /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { int result10 = value1; for (int index11 = 0; index11 < value9.numElements(); index11++) { if (!value9.isNullAt(index11)) { final int element12 = value9.getInt(index11); result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10); } } value1 = result10; } } ``` **after:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); /* input[1, double] */ double value5 = i.getDouble(1); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { for (int index10 = 0; index10 < value9.numElements(); index10++) { final int element11 = value9.getInt(index10); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1); } } rowWriter14.write(0, value1); return result12; } ``` Author: Wenchen Fan Closes #10974 from cloud-fan/codegen. --- .../spark/sql/catalyst/expressions/misc.scala | 155 ++++++++---------- 1 file changed, 69 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 493e0aae01af7..8480c3f9a12f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" - val childrenHash = children.zipWithIndex.map { - case (child, dt) => - val childGen = child.gen(ctx) - val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) - s""" - ${childGen.code} - if (!${childGen.isNull}) { - ${childHash.code} - ${ev.value} = ${childHash.value}; - } - """ + val childrenHash = children.map { child => + val childGen = child.gen(ctx) + childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } }.mkString("\n") + s""" int ${ev.value} = $seed; $childrenHash """ } + private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execution + } + """ + } else { + "\n" + execution + } + } + + private def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + generateNullCheck(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + private def computeHash( input: String, dataType: DataType, - seed: String, - ctx: CodegenContext): ExprCode = { + result: String, + ctx: CodegenContext): String = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) + + def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" + def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" + def hashBytes(b: String): String = + s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" dataType match { - case NullType => inlineValue(seed) + case NullType => "" case BooleanType => hashInt(s"$input ? 1 : 0") case ByteType | ShortType | IntegerType | DateType => hashInt(input) case LongType | TimestampType => hashLong(input) @@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression hashLong(s"$input.toUnscaledLong()") } else { val bytes = ctx.freshName("bytes") - val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" - val offset = "Platform.BYTE_ARRAY_OFFSET" - val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - ExprCode(code, "false", result) + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ } case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" - val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" - inlineValue(monthsHash) - case BinaryType => - val offset = "Platform.BYTE_ARRAY_OFFSET" - inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" + s"$result = $hasher.hashInt($input.months, $microsecondsHash);" + case BinaryType => hashBytes(input) case StringType => val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" - inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, _) => - val result = ctx.freshName("result") + case ArrayType(et, containsNull) => val index = ctx.freshName("index") - val element = ctx.freshName("element") - val elementHash = computeHash(element, et, result, ctx) - val code = - s""" - int $result = $seed; - for (int $index = 0; $index < $input.numElements(); $index++) { - if (!$input.isNullAt($index)) { - final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; - ${elementHash.code} - $result = ${elementHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} + } + """ - case MapType(kt, vt, _) => - val result = ctx.freshName("result") + case MapType(kt, vt, valueContainsNull) => val index = ctx.freshName("index") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val key = ctx.freshName("key") - val value = ctx.freshName("value") - val keyHash = computeHash(key, kt, result, ctx) - val valueHash = computeHash(value, vt, result, ctx) - val code = - s""" - int $result = $seed; - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; - ${keyHash.code} - $result = ${keyHash.value}; - if (!$values.isNullAt($index)) { - final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; - ${valueHash.code} - $result = ${valueHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, kt, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} + } + """ case StructType(fields) => - val result = ctx.freshName("result") - val fieldsHash = fields.map(_.dataType).zipWithIndex.map { - case (dt, index) => - val field = ctx.freshName("field") - val fieldHash = computeHash(field, dt, result, ctx) - s""" - if (!$input.isNullAt($index)) { - final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; - ${fieldHash.code} - $result = ${fieldHash.value}; - } - """ + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") - val code = - s""" - int $result = $seed; - $fieldsHash - """ - ExprCode(code, "false", result) - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) } } } From 5f686cc8b74ea9e36f56c31f14df90d134fd9343 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Jan 2016 11:22:12 -0800 Subject: [PATCH 1739/2089] [SPARK-12656] [SQL] Implement Intersect with Left-semi Join Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: https://github.com/apache/spark/pull/10566 Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10630 from gatorsmile/IntersectBySemiJoin. --- .../sql/catalyst/analysis/Analyzer.scala | 113 ++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 ++- .../sql/catalyst/optimizer/Optimizer.scala | 45 ++++--- .../plans/logical/basicOperators.scala | 32 +++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 5 + .../optimizer/AggregateOptimizeSuite.scala | 12 -- .../optimizer/ReplaceOperatorSuite.scala | 59 +++++++++ .../optimizer/SetOperationSuite.scala | 15 +-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../spark/sql/execution/basicOperators.scala | 12 -- .../org/apache/spark/sql/DataFrameSuite.scala | 21 ++++ 11 files changed, 211 insertions(+), 122 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 33d76eeb21287..5fe700ee00673 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -344,6 +344,63 @@ class Analyzer( } } + /** + * Generate a new logical plan for the right child with different expression IDs + * for all conflicting attributes. + */ + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + + s"between $left and $right") + + right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + + case oldVersion @ Window(_, windowExpressions, _, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) + } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + right + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + newRight + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -388,57 +445,11 @@ class Analyzer( .map(_.asInstanceOf[NamedExpression]) a.copy(aggregateExpressions = expanded) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - - right.collect { - // Handle base relations that might appear more than once. - case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = oldVersion.newInstance() - (oldVersion, newVersion) - - // Handle projects that create conflicting aliases. - case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) - - case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - - case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => - val newOutput = oldVersion.generatorOutput.map(_.newInstance()) - (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - - case oldVersion @ Window(_, windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => - (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - } - // Only handle first case, others will be fixed on the next pass. - .headOption match { - case None => - /* - * No result implies that there is a logical plan node that produces new references - * that this rule cannot handle. When that is the case, there must be another rule - * that resolves these conflicts. Otherwise, the analysis will fail. - */ - j - case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } - } - j.copy(right = newRight) - } + // To resolve duplicate expression IDs for Join and Intersect + case j @ Join(left, right, _, _) if !j.duplicateResolved => + j.copy(right = dedupRight(left, right)) + case i @ Intersect(left, right) if !i.duplicateResolved => + i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on grandchild diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f2e78d97442e3..4a2f2b8bc6e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,9 +214,8 @@ trait CheckAnalysis { s"""Only a single table generating function is allowed in a SELECT clause, found: | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) + case j: Join if !j.duplicateResolved => + val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) failAnalysis( s""" |Failure when resolving conflicting references in Join: @@ -224,6 +223,15 @@ trait CheckAnalysis { |Conflicting attributes: ${conflictingAttributes.mkString(",")} |""".stripMargin) + case i: Intersect if !i.duplicateResolved => + val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Intersect: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6addc2080648b..f156b5d10acc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, + ReplaceDistinctWithAggregate) :: Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down @@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] { } /** - * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * - * Intersect: - * It is not safe to pushdown Projections through it because we need to get the - * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * with deterministic condition. - * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters @@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrites an expression so that it can be pushed to the right side of a - * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * Union or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) - // Push down filter through INTERSECT - case Filter(condition, Intersect(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(left, right) - Filter(nondeterministic, - Intersect( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) - // Push down filter through EXCEPT case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) @@ -1054,6 +1040,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. + * {{{ + * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 + * }}} + * + * Note: + * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL. + * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated + * join conditions will be incorrect. + */ +object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right) => + assert(left.output.size == right.output.size) + val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e9c970cd08088..16f4b355b1b6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -90,12 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - final override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -103,15 +99,30 @@ private[sql] object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + + // Intersect are only resolved if they don't introduce ambiguous expression ids, + // since the Optimizer will convert Intersect to Join. + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + duplicateResolved } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } /** Factory for constructing new `Union` nodes. */ @@ -169,13 +180,13 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { childrenResolved && expressions.forall(_.resolved) && - selfJoinResolved && + duplicateResolved && condition.forall(_.dataType == BooleanType) } } @@ -249,7 +260,7 @@ case class Range( end: Long, step: Long, numSlices: Int, - output: Seq[Attribute]) extends LeafNode { + output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { require(step != 0, "step cannot be 0") val numElements: BigInt = { val safeStart = BigInt(start) @@ -262,6 +273,9 @@ case class Range( } } + override def newInstance(): Range = + Range(start, end, step, numSlices, output.map(_.newInstance())) + override def statistics: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ab680282208c8..1938bce02a177 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("self intersect should resolve duplicate expression IDs") { + val plan = testRelation.intersect(testRelation) + assertAnalysisSuccess(plan) + } + test("SPARK-8654: invalid CAST in NULL IN(...) expression") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 37148a226f293..a4a12c0d62e92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } - test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala new file mode 100644 index 0000000000000..f8ae5d9be2084 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", FixedPoint(100), + ReplaceDistinctWithAggregate, + ReplaceIntersectWithSemiJoin) :: Nil + } + + test("replace Intersect with Left-semi Join") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = Intersect(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Distinct with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 2283f7c008ba2..b8ea32b4dfe01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) - val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) test("union: combine unions into one unions") { @@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest { comparePlans(combinedUnionsOptimized, unionOptimized3) } - test("intersect/except: filter to each side") { - val intersectQuery = testIntersect.where('b < 10) + test("except: filter to each side") { val exceptQuery = testExcept.where('c >= 5) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - val intersectCorrectAnswer = - Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } @@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest { } test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - comparePlans(intersectOptimized, intersectQuery.analyze) comparePlans(exceptOptimized, exceptQuery.analyze) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 60fbb595e5758..9293e55141757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -298,6 +298,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + case logical.Intersect(left, right) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil @@ -340,8 +343,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil - case logical.Intersect(left, right) => - execution.Intersect(planLater(left), planLater(right)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb4bf..fd81531c9316a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -420,18 +420,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { } } -/** - * Returns the rows in left that also appear in right using the built in spark - * intersection function. - */ -case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output: Seq[Attribute] = children.head.output - - protected override def doExecute(): RDD[InternalRow] = { - left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) - } -} - /** * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 09bbe57a43ceb..4ff99bdf2937d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) } test("intersect - nullability") { From 2b027e9a386fe4009f61ad03b169335af5a9a5c6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 29 Jan 2016 12:01:13 -0800 Subject: [PATCH 1740/2089] [SPARK-12818] Polishes spark-sketch module Fixes various minor code and Javadoc styling issues. Author: Cheng Lian Closes #10985 from liancheng/sketch-polishing. --- .../apache/spark/util/sketch/BitArray.java | 2 +- .../apache/spark/util/sketch/BloomFilter.java | 111 ++++++++++-------- .../spark/util/sketch/BloomFilterImpl.java | 40 ++++--- .../spark/util/sketch/CountMinSketch.java | 26 ++-- .../spark/util/sketch/CountMinSketchImpl.java | 12 ++ .../org/apache/spark/util/sketch/Utils.java | 2 +- 6 files changed, 110 insertions(+), 83 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 2a0484e324b13..480a0a79db32d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -22,7 +22,7 @@ import java.io.IOException; import java.util.Arrays; -public final class BitArray { +final class BitArray { private final long[] data; private long bitCount; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 81772fcea0ec2..c0b425e729595 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -22,16 +22,10 @@ import java.io.OutputStream; /** - * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether - * an element is a member of a set. It returns false when the element is definitely not in the - * set, returns true when the element is probably in the set. - * - * Internally a Bloom filter is initialized with 2 information: how many space to use(number of - * bits) and how many hash values to calculate for each record. To get as lower false positive - * probability as possible, user should call {@link BloomFilter#create} to automatically pick a - * best combination of these 2 parameters. - * - * Currently the following data types are supported: + * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate + * containment test with one-sided error: if it claims that an item is contained in it, this + * might be in error, but if it claims that an item is not contained in it, then this is + * definitely true. Currently supported data types include: *
        *
      • {@link Byte}
      • *
      • {@link Short}
      • @@ -39,14 +33,17 @@ *
      • {@link Long}
      • *
      • {@link String}
      • *
      + * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * not actually been put in the {@code BloomFilter}. * - * The implementation is largely based on the {@code BloomFilter} class from guava. + * The implementation is largely based on the {@code BloomFilter} class from Guava. */ public abstract class BloomFilter { public enum Version { /** - * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * {@code BloomFilter} binary format version 1. All values written in big-endian order: *
        *
      • Version number, always 1 (32 bit)
      • *
      • Number of hash functions (32 bit)
      • @@ -68,14 +65,13 @@ int getVersionNumber() { } /** - * Returns the false positive probability, i.e. the probability that - * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that - * has not actually been put in the {@code BloomFilter}. + * Returns the probability that {@linkplain #mightContain(Object)} erroneously return {@code true} + * for an object that has not actually been put in the {@code BloomFilter}. * - *

        Ideally, this number should be close to the {@code fpp} parameter - * passed in to create this bloom filter, or smaller. If it is - * significantly higher, it is usually the case that too many elements (more than - * expected) have been put in the {@code BloomFilter}, degenerating it. + * Ideally, this number should be close to the {@code fpp} parameter passed in + * {@linkplain #create(long, double)}, or smaller. If it is significantly higher, it is usually + * the case that too many items (more than expected) have been put in the {@code BloomFilter}, + * degenerating it. */ public abstract double expectedFpp(); @@ -85,8 +81,8 @@ int getVersionNumber() { public abstract long bitSize(); /** - * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of - * {@link #mightContain(Object)} with the same element will always return {@code true}. + * Puts an item into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@linkplain #mightContain(Object)} with the same item will always return {@code true}. * * @return true if the bloom filter's bits changed as a result of this operation. If the bits * changed, this is definitely the first time {@code object} has been added to the @@ -98,19 +94,19 @@ int getVersionNumber() { public abstract boolean put(Object item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string. + * A specialized variant of {@link #put(Object)} that only supports {@code String} items. */ - public abstract boolean putString(String str); + public abstract boolean putString(String item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put long. + * A specialized variant of {@link #put(Object)} that only supports {@code long} items. */ - public abstract boolean putLong(long l); + public abstract boolean putLong(long item); /** - * A specialized variant of {@link #put(Object)}, that can only be used to put byte array. + * A specialized variant of {@link #put(Object)} that only supports byte array items. */ - public abstract boolean putBinary(byte[] bytes); + public abstract boolean putBinary(byte[] item); /** * Determines whether a given bloom filter is compatible with this bloom filter. For two @@ -137,38 +133,36 @@ int getVersionNumber() { public abstract boolean mightContain(Object item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8 - * string. + * A specialized variant of {@link #mightContain(Object)} that only tests {@code String} items. */ - public abstract boolean mightContainString(String str); + public abstract boolean mightContainString(String item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long. + * A specialized variant of {@link #mightContain(Object)} that only tests {@code long} items. */ - public abstract boolean mightContainLong(long l); + public abstract boolean mightContainLong(long item); /** - * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte - * array. + * A specialized variant of {@link #mightContain(Object)} that only tests byte array items. */ - public abstract boolean mightContainBinary(byte[] bytes); + public abstract boolean mightContainBinary(byte[] item); /** - * Writes out this {@link BloomFilter} to an output stream in binary format. - * It is the caller's responsibility to close the stream. + * Writes out this {@link BloomFilter} to an output stream in binary format. It is the caller's + * responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** - * Reads in a {@link BloomFilter} from an input stream. - * It is the caller's responsibility to close the stream. + * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close + * the stream. */ public static BloomFilter readFrom(InputStream in) throws IOException { return BloomFilterImpl.readFrom(in); } /** - * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the + * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. * * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. @@ -197,21 +191,31 @@ private static long optimalNumOfBits(long n, double p) { static final double DEFAULT_FPP = 0.03; /** - * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}. + * Creates a {@link BloomFilter} with the expected number of insertions and a default expected + * false positive probability of 3%. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. */ public static BloomFilter create(long expectedNumItems) { return create(expectedNumItems, DEFAULT_FPP); } /** - * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick - * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter. + * Creates a {@link BloomFilter} with the expected number of insertions and expected false + * positive probability. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. */ public static BloomFilter create(long expectedNumItems, double fpp) { - assert fpp > 0.0 : "False positive probability must be > 0.0"; - assert fpp < 1.0 : "False positive probability must be < 1.0"; - long numBits = optimalNumOfBits(expectedNumItems, fpp); - return create(expectedNumItems, numBits); + if (fpp <= 0D || fpp >= 1D) { + throw new IllegalArgumentException( + "False positive probability must be within range (0.0, 1.0)" + ); + } + + return create(expectedNumItems, optimalNumOfBits(expectedNumItems, fpp)); } /** @@ -219,9 +223,14 @@ public static BloomFilter create(long expectedNumItems, double fpp) { * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. */ public static BloomFilter create(long expectedNumItems, long numBits) { - assert expectedNumItems > 0 : "Expected insertions must be > 0"; - assert numBits > 0 : "number of bits must be > 0"; - int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); - return new BloomFilterImpl(numHashFunctions, numBits); + if (expectedNumItems <= 0) { + throw new IllegalArgumentException("Expected insertions must be positive"); + } + + if (numBits <= 0) { + throw new IllegalArgumentException("Number of bits must be positive"); + } + + return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 35107e0b389d7..92c28bcb56a5a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,9 +19,10 @@ import java.io.*; -public class BloomFilterImpl extends BloomFilter implements Serializable { +class BloomFilterImpl extends BloomFilter implements Serializable { private int numHashFunctions; + private BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { @@ -77,14 +78,14 @@ public boolean put(Object item) { } @Override - public boolean putString(String str) { - return putBinary(Utils.getBytesFromUTF8String(str)); + public boolean putString(String item) { + return putBinary(Utils.getBytesFromUTF8String(item)); } @Override - public boolean putBinary(byte[] bytes) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + public boolean putBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -100,14 +101,14 @@ public boolean putBinary(byte[] bytes) { } @Override - public boolean mightContainString(String str) { - return mightContainBinary(Utils.getBytesFromUTF8String(str)); + public boolean mightContainString(String item) { + return mightContainBinary(Utils.getBytesFromUTF8String(item)); } @Override - public boolean mightContainBinary(byte[] bytes) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1); + public boolean mightContainBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -124,14 +125,14 @@ public boolean mightContainBinary(byte[] bytes) { } @Override - public boolean putLong(long l) { + public boolean putLong(long item) { // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. // Note that `CountMinSketch` use a different strategy, it hash the input long element with // every i to produce n hash values. // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? - int h1 = Murmur3_x86_32.hashLong(l, 0); - int h2 = Murmur3_x86_32.hashLong(l, h1); + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -147,9 +148,9 @@ public boolean putLong(long l) { } @Override - public boolean mightContainLong(long l) { - int h1 = Murmur3_x86_32.hashLong(l, 0); - int h2 = Murmur3_x86_32.hashLong(l, h1); + public boolean mightContainLong(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -197,7 +198,7 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep throw new IncompatibleMergeException("Cannot merge null bloom filter"); } - if (!(other instanceof BloomFilter)) { + if (!(other instanceof BloomFilterImpl)) { throw new IncompatibleMergeException( "Cannot merge bloom filter of class " + other.getClass().getName() ); @@ -211,7 +212,8 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep if (this.numHashFunctions != that.numHashFunctions) { throw new IncompatibleMergeException( - "Cannot merge bloom filters with different number of hash functions"); + "Cannot merge bloom filters with different number of hash functions" + ); } this.bits.putAll(that.bits); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index f0aac5bb00dfb..48f98680f48ca 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -22,7 +22,7 @@ import java.io.OutputStream; /** - * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in + * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in * sub-linear space. Currently, supported data types include: *

          *
        • {@link Byte}
        • @@ -31,8 +31,7 @@ *
        • {@link Long}
        • *
        • {@link String}
        • *
        - * Each {@link CountMinSketch} is initialized with a random seed, and a pair - * of parameters: + * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: *
          *
        1. relative error (or {@code eps}), and *
        2. confidence (or {@code delta}) @@ -49,16 +48,13 @@ *
        3. {@code w = ceil(-log(1 - confidence) / log(2))}
        4. *
      * - * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details, - * including proofs of the estimates and error bounds used in this implementation. - * * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { public enum Version { /** - * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * {@code CountMinSketch} binary format version 1. All values written in big-endian order: *
        *
      • Version number, always 1 (32 bit)
      • *
      • Total count of added items (64 bit)
      • @@ -172,14 +168,14 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException; /** - * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream. + * Writes out this {@link CountMinSketch} to an output stream in binary format. It is the caller's + * responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** - * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream. + * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to + * close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); @@ -188,6 +184,10 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { /** * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random * {@code seed}. + * + * @param depth depth of the Count-min Sketch, must be positive + * @param width width of the Count-min Sketch, must be positive + * @param seed random seed */ public static CountMinSketch create(int depth, int width, int seed) { return new CountMinSketchImpl(depth, width, seed); @@ -196,6 +196,10 @@ public static CountMinSketch create(int depth, int width, int seed) { /** * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence}, * and random {@code seed}. + * + * @param eps relative error, must be positive + * @param confidence confidence, must be positive and less than 1.0 + * @param seed random seed */ public static CountMinSketch create(double eps, double confidence, int seed) { return new CountMinSketchImpl(eps, confidence, seed); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index c0631c6778df4..2acbb247b13cd 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -42,6 +42,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { private CountMinSketchImpl() {} CountMinSketchImpl(int depth, int width, int seed) { + if (depth <= 0 || width <= 0) { + throw new IllegalArgumentException("Depth and width must be both positive"); + } + this.depth = depth; this.width = width; this.eps = 2.0 / width; @@ -50,6 +54,14 @@ private CountMinSketchImpl() {} } CountMinSketchImpl(double eps, double confidence, int seed) { + if (eps <= 0D) { + throw new IllegalArgumentException("Relative error must be positive"); + } + + if (confidence <= 0D || confidence >= 1D) { + throw new IllegalArgumentException("Confidence must be within range (0.0, 1.0)"); + } + // 2/w = eps ; w = 2/eps // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) this.eps = eps; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java index a6b33313035b0..feb601d44f39d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -19,7 +19,7 @@ import java.io.UnsupportedEncodingException; -public class Utils { +class Utils { public static byte[] getBytesFromUTF8String(String str) { try { return str.getBytes("utf-8"); From e38b0baa38c6894335f187eaa4c8ea5c02d4563b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 13:45:03 -0800 Subject: [PATCH 1741/2089] [SPARK-13055] SQLHistoryListener throws ClassCastException This is an existing issue uncovered recently by #10835. The reason for the exception was because the `SQLHistoryListener` gets all sorts of accumulators, not just the ones that represent SQL metrics. For example, the listener gets the `internal.metrics.shuffleRead.remoteBlocksFetched`, which is an Int, then it proceeds to cast the Int to a Long, which fails. The fix is to mark accumulators representing SQL metrics using some internal metadata. Then we can identify which ones are SQL metrics and only process those in the `SQLHistoryListener`. Author: Andrew Or Closes #10971 from andrewor14/fix-sql-history. --- .../scala/org/apache/spark/Accumulable.scala | 8 ++++ .../apache/spark/executor/TaskMetrics.scala | 4 +- .../spark/scheduler/AccumulableInfo.scala | 5 ++- .../apache/spark/scheduler/DAGScheduler.scala | 7 +--- .../org/apache/spark/util/JsonProtocol.scala | 6 ++- .../spark/executor/TaskMetricsSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 14 ++----- .../spark/scheduler/TaskSetManagerSuite.scala | 8 +--- .../apache/spark/util/JsonProtocolSuite.scala | 16 +++++--- .../sql/execution/metric/SQLMetrics.scala | 21 +++++++++- .../spark/sql/execution/ui/SQLListener.scala | 23 +++++++---- .../execution/metric/SQLMetricsSuite.scala | 24 +++++++++++- .../sql/execution/ui/SQLListenerSuite.scala | 38 ++++++++++++++++++- 13 files changed, 133 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala index 52f572b63fa95..601b503d12c7e 100644 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ b/core/src/main/scala/org/apache/spark/Accumulable.scala @@ -22,6 +22,7 @@ import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable import scala.reflect.ClassTag +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils @@ -187,6 +188,13 @@ class Accumulable[R, T] private ( */ private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) } + /** + * Create an [[AccumulableInfo]] representation of this [[Accumulable]] with the provided values. + */ + private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal, countFailedValues) + } + // Called by Java when deserializing an object private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 8d10bf588ef1f..0a6ebcb3e0293 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -323,8 +323,8 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { * field is always empty, since this represents the partial updates recorded in this task, * not the aggregated value across multiple tasks. */ - def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a => - new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues) + def accumulatorUpdates(): Seq[AccumulableInfo] = { + accums.map { a => a.toInfo(Some(a.localValue), None) } } // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 9d45fff9213c6..cedacad44afec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi * @param value total accumulated value so far, maybe None if used on executors to describe a task * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed + * @param metadata internal metadata associated with this accumulator, if any */ @DeveloperApi case class AccumulableInfo private[spark] ( @@ -43,7 +44,9 @@ case class AccumulableInfo private[spark] ( update: Option[Any], // represents a partial update within a task value: Option[Any], private[spark] val internal: Boolean, - private[spark] val countFailedValues: Boolean) + private[spark] val countFailedValues: Boolean, + // TODO: use this to identify internal task metrics instead of encoding it in the name + private[spark] val metadata: Option[String] = None) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 897479b50010d..ee0b8a1c95fd8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1101,11 +1101,8 @@ class DAGScheduler( acc ++= partialValue // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name - stage.latestInfo.accumulables(id) = new AccumulableInfo( - id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues) - event.taskInfo.accumulables += new AccumulableInfo( - id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues) + stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) + event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value)) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index dc8070cf8aad3..a2487eeb0483a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -290,7 +290,8 @@ private[spark] object JsonProtocol { ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~ ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~ ("Internal" -> accumulableInfo.internal) ~ - ("Count Failed Values" -> accumulableInfo.countFailedValues) + ("Count Failed Values" -> accumulableInfo.countFailedValues) ~ + ("Metadata" -> accumulableInfo.metadata) } /** @@ -728,7 +729,8 @@ private[spark] object JsonProtocol { val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) - new AccumulableInfo(id, name, update, value, internal, countFailedValues) + val metadata = (json \ "Metadata").extractOpt[String] + new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } /** diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 15be0b194ed8e..67c4595ed1923 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -551,8 +551,6 @@ private[spark] object TaskMetricsSuite extends Assertions { * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ - def makeInfo(a: Accumulable[_, _]): AccumulableInfo = { - new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues) - } + def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d9c71ec2eae7b..62972a0738211 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1581,12 +1581,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(Accumulators.get(acc1.id).isDefined) assert(Accumulators.get(acc2.id).isDefined) assert(Accumulators.get(acc3.id).isDefined) - val accInfo1 = new AccumulableInfo( - acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false) - val accInfo2 = new AccumulableInfo( - acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false) - val accInfo3 = new AccumulableInfo( - acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false) + val accInfo1 = acc1.toInfo(Some(15L), None) + val accInfo2 = acc2.toInfo(Some(13L), None) + val accInfo3 = acc3.toInfo(Some(18L), None) val accumUpdates = Seq(accInfo1, accInfo2, accInfo3) val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates) submit(new MyRDD(sc, 1, Nil), Array(0)) @@ -1954,10 +1951,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { val accumUpdates = reason match { - case Success => - task.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues) - } + case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) } case ef: ExceptionFailure => ef.accumUpdates case _ => Seq.empty[AccumulableInfo] } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index a2e74365641a6..2c99dd5afb32e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -165,9 +165,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) - } + val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have @@ -186,9 +184,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => - task.initialAccumulators.map { a => - new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues) - } + task.initialAccumulators.map { a => a.toInfo(Some(0L), None) } } // First three offers should all find tasks diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 57021d1d3d528..48951c3168032 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -374,15 +374,18 @@ class JsonProtocolSuite extends SparkFunSuite { test("AccumulableInfo backward compatibility") { // "Internal" property of AccumulableInfo was added in 1.5.1 val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true) - val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Internal" }) + val accumulableInfoJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) + val oldJson = accumulableInfoJson.removeField({ _._1 == "Internal" }) val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) assert(!oldInfo.internal) // "Count Failed Values" property of AccumulableInfo was added in 2.0.0 - val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Count Failed Values" }) + val oldJson2 = accumulableInfoJson.removeField({ _._1 == "Count Failed Values" }) val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2) assert(!oldInfo2.countFailedValues) + // "Metadata" property of AccumulableInfo was added in 2.0.0 + val oldJson3 = accumulableInfoJson.removeField({ _._1 == "Metadata" }) + val oldInfo3 = JsonProtocol.accumulableInfoFromJson(oldJson3) + assert(oldInfo3.metadata.isEmpty) } test("ExceptionFailure backward compatibility: accumulator updates") { @@ -820,9 +823,10 @@ private[spark] object JsonProtocolSuite extends Assertions { private def makeAccumulableInfo( id: Int, internal: Boolean = false, - countFailedValues: Boolean = false): AccumulableInfo = + countFailedValues: Boolean = false, + metadata: Option[String] = None): AccumulableInfo = new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), - internal, countFailedValues) + internal, countFailedValues, metadata) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 950dc7816241f..6b43d273fefde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.Utils /** @@ -27,9 +28,16 @@ import org.apache.spark.util.Utils * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, val param: SQLMetricParam[R, T]) + name: String, + val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name), internal = true) { + // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later + override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, Some(name), update, value, isInternal, countFailedValues, + Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + def reset(): Unit = { this.value = param.zero } @@ -73,6 +81,14 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr // Although there is a boxing here, it's fine because it's only called in SQLListener override def value: Long = _value + + // Needed for SQLListenerSuite + override def equals(other: Any): Boolean = { + other match { + case o: LongSQLMetricValue => value == o.value + case _ => false + } + } } /** @@ -126,6 +142,9 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam( private[sql] object SQLMetrics { + // Identifier for distinguishing SQL metrics from other accumulators + private[sql] val ACCUM_IDENTIFIER = "sql" + private def createLongMetric( sc: SparkContext, name: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 544606f1168b6..835e7ba6c5168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -23,7 +23,7 @@ import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric._ import org.apache.spark.ui.SparkUI @DeveloperApi @@ -314,14 +314,17 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } + +/** + * A [[SQLListener]] for rendering the SQL UI in the history server. + */ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) extends SQLListener(conf) { private var sqlTabAttached = false - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - // Do nothing + override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = { + // Do nothing; these events are not logged } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { @@ -329,9 +332,15 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { a => - val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L)) - a.copy(update = Some(newValue)) + taskEnd.taskInfo.accumulables.flatMap { a => + // Filter out accumulators that are not SQL metrics + // For now we assume all SQL metrics are Long's that have been JSON serialized as String's + if (a.metadata.exists(_ == SQLMetrics.ACCUM_IDENTIFIER)) { + val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L)) + Some(a.copy(update = Some(newValue))) + } else { + None + } }, finishTask = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503c23..2260e4870299a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{JsonProtocol, Utils} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { @@ -356,6 +356,28 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("metrics can be loaded by history server") { + val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam) + metric += 10L + val metricInfo = metric.toInfo(Some(metric.localValue), None) + metricInfo.update match { + case Some(v: LongSQLMetricValue) => assert(v.value === 10L) + case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}") + case _ => fail("metric update is missing") + } + assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + // After serializing to JSON, the original value type is lost, but we can still + // identify that it's a SQL metric from the metadata + val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo) + val metricInfoDeser = JsonProtocol.accumulableInfoFromJson(metricInfoJson) + metricInfoDeser.update match { + case Some(v: String) => assert(v.toLong === 10L) + case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}") + case _ => fail("deserialized metric update is missing") + } + assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2c408c8878470..085e4a49a57e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.ui.SparkUI class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ @@ -335,8 +336,43 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } + test("SPARK-13055: history listener only tracks SQL metrics") { + val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI])) + // We need to post other events for the listener to track our accumulators. + // These are largely just boilerplate unrelated to what we're trying to test. + val df = createTestDataFrame + val executionStart = SparkListenerSQLExecutionStart( + 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0) + val stageInfo = createStageInfo(0, 0) + val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0)) + val stageSubmitted = SparkListenerStageSubmitted(stageInfo) + // This task has both accumulators that are SQL metrics and accumulators that are not. + // The listener should only track the ones that are actually SQL metrics. + val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella") + val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") + val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val taskInfo = createTaskInfo(0, 0) + taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) + val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) + listener.onOtherEvent(executionStart) + listener.onJobStart(jobStart) + listener.onStageSubmitted(stageSubmitted) + // Before SPARK-13055, this throws ClassCastException because the history listener would + // assume that the accumulator value is of type Long, but this may not be true for + // accumulators that are not SQL metrics. + listener.onTaskEnd(taskEnd) + val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics => + stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates) + } + // Listener tracks only SQL metrics, not other accumulators + assert(trackedAccums.size === 1) + assert(trackedAccums.head === sqlMetricInfo) + } + } + class SQLListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { From 2cbc412821641cf9446c0621ffa1976bd7fc4fa1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 29 Jan 2016 16:57:34 -0800 Subject: [PATCH 1742/2089] [SPARK-13076][SQL] Rename ClientInterface -> HiveClient And ClientWrapper -> HiveClientImpl. I have some followup pull requests to introduce a new internal catalog, and I think this new naming reflects better the functionality of the two classes. Author: Reynold Xin Closes #10981 from rxin/SPARK-13076. --- .../apache/spark/sql/execution/commands.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 10 +++++----- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- ...ClientInterface.scala => HiveClient.scala} | 10 +++++----- ...ientWrapper.scala => HiveClientImpl.scala} | 20 +++++++++---------- .../spark/sql/hive/client/HiveShim.scala | 5 ++--- .../hive/client/IsolatedClientLoader.scala | 18 ++++++++--------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- .../apache/spark/sql/hive/test/TestHive.scala | 4 ++-- .../spark/sql/hive/client/VersionsSuite.scala | 4 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 12 files changed, 41 insertions(+), 42 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/client/{ClientInterface.scala => HiveClient.scala} (95%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/client/{ClientWrapper.scala => HiveClientImpl.scala} (97%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 703e4643cbd25..c6adb583f931b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -404,7 +404,7 @@ case class DescribeFunction( result } - case None => Seq(Row(s"Function: $functionName is not found.")) + case None => Seq(Row(s"Function: $functionName not found.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa30e4..2b821c1056f56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -84,7 +84,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") } test("SPARK-6743: no columns from cache") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1797ea54f2501..05863ae18350d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -79,8 +79,8 @@ class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, listener: SQLListener, - @transient private val execHive: ClientWrapper, - @transient private val metaHive: ClientInterface, + @transient private val execHive: HiveClientImpl, + @transient private val metaHive: HiveClient, isRootContext: Boolean) extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { self => @@ -193,7 +193,7 @@ class HiveContext private[hive]( * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient - protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) { + protected[hive] lazy val executionHive: HiveClientImpl = if (execHive != null) { execHive } else { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") @@ -203,7 +203,7 @@ class HiveContext private[hive]( config = newTemporaryConfiguration(useInMemoryDerby = true), isolationOn = false, baseClassLoader = Utils.getContextOrSparkClassLoader) - loader.createClient().asInstanceOf[ClientWrapper] + loader.createClient().asInstanceOf[HiveClientImpl] } /** @@ -222,7 +222,7 @@ class HiveContext private[hive]( * in the hive-site.xml file. */ @transient - protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) { + protected[hive] lazy val metadataHive: HiveClient = if (metaHive != null) { metaHive } else { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 848aa4ec6fe56..61d0d6759ff72 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -96,7 +96,7 @@ private[hive] object HiveSerDe { } } -private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) +private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) extends Catalog with Logging { val conf = hive.conf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala similarity index 95% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 4eec3fef7408b..f681cc67041a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -60,9 +60,9 @@ private[hive] case class HiveTable( viewText: Option[String] = None) { @transient - private[client] var client: ClientInterface = _ + private[client] var client: HiveClient = _ - private[client] def withClient(ci: ClientInterface): this.type = { + private[client] def withClient(ci: HiveClient): this.type = { client = ci this } @@ -85,7 +85,7 @@ private[hive] case class HiveTable( * internal and external classloaders for a given version of Hive and thus must expose only * shared classes. */ -private[hive] trait ClientInterface { +private[hive] trait HiveClient { /** Returns the Hive Version of this client. */ def version: HiveVersion @@ -184,8 +184,8 @@ private[hive] trait ClientInterface { /** Add a jar into class loader */ def addJar(path: String): Unit - /** Return a ClientInterface as new session, that will share the class loader and Hive client */ - def newSession(): ClientInterface + /** Return a [[HiveClient]] as new session, that will share the class loader and Hive client */ + def newSession(): HiveClient /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ def withHiveState[A](f: => A): A diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala similarity index 97% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 5307e924e7e55..cf1ff55c96fc9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.{CircularBuffer, Utils} * A class that wraps the HiveClient and converts its responses to externally visible classes. * Note that this class is typically loaded with an internal classloader for each instantiation, * allowing it to interact directly with a specific isolated version of Hive. Loading this class - * with the isolated classloader however will result in it only being visible as a ClientInterface, - * not a ClientWrapper. + * with the isolated classloader however will result in it only being visible as a [[HiveClient]], + * not a [[HiveClientImpl]]. * * This class needs to interact with multiple versions of Hive, but will always be compiled with * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility @@ -55,14 +55,14 @@ import org.apache.spark.util.{CircularBuffer, Utils} * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. * @param initClassLoader the classloader used when creating the `state` field of - * this ClientWrapper. + * this [[HiveClientImpl]]. */ -private[hive] class ClientWrapper( +private[hive] class HiveClientImpl( override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) - extends ClientInterface + extends HiveClient with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @@ -77,7 +77,7 @@ private[hive] class ClientWrapper( case hive.v1_2 => new Shim_v1_2() } - // Create an internal session state for this ClientWrapper. + // Create an internal session state for this HiveClientImpl. val state = { val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. @@ -160,7 +160,7 @@ private[hive] class ClientWrapper( case e: Exception if causedByThrift(e) => caughtException = e logWarning( - "HiveClientWrapper got thrift exception, destroying client and retrying " + + "HiveClient got thrift exception, destroying client and retrying " + s"(${retryLimit - numTries} tries remaining)", e) clientLoader.cachedHive = null Thread.sleep(retryDelayMillis) @@ -199,7 +199,7 @@ private[hive] class ClientWrapper( */ def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - // Set the thread local metastore client to the client associated with this ClientWrapper. + // Set the thread local metastore client to the client associated with this HiveClientImpl. Hive.set(client) // The classloader in clientLoader could be changed after addJar, always use the latest // classloader @@ -521,8 +521,8 @@ private[hive] class ClientWrapper( runSqlHive(s"ADD JAR $path") } - def newSession(): ClientWrapper = { - clientLoader.createClient().asInstanceOf[ClientWrapper] + def newSession(): HiveClientImpl = { + clientLoader.createClient().asInstanceOf[HiveClientImpl] } def reset(): Unit = withHiveState { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index ca636b0265d41..70c10be25be9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegralType, StringType} /** - * A shim that defines the interface between ClientWrapper and the underlying Hive library used to - * talk to the metastore. Each Hive version has its own implementation of this class, defining + * A shim that defines the interface between [[HiveClientImpl]] and the underlying Hive library used + * to talk to the metastore. Each Hive version has its own implementation of this class, defining * version-specific version of needed functions. * * The guideline for writing shims is: @@ -52,7 +52,6 @@ private[client] sealed abstract class Shim { /** * Set the current SessionState to the given SessionState. Also, set the context classloader of * the current thread to the one set in the HiveConf of this given `state`. - * @param state */ def setCurrentSessionState(state: SessionState): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 010051d255fdc..dca7396ee1ab4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -124,15 +124,15 @@ private[hive] object IsolatedClientLoader extends Logging { } /** - * Creates a Hive `ClientInterface` using a classloader that works according to the following rules: + * Creates a [[HiveClient]] using a classloader that works according to the following rules: * - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` - * allowing the results of calls to the `ClientInterface` to be visible externally. + * allowing the results of calls to the [[HiveClient]] to be visible externally. * - Hive classes: new instances are loaded from `execJars`. These classes are not * accessible externally due to their custom loading. - * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`. + * - [[HiveClientImpl]]: a new copy is created for each instance of `IsolatedClassLoader`. * This new instance is able to see a specific version of hive without using reflection. Since * this is a unique instance, it is not visible externally other than as a generic - * `ClientInterface`, unless `isolationOn` is set to `false`. + * [[HiveClient]], unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures * that are not compatible across versions. @@ -179,7 +179,7 @@ private[hive] class IsolatedClientLoader( /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith(classOf[ClientWrapper].getName) || + name.startsWith(classOf[HiveClientImpl].getName) || name.startsWith(classOf[Shim].getName) || barrierPrefixes.exists(name.startsWith) @@ -233,9 +233,9 @@ private[hive] class IsolatedClientLoader( } /** The isolated client interface to Hive. */ - private[hive] def createClient(): ClientInterface = { + private[hive] def createClient(): HiveClient = { if (!isolationOn) { - return new ClientWrapper(version, config, baseClassLoader, this) + return new HiveClientImpl(version, config, baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -244,10 +244,10 @@ private[hive] class IsolatedClientLoader( try { classLoader - .loadClass(classOf[ClientWrapper].getName) + .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head .newInstance(version, config, classLoader, this) - .asInstanceOf[ClientInterface] + .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => if (e.getCause().isInstanceOf[NoClassDefFoundError]) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 56cab1aee89df..d5ed838ca4b1a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -38,13 +38,13 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ private[hive] class HiveFunctionRegistry( underlying: analysis.FunctionRegistry, - executionHive: ClientWrapper) + executionHive: HiveClientImpl) extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a33223af24370..246108e0d0e11 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -458,7 +458,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) } -private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper) +private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl) extends HiveFunctionRegistry(fr, client) { private val removedFunctions = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff10a251f3b45..1344a2cc4bd37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * A simple set of tests that call the methods of a [[HiveClient]], loading different version * of hive from maven central. These tests are simple in that they are mostly just testing to make * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. @@ -101,7 +101,7 @@ class VersionsSuite extends SparkFunSuite with Logging { private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") - private var client: ClientInterface = null + private var client: HiveClient = null versions.foreach { version => test(s"$version: create client") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 0d62d799c8dce..1ada2e325bda6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -199,7 +199,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") checkExistence(sql("describe functioN `~`"), true, "Function: ~", From e6ceac49a311faf3413acda57a6612fe806adf90 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 17:59:41 -0800 Subject: [PATCH 1743/2089] [SPARK-13096][TEST] Fix flaky verifyPeakExecutionMemorySet Previously we would assert things before all events are guaranteed to have been processed. To fix this, just block until all events are actually processed, i.e. until the listener queue is empty. https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/79/testReport/junit/org.apache.spark.util.collection/ExternalAppendOnlyMapSuite/spilling/ Author: Andrew Or Closes #10990 from andrewor14/accum-suite-less-flaky. --- core/src/test/scala/org/apache/spark/AccumulatorSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 11c97d7d9a447..b8f2b96d7088d 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -307,6 +307,8 @@ private[spark] object AccumulatorSuite { val listener = new SaveInfoListener sc.addSparkListener(listener) testBody + // wait until all events have been processed before proceeding to assert things + sc.listenerBus.waitUntilEmpty(10 * 1000) val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) val isSet = accums.exists { a => a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) From 70e69fc4dd619654f5d24b8b84f6a94f7705c59b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Jan 2016 18:00:49 -0800 Subject: [PATCH 1744/2089] [SPARK-13088] Fix DAG viz in latest version of chrome Apparently chrome removed `SVGElement.prototype.getTransformToElement`, which is used by our JS library dagre-d3 when creating edges. The real diff can be found here: https://github.com/andrewor14/dagre-d3/commit/7d6c0002e4c74b82a02c5917876576f71e215590, which is taken from the fix in the main repo: https://github.com/cpettitt/dagre-d3/commit/1ef067f1c6ad2e0980f6f0ca471bce998784b7b2 Upstream issue: https://github.com/cpettitt/dagre-d3/issues/202 Author: Andrew Or Closes #10986 from andrewor14/fix-dag-viz. --- .../org/apache/spark/ui/static/dagre-d3.min.js | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index 2d9262b972a59..6fe8136c87ae0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -1,4 +1,5 @@ -/* This is a custom version of dagre-d3 on top of v0.4.3. The full list of commits can be found at http://github.com/andrewor14/dagre-d3/ */!function(e){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=e();else if("function"==typeof define&&define.amd)define([],e);else{var f;"undefined"!=typeof window?f=window:"undefined"!=typeof global?f=global:"undefined"!=typeof self&&(f=self),f.dagreD3=e()}}(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ -parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdx0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){ +normalize.run(g)});time(" parentDummyChains",function(){parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎ â€â€‚         âŸã€€";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; +})}function enterEdge(t,g,edge){var v=edge.v,w=edge.w;if(!g.hasEdge(v,w)){v=edge.w;w=edge.v}var vLabel=t.node(v),wLabel=t.node(w),tailLabel=vLabel,flip=false;if(vLabel.lim>wLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎ â€â€‚         âŸã€€";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){ +return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index Date: Fri, 29 Jan 2016 18:03:04 -0800 Subject: [PATCH 1745/2089] [SPARK-13071] Coalescing HadoopRDD overwrites existing input metrics This issue is causing tests to fail consistently in master with Hadoop 2.6 / 2.7. This is because for Hadoop 2.5+ we overwrite existing values of `InputMetrics#bytesRead` in each call to `HadoopRDD#compute`. In the case of coalesce, e.g. ``` sc.textFile(..., 4).coalesce(2).count() ``` we will call `compute` multiple times in the same task, overwriting `bytesRead` values from previous calls to `compute`. For a regression test, see `InputOutputMetricsSuite.input metrics for old hadoop with coalesce`. I did not add a new regression test because it's impossible without significant refactoring; there's a lot of existing duplicate code in this corner of Spark. This was caused by #10835. Author: Andrew Or Closes #10973 from andrewor14/fix-input-metrics-coalesce. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 7 ++++++- .../src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 7 ++++++- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 3204e6adceca2..e2ebd7f00d0d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,6 +215,7 @@ class HadoopRDD[K, V]( // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { @@ -230,9 +231,13 @@ class HadoopRDD[K, V]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 4d2816e335fe3..e71d3405c0ead 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -130,6 +130,7 @@ class NewHadoopRDD[K, V]( val conf = getConf val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -139,9 +140,13 @@ class NewHadoopRDD[K, V]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index edd87c2d8ed07..9703b16c86f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -127,6 +127,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( val conf = getConf(isDriverSide = false) val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { @@ -142,9 +143,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( case _ => None } + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(getBytesRead()) + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } From e6a02c66d53f59ba2d5c1548494ae80a385f9f5c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 20:16:11 -0800 Subject: [PATCH 1746/2089] [SPARK-12914] [SQL] generate aggregation with grouping keys This PR add support for grouping keys for generated TungstenAggregate. Spilling and performance improvements for BytesToBytesMap will be done by followup PR. Author: Davies Liu Closes #10855 from davies/gen_keys. --- .../expressions/codegen/CodeGenerator.scala | 47 ++++ .../codegen/GenerateMutableProjection.scala | 27 +- .../sql/execution/BufferedRowIterator.java | 6 +- .../aggregate/TungstenAggregate.scala | 238 ++++++++++++++++-- .../BenchmarkWholeStageCodegen.scala | 119 ++++++++- .../execution/WholeStageCodegenSuite.scala | 9 + 6 files changed, 393 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb1f7..21f9198073d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,20 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ + def addReferenceObj(name: String, obj: Any, className: String = null): String = { + val term = freshName(name) + val idx = references.length + references += obj + val clsName = Option(className).getOrElse(obj.getClass.getName) + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. @@ -198,6 +212,39 @@ class CodegenContext { } } + /** + * Update a column in MutableRow from ExprCode. + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (dataType.isInstanceOf[DecimalType]) { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + ${setColumn(row, dataType, ordinal, "null")}; + } + """ + } else { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + $row.setNullAt($ordinal); + } + """ + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b94b8..5b4dc8df8622b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val updates = validExpr.zip(index).map { case (e, i) => - if (e.nullable) { - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } else { - s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } - } else { - s""" - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - """ - } - + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index b1bbb1da10a39..6acf70dbbad0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution; +import java.io.IOException; + import scala.collection.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -34,7 +36,7 @@ public class BufferedRowIterator { // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); - public boolean hasNext() { + public boolean hasNext() throws IOException { if (currentRow == null) { processNext(); } @@ -56,7 +58,7 @@ public void setInput(Iterator iter) { * * After it's called, if currentRow is still null, it means no more rows left. */ - protected void processNext() { + protected void processNext() throws IOException { if (input.hasNext()) { currentRow = input.next(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index ff2f38bfd9105..57db7262fdaf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -114,22 +116,38 @@ case class TungstenAggregate( } } + // all the mode of aggregate expressions + private val modes = aggregateExpressions.map(_.mode).distinct + override def supportCodegen: Boolean = { - groupingExpressions.isEmpty && - // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - // The variables used as aggregation buffer - private var bufVars: Seq[ExprCode] = _ - - private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { child.asInstanceOf[CodegenSupport].upstream() } protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -176,10 +194,10 @@ case class TungstenAggregate( (resultVars, resultVars.map(_.code).mkString("\n")) } - val doAgg = ctx.freshName("doAgg") + val doAgg = ctx.freshName("doAggregateWithoutKey") ctx.addNewFunction(doAgg, s""" - | private void $doAgg() { + | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | @@ -200,7 +218,7 @@ case class TungstenAggregate( """.stripMargin } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output @@ -212,7 +230,6 @@ case class TungstenAggregate( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) @@ -232,6 +249,199 @@ case class TungstenAggregate( """.stripMargin } + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) + private val bufferSchema = StructType.fromAttributes(bufferAttributes) + + // The name for HashMap + private var hashMapTerm: String = _ + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + + /** + * Update peak execution memory, called in generated Java class. + */ + def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val metrics = TaskContext.get().taskMetrics() + metrics.incPeakExecutionMemory(mapMemory) + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + ctx.INPUT_ROW = bufferTerm + val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttributes).gen(ctx) + } + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).gen(ctx) + } + s""" + ${keyVars.map(_.code).mkString("\n")} + ${bufferVars.map(_.code).mkString("\n")} + ${aggResults.map(_.code).mkString("\n")} + ${resultVars.map(_.code).mkString("\n")} + + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).gen(ctx) + } + s""" + ${eval.map(_.code).mkString("\n")} + ${consume(ctx, eval)} + """ + } + + val doAgg = ctx.freshName("doAggregateWithKeys") + ctx.addNewFunction(doAgg, + s""" + private void $doAgg() throws java.io.IOException { + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + $iterTerm = $hashMapTerm.iterator(); + } + """) + + s""" + if (!$initAgg) { + $initAgg = true; + $doAgg(); + } + + // output the result + while ($iterTerm.next()) { + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + } + + $thisPlan.updatePeakMemory($hashMapTerm); + $hashMapTerm.free(); + """ + } + + private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + val keyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val key = keyCode.value + val buffer = ctx.freshName("aggBuffer") + + // only have DeclarativeAggregate + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + + val inputAttr = bufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input + ctx.INPUT_ROW = buffer + // TODO: support subexpression elimination + val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val updates = evals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) + } + + s""" + // generate grouping key + ${keyCode.code} + UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } + + // evaluate aggregate function + ${evals.map(_.code).mkString("\n")} + // update aggregate buffer + ${updates.mkString("\n")} + """ + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index c4aad398bfa54..2f09c8a114bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** @@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" */ class BenchmarkWholeStageCodegen extends SparkFunSuite { - def testWholeStage(values: Int): Unit = { - val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - val sc = SparkContext.getOrCreate(conf) - val sqlContext = SQLContext.getOrCreate(sc) + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) - val benchmark = new Benchmark("Single Int Column Scan", values) + def testWholeStage(values: Int): Unit = { + val benchmark = new Benchmark("rang/filter/aggregate", values) - benchmark.addCase("Without whole stage codegen") { iter => + benchmark.addCase("Without codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "false") sqlContext.range(values).filter("(id & 1) = 1").count() } - benchmark.addCase("With whole stage codegen") { iter => + benchmark.addCase("With codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "true") sqlContext.range(values).filter("(id & 1) = 1").count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- - Without whole stage codegen 7775.53 26.97 1.00 X - With whole stage codegen 342.15 612.94 22.73 X + Without codegen 7775.53 26.97 1.00 X + With codegen 342.15 612.94 22.73 X */ benchmark.run() } - ignore("benchmark") { - testWholeStage(1024 * 1024 * 200) + def testAggregateWithKey(values: Int): Unit = { + val benchmark = new Benchmark("Aggregate with keys", values) + + benchmark.addCase("Aggregate w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + benchmark.addCase(s"Aggregate w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Aggregate w/o codegen 4254.38 4.93 1.00 X + Aggregate w codegen 2661.45 7.88 1.60 X + */ + benchmark.run() + } + + def testBytesToBytesMap(values: Int): Unit = { + val benchmark = new Benchmark("BytesToBytesMap", values) + + benchmark.addCase("hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < values) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + s += h + i += 1 + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + while (i < values) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + if (loc.isDefined) { + value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, + loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { + loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + } + } + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + hash 662.06 79.19 1.00 X + BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X + BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + */ + benchmark.run() + } + + test("benchmark") { + // testWholeStage(1024 * 1024 * 200) + // testAggregateWithKey(20 << 20) + // testBytesToBytesMap(1024 * 1024 * 50) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 300788c88ab2f..c2516509dfbbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { + val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + } } From dab246f7e4664d36073ec49d9df8a11c5e998cdb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 23:37:51 -0800 Subject: [PATCH 1747/2089] [SPARK-13098] [SQL] remove GenericInternalRowWithSchema This class is only used for serialization of Python DataFrame. However, we don't require internal row there, so `GenericRowWithSchema` can also do the job. Author: Wenchen Fan Closes #10992 from cloud-fan/python. --- .../spark/sql/catalyst/expressions/rows.scala | 12 ------------ .../org/apache/spark/sql/execution/python.scala | 13 +++++-------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 387d979484f2c..be6b2530ef39e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -233,18 +233,6 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override def copy(): GenericInternalRow = this } -/** - * This is used for serialization of Python DataFrame - */ -class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) - extends GenericInternalRow(values) { - - /** No-arg constructor for serialization. */ - protected def this() = this(null, null) - - def fieldIndex(name: String): Int = schema.fieldIndex(name) -} - class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index e3a016e18db87..bf62bb05c3d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -143,7 +143,7 @@ object EvaluatePython { values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) i += 1 } - new GenericInternalRowWithSchema(values, struct) + new GenericRowWithSchema(values, struct) case (a: ArrayData, array: ArrayType) => val values = new java.util.ArrayList[Any](a.numElements()) @@ -199,10 +199,7 @@ object EvaluatePython { case (c: Long, TimestampType) => c - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) + case (c, StringType) => UTF8String.fromString(c.toString) case (c: String, BinaryType) => c.getBytes("utf-8") case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c @@ -263,11 +260,11 @@ object EvaluatePython { } /** - * Pickler for InternalRow + * Pickler for external row. */ private class RowPickler extends IObjectPickler { - private val cls = classOf[GenericInternalRowWithSchema] + private val cls = classOf[GenericRowWithSchema] // register this to Pickler and Unpickler def register(): Unit = { @@ -282,7 +279,7 @@ object EvaluatePython { } else { // it will be memorized by Pickler to save some bytes pickler.save(this) - val row = obj.asInstanceOf[GenericInternalRowWithSchema] + val row = obj.asInstanceOf[GenericRowWithSchema] // schema should always be same object for memoization pickler.save(row.schema) out.write(Opcodes.TUPLE1) From 289373b28cd2546165187de2e6a9185a1257b1e7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 30 Jan 2016 00:20:28 -0800 Subject: [PATCH 1748/2089] [SPARK-6363][BUILD] Make Scala 2.11 the default Scala version This patch changes Spark's build to make Scala 2.11 the default Scala version. To be clear, this does not mean that Spark will stop supporting Scala 2.10: users will still be able to compile Spark for Scala 2.10 by following the instructions on the "Building Spark" page; however, it does mean that Scala 2.11 will be the default Scala version used by our CI builds (including pull request builds). The Scala 2.11 compiler is faster than 2.10, so I think we'll be able to look forward to a slight speedup in our CI builds (it looks like it's about 2X faster for the Maven compile-only builds, for instance). After this patch is merged, I'll update Jenkins to add new compile-only jobs to ensure that Scala 2.10 compilation doesn't break. Author: Josh Rosen Closes #10608 from JoshRosen/SPARK-6363. --- assembly/pom.xml | 4 +-- common/sketch/pom.xml | 4 +-- core/pom.xml | 4 +-- dev/create-release/release-build.sh | 14 ++++----- dev/deps/spark-deps-hadoop-2.2 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.3 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.4 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.6 | 31 +++++++++---------- dev/deps/spark-deps-hadoop-2.7 | 31 +++++++++---------- docker-integration-tests/pom.xml | 4 +-- docs/_plugins/copy_api_dirs.rb | 2 +- docs/building-spark.md | 10 +++--- examples/pom.xml | 4 +-- external/akka/pom.xml | 4 +-- external/flume-assembly/pom.xml | 4 +-- external/flume-sink/pom.xml | 4 +-- external/flume/pom.xml | 4 +-- external/kafka-assembly/pom.xml | 4 +-- external/kafka/pom.xml | 4 +-- external/mqtt-assembly/pom.xml | 4 +-- external/mqtt/pom.xml | 4 +-- external/twitter/pom.xml | 4 +-- external/zeromq/pom.xml | 4 +-- extras/java8-tests/pom.xml | 4 +-- extras/kinesis-asl-assembly/pom.xml | 4 +-- extras/kinesis-asl/pom.xml | 4 +-- extras/spark-ganglia-lgpl/pom.xml | 4 +-- graphx/pom.xml | 4 +-- launcher/pom.xml | 4 +-- mllib/pom.xml | 4 +-- network/common/pom.xml | 4 +-- network/shuffle/pom.xml | 4 +-- network/yarn/pom.xml | 4 +-- pom.xml | 8 ++--- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 6 ++++ project/SparkBuild.scala | 12 +++---- repl/pom.xml | 8 ++--- .../scala/org/apache/spark/repl/Main.scala | 9 +++++- .../org/apache/spark/repl/ReplSuite.scala | 7 +---- sql/catalyst/pom.xml | 13 ++------ sql/core/pom.xml | 6 ++-- sql/hive-thriftserver/pom.xml | 4 +-- sql/hive/pom.xml | 4 +-- streaming/pom.xml | 4 +-- tags/pom.xml | 4 +-- tools/pom.xml | 4 +-- unsafe/pom.xml | 4 +-- yarn/pom.xml | 4 +-- 49 files changed, 186 insertions(+), 194 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 6c79f9189787d..477d4931c3a88 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-assembly_2.10 + spark-assembly_2.11 Spark Project Assembly http://spark.apache.org/ pom diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2cafe8c548f5f..442043cb51164 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sketch_2.10 + spark-sketch_2.11 jar Spark Project Sketch http://spark.apache.org/ diff --git a/core/pom.xml b/core/pom.xml index 0ab170e028ab4..be40d9936afd7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-core_2.10 + spark-core_2.11 core diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 00bf81120df65..2fd7fcc39ea28 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -134,9 +134,9 @@ if [[ "$1" == "package" ]]; then cd spark-$SPARK_VERSION-bin-$NAME - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 + # TODO There should probably be a flag to make-distribution to allow 2.10 support + if [[ $FLAGS == *scala-2.10* ]]; then + ./dev/change-scala-version.sh 2.10 fi export ZINC_PORT=$ZINC_PORT @@ -228,8 +228,8 @@ if [[ "$1" == "publish-snapshot" ]]; then $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver deploy - ./dev/change-scala-version.sh 2.11 - $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \ -DskipTests $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process @@ -266,9 +266,9 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver clean install - ./dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.10 - $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ -DskipTests $PUBLISH_PROFILES clean install # Clean-up Zinc nailgun process diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 4d9937c5cbc34..3a14499d9b4d9 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -14,13 +14,13 @@ avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -86,10 +86,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar javax.servlet-3.1.jar @@ -111,15 +110,14 @@ jets3t-0.7.1.jar jettison-1.1.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -158,19 +156,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index fd659ee20df1a..615836b3d3b77 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -102,15 +101,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -149,19 +147,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index afae3deb9ada2..f275226f1d088 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -103,15 +102,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -150,19 +148,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 5a6460136a3a0..21432a16e3659 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -109,15 +108,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -156,19 +154,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 70083e7f3d16a..20e09cd002635 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.10-0.11.2.jar -breeze_2.10-0.11.2.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.5.0.jar -chill_2.10-0.5.0.jar +chill_2.11-0.5.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.5.3.jar +jackson-module-scala_2.11-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar -jansi-1.4.jar java-xmlbuilder-1.0.jar javax.inject-1.jar javax.servlet-3.0.0.v201112011016.jar @@ -109,15 +108,14 @@ jettison-1.1.jar jetty-6.1.26.jar jetty-all-7.6.0.v20120127.jar jetty-util-6.1.26.jar -jline-2.10.5.jar jline-2.12.jar joda-time-2.9.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.10-3.2.10.jar -json4s-core_2.10-3.2.10.jar -json4s-jackson_2.10-3.2.10.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -157,19 +155,20 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.1.jar pyrolite-4.9.jar -quasiquotes_2.10-2.0.0-M8.jar reflectasm-1.07-shaded.jar -scala-compiler-2.10.5.jar -scala-library-2.10.5.jar -scala-reflect-2.10.5.jar -scalap-2.10.5.jar +scala-compiler-2.11.7.jar +scala-library-2.11.7.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.7.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.7.jar servlet-api-2.5.jar slf4j-api-1.7.10.jar slf4j-log4j12-1.7.10.jar snappy-0.2.jar snappy-java-1.1.2.jar -spire-macros_2.10-0.7.4.jar -spire_2.10-0.7.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index 78b638ecfa638..833ca29cd8218 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml - spark-docker-integration-tests_2.10 + spark-docker-integration-tests_2.11 jar Spark Project Docker Integration Tests http://spark.apache.org/ diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 174c202e37918..f926d67e6beaf 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.10/unidoc" + source = "../target/scala-2.11/unidoc" dest = "api/scala" puts "Making directory " + dest diff --git a/docs/building-spark.md b/docs/building-spark.md index e1abcf1be501d..975e1b295c8ae 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -114,13 +114,11 @@ By default Spark will build with Hive 0.13.1 bindings. mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package {% endhighlight %} -# Building for Scala 2.11 -To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: +# Building for Scala 2.10 +To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property: - ./dev/change-scala-version.sh 2.11 - mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package - -Spark does not yet support its JDBC component for Scala 2.11. + ./dev/change-scala-version.sh 2.10 + mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package # Spark Tests in Maven diff --git a/examples/pom.xml b/examples/pom.xml index 9437cee2abfdf..82baa9085b4f9 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-examples_2.10 + spark-examples_2.11 examples diff --git a/external/akka/pom.xml b/external/akka/pom.xml index 06c8e8aaabd8c..bbe644e3b32b3 100644 --- a/external/akka/pom.xml +++ b/external/akka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-akka_2.10 + spark-streaming-akka_2.11 streaming-akka diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index b2c377fe4cc9b..ac15b93c048da 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-assembly_2.10 + spark-streaming-flume-assembly_2.11 jar Spark Project External Flume Assembly http://spark.apache.org/ diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 4b6485ee0a71a..e4effe158c826 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-sink_2.10 + spark-streaming-flume-sink_2.11 streaming-flume-sink diff --git a/external/flume/pom.xml b/external/flume/pom.xml index a79656c6f7d96..d650dd034d636 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume_2.10 + spark-streaming-flume_2.11 streaming-flume diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0c466b3c4ac37..62818f5e8f434 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka-assembly_2.10 + spark-streaming-kafka-assembly_2.11 jar Spark Project External Kafka Assembly http://spark.apache.org/ diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 5180ab6dbafbd..68d52e9339b3d 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka_2.10 + spark-streaming-kafka_2.11 streaming-kafka diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index c4a1ae26ea699..ac2a3f65ed2f5 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-mqtt-assembly_2.10 + spark-streaming-mqtt-assembly_2.11 jar Spark Project External MQTT Assembly http://spark.apache.org/ diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index d3a2bf5825b08..d0d968782c7f1 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-mqtt_2.10 + spark-streaming-mqtt_2.11 streaming-mqtt diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 7b628b09ea6a5..5d4053afcbba7 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-twitter_2.10 + spark-streaming-twitter_2.11 streaming-twitter diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7781aaeed9e0c..f16bc0f319744 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-zeromq_2.10 + spark-streaming-zeromq_2.11 streaming-zeromq diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4dfe3b654df1a..0ad9c5303a36a 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - java8-tests_2.10 + java8-tests_2.11 pom Spark Project Java8 Tests POM diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 601080c2e6fbd..d1c38c7ca5d69 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kinesis-asl-assembly_2.10 + spark-streaming-kinesis-asl-assembly_2.11 jar Spark Project Kinesis Assembly http://spark.apache.org/ diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 20e2c5e0ffbee..935155eb5d362 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -19,14 +19,14 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kinesis-asl_2.10 + spark-streaming-kinesis-asl_2.11 jar Spark Kinesis Integration diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index b046a10a04d5b..bfb92791de3d8 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -19,14 +19,14 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-ganglia-lgpl_2.10 + spark-ganglia-lgpl_2.11 jar Spark Ganglia Integration diff --git a/graphx/pom.xml b/graphx/pom.xml index 388a0ef06a2b0..1813f383cdcba 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-graphx_2.10 + spark-graphx_2.11 graphx diff --git a/launcher/pom.xml b/launcher/pom.xml index 135866cea2e74..ef731948826ef 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-launcher_2.10 + spark-launcher_2.11 jar Spark Project Launcher http://spark.apache.org/ diff --git a/mllib/pom.xml b/mllib/pom.xml index 42af2b8b3e411..816f3f6830382 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-mllib_2.10 + spark-mllib_2.11 mllib diff --git a/network/common/pom.xml b/network/common/pom.xml index eda2b7307088f..bd507c2cb6c4b 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-common_2.10 + spark-network-common_2.11 jar Spark Project Networking http://spark.apache.org/ diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index f9aa7e2dd1f43..810ec10ca05b3 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_2.11 jar Spark Project Shuffle Streaming Service http://spark.apache.org/ diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a19cbb04b18c6..a28785b16e1e6 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-network-yarn_2.10 + spark-network-yarn_2.11 jar Spark Project YARN Shuffle Service http://spark.apache.org/ diff --git a/pom.xml b/pom.xml index fb7750602c425..d0387aca66d0d 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 14 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT pom Spark Project Parent POM @@ -165,7 +165,7 @@ 3.2.2 2.10.5 - 2.10 + 2.11 ${scala.version} org.scala-lang 1.9.13 @@ -2456,7 +2456,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 2.10.5 @@ -2488,7 +2488,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 2.11.7 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 41856443af49b..4adf64a5a0d86 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -95,7 +95,7 @@ object MimaBuild { // because spark-streaming-mqtt(1.6.0) depends on it. // Remove the setting on updating previousSparkVersion. val previousSparkVersion = "1.6.0" - val fullId = "spark-" + projectRef.project + "_2.10" + val fullId = "spark-" + projectRef.project + "_2.11" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a3ae4d2b730ff..3748e07f88aad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -220,6 +220,12 @@ object MimaExcludes { // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") + ) ++ Seq( + // SPARK-6363 Make Scala 2.11 the default Scala version + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") ) case v if v.startsWith("1.6") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4224a65a822b8..550b5bad8a46a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -119,11 +119,11 @@ object SparkBuild extends PomBuild { v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - if (System.getProperty("scala-2.11") == "") { - // To activate scala-2.11 profile, replace empty property value to non-empty value + if (System.getProperty("scala-2.10") == "") { + // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.11", "true") + System.setProperty("scala-2.10", "true") } profiles } @@ -382,7 +382,7 @@ object OldDeps { lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" + val fullId = id + "_2.11" Some("org.apache.spark" % fullId % "1.2.0") } @@ -390,7 +390,7 @@ object OldDeps { name := "old-deps", scalaVersion := "2.10.5", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming-flume", "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-graphx", "spark-core").map(versionArtifact(_).get intransitive()) ) @@ -704,7 +704,7 @@ object Java8TestSettings { lazy val settings = Seq( javacJVMVersion := "1.8", // Targeting Java 8 bytecode is only supported in Scala 2.11.4 and higher: - scalacJVMVersion := (if (System.getProperty("scala-2.11") == "true") "1.8" else "1.7") + scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8") ) } diff --git a/repl/pom.xml b/repl/pom.xml index efc3dd452e329..0f396c9b809bd 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-repl_2.10 + spark-repl_2.11 jar Spark Project REPL http://spark.apache.org/ @@ -159,7 +159,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 @@ -173,7 +173,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 scala-2.11/src/main/scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index bb3081d12938e..07ba28bb07545 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -33,7 +33,8 @@ object Main extends Logging { var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ - var interp = new SparkILoop // this is a public var because tests reset it. + // this is a public var because tests reset it. + var interp: SparkILoop = _ private var hasErrors = false @@ -43,6 +44,12 @@ object Main extends Logging { } def main(args: Array[String]) { + doMain(args, new SparkILoop) + } + + // Visible for testing + private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { + interp = _interp val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 63f3688c9e612..b9ed79da421a6 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -50,12 +50,7 @@ class ReplSuite extends SparkFunSuite { System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) System.setProperty("spark.master", master) - val interp = { - new SparkILoop(in, new PrintWriter(out)) - } - org.apache.spark.repl.Main.interp = interp - Main.main(Array("-classpath", classpath)) // call main - org.apache.spark.repl.Main.interp = null + Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) if (oldExecutorClasspath != null) { System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 76ca3f3bb1bfa..c2ad9b99f3ac9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-catalyst_2.10 + spark-catalyst_2.11 jar Spark Project Catalyst http://spark.apache.org/ @@ -127,13 +127,4 @@ - - - - scala-2.10 - - !scala-2.11 - - - diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 4bb55f6b7f739..89e01fc01596e 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql_2.10 + spark-sql_2.11 jar Spark Project SQL http://spark.apache.org/ @@ -44,7 +44,7 @@ org.apache.spark - spark-sketch_2.10 + spark-sketch_2.11 ${project.version} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 435e565f63458..c8d17bd468582 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive-thriftserver_2.10 + spark-hive-thriftserver_2.11 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index cd0c2aeb93a9f..14cf9acf09d5b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive_2.10 + spark-hive_2.11 jar Spark Project Hive http://spark.apache.org/ diff --git a/streaming/pom.xml b/streaming/pom.xml index 39cbd0d00f951..7d409c5d3b076 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-streaming_2.10 + spark-streaming_2.11 streaming diff --git a/tags/pom.xml b/tags/pom.xml index 9e4610dae7a65..3e8e6f6182875 100644 --- a/tags/pom.xml +++ b/tags/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-test-tags_2.10 + spark-test-tags_2.11 jar Spark Project Test Tags http://spark.apache.org/ diff --git a/tools/pom.xml b/tools/pom.xml index 30cbb6a5a59c7..b3a5ae2771241 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-tools_2.10 + spark-tools_2.11 tools diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 21fef3415adce..75fea556eeae1 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-unsafe_2.10 + spark-unsafe_2.11 jar Spark Project Unsafe http://spark.apache.org/ diff --git a/yarn/pom.xml b/yarn/pom.xml index a8c122fd40a1f..328bb6678db99 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 + spark-parent_2.11 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-yarn_2.10 + spark-yarn_2.11 jar Spark Project YARN From de283719980ae78b740e507e4d70c7ebbf6c5f74 Mon Sep 17 00:00:00 2001 From: wangyang Date: Sat, 30 Jan 2016 15:20:57 -0800 Subject: [PATCH 1749/2089] [SPARK-13100][SQL] improving the performance of stringToDate method in DateTimeUtils.scala In jdk1.7 TimeZone.getTimeZone() is synchronized, so use an instance variable to hold an GMT TimeZone object instead of instantiate it every time. Author: wangyang Closes #10994 from wangyang1992/datetimeUtil. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f18c052b68e37..a159bc6a61415 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -55,6 +55,7 @@ object DateTimeUtils { // this is year -17999, calculation: 50 * daysIn400Year final val YearZero = -17999 final val toYearZero = to2001 + 7304850 + final val TimeZoneGMT = TimeZone.getTimeZone("GMT") @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -407,7 +408,7 @@ object DateTimeUtils { segments(2) < 1 || segments(2) > 31) { return None } - val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + val c = Calendar.getInstance(TimeZoneGMT) c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) c.set(Calendar.MILLISECOND, 0) Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) From a1303de0a0e9d0c80327977abf52a79e2aa95e1f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 30 Jan 2016 23:02:49 -0800 Subject: [PATCH 1750/2089] [SPARK-13070][SQL] Better error message when Parquet schema merging fails Make sure we throw better error messages when Parquet schema merging fails. Author: Cheng Lian Author: Liang-Chi Hsieh Closes #10979 from viirya/schema-merging-failure-message. --- .../apache/spark/sql/types/StructType.scala | 6 ++-- .../datasources/parquet/ParquetRelation.scala | 33 ++++++++++++++++--- .../parquet/ParquetFilterSuite.scala | 15 +++++++++ .../parquet/ParquetSchemaSuite.scala | 30 +++++++++++++++++ 4 files changed, 77 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index da0c92864e9bd..c9e7e7fe633b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -424,13 +424,13 @@ object StructType extends AbstractDataType { if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { DecimalType(leftPrecision, leftScale) } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") } else if (leftPrecision != rightPrecision) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision") } else { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"scala $leftScale and $rightScale") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index f87590095d344..1e686d41f41db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -800,12 +800,37 @@ private[sql] object ParquetRelation extends Logging { assumeInt96IsTimestamp = assumeInt96IsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) - footers.map { footer => - ParquetRelation.readSchemaFromFooter(footer, converter) - }.reduceLeftOption(_ merge _).iterator + if (footers.isEmpty) { + Iterator.empty + } else { + var mergedSchema = ParquetRelation.readSchemaFromFooter(footers.head, converter) + footers.tail.foreach { footer => + val schema = ParquetRelation.readSchemaFromFooter(footer, converter) + try { + mergedSchema = mergedSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema of file ${footer.getFile}:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } }.collect() - partiallyMergedSchemas.reduceLeftOption(_ merge _) + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1796b3af0e37a..3ded32c450541 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -421,6 +421,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // We will remove the temporary metadata when writing Parquet file. val forPathSix = sqlContext.read.parquet(pathSix).schema assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + // sanity test: make sure optional metadata field is not wrongly set. + val pathSeven = s"${dir.getCanonicalPath}/table7" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) + val pathEight = s"${dir.getCanonicalPath}/table8" + (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) + + val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + checkAnswer( + df2, + Row(1, "1")) + + // The fields "a" and "b" exist in both two Parquet files. No metadata is set. + assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) + assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 60fa81b1ab819..d860651d421f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -449,6 +450,35 @@ class ParquetSchemaSuite extends ParquetSchemaTest { }.getMessage.contains("detected conflicting schemas")) } + test("schema merging failure error message") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema of file")) + } + + // test for second merging (after read Parquet schema in parallel done) + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema:")) + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From 0e6d92d042b0a2920d8df5959d5913ba0166a678 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 30 Jan 2016 23:05:29 -0800 Subject: [PATCH 1751/2089] [SPARK-12689][SQL] Migrate DDL parsing to the newly absorbed parser JIRA: https://issues.apache.org/jira/browse/SPARK-12689 DDLParser processes three commands: createTable, describeTable and refreshTable. This patch migrates the three commands to newly absorbed parser. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10723 from viirya/migrate-ddl-describe. --- project/MimaExcludes.scala | 5 + .../sql/catalyst/parser/ExpressionParser.g | 14 ++ .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 +- .../sql/catalyst/parser/SparkSqlParser.g | 80 +++++++- .../spark/sql/catalyst/CatalystQl.scala | 23 ++- .../org/apache/spark/sql/SQLContext.scala | 5 +- .../apache/spark/sql/execution/SparkQl.scala | 101 ++++++++- .../sql/execution/datasources/DDLParser.scala | 193 ------------------ .../spark/sql/execution/datasources/ddl.scala | 5 - .../sources/CreateTableAsSelectSuite.scala | 7 +- 10 files changed, 208 insertions(+), 229 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3748e07f88aad..8b1a7303fc5b2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -200,6 +200,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-12689 Migrate DDL parsing to the newly absorbed parser + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") ) ++ Seq( // SPARK-7799 Add "streaming-akka" project ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 0555a6ba83cbb..c162c1a0c5789 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -493,6 +493,16 @@ descFuncNames | functionIdentifier ; +//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. +looseIdentifier + : + Identifier + | looseNonReserved -> Identifier[$looseNonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + identifier : Identifier @@ -516,6 +526,10 @@ principalIdentifier | QuotedIdentifier ; +looseNonReserved + : nonReserved | KW_FROM | KW_TO + ; + //The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved //Non reserved keywords are basically the keywords that can be used as identifiers. //All the KW_* are automatically not only keywords, but also reserved keywords. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 4374cd7ef7200..e930caa291d4f 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -324,6 +324,8 @@ KW_ISOLATION: 'ISOLATION'; KW_LEVEL: 'LEVEL'; KW_SNAPSHOT: 'SNAPSHOT'; KW_AUTOCOMMIT: 'AUTOCOMMIT'; +KW_REFRESH: 'REFRESH'; +KW_OPTIONS: 'OPTIONS'; KW_WEEK: 'WEEK'|'WEEKS'; KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; @@ -470,7 +472,7 @@ Identifier fragment QuotedIdentifier : - '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); } ; WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 35bef00351d72..6591f6b0f56ce 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -142,6 +142,7 @@ TOK_UNIONTYPE; TOK_COLTYPELIST; TOK_CREATEDATABASE; TOK_CREATETABLE; +TOK_CREATETABLEUSING; TOK_TRUNCATETABLE; TOK_CREATEINDEX; TOK_CREATEINDEX_INDEXTBLNAME; @@ -371,6 +372,10 @@ TOK_TXN_READ_WRITE; TOK_COMMIT; TOK_ROLLBACK; TOK_SET_AUTOCOMMIT; +TOK_REFRESHTABLE; +TOK_TABLEPROVIDER; +TOK_TABLEOPTIONS; +TOK_TABLEOPTION; TOK_CACHETABLE; TOK_UNCACHETABLE; TOK_CLEARCACHE; @@ -660,6 +665,12 @@ import java.util.HashMap; } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { + if (input.length() > 0) { + if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') { + // When column name is backquoted, we don't care about excluded chars. + return false; + } + } for(char c : excludedCharForColumnName) { if(input.indexOf(c)>-1) { return true; @@ -781,6 +792,7 @@ ddlStatement | truncateTableStatement | alterStatement | descStatement + | refreshStatement | showStatement | metastoreCheck | createViewStatement @@ -907,12 +919,31 @@ createTableStatement @init { pushMsg("create table statement", state); } @after { popMsg(state); } : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName - ( like=KW_LIKE likeName=tableName + ( + like=KW_LIKE likeName=tableName tableRowFormat? tableFileFormat? tableLocation? tablePropertiesPrefixed? + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + ) + | + tableProvider + tableOpts? + (KW_AS selectStatementWithCTE)? + -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? + tableProvider + tableOpts? + selectStatementWithCTE? + ) | (LPAREN columnNameTypeList RPAREN)? + (p=tableProvider?) + tableOpts? tableComment? tablePartition? tableBuckets? @@ -922,8 +953,15 @@ createTableStatement tableLocation? tablePropertiesPrefixed? (KW_AS selectStatementWithCTE)? - ) - -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + -> {p != null}? + ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? + columnNameTypeList? + $p + tableOpts? + selectStatementWithCTE? + ) + -> + ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? ^(TOK_LIKETABLE $likeName?) columnNameTypeList? tableComment? @@ -935,7 +973,8 @@ createTableStatement tableLocation? tablePropertiesPrefixed? selectStatementWithCTE? - ) + ) + ) ; truncateTableStatement @@ -1379,6 +1418,13 @@ tabPartColTypeExpr : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) ; +refreshStatement +@init { pushMsg("refresh statement", state); } +@after { popMsg(state); } + : + KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName) + ; + descStatement @init { pushMsg("describe statement", state); } @after { popMsg(state); } @@ -1774,6 +1820,30 @@ showStmtIdentifier | StringLiteral ; +tableProvider +@init { pushMsg("table's provider", state); } +@after { popMsg(state); } + : + KW_USING Identifier (DOT Identifier)* + -> ^(TOK_TABLEPROVIDER Identifier+) + ; + +optionKeyValue +@init { pushMsg("table's option specification", state); } +@after { popMsg(state); } + : + (looseIdentifier (DOT looseIdentifier)*) StringLiteral + -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral) + ; + +tableOpts +@init { pushMsg("table's options", state); } +@after { popMsg(state); } + : + KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN + -> ^(TOK_TABLEOPTIONS optionKeyValue+) + ; + tableComment @init { pushMsg("table's comment", state); } @after { popMsg(state); } @@ -2132,7 +2202,7 @@ structType mapType @init { pushMsg("map type", state); } @after { popMsg(state); } - : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN -> ^(TOK_MAP $left $right) ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 536c292ab7f34..7ce2407913ade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -140,6 +140,7 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends case Token("TOK_BOOLEAN", Nil) => BooleanType case Token("TOK_STRING", Nil) => StringType case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType case Token("TOK_DOUBLE", Nil) => DoubleType case Token("TOK_DATE", Nil) => DateType @@ -156,9 +157,10 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends protected def nodeToStructField(node: ASTNode): StructField = node match { case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) + StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) => + val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build() + StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta) case _ => noParseRule("StructField", node) } @@ -222,15 +224,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Nil => ShowFunctions(None, None) case Token(name, Nil) :: Nil => - ShowFunctions(None, Some(unquoteString(name))) + ShowFunctions(None, Some(unquoteString(cleanIdentifier(name)))) case Token(db, Nil) :: Token(name, Nil) :: Nil => - ShowFunctions(Some(unquoteString(db)), Some(unquoteString(name))) + ShowFunctions(Some(unquoteString(cleanIdentifier(db))), + Some(unquoteString(cleanIdentifier(name)))) case _ => noParseRule("SHOW FUNCTIONS", node) } case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => - DescribeFunction(functionName, isExtended.nonEmpty) + DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty) case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => val (fromClause: Option[ASTNode], insertClauses, cteRelations) = @@ -611,7 +614,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C noParseRule("Select", node) } - protected val escapedIdentifier = "`([^`]+)`".r + protected val escapedIdentifier = "`(.+)`".r protected val doubleQuotedString = "\"([^\"]+)\"".r protected val singleQuotedString = "'([^']+)'".r @@ -655,7 +658,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C nodeToExpr(qualifier) match { case UnresolvedAttribute(nameParts) => UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) + case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) } /* Stars (*) */ @@ -663,7 +666,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => - UnresolvedStar(Some(target.map(_.text))) + UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text)))) /* Aggregate Functions */ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => @@ -971,7 +974,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node - val alias = getClause("TOK_TABALIAS", clauses).children.head.text + val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text) val generator = clauses.head match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index be28df3a51557..ef993c3edae37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -206,10 +206,7 @@ class SQLContext private[sql]( @transient protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) - @transient - protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser) - - protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index a5bd8ee42dec9..4174e27e9c8b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -16,11 +16,14 @@ */ package org.apache.spark.sql.execution +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.types.StructType private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ @@ -55,6 +58,86 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + case Token("TOK_REFRESHTABLE", nameParts :: Nil) => + val tableIdent = extractTableIdent(nameParts) + RefreshTable(tableIdent) + + case Token("TOK_CREATETABLEUSING", createTableArgs) => + val Seq( + temp, + allowExisting, + Some(tabName), + tableCols, + Some(Token("TOK_TABLEPROVIDER", providerNameParts)), + tableOpts, + tableAs) = getClauses(Seq( + "TEMPORARY", + "TOK_IFNOTEXISTS", + "TOK_TABNAME", "TOK_TABCOLLIST", + "TOK_TABLEPROVIDER", + "TOK_TABLEOPTIONS", + "TOK_QUERY"), createTableArgs) + + val tableIdent: TableIdentifier = extractTableIdent(tabName) + + val columns = tableCols.map { + case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField)) + } + + val provider = providerNameParts.map { + case Token(name, Nil) => name + }.mkString(".") + + val options: Map[String, String] = tableOpts.toSeq.flatMap { + case Token("TOK_TABLEOPTIONS", options) => + options.map { + case Token("TOK_TABLEOPTION", keysAndValue) => + val key = keysAndValue.init.map(_.text).mkString(".") + val value = unquoteString(keysAndValue.last.text) + (key, value) + } + }.toMap + + val asClause = tableAs.map(nodeToPlan(_)) + + if (temp.isDefined && allowExisting.isDefined) { + throw new AnalysisException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + if (asClause.isDefined) { + if (columns.isDefined) { + throw new AnalysisException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + CreateTableUsingAsSelect(tableIdent, + provider, + temp.isDefined, + Array.empty[String], + bucketSpec = None, + mode, + options, + asClause.get) + } else { + CreateTableUsing( + tableIdent, + columns, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => SetDatabaseCommand(cleanIdentifier(database)) @@ -68,26 +151,30 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly nodeToDescribeFallback(node) } else { tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) => nameParts match { - case Token(".", dbName :: tableName :: Nil) => + case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil => // It is describing a table with the format like "describe db.table". // TODO: Actually, a user may mean tableName.columnName. Need to resolve this // issue. - val tableIdent = extractTableIdent(nameParts) + val tableIdent = TableIdentifier( + cleanIdentifier(tableName), Some(cleanIdentifier(dbName))) datasources.DescribeCommand( UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => + case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil => // It is describing a column with the format like "describe db.table column". nodeToDescribeFallback(node) - case tableName => + case tableName :: Nil => // It is describing a table with the format like "describe table". datasources.DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.text), None), + UnresolvedRelation(TableIdentifier(cleanIdentifier(tableName.text)), None), isExtended = extended.isDefined) + case _ => + nodeToDescribeFallback(node) } // All other cases. - case _ => nodeToDescribeFallback(node) + case _ => + nodeToDescribeFallback(node) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala deleted file mode 100644 index f4766b037027d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ /dev/null @@ -1,193 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.datasources - -import scala.language.implicitConversions -import scala.util.matching.Regex - -import org.apache.spark.Logging -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ - -/** - * A parser for foreign DDL commands. - */ -class DDLParser(fallback: => ParserInterface) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) - - override def parseTableIdentifier(sql: String): TableIdentifier = { - fallback.parseTableIdentifier(sql) - } - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parsePlan(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => fallback.parsePlan(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = { - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = fallback.parsePlan(query.get) - CreateTableUsingAsSelect(tableIdent, - provider, - temp.isDefined, - Array.empty[String], - bucketSpec = None, - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableIdent, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - } - - // This is the same as tableIdentifier in SqlParser. - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { - case e ~ tableIdent => - DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> tableIdentifier ^^ { - case tableIndet => - RefreshTable(tableIndet) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c3603936dfd2e..1554209be9891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -169,8 +169,3 @@ class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] override def -(key: String): Map[String, String] = baseMap - key.toLowerCase } - -/** - * The exception thrown from the DDL parser. - */ -class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 6fc9febe49707..cb88a1c83c999 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,7 +22,6 @@ import java.io.{File, IOException} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -105,7 +104,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -156,7 +155,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -173,7 +172,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("a CTAS statement with column definitions is not allowed") { - intercept[DDLException]{ + intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) From 5a8b978fabb60aa178274f86432c63680c8b351a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 31 Jan 2016 13:56:13 -0800 Subject: [PATCH 1752/2089] [SPARK-13049] Add First/last with ignore nulls to functions.scala This PR adds the ability to specify the ```ignoreNulls``` option to the functions dsl, e.g: ```df.select($"id", last($"value", ignoreNulls = true).over(Window.partitionBy($"id").orderBy($"other"))``` This PR is some where between a bug fix (see the JIRA) and a new feature. I am not sure if we should backport to 1.6. cc yhuai Author: Herman van Hovell Closes #10957 from hvanhovell/SPARK-13049. --- python/pyspark/sql/functions.py | 26 +++- python/pyspark/sql/tests.py | 10 ++ .../org/apache/spark/sql/functions.scala | 118 ++++++++++++++---- .../spark/sql/DataFrameWindowSuite.scala | 32 +++++ 4 files changed, 157 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719eca8f5559e..0d5708526701e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -81,8 +81,6 @@ def _(): 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', 'count': 'Aggregate function: returns the number of items in a group.', 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', @@ -278,6 +276,18 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.3) +def first(col, ignorenulls=False): + """Aggregate function: returns the first value in a group. + + The function by default returns the first values it sees. It will return the first non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -310,6 +320,18 @@ def isnull(col): return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.3) +def last(col, ignorenulls=False): + """Aggregate function: returns the last value in a group. + + The function by default returns the last values it sees. It will return the last non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 410efbafe0792..e30aa0a796924 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -641,6 +641,16 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.sqlCtx.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a27466176a20..b970eee4e31a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -349,19 +349,51 @@ object functions extends LegacyFunctions { } /** - * Aggregate function: returns the first value in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def first(e: Column): Column = withAggregateFunction { new First(e.expr) } - - /** - * Aggregate function: returns the first value of a column in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new First(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(columnName: String, ignoreNulls: Boolean): Column = { + first(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def first(e: Column): Column = first(e, ignoreNulls = false) + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(columnName: String): Column = first(Column(columnName)) /** @@ -381,20 +413,52 @@ object functions extends LegacyFunctions { def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) /** - * Aggregate function: returns the last value in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } - - /** - * Aggregate function: returns the last value of the column in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def last(columnName: String): Column = last(Column(columnName)) + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new Last(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(columnName: String, ignoreNulls: Boolean): Column = { + last(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(e: Column): Column = last(e, ignoreNulls = false) + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) /** * Aggregate function: returns the maximum value of the expression in a group. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 09a56f6f3ae28..d38842c3c0cf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -312,4 +312,36 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 3, null, null), Row("b", 2, null, null))) } + + test("last/first with ignoreNulls") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, nullStr), + ("b", 1, nullStr), + ("b", 2, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order") + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, null, null, null, null, null, null), + Row("a", 1, null, null, "x", "x", "x", "x"), + Row("a", 2, null, null, "x", "y", "y", "y"), + Row("a", 3, null, null, "x", "z", "z", "z"), + Row("a", 4, null, null, "x", null, null, "z"), + Row("b", 1, null, null, null, null, null, null), + Row("b", 2, null, null, null, null, null, null))) + } } From c1da4d421ab78772ffa52ad46e5bdfb4e5268f47 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 Jan 2016 22:43:03 -0800 Subject: [PATCH 1753/2089] [SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression The current implementation is sub-optimal: * If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable. * If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others. This PR improves this by making the null check elimination more fine-grained. Author: Wenchen Fan Closes #10987 from cloud-fan/null-check. --- .../sql/catalyst/expressions/Expression.scala | 104 ++++++++++-------- .../expressions/codegen/CodeGenerator.scala | 32 +++++- .../spark/sql/catalyst/expressions/misc.scala | 16 +-- 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index db17ba7c84ffc..353fb92581d3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * As an example, the following does a boolean inversion (i.e. NOT). * {{{ @@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * @param f function that accepts the non-null evaluation result name of child and returns Java * code to compute the output. @@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: String => String): String = { - val eval = child.gen(ctx) + val childGen = child.gen(ctx) + val resultCode = f(childGen.value) + if (nullable) { - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; + val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${eval.isNull}) { - ${f(eval.value)} - } + $nullSafeEval """ } else { ev.isNull = "false" - eval.code + s""" + s""" + ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${f(eval.value)} + $resultCode """ } } @@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(eval1.value, eval2.value) + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + val resultCode = f(leftGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { + rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } + $nullSafeEval """ - } else { ev.isNull = "false" s""" - ${eval1.code} - ${eval2.code} + ${leftGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ @@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression { /** * Default behavior of evaluation according to the default nullability of TernaryExpression. - * If subclass of BinaryExpression override nullable, probably should also override this. + * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { val exprs = children @@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression { sys.error(s"BinaryExpressions must override either eval or nullSafeEval") /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f accepts two variable names and returns Java code to compute the output. + * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( ctx: CodegenContext, @@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression { } /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f function that accepts the 2 non-null evaluation result names of children + * @param f function that accepts the 3 non-null evaluation result names of children * and returns Java code to compute the output. */ protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): String = { - val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) + val leftGen = children(0).gen(ctx) + val midGen = children(1).gen(ctx) + val rightGen = children(2).gen(ctx) + val resultCode = f(leftGen.value, midGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { + rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + } + s""" - ${evals(0).code} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode - } - } - } + $nullSafeEval """ } else { ev.isNull = "false" s""" - ${evals(0).code} - ${evals(1).code} - ${evals(2).code} + ${leftGen.code} + ${midGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 21f9198073d74..a30aba16170a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,17 +402,37 @@ class CodegenContext { } /** - * Generates code for greater of two expressions. - * - * @param dataType data type of the expressions - * @param c1 name of the variable of expression 1's output - * @param c2 name of the variable of expression 2's output - */ + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code to do null safe execution, i.e. only execute the code when the input is not + * null by adding null check if necessary. + * + * @param nullable used to decide whether we should add null check or not. + * @param isNull the code to check if the input is null. + * @param execute the code that should only be executed when the input is not null. + */ + def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execute + } + """ + } else { + "\n" + execute + } + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8480c3f9a12f6..36e1fa1176d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.gen(ctx) - childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") @@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression """ } - private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { - if (nullable) { - s""" - if (!$isNull) { - $execution - } - """ - } else { - "\n" + execution - } - } - private def nullSafeElementHash( input: String, index: String, @@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ctx: CodegenContext): String = { val element = ctx.freshName("element") - generateNullCheck(nullable, s"$input.isNullAt($index)") { + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} From 6075573a93176ee8c071888e4525043d9e73b061 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 1 Feb 2016 11:02:17 -0800 Subject: [PATCH 1754/2089] [SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream Add a local property to indicate if checkpointing all RDDs that are marked with the checkpoint flag, and enable it in Streaming Author: Shixiong Zhu Closes #10934 from zsxwing/recursive-checkpoint. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 +++++ .../org/apache/spark/CheckpointSuite.scala | 21 ++++++ .../streaming/scheduler/JobGenerator.scala | 5 ++ .../streaming/scheduler/JobScheduler.scala | 7 +- .../spark/streaming/CheckpointSuite.scala | 69 +++++++++++++++++++ 5 files changed, 119 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index be47172581b7f..e8157cf4ebe7d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1542,6 +1542,15 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default, + // we stop as soon as we find the first such RDD, an optimization that allows us to write + // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both + // an RDD and its parent in every batch, in which case the parent may never be checkpointed + // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). + private val checkpointAllMarkedAncestors = + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) + .map(_.toBoolean).getOrElse(false) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -1585,6 +1594,13 @@ abstract class RDD[T: ClassTag]( if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { + if (checkpointAllMarkedAncestors) { + // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint + // them in parallel. + // Checkpoint parents first because our lineage will be truncated after we + // checkpoint ourselves + dependencies.foreach(_.rdd.doCheckpoint()) + } checkpointData.get.checkpoint() } else { dependencies.foreach(_.rdd.doCheckpoint()) @@ -1704,6 +1720,9 @@ abstract class RDD[T: ClassTag]( */ object RDD { + private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS = + "spark.checkpoint.checkpointAllMarkedAncestors" + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 390764ba242fd..ce35856dce3f7 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -512,6 +512,27 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } + + runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean => + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true) + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false) + } + + private def testCheckpointAllMarkedAncestors( + reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString) + try { + val rdd1 = sc.parallelize(1 to 10) + checkpoint(rdd1, reliableCheckpoint) + val rdd2 = rdd1.map(_ + 1) + checkpoint(rdd2, reliableCheckpoint) + rdd2.count() + assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors) + assert(rdd2.isCheckpointed === true) + } finally { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null) + } + } } /** RDD partition that has large serialized size. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index a5a01e77639c4..a3ad5eaa40edc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} @@ -243,6 +244,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Example: BlockRDDs are created in this thread, and it needs to access BlockManager // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. SparkEnv.set(ssc.env) + + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") Try { jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch graph.generateJobs(time) // generate jobs using allocated block diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 9535c8e5b768a..3fed3d88354c7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -23,10 +23,10 @@ import scala.collection.JavaConverters._ import scala.util.Failure import org.apache.spark.Logging -import org.apache.spark.rdd.PairRDDFunctions +import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -210,6 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { s"""Streaming job from $batchLinkText""") ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") // We need to assign `eventLoop` to a temp variable. Otherwise, because // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 4a6b91fbc745e..786703eb9a84e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester checkpointWriter.stop() } + test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") { + // In this test, there are two updateStateByKey operators. The RDD DAG is as follows: + // + // batch 1 batch 2 batch 3 ... + // + // 1) input rdd input rdd input rdd + // | | | + // v v v + // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 3) map rdd --- map rdd --- map rdd ... + // | | | + // v v v + // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 5) map rdd --- map rdd --- map rdd ... + // + // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to + // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second + // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first + // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3) + // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken + // and the RDD chain will grow infinitely and cause StackOverflow. + // + // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing + // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break + // connections between layer 2 and layer 3) + ssc = new StreamingContext(master, framework, batchDuration) + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + @volatile var shouldCheckpointAllMarkedRDDs = false + @volatile var rddsCheckpointed = false + inputDStream.map(i => (i, i)) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .foreachRDD { rdd => + /** + * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors. + */ + def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = { + val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList + if (rdd.checkpointData.isDefined) { + rdd :: markedRDDs + } else { + markedRDDs + } + } + + shouldCheckpointAllMarkedRDDs = + Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)). + map(_.toBoolean).getOrElse(false) + + val stateRDDs = findAllMarkedRDDs(rdd) + rdd.count() + // Check the two state RDDs are both checkpointed + rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed) + } + ssc.start() + batchCounter.waitUntilBatchesCompleted(1, 10000) + assert(shouldCheckpointAllMarkedRDDs === true) + assert(rddsCheckpointed === true) + } + /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. From 33c8a490f7f64320c53530a57bd8d34916e3607c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 1 Feb 2016 11:22:02 -0800 Subject: [PATCH 1755/2089] [SPARK-12989][SQL] Delaying Alias Cleanup after ExtractWindowExpressions JIRA: https://issues.apache.org/jira/browse/SPARK-12989 In the rule `ExtractWindowExpressions`, we simply replace alias by the corresponding attribute. However, this will cause an issue exposed by the following case: ```scala val data = Seq(("a", "b", "c", 3), ("c", "b", "a", 3)).toDF("A", "B", "C", "num") .withColumn("Data", struct("A", "B", "C")) .drop("A") .drop("B") .drop("C") val winSpec = Window.partitionBy("Data.A", "Data.B").orderBy($"num".desc) data.select($"*", max("num").over(winSpec) as "max").explain(true) ``` In this case, both `Data.A` and `Data.B` are `alias` in `WindowSpecDefinition`. If we replace these alias expression by their alias names, we are unable to know what they are since they will not be put in `missingExpr` too. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10963 from gatorsmile/seletStarAfterColDrop. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++-- .../org/apache/spark/sql/DataFrameWindowSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5fe700ee00673..ee60fca1ad4fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -883,12 +883,13 @@ class Analyzer( if (missingExpr.nonEmpty) { extractedExprBuffer += ne } - ne.toAttribute + // alias will be cleaned in the rule CleanupAliases + ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with - // an interal column name, e.g. "_w0"). + // an internal column name, e.g. "_w0"). val withName = Alias(e, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index d38842c3c0cf0..2bcbb1983f7ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -344,4 +344,14 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 1, null, null, null, null, null, null), Row("b", 2, null, null, null, null, null, null))) } + + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { + val src = Seq((0, 3, 5)).toDF("a", "b", "c") + .withColumn("Data", struct("a", "b")) + .drop("a") + .drop("b") + val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc) + val df = src.select($"*", max("c").over(winSpec) as "max") + checkAnswer(df, Row(5, Row(0, 3), 5)) + } } From 8f26eb5ef6853a6666d7d9481b333de70bc501ed Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 1 Feb 2016 11:57:13 -0800 Subject: [PATCH 1756/2089] [SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences JIRA: https://issues.apache.org/jira/browse/SPARK-12705 **Scope:** This PR is a general fix for sorting reference resolution when the child's `outputSet` does not have the order-by attributes (called, *missing attributes*): - UnaryNode support is limited to `Project`, `Window`, `Aggregate`, `Distinct`, `Filter`, `RepartitionByExpression`. - We will not try to resolve the missing references inside a subquery, unless the outputSet of this subquery contains it. **General Reference Resolution Rules:** - Jump over the nodes with the following types: `Distinct`, `Filter`, `RepartitionByExpression`. Do not need to add missing attributes. The reason is their `outputSet` is decided by their `inputSet`, which is the `outputSet` of their children. - Group-by expressions in `Aggregate`: missing order-by attributes are not allowed to be added into group-by expressions since it will change the query result. Thus, in RDBMS, it is not allowed. - Aggregate expressions in `Aggregate`: if the group-by expressions in `Aggregate` contains the missing attributes but aggregate expressions do not have it, just add them into the aggregate expressions. This can resolve the analysisExceptions thrown by the three TCPDS queries. - `Project` and `Window` are special. We just need to add the missing attributes to their `projectList`. **Implementation:** 1. Traverse the whole tree in a pre-order manner to find all the resolvable missing order-by attributes. 2. Traverse the whole tree in a post-order manner to add the found missing order-by attributes to the node if their `inputSet` contains the attributes. 3. If the origins of the missing order-by attributes are different nodes, each pass only resolves the missing attributes that are from the same node. **Risk:** Low. This rule will be trigger iff ```!s.resolved && child.resolved``` is true. Thus, very few cases are affected. Author: gatorsmile Closes #10678 from gatorsmile/sortWindows. --- .../sql/catalyst/analysis/Analyzer.scala | 101 ++++++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 83 ++++++++++++++ .../sql/catalyst/analysis/TestRelations.scala | 6 ++ .../apache/spark/sql/DataFrameJoinSuite.scala | 16 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 84 ++++++++++++++- 6 files changed, 274 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ee60fca1ad4fe..a983dc1cdfebe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -452,7 +453,7 @@ class Analyzer( i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on grandchild + // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) @@ -533,38 +534,96 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(ordering, global, p @ Project(projectList, child)) - if !s.resolved && p.resolved => - val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) + // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, child: Aggregate) => sa - // If this rule was not a no-op, return the transformed plan, otherwise return the original. - if (missing.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(p.output, - Sort(newOrdering, global, - Project(projectList ++ missing, child))) - } else { - logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") + case s @ Sort(_, _, child) if !s.resolved && child.resolved => + val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) + + if (missingResolvableAttrs.isEmpty) { + val unresolvableAttrs = s.order.filterNot(_.resolved) + logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") s // Nothing we can do here. Return original plan. + } else { + // Add the missing attributes into projectList of Project/Window or + // aggregateExpressions of Aggregate, if they are in the inputSet + // but not in the outputSet of the plan. + val newChild = child transformUp { + case p: Project => + p.copy(projectList = p.projectList ++ + missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) + case w: Window => + w.copy(projectList = w.projectList ++ + missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) + case a: Aggregate => + val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) + val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) + val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case o => o + } + + // Add missing attributes and then project them away after the sort. + Project(child.output, + Sort(newOrdering, s.global, newChild)) } } /** - * Given a child and a grandchild that are present beneath a sort operator, try to resolve - * the sort ordering and returns it with a list of attributes that are missing from the - * child but are present in the grandchild. + * Traverse the tree until resolving the sorting attributes + * Return all the resolvable missing sorting attributes + */ + @tailrec + private def collectResolvableMissingAttrs( + ordering: Seq[SortOrder], + plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + plan match { + // Only Windows and Project have projectList-like attribute. + case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) { + (newOrdering, missingAttrs) + } else { + collectResolvableMissingAttrs(ordering, un.child) + } + case a: Aggregate => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) + // For Aggregate, all the order by columns must be specified in group by clauses + if (missingAttrs.nonEmpty && + missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { + (newOrdering, missingAttrs) + } else { + // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes + (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + // Jump over the following UnaryNode types + // The output of these types is the same as their child's output + case _: Distinct | + _: Filter | + _: RepartitionByExpression => + collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) + // If hitting the other unsupported operators, we are unable to resolve it. + case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + } + + /** + * Try to resolve the sort ordering and returns it with a list of attributes that are missing + * from the plan but are present in the child. */ - def resolveAndFindMissing( + private def resolveAndFindMissing( ordering: Seq[SortOrder], - child: LogicalPlan, - grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) + plan: LogicalPlan, + child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + val newOrdering = resolveSortOrders(ordering, child, throws = false) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. - val missingInProject = requiredAttributes -- child.output + val missingInProject = requiredAttributes -- plan.outputSet // It is important to return the new SortOrders here, instead of waiting for the standard // resolving process as adding attributes to the project below can actually introduce // ambiguity that was not present before. @@ -719,7 +778,7 @@ class Analyzer( } } - protected def containsAggregate(condition: Expression): Boolean = { + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1938bce02a177..ebf885a8fe484 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } + test("resolve sort references - filter/limit") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + + // Case 1: one missing attribute is in the leaf node and another is in the unary node + val plan1 = testRelation2 + .where('a > "str").select('a, 'b) + .where('b > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected1 = testRelation2 + .where(a > "str").select(a, b, c) + .where(b > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a, b).select(a) + checkAnalysis(plan1, expected1) + + // Case 2: all the missing attributes are in the leaf node + val plan2 = testRelation2 + .where('a > "str").select('a) + .where('a > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected2 = testRelation2 + .where(a > "str").select(a, b, c) + .where(a > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan2, expected2) + } + + test("resolve sort references - join") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val h = testRelation3.output(3) + + // Case: join itself can resolve all the missing attributes + val plan = testRelation2.join(testRelation3) + .where('a > "str").select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected = testRelation2.join(testRelation3) + .where(a > "str").select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b) + checkAnalysis(plan, expected) + } + + test("resolve sort references - aggregate") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val alias_a3 = count(a).as("a3") + val alias_b = b.as("aggOrder") + + // Case 1: when the child of Sort is not Aggregate, + // the sort reference is handled by the rule ResolveSortReferences + val plan1 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .select('a, 'c, 'a3) + .orderBy('b.asc) + + val expected1 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, b) + .select(a, c, alias_a3.toAttribute, b) + .orderBy(b.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan1, expected1) + + // Case 2: when the child of Sort is Aggregate, + // the sort reference is handled by the rule ResolveAggregateFunctions + val plan2 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .orderBy('b.asc) + + val expected2 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, alias_b) + .orderBy(alias_b.toAttribute.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan2, expected2) + } + test("resolve relations") { assertAnalysisError( UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index bc07b609a3413..3741a6ba95a86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -31,6 +31,12 @@ object TestRelations { AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) + val testRelation3 = LocalRelation( + AttributeReference("e", ShortType)(), + AttributeReference("f", StringType)(), + AttributeReference("g", DoubleType)(), + AttributeReference("h", DecimalType(10, 2))()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c17be8ace9287..a5e5f156423cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - sorted columns not in join's outputSet") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3) + + checkAnswer( + df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") + .orderBy('str_sort.asc, 'str.asc), + Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + + checkAnswer( + df2.join(df3, $"df2.int" === $"df3.int", "inner") + .select($"df2.int", $"df3.int").orderBy($"df2.str".desc), + Row(5, 5) :: Row(1, 1) :: Nil) + } + test("join - join using multiple columns and specifying join type") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4ff99bdf2937d..c02133ffc8540 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -954,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(expected === actual) } + test("Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-9323: DataFrame.orderBy should support nested column name") { val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1ada2e325bda6..6048b8f5a3998 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), (2 to 6).map(i => Row(i))) } - test("window function: udaf with aggregate expressin") { + test("window function: udaf with aggregate expression") { val data = Seq( WindowData(1, "a", 5), WindowData(2, "a", 6), @@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("d", 1), + ("c", 2) + ).map(i => Row(i._1, i._2))) + } + + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") From da9146c91a33577ff81378ca7e7c38a4b1917876 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 1 Feb 2016 12:02:06 -0800 Subject: [PATCH 1757/2089] [DOCS] Fix the jar location of datanucleus in sql-programming-guid.md ISTM `lib` is better because `datanucleus` jars are located in `lib` for release builds. Author: Takeshi YAMAMURO Closes #10901 from maropu/DocFix. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fddc51379406b..550a40010e828 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1695,7 +1695,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running -the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. From 711ce048a285403241bbc9eaabffc1314162e89c Mon Sep 17 00:00:00 2001 From: Lewuathe Date: Mon, 1 Feb 2016 12:21:21 -0800 Subject: [PATCH 1758/2089] [ML][MINOR] Invalid MulticlassClassification reference in ml-guide In [ml-guide](https://spark.apache.org/docs/latest/ml-guide.html#example-model-selection-via-cross-validation), there is invalid reference to `MulticlassClassificationEvaluator` apidoc. https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator Author: Lewuathe Closes #10996 from Lewuathe/fix-typo-in-ml-guide. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5aafd53b584e7..f8279262e673f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -627,7 +627,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName` method in each of these evaluators. From 51b03b71ffc390e67b32936efba61e614a8b0d86 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Mon, 1 Feb 2016 12:45:02 -0800 Subject: [PATCH 1759/2089] [SPARK-12463][SPARK-12464][SPARK-12465][SPARK-10647][MESOS] Fix zookeeper dir with mesos conf and add docs. Fix zookeeper dir configuration used in cluster mode, and also add documentation around these settings. Author: Timothy Chen Closes #10057 from tnachen/fix_mesos_dir. --- .../deploy/mesos/MesosClusterDispatcher.scala | 6 ++--- .../mesos/MesosClusterPersistenceEngine.scala | 4 ++-- docs/configuration.md | 23 +++++++++++++++++++ docs/running-on-mesos.md | 5 +++- docs/spark-standalone.md | 23 ++++--------------- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 66e1e645007a7..9b31497adfb12 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -50,7 +50,7 @@ private[mesos] class MesosClusterDispatcher( extends Logging { private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) - private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) private val engineFactory = recoveryMode match { @@ -98,8 +98,8 @@ private[mesos] object MesosClusterDispatcher extends Logging { conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => - conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") - conf.set("spark.mesos.deploy.zookeeper.url", z) + conf.set("spark.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.deploy.zookeeper.url", z) } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index e0c547dce6d07..092d9e4182530 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -53,9 +53,9 @@ private[spark] trait MesosClusterPersistenceEngine { * all of them reuses the same connection pool. */ private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) - extends MesosClusterPersistenceEngineFactory(conf) { + extends MesosClusterPersistenceEngineFactory(conf) with Logging { - lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + lazy val zk = SparkCuratorUtil.newClient(conf) def createEngine(path: String): MesosClusterPersistenceEngine = { new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) diff --git a/docs/configuration.md b/docs/configuration.md index 74a8fb5d35a66..93b399d819ccd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1585,6 +1585,29 @@ Apart from these, the following properties are also available, and may be useful +#### Deploy + + + + + + + + + + + + + + + + + + +
        Property NameDefaultMeaniing
        spark.deploy.recoveryModeNONEThe recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches. + This is only applicable for cluster mode when running with Standalone or Mesos.
        spark.deploy.zookeeper.urlNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.
        spark.deploy.zookeeper.dirNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.
        + + #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ed720f1039f94..0ef1ccb36e117 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -153,7 +153,10 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. + +The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]. From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fe9ec3542b28..3de72bc016dd4 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -112,8 +112,8 @@ You can optionally configure the cluster further by setting environment variable SPARK_LOCAL_DIRS - Directory to use for "scratch" space in Spark, including map output files and RDDs that get - stored on disk. This should be on a fast, local disk in your system. It can also be a + Directory to use for "scratch" space in Spark, including map output files and RDDs that get + stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. @@ -341,23 +341,8 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** -In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration: - - - - - - - - - - - - - - - -
        System propertyMeaning
        spark.deploy.recoveryModeSet to ZOOKEEPER to enable standby Master recovery mode (default: NONE).
        spark.deploy.zookeeper.urlThe ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).
        spark.deploy.zookeeper.dirThe directory in ZooKeeper to store recovery state (default: /spark).
        +In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). From a41b68b954ba47284a1df312f0aaea29b0721b0a Mon Sep 17 00:00:00 2001 From: Nilanjan Raychaudhuri Date: Mon, 1 Feb 2016 13:33:24 -0800 Subject: [PATCH 1760/2089] [SPARK-12265][MESOS] Spark calls System.exit inside driver instead of throwing exception This takes over #10729 and makes sure that `spark-shell` fails with a proper error message. There is a slight behavioral change: before this change `spark-shell` would exit, while now the REPL is still there, but `sc` and `sqlContext` are not defined and the error is visible to the user. Author: Nilanjan Raychaudhuri Author: Iulian Dragos Closes #10921 from dragos/pr/10729. --- .../cluster/mesos/MesosClusterScheduler.scala | 1 + .../cluster/mesos/MesosSchedulerBackend.scala | 1 + .../cluster/mesos/MesosSchedulerUtils.scala | 21 +++++++++++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 05fda0fded7f8..e77d77208ccba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -573,6 +573,7 @@ private[spark] class MesosClusterScheduler( override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} override def error(driver: SchedulerDriver, error: String): Unit = { logError("Error received: " + error) + markErr() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index eaf0cb06d6c73..a8bf79a78cf5d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -375,6 +375,7 @@ private[spark] class MesosSchedulerBackend( override def error(d: SchedulerDriver, message: String) { inClassLoader() { logError("Mesos error: " + message) + markErr() scheduler.error(message) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 010caff3e39b2..f9f5da9bc8df6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -106,28 +106,37 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.await() return } + @volatile + var error: Option[Exception] = None + // We create a new thread that will block inside `mesosDriver.run` + // until the scheduler exists new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { setDaemon(true) - override def run() { - mesosDriver = newDriver try { + mesosDriver = newDriver val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { - System.exit(1) + error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) + markErr() } } catch { case e: Exception => { logError("driver.run() failed", e) - System.exit(1) + error = Some(e) + markErr() } } } }.start() registerLatch.await() + + // propagate any error to the calling thread. This ensures that SparkContext creation fails + // without leaving a broken context that won't be able to schedule any tasks + error.foreach(throw _) } } @@ -144,6 +153,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.countDown() } + protected def markErr(): Unit = { + registerLatch.countDown() + } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { val builder = Resource.newBuilder() .setName(name) From c9b89a0a0921ce3d52864afd4feb7f37b90f7b46 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Mon, 1 Feb 2016 13:38:38 -0800 Subject: [PATCH 1761/2089] =?UTF-8?q?[SPARK-12979][MESOS]=20Don=E2=80=99t?= =?UTF-8?q?=20resolve=20paths=20on=20the=20local=20file=20system=20in=20Me?= =?UTF-8?q?sos=20scheduler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The driver filesystem is likely different from where the executors will run, so resolving paths (and symlinks, etc.) will lead to invalid paths on executors. Author: Iulian Dragos Closes #10923 from dragos/issue/canonical-paths. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 2 +- .../spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 58c30e7d97886..2f095b86c69ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -179,7 +179,7 @@ private[spark] class CoarseMesosSchedulerBackend( .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) if (uri.isEmpty) { - val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath + val runScript = new File(executorSparkHome, "./bin/spark-class").getPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index e77d77208ccba..8cda4ff0eb3b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -394,7 +394,7 @@ private[spark] class MesosClusterScheduler( .getOrElse { throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } - val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getPath // Sandbox points to the current directory by default with Mesos. (cmdExecutable, ".") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a8bf79a78cf5d..340f29bac9218 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -125,7 +125,7 @@ private[spark] class MesosSchedulerBackend( val executorBackendName = classOf[MesosExecutorBackend].getName if (uri.isEmpty) { - val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath + val executorPath = new File(executorSparkHome, "/bin/spark-class").getPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to From 064b029c6a15481fc4dfb147100c19a68cd1cc95 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 1 Feb 2016 13:56:14 -0800 Subject: [PATCH 1762/2089] [SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch. This includes: float, boolean, short, decimal and calendar interval. Decimal is mapped to long or byte array depending on the size and calendar interval is mapped to a struct of int and long. The only remaining type is map. The schema mapping is straightforward but we might want to revisit how we deal with this in the rest of the execution engine. Author: Nong Li Closes #10961 from nongli/spark-13043. --- .../apache/spark/sql/types/DecimalType.scala | 22 +++ .../execution/vectorized/ColumnVector.java | 180 +++++++++++++++++- .../vectorized/ColumnVectorUtils.java | 34 +++- .../execution/vectorized/ColumnarBatch.java | 46 +++-- .../vectorized/OffHeapColumnVector.java | 98 +++++++++- .../vectorized/OnHeapColumnVector.java | 94 ++++++++- .../vectorized/ColumnarBatchSuite.scala | 44 ++++- .../org/apache/spark/unsafe/Platform.java | 8 + 8 files changed, 484 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index cf5322125bd72..5dd661ee6b339 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a long + */ + def is64BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_LONG_DIGITS + case _ => false + } + } + + /** + * Returns if dt is a DecimalType that doesn't fit inside a long + */ + def isByteArrayDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision > Decimal.MAX_LONG_DIGITS + case _ => false + } + } + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a0bf8734b6545..a5bc506a65ac2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -102,18 +105,36 @@ public Object[] array() { DataType dt = data.dataType(); Object[] list = new Object[length]; - if (dt instanceof ByteType) { + if (dt instanceof BooleanType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getBoolean(offset + i); + } + } + } else if (dt instanceof ByteType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getByte(offset + i); } } + } else if (dt instanceof ShortType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getShort(offset + i); + } + } } else if (dt instanceof IntegerType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getInt(offset + i); } } + } else if (dt instanceof FloatType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getFloat(offset + i); + } + } } else if (dt instanceof DoubleType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { @@ -126,12 +147,25 @@ public Object[] array() { list[i] = data.getLong(offset + i); } } + } else if (dt instanceof DecimalType) { + DecimalType decType = (DecimalType)dt; + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getDecimal(i, decType.precision(), decType.scale()); + } + } } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); } } + } else if (dt instanceof CalendarIntervalType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getInterval(i); + } + } } else { throw new NotImplementedException("Type " + dt); } @@ -170,7 +204,14 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -181,17 +222,22 @@ public UTF8String getUTF8String(int ordinal) { @Override public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = data.getByteArray(offset + ordinal); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); } @Override public InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); + return data.getStruct(offset + ordinal); } @Override @@ -279,6 +325,21 @@ public void reset() { */ public abstract boolean getIsNull(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Returns the value for rowId. + */ + public abstract boolean getBoolean(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -299,6 +360,26 @@ public void reset() { */ public abstract byte getByte(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract short getShort(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -351,6 +432,33 @@ public void reset() { */ public abstract long getLong(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract float getFloat(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -369,7 +477,7 @@ public void reset() { /** * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formated doubles. + * The data in src must be ieee formatted doubles. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -469,6 +577,20 @@ public final int appendNotNulls(int count) { return result; } + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendByte(byte v) { reserve(elementsAppended + 1); putByte(elementsAppended, v); @@ -491,6 +613,28 @@ public final int appendBytes(int length, byte[] src, int offset) { return result; } + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendInt(int v) { reserve(elementsAppended + 1); putInt(elementsAppended, v); @@ -535,6 +679,20 @@ public final int appendLongs(int length, long[] src, int offset) { return result; } + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); @@ -661,7 +819,8 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { this.capacity = capacity; this.type = type; - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) { + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { @@ -682,6 +841,13 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new ColumnVector[2]; + this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 6c651a759d250..453bc15e13503 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Iterator; import java.util.List; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.commons.lang.NotImplementedException; @@ -59,19 +62,44 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) { private static void appendValue(ColumnVector dst, DataType t, Object o) { if (o == null) { - dst.appendNull(); + if (t instanceof CalendarIntervalType) { + dst.appendStruct(true); + } else { + dst.appendNull(); + } } else { - if (t == DataTypes.ByteType) { - dst.appendByte(((Byte)o).byteValue()); + if (t == DataTypes.BooleanType) { + dst.appendBoolean(((Boolean)o).booleanValue()); + } else if (t == DataTypes.ByteType) { + dst.appendByte(((Byte) o).byteValue()); + } else if (t == DataTypes.ShortType) { + dst.appendShort(((Short)o).shortValue()); } else if (t == DataTypes.IntegerType) { dst.appendInt(((Integer)o).intValue()); } else if (t == DataTypes.LongType) { dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.FloatType) { + dst.appendFloat(((Float)o).floatValue()); } else if (t == DataTypes.DoubleType) { dst.appendDouble(((Double)o).doubleValue()); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(); dst.appendByteArray(b, 0, b.length); + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + dst.appendLong(d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + dst.appendByteArray(bytes, 0, bytes.length); + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)o; + dst.appendStruct(false); + dst.getChildColumn(0).appendInt(c.months); + dst.getChildColumn(1).appendLong(c.microseconds); } else { throw new NotImplementedException("Type " + t); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 5a575811fa896..dbad5e070f1fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -150,44 +153,40 @@ public final boolean anyNull() { } @Override - public final boolean isNullAt(int ordinal) { - return columns[ordinal].getIsNull(rowId); - } + public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); } @Override - public final boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } + public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } @Override public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override - public final short getShort(int ordinal) { - throw new NotImplementedException(); - } + public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } @Override - public final int getInt(int ordinal) { - return columns[ordinal].getInt(rowId); - } + public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } @Override public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override - public final float getFloat(int ordinal) { - throw new NotImplementedException(); - } + public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } @Override - public final double getDouble(int ordinal) { - return columns[ordinal].getDouble(rowId); - } + public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -198,12 +197,17 @@ public final UTF8String getUTF8String(int ordinal) { @Override public final byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = columns[ordinal].getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public final CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 335124fd5a603..22c5e5fc81a4a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,11 +19,15 @@ import java.nio.ByteOrder; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; import org.apache.spark.sql.types.IntegerType; import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -121,6 +125,26 @@ public final boolean getIsNull(int rowId) { return Platform.getByte(null, nulls + rowId) == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0)); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, v); + } + } + + @Override + public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + // // APIs dealing with Bytes // @@ -148,6 +172,34 @@ public final byte getByte(int rowId) { return Platform.getByte(null, data + rowId); } + // + // APIs dealing with shorts + // + + @Override + public final void putShort(int rowId, short value) { + Platform.putShort(null, data + 2 * rowId, value); + } + + @Override + public final void putShorts(int rowId, int count, short value) { + long offset = data + 2 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putShort(null, offset, value); + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, + null, data + 2 * rowId, count * 2); + } + + @Override + public final short getShort(int rowId) { + return Platform.getShort(null, data + 2 * rowId); + } + // // APIs dealing with ints // @@ -216,6 +268,41 @@ public final long getLong(int rowId) { return Platform.getLong(null, data + 8 * rowId); } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { + Platform.putFloat(null, data + rowId * 4, value); + } + + @Override + public final void putFloats(int rowId, int count, float value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putFloat(null, offset, value); + } + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { + return Platform.getFloat(null, data + rowId * 4); + } + + // // APIs dealing with doubles // @@ -241,7 +328,7 @@ public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { @Override public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex, + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId * 8, count * 8); } @@ -300,11 +387,14 @@ private final void reserveInternal(int newCapacity) { Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); this.offsetData = Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof ByteType) { + } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); - } else if (type instanceof IntegerType) { + } else if (type instanceof ShortType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + } else if (type instanceof IntegerType || type instanceof FloatType) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof LongType || type instanceof DoubleType) { + } else if (type instanceof LongType || type instanceof DoubleType || + DecimalType.is64BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 8197fa11cd4c8..32356334c031f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector { // Array for each type. Only 1 is populated for any type. private byte[] byteData; + private short[] shortData; private int[] intData; private long[] longData; + private float[] floatData; private double[] doubleData; // Only set if type is Array. @@ -104,6 +106,30 @@ public final boolean getIsNull(int rowId) { return nulls[rowId] == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + byteData[rowId] = (byte)((value) ? 1 : 0); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = v; + } + } + + @Override + public final boolean getBoolean(int rowId) { + return byteData[rowId] == 1; + } + + // + // // APIs dealing with Bytes // @@ -130,6 +156,33 @@ public final byte getByte(int rowId) { return byteData[rowId]; } + // + // APIs dealing with Shorts + // + + @Override + public final void putShort(int rowId, short value) { + shortData[rowId] = value; + } + + @Override + public final void putShorts(int rowId, int count, short value) { + for (int i = 0; i < count; ++i) { + shortData[i + rowId] = value; + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + System.arraycopy(src, srcIndex, shortData, rowId, count); + } + + @Override + public final short getShort(int rowId) { + return shortData[rowId]; + } + + // // APIs dealing with Ints // @@ -202,6 +255,31 @@ public final long getLong(int rowId) { return longData[rowId]; } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { floatData[rowId] = value; } + + @Override + public final void putFloats(int rowId, int count, float value) { + Arrays.fill(floatData, rowId, rowId + count, value); + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + System.arraycopy(src, srcIndex, floatData, rowId, count); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { return floatData[rowId]; } // // APIs dealing with doubles @@ -277,7 +355,7 @@ public final void reserve(int requiredCapacity) { // Spilt this function out since it is the slow path. private final void reserveInternal(int newCapacity) { - if (this.resultArray != null) { + if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { @@ -286,18 +364,30 @@ private final void reserveInternal(int newCapacity) { } arrayLengths = newLengths; arrayOffsets = newOffsets; + } else if (type instanceof BooleanType) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; } else if (type instanceof ByteType) { byte[] newData = new byte[newCapacity]; if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); byteData = newData; + } else if (type instanceof ShortType) { + short[] newData = new short[newCapacity]; + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + shortData = newData; } else if (type instanceof IntegerType) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; - } else if (type instanceof LongType) { + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); longData = newData; + } else if (type instanceof FloatType) { + float[] newData = new float[newCapacity]; + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + floatData = newData; } else if (type instanceof DoubleType) { double[] newData = new double[newCapacity]; if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 67cc08b6fc8ba..445f311107e33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { test("Null Apis") { @@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } - private def doubleEquals(d1: Double, d2: Double): Boolean = { if (d1.isNaN && d2.isNaN) { true @@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) if (!r1.isNullAt(v._2)) { v._1.dataType match { + case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed) case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed) case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)), + "Seed = " + seed) case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), "Seed = " + seed) + case t: DecimalType => + val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal + val d2 = r2.getDecimal(v._2) + assert(d1.compare(d2) == 0, "Seed = " + seed) case StringType => assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case CalendarIntervalType => + assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval]) case ArrayType(childType, n) => val a1 = r1.getArray(v._2).array val a2 = r2.getList(v._2).toArray @@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } } + case FloatType => { + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), + "Seed = " + seed) + i += 1 + } + } + + case t: DecimalType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal + val d2 = a2(i).asInstanceOf[java.math.BigDecimal] + assert(d1.compare(d2) == 0, "Seed = " + seed) + } + i += 1 + } + case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => @@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite { * results. */ def testRandomRows(flatSchema: Boolean, numFields: Int) { - // TODO: add remaining types. Figure out why StringType doesn't work on jenkins. - val types = Array(ByteType, IntegerType, LongType, DoubleType) + // TODO: Figure out why StringType doesn't work on jenkins. + val types = Array( + BooleanType, ByteType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10), + CalendarIntervalType) val seed = System.nanoTime() - val NUM_ROWS = 500 + val NUM_ROWS = 200 val NUM_ITERS = 1000 val random = new Random(seed) var i = 0 @@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("Random flat schema") { - testRandomRows(true, 10) + testRandomRows(true, 15) } test("Random nested schema") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b29bf6a464b30..18761bfd222a2 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -27,10 +27,14 @@ public final class Platform { public static final int BYTE_ARRAY_OFFSET; + public static final int SHORT_ARRAY_OFFSET; + public static final int INT_ARRAY_OFFSET; public static final int LONG_ARRAY_OFFSET; + public static final int FLOAT_ARRAY_OFFSET; + public static final int DOUBLE_ARRAY_OFFSET; public static int getInt(Object object, long offset) { @@ -168,13 +172,17 @@ public static void throwException(Throwable t) { if (_UNSAFE != null) { BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); } else { BYTE_ARRAY_OFFSET = 0; + SHORT_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; + FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; } } From a2973fed30fbe9a0b12e1c1225359fdf55d322b4 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 1 Feb 2016 13:57:48 -0800 Subject: [PATCH 1763/2089] =?UTF-8?q?Fix=20for=20[SPARK-12854][SQL]=20Impl?= =?UTF-8?q?ement=20complex=20types=20support=20in=20Columna=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rBatch Fixes build for Scala 2.11. Author: Jacek Laskowski Closes #10946 from jaceklaskowski/SPARK-12854-fix. --- .../spark/sql/execution/vectorized/OffHeapColumnVector.java | 2 +- .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 22c5e5fc81a4a..7a224d19d15b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -367,7 +367,7 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) { } @Override - public final void loadBytes(Array array) { + public final void loadBytes(ColumnVector.Array array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 32356334c031f..c42bbd642ecae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -331,7 +331,7 @@ public final void putArray(int rowId, int offset, int length) { } @Override - public final void loadBytes(Array array) { + public final void loadBytes(ColumnVector.Array array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } From be7a2fc0716b7d25327b6f8f683390fc62532e3b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 14:11:52 -0800 Subject: [PATCH 1764/2089] [SPARK-13078][SQL] API and test cases for internal catalog This pull request creates an internal catalog API. The creation of this API is the first step towards consolidating SQLContext and HiveContext. I envision we will have two different implementations in Spark 2.0: (1) a simple in-memory implementation, and (2) an implementation based on the current HiveClient (ClientWrapper). I took a look at what Hive's internal metastore interface/implementation, and then created this API based on it. I believe this is the minimal set needed in order to achieve all the needed functionality. Author: Reynold Xin Closes #10982 from rxin/SPARK-13078. --- .../catalyst/catalog/InMemoryCatalog.scala | 246 ++++++++++++++++ .../sql/catalyst/catalog/interface.scala | 178 ++++++++++++ .../catalyst/catalog/CatalogTestCases.scala | 263 ++++++++++++++++++ .../catalog/InMemoryCatalogSuite.scala | 23 ++ 4 files changed, 710 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala new file mode 100644 index 0000000000000..9e6dfb7e9506f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException + + +/** + * An in-memory (ephemeral) implementation of the system catalog. + * + * All public methods should be synchronized for thread-safety. + */ +class InMemoryCatalog extends Catalog { + + private class TableDesc(var table: Table) { + val partitions = new mutable.HashMap[String, TablePartition] + } + + private class DatabaseDesc(var db: Database) { + val tables = new mutable.HashMap[String, TableDesc] + val functions = new mutable.HashMap[String, Function] + } + + private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] + + private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val regex = pattern.replaceAll("\\*", ".*").r + names.filter { funcName => regex.pattern.matcher(funcName).matches() } + } + + private def existsFunction(db: String, funcName: String): Boolean = { + catalog(db).functions.contains(funcName) + } + + private def existsTable(db: String, table: String): Boolean = { + catalog(db).tables.contains(table) + } + + private def assertDbExists(db: String): Unit = { + if (!catalog.contains(db)) { + throw new AnalysisException(s"Database $db does not exist") + } + } + + private def assertFunctionExists(db: String, funcName: String): Unit = { + assertDbExists(db) + if (!existsFunction(db, funcName)) { + throw new AnalysisException(s"Function $funcName does not exists in $db database") + } + } + + private def assertTableExists(db: String, table: String): Unit = { + assertDbExists(db) + if (!existsTable(db, table)) { + throw new AnalysisException(s"Table $table does not exists in $db database") + } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized { + if (catalog.contains(dbDefinition.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Database ${dbDefinition.name} already exists.") + } + } else { + catalog.put(dbDefinition.name, new DatabaseDesc(dbDefinition)) + } + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { + if (catalog.contains(db)) { + if (!cascade) { + // If cascade is false, make sure the database is empty. + if (catalog(db).tables.nonEmpty) { + throw new AnalysisException(s"Database $db is not empty. One or more tables exist.") + } + if (catalog(db).functions.nonEmpty) { + throw new AnalysisException(s"Database $db is not empty. One or more functions exist.") + } + } + // Remove the database. + catalog.remove(db) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Database $db does not exist") + } + } + } + + override def alterDatabase(db: String, dbDefinition: Database): Unit = synchronized { + assertDbExists(db) + assert(db == dbDefinition.name) + catalog(db).db = dbDefinition + } + + override def getDatabase(db: String): Database = synchronized { + assertDbExists(db) + catalog(db).db + } + + override def listDatabases(): Seq[String] = synchronized { + catalog.keySet.toSeq + } + + override def listDatabases(pattern: String): Seq[String] = synchronized { + filterPattern(listDatabases(), pattern) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean) + : Unit = synchronized { + assertDbExists(db) + if (existsTable(db, tableDefinition.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") + } + } else { + catalog(db).tables.put(tableDefinition.name, new TableDesc(tableDefinition)) + } + } + + override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean) + : Unit = synchronized { + assertDbExists(db) + if (existsTable(db, table)) { + catalog(db).tables.remove(table) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Table $table does not exist in $db database") + } + } + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + assertTableExists(db, oldName) + val oldDesc = catalog(db).tables(oldName) + oldDesc.table = oldDesc.table.copy(name = newName) + catalog(db).tables.put(newName, oldDesc) + catalog(db).tables.remove(oldName) + } + + override def alterTable(db: String, table: String, tableDefinition: Table): Unit = synchronized { + assertTableExists(db, table) + assert(table == tableDefinition.name) + catalog(db).tables(table).table = tableDefinition + } + + override def getTable(db: String, table: String): Table = synchronized { + assertTableExists(db, table) + catalog(db).tables(table).table + } + + override def listTables(db: String): Seq[String] = synchronized { + assertDbExists(db) + catalog(db).tables.keySet.toSeq + } + + override def listTables(db: String, pattern: String): Seq[String] = synchronized { + assertDbExists(db) + filterPattern(listTables(db), pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def alterPartition(db: String, table: String, part: TablePartition) + : Unit = synchronized { + throw new UnsupportedOperationException + } + + override def alterPartitions(db: String, table: String, parts: Seq[TablePartition]) + : Unit = synchronized { + throw new UnsupportedOperationException + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction( + db: String, func: Function, ifNotExists: Boolean): Unit = synchronized { + assertDbExists(db) + + if (existsFunction(db, func.name)) { + if (!ifNotExists) { + throw new AnalysisException(s"Function $func already exists in $db database") + } + } else { + catalog(db).functions.put(func.name, func) + } + } + + override def dropFunction(db: String, funcName: String): Unit = synchronized { + assertFunctionExists(db, funcName) + catalog(db).functions.remove(funcName) + } + + override def alterFunction(db: String, funcName: String, funcDefinition: Function) + : Unit = synchronized { + assertFunctionExists(db, funcName) + if (funcName != funcDefinition.name) { + // Also a rename; remove the old one and add the new one back + catalog(db).functions.remove(funcName) + } + catalog(db).functions.put(funcName, funcDefinition) + } + + override def getFunction(db: String, funcName: String): Function = synchronized { + assertFunctionExists(db, funcName) + catalog(db).functions(funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { + assertDbExists(db) + val regex = pattern.replaceAll("\\*", ".*").r + filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala new file mode 100644 index 0000000000000..a6caf91f3304b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.AnalysisException + + +/** + * Interface for the system catalog (of columns, partitions, tables, and databases). + * + * This is only used for non-temporary items, and implementations must be thread-safe as they + * can be accessed in multiple threads. + * + * Implementations should throw [[AnalysisException]] when table or database don't exist. + */ +abstract class Catalog { + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit + + def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit + + def alterDatabase(db: String, dbDefinition: Database): Unit + + def getDatabase(db: String): Database + + def listDatabases(): Seq[String] + + def listDatabases(pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + def createTable(db: String, tableDefinition: Table, ignoreIfExists: Boolean): Unit + + def dropTable(db: String, table: String, ignoreIfNotExists: Boolean): Unit + + def renameTable(db: String, oldName: String, newName: String): Unit + + def alterTable(db: String, table: String, tableDefinition: Table): Unit + + def getTable(db: String, table: String): Table + + def listTables(db: String): Seq[String] + + def listTables(db: String, pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + // TODO: need more functions for partitioning. + + def alterPartition(db: String, table: String, part: TablePartition): Unit + + def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + def createFunction(db: String, funcDefinition: Function, ignoreIfExists: Boolean): Unit + + def dropFunction(db: String, funcName: String): Unit + + def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit + + def getFunction(db: String, funcName: String): Function + + def listFunctions(db: String, pattern: String): Seq[String] + +} + + +/** + * A function defined in the catalog. + * + * @param name name of the function + * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" + */ +case class Function( + name: String, + className: String +) + + +/** + * Storage format, used to describe how a partition or a table is stored. + */ +case class StorageFormat( + locationUri: String, + inputFormat: String, + outputFormat: String, + serde: String, + serdeProperties: Map[String, String] +) + + +/** + * A column in a table. + */ +case class Column( + name: String, + dataType: String, + nullable: Boolean, + comment: String +) + + +/** + * A partition (Hive style) defined in the catalog. + * + * @param values values for the partition columns + * @param storage storage format of the partition + */ +case class TablePartition( + values: Seq[String], + storage: StorageFormat +) + + +/** + * A table defined in the catalog. + * + * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the + * future once we have a better understanding of how we want to handle skewed columns. + */ +case class Table( + name: String, + description: String, + schema: Seq[Column], + partitionColumns: Seq[Column], + sortColumns: Seq[Column], + storage: StorageFormat, + numBuckets: Int, + properties: Map[String, String], + tableType: String, + createTime: Long, + lastAccessTime: Long, + viewOriginalText: Option[String], + viewText: Option[String]) { + + require(tableType == "EXTERNAL_TABLE" || tableType == "INDEX_TABLE" || + tableType == "MANAGED_TABLE" || tableType == "VIRTUAL_VIEW") +} + + +/** + * A database defined in the catalog. + */ +case class Database( + name: String, + description: String, + locationUri: String, + properties: Map[String, String] +) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala new file mode 100644 index 0000000000000..ab9d5ac8a20eb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException + + +/** + * A reasonable complete test suite (i.e. behaviors) for a [[Catalog]]. + * + * Implementations of the [[Catalog]] interface can create test suites by extending this. + */ +abstract class CatalogTestCases extends SparkFunSuite { + + protected def newEmptyCatalog(): Catalog + + /** + * Creates a basic catalog, with the following structure: + * + * db1 + * db2 + * - tbl1 + * - tbl2 + * - func1 + */ + private def newBasicCatalog(): Catalog = { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("db1"), ifNotExists = false) + catalog.createDatabase(newDb("db2"), ifNotExists = false) + + catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog + } + + private def newFunc(): Function = Function("funcname", "org.apache.spark.MyFunc") + + private def newDb(name: String = "default"): Database = + Database(name, name + " description", "uri", Map.empty) + + private def newTable(name: String): Table = + Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, + None, None) + + private def newFunc(name: String): Function = Function(name, "class.name") + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + test("basic create, drop and list databases") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb(), ifNotExists = false) + assert(catalog.listDatabases().toSet == Set("default")) + + catalog.createDatabase(newDb("default2"), ifNotExists = false) + assert(catalog.listDatabases().toSet == Set("default", "default2")) + } + + test("get database when a database exists") { + val db1 = newBasicCatalog().getDatabase("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } + + test("get database should throw exception when the database does not exist") { + intercept[AnalysisException] { newBasicCatalog().getDatabase("db_that_does_not_exist") } + } + + test("list databases without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases().toSet == Set("db1", "db2")) + } + + test("list databases with pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } + + test("drop database") { + val catalog = newBasicCatalog() + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("db2")) + } + + test("drop database when the database is not empty") { + // Throw exception if there are functions left + val catalog1 = newBasicCatalog() + catalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false) + catalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // Throw exception if there are tables left + val catalog2 = newBasicCatalog() + catalog2.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // When cascade is true, it should drop them + val catalog3 = newBasicCatalog() + catalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog3.listDatabases().toSet == Set("db1")) + } + + test("drop database when the database does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) + } + + test("alter database") { + val catalog = newBasicCatalog() + catalog.alterDatabase("db1", Database("db1", "new description", "lll", Map.empty)) + assert(catalog.getDatabase("db1").description == "new description") + } + + test("alter database should throw exception when the database does not exist") { + intercept[AnalysisException] { + newBasicCatalog().alterDatabase("no_db", Database("no_db", "ddd", "lll", Map.empty)) + } + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + test("drop table") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable("db2", "tbl1", ignoreIfNotExists = false) + assert(catalog.listTables("db2").toSet == Set("tbl2")) + } + + test("drop table when database / table does not exist") { + val catalog = newBasicCatalog() + + // Should always throw exception when the database does not exist + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = false) + } + + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = true) + } + + // Should throw exception when the table does not exist, if ignoreIfNotExists is false + intercept[AnalysisException] { + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = false) + } + + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = true) + } + + test("rename table") { + val catalog = newBasicCatalog() + + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable("db2", "tbl1", "tblone") + assert(catalog.listTables("db2").toSet == Set("tblone", "tbl2")) + } + + test("rename table when database / table does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { // Throw exception when the database does not exist + catalog.renameTable("unknown_db", "unknown_table", "unknown_table") + } + + intercept[AnalysisException] { // Throw exception when the table does not exist + catalog.renameTable("db2", "unknown_table", "unknown_table") + } + } + + test("alter table") { + val catalog = newBasicCatalog() + catalog.alterTable("db2", "tbl1", newTable("tbl1").copy(createTime = 10)) + assert(catalog.getTable("db2", "tbl1").createTime == 10) + } + + test("alter table when database / table does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { // Throw exception when the database does not exist + catalog.alterTable("unknown_db", "unknown_table", newTable("unknown_table")) + } + + intercept[AnalysisException] { // Throw exception when the table does not exist + catalog.alterTable("db2", "unknown_table", newTable("unknown_table")) + } + } + + test("get table") { + assert(newBasicCatalog().getTable("db2", "tbl1").name == "tbl1") + } + + test("get table when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getTable("unknown_db", "unknown_table") + } + + intercept[AnalysisException] { + catalog.getTable("db2", "unknown_table") + } + } + + test("list tables without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db1").toSet == Set.empty) + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + } + + test("list tables with pattern") { + val catalog = newBasicCatalog() + + // Test when database does not exist + intercept[AnalysisException] { catalog.listTables("unknown_db") } + + assert(catalog.listTables("db1", "*").toSet == Set.empty) + assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "*1").toSet == Set("tbl1")) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + // TODO: Add tests cases for partitions + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + // TODO: Add tests cases for functions +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala new file mode 100644 index 0000000000000..871f0a0f46a22 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +/** Test suite for the [[InMemoryCatalog]]. */ +class InMemoryCatalogSuite extends CatalogTestCases { + override protected def newEmptyCatalog(): Catalog = new InMemoryCatalog +} From 715a19d56fc934d4aec5025739ff650daf4580b7 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 Feb 2016 16:23:17 -0800 Subject: [PATCH 1765/2089] [SPARK-12637][CORE] Print stage info of finished stages properly Improve printing of StageInfo in onStageCompleted See also https://github.com/apache/spark/pull/10585 Author: Sean Owen Closes #10922 from srowen/SPARK-12637. --- .../org/apache/spark/scheduler/SparkListener.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index ed3adbd81c28e..7b09c2eded0be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -270,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) + this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) // Shuffle write @@ -297,6 +297,17 @@ class StatsReportListener extends SparkListener with Logging { taskInfoMetrics.clear() } + private def getStatusDetail(info: StageInfo): String = { + val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("") + val timeTaken = info.submissionTime.map( + x => info.completionTime.getOrElse(System.currentTimeMillis()) - x + ).getOrElse("-") + + s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + + s"Took: $timeTaken msec" + } + } private[spark] object StatsReportListener extends Logging { From 0df3cfb8ab4d584c95db6c340694e199d7b59e9e Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 1 Feb 2016 16:55:21 -0800 Subject: [PATCH 1766/2089] [SPARK-12790][CORE] Remove HistoryServer old multiple files format Removed isLegacyLogDirectory code path and updated tests andrewor14 Author: felixcheung Closes #10860 from felixcheung/historyserverformat. --- .rat-excludes | 12 +- .../deploy/history/FsHistoryProvider.scala | 124 ++---------------- .../scheduler/EventLoggingListener.scala | 2 - .../EVENT_LOG_1 => local-1422981759269} | 0 .../local-1422981759269/APPLICATION_COMPLETE | 0 .../local-1422981759269/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1422981780767} | 0 .../local-1422981780767/APPLICATION_COMPLETE | 0 .../local-1422981780767/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1425081759269} | 0 .../local-1425081759269/APPLICATION_COMPLETE | 0 .../local-1425081759269/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1426533911241} | 0 .../local-1426533911241/APPLICATION_COMPLETE | 0 .../local-1426533911241/SPARK_VERSION_1.2.0 | 0 .../EVENT_LOG_1 => local-1426633911242} | 0 .../local-1426633911242/APPLICATION_COMPLETE | 0 .../local-1426633911242/SPARK_VERSION_1.2.0 | 0 .../history/FsHistoryProviderSuite.scala | 95 +------------- .../deploy/history/HistoryServerSuite.scala | 25 +--- 20 files changed, 23 insertions(+), 235 deletions(-) rename core/src/test/resources/spark-events/{local-1422981759269/EVENT_LOG_1 => local-1422981759269} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1422981780767/EVENT_LOG_1 => local-1422981780767} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1425081759269/EVENT_LOG_1 => local-1425081759269} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1426533911241/EVENT_LOG_1 => local-1426533911241} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 rename core/src/test/resources/spark-events/{local-1426633911242/EVENT_LOG_1 => local-1426633911242} (100%) delete mode 100755 core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE delete mode 100755 core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 diff --git a/.rat-excludes b/.rat-excludes index 874a6ee9f4043..8b5061415ff4c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -73,12 +73,12 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation -local-1422981759269/* -local-1422981780767/* -local-1425081759269/* -local-1426533911241/* -local-1426633911242/* -local-1430917381534/* +local-1422981759269 +local-1422981780767 +local-1425081759269 +local-1426533911241 +local-1426633911242 +local-1430917381534 local-1430917381535_1 local-1430917381535_2 DESCRIPTION diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 22e4155cc5452..9648959dbacb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -248,9 +248,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val logInfos: Seq[FileStatus] = statusList .filter { entry => try { - getModificationTime(entry).map { time => - time >= lastScanTime - }.getOrElse(false) + !entry.isDirectory() && (entry.getModificationTime() >= lastScanTime) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -261,9 +259,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => - val mod1 = getModificationTime(entry1).getOrElse(-1L) - val mod2 = getModificationTime(entry2).getOrElse(-1L) - mod1 >= mod2 + entry1.getModificationTime() >= entry2.getModificationTime() } logInfos.grouped(20) @@ -341,19 +337,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get }.foreach { attempt => val logPath = new Path(logDir, attempt.logPath) - // If this is a legacy directory, then add the directory to the zipStream and add - // each file to that directory. - if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { - val files = fs.listStatus(logPath) - zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) - zipStream.closeEntry() - files.foreach { file => - val path = file.getPath - zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) - } - } else { - zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) - } + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) } } finally { zipStream.close() @@ -527,12 +511,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") - val logInput = - if (isLegacyLogDirectory(eventLog)) { - openLegacyEventLog(logPath) - } else { - EventLoggingListener.openEventLog(logPath, fs) - } + val logInput = EventLoggingListener.openEventLog(logPath, fs) try { val appListener = new ApplicationEventListener val appCompleted = isApplicationCompleted(eventLog) @@ -540,9 +519,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.replay(logInput, logPath.toString, !appCompleted) // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. Some old versions of Spark generate logs without an app ID, so let - // logs generated by those versions go through. - if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + // try to show their UI. + if (appListener.appId.isDefined) { Some(new FsApplicationAttemptInfo( logPath.getName(), appListener.appName.getOrElse(NOT_STARTED), @@ -550,7 +528,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.appAttemptId, appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, + eventLog.getModificationTime(), appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted)) } else { @@ -561,91 +539,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Loads a legacy log directory. This assumes that the log directory contains a single event - * log file (along with other metadata files), which is the case for directories generated by - * the code in previous releases. - * - * @return input stream that holds one JSON record per line. - */ - private[history] def openLegacyEventLog(dir: Path): InputStream = { - val children = fs.listStatus(dir) - var eventLogPath: Path = null - var codecName: Option[String] = None - - children.foreach { child => - child.getPath().getName() match { - case name if name.startsWith(LOG_PREFIX) => - eventLogPath = child.getPath() - case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) => - codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length())) - case _ => - } - } - - if (eventLogPath == null) { - throw new IllegalArgumentException(s"$dir is not a Spark application log directory.") - } - - val codec = try { - codecName.map { c => CompressionCodec.createCodec(conf, c) } - } catch { - case e: Exception => - throw new IllegalArgumentException(s"Unknown compression codec $codecName.") - } - - val in = new BufferedInputStream(fs.open(eventLogPath)) - codec.map(_.compressedInputStream(in)).getOrElse(in) - } - - /** - * Return whether the specified event log path contains a old directory-based event log. - * Previously, the event log of an application comprises of multiple files in a directory. - * As of Spark 1.3, these files are consolidated into a single one that replaces the directory. - * See SPARK-2261 for more detail. - */ - private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDirectory - - /** - * Returns the modification time of the given event log. If the status points at an empty - * directory, `None` is returned, indicating that there isn't an event log at that location. - */ - private def getModificationTime(fsEntry: FileStatus): Option[Long] = { - if (isLegacyLogDirectory(fsEntry)) { - val statusList = fs.listStatus(fsEntry.getPath) - if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None - } else { - Some(fsEntry.getModificationTime()) - } - } - /** * Return true when the application has completed. */ private def isApplicationCompleted(entry: FileStatus): Boolean = { - if (isLegacyLogDirectory(entry)) { - fs.exists(new Path(entry.getPath(), APPLICATION_COMPLETE)) - } else { - !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS) - } - } - - /** - * Returns whether the version of Spark that generated logs records app IDs. App IDs were added - * in Spark 1.1. - */ - private def sparkVersionHasAppId(entry: FileStatus): Boolean = { - if (isLegacyLogDirectory(entry)) { - fs.listStatus(entry.getPath()) - .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } - .map { status => - val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) - version != "1.0" && version != "1.1" - } - .getOrElse(true) - } else { - true - } + !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS) } /** @@ -670,12 +568,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" - - // Constants used to parse Spark 1.0.0 log directories. - val LOG_PREFIX = "EVENT_LOG_" - val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 36f2b74f948f1..01fee46e73a80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -232,8 +232,6 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" - val SPARK_VERSION_KEY = "SPARK_VERSION" - val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) diff --git a/core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981759269 diff --git a/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981780767 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981780767 diff --git a/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1425081759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1425081759269 diff --git a/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426533911241 similarity index 100% rename from core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426533911241 diff --git a/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426633911242 similarity index 100% rename from core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426633911242 diff --git a/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 6cbf911395a84..3baa2e2ddad31 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -69,7 +69,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new File(logPath) } - test("Parse new and old application logs") { + test("Parse application logs") { val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. @@ -95,26 +95,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc None) ) - // Write an old-style application log. - val oldAppComplete = writeOldLog("old1", "1.0", None, true, - SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - - // Check for logs so that we force the older unfinished app to be loaded, to make - // sure unfinished apps are also sorted correctly. - provider.checkForLogs() - - // Write an unfinished app, old-style. - val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, - SparkListenerApplicationStart("old2", None, 2L, "test", None) - ) - - // Force a reload of data from the log directory, and check that both logs are loaded. + // Force a reload of data from the log directory, and check that logs are loaded. // Take the opportunity to check that the offset checks work as expected. updateAndCheck(provider) { list => - list.size should be (5) - list.count(_.attempts.head.completed) should be (3) + list.size should be (3) + list.count(_.attempts.head.completed) should be (2) def makeAppInfo( id: String, @@ -132,11 +117,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc newAppComplete.lastModified(), "test", true)) list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) - list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, - oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, - -1L, oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, + list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -148,38 +129,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("Parse legacy logs with compression codec set") { - val provider = new FsHistoryProvider(createTestConf()) - val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), - (classOf[SnappyCompressionCodec].getName(), true), - ("invalid.codec", false)) - - testCodecs.foreach { case (codecName, valid) => - val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null - val logDir = new File(testDir, codecName) - logDir.mkdir() - createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), - SparkListenerApplicationStart("app2", None, 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) - - val logPath = new Path(logDir.getAbsolutePath()) - try { - val logInput = provider.openLegacyEventLog(logPath) - try { - Source.fromInputStream(logInput).getLines().toSeq.size should be (2) - } finally { - logInput.close() - } - } catch { - case e: IllegalArgumentException => - valid should be (false) - } - } - } - test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, @@ -395,21 +344,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc SparkListenerLogStart("1.4") ) - // Write a 1.2 log file with no start event (= no app id), it should be ignored. - writeOldLog("v12Log", "1.2", None, false) - - // Write 1.0 and 1.1 logs, which don't have app ids. - writeOldLog("v11Log", "1.1", None, true, - SparkListenerApplicationStart("v11Log", None, 2L, "test", None), - SparkListenerApplicationEnd(3L)) - writeOldLog("v10Log", "1.0", None, true, - SparkListenerApplicationStart("v10Log", None, 2L, "test", None), - SparkListenerApplicationEnd(4L)) - updateAndCheck(provider) { list => - list.size should be (2) - list(0).id should be ("v10Log") - list(1).id should be ("v11Log") + list.size should be (0) } } @@ -499,25 +435,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } - private def writeOldLog( - fname: String, - sparkVersion: String, - codec: Option[CompressionCodec], - completed: Boolean, - events: SparkListenerEvent*): File = { - val log = new File(testDir, fname) - log.mkdir() - - val oldEventLog = new File(log, LOG_PREFIX + "1") - createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) - writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) - if (completed) { - createEmptyFile(new File(log, APPLICATION_COMPLETE)) - } - - log - } - private class SafeModeTestProvider(conf: SparkConf, clock: Clock) extends FsHistoryProvider(conf, clock) { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index be55b2e0fe1b7..40d0076eecfc8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -176,18 +176,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } } - test("download legacy logs - all attempts") { - doDownloadTest("local-1426533911241", None, legacy = true) - } - - test("download legacy logs - single attempts") { - (1 to 2). foreach { - attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) - } - } - // Test that the files are downloaded correctly, and validate them. - def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { + def doDownloadTest(appId: String, attemptId: Option[Int]): Unit = { val url = attemptId match { case Some(id) => @@ -205,22 +195,13 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers var entry = zipStream.getNextEntry entry should not be null val totalFiles = { - if (legacy) { - attemptId.map { x => 3 }.getOrElse(6) - } else { - attemptId.map { x => 1 }.getOrElse(2) - } + attemptId.map { x => 1 }.getOrElse(2) } var filesCompared = 0 while (entry != null) { if (!entry.isDirectory) { val expectedFile = { - if (legacy) { - val splits = entry.getName.split("/") - new File(new File(logDir, splits(0)), splits(1)) - } else { - new File(logDir, entry.getName) - } + new File(logDir, entry.getName) } val expected = Files.toString(expectedFile, Charsets.UTF_8) val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) From 0fff5c6e6325357a241d311e72db942c4850af34 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 23:08:11 -0800 Subject: [PATCH 1767/2089] [SPARK-13130][SQL] Make codegen variable names easier to read 1. Use lower case 2. Change long prefixes to something shorter (in this case I am changing only one: TungstenAggregate -> agg). Author: Reynold Xin Closes #11017 from rxin/SPARK-13130. --- .../spark/sql/execution/WholeStageCodegen.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f049f..02b0f423ed438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.util.Utils /** @@ -33,6 +34,12 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { + /** Prefix used in the current operator's variable names. */ + private def variablePrefix: String = this match { + case _: TungstenAggregate => "agg" + case _ => nodeName.toLowerCase + } + /** * Whether this SparkPlan support whole stage codegen or not. */ @@ -53,7 +60,7 @@ trait CodegenSupport extends SparkPlan { */ def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent - ctx.freshNamePrefix = nodeName + ctx.freshNamePrefix = variablePrefix doProduce(ctx) } @@ -94,7 +101,7 @@ trait CodegenSupport extends SparkPlan { child: SparkPlan, input: Seq[ExprCode], row: String = null): String = { - ctx.freshNamePrefix = nodeName + ctx.freshNamePrefix = variablePrefix if (row != null) { ctx.currentVars = null ctx.INPUT_ROW = row From b8666fd0e2a797924eb2e94ac5558aba2a9b5140 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 1 Feb 2016 23:37:06 -0800 Subject: [PATCH 1768/2089] Closes #10662. Closes #10661 From 22ba21348b28d8b1909ccde6fe17fb9e68531e5a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 16:48:59 +0800 Subject: [PATCH 1769/2089] [SPARK-13087][SQL] Fix group by function for sort based aggregation It is not valid to call `toAttribute` on a `NamedExpression` unless we know for sure that the child produced that `NamedExpression`. The current code worked fine when the grouping expressions were simple, but when they were a derived value this blew up at execution time. Author: Michael Armbrust Closes #11013 from marmbrus/groupByFunction-master. --- .../org/apache/spark/sql/execution/aggregate/utils.scala | 5 ++--- .../spark/sql/hive/execution/AggregationQuerySuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 83379ae90f703..1e113ccd4e137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -33,15 +33,14 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, + requiredChildDistributionExpressions = Some(groupingExpressions), + groupingExpressions = groupingExpressions, aggregateExpressions = completeAggregateExpressions, aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 3e4cf3f79e57c..7a9ed1eaf3dbc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -193,6 +193,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te sqlContext.dropTempTable("emptyTable") } + test("group by function") { + Seq((1, 2)).toDF("a", "b").registerTempTable("data") + + checkAnswer( + sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"), + Row(1, Array(2)) :: Nil) + } + test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( From 12a20c144f14e80ef120ddcfb0b455a805a2da23 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 10:13:54 -0800 Subject: [PATCH 1770/2089] [SPARK-10820][SQL] Support for the continuous execution of structured queries This is a follow up to 9aadcffabd226557174f3ff566927f873c71672e that extends Spark SQL to allow users to _repeatedly_ optimize and execute structured queries. A `ContinuousQuery` can be expressed using SQL, DataFrames or Datasets. The purpose of this PR is only to add some initial infrastructure which will be extended in subsequent PRs. ## User-facing API - `sqlContext.streamFrom` and `df.streamTo` return builder objects that are analogous to the `read/write` interfaces already available to executing queries in a batch-oriented fashion. - `ContinuousQuery` provides an interface for interacting with a query that is currently executing in the background. ## Internal Interfaces - `StreamExecution` - executes streaming queries in micro-batches The following are currently internal, but public APIs will be provided in a future release. - `Source` - an interface for providers of continually arriving data. A source must have a notion of an `Offset` that monotonically tracks what data has arrived. For fault tolerance, a source must be able to replay data given a start offset. - `Sink` - an interface that accepts the results of a continuously executing query. Also responsible for tracking the offset that should be resumed from in the case of a failure. ## Testing - `MemoryStream` and `MemorySink` - simple implementations of source and sink that keep all data in memory and have methods for simulating durability failures - `StreamTest` - a framework for performing actions and checking invariants on a continuous query Author: Michael Armbrust Author: Tathagata Das Author: Josh Rosen Closes #11006 from marmbrus/structured-streaming. --- .../apache/spark/sql/ContinuousQuery.scala | 30 ++ .../org/apache/spark/sql/DataFrame.scala | 8 + .../apache/spark/sql/DataStreamReader.scala | 127 +++++++ .../apache/spark/sql/DataStreamWriter.scala | 134 +++++++ .../org/apache/spark/sql/SQLContext.scala | 8 + .../datasources/ResolvedDataSource.scala | 33 +- .../spark/sql/execution/streaming/Batch.scala | 26 ++ .../execution/streaming/CompositeOffset.scala | 67 ++++ .../sql/execution/streaming/LongOffset.scala | 33 ++ .../sql/execution/streaming/Offset.scala | 37 ++ .../spark/sql/execution/streaming/Sink.scala | 47 +++ .../sql/execution/streaming/Source.scala | 36 ++ .../execution/streaming/StreamExecution.scala | 211 +++++++++++ .../execution/streaming/StreamProgress.scala | 67 ++++ .../streaming/StreamingRelation.scala | 34 ++ .../sql/execution/streaming/memory.scala | 138 +++++++ .../apache/spark/sql/sources/interfaces.scala | 21 ++ .../org/apache/spark/sql/QueryTest.scala | 74 ++-- .../org/apache/spark/sql/StreamTest.scala | 346 ++++++++++++++++++ .../sql/streaming/DataStreamReaderSuite.scala | 166 +++++++++ .../streaming/MemorySourceStressSuite.scala | 33 ++ .../spark/sql/streaming/OffsetSuite.scala | 98 +++++ .../spark/sql/streaming/StreamSuite.scala | 84 +++++ .../spark/sql/test/SharedSQLContext.scala | 2 +- 24 files changed, 1828 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala new file mode 100644 index 0000000000000..1c2c0290fc4cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * A handle to a query that is executing continuously in the background as new data arrives. + */ +trait ContinuousQuery { + + /** + * Stops the execution of this query if it is running. This method blocks until the threads + * performing execution has stopped. + */ + def stop(): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 518f9dcf94a70..6de17e5924d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1690,6 +1690,14 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) + /** + * :: Experimental :: + * Interface for starting a streaming query that will continually output results to the specified + * external sink as new data arrives. + */ + @Experimental + def streamTo: DataStreamWriter = new DataStreamWriter(this) + /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala new file mode 100644 index 0000000000000..2febc93fa49d4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala @@ -0,0 +1,127 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * An interface to reading streaming data. Use `sqlContext.streamFrom` to access these methods. + * + * {{{ + * val df = sqlContext.streamFrom + * .format("...") + * .open() + * }}} + */ +@Experimental +class DataStreamReader private[sql](sqlContext: SQLContext) extends Logging { + + /** + * Specifies the input data source format. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data streams (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data stream can + * skip the schema inference step, and thus speed up data reading. + * + * @since 2.0.0 + */ + def schema(schema: StructType): DataStreamReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data stream. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamReader = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds input options for the underlying data stream. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data stream. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamReader = { + this.options(options.asScala) + this + } + + /** + * Loads streaming input in as a [[DataFrame]], for data streams that don't require a path (e.g. + * external key-value stores). + * + * @since 2.0.0 + */ + def open(): DataFrame = { + val resolved = ResolvedDataSource.createSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + providerName = source, + options = extraOptions.toMap) + DataFrame(sqlContext, StreamingRelation(resolved)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def open(path: String): DataFrame = { + option("path", path).open() + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sqlContext.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala new file mode 100644 index 0000000000000..b325d48fcbbb1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.streaming.StreamExecution + +/** + * :: Experimental :: + * Interface used to start a streaming query query execution. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamWriter private[sql](df: DataFrame) { + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamWriter = { + this.source = source + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamWriter = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamWriter = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamWriter = { + this.options(options.asScala) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme.\ + * @since 2.0.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataStreamWriter = { + this.partitioningColumns = colNames + this + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * @since 2.0.0 + */ + def start(path: String): ContinuousQuery = { + this.extraOptions += ("path" -> path) + start() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def start(): ContinuousQuery = { + val sink = ResolvedDataSource.createSink( + df.sqlContext, + source, + extraOptions.toMap, + normalizedParCols) + + new StreamExecution(df.sqlContext, df.logicalPlan, sink) + } + + private def normalizedParCols: Seq[String] = { + partitioningColumns.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sqlContext.conf.defaultDataSourceName + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Seq[String] = Nil + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ef993c3edae37..13700be06828d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -594,6 +594,14 @@ class SQLContext private[sql]( @Experimental def read: DataFrameReader = new DataFrameReader(this) + + /** + * :: Experimental :: + * Returns a [[DataStreamReader]] than can be used to access data continuously as it arrives. + */ + @Experimental + def streamFrom: DataStreamReader = new DataStreamReader(this) + /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index cc8dcf59307f2..e3065ac5f87d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -29,11 +29,11 @@ import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils - case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) @@ -92,6 +92,37 @@ object ResolvedDataSource extends Logging { } } + def createSource( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + providerName: String, + options: Map[String, String]): Source = { + val provider = lookupDataSource(providerName).newInstance() match { + case s: StreamSourceProvider => s + case _ => + throw new UnsupportedOperationException( + s"Data source $providerName does not support streamed reading") + } + + provider.createSource(sqlContext, options, userSpecifiedSchema) + } + + def createSink( + sqlContext: SQLContext, + providerName: String, + options: Map[String, String], + partitionColumns: Seq[String]): Sink = { + val provider = lookupDataSource(providerName).newInstance() match { + case s: StreamSinkProvider => s + case _ => + throw new UnsupportedOperationException( + s"Data source $providerName does not support streamed writing") + } + + provider.createSink(sqlContext, options, partitionColumns) + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ def apply( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala new file mode 100644 index 0000000000000..1f25eb8fc5223 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.DataFrame + +/** + * Used to pass a batch of data through a streaming query execution along with an indication + * of progress in the stream. + */ +class Batch(val end: Offset, val data: DataFrame) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala new file mode 100644 index 0000000000000..d2cb20ef8b819 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.Try + +/** + * An ordered collection of offsets, used to track the progress of processing data from one or more + * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance + * vector clock that must progress linearly forward. + */ +case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + override def compareTo(other: Offset): Int = other match { + case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => + val comparisons = offsets.zip(otherComposite.offsets).map { + case (Some(a), Some(b)) => a compareTo b + case (None, None) => 0 + case (None, _) => -1 + case (_, None) => 1 + } + val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet + nonZeroSigns.size match { + case 0 => 0 // if both empty or only 0s + case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) + case _ => // there are both 1s and -1s + throw new IllegalArgumentException( + s"Invalid comparison between non-linear histories: $this <=> $other") + } + case _ => + throw new IllegalArgumentException(s"Cannot compare $this <=> $other") + } + + private def sign(num: Int): Int = num match { + case i if i < 0 => -1 + case i if i == 0 => 0 + case i if i > 0 => 1 + } +} + +object CompositeOffset { + /** + * Returns a [[CompositeOffset]] with a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(offsets: Offset*): CompositeOffset = { + CompositeOffset(offsets.map(Option(_))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala new file mode 100644 index 0000000000000..008195af38b75 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A simple offset for sources that produce a single linear stream of data. + */ +case class LongOffset(offset: Long) extends Offset { + + override def compareTo(other: Offset): Int = other match { + case l: LongOffset => offset.compareTo(l.offset) + case _ => + throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") + } + + def +(increment: Long): LongOffset = new LongOffset(offset + increment) + def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala new file mode 100644 index 0000000000000..0f5d6445b1e2b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A offset is a monotonically increasing metric used to track progress in the computation of a + * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent + * with `equals` and `hashcode`. + */ +trait Offset extends Serializable { + + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + def compareTo(other: Offset): Int + + def >(other: Offset): Boolean = compareTo(other) > 0 + def <(other: Offset): Boolean = compareTo(other) < 0 + def <=(other: Offset): Boolean = compareTo(other) <= 0 + def >=(other: Offset): Boolean = compareTo(other) >= 0 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala new file mode 100644 index 0000000000000..1bd71b6b02ea9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * An interface for systems that can collect the results of a streaming query. + * + * When new data is produced by a query, a [[Sink]] must be able to transactionally collect the + * data and update the [[Offset]]. In the case of a failure, the sink will be recreated + * and must be able to return the [[Offset]] for all of the data that is made durable. + * This contract allows Spark to process data with exactly-once semantics, even in the case + * of failures that require the computation to be restarted. + */ +trait Sink { + /** + * Returns the [[Offset]] for all data that is currently present in the sink, if any. This + * function will be called by Spark when restarting execution in order to determine at which point + * in the input stream computation should be resumed from. + */ + def currentOffset: Option[Offset] + + /** + * Accepts a new batch of data as well as a [[Offset]] that denotes how far in the input + * data computation has progressed to. When computation restarts after a failure, it is important + * that a [[Sink]] returns the same [[Offset]] as the most recent batch of data that + * has been persisted durrably. Note that this does not necessarily have to be the + * [[Offset]] for the most recent batch of data that was given to the sink. For example, + * it is valid to buffer data before persisting, as long as the [[Offset]] is stored + * transactionally as data is eventually persisted. + */ + def addBatch(batch: Batch): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala new file mode 100644 index 0000000000000..25922979ac83e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.types.StructType + +/** + * A source of continually arriving data for a streaming query. A [[Source]] must have a + * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark + * will regularly query each [[Source]] to see if any more data is available. + */ +trait Source { + + /** Returns the schema of the data from this source */ + def schema: StructType + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + def getNextBatch(start: Option[Offset]): Option[Batch] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala new file mode 100644 index 0000000000000..ebebb829710b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Logging +import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. + * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any + * [[Source]] present in the query plan. Whenever new data arrives, a [[QueryExecution]] is created + * and the results are committed transactionally to the given [[Sink]]. + */ +class StreamExecution( + sqlContext: SQLContext, + private[sql] val logicalPlan: LogicalPlan, + val sink: Sink) extends ContinuousQuery with Logging { + + /** An monitor used to wait/notify when batches complete. */ + private val awaitBatchLock = new Object + + @volatile + private var batchRun = false + + /** Minimum amount of time in between the start of each batch. */ + private val minBatchTime = 10 + + /** Tracks how much data we have processed from each input source. */ + private[sql] val streamProgress = new StreamProgress + + /** All stream sources present the query plan. */ + private val sources = + logicalPlan.collect { case s: StreamingRelation => s.source } + + // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data + // that we have already processed). + { + sink.currentOffset match { + case Some(c: CompositeOffset) => + val storedProgress = c.offsets + val sources = logicalPlan collect { + case StreamingRelation(source, _) => source + } + + assert(sources.size == storedProgress.size) + sources.zip(storedProgress).foreach { case (source, offset) => + offset.foreach(streamProgress.update(source, _)) + } + case None => // We are starting this stream for the first time. + case _ => throw new IllegalArgumentException("Expected composite offset from sink") + } + } + + logInfo(s"Stream running at $streamProgress") + + /** When false, signals to the microBatchThread that it should stop running. */ + @volatile private var shouldRun = true + + /** The thread that runs the micro-batches of this stream. */ + private[sql] val microBatchThread = new Thread("stream execution thread") { + override def run(): Unit = { + SQLContext.setActive(sqlContext) + while (shouldRun) { + attemptBatch() + Thread.sleep(minBatchTime) // TODO: Could be tighter + } + } + } + microBatchThread.setDaemon(true) + microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + } + }) + microBatchThread.start() + + @volatile + private[sql] var lastExecution: QueryExecution = null + @volatile + private[sql] var streamDeathCause: Throwable = null + + /** + * Checks to see if any new data is present in any of the sources. When new data is available, + * a batch is executed and passed to the sink, updating the currentOffsets. + */ + private def attemptBatch(): Unit = { + val startTime = System.nanoTime() + + // A list of offsets that need to be updated if this batch is successful. + // Populated while walking the tree. + val newOffsets = new ArrayBuffer[(Source, Offset)] + // A list of attributes that will need to be updated. + var replacements = new ArrayBuffer[(Attribute, Attribute)] + // Replace sources in the logical plan with data that has arrived since the last batch. + val withNewSources = logicalPlan transform { + case StreamingRelation(source, output) => + val prevOffset = streamProgress.get(source) + val newBatch = source.getNextBatch(prevOffset) + + newBatch.map { batch => + newOffsets += ((source, batch.end)) + val newPlan = batch.data.logicalPlan + + assert(output.size == newPlan.output.size) + replacements ++= output.zip(newPlan.output) + newPlan + }.getOrElse { + LocalRelation(output) + } + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val newPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => replacementMap(a) + } + + if (newOffsets.nonEmpty) { + val optimizerStart = System.nanoTime() + + lastExecution = new QueryExecution(sqlContext, newPlan) + val executedPlan = lastExecution.executedPlan + val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 + logDebug(s"Optimized batch in ${optimizerTime}ms") + + streamProgress.synchronized { + // Update the offsets and calculate a new composite offset + newOffsets.foreach(streamProgress.update) + val newStreamProgress = logicalPlan.collect { + case StreamingRelation(source, _) => streamProgress.get(source) + } + val batchOffset = CompositeOffset(newStreamProgress) + + // Construct the batch and send it to the sink. + val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan)) + sink.addBatch(nextBatch) + } + + batchRun = true + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() + } + + val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 + logInfo(s"Compete up to $newOffsets in ${batchTime}ms") + } + + logDebug(s"Waiting for data, current: $streamProgress") + } + + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + def stop(): Unit = { + shouldRun = false + if (microBatchThread.isAlive) { microBatchThread.join() } + } + + /** + * Blocks the current thread until processing for data from the given `source` has reached at + * least the given `Offset`. This method is indented for use primarily when writing tests. + */ + def awaitOffset(source: Source, newOffset: Offset): Unit = { + def notDone = streamProgress.synchronized { + !streamProgress.contains(source) || streamProgress(source) < newOffset + } + + while (notDone) { + logInfo(s"Waiting until $newOffset at $source") + awaitBatchLock.synchronized { awaitBatchLock.wait(100) } + } + logDebug(s"Unblocked at $newOffset for $source") + } + + override def toString: String = + s""" + |=== Streaming Query === + |CurrentOffsets: $streamProgress + |Thread State: ${microBatchThread.getState} + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |$logicalPlan + """.stripMargin +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala new file mode 100644 index 0000000000000..0ded1d7152c19 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +/** + * A helper class that looks like a Map[Source, Offset]. + */ +class StreamProgress { + private val currentOffsets = new mutable.HashMap[Source, Offset] + + private[streaming] def update(source: Source, newOffset: Offset): Unit = { + currentOffsets.get(source).foreach(old => + assert(newOffset > old, s"Stream going backwards $newOffset -> $old")) + currentOffsets.put(source, newOffset) + } + + private[streaming] def update(newOffset: (Source, Offset)): Unit = + update(newOffset._1, newOffset._2) + + private[streaming] def apply(source: Source): Offset = currentOffsets(source) + private[streaming] def get(source: Source): Option[Offset] = currentOffsets.get(source) + private[streaming] def contains(source: Source): Boolean = currentOffsets.contains(source) + + private[streaming] def ++(updates: Map[Source, Offset]): StreamProgress = { + val updated = new StreamProgress + currentOffsets.foreach(updated.update) + updates.foreach(updated.update) + updated + } + + /** + * Used to create a new copy of this [[StreamProgress]]. While this class is currently mutable, + * it should be copied before being passed to user code. + */ + private[streaming] def copy(): StreamProgress = { + val copied = new StreamProgress + currentOffsets.foreach(copied.update) + copied + } + + override def toString: String = + currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") + + override def equals(other: Any): Boolean = other match { + case s: StreamProgress => currentOffsets == s.currentOffsets + case _ => false + } + + override def hashCode: Int = currentOffsets.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala new file mode 100644 index 0000000000000..e35c444348f48 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode + +object StreamingRelation { + def apply(source: Source): StreamingRelation = + StreamingRelation(source, source.schema.toAttributes) +} + +/** + * Used to link a streaming [[Source]] of data into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. + */ +case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode { + override def toString: String = source.toString +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala new file mode 100644 index 0000000000000..e6a0842936ea2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} +import org.apache.spark.sql.types.StructType + +object MemoryStream { + protected val currentBlockId = new AtomicInteger(0) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] + * is primarily intended for use in unit tests as it can only replay data when the object is still + * available. + */ +case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends Source with Logging { + protected val encoder = encoderFor[A] + protected val logicalPlan = StreamingRelation(this) + protected val output = logicalPlan.output + protected val batches = new ArrayBuffer[Dataset[A]] + protected var currentOffset: LongOffset = new LongOffset(-1) + + protected def blockManager = SparkEnv.get.blockManager + + def schema: StructType = encoder.schema + + def getCurrentOffset: Offset = currentOffset + + def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { + new Dataset(sqlContext, logicalPlan) + } + + def toDF()(implicit sqlContext: SQLContext): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } + + def addData(data: TraversableOnce[A]): Offset = { + import sqlContext.implicits._ + this.synchronized { + currentOffset = currentOffset + 1 + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") + batches.append(ds) + currentOffset + } + } + + override def getNextBatch(start: Option[Offset]): Option[Batch] = synchronized { + val newBlocks = + batches.drop( + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1) + + if (newBlocks.nonEmpty) { + logDebug(s"Running [$start, $currentOffset] on blocks ${newBlocks.mkString(", ")}") + val df = newBlocks + .map(_.toDF()) + .reduceOption(_ unionAll _) + .getOrElse(sqlContext.emptyDataFrame) + + Some(new Batch(currentOffset, df)) + } else { + None + } + } + + override def toString: String = s"MemoryStream[${output.mkString(",")}]" +} + +/** + * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit + * tests and does not provide durablility. + */ +class MemorySink(schema: StructType) extends Sink with Logging { + /** An order list of batches that have been written to this [[Sink]]. */ + private var batches = new ArrayBuffer[Batch]() + + /** Used to convert an [[InternalRow]] to an external [[Row]] for comparison in testing. */ + private val externalRowConverter = RowEncoder(schema) + + override def currentOffset: Option[Offset] = synchronized { + batches.lastOption.map(_.end) + } + + override def addBatch(nextBatch: Batch): Unit = synchronized { + batches.append(nextBatch) + } + + /** Returns all rows that are stored in this [[Sink]]. */ + def allData: Seq[Row] = synchronized { + batches + .map(_.data) + .reduceOption(_ unionAll _) + .map(_.collect().toSeq) + .getOrElse(Seq.empty) + } + + /** + * Atomically drops the most recent `num` batches and resets the [[StreamProgress]] to the + * corresponding point in the input. This function can be used when testing to simulate data + * that has been lost due to buffering. + */ + def dropBatches(num: Int): Unit = synchronized { + batches.dropRight(num) + } + + override def toString: String = synchronized { + batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n") + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8911ad370aa7b..299fc6efbb046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -123,6 +124,26 @@ trait SchemaRelationProvider { schema: StructType): BaseRelation } +/** + * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. + */ +trait StreamSourceProvider { + def createSource( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): Source +} + +/** + * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. + */ +trait StreamSinkProvider { + def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink +} + /** * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ce12f788b786c..405e5891ac976 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -304,27 +304,7 @@ object QueryTest { def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case d: java.math.BigDecimal => BigDecimal(d) - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - case o => o - }) - } - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } val sparkAnswer = try df.collect().toSeq catch { case e: Exception => val errorMessage = @@ -338,22 +318,56 @@ object QueryTest { return Some(errorMessage) } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => s""" |Results do not match for query: |${df.queryExecution} |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + |$results + """.stripMargin } + } + + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } - return None + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val errorMessage = + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + None } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala new file mode 100644 index 0000000000000..f45abbf2496a2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.streaming._ + +/** + * A framework for implementing tests for streaming queries and sources. + * + * A test consists of a set of steps (expressed as a `StreamAction`) that are executed in order, + * blocking as necessary to let the stream catch up. For example, the following adds some data to + * a stream, blocking until it can verify that the correct values are eventually produced. + * + * {{{ + * val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + CheckAnswer(2, 3, 4)) + * }}} + * + * Note that while we do sleep to allow the other thread to progress without spinning, + * `StreamAction` checks should not depend on the amount of time spent sleeping. Instead they + * should check the actual progress of the stream before verifying the required test condition. + * + * Currently it is assumed that all streaming queries will eventually complete in 10 seconds to + * avoid hanging forever in the case of failures. However, individual suites can change this + * by overriding `streamingTimeout`. + */ +trait StreamTest extends QueryTest with Timeouts { + + implicit class RichSource(s: Source) { + def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + } + + /** How long to wait for an active stream to catch up when checking a result. */ + val streamingTimout = 10.seconds + + /** A trait for actions that can be performed while testing a streaming DataFrame. */ + trait StreamAction + + /** A trait to mark actions that require the stream to be actively running. */ + trait StreamMustBeRunning + + /** + * Adds the given data to the stream. Subsuquent check answers will block until this data has + * been processed. + */ + object AddData { + def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + AddDataMemory(source, data) + } + + /** A trait that can be extended when testing other sources. */ + trait AddData extends StreamAction { + def source: Source + + /** + * Called to trigger adding the data. Should return the offset that will denote when this + * new data has been processed. + */ + def addData(): Offset + } + + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + override def toString: String = s"AddData to $source: ${data.mkString(",")}" + + override def addData(): Offset = { + source.addData(data) + } + } + + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks untill all added data has been processed. + */ + object CheckAnswer { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + } + + case class DropBatches(num: Int) extends StreamAction + + /** Stops the stream. It must currently be running. */ + case object StopStream extends StreamAction with StreamMustBeRunning + + /** Starts the stream, resuming if data has already been processed. It must not be running. */ + case object StartStream extends StreamAction + + /** Signals that a failure is expected and should not kill the test. */ + case object ExpectFailure extends StreamAction + + /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ + def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = + testStream(stream.toDF())(actions: _*) + + /** + * Executes the specified actions on the the given streaming DataFrame and provides helpful + * error messages in the case of failures or incorrect answers. + * + * Note that if the stream is not explictly started before an action that requires it to be + * running then it will be automatically started before performing any other actions. + */ + def testStream(stream: DataFrame)(actions: StreamAction*): Unit = { + var pos = 0 + var currentPlan: LogicalPlan = stream.logicalPlan + var currentStream: StreamExecution = null + val awaiting = new mutable.HashMap[Source, Offset]() + val sink = new MemorySink(stream.schema) + + @volatile + var streamDeathCause: Throwable = null + + // If the test doesn't manually start the stream, we do it automatically at the beginning. + val startedManually = + actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).contains(StartStream) + val startedTest = if (startedManually) actions else StartStream +: actions + + def testActions = actions.zipWithIndex.map { + case (a, i) => + if ((pos == i && startedManually) || (pos == (i + 1) && !startedManually)) { + "=> " + a.toString + } else { + " " + a.toString + } + }.mkString("\n") + + def currentOffsets = + if (currentStream != null) currentStream.streamProgress.toString else "not started" + + def threadState = + if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def testState = + s""" + |== Progress == + |$testActions + | + |== Stream == + |Stream state: $currentOffsets + |Thread state: $threadState + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |== Sink == + |$sink + | + |== Plan == + |${if (currentStream != null) currentStream.lastExecution else ""} + """ + + def checkState(check: Boolean, error: String) = if (!check) { + fail( + s""" + |Invalid State: $error + |$testState + """.stripMargin) + } + + val testThread = Thread.currentThread() + + try { + startedTest.foreach { action => + action match { + case StartStream => + checkState(currentStream == null, "stream already running") + + currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink) + currentStream.microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + testThread.interrupt() + } + }) + + case StopStream => + checkState(currentStream != null, "can not stop a stream that is not running") + currentStream.stop() + currentStream = null + + case DropBatches(num) => + checkState(currentStream == null, "dropping batches while running leads to corruption") + sink.dropBatches(num) + + case ExpectFailure => + try failAfter(streamingTimout) { + while (streamDeathCause == null) { + Thread.sleep(100) + } + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + fail( + s""" + |Timed out while waiting for failure. + |$testState + """.stripMargin) + } + + currentStream = null + streamDeathCause = null + + case a: AddData => + awaiting.put(a.source, a.addData()) + + case CheckAnswerRows(expectedAnswer) => + checkState(currentStream != null, "stream not running") + + // Block until all data added has been processed + awaiting.foreach { case (source, offset) => + failAfter(streamingTimout) { + currentStream.awaitOffset(source, offset) + } + } + + val allData = try sink.allData catch { + case e: Exception => + fail( + s""" + |Exception while getting data from sink $e + |$testState + """.stripMargin) + } + + QueryTest.sameRows(expectedAnswer, allData).foreach { + error => fail( + s""" + |$error + |$testState + """.stripMargin) + } + } + pos += 1 + } + } catch { + case _: InterruptedException if streamDeathCause != null => + fail( + s""" + |Stream Thread Died + |$testState + """.stripMargin) + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + fail( + s""" + |Timed out waiting for stream + |$testState + """.stripMargin) + } finally { + if (currentStream != null && currentStream.microBatchThread.isAlive) { + currentStream.stop() + } + } + } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. + * @param addData and add data action that adds the given numbers to the stream, encoding them + * as needed + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + implicit val intEncoder = ExpressionEncoder[Int]() + var dataPos = 0 + var running = true + val actions = new ArrayBuffer[StreamAction]() + + def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } + + def addRandomData() = { + val numItems = Random.nextInt(10) + val data = dataPos until (dataPos + numItems) + dataPos += numItems + actions += addData(data) + } + + (1 to iterations).foreach { i => + val rand = Random.nextDouble() + if(!running) { + rand match { + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StartStream + running = true + } + } else { + rand match { + case r if r < 0.1 => + addCheck() + + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StopStream + running = false + } + } + } + if(!running) { actions += StartStream } + addCheck() + testStream(ds)(actions: _*) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala new file mode 100644 index 0000000000000..1dab6ebf1bee9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.test + +import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest} +import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source} +import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +object LastOptions { + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var partitionColumns: Seq[String] = Nil +} + +/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ +class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + override def createSource( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): Source = { + LastOptions.parameters = parameters + LastOptions.schema = schema + new Source { + override def getNextBatch(start: Option[Offset]): Option[Batch] = None + override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + } + } + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink = { + LastOptions.parameters = parameters + LastOptions.partitionColumns = partitionColumns + new Sink { + override def addBatch(batch: Batch): Unit = {} + override def currentOffset: Option[Offset] = None + } + } +} + +class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("resolve default source") { + sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open() + .streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + } + + test("resolve full class") { + sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test.DefaultSource") + .open() + .streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .open() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.parameters = null + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .start() + .stop() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("partitioning") { + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open() + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .start() + .stop() + assert(LastOptions.partitionColumns == Nil) + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("a") + .start() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + + + withSQLConf("spark.sql.caseSensitive" -> "false") { + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("A") + .start() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + } + + intercept[AnalysisException] { + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .partitionBy("b") + .start() + .stop() + } + } + + test("stream paths") { + val df = sqlContext.streamFrom + .format("org.apache.spark.sql.streaming.test") + .open("/test") + + assert(LastOptions.parameters("path") == "/test") + + LastOptions.parameters = null + + df.streamTo + .format("org.apache.spark.sql.streaming.test") + .start("/test") + .stop() + + assert(LastOptions.parameters("path") == "/test") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala new file mode 100644 index 0000000000000..81760d2aa8205 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class MemorySourceStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("memory stress test") { + val input = MemoryStream[Int] + val mapped = input.toDS().map(_ + 1) + + runStressTest(mapped, AddData(input, _: _*)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala new file mode 100644 index 0000000000000..989465826d54e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, Offset} + +trait OffsetSuite extends SparkFunSuite { + /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ + def compare(one: Offset, two: Offset): Unit = { + test(s"comparision $one <=> $two") { + assert(one < two) + assert(one <= two) + assert(one <= one) + assert(two > one) + assert(two >= one) + assert(one >= one) + assert(one == one) + assert(two == two) + assert(one != two) + assert(two != one) + } + } + + /** Creates test to check that non-equality comparisons throw exception. */ + def compareInvalid(one: Offset, two: Offset): Unit = { + test(s"invalid comparison $one <=> $two") { + intercept[IllegalArgumentException] { + assert(one < two) + } + + intercept[IllegalArgumentException] { + assert(one <= two) + } + + intercept[IllegalArgumentException] { + assert(one > two) + } + + intercept[IllegalArgumentException] { + assert(one >= two) + } + + assert(!(one == two)) + assert(!(two == one)) + assert(one != two) + assert(two != one) + } + } +} + +class LongOffsetSuite extends OffsetSuite { + val one = LongOffset(1) + val two = LongOffset(2) + compare(one, two) +} + +class CompositeOffsetSuite extends OffsetSuite { + compare( + one = CompositeOffset(Some(LongOffset(1)) :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset(None :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compareInvalid( // sizes must be same + one = CompositeOffset(Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compare( + one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compareInvalid( + one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala new file mode 100644 index 0000000000000..fbb1792596b18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class StreamSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("map with recovery") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + StartStream, + CheckAnswer(2, 3, 4), + StopStream, + AddData(inputData, 4, 5, 6), + StartStream, + CheckAnswer(2, 3, 4, 5, 6, 7)) + } + + test("join") { + // Make a table and ensure it will be broadcast. + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF().join(smallTable, $"value" === $"number") + + testStream(joined)( + AddData(inputData, 1, 2, 3), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two")), + AddData(inputData, 4), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) + } + + test("union two streams") { + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val unioned = inputData1.toDS().union(inputData2.toDS()) + + testStream(unioned)( + AddData(inputData1, 1, 3, 5), + CheckAnswer(1, 3, 5), + AddData(inputData2, 2, 4, 6), + CheckAnswer(1, 2, 3, 4, 5, 6), + StopStream, + AddData(inputData1, 7), + StartStream, + AddData(inputData2, 8), + CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8)) + } + + test("sql queries") { + val inputData = MemoryStream[Int] + inputData.toDF().registerTempTable("stream") + val evens = sql("SELECT * FROM stream WHERE value % 2 = 0") + + testStream(evens)( + AddData(inputData, 1, 2, 3, 4), + CheckAnswer(2, 4)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e7b376548787c..c341191c70bb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -36,7 +36,7 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. From 29d92181d0c49988c387d34e4a71b1afe02c29e2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 Feb 2016 10:15:40 -0800 Subject: [PATCH 1771/2089] [SPARK-13094][SQL] Add encoders for seq/array of primitives Author: Michael Armbrust Closes #11014 from marmbrus/seqEncoders. --- .../org/apache/spark/sql/SQLImplicits.scala | 63 ++++++++++++++++++- .../spark/sql/DatasetPrimitiveSuite.scala | 22 +++++++ .../org/apache/spark/sql/QueryTest.scala | 8 ++- 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index ab414799f1a42..16c4095db722a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -39,6 +39,8 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + // Primitives + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() @@ -56,13 +58,72 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() - /** @since 1.6.0 */ + /** @since 1.6.0 */ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + // Seqs + + /** @since 1.6.1 */ + implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + + // Arrays + + /** @since 1.6.1 */ + implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = + ExpressionEncoder() + /** * Creates a [[Dataset]] from an RDD. * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index f75d0961823c4..243d13b19d6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { agged, "1", "abc", "3", "xyz", "5", "hello") } + + test("Arrays and Lists") { + checkAnswer(Seq(Seq(1)).toDS(), Seq(1)) + checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkAnswer(Seq(Seq(true)).toDS(), Seq(true)) + checkAnswer(Seq(Seq("test")).toDS(), Seq("test")) + checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkAnswer(Seq(Array(1)).toDS(), Array(1)) + checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkAnswer(Seq(Array(true)).toDS(), Array(true)) + checkAnswer(Seq(Array("test")).toDS(), Array("test")) + checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 405e5891ac976..5401212428d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - if (decoded != expectedAnswer.toSet) { + // Handle the case where the return type is an array + val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) + def normalEquality = decoded == expectedAnswer.toSet + def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet + def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) + + if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted From b93830126cc59a26e2cfb5d7b3c17f9cfbf85988 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Feb 2016 10:41:06 -0800 Subject: [PATCH 1772/2089] [SPARK-13114][SQL] Add a test for tokens more than the fields in schema https://issues.apache.org/jira/browse/SPARK-13114 This PR adds a test for tokens more than the fields in schema. Author: hyukjinkwon Closes #11020 from HyukjinKwon/SPARK-13114. --- sql/core/src/test/resources/cars-malformed.csv | 6 ++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 sql/core/src/test/resources/cars-malformed.csv diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/cars-malformed.csv new file mode 100644 index 0000000000000..cfa378c01f1d9 --- /dev/null +++ b/sql/core/src/test/resources/cars-malformed.csv @@ -0,0 +1,6 @@ +~ All the rows here are malformed having tokens more than the schema (header). +year,make,model,comment,blank +"2012","Tesla","S","No comment",,null,null + +1997,Ford,E350,"Go get one now they are going fast",,null,null +2015,Chevy,,,, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a79566b1f3658..fa4f137b703b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.types._ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val carsFile = "cars.csv" + private val carsMalformedFile = "cars-malformed.csv" private val carsFile8859 = "cars_iso-8859-1.csv" private val carsTsvFile = "cars.tsv" private val carsAltFile = "cars-alternative.csv" @@ -191,6 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) } + test("test for tokens more than the fields in the schema") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .option("comment", "~") + .load(testFile(carsMalformedFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + test("test with null quote character") { val cars = sqlContext.read .format("csv") From cba1d6b659288bfcd8db83a6d778155bab2bbecf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 2 Feb 2016 10:50:22 -0800 Subject: [PATCH 1773/2089] [SPARK-12631][PYSPARK][DOC] PySpark clustering parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the clustering module. Author: Bryan Cutler Closes #10610 from BryanCutler/param-desc-consistent-cluster-SPARK-12631. --- .../mllib/clustering/GaussianMixture.scala | 12 +- .../spark/mllib/clustering/KMeans.scala | 31 +- .../apache/spark/mllib/clustering/LDA.scala | 13 +- .../clustering/PowerIterationClustering.scala | 4 +- .../mllib/clustering/StreamingKMeans.scala | 6 +- python/pyspark/mllib/clustering.py | 265 +++++++++++++----- 6 files changed, 228 insertions(+), 103 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 7b203e2f40815..88dbfe3fcc9f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -45,10 +45,10 @@ import org.apache.spark.util.Utils * This is due to high-dimensional data (a) making it difficult to cluster at all (based * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. * - * @param k The number of independent Gaussians in the mixture model - * @param convergenceTol The maximum change in log-likelihood at which convergence - * is considered to have occurred. - * @param maxIterations The maximum number of iterations to perform + * @param k Number of independent Gaussians in the mixture model. + * @param convergenceTol Maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations Maximum number of iterations allowed. */ @Since("1.3.0") class GaussianMixture private ( @@ -108,7 +108,7 @@ class GaussianMixture private ( def getK: Int = k /** - * Set the maximum number of iterations to run. Default: 100 + * Set the maximum number of iterations allowed. Default: 100 */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { @@ -117,7 +117,7 @@ class GaussianMixture private ( } /** - * Return the maximum number of iterations to run + * Return the maximum number of iterations allowed */ @Since("1.3.0") def getMaxIterations: Int = maxIterations diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ca11ede4ccd47..901164a391170 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -70,13 +70,13 @@ class KMeans private ( } /** - * Maximum number of iterations to run. + * Maximum number of iterations allowed. */ @Since("1.4.0") def getMaxIterations: Int = maxIterations /** - * Set maximum number of iterations to run. Default: 20. + * Set maximum number of iterations allowed. Default: 20. */ @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { @@ -482,12 +482,15 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). - * @param seed random seed value for cluster initialization + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + * @param seed Random seed for cluster initialization. Default is to generate seed based + * on system time. */ @Since("1.3.0") def train( @@ -508,11 +511,13 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") */ @Since("0.8.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index eb802a365ed6e..81566b4779d66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -61,14 +61,13 @@ class LDA private ( ldaOptimizer = new EMLDAOptimizer) /** - * Number of topics to infer. I.e., the number of soft cluster centers. - * + * Number of topics to infer, i.e., the number of soft cluster centers. */ @Since("1.3.0") def getK: Int = k /** - * Number of topics to infer. I.e., the number of soft cluster centers. + * Set the number of topics to infer, i.e., the number of soft cluster centers. * (default = 10) */ @Since("1.3.0") @@ -222,13 +221,13 @@ class LDA private ( def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** - * Maximum number of iterations for learning. + * Maximum number of iterations allowed. */ @Since("1.3.0") def getMaxIterations: Int = maxIterations /** - * Maximum number of iterations for learning. + * Set the maximum number of iterations allowed. * (default = 20) */ @Since("1.3.0") @@ -238,13 +237,13 @@ class LDA private ( } /** - * Random seed + * Random seed for cluster initialization. */ @Since("1.3.0") def getSeed: Long = seed /** - * Random seed + * Set the random seed for cluster initialization. */ @Since("1.3.0") def setSeed(seed: Long): this.type = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 2ab0920b06363..1ab7cb393b081 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -111,7 +111,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. - * @param initMode Initialization mode. + * @param initMode Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. + * Default: random. * * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 79d217e183c62..d99b89dc49ebf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -183,7 +183,7 @@ class StreamingKMeans @Since("1.2.0") ( } /** - * Set the decay factor directly (for forgetful algorithms). + * Set the forgetfulness of the previous centroids. */ @Since("1.2.0") def setDecayFactor(a: Double): this.type = { @@ -192,7 +192,9 @@ class StreamingKMeans @Since("1.2.0") ( } /** - * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + * Set the half life and time unit ("batches" or "points"). If points, then the decay factor + * is raised to the power of number of new points and if batches, then decay factor will be + * used as is. */ @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 4e9eb96fd9da1..ad04e46e8870b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -88,8 +88,11 @@ def predict(self, x): Find the cluster that each of the points belongs to in this model. - :param x: the point (or RDD of points) to determine - compute the clusters for. + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -105,7 +108,8 @@ def computeCost(self, x): points to their nearest center) for this model on the given data. If provided with an RDD of points returns the sum. - :param point: the point or RDD of points to compute the cost(s). + :param point: + A data point (or RDD of points) to compute the cost(s). """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -143,17 +147,23 @@ def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1 """ Runs the bisecting k-means algorithm return the model. - :param rdd: input RDD to be trained on - :param k: The desired number of leaf clusters (default: 4). - The actual number could be smaller if there are no divisible - leaf clusters. - :param maxIterations: the max number of k-means iterations to - split clusters (default: 20) - :param minDivisibleClusterSize: the minimum number of points - (if >= 1.0) or the minimum proportion of points (if < 1.0) - of a divisible cluster (default: 1) - :param seed: a random seed (default: -1888008604 from - classOf[BisectingKMeans].getName.##) + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + The desired number of leaf clusters. The actual number could + be smaller if there are no divisible leaf clusters. + (default: 4) + :param maxIterations: + Maximum number of iterations allowed to split clusters. + (default: 20) + :param minDivisibleClusterSize: + Minimum number of points (if >= 1.0) or the minimum proportion + of points (if < 1.0) of a divisible cluster. + (default: 1) + :param seed: + Random seed value for cluster initialization. + (default: -1888008604 from classOf[BisectingKMeans].getName.##) """ java_model = callMLlibFunc( "trainBisectingKMeans", rdd.map(_convert_to_vector), @@ -239,8 +249,11 @@ def predict(self, x): Find the cluster that each of the points belongs to in this model. - :param x: the point (or RDD of points) to determine - compute the clusters for. + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ best = 0 best_distance = float("inf") @@ -262,7 +275,8 @@ def computeCost(self, rdd): their nearest center) for this model on the given data. - :param point: the RDD of points to compute the cost on. + :param rdd: + The RDD of points to compute the cost on. """ cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), [_convert_to_vector(c) for c in self.centers]) @@ -296,7 +310,44 @@ class KMeans(object): @since('0.9.0') def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): - """Train a k-means clustering model.""" + """ + Train a k-means clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of clusters to create. + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param runs: + Number of runs to execute in parallel. The best model according + to the cost function will be returned (deprecated in 1.6.0). + (default: 1) + :param initializationMode: + The initialization algorithm. This can be either "random" or + "k-means||". + (default: "k-means||") + :param seed: + Random seed value for cluster initialization. Set as None to + generate seed based on system time. + (default: None) + :param initializationSteps: + Number of steps for the k-means|| initialization mode. + This is an advanced setting -- the default of 5 is almost + always enough. + (default: 5) + :param epsilon: + Distance threshold within which a center will be considered to + have converged. If all centers move less than this Euclidean + distance, iterations are stopped. + (default: 1e-4) + :param initialModel: + Initial cluster centers can be provided as a KMeansModel object + rather than using the random or k-means|| initializationModel. + (default: None) + """ if runs != 1: warnings.warn( "Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.") @@ -415,8 +466,11 @@ def predict(self, x): Find the cluster to which the point 'x' or each point in RDD 'x' has maximum membership in this model. - :param x: vector or RDD of vector represents data points. - :return: cluster label or RDD of cluster labels. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + Predicted cluster label or an RDD of predicted cluster labels + if the input is an RDD. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) @@ -430,9 +484,11 @@ def predictSoft(self, x): """ Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: vector or RDD of vector represents data points. - :return: the membership value to all mixture components for vector 'x' - or each vector in RDD 'x'. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + The membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -447,8 +503,10 @@ def predictSoft(self, x): def load(cls, sc, path): """Load the GaussianMixtureModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ model = cls._load_java(sc, path) wrapper = sc._jvm.GaussianMixtureModelWrapper(model) @@ -461,19 +519,35 @@ class GaussianMixture(object): Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. - :param data: RDD of data points - :param k: Number of components - :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 - :param maxIterations: Number of iterations. Default to 100 - :param seed: Random Seed - :param initialModel: GaussianMixtureModel for initializing learning - .. versionadded:: 1.3.0 """ @classmethod @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): - """Train a Gaussian Mixture clustering model.""" + """ + Train a Gaussian Mixture clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of independent Gaussians in the mixture model. + :param convergenceTol: + Maximum change in log-likelihood at which convergence is + considered to have occurred. + (default: 1e-3) + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param seed: + Random seed for initial Gaussian distribution. Set as None to + generate seed based on system time. + (default: None) + :param initialModel: + Initial GMM starting point, bypassing the random + initialization. + (default: None) + """ initialModelWeights = None initialModelMu = None initialModelSigma = None @@ -574,18 +648,24 @@ class PowerIterationClustering(object): @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): """ - :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the - affinity matrix, which is the matrix A in the PIC paper. - The similarity s,,ij,, must be nonnegative. - This is a symmetric matrix and hence s,,ij,, = s,,ji,,. - For any (i, j) with nonzero similarity, there should be - either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. - Tuples with i = j are ignored, because we assume - s,,ij,, = 0.0. - :param k: Number of clusters. - :param maxIterations: Maximum number of iterations of the - PIC algorithm. - :param initMode: Initialization mode. + :param rdd: + An RDD of (i, j, s\ :sub:`ij`\) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. The + similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric + matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with + nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or + (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored, + because it is assumed s\ :sub:`ij`\ = 0.0. + :param k: + Number of clusters. + :param maxIterations: + Maximum number of iterations of the PIC algorithm. + (default: 100) + :param initMode: + Initialization mode. This can be either "random" to use + a random vector as vertex properties, or "degree" to use + normalized sum similarities. + (default: "random") """ model = callMLlibFunc("trainPowerIterationClusteringModel", rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) @@ -625,8 +705,10 @@ class StreamingKMeansModel(KMeansModel): and new data. If it set to zero, the old centroids are completely forgotten. - :param clusterCenters: Initial cluster centers. - :param clusterWeights: List of weights assigned to each cluster. + :param clusterCenters: + Initial cluster centers. + :param clusterWeights: + List of weights assigned to each cluster. >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] >>> initWeights = [1.0, 1.0] @@ -673,11 +755,14 @@ def clusterWeights(self): def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data - :param data: Should be a RDD that represents the new data. - :param decayFactor: forgetfulness of the previous centroids. - :param timeUnit: Can be "batches" or "points". If points, then the - decay factor is raised to the power of number of new - points and if batches, it is used as it is. + :param data: + RDD with new data for the model update. + :param decayFactor: + Forgetfulness of the previous centroids. + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor + is raised to the power of number of new points and if batches, + then decay factor will be used as is. """ if not isinstance(data, RDD): raise TypeError("Data should be of an RDD, got %s." % type(data)) @@ -704,10 +789,17 @@ class StreamingKMeans(object): More details on how the centroids are updated are provided under the docs of StreamingKMeansModel. - :param k: int, number of clusters - :param decayFactor: float, forgetfulness of the previous centroids. - :param timeUnit: can be "batches" or "points". If points, then the - decayfactor is raised to the power of no. of new points. + :param k: + Number of clusters. + (default: 2) + :param decayFactor: + Forgetfulness of the previous centroids. + (default: 1.0) + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor is + raised to the power of number of new points and if batches, then + decay factor will be used as is. + (default: "batches") .. versionadded:: 1.5.0 """ @@ -870,11 +962,13 @@ def describeTopics(self, maxTermsPerTopic=None): WARNING: If vocabSize and k are large, this can return a large object! - :param maxTermsPerTopic: Maximum number of terms to collect for each topic. - (default: vocabulary size) - :return: Array over topics. Each topic is represented as a pair of matching arrays: - (term indices, term weights in topic). - Each topic's terms are sorted in order of decreasing weight. + :param maxTermsPerTopic: + Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: + Array over topics. Each topic is represented as a pair of + matching arrays: (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ if maxTermsPerTopic is None: topics = self.call("describeTopics") @@ -887,8 +981,10 @@ def describeTopics(self, maxTermsPerTopic=None): def load(cls, sc, path): """Load the LDAModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) @@ -909,17 +1005,38 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. - :param rdd: RDD of data points - :param k: Number of clusters you want - :param maxIterations: Number of iterations. Default to 20 - :param docConcentration: Concentration parameter (commonly named "alpha") - for the prior placed on documents' distributions over topics ("theta"). - :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") - for the prior placed on topics' distributions over terms. - :param seed: Random Seed - :param checkpointInterval: Period (in iterations) between checkpoints. - :param optimizer: LDAOptimizer used to perform the actual calculation. - Currently "em", "online" are supported. Default to "em". + :param rdd: + RDD of documents, which are tuples of document IDs and term + (word) count vectors. The term count vectors are "bags of + words" with a fixed-size vocabulary (where the vocabulary size + is the length of the vector). Document IDs must be unique + and >= 0. + :param k: + Number of topics to infer, i.e., the number of soft cluster + centers. + (default: 10) + :param maxIterations: + Maximum number of iterations allowed. + (default: 20) + :param docConcentration: + Concentration parameter (commonly named "alpha") for the prior + placed on documents' distributions over topics ("theta"). + (default: -1.0) + :param topicConcentration: + Concentration parameter (commonly named "beta" or "eta") for + the prior placed on topics' distributions over terms. + (default: -1.0) + :param seed: + Random seed for cluster initialization. Set as None to generate + seed based on system time. + (default: None) + :param checkpointInterval: + Period (in iterations) between checkpoints. + (default: 10) + :param optimizer: + LDAOptimizer used to perform the actual calculation. Currently + "em", "online" are supported. + (default: "em") """ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, docConcentration, topicConcentration, seed, From 358300c795025735c3b2f96c5447b1b227d4abc1 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 2 Feb 2016 11:09:40 -0800 Subject: [PATCH 1774/2089] [SPARK-13056][SQL] map column would throw NPE if value is null Jira: https://issues.apache.org/jira/browse/SPARK-13056 Create a map like { "a": "somestring", "b": null} Query like SELECT col["b"] FROM t1; NPE would be thrown. Author: Daoyuan Wang Closes #10964 from adrian-wang/npewriter. --- .../expressions/complexTypeExtractors.scala | 15 +++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5256baaf432a2..9f2f82d68cca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -218,7 +218,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { + if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) @@ -267,6 +267,7 @@ case class GetMapValue(child: Expression, key: Expression) val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() + val values = map.valueArray() var i = 0 var found = false @@ -278,10 +279,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!found) { + if (!found || values.isNullAt(i)) { null } else { - map.valueArray().get(i, dataType) + values.get(i, dataType) } } @@ -291,10 +292,12 @@ case class GetMapValue(child: Expression, key: Expression) val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") + val values = ctx.freshName("values") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); + final ArrayData $values = $eval1.valueArray(); int $index = 0; boolean $found = false; @@ -307,10 +310,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if ($found) { - ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; - } else { + if (!$found || $values.isNullAt($index)) { ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(values, dataType, index)}; } """ }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2b821c1056f56..79bfd4b44b70a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2055,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-13056: Null in map value causes NPE") { + val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") + withTempTable("maptest") { + df.registerTempTable("maptest") + // local optimization will by pass codegen code, so we should keep the filter `key=1` + checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring")) + checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) + } + } + test("hash function") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") withTempTable("tbl") { From b1835d727234fdff42aa8cadd17ddcf43b0bed15 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Tue, 2 Feb 2016 11:16:24 -0800 Subject: [PATCH 1775/2089] [SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication Fixes problem and verifies fix by test suite. Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn and deduplicates SchemaUtils.appendColumn functions. Author: Grzegorz Chilkiewicz Closes #10741 from grzegorz-chilkiewicz/master. --- .../spark/ml/feature/StopWordsRemover.scala | 4 +--- .../org/apache/spark/ml/util/SchemaUtils.scala | 8 +++----- .../spark/ml/feature/StopWordsRemoverSuite.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index b93c9ed382bdf..e53ef300f644b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index e71dd9eee03e3..76021ad8f4e65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -71,12 +71,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index fb217e0c1de93..a5b24c18565b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -89,4 +89,19 @@ class StopWordsRemoverSuite .setCaseSensitive(true) testDefaultReadWrite(t) } + + test("StopWordsRemover output column already exists") { + val outputCol = "expected" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol(outputCol) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("The", "the", "swift"), Seq("swift")) + )).toDF("raw", outputCol) + + val thrown = intercept[IllegalArgumentException] { + testStopWordsRemover(remover, dataSet) + } + assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + } } From 7f6e3ec79b77400f558ceffa10b2af011962115f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Feb 2016 11:29:20 -0800 Subject: [PATCH 1776/2089] [SPARK-13138][SQL] Add "logical" package prefix for ddl.scala ddl.scala is defined in the execution package, and yet its reference of "UnaryNode" and "Command" are logical. This was fairly confusing when I was trying to understand the ddl code. Author: Reynold Xin Closes #11021 from rxin/SPARK-13138. --- .../spark/sql/execution/datasources/ddl.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1554209be9891..a141b58d3d72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ import org.apache.spark.sql.types._ */ case class DescribeCommand( table: LogicalPlan, - isExtended: Boolean) extends LogicalPlan with Command { + isExtended: Boolean) extends LogicalPlan with logical.Command { override def children: Seq[LogicalPlan] = Seq.empty @@ -59,7 +60,7 @@ case class CreateTableUsing( temporary: Boolean, options: Map[String, String], allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with Command { + managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty @@ -67,8 +68,8 @@ case class CreateTableUsing( /** * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer - * can analyze the logical plan that will be used to populate the table. + * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the + * analyzer can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ case class CreateTableUsingAsSelect( @@ -79,7 +80,7 @@ case class CreateTableUsingAsSelect( bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan) extends logical.UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } From be5dd881f1eff248224a92d57cfd1309cb3acf38 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 11:50:14 -0800 Subject: [PATCH 1777/2089] [SPARK-12913] [SQL] Improve performance of stat functions As benchmarked and discussed here: https://github.com/apache/spark/pull/10786/files#r50038294, benefits from codegen, the declarative aggregate function could be much faster than imperative one. Author: Davies Liu Closes #10960 from davies/stddev. --- .../catalyst/analysis/HiveTypeCoercion.scala | 18 +- .../aggregate/CentralMomentAgg.scala | 285 ++++++++---------- .../catalyst/expressions/aggregate/Corr.scala | 208 ++++--------- .../expressions/aggregate/Covariance.scala | 205 ++++--------- .../expressions/aggregate/Kurtosis.scala | 54 ---- .../expressions/aggregate/Skewness.scala | 53 ---- .../expressions/aggregate/Stddev.scala | 81 ----- .../expressions/aggregate/Variance.scala | 81 ----- .../spark/sql/catalyst/expressions/misc.scala | 18 ++ .../apache/spark/sql/execution/Window.scala | 6 +- .../aggregate/TungstenAggregate.scala | 1 - .../BenchmarkWholeStageCodegen.scala | 55 +++- .../execution/HiveCompatibilitySuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 17 +- 14 files changed, 331 insertions(+), 755 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 957ac89fa530d..57bdb164e1a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -347,18 +347,12 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) - case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => - Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) + case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) + case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 30f602227b17d..9d2db45144817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -44,7 +42,7 @@ import org.apache.spark.sql.types._ * * @param child to compute central moments of. */ -abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { +abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate { /** * The central moment order to be computed. @@ -52,178 +50,161 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w protected def momentOrder: Int override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = true - override def dataType: DataType = DoubleType + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val avg = AttributeReference("avg", DoubleType, nullable = false)() + protected val m2 = AttributeReference("m2", DoubleType, nullable = false)() + protected val m3 = AttributeReference("m3", DoubleType, nullable = false)() + protected val m4 = AttributeReference("m4", DoubleType, nullable = false)() - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") + private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1) - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) - /** - * Size of aggregation buffer. - */ - private[this] val bufferSize = 5 + override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => - AttributeReference(s"M$i", DoubleType)() + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) } - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - // buffer offsets - private[this] val nOffset = mutableAggBufferOffset - private[this] val meanOffset = mutableAggBufferOffset + 1 - private[this] val secondMomentOffset = mutableAggBufferOffset + 2 - private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 - private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 - - // frequently used values for online updates - private[this] var delta = 0.0 - private[this] var deltaN = 0.0 - private[this] var delta2 = 0.0 - private[this] var deltaN2 = 0.0 - private[this] var n = 0.0 - private[this] var mean = 0.0 - private[this] var m2 = 0.0 - private[this] var m3 = 0.0 - private[this] var m4 = 0.0 + override val mergeExpressions: Seq[Expression] = { - /** - * Initialize all moments to zero. - */ - override def initialize(buffer: MutableRow): Unit = { - for (aggIndex <- 0 until bufferSize) { - buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val delta = avg.right - avg.left + val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val newAvg = avg.left + deltaN * n2 + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2 + // `m3.right` is not available if momentOrder < 3 + val newM3 = if (momentOrder >= 3) { + m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) + + Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left) + } else { + Literal(0.0) } + // `m4.right` is not available if momentOrder < 4 + val newM4 = if (momentOrder >= 4) { + m4.left + m4.right + + deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) + + Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) + + Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } +} - /** - * Update the central moments buffer. - */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = Cast(child, DoubleType).eval(input) - if (v != null) { - val updateValue = v match { - case d: Double => d - } - - n = buffer.getDouble(nOffset) - mean = buffer.getDouble(meanOffset) - - n += 1.0 - buffer.setDouble(nOffset, n) - delta = updateValue - mean - deltaN = delta / n - mean += deltaN - buffer.setDouble(meanOffset, mean) - - if (momentOrder >= 2) { - m2 = buffer.getDouble(secondMomentOffset) - m2 += delta * (delta - deltaN) - buffer.setDouble(secondMomentOffset, m2) - } - - if (momentOrder >= 3) { - delta2 = delta * delta - deltaN2 = deltaN * deltaN - m3 = buffer.getDouble(thirdMomentOffset) - m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) - buffer.setDouble(thirdMomentOffset, m3) - } - - if (momentOrder >= 4) { - m4 = buffer.getDouble(fourthMomentOffset) - m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + - delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(fourthMomentOffset, m4) - } - } +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + Sqrt(m2 / n)) } - /** - * Merge two central moment buffers. - */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val n1 = buffer1.getDouble(nOffset) - val n2 = buffer2.getDouble(inputAggBufferOffset) - val mean1 = buffer1.getDouble(meanOffset) - val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 - var secondMoment1 = 0.0 - var secondMoment2 = 0.0 + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + Sqrt(m2 / (n - Literal(1.0))))) + } - var thirdMoment1 = 0.0 - var thirdMoment2 = 0.0 + override def prettyName: String = "stddev_samp" +} - var fourthMoment1 = 0.0 - var fourthMoment2 = 0.0 +// Compute the population variance of a column +case class VariancePop(child: Expression) extends CentralMomentAgg(child) { - n = n1 + n2 - buffer1.setDouble(nOffset, n) - delta = mean2 - mean1 - deltaN = if (n == 0.0) 0.0 else delta / n - mean = mean1 + deltaN * n2 - buffer1.setDouble(mutableAggBufferOffset + 1, mean) + override protected def momentOrder = 2 - // higher order moments computed according to: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics - if (momentOrder >= 2) { - secondMoment1 = buffer1.getDouble(secondMomentOffset) - secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 - buffer1.setDouble(secondMomentOffset, m2) - } + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + m2 / n) + } - if (momentOrder >= 3) { - thirdMoment1 = buffer1.getDouble(thirdMomentOffset) - thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * - (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) - buffer1.setDouble(thirdMomentOffset, m3) - } + override def prettyName: String = "var_pop" +} - if (momentOrder >= 4) { - fourthMoment1 = buffer1.getDouble(fourthMomentOffset) - fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * - n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * - (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + - 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) - buffer1.setDouble(fourthMomentOffset, m4) - } +// Compute the sample variance of a column +case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + m2 / (n - Literal(1.0)))) } - /** - * Compute aggregate statistic from sufficient moments. - * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) - * needed to compute the aggregate stat. - */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any - - override final def eval(buffer: InternalRow): Any = { - val n = buffer.getDouble(nOffset) - val mean = buffer.getDouble(meanOffset) - val moments = Array.ofDim[Double](momentOrder + 1) - moments(0) = 1.0 - moments(1) = 0.0 - if (momentOrder >= 2) { - moments(2) = buffer.getDouble(secondMomentOffset) - } - if (momentOrder >= 3) { - moments(3) = buffer.getDouble(thirdMomentOffset) - } - if (momentOrder >= 4) { - moments(4) = buffer.getDouble(fourthMomentOffset) - } + override def prettyName: String = "var_samp" +} + +case class Skewness(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "skewness" + + override protected def momentOrder = 3 - getStatistic(n, mean, moments) + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) } } + +case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 4 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + n * m4 / (m2 * m2) - Literal(3.0))) + } + + override def prettyName: String = "kurtosis" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index d25f3335ffd93..e6b8214ef25e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -29,165 +28,70 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -case class Corr( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { - - def this(left: Expression, right: Expression) = - this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def children: Seq[Expression] = Seq(left, right) +case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true - override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"corr requires that both arguments are double type, " + - s"not (${left.dataType}, ${right.dataType}).") - } + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)() + protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk) + + override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) + + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dxN = dx / newN + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dxN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + val newXMk = xMk + dx * (x - newXAvg) + val newYMk = yMk + dy * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk), + If(isNull, xMk, newXMk), + If(isNull, yMk, newYMk) + ) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) + override val mergeExpressions: Seq[Expression] = { + + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) } - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("MkX", DoubleType)(), - AttributeReference("MkY", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 - private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 - private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 - private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 - private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 - private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 - private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 - private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 - private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(mutableAggBufferOffset, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) - buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / Sqrt(xMk * yMk))) } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer.getLong(mutableAggBufferOffsetPlus5) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - - buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(mutableAggBufferOffset) - var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer1.getLong(mutableAggBufferOffsetPlus5) - - val xAvg2 = buffer2.getDouble(inputAggBufferOffset) - val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) - val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) - val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) - val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) - - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 - MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 - count = totalCount - - buffer1.setDouble(mutableAggBufferOffset, xAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer1.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(mutableAggBufferOffsetPlus5) - if (count > 0) { - val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - val corr = Ck / math.sqrt(MkX * MkY) - if (corr.isNaN) { - null - } else { - corr - } - } else { - null - } - } + override def prettyName: String = "corr" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index f53b01be2a0d5..c175a8c4c77b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -17,182 +17,79 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Compute the covariance between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. - * */ -abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate - with Serializable { - override def children: Seq[Expression] = Seq(left, right) +abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true - override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"covariance requires that both arguments are double type, " + - s"not (${left.dataType}, ${right.dataType}).") - } - } - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) - } - - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - val xAvgOffset = mutableAggBufferOffset - val yAvgOffset = mutableAggBufferOffset + 1 - val CkOffset = mutableAggBufferOffset + 2 - val countOffset = mutableAggBufferOffset + 3 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - val inputXAvgOffset = inputAggBufferOffset - val inputYAvgOffset = inputAggBufferOffset + 1 - val inputCkOffset = inputAggBufferOffset + 2 - val inputCountOffset = inputAggBufferOffset + 3 - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(xAvgOffset, 0.0) - buffer.setDouble(yAvgOffset, 0.0) - buffer.setDouble(CkOffset, 0.0) - buffer.setLong(countOffset, 0L) - } - - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(xAvgOffset) - var yAvg = buffer.getDouble(yAvgOffset) - var Ck = buffer.getDouble(CkOffset) - var count = buffer.getLong(countOffset) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - - buffer.setDouble(xAvgOffset, xAvg) - buffer.setDouble(yAvgOffset, yAvg) - buffer.setDouble(CkOffset, Ck) - buffer.setLong(countOffset, count) - } + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck) + + override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) + + override lazy val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) } - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputCountOffset) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(xAvgOffset) - var yAvg = buffer1.getDouble(yAvgOffset) - var Ck = buffer1.getDouble(CkOffset) - var count = buffer1.getLong(countOffset) + override val mergeExpressions: Seq[Expression] = { - val xAvg2 = buffer2.getDouble(inputXAvgOffset) - val yAvg2 = buffer2.getDouble(inputYAvgOffset) - val Ck2 = buffer2.getDouble(inputCkOffset) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - count = totalCount - - buffer1.setDouble(xAvgOffset, xAvg) - buffer1.setDouble(yAvgOffset, yAvg) - buffer1.setDouble(CkOffset, Ck) - buffer1.setLong(countOffset, count) - } + Seq(newN, newXAvg, newYAvg, newCk) } } -case class CovSample( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends Covariance(left, right) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(countOffset) - if (count > 1) { - val Ck = buffer.getDouble(CkOffset) - val cov = Ck / (count - 1) - if (cov.isNaN) { - null - } else { - cov - } - } else { - null - } +case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + ck / n) } + override def prettyName: String = "covar_pop" } -case class CovPopulation( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends Covariance(left, right) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(countOffset) - if (count > 0) { - val Ck = buffer.getDouble(CkOffset) - if (Ck.isNaN) { - null - } else { - Ck / count - } - } else { - null - } +case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / (n - Literal(1.0)))) } + override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala deleted file mode 100644 index c2bf2cb94116c..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Kurtosis(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "kurtosis" - - override protected val momentOrder = 4 - - // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m4 = moments(4) - - if (n == 0.0) { - null - } else if (m2 == 0.0) { - Double.NaN - } else { - n * m4 / (m2 * m2) - 3.0 - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala deleted file mode 100644 index 9411bcea2539a..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Skewness(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "skewness" - - override protected val momentOrder = 3 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m3 = moments(3) - - if (n == 0.0) { - null - } else if (m2 == 0.0) { - Double.NaN - } else { - math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala deleted file mode 100644 index eec79a9033e36..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class StddevSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "stddev_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else if (n == 1.0) { - Double.NaN - } else { - math.sqrt(moments(2) / (n - 1.0)) - } - } -} - -case class StddevPop( - child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "stddev_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else { - math.sqrt(moments(2) / n) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala deleted file mode 100644 index cf3a740305391..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class VarianceSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else if (n == 1.0) { - Double.NaN - } else { - moments(2) / (n - 1.0) - } - } -} - -case class VariancePop( - child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) { - null - } else { - moments(2) / n - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 36e1fa1176d22..f4ccadd9c563e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -424,3 +424,21 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } } + +/** + * Print the result of an expression to stderr (used for debugging codegen). + */ +case class PrintToStderr(child: Expression) extends UnaryExpression { + + override def dataType: DataType = child.dataType + + protected override def nullSafeEval(input: Any): Any = input + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, c => + s""" + | System.err.println("Result of ${child.simpleString} is " + $c); + | ${ev.value} = $c; + """.stripMargin) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 26a7340f1ae10..84154a47de393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -198,7 +198,8 @@ case class Window( functions, ordinal, child.output, - (expressions, schema) => newMutableProjection(expressions, schema)) + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) // Create the factory val factory = key match { @@ -210,7 +211,8 @@ case class Window( ordinal, functions, child.output, - (expressions, schema) => newMutableProjection(expressions, schema), + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 57db7262fdaf3..a8a81d6d6574e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -240,7 +240,6 @@ case class TungstenAggregate( | ${bufVars(i).value} = ${ev.value}; """.stripMargin } - s""" | // do aggregate | ${aggVals.map(_.code).mkString("\n")} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2f09c8a114bc5..1ccf0e3d0656c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -59,6 +59,55 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testStatFunctions(values: Int): Unit = { + + val benchmark = new Benchmark("stat functions", values) + + benchmark.addCase("stddev w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("stddev w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("kurtosis w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + } + + benchmark.addCase("kurtosis w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + } + + + /** + Using ImperativeAggregate (as implemented in Spark 1.6): + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 2019.04 10.39 1.00 X + stddev w codegen 2097.29 10.00 0.96 X + kurtosis w/o codegen 2108.99 9.94 0.96 X + kurtosis w codegen 2090.69 10.03 0.97 X + + Using DeclarativeAggregate: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 989.22 21.20 1.00 X + stddev w codegen 352.35 59.52 2.81 X + kurtosis w/o codegen 3636.91 5.77 0.27 X + kurtosis w codegen 369.25 56.79 2.68 X + */ + benchmark.run() + } + def testAggregateWithKey(values: Int): Unit = { val benchmark = new Benchmark("Aggregate with keys", values) @@ -147,8 +196,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } - test("benchmark") { - // testWholeStage(1024 * 1024 * 200) + // These benchmark are skipped in normal build + ignore("benchmark") { + // testWholeStage(200 << 20) + // testStddev(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 554d47d651aef..61b73fa557144 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -325,6 +325,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_ignore_protection", "protectmode", + // Hive returns null rather than NaN when n = 1 + "udaf_covar_samp", + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. "udf_abs", "udf_format_number", @@ -881,7 +884,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_widening", "udaf_collect_set", "udaf_covar_pop", - "udaf_covar_samp", "udaf_histogram_numeric", "udf2", "udf5", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7a9ed1eaf3dbc..caf1db9ad0855 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -798,7 +798,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), - Row(null) :: Nil) + Row(Double.NaN) :: Nil) checkAnswer( sqlContext.sql( @@ -807,10 +807,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """.stripMargin), Row(1, null) :: Row(2, null) :: - Row(3, null) :: - Row(4, null) :: - Row(5, null) :: - Row(6, null) :: Nil) + Row(3, Double.NaN) :: + Row(4, Double.NaN) :: + Row(5, Double.NaN) :: + Row(6, Double.NaN) :: Nil) val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) @@ -841,11 +841,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // one row test val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") - val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) - assert(cov_samp3 == null) - - val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) - assert(cov_pop3 == 0.0) + checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN)) + checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } test("no aggregation function (SPARK-11486)") { From d0df2ca40953ba581dce199798a168af01283cdc Mon Sep 17 00:00:00 2001 From: Gabriele Nizzoli Date: Tue, 2 Feb 2016 13:20:01 -0800 Subject: [PATCH 1778/2089] [SPARK-13121][STREAMING] java mapWithState mishandles scala Option Already merged into 1.6 branch, this PR is to commit to master the same change Author: Gabriele Nizzoli Closes #11028 from gabrielenizzoli/patch-1. --- .../src/main/scala/org/apache/spark/streaming/StateSpec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 66f646d7dc136..e6724feaee105 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -221,7 +221,7 @@ object StateSpec { mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { - mappingFunction.call(k, Optional.ofNullable(v.get), s) + mappingFunction.call(k, JavaUtils.optionToOptional(v), s) } StateSpec.function(wrappedFunc) } From b377b03531d21b1d02a8f58b3791348962e1f31b Mon Sep 17 00:00:00 2001 From: "Kevin (Sangwoo) Kim" Date: Tue, 2 Feb 2016 13:24:09 -0800 Subject: [PATCH 1779/2089] [DOCS] Update StructType.scala The example will throw error like :20: error: not found: value StructType Need to add this line: import org.apache.spark.sql.types._ Author: Kevin (Sangwoo) Kim Closes #10141 from swkimme/patch-1. --- .../src/main/scala/org/apache/spark/sql/types/StructType.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c9e7e7fe633b8..e797d83cb05be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParse * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * * val struct = * StructType( From 6de6a97728408ee2619006decf2267cc43eeea0d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 16:24:31 -0800 Subject: [PATCH 1780/2089] [SPARK-13150] [SQL] disable two flaky tests Author: Davies Liu Closes #11037 from davies/disable_flaky. --- .../sql/hive/thriftserver/HiveThriftServer2Suites.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ba3b26e1b7d49..9860e40fe8546 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -488,7 +488,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } - test("SPARK-11595 ADD JAR with input path having URL scheme") { + // TODO: enable this + ignore("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" @@ -546,7 +547,8 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override protected def extraConf: Seq[String] = "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil - test("test single session") { + // TODO: enable this + ignore("test single session") { withMultipleConnectionJdbcStatement( { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" From 672032d0ab1e43bc5a25cecdb1b96dfd35c39778 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Feb 2016 08:26:35 +0800 Subject: [PATCH 1781/2089] [SPARK-13020][SQL][TEST] fix random generator for map type when we generate map, we first randomly pick a length, then create a seq of key value pair with the expected length, and finally call `toMap`. However, `toMap` will remove all duplicated keys, which makes the actual map size much less than we expected. This PR fixes this problem by put keys in a set first, to guarantee we have enough keys to build a map with expected length. Author: Wenchen Fan Closes #10930 from cloud-fan/random-generator. --- .../apache/spark/sql/RandomDataGenerator.scala | 18 ++++++++++++++---- .../spark/sql/RandomDataGeneratorSuite.scala | 11 +++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 55efea80d1a4d..7c173cbceefed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -47,9 +47,9 @@ object RandomDataGenerator { */ private val PROBABILITY_OF_NULL: Float = 0.1f - private val MAX_STR_LEN: Int = 1024 - private val MAX_ARR_SIZE: Int = 128 - private val MAX_MAP_SIZE: Int = 128 + final val MAX_STR_LEN: Int = 1024 + final val MAX_ARR_SIZE: Int = 128 + final val MAX_MAP_SIZE: Int = 128 /** * Helper function for constructing a biased random number generator which returns "interesting" @@ -208,7 +208,17 @@ object RandomDataGenerator { forType(valueType, nullable = valueContainsNull, rand) ) yield { () => { - Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + val length = rand.nextInt(MAX_MAP_SIZE) + val keys = scala.collection.mutable.HashSet(Seq.fill(length)(keyGenerator()): _*) + // In case the number of different keys is not enough, set a max iteration to avoid + // infinite loop. + var count = 0 + while (keys.size < length && count < MAX_MAP_SIZE) { + keys += keyGenerator() + count += 1 + } + val values = Seq.fill(keys.size)(valueGenerator()) + keys.zip(values).toMap } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index b8ccdf7516d82..9fba7924e9542 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -95,4 +95,15 @@ class RandomDataGeneratorSuite extends SparkFunSuite { } } + test("check size of generated map") { + val mapType = MapType(IntegerType, IntegerType) + for (seed <- 1 to 1000) { + val generator = RandomDataGenerator.forType( + mapType, nullable = false, rand = new Random(seed)).get + val maps = Seq.fill(100)(generator().asInstanceOf[Map[Int, Int]]) + val expectedTotalElements = 100 / 2 * RandomDataGenerator.MAX_MAP_SIZE + val deviation = math.abs(maps.map(_.size).sum - expectedTotalElements) + assert(deviation.toDouble / expectedTotalElements < 2e-1) + } + } } From 21112e8a14c042ccef4312079672108a1082a95e Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 2 Feb 2016 16:33:21 -0800 Subject: [PATCH 1782/2089] [SPARK-12992] [SQL] Update parquet reader to support more types when decoding to ColumnarBatch. This patch implements support for more types when doing the vectorized decode. There are a few more types remaining but they should be very straightforward after this. This code has a few copy and paste pieces but they are difficult to eliminate due to performance considerations. Specifically, this patch adds support for: - String, Long, Byte types - Dictionary encoding for those types. Author: Nong Li Closes #10908 from nongli/spark-12992. --- .../parquet/UnsafeRowParquetRecordReader.java | 146 ++++++++++++++-- .../parquet/VectorizedPlainValuesReader.java | 45 ++++- .../parquet/VectorizedRleValuesReader.java | 160 +++++++++++++++++- .../parquet/VectorizedValuesReader.java | 5 + .../execution/vectorized/ColumnVector.java | 7 +- .../parquet/ParquetEncodingSuite.scala | 82 +++++++++ 6 files changed, 424 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 17adfec32192f..b5dddb9f11b22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.List; +import org.apache.commons.lang.NotImplementedException; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; @@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException { int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { - switch (columnReaders[i].descriptor.getType()) { - case INT32: - columnReaders[i].readIntBatch(num, columnarBatch.column(i)); - break; - default: - throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType()); - } + columnReaders[i].readBatch(num, columnarBatch.column(i)); } rowsReturned += num; columnarBatch.setNumRows(num); @@ -237,7 +233,8 @@ private void initializeInternal() throws IOException { // TODO: Be extremely cautious in what is supported. Expand this. if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && - originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE && + originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && @@ -464,6 +461,11 @@ private final class ColumnReader { */ private boolean useDictionary; + /** + * If useDictionary is true, the staging vector used to decode the ids. + */ + private ColumnVector dictionaryIds; + /** * Maximum definition level for this column. */ @@ -587,9 +589,8 @@ private boolean next() throws IOException { /** * Reads `total` values from this columnReader into column. - * TODO: implement the other encodings. */ - private void readIntBatch(int total, ColumnVector column) throws IOException { + private void readBatch(int total, ColumnVector column) throws IOException { int rowId = 0; while (total > 0) { // Compute the number of values we want to read in this page. @@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException { leftInPage = (int)(endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0); - - // Remap the values if it is dictionary encoded. if (useDictionary) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(column.getInt(i))); + // Data is dictionary encoded. We will vector decode the ids and then resolve the values. + if (dictionaryIds == null) { + dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(total); + } + // Read and decode dictionary ids. + readIntBatch(rowId, num, dictionaryIds); + decodeDictionaryIds(rowId, num, column); + } else { + switch (descriptor.getType()) { + case INT32: + readIntBatch(rowId, num, column); + break; + case INT64: + readLongBatch(rowId, num, column); + break; + case BINARY: + readBinaryBatch(rowId, num, column); + break; + default: + throw new IOException("Unsupported type: " + descriptor.getType()); } } + valuesRead += num; rowId += num; total -= num; } } + /** + * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. + */ + private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + switch (descriptor.getType()) { + case INT32: + if (column.dataType() == DataTypes.IntegerType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case INT64: + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + break; + + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; + + default: + throw new NotImplementedException("Unsupported type: " + descriptor.getType()); + } + + if (dictionaryIds.numNulls() > 0) { + // Copy the NULLs over. + // TODO: we can improve this by decoding the NULLs directly into column. This would + // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then + // just do the ID remapping as above. + for (int i = 0; i < num; ++i) { + if (dictionaryIds.getIsNull(rowId + i)) { + column.putNull(rowId + i); + } + } + } + } + + /** + * For all the read*Batch functions, reads `num` values from this columnReader into column. It + * is guaranteed that num is smaller than the number of values left in the current page. + */ + + private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.IntegerType) { + defColumn.readIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); + } else if (column.dataType() == DataTypes.ByteType) { + defColumn.readBytes( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.LongType) { + defColumn.readLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.isArray()) { + defColumn.readBinarys( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readPage() throws IOException { DataPage page = pageReader.readPage(); // TODO: Why is this a visitor? diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index dac0c52ebd2cf..cec2418e46030 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -18,10 +18,13 @@ import java.io.IOException; +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.unsafe.Platform; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. @@ -52,15 +55,53 @@ public void skip(int n) { } @Override - public void readIntegers(int total, ColumnVector c, int rowId) { + public final void readIntegers(int total, ColumnVector c, int rowId) { c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public int readInteger() { + public final void readLongs(int total, ColumnVector c, int rowId) { + c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + + @Override + public final void readBytes(int total, ColumnVector c, int rowId) { + for (int i = 0; i < total; i++) { + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + c.putInt(rowId + i, buffer[offset]); + offset += 4; + } + } + + @Override + public final int readInteger() { int v = Platform.getInt(buffer, offset); offset += 4; return v; } + + @Override + public final long readLong() { + long v = Platform.getLong(buffer, offset); + offset += 8; + return v; + } + + @Override + public final byte readByte() { + return (byte)readInteger(); + } + + @Override + public final void readBinary(int total, ColumnVector v, int rowId) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + int start = offset; + offset += len; + v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 493ec9deed499..9bfd74db38766 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,12 +17,16 @@ package org.apache.spark.sql.execution.datasources.parquet; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; import org.apache.parquet.column.values.bitpacking.Packer; import org.apache.parquet.io.ParquetDecodingException; +import org.apache.parquet.io.api.Binary; + +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -35,7 +39,8 @@ * - Definition/Repetition levels * - Dictionary ids. */ -public final class VectorizedRleValuesReader extends ValuesReader { +public final class VectorizedRleValuesReader extends ValuesReader + implements VectorizedValuesReader { // Current decoding mode. The encoded data contains groups of either run length encoded data // (RLE) or bit packed data. Each group contains a header that indicates which group it is and // the number of values in the group. @@ -121,6 +126,7 @@ public int readValueDictionaryId() { return readInteger(); } + @Override public int readInteger() { if (this.currentCount == 0) { this.readNextGroup(); } @@ -138,7 +144,9 @@ public int readInteger() { /** * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader * reads the definition levels and then will read from `data` for the non-null values. - * If the value is null, c will be populated with `nullValue`. + * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only + * necessary for readIntegers because we also use it to decode dictionaryIds and want to make + * sure it always has a value in range. * * This is a batched version of this logic: * if (this.readInt() == level) { @@ -180,6 +188,154 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } } + // TODO: can this code duplication be removed without a perf penalty? + public void readBytes(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBytes(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putByte(rowId + i, data.readByte()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readLongs(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readLongs(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, data.readLong()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readBinarys(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + c.putNotNulls(rowId, n); + data.readBinary(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putNotNull(rowId + i); + data.readBinary(1, c, rowId); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + + // The RLE reader implements the vectorized decoding interface when used to decode dictionary + // IDs. This is different than the above APIs that decodes definitions levels along with values. + // Since this is only used to decode dictionary IDs, only decoding integers is supported. + @Override + public void readIntegers(int total, ColumnVector c, int rowId) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + c.putInts(rowId, n, currentValue); + break; + case PACKED: + c.putInts(rowId, n, currentBuffer, currentBufferIdx); + currentBufferIdx += n; + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + @Override + public byte readByte() { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBytes(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readLongs(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBinary(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void skip(int n) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + /** * Reads the next varint encoded int. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 49a9ed83d590a..b6ec7311c564a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -24,12 +24,17 @@ * TODO: merge this into parquet-mr. */ public interface VectorizedValuesReader { + byte readByte(); int readInteger(); + long readLong(); /* * Reads `total` values into `c` start at `c[rowId]` */ + void readBytes(int total, ColumnVector c, int rowId); void readIntegers(int total, ColumnVector c, int rowId); + void readLongs(int total, ColumnVector c, int rowId); + void readBinary(int total, ColumnVector c, int rowId); // TODO: add all the other parquet types. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a5bc506a65ac2..0514252a8e53d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -763,7 +763,12 @@ public final int appendStruct(boolean isNull) { /** * Returns the elements appended. */ - public int getElementsAppended() { return elementsAppended; } + public final int getElementsAppended() { return elementsAppended; } + + /** + * Returns true if this column is an array. + */ + public final boolean isArray() { return resultArray != null; } /** * Maximum number of rows that can be stored in this column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala new file mode 100644 index 0000000000000..cef6b79a094d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils +import org.apache.spark.sql.test.SharedSQLContext + +// TODO: this needs a lot more testing but it's currently not easy to test with the parquet +// writer abstractions. Revisit. +class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext { + import testImplicits._ + + val ROW = ((1).toByte, 2, 3L, "abc") + val NULL_ROW = ( + null.asInstanceOf[java.lang.Byte], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[String]) + + test("All Types Dictionary") { + (1 :: 1000 :: Nil).foreach { n => { + withTempPath { dir => + List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getByte(i) == 1) + assert(batch.column(1).getInt(i) == 2) + assert(batch.column(2).getLong(i) == 3) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + i += 1 + } + reader.close() + } + }} + } + + test("All Types Null") { + (1 :: 100 :: Nil).foreach { n => { + withTempPath { dir => + val data = List.fill(n)(NULL_ROW).toDF + data.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getIsNull(i)) + assert(batch.column(1).getIsNull(i)) + assert(batch.column(2).getIsNull(i)) + assert(batch.column(3).getIsNull(i)) + i += 1 + } + reader.close() + }} + } + } +} From ff71261b651a7b289ea2312abd6075da8b838ed9 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Tue, 2 Feb 2016 19:35:33 -0800 Subject: [PATCH 1783/2089] [SPARK-13122] Fix race condition in MemoryStore.unrollSafely() https://issues.apache.org/jira/browse/SPARK-13122 A race condition can occur in MemoryStore's unrollSafely() method if two threads that return the same value for currentTaskAttemptId() execute this method concurrently. This change makes the operation of reading the initial amount of unroll memory used, performing the unroll, and updating the associated memory maps atomic in order to avoid this race condition. Initial proposed fix wraps all of unrollSafely() in a memoryManager.synchronized { } block. A cleaner approach might be introduce a mechanism that synchronizes based on task attempt ID. An alternative option might be to track unroll/pending unroll memory based on block ID rather than task attempt ID. Author: Adam Budde Closes #11012 from budde/master. --- .../org/apache/spark/storage/MemoryStore.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 76aaa782b9524..024b660ce6a7b 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -255,8 +255,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this task, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisTask + // Keep track of pending unroll memory reserved by this method. + var pendingMemoryReserved = 0L // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] @@ -266,6 +266,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } else { + pendingMemoryReserved += initialMemoryThreshold } // Unroll this block safely, checking whether we have exceeded our threshold periodically @@ -278,6 +280,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) + if (keepUnrolling) { + pendingMemoryReserved += amountToRequest + } // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -304,10 +309,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo // release the unroll memory yet. Instead, we transfer it to pending unroll memory // so `tryToPut` can further transfer it to normal storage memory later. // TODO: we can probably express this without pending unroll memory (SPARK-10907) - val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved - unrollMemoryMap(taskAttemptId) -= amountToTransferToPending + unrollMemoryMap(taskAttemptId) -= pendingMemoryReserved pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + pendingMemoryReserved } } else { // Otherwise, if we return an iterator, we can only release the unroll memory when From 99a6e3c1e8d580ce1cc497bd9362eaf16c597f77 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 19:47:44 -0800 Subject: [PATCH 1784/2089] [SPARK-12951] [SQL] support spilling in generated aggregate This PR add spilling support for generated TungstenAggregate. If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated). The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite Author: Davies Liu Closes #10998 from davies/gen_spilling. --- .../aggregate/TungstenAggregate.scala | 172 +++++++++++++++--- 1 file changed, 142 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a8a81d6d6574e..f61db8594dab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -258,6 +258,7 @@ case class TungstenAggregate( // The name for HashMap private var hashMapTerm: String = _ + private var sorterTerm: String = _ /** * This is called by generated Java class, should be public. @@ -286,39 +287,98 @@ case class TungstenAggregate( GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } - /** - * Update peak execution memory, called in generated Java class. + * Called by generated Java class to finish the aggregate and return a KVIterator. */ - def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) val metrics = TaskContext.get().taskMetrics() - metrics.incPeakExecutionMemory(mapMemory) - } + metrics.incPeakExecutionMemory(peakMemory) - private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + if (sorter == null) { + // not spilled + return hashMap.iterator() + } - // create hashMap - val thisPlan = ctx.addReferenceObj("plan", this) - hashMapTerm = ctx.freshName("hashMap") - val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = newMutableProjection( + mergeExpr, + bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + subexpressionEliminationEnabled)() + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } - // Create a name for iterator from HashMap - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + false + } + } - // generate code for output - val keyTerm = ctx.freshName("aggKey") - val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + /** + * Generate the code for output. + */ + private def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + plan: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).gen(ctx) } ctx.INPUT_ROW = bufferTerm val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => @@ -348,7 +408,7 @@ case class TungstenAggregate( // This should be the last operator in a stage, we should output UnsafeRow directly val joinerTerm = ctx.freshName("unsafeRowJoiner") ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, - s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + s"$joinerTerm = $plan.createUnsafeJoiner();") val resultRow = ctx.freshName("resultRow") s""" UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); @@ -367,6 +427,23 @@ case class TungstenAggregate( ${consume(ctx, eval)} """ } + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + sorterTerm = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, @@ -374,10 +451,15 @@ case class TungstenAggregate( private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - $iterTerm = $hashMapTerm.iterator(); + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); } """) + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + s""" if (!$initAgg) { $initAgg = true; @@ -391,8 +473,10 @@ case class TungstenAggregate( $outputCode } - $thisPlan.updatePeakMemory($hashMapTerm); - $hashMapTerm.free(); + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } """ } @@ -425,14 +509,42 @@ case class TungstenAggregate( ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) } + val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) { + val countTerm = ctx.freshName("fallbackCounter") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + ("true", "", "") + } + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the + // hash map will return null for new key), we spill the hash map to disk to free memory, then + // continue to do in-memory aggregation and spilling until all the rows had been processed. + // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key ${keyCode.code} - UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + UnsafeRow $buffer = null; + if ($checkFallback) { + // try to get the buffer from hash map + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + } if ($buffer == null) { - // failed to allocate the first page - throw new OutOfMemoryError("No enough memory for aggregation"); + if ($sorterTerm == null) { + $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + } else { + $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + } + $resetCoulter + // the hash map had be spilled, it should have enough memory now, + // try to allocate buffer again. + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } } + $incCounter // evaluate aggregate function ${evals.map(_.code).mkString("\n")} From 0557146619868002e2f7ec3c121c30bbecc918fc Mon Sep 17 00:00:00 2001 From: Imran Younus Date: Tue, 2 Feb 2016 20:38:53 -0800 Subject: [PATCH 1785/2089] [SPARK-12732][ML] bug fix in linear regression train Fixed the bug in linear regression train for the case when the target variable is constant. The two cases for `fitIntercept=true` or `fitIntercept=false` should be treated differently. Author: Imran Younus Closes #10702 from iyounus/SPARK-12732_bug_fix_in_linear_regression_train. --- .../ml/regression/LinearRegression.scala | 66 ++++++----- .../ml/regression/LinearRegressionSuite.scala | 105 ++++++++++++++++++ 2 files changed, 146 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c54e08b2ad9a5..e253f25c0ea65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -219,33 +219,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val yMean = ySummarizer.mean(0) - val yStd = math.sqrt(ySummarizer.variance(0)) - - // If the yStd is zero, then the intercept is yMean with zero coefficient; - // as a result, training is not needed. - if (yStd == 0.0) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - if (handlePersistence) instances.unpersist() - val coefficients = Vectors.sparse(numFeatures, Seq()) - val intercept = yMean - - val model = new LinearRegressionModel(uid, coefficients, intercept) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - model, - Array(0D), - $(featuresCol), - Array(0D)) - return copyValues(model.setSummary(trainingSummary)) + val rawYStd = math.sqrt(ySummarizer.variance(0)) + if (rawYStd == 0.0) { + if ($(fitIntercept) || yMean==0.0) { + // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with + // zero coefficient; as a result, training is not needed. + // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of + // the fitIntercept + if (yMean == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + if (handlePersistence) instances.unpersist() + val coefficients = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, coefficients, intercept) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + model, + Array(0D), + $(featuresCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) + } else { + require($(regParam) == 0.0, "The standard deviation of the label is zero. " + + "Model cannot be regularized.") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } } + // if y is constant (rawYStd is zero), then y cannot be scaled. In this case + // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm. + val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val featuresMean = featuresSummarizer.mean.toArray val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 273c882c2a47f..81fc6603ccfe6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -37,6 +37,8 @@ class LinearRegressionSuite @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ @transient var datasetWithWeight: DataFrame = _ + @transient var datasetWithWeightConstantLabel: DataFrame = _ + @transient var datasetWithWeightZeroLabel: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -92,6 +94,29 @@ class LinearRegressionSuite Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + df.const.label <- as.data.frame(cbind(A, b.const)) + */ + datasetWithWeightConstantLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + datasetWithWeightZeroLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -558,6 +583,86 @@ class LinearRegressionSuite } } + test("linear regression model with constant label") { + /* + R code: + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + Seq("auto", "l-bfgs", "normal").foreach { solver => + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightConstantLabel) + val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), + model1.coefficients(1)) + assert(actual1 ~== expected(idx) absTol 1e-4) + + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightZeroLabel) + val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), + model2.coefficients(1)) + assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) + idx += 1 + } + } + } + + test("regularized linear regression through origin with constant label") { + // The problem is ill-defined if fitIntercept=false, regParam is non-zero. + // An exception is thrown in this case. + Seq("auto", "l-bfgs", "normal").foreach { solver => + for (standardization <- Seq(false, true)) { + val model = new LinearRegression().setFitIntercept(false) + .setRegParam(0.1).setStandardization(standardization).setSolver(solver) + intercept[IllegalArgumentException] { + model.fit(datasetWithWeightConstantLabel) + } + } + } + } + + test("linear regression with l-bfgs when training is not needed") { + // When label is constant, l-bfgs solver returns results without training. + // There are two possibilities: If the label is non-zero but constant, + // and fitIntercept is true, then the model return yMean as intercept without training. + // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so + // no training is needed. + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightConstantLabel) + if (fitIntercept) { + assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightZeroLabel) + assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + } + } + test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer = new LinearRegression().setSolver(solver) From 335f10edad8c759bad3dbd0660ed4dd5d70ddd8b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 2 Feb 2016 21:13:54 -0800 Subject: [PATCH 1786/2089] [SPARK-7997][CORE] Add rpcEnv.awaitTermination() back to SparkEnv `rpcEnv.awaitTermination()` was not added in #10854 because some Streaming Python tests hung forever. This patch fixed the hung issue and added rpcEnv.awaitTermination() back to SparkEnv. Previously, Streaming Kafka Python tests shutdowns the zookeeper server before stopping StreamingContext. Then when stopping StreamingContext, KafkaReceiver may be hung due to https://issues.apache.org/jira/browse/KAFKA-601, hence, some thread of RpcEnv's Dispatcher cannot exit and rpcEnv.awaitTermination is hung.The patch just changed the shutdown order to fix it. Author: Shixiong Zhu Closes #11031 from zsxwing/awaitTermination. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 1 + python/pyspark/streaming/tests.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 12c7b2048a8c8..9461afdc54124 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -91,6 +91,7 @@ class SparkEnv ( metricsSystem.stop() outputCommitCoordinator.stop() rpcEnv.shutdown() + rpcEnv.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 24b812615cbb4..b33e8252a7d32 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1013,12 +1013,12 @@ def setUp(self): self._kafkaTestUtils.setup() def tearDown(self): + super(KafkaStreamTests, self).tearDown() + if self._kafkaTestUtils is not None: self._kafkaTestUtils.teardown() self._kafkaTestUtils = None - super(KafkaStreamTests, self).tearDown() - def _randomTopic(self): return "topic-%d" % random.randint(0, 10000) From e86f8f63bfa3c15659b94e831b853b1bc9ddae32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 22:13:10 -0800 Subject: [PATCH 1787/2089] [SPARK-13147] [SQL] improve readability of generated code 1. try to avoid the suffix (unique id) 2. remove the comment if there is no code generated. 3. re-arrange the order of functions 4. trop the new line for inlined blocks. Author: Davies Liu Closes #11032 from davies/better_suffix. --- .../sql/catalyst/expressions/Expression.scala | 8 +++-- .../expressions/codegen/CodeGenerator.scala | 27 ++++++++++------ .../expressions/complexTypeExtractors.scala | 31 +++++++++++-------- .../sql/execution/WholeStageCodegen.scala | 13 ++++---- .../aggregate/TungstenAggregate.scala | 14 ++++----- .../spark/sql/execution/basicOperators.scala | 7 ++++- .../BenchmarkWholeStageCodegen.scala | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 353fb92581d3b..c73b2f8f2a316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] { val value = ctx.freshName("value") val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + if (ve.code != "") { + // Add `this` in the comment. + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + } else { + ve + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a30aba16170a9..63e19564dd861 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -156,7 +156,11 @@ class CodegenContext { /** The variable name of the input row in generated code. */ final var INPUT_ROW = "i" - private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * The map from a variable name to it's next ID. + */ + private val freshNameIds = new mutable.HashMap[String, Int] + freshNameIds += INPUT_ROW -> 1 /** * A prefix used to generate fresh name. @@ -164,16 +168,21 @@ class CodegenContext { var freshNamePrefix = "" /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Returns a term name that is unique within this instance of a `CodegenContext`. */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" + def freshName(name: String): String = synchronized { + val fullName = if (freshNamePrefix == "") { + name + } else { + s"${freshNamePrefix}_$name" + } + if (freshNameIds.contains(fullName)) { + val id = freshNameIds(fullName) + freshNameIds(fullName) = id + 1 + s"$fullName$id" } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" + freshNameIds += fullName -> 1 + fullName } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9f2f82d68cca0..6b24fae9f3f1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -173,22 +173,26 @@ case class GetArrayStructFields( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { + val n = ctx.freshName("n") + val values = ctx.freshName("values") + val j = ctx.freshName("j") + val row = ctx.freshName("row") s""" - final int n = $eval.numElements(); - final Object[] values = new Object[n]; - for (int j = 0; j < n; j++) { - if ($eval.isNullAt(j)) { - values[j] = null; + final int $n = $eval.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($eval.isNullAt($j)) { + $values[$j] = null; } else { - final InternalRow row = $eval.getStruct(j, $numFields); - if (row.isNullAt($ordinal)) { - values[j] = null; + final InternalRow $row = $eval.getStruct($j, $numFields); + if ($row.isNullAt($ordinal)) { + $values[$j] = null; } else { - values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } } - ${ev.value} = new $arrayClass(values); + ${ev.value} = new $arrayClass($values); """ }) } @@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") s""" - final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) { + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; + ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; } """ }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 02b0f423ed438..14754969072f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { s""" | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); - | ${columns.map(_.code).mkString("\n")} - | ${consume(ctx, columns)} + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} | } """.stripMargin } @@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public GeneratedIterator(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} + this.references = references; + ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + protected void processNext() throws java.io.IOException { - $code + ${code.trim} } } """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f61db8594dab2..d0244770613d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -211,9 +211,9 @@ case class TungstenAggregate( | $doAgg(); | | // output the result - | $genResult + | ${genResult.trim} | - | ${consume(ctx, resultVars)} + | ${consume(ctx, resultVars).trim} | } """.stripMargin } @@ -242,9 +242,9 @@ case class TungstenAggregate( } s""" | // do aggregate - | ${aggVals.map(_.code).mkString("\n")} + | ${aggVals.map(_.code).mkString("\n").trim} | // update aggregation buffer - | ${updates.mkString("")} + | ${updates.mkString("\n").trim} """.stripMargin } @@ -523,7 +523,7 @@ case class TungstenAggregate( // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key - ${keyCode.code} + ${keyCode.code.trim} UnsafeRow $buffer = null; if ($checkFallback) { // try to get the buffer from hash map @@ -547,9 +547,9 @@ case class TungstenAggregate( $incCounter // evaluate aggregate function - ${evals.map(_.code).mkString("\n")} + ${evals.map(_.code).mkString("\n").trim} // update aggregate buffer - ${updates.mkString("\n")} + ${updates.mkString("\n").trim} """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fd81531c9316a..ae4422195cc4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit BindReferences.bindReference(condition, child.output)) ctx.currentVars = input val eval = expr.gen(ctx) + val nullCheck = if (expr.nullable) { + s"!${eval.isNull} &&" + } else { + s"" + } s""" | ${eval.code} - | if (!${eval.isNull} && ${eval.value}) { + | if ($nullCheck ${eval.value}) { | ${consume(ctx, ctx.currentVars)} | } """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 1ccf0e3d0656c..ec2b9ab2cbad5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // These benchmark are skipped in normal build ignore("benchmark") { // testWholeStage(200 << 20) - // testStddev(20 << 20) + // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } From 138c300f97d29cb0d04a70bea98a8a0c0548318a Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 2 Feb 2016 22:22:50 -0800 Subject: [PATCH 1788/2089] [SPARK-12957][SQL] Initial support for constraint propagation in SparkSQL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on the semantics of a query, we can derive a number of data constraints on output of each (logical or physical) operator. For instance, if a filter defines `‘a > 10`, we know that the output data of this filter satisfies 2 constraints: 1. `‘a > 10` 2. `isNotNull(‘a)` This PR proposes a possible way of keeping track of these constraints and propagating them in the logical plan, which can then help us build more advanced optimizations (such as pruning redundant filters, optimizing joins, among others). We define constraints as a set of (implicitly conjunctive) expressions. For e.g., if a filter operator has constraints = `Set(‘a > 10, ‘b < 100)`, it’s implied that the outputs satisfy both individual constraints (i.e., `‘a > 10` AND `‘b < 100`). Design Document: https://docs.google.com/a/databricks.com/document/d/1WQRgDurUBV9Y6CWOBS75PQIqJwT-6WftVa18xzm7nCo/edit?usp=sharing Author: Sameer Agarwal Closes #10844 from sameeragarwal/constraints. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 55 +++++- .../catalyst/plans/logical/LogicalPlan.scala | 2 + .../plans/logical/basicOperators.scala | 79 +++++++- .../plans/ConstraintPropagationSuite.scala | 173 ++++++++++++++++++ 4 files changed, 302 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b43b7ee71e7aa..05f5bdbfc0769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -26,6 +26,56 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def output: Seq[Attribute] + /** + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. + */ + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + /** + * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions. + * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form + * `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints.map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) + } + + /** + * A sequence of expressions that describes the data property of the output rows of this + * operator. For example, if the output of this operator is column `a`, an example `constraints` + * can be `Set(a > 10, a < 20)`. + */ + lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]] + */ + protected def validConstraints: Set[Expression] = Set.empty + /** * Returns the set of attributes that are output by this node. */ @@ -59,6 +109,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { @@ -67,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { @@ -99,6 +151,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. * @return */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d859551f8c52..d8944a424156e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -305,6 +305,8 @@ abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil + + override protected def validConstraints: Set[Expression] = child.constraints } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 16f4b355b1b6c..8150ff8434762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -87,11 +87,27 @@ case class Generate( } } -case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { +case class Filter(condition: Expression, child: LogicalPlan) + extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + + override protected def validConstraints: Set[Expression] = { + child.constraints.union(splitConjunctivePredicates(condition).toSet) + } } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + + protected def leftConstraints: Set[Expression] = left.constraints + + protected def rightConstraints: Set[Expression] = { + require(left.output.size == right.output.size) + val attributeRewrites = AttributeMap(right.output.zip(left.output)) + right.constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } +} private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -106,6 +122,10 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override protected def validConstraints: Set[Expression] = { + leftConstraints.union(rightConstraints) + } + // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. override lazy val resolved: Boolean = @@ -119,6 +139,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override protected def validConstraints: Set[Expression] = leftConstraints + override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && @@ -157,13 +179,36 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } + + /** + * Maps the constraints containing a given (original) sequence of attributes to those with a + * given (reference) sequence of attributes. Given the nature of union, we expect that the + * mapping between the original and reference sequences are symmetric. + */ + private def rewriteConstraints( + reference: Seq[Attribute], + original: Seq[Attribute], + constraints: Set[Expression]): Set[Expression] = { + require(reference.size == original.size) + val attributeRewrites = AttributeMap(original.zip(reference)) + constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + + override protected def validConstraints: Set[Expression] = { + children + .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) + .reduce(_ intersect _) + } } case class Join( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { joinType match { @@ -180,6 +225,28 @@ case class Join( } } + override protected def validConstraints: Set[Expression] = { + joinType match { + case Inner if condition.isDefined => + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + case LeftSemi if condition.isDefined => + left.constraints + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftSemi => + left.constraints + case LeftOuter => + left.constraints + case RightOuter => + right.constraints + case FullOuter => + Set.empty[Expression] + } + } + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala new file mode 100644 index 0000000000000..b5cf91394d910 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class ConstraintPropagationSuite extends SparkFunSuite { + + private def resolveColumn(tr: LocalRelation, columnName: String): Expression = + tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + + private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { + val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) + val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } + } + + test("propagating constraints in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a")))) + + verifyConstraints(tr + .where('a.attr > 10) + .select('c.attr, 'a.attr) + .where('c.attr < 100) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") < 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c")))) + } + + test("propagating constraints in union") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + + assert(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('e.attr > 10) + .unionAll(tr3.where('i.attr > 10))) + .analyze.constraints.isEmpty) + + verifyConstraints(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in intersect") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr1 + .where('a.attr > 10) + .intersect(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b")))) + } + + test("propagating constraints in except") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 + .where('a.attr > 10) + .except(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) + } + + test("propagating constraints in inner join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-semi join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in right-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + } + + test("propagating constraints in full-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + assert(tr1.where('a.attr > 10) + .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints.isEmpty) + } +} From e9eb248edfa81d75f99c9afc2063e6b3d9ee7392 Mon Sep 17 00:00:00 2001 From: Mario Briggs Date: Wed, 3 Feb 2016 09:50:28 -0800 Subject: [PATCH 1789/2089] [SPARK-12739][STREAMING] Details of batch in Streaming tab uses two Duration columns I have clearly prefix the two 'Duration' columns in 'Details of Batch' Streaming tab as 'Output Op Duration' and 'Job Duration' Author: Mario Briggs Author: mariobriggs Closes #11022 from mariobriggs/spark-12739. --- .../main/scala/org/apache/spark/streaming/ui/BatchPage.scala | 4 ++-- .../scala/org/apache/spark/streaming/UISeleniumSuite.scala | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 7635f79a3d2d1..81de07f933f8a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -37,10 +37,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private def columns: Seq[Node] = { Output Op Id Description - Duration + Output Op Duration Status Job Id - Duration + Job Duration Stages: Succeeded/Total Tasks (for all stages): Succeeded/Total Error diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index c4ecebcacf3c8..96dd4757be855 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -143,8 +143,9 @@ class UISeleniumSuite summaryText should contain ("Total delay:") findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { - List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration", - "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") + List("Output Op Id", "Description", "Output Op Duration", "Status", "Job Id", + "Job Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", + "Error") } // Check we have 2 output op ids From c4feec26eb677bfd3bfac38e5e28eae05279956e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Feb 2016 10:38:53 -0800 Subject: [PATCH 1790/2089] [SPARK-12798] [SQL] generated BroadcastHashJoin A row from stream side could match multiple rows on build side, the loop for these matched rows should not be interrupted when emitting a row, so we buffer the output rows in a linked list, check the termination condition on producer loop (for example, Range or Aggregate). Author: Davies Liu Closes #10989 from davies/gen_join. --- .../sql/execution/BufferedRowIterator.java | 30 ++++-- .../sql/execution/WholeStageCodegen.scala | 18 ++-- .../aggregate/TungstenAggregate.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 2 + .../execution/joins/BroadcastHashJoin.scala | 92 ++++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 28 +++++- .../execution/WholeStageCodegenSuite.scala | 15 ++- 7 files changed, 169 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index 6acf70dbbad0c..ea20115770f79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -18,9 +18,11 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.LinkedList; import scala.collection.Iterator; +import org.apache.spark.TaskContext; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -31,28 +33,42 @@ * TODO: replaced it by batched columnar format. */ public class BufferedRowIterator { - protected InternalRow currentRow; + protected LinkedList currentRows = new LinkedList<>(); protected Iterator input; // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); public boolean hasNext() throws IOException { - if (currentRow == null) { + if (currentRows.isEmpty()) { processNext(); } - return currentRow != null; + return !currentRows.isEmpty(); } public InternalRow next() { - InternalRow r = currentRow; - currentRow = null; - return r; + return currentRows.remove(); } public void setInput(Iterator iter) { input = iter; } + /** + * Returns whether `processNext()` should stop processing next row from `input` or not. + * + * If it returns true, the caller should exit the loop (return from processNext()). + */ + protected boolean shouldStop() { + return !currentRows.isEmpty(); + } + + /** + * Increase the peak execution memory for current task. + */ + protected void incPeakExecutionMemory(long size) { + TaskContext.get().taskMetrics().incPeakExecutionMemory(size); + } + /** * Processes the input until have a row as output (currentRow). * @@ -60,7 +76,7 @@ public void setInput(Iterator iter) { */ protected void processNext() throws IOException { if (input.hasNext()) { - currentRow = input.next(); + currentRows.add(input.next()); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 14754969072f9..131efea20f31e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} import org.apache.spark.util.Utils /** @@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} + | if (shouldStop()) { + | return; + | } | } """.stripMargin } @@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRow = $row; - | return; + | currentRows.add($row.copy()); """.stripMargin } else { assert(input != null) @@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" | ${code.code.trim} - | currentRow = ${code.value}; - | return; + | currentRows.add(${code.value}.copy()); """.stripMargin } else { // There is no columns s""" - | currentRow = unsafeRow; - | return; + | currentRows.add(unsafeRow); """.stripMargin } } @@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { + // The build side can't be compiled together + case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + b.copy(left = apply(left)) + case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d0244770613d3..9d9f14f2dd014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -471,6 +471,8 @@ case class TungstenAggregate( UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); $outputCode + + if (shouldStop()) return; } $iterTerm.close(); @@ -480,7 +482,7 @@ case class TungstenAggregate( """ } - private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key ctx.currentVars = input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ae4422195cc4b..6e51c4d84824a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -237,6 +237,8 @@ case class Range( | $overflow = true; | } | ${consume(ctx, Seq(ev))} + | + | if (shouldStop()) return; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 04640711d99d0..8b275e886c46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -42,7 +45,7 @@ case class BroadcastHashJoin( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryNode with HashJoin { + extends BinaryNode with HashJoin with CodegenSupport { override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), @@ -117,6 +120,87 @@ case class BroadcastHashJoin( hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + + // the term for hash relation + private var relationTerm: String = _ + + override def upstream(): RDD[InternalRow] = { + streamedPlan.asInstanceOf[CodegenSupport].upstream() + } + + override def doProduce(ctx: CodegenContext): String = { + // create a name for HashRelation + val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + relationTerm = ctx.freshName("relation") + // TODO: create specialized HashRelation for single join key + val clsName = classOf[UnsafeHashedRelation].getName + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = ($clsName) $broadcast.value(); + | incPeakExecutionMemory($relationTerm.getUnsafeSize()); + """.stripMargin) + + s""" + | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} + """.stripMargin + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // generate the key as UnsafeRow + ctx.currentVars = input + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val keyTerm = keyVal.value + val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + + // find the matches from HashedRelation + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val row = ctx.freshName("row") + + // create variables for output + ctx.currentVars = null + ctx.INPUT_ROW = row + val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + val resultVars = buildSide match { + case BuildLeft => buildColumns ++ input + case BuildRight => input ++ buildColumns + } + + val ouputCode = if (condition.isDefined) { + // filter the output via condition + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + | ${ev.code} + | if (!${ev.isNull} && ${ev.value}) { + | ${consume(ctx, resultVars)} + | } + """.stripMargin + } else { + consume(ctx, resultVars) + } + + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $row = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $ouputCode + | } + | } + """.stripMargin + } } object BroadcastHashJoin { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index ec2b9ab2cbad5..15ba77353109c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.functions._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() } + def testBroadcastHashJoin(values: Int): Unit = { + val benchmark = new Benchmark("BroadcastHashJoin", values) + + val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + + benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X + BroadcastHashJoin w codegen 1028.40 10.20 2.97 X + */ + benchmark.run() + } + def testBytesToBytesMap(values: Int): Unit = { val benchmark = new Benchmark("BytesToBytesMap", values) @@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // testWholeStage(200 << 20) // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) - // testBytesToBytesMap(1024 * 1024 * 50) + // testBytesToBytesMap(50 << 20) + // testBroadcastHashJoin(10 << 20) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c2516509dfbbf..9350205d791d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.functions.{avg, col, max} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } + + test("BroadcastHashJoin should be included in WholeStageCodegen") { + val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val schema = new StructType().add("k", IntegerType).add("v", StringType) + val smallDF = sqlContext.createDataFrame(rdd, schema) + val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + assert(df.queryExecution.executedPlan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) + assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) + } } From 9dd2741ebe5f9b5fa0a3b0e9c594d0e94b6226f9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 3 Feb 2016 12:31:30 -0800 Subject: [PATCH 1791/2089] [SPARK-13157] [SQL] Support any kind of input for SQL commands. The ```SparkSqlLexer``` currently swallows characters which have not been defined in the grammar. This causes problems with SQL commands, such as: ```add jar file:///tmp/ab/TestUDTF.jar```. In this example the `````` is swallowed. This PR adds an extra Lexer rule to handle such input, and makes a tiny modification to the ```ASTNode```. cc davies liancheng Author: Herman van Hovell Closes #11052 from hvanhovell/SPARK-13157. --- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 4 ++ .../spark/sql/catalyst/parser/ASTNode.scala | 4 +- .../sql/catalyst/parser/ASTNodeSuite.scala | 38 +++++++++++++++++++ .../HiveThriftServer2Suites.scala | 6 +-- 4 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index e930caa291d4f..1d07a27353dcb 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -483,3 +483,7 @@ COMMENT { $channel=HIDDEN; } ; +/* Prevent that the lexer swallows unknown characters. */ +ANY + :. + ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala index ec9812414e19f..28f7b10ed6a59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -58,12 +58,12 @@ case class ASTNode( override val origin: Origin = Origin(Some(line), Some(positionInLine)) /** Source text. */ - lazy val source: String = stream.toString(startIndex, stopIndex) + lazy val source: String = stream.toOriginalString(startIndex, stopIndex) /** Get the source text that remains after this token. */ lazy val remainder: String = { stream.fill() - stream.toString(stopIndex + 1, stream.size() - 1).trim() + stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim() } def text: String = token.getText diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala new file mode 100644 index 0000000000000..8b05f9e33d69e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ASTNodeSuite extends SparkFunSuite { + test("SPARK-13157 - remainder must return all input chars") { + val inputs = Seq( + ("add jar", "file:///tmp/ab/TestUDTF.jar"), + ("add jar", "file:///tmp/a@b/TestUDTF.jar"), + ("add jar", "c:\\windows32\\TestUDTF.jar"), + ("add jar", "some \nbad\t\tfile\r\n.\njar"), + ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"), + ("SET", "foo=bar"), + ("SET", "foo*)(@#^*@&!#^=bar") + ) + inputs.foreach { + case (command, arguments) => + val node = ParseDriver.parsePlan(s"$command $arguments", null) + assert(node.remainder === arguments) + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 9860e40fe8546..ba3b26e1b7d49 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -488,8 +488,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } - // TODO: enable this - ignore("SPARK-11595 ADD JAR with input path having URL scheme") { + test("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" @@ -547,8 +546,7 @@ class SingleSessionSuite extends HiveThriftJdbcTest { override protected def extraConf: Seq[String] = "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil - // TODO: enable this - ignore("test single session") { + test("test single session") { withMultipleConnectionJdbcStatement( { statement => val jarPath = "../hive/src/test/resources/TestUDTF.jar" From 3221eddb8f9728f65c579969a3a88baeeb7577a9 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 3 Feb 2016 15:53:10 -0800 Subject: [PATCH 1792/2089] [SPARK-3611][WEB UI] Show number of cores for each executor in application web UI Added a Cores column in the Executors UI Author: Alex Bozarth Closes #11039 from ajbozarth/spark3611. --- .../main/scala/org/apache/spark/status/api/v1/api.scala | 1 + .../scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 7 +++++++ .../main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 5 +++-- .../executor_list_json_expectation.json | 1 + 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2b0079f5fd62e..d116e68c17f18 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -57,6 +57,7 @@ class ExecutorSummary private[spark]( val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, + val totalCores: Int, val maxTasks: Int, val activeTasks: Int, val failedTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index e36b96b3e6978..e1f754999912b 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -75,6 +75,7 @@ private[ui] class ExecutorsPage( RDD Blocks Storage Memory Disk Used + Cores Active Tasks Failed Tasks Complete Tasks @@ -131,6 +132,7 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} + {info.totalCores} {taskData(info.maxTasks, info.activeTasks, info.failedTasks, info.completedTasks, info.totalTasks, info.totalDuration, info.totalGCTime)} @@ -174,6 +176,7 @@ private[ui] class ExecutorsPage( val maximumMemory = execInfo.map(_.maxMemory).sum val memoryUsed = execInfo.map(_.memoryUsed).sum val diskUsed = execInfo.map(_.diskUsed).sum + val totalCores = execInfo.map(_.totalCores).sum val totalInputBytes = execInfo.map(_.totalInputBytes).sum val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum @@ -188,6 +191,7 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(diskUsed)} + {totalCores} {taskData(execInfo.map(_.maxTasks).sum, execInfo.map(_.activeTasks).sum, execInfo.map(_.failedTasks).sum, @@ -211,6 +215,7 @@ private[ui] class ExecutorsPage( RDD Blocks Storage Memory Disk Used + Cores Active Tasks Failed Tasks Complete Tasks @@ -305,6 +310,7 @@ private[spark] object ExecutorsPage { val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed + val totalCores = listener.executorToTotalCores.getOrElse(execId, 0) val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0) val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0) val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) @@ -323,6 +329,7 @@ private[spark] object ExecutorsPage { rddBlocks, memUsed, diskUsed, + totalCores, maxTasks, activeTasks, failedTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index a9e926b158780..dcfebe92ed805 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -45,6 +45,7 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { + val executorToTotalCores = HashMap[String, Int]() val executorToTasksMax = HashMap[String, Int]() val executorToTasksActive = HashMap[String, Int]() val executorToTasksComplete = HashMap[String, Int]() @@ -65,8 +66,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap - executorToTasksMax(eid) = - executorAdded.executorInfo.totalCores / conf.getInt("spark.task.cpus", 1) + executorToTotalCores(eid) = executorAdded.executorInfo.totalCores + executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1) executorIdToData(eid) = ExecutorUIData(executorAdded.time) } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 94f8aeac55b5d..9d5d224e55176 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -4,6 +4,7 @@ "rddBlocks" : 8, "memoryUsed" : 28000128, "diskUsed" : 0, + "totalCores" : 0, "maxTasks" : 0, "activeTasks" : 0, "failedTasks" : 1, From 915a75398ecbccdbf9a1e07333104c857ae1ce5e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Feb 2016 16:10:11 -0800 Subject: [PATCH 1793/2089] [SPARK-13166][SQL] Remove DataStreamReader/Writer They seem redundant and we can simply use DataFrameReader/Writer. The new usage looks like: ```scala val df = sqlContext.read.stream("...") val handle = df.write.stream("...") handle.stop() ``` Author: Reynold Xin Closes #11062 from rxin/SPARK-13166. --- .../org/apache/spark/sql/DataFrame.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 29 +++- .../apache/spark/sql/DataFrameWriter.scala | 36 ++++- .../apache/spark/sql/DataStreamReader.scala | 127 ----------------- .../apache/spark/sql/DataStreamWriter.scala | 134 ------------------ .../org/apache/spark/sql/SQLContext.scala | 11 +- .../datasources/ResolvedDataSource.scala | 1 - .../sql/streaming/DataStreamReaderSuite.scala | 53 ++++--- 8 files changed, 86 insertions(+), 315 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6de17e5924d04..84203bbfef66a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1682,7 +1682,7 @@ class DataFrame private[sql]( /** * :: Experimental :: - * Interface for saving the content of the [[DataFrame]] out into external storage. + * Interface for saving the content of the [[DataFrame]] out into external storage or streams. * * @group output * @since 1.4.0 @@ -1690,14 +1690,6 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) - /** - * :: Experimental :: - * Interface for starting a streaming query that will continually output results to the specified - * external sink as new data arrives. - */ - @Experimental - def streamTo: DataStreamWriter = new DataStreamWriter(this) - /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2e0c6c7df967e..a58643a5ba154 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,17 +29,17 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystQl} import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType /** * :: Experimental :: * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc). Use [[SQLContext.read]] to access this. + * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this. * * @since 1.4.0 */ @@ -136,6 +136,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } + /** + * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path + * (e.g. external key-value stores). + * + * @since 2.0.0 + */ + def stream(): DataFrame = { + val resolved = ResolvedDataSource.createSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + providerName = source, + options = extraOptions.toMap) + DataFrame(sqlContext, StreamingRelation(resolved)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def stream(path: String): DataFrame = { + option("path", path).stream() + } + /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. @@ -165,7 +189,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 12eb2393634a9..80447fefe1f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,17 +22,18 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.sources.HadoopFsRelation /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc). Use [[DataFrame.write]] to access this. + * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. * * @since 1.4.0 */ @@ -183,6 +184,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { df) } + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def stream(path: String): ContinuousQuery = { + option("path", path).stream() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def stream(): ContinuousQuery = { + val sink = ResolvedDataSource.createSink( + df.sqlContext, + source, + extraOptions.toMap, + normalizedParCols.getOrElse(Nil)) + + new StreamExecution(df.sqlContext, df.logicalPlan, sink) + } + /** * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. @@ -255,7 +284,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * The given column name may not be equal to any of the existing column names if we were in - * case-insensitive context. Normalize the given column name to the real one so that we don't + * case-insensitive context. Normalize the given column name to the real one so that we don't * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { @@ -339,7 +368,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala deleted file mode 100644 index 2febc93fa49d4..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamReader.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.datasources.ResolvedDataSource -import org.apache.spark.sql.execution.streaming.StreamingRelation -import org.apache.spark.sql.types.StructType - -/** - * :: Experimental :: - * An interface to reading streaming data. Use `sqlContext.streamFrom` to access these methods. - * - * {{{ - * val df = sqlContext.streamFrom - * .format("...") - * .open() - * }}} - */ -@Experimental -class DataStreamReader private[sql](sqlContext: SQLContext) extends Logging { - - /** - * Specifies the input data source format. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamReader = { - this.source = source - this - } - - /** - * Specifies the input schema. Some data streams (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data stream can - * skip the schema inference step, and thus speed up data reading. - * - * @since 2.0.0 - */ - def schema(schema: StructType): DataStreamReader = { - this.userSpecifiedSchema = Option(schema) - this - } - - /** - * Adds an input option for the underlying data stream. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamReader = { - this.extraOptions += (key -> value) - this - } - - /** - * (Scala-specific) Adds input options for the underlying data stream. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { - this.extraOptions ++= options - this - } - - /** - * Adds input options for the underlying data stream. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { - this.options(options.asScala) - this - } - - /** - * Loads streaming input in as a [[DataFrame]], for data streams that don't require a path (e.g. - * external key-value stores). - * - * @since 2.0.0 - */ - def open(): DataFrame = { - val resolved = ResolvedDataSource.createSource( - sqlContext, - userSpecifiedSchema = userSpecifiedSchema, - providerName = source, - options = extraOptions.toMap) - DataFrame(sqlContext, StreamingRelation(resolved)) - } - - /** - * Loads input in as a [[DataFrame]], for data streams that read from some path. - * - * @since 2.0.0 - */ - def open(path: String): DataFrame = { - option("path", path).open() - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = sqlContext.conf.defaultDataSourceName - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = new scala.collection.mutable.HashMap[String, String] - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala deleted file mode 100644 index b325d48fcbbb1..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataStreamWriter.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.datasources.ResolvedDataSource -import org.apache.spark.sql.execution.streaming.StreamExecution - -/** - * :: Experimental :: - * Interface used to start a streaming query query execution. - * - * @since 2.0.0 - */ -@Experimental -final class DataStreamWriter private[sql](df: DataFrame) { - - /** - * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamWriter = { - this.source = source - this - } - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamWriter = { - this.extraOptions += (key -> value) - this - } - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter = { - this.extraOptions ++= options - this - } - - /** - * Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter = { - this.options(options.asScala) - this - } - - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme.\ - * @since 2.0.0 - */ - @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter = { - this.partitioningColumns = colNames - this - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * @since 2.0.0 - */ - def start(path: String): ContinuousQuery = { - this.extraOptions += ("path" -> path) - start() - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - def start(): ContinuousQuery = { - val sink = ResolvedDataSource.createSink( - df.sqlContext, - source, - extraOptions.toMap, - normalizedParCols) - - new StreamExecution(df.sqlContext, df.logicalPlan, sink) - } - - private def normalizedParCols: Seq[String] = { - partitioningColumns.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = df.sqlContext.conf.defaultDataSourceName - - private var extraOptions = new scala.collection.mutable.HashMap[String, String] - - private var partitioningColumns: Seq[String] = Nil - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 13700be06828d..1661fdbec5326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -579,10 +579,9 @@ class SQLContext private[sql]( DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) } - /** * :: Experimental :: - * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. * {{{ * sqlContext.read.parquet("/path/to/file.parquet") * sqlContext.read.schema(schema).json("/path/to/file.json") @@ -594,14 +593,6 @@ class SQLContext private[sql]( @Experimental def read: DataFrameReader = new DataFrameReader(this) - - /** - * :: Experimental :: - * Returns a [[DataStreamReader]] than can be used to access data continuously as it arrives. - */ - @Experimental - def streamFrom: DataStreamReader = new DataStreamReader(this) - /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index e3065ac5f87d2..7702f535ad2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -122,7 +122,6 @@ object ResolvedDataSource extends Logging { provider.createSink(sqlContext, options, partitionColumns) } - /** Create a [[ResolvedDataSource]] for reading data in. */ def apply( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala index 1dab6ebf1bee9..b36b41cac9b4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -60,22 +60,22 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { import testImplicits._ test("resolve default source") { - sqlContext.streamFrom + sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open() - .streamTo + .stream() + .write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() } test("resolve full class") { - sqlContext.streamFrom + sqlContext.read .format("org.apache.spark.sql.streaming.test.DefaultSource") - .open() - .streamTo + .stream() + .write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() } @@ -83,12 +83,12 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .open() + .stream() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -96,12 +96,12 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { LastOptions.parameters = null - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .start() + .stream() .stop() assert(LastOptions.parameters("opt1") == "1") @@ -110,54 +110,53 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { } test("partitioning") { - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open() + .stream() - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Nil) - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("a") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Seq("a")) - withSQLConf("spark.sql.caseSensitive" -> "false") { - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("A") - .start() + .stream() .stop() assert(LastOptions.partitionColumns == Seq("a")) } intercept[AnalysisException] { - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") .partitionBy("b") - .start() + .stream() .stop() } } test("stream paths") { - val df = sqlContext.streamFrom + val df = sqlContext.read .format("org.apache.spark.sql.streaming.test") - .open("/test") + .stream("/test") assert(LastOptions.parameters("path") == "/test") LastOptions.parameters = null - df.streamTo + df.write .format("org.apache.spark.sql.streaming.test") - .start("/test") + .stream("/test") .stop() assert(LastOptions.parameters("path") == "/test") From de0914522fc5b2658959f9e2272b4e3162b14978 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Feb 2016 17:07:27 -0800 Subject: [PATCH 1794/2089] [SPARK-13131] [SQL] Use best and average time in benchmark Best time is stabler than average time, also added a column for nano seconds per row (which could be used to estimate contributions of each components in a query). Having best time and average time together for more information (we can see kind of variance). rate, time per row and relative are all calculated using best time. The result looks like this: ``` Intel(R) Core(TM) i7-4558U CPU 2.80GHz rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X ``` Author: Davies Liu Closes #11018 from davies/gen_bench. --- .../org/apache/spark/util/Benchmark.scala | 38 +++-- .../BenchmarkWholeStageCodegen.scala | 154 ++++++++---------- 2 files changed, 89 insertions(+), 103 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index d484cec7ae384..d1699f5c28655 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.commons.lang3.SystemUtils @@ -59,17 +60,21 @@ private[spark] class Benchmark( } println - val firstRate = results.head.avgRate + val firstBest = results.head.bestMs + val firstAvg = results.head.avgMs // The results are going to be processor specific so it is useful to include that. println(Benchmark.getProcessorName()) - printf("%-30s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") - println("-------------------------------------------------------------------------------") - results.zip(benchmarks).foreach { r => - printf("%-30s %16s %16s %14s\n", - r._2.name, - "%10.2f" format r._1.avgMs, - "%10.2f" format r._1.avgRate, - "%6.2f X" format (r._1.avgRate / firstRate)) + printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", + "Per Row(ns)", "Relative") + println("-----------------------------------------------------------------------------------" + + "--------") + results.zip(benchmarks).foreach { case (result, benchmark) => + printf("%-35s %16s %12s %13s %10s\n", + benchmark.name, + "%5.0f / %4.0f" format (result.bestMs, result.avgMs), + "%10.1f" format result.bestRate, + "%6.1f" format (1000 / result.bestRate), + "%3.1fX" format (firstBest / result.bestMs)) } println // scalastyle:on @@ -78,7 +83,7 @@ private[spark] class Benchmark( private[spark] object Benchmark { case class Case(name: String, fn: Int => Unit) - case class Result(avgMs: Double, avgRate: Double) + case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** * This should return a user helpful processor information. Getting at this depends on the OS. @@ -99,22 +104,27 @@ private[spark] object Benchmark { * the rate of the function. */ def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { - var totalTime = 0L + val runTimes = ArrayBuffer[Long]() for (i <- 0 until iters + 1) { val start = System.nanoTime() f(i) val end = System.nanoTime() - if (i != 0) totalTime += end - start + val runTime = end - start + if (i > 0) { + runTimes += runTime + } if (outputPerIteration) { // scalastyle:off - println(s"Iteration $i took ${(end - start) / 1000} microseconds") + println(s"Iteration $i took ${runTime / 1000} microseconds") // scalastyle:on } } - Result(totalTime.toDouble / 1000000 / iters, num * iters / (totalTime.toDouble / 1000)) + val best = runTimes.min + val avg = runTimes.sum / iters + Result(avg / 1000000, num / (best / 1000), best / 1000000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 15ba77353109c..33d4976403d9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -34,54 +34,47 @@ import org.apache.spark.util.Benchmark */ class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + .set("spark.sql.shuffle.partitions", "1") lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) - def testWholeStage(values: Int): Unit = { - val benchmark = new Benchmark("rang/filter/aggregate", values) + def runBenchmark(name: String, values: Int)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, values) - benchmark.addCase("Without codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).filter("(id & 1) = 1").count() - } - - benchmark.addCase("With codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).filter("(id & 1) = 1").count() + Seq(false, true).foreach { enabled => + benchmark.addCase(s"$name codegen=$enabled") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) + f + } } - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - Without codegen 7775.53 26.97 1.00 X - With codegen 342.15 612.94 22.73 X - */ benchmark.run() } - def testStatFunctions(values: Int): Unit = { - - val benchmark = new Benchmark("stat functions", values) - - benchmark.addCase("stddev w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + // These benchmark are skipped in normal build + ignore("range/filter/sum") { + val N = 500 << 20 + runBenchmark("rang/filter/sum", N) { + sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect() } + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X + rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X + */ + } - benchmark.addCase("stddev w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() - } + ignore("stat functions") { + val N = 100 << 20 - benchmark.addCase("kurtosis w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + runBenchmark("stddev", N) { + sqlContext.range(N).groupBy().agg("id" -> "stddev").collect() } - benchmark.addCase("kurtosis w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect() + runBenchmark("kurtosis", N) { + sqlContext.range(N).groupBy().agg("id" -> "kurtosis").collect() } @@ -99,64 +92,56 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Using DeclarativeAggregate: Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - stddev w/o codegen 989.22 21.20 1.00 X - stddev w codegen 352.35 59.52 2.81 X - kurtosis w/o codegen 3636.91 5.77 0.27 X - kurtosis w codegen 369.25 56.79 2.68 X + stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + stddev codegen=false 5630 / 5776 18.0 55.6 1.0X + stddev codegen=true 1259 / 1314 83.0 12.0 4.5X + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X + kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X */ - benchmark.run() } - def testAggregateWithKey(values: Int): Unit = { - val benchmark = new Benchmark("Aggregate with keys", values) + ignore("aggregate with keys") { + val N = 20 << 20 - benchmark.addCase("Aggregate w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() - } - benchmark.addCase(s"Aggregate w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + runBenchmark("Aggregate w keys", N) { + sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - Aggregate w/o codegen 4254.38 4.93 1.00 X - Aggregate w codegen 2661.45 7.88 1.60 X + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X + Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X */ - benchmark.run() } - def testBroadcastHashJoin(values: Int): Unit = { - val benchmark = new Benchmark("BroadcastHashJoin", values) - + ignore("broadcast hash join") { + val N = 20 << 20 val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) - benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() - } - benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() + runBenchmark("BroadcastHashJoin", N) { + sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() } /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X - BroadcastHashJoin w codegen 1028.40 10.20 2.97 X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X + BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X */ - benchmark.run() } - def testBytesToBytesMap(values: Int): Unit = { - val benchmark = new Benchmark("BytesToBytesMap", values) + ignore("hash and BytesToBytesMap") { + val N = 50 << 20 + + val benchmark = new Benchmark("BytesToBytesMap", N) benchmark.addCase("hash") { iter => var i = 0 @@ -167,7 +152,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(2) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 - while (i < values) { + while (i < N) { key.setInt(0, i % 1000) val h = Murmur3_x86_32.hashUnsafeWords( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) @@ -194,7 +179,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(2) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < values) { + while (i < N) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) if (loc.isDefined) { @@ -212,21 +197,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - hash 662.06 79.19 1.00 X - BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X - BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + hash 628 / 661 83.0 12.0 1.0X + BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X + BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X */ benchmark.run() } - - // These benchmark are skipped in normal build - ignore("benchmark") { - // testWholeStage(200 << 20) - // testStatFunctions(20 << 20) - // testAggregateWithKey(20 << 20) - // testBytesToBytesMap(50 << 20) - // testBroadcastHashJoin(10 << 20) - } } From a8e2ba776b20c8054918af646d8228bba1b87c9b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 3 Feb 2016 17:43:14 -0800 Subject: [PATCH 1795/2089] [SPARK-13152][CORE] Fix task metrics deprecation warning Make an internal non-deprecated version of incBytesRead and incRecordsRead so we don't have unecessary deprecation warnings in our build. Right now incBytesRead and incRecordsRead are marked as deprecated and for internal use only. We should make private[spark] versions which are not deprecated and switch to those internally so as to not clutter up the warning messages when building. cc andrewor14 who did the initial deprecation Author: Holden Karau Closes #11056 from holdenk/SPARK-13152-fix-task-metrics-deprecation-warnings. --- core/src/main/scala/org/apache/spark/CacheManager.scala | 4 ++-- .../main/scala/org/apache/spark/executor/InputMetrics.scala | 5 +++++ core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- core/src/main/scala/org/apache/spark/util/JsonProtocol.scala | 4 ++-- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index fa8e2b953835b..923ff411ce252 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,12 +44,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(blockResult) => // Partition is already materialized, so just return its values val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) - existingMetrics.incBytesRead(blockResult.bytes) + existingMetrics.incBytesReadInternal(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] new InterruptibleIterator[T](context, iter) { override def next(): T = { - existingMetrics.incRecordsRead(1) + existingMetrics.incRecordsReadInternal(1) delegate.next() } } diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index ed9e157ce758b..6d30d3c76a9fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -81,10 +81,15 @@ class InputMetrics private ( */ def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) + // Once incBytesRead & intRecordsRead is ready to be removed from the public API + // we can remove the internal versions and make the previous public API private. + // This has been done to suppress warnings when building. @deprecated("incrementing input metrics is for internal use only", "2.0.0") def incBytesRead(v: Long): Unit = _bytesRead.add(v) + private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v) @deprecated("incrementing input metrics is for internal use only", "2.0.0") def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e2ebd7f00d0d5..805cd9fe1f638 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -260,7 +260,7 @@ class HadoopRDD[K, V]( finished = true } if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -292,7 +292,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e71d3405c0ead..f23da39eb90de 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -188,7 +188,7 @@ class NewHadoopRDD[K, V]( } havePair = false if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -219,7 +219,7 @@ class NewHadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a2487eeb0483a..38e6478d80f03 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -811,8 +811,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) val inputMetrics = metrics.registerInputMetrics(readMethod) - inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) } // Updated blocks diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 9703b16c86f90..3605150b3b767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -214,7 +214,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } havePair = false if (!finished) { - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsReadInternal(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -246,7 +246,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) From a64831124c215f56f124747fa241560c70cf0a36 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 3 Feb 2016 19:32:41 -0800 Subject: [PATCH 1796/2089] [SPARK-13079][SQL] Extend and implement InMemoryCatalog This is a step towards consolidating `SQLContext` and `HiveContext`. This patch extends the existing Catalog API added in #10982 to include methods for handling table partitions. In particular, a partition is identified by `PartitionSpec`, which is just a `Map[String, String]`. The Catalog is still not used by anything yet, but its API is now more or less complete and an implementation is fully tested. About 200 lines are test code. Author: Andrew Or Closes #11069 from andrewor14/catalog. --- .../catalyst/catalog/InMemoryCatalog.scala | 129 ++++++++--- .../sql/catalyst/catalog/interface.scala | 40 +++- .../catalyst/catalog/CatalogTestCases.scala | 206 +++++++++++++++++- 3 files changed, 328 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9e6dfb7e9506f..38be61c52a95e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,9 +28,10 @@ import org.apache.spark.sql.AnalysisException * All public methods should be synchronized for thread-safety. */ class InMemoryCatalog extends Catalog { + import Catalog._ private class TableDesc(var table: Table) { - val partitions = new mutable.HashMap[String, TablePartition] + val partitions = new mutable.HashMap[PartitionSpec, TablePartition] } private class DatabaseDesc(var db: Database) { @@ -46,13 +47,20 @@ class InMemoryCatalog extends Catalog { } private def existsFunction(db: String, funcName: String): Boolean = { + assertDbExists(db) catalog(db).functions.contains(funcName) } private def existsTable(db: String, table: String): Boolean = { + assertDbExists(db) catalog(db).tables.contains(table) } + private def existsPartition(db: String, table: String, spec: PartitionSpec): Boolean = { + assertTableExists(db, table) + catalog(db).tables(table).partitions.contains(spec) + } + private def assertDbExists(db: String): Unit = { if (!catalog.contains(db)) { throw new AnalysisException(s"Database $db does not exist") @@ -60,16 +68,20 @@ class InMemoryCatalog extends Catalog { } private def assertFunctionExists(db: String, funcName: String): Unit = { - assertDbExists(db) if (!existsFunction(db, funcName)) { - throw new AnalysisException(s"Function $funcName does not exists in $db database") + throw new AnalysisException(s"Function $funcName does not exist in $db database") } } private def assertTableExists(db: String, table: String): Unit = { - assertDbExists(db) if (!existsTable(db, table)) { - throw new AnalysisException(s"Table $table does not exists in $db database") + throw new AnalysisException(s"Table $table does not exist in $db database") + } + } + + private def assertPartitionExists(db: String, table: String, spec: PartitionSpec): Unit = { + if (!existsPartition(db, table, spec)) { + throw new AnalysisException(s"Partition does not exist in database $db table $table: $spec") } } @@ -77,9 +89,11 @@ class InMemoryCatalog extends Catalog { // Databases // -------------------------------------------------------------------------- - override def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit = synchronized { + override def createDatabase( + dbDefinition: Database, + ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Database ${dbDefinition.name} already exists.") } } else { @@ -88,9 +102,9 @@ class InMemoryCatalog extends Catalog { } override def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit = synchronized { + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { if (catalog.contains(db)) { if (!cascade) { // If cascade is false, make sure the database is empty. @@ -133,11 +147,13 @@ class InMemoryCatalog extends Catalog { // Tables // -------------------------------------------------------------------------- - override def createTable(db: String, tableDefinition: Table, ifNotExists: Boolean) - : Unit = synchronized { + override def createTable( + db: String, + tableDefinition: Table, + ignoreIfExists: Boolean): Unit = synchronized { assertDbExists(db) if (existsTable(db, tableDefinition.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") } } else { @@ -145,8 +161,10 @@ class InMemoryCatalog extends Catalog { } } - override def dropTable(db: String, table: String, ignoreIfNotExists: Boolean) - : Unit = synchronized { + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = synchronized { assertDbExists(db) if (existsTable(db, table)) { catalog(db).tables.remove(table) @@ -190,14 +208,67 @@ class InMemoryCatalog extends Catalog { // Partitions // -------------------------------------------------------------------------- - override def alterPartition(db: String, table: String, part: TablePartition) - : Unit = synchronized { - throw new UnsupportedOperationException + override def createPartitions( + db: String, + table: String, + parts: Seq[TablePartition], + ignoreIfExists: Boolean): Unit = synchronized { + assertTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfExists) { + val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } + if (dupSpecs.nonEmpty) { + val dupSpecsStr = dupSpecs.mkString("\n===\n") + throw new AnalysisException( + s"The following partitions already exist in database $db table $table:\n$dupSpecsStr") + } + } + parts.foreach { p => existingParts.put(p.spec, p) } + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[PartitionSpec], + ignoreIfNotExists: Boolean): Unit = synchronized { + assertTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfNotExists) { + val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } + if (missingSpecs.nonEmpty) { + val missingSpecsStr = missingSpecs.mkString("\n===\n") + throw new AnalysisException( + s"The following partitions do not exist in database $db table $table:\n$missingSpecsStr") + } + } + partSpecs.foreach(existingParts.remove) } - override def alterPartitions(db: String, table: String, parts: Seq[TablePartition]) - : Unit = synchronized { - throw new UnsupportedOperationException + override def alterPartition( + db: String, + table: String, + spec: Map[String, String], + newPart: TablePartition): Unit = synchronized { + assertPartitionExists(db, table, spec) + val existingParts = catalog(db).tables(table).partitions + if (spec != newPart.spec) { + // Also a change in specs; remove the old one and add the new one back + existingParts.remove(spec) + } + existingParts.put(newPart.spec, newPart) + } + + override def getPartition( + db: String, + table: String, + spec: Map[String, String]): TablePartition = synchronized { + assertPartitionExists(db, table, spec) + catalog(db).tables(table).partitions(spec) + } + + override def listPartitions(db: String, table: String): Seq[TablePartition] = synchronized { + assertTableExists(db, table) + catalog(db).tables(table).partitions.values.toSeq } // -------------------------------------------------------------------------- @@ -205,11 +276,12 @@ class InMemoryCatalog extends Catalog { // -------------------------------------------------------------------------- override def createFunction( - db: String, func: Function, ifNotExists: Boolean): Unit = synchronized { + db: String, + func: Function, + ignoreIfExists: Boolean): Unit = synchronized { assertDbExists(db) - if (existsFunction(db, func.name)) { - if (!ifNotExists) { + if (!ignoreIfExists) { throw new AnalysisException(s"Function $func already exists in $db database") } } else { @@ -222,14 +294,16 @@ class InMemoryCatalog extends Catalog { catalog(db).functions.remove(funcName) } - override def alterFunction(db: String, funcName: String, funcDefinition: Function) - : Unit = synchronized { + override def alterFunction( + db: String, + funcName: String, + funcDefinition: Function): Unit = synchronized { assertFunctionExists(db, funcName) if (funcName != funcDefinition.name) { // Also a rename; remove the old one and add the new one back catalog(db).functions.remove(funcName) } - catalog(db).functions.put(funcName, funcDefinition) + catalog(db).functions.put(funcDefinition.name, funcDefinition) } override def getFunction(db: String, funcName: String): Function = synchronized { @@ -239,7 +313,6 @@ class InMemoryCatalog extends Catalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { assertDbExists(db) - val regex = pattern.replaceAll("\\*", ".*").r filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index a6caf91f3304b..b4d7dd2f4e31c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -29,17 +29,15 @@ import org.apache.spark.sql.AnalysisException * Implementations should throw [[AnalysisException]] when table or database don't exist. */ abstract class Catalog { + import Catalog._ // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: Database, ifNotExists: Boolean): Unit + def createDatabase(dbDefinition: Database, ignoreIfExists: Boolean): Unit - def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit def alterDatabase(db: String, dbDefinition: Database): Unit @@ -71,11 +69,28 @@ abstract class Catalog { // Partitions // -------------------------------------------------------------------------- - // TODO: need more functions for partitioning. + def createPartitions( + db: String, + table: String, + parts: Seq[TablePartition], + ignoreIfExists: Boolean): Unit - def alterPartition(db: String, table: String, part: TablePartition): Unit + def dropPartitions( + db: String, + table: String, + parts: Seq[PartitionSpec], + ignoreIfNotExists: Boolean): Unit - def alterPartitions(db: String, table: String, parts: Seq[TablePartition]): Unit + def alterPartition( + db: String, + table: String, + spec: PartitionSpec, + newPart: TablePartition): Unit + + def getPartition(db: String, table: String, spec: PartitionSpec): TablePartition + + // TODO: support listing by pattern + def listPartitions(db: String, table: String): Seq[TablePartition] // -------------------------------------------------------------------------- // Functions @@ -132,11 +147,11 @@ case class Column( /** * A partition (Hive style) defined in the catalog. * - * @param values values for the partition columns + * @param spec partition spec values indexed by column name * @param storage storage format of the partition */ case class TablePartition( - values: Seq[String], + spec: Catalog.PartitionSpec, storage: StorageFormat ) @@ -176,3 +191,8 @@ case class Database( locationUri: String, properties: Map[String, String] ) + + +object Catalog { + type PartitionSpec = Map[String, String] +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index ab9d5ac8a20eb..0d8434323fcbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -27,6 +27,11 @@ import org.apache.spark.sql.AnalysisException * Implementations of the [[Catalog]] interface can create test suites by extending this. */ abstract class CatalogTestCases extends SparkFunSuite { + private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map.empty[String, String]) + private val part1 = TablePartition(Map[String, String]("a" -> "1"), storageFormat) + private val part2 = TablePartition(Map[String, String]("b" -> "2"), storageFormat) + private val part3 = TablePartition(Map[String, String]("c" -> "3"), storageFormat) + private val funcClass = "org.apache.spark.myFunc" protected def newEmptyCatalog(): Catalog @@ -41,16 +46,16 @@ abstract class CatalogTestCases extends SparkFunSuite { */ private def newBasicCatalog(): Catalog = { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb("db1"), ifNotExists = false) - catalog.createDatabase(newDb("db2"), ifNotExists = false) - + catalog.createDatabase(newDb("db1"), ignoreIfExists = false) + catalog.createDatabase(newDb("db2"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) catalog } - private def newFunc(): Function = Function("funcname", "org.apache.spark.MyFunc") + private def newFunc(): Function = Function("funcname", funcClass) private def newDb(name: String = "default"): Database = Database(name, name + " description", "uri", Map.empty) @@ -59,7 +64,7 @@ abstract class CatalogTestCases extends SparkFunSuite { Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, None, None) - private def newFunc(name: String): Function = Function(name, "class.name") + private def newFunc(name: String): Function = Function(name, funcClass) // -------------------------------------------------------------------------- // Databases @@ -67,10 +72,10 @@ abstract class CatalogTestCases extends SparkFunSuite { test("basic create, drop and list databases") { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb(), ifNotExists = false) + catalog.createDatabase(newDb(), ignoreIfExists = false) assert(catalog.listDatabases().toSet == Set("default")) - catalog.createDatabase(newDb("default2"), ifNotExists = false) + catalog.createDatabase(newDb("default2"), ignoreIfExists = false) assert(catalog.listDatabases().toSet == Set("default", "default2")) } @@ -253,11 +258,194 @@ abstract class CatalogTestCases extends SparkFunSuite { // Partitions // -------------------------------------------------------------------------- - // TODO: Add tests cases for partitions + test("basic create and list partitions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable("mydb", newTable("mytbl"), ignoreIfExists = false) + catalog.createPartitions("mydb", "mytbl", Seq(part1, part2), ignoreIfExists = false) + assert(catalog.listPartitions("mydb", "mytbl").toSet == Set(part1, part2)) + } + + test("create partitions when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("does_not_exist", "tbl1", Seq(), ignoreIfExists = false) + } + intercept[AnalysisException] { + catalog.createPartitions("db2", "does_not_exist", Seq(), ignoreIfExists = false) + } + } + + test("create partitions that already exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = false) + } + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) + } + + test("drop partitions") { + val catalog = newBasicCatalog() + assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part2)) + val catalog2 = newBasicCatalog() + assert(catalog2.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + assert(catalog2.listPartitions("db2", "tbl2").isEmpty) + } + + test("drop partitions when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + } + } + + test("drop partitions that do not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + } + catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + } + + test("get partition") { + val catalog = newBasicCatalog() + assert(catalog.getPartition("db2", "tbl2", part1.spec) == part1) + assert(catalog.getPartition("db2", "tbl2", part2.spec) == part2) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl1", part3.spec) + } + } + + test("get partition when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getPartition("does_not_exist", "tbl1", part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition("db2", "does_not_exist", part1.spec) + } + } + + test("alter partitions") { + val catalog = newBasicCatalog() + val partSameSpec = part1.copy(storage = storageFormat.copy(serde = "myserde")) + val partNewSpec = part1.copy(spec = Map("x" -> "10")) + // alter but keep spec the same + catalog.alterPartition("db2", "tbl2", part1.spec, partSameSpec) + assert(catalog.getPartition("db2", "tbl2", part1.spec) == partSameSpec) + // alter and change spec + catalog.alterPartition("db2", "tbl2", part1.spec, partNewSpec) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl2", part1.spec) + } + assert(catalog.getPartition("db2", "tbl2", partNewSpec.spec) == partNewSpec) + } + + test("alter partition when database / table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterPartition("does_not_exist", "tbl1", part1.spec, part1) + } + intercept[AnalysisException] { + catalog.alterPartition("db2", "does_not_exist", part1.spec, part1) + } + } // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- - // TODO: Add tests cases for functions + test("basic create and list functions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction("mydb", newFunc("myfunc"), ignoreIfExists = false) + assert(catalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + } + + test("create function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("does_not_exist", newFunc(), ignoreIfExists = false) + } + } + + test("create function that already exists") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + } + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = true) + } + + test("drop function") { + val catalog = newBasicCatalog() + assert(catalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction("db2", "func1") + assert(catalog.listFunctions("db2", "*").isEmpty) + } + + test("drop function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("does_not_exist", "something") + } + } + + test("drop function that does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("db2", "does_not_exist") + } + } + + test("get function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1") == newFunc("func1")) + intercept[AnalysisException] { + catalog.getFunction("db2", "does_not_exist") + } + } + + test("get function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getFunction("does_not_exist", "func1") + } + } + + test("alter function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1").className == funcClass) + // alter func but keep name + catalog.alterFunction("db2", "func1", newFunc("func1").copy(className = "muhaha")) + assert(catalog.getFunction("db2", "func1").className == "muhaha") + // alter func and change name + catalog.alterFunction("db2", "func1", newFunc("funcky")) + intercept[AnalysisException] { + catalog.getFunction("db2", "func1") + } + assert(catalog.getFunction("db2", "funcky").className == funcClass) + } + + test("alter function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterFunction("does_not_exist", "func1", newFunc()) + } + } + + test("list functions") { + val catalog = newBasicCatalog() + catalog.createFunction("db2", newFunc("func2"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("not_me"), ignoreIfExists = false) + assert(catalog.listFunctions("db2", "*").toSet == Set("func1", "func2", "not_me")) + assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2")) + } + } From 0f81318ae217346c20894572795e1a9cee2ebc8f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 3 Feb 2016 21:05:53 -0800 Subject: [PATCH 1797/2089] [SPARK-12828][SQL] add natural join support Jira: https://issues.apache.org/jira/browse/SPARK-12828 Author: Daoyuan Wang Closes #10762 from adrian-wang/naturaljoin. --- .../sql/catalyst/parser/FromClauseParser.g | 23 +++-- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 2 + .../sql/catalyst/parser/SparkSqlParser.g | 4 + .../spark/sql/catalyst/CatalystQl.scala | 4 + .../sql/catalyst/analysis/Analyzer.scala | 43 +++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../spark/sql/catalyst/plans/joinTypes.scala | 4 + .../plans/logical/basicOperators.scala | 10 ++- .../analysis/ResolveNaturalJoinSuite.scala | 90 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 24 +++++ 11 files changed, 198 insertions(+), 11 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index 6d76afcd4ac07..e83f8a7cd1b5c 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -117,15 +117,20 @@ joinToken @init { gParent.pushMsg("join type specifier", state); } @after { gParent.popMsg(state); } : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN + | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN + | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN + | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN ; lateralView diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 1d07a27353dcb..fd1ad59207e31 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -335,6 +335,8 @@ KW_CACHE: 'CACHE'; KW_UNCACHE: 'UNCACHE'; KW_DFS: 'DFS'; +KW_NATURAL: 'NATURAL'; + // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 6591f6b0f56ce..9935678ca2ca2 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN; TOK_FULLOUTERJOIN; TOK_UNIQUEJOIN; TOK_CROSSJOIN; +TOK_NATURALJOIN; +TOK_NATURALLEFTOUTERJOIN; +TOK_NATURALRIGHTOUTERJOIN; +TOK_NATURALFULLOUTERJOIN; TOK_LOAD; TOK_EXPORT; TOK_IMPORT; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 7ce2407913ade..a42360d5629f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -520,6 +520,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case "TOK_LEFTSEMIJOIN" => LeftSemi case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + case "TOK_NATURALJOIN" => NaturalJoin(Inner) + case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) + case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) + case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) } Join(nodeToRelation(relation1), nodeToRelation(relation2), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a983dc1cdfebe..b30ed5928fd56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -81,6 +82,7 @@ class Analyzer( ResolveAliases :: ResolveWindowOrder :: ResolveWindowFrame :: + ResolveNaturalJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -1230,6 +1232,47 @@ class Analyzer( } } } + + /** + * Removes natural joins by calculating output columns based on output from two sides, + * Then apply a Project on a normal Join to eliminate natural join. + */ + object ResolveNaturalJoin extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // Should not skip unresolved nodes because natural join is always unresolved. + case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + // find common column names from both sides, should be treated like usingColumns + val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) + val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) + val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions + val newCondition = (condition ++ joinPairs.map { + case (l, r) => EqualTo(l, r) + }).reduceLeftOption(And) + // columns not in joinPairs + val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) + val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + // we should only keep unique columns(depends on joinType) for joinCols + val projectList = joinType match { + case LeftOuter => + leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case RightOuter => + rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + // in full outer join, joinCols should be non-null if there is. + val joinedCols = joinPairs.map { + case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() + } + joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + rUniqueOutput.map(_.withNullability(true)) + case _ => + rightKeys ++ lUniqueOutput ++ rUniqueOutput + } + // use Project to trim unnecessary fields + Project(projectList, Join(left, right, joinType, newCondition)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f156b5d10acc2..4ecee75048248 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -905,6 +905,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case FullOuter => f // DO Nothing for Full Outer Join + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } // push down the join filter into sub query scanning if applicable @@ -939,6 +940,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index a5f6764aef7ce..b10f1e63a73e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -60,3 +60,7 @@ case object FullOuter extends JoinType { case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } + +case class NaturalJoin(tpe: JoinType) extends JoinType { + override def sql: String = "NATURAL " + tpe.sql +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8150ff8434762..03a79520cbd3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -250,12 +250,20 @@ case class Join( def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. - override lazy val resolved: Boolean = { + // NaturalJoin should be ready for resolution only if everything else is resolved here + lazy val resolvedExceptNatural: Boolean = { childrenResolved && expressions.forall(_.resolved) && duplicateResolved && condition.forall(_.dataType == BooleanType) } + + // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need + // to eliminate natural before we mark it resolved. + override lazy val resolved: Boolean = joinType match { + case NaturalJoin(_) => false + case _ => resolvedExceptNatural + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala new file mode 100644 index 0000000000000..a6554fbc414b5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +class ResolveNaturalJoinSuite extends AnalysisTest { + lazy val a = 'a.string + lazy val b = 'b.string + lazy val c = 'c.string + lazy val aNotNull = a.notNull + lazy val bNotNull = b.notNull + lazy val cNotNull = c.notNull + lazy val r1 = LocalRelation(a, b) + lazy val r2 = LocalRelation(a, c) + lazy val r3 = LocalRelation(aNotNull, bNotNull) + lazy val r4 = LocalRelation(bNotNull, cNotNull) + + test("natural inner join") { + val plan = r1.join(r2, NaturalJoin(Inner), None) + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural left join") { + val plan = r1.join(r2, NaturalJoin(LeftOuter), None) + val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural right join") { + val plan = r1.join(r2, NaturalJoin(RightOuter), None) + val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(plan, expected) + } + + test("natural full outer join") { + val plan = r1.join(r2, NaturalJoin(FullOuter), None) + val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select( + Alias(Coalesce(Seq(a, a)), "a")(), b, c) + checkAnalysis(plan, expected) + } + + test("natural inner join with no nullability") { + val plan = r3.join(r4, NaturalJoin(Inner), None) + val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, cNotNull) + checkAnalysis(plan, expected) + } + + test("natural left join with no nullability") { + val plan = r3.join(r4, NaturalJoin(LeftOuter), None) + val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, c) + checkAnalysis(plan, expected) + } + + test("natural right join with no nullability") { + val plan = r3.join(r4, NaturalJoin(RightOuter), None) + val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, a, cNotNull) + checkAnalysis(plan, expected) + } + + test("natural full outer join with no nullability") { + val plan = r3.join(r4, NaturalJoin(FullOuter), None) + val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( + Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) + checkAnalysis(plan, expected) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 84203bbfef66a..f15b926bd27cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -474,6 +474,7 @@ class DataFrame private[sql]( val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) Alias(Coalesce(Seq(leftCol, rightCol)), col)() } + case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.") } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 79bfd4b44b70a..8ef7b61314a56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2075,4 +2075,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("natural join") { + val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") + val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") + withTempTable("nt1", "nt2") { + df1.registerTempTable("nt1") + df2.registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } + } } From c2c956bcd1a75fd01868ee9ad2939a6d3de52bc2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 3 Feb 2016 21:19:44 -0800 Subject: [PATCH 1798/2089] [ML][DOC] fix wrong api link in ml onevsrest minor fix for api link in ml onevsrest Author: Yuhao Yang Closes #11068 from hhbyyh/onevsrestDoc. --- docs/ml-classification-regression.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 8ffc997b4bf5a..9569a06472cbf 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -289,7 +289,7 @@ The example below demonstrates how to load the
        -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) for more details. {% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %}
        From d39087147ff1052b623cdba69ffbde28b266745f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 Feb 2016 23:17:51 -0800 Subject: [PATCH 1799/2089] [SPARK-13113] [CORE] Remove unnecessary bit operation when decoding page number JIRA: https://issues.apache.org/jira/browse/SPARK-13113 As we shift bits right, looks like the bitwise AND operation is unnecessary. Author: Liang-Chi Hsieh Closes #11002 from viirya/improve-decodepagenumber. --- .../main/java/org/apache/spark/memory/TaskMemoryManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d31eb449eb82e..d2a88864f7ac9 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -312,7 +312,7 @@ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) @VisibleForTesting public static int decodePageNumber(long pagePlusOffsetAddress) { - return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + return (int) (pagePlusOffsetAddress >>> OFFSET_BITS); } private static long decodeOffset(long pagePlusOffsetAddress) { From dee801adb78d6abd0abbf76b4dfa71aa296b4f0b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Feb 2016 23:43:48 -0800 Subject: [PATCH 1800/2089] [SPARK-12828][SQL] Natural join follow-up This is a small addendum to #10762 to make the code more robust again future changes. Author: Reynold Xin Closes #11070 from rxin/SPARK-12828-natural-join. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../spark/sql/catalyst/plans/joinTypes.scala | 2 ++ .../analysis/ResolveNaturalJoinSuite.scala | 6 +++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b30ed5928fd56..b59eb12419c45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1239,21 +1239,23 @@ class Analyzer( */ object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Should not skip unresolved nodes because natural join is always unresolved. case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => - // find common column names from both sides, should be treated like usingColumns + // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) - }).reduceLeftOption(And) + }).reduceOption(And) + // columns not in joinPairs val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) - // we should only keep unique columns(depends on joinType) for joinCols + + // the output list looks like: join keys, columns from left, columns from right val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) @@ -1261,13 +1263,14 @@ class Analyzer( rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => // in full outer join, joinCols should be non-null if there is. - val joinedCols = joinPairs.map { - case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() - } - joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } + joinedCols ++ + lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) - case _ => + case Inner => rightKeys ++ lUniqueOutput ++ rUniqueOutput + case _ => + sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields Project(projectList, Join(left, right, joinType, newCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index b10f1e63a73e2..27a75326eba07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -62,5 +62,7 @@ case object LeftSemi extends JoinType { } case class NaturalJoin(tpe: JoinType) extends JoinType { + require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), + "Unsupported natural join type " + tpe) override def sql: String = "NATURAL " + tpe.sql } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index a6554fbc414b5..fcf4ac1967a53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest { lazy val aNotNull = a.notNull lazy val bNotNull = b.notNull lazy val cNotNull = c.notNull - lazy val r1 = LocalRelation(a, b) - lazy val r2 = LocalRelation(a, c) + lazy val r1 = LocalRelation(b, a) + lazy val r2 = LocalRelation(c, a) lazy val r3 = LocalRelation(aNotNull, bNotNull) - lazy val r4 = LocalRelation(bNotNull, cNotNull) + lazy val r4 = LocalRelation(cNotNull, bNotNull) test("natural inner join") { val plan = r1.join(r2, NaturalJoin(Inner), None) From 2eaeafe8a2aa31be9b230b8d53d3baccd32535b1 Mon Sep 17 00:00:00 2001 From: Charles Allen Date: Thu, 4 Feb 2016 10:27:25 -0800 Subject: [PATCH 1801/2089] [SPARK-12330][MESOS] Fix mesos coarse mode cleanup In the current implementation the mesos coarse scheduler does not wait for the mesos tasks to complete before ending the driver. This causes a race where the task has to finish cleaning up before the mesos driver terminates it with a SIGINT (and SIGKILL after 3 seconds if the SIGINT doesn't work). This PR causes the mesos coarse scheduler to wait for the mesos tasks to finish (with a timeout defined by `spark.mesos.coarse.shutdown.ms`) This PR also fixes a regression caused by [SPARK-10987] whereby submitting a shutdown causes a race between the local shutdown procedure and the notification of the scheduler driver disconnection. If the scheduler driver disconnection wins the race, the coarse executor incorrectly exits with status 1 (instead of the proper status 0) With this patch the mesos coarse scheduler terminates properly, the executors clean up, and the tasks are reported as `FINISHED` in the Mesos console (as opposed to `KILLED` in < 1.6 or `FAILED` in 1.6 and later) Author: Charles Allen Closes #10319 from drcrallen/SPARK-12330. --- .../CoarseGrainedExecutorBackend.scala | 8 +++- .../mesos/CoarseMesosSchedulerBackend.scala | 39 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 136cf4a84d387..3b5cb18da1b26 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.util.{Failure, Success} @@ -42,6 +43,7 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -102,19 +104,23 @@ private[spark] class CoarseGrainedExecutorBackend( } case StopExecutor => + stopping.set(true) logInfo("Driver commanded a shutdown") // Cannot shutdown here because an ack may need to be sent back to the caller. So send // a message to self to actually do the shutdown. self.send(Shutdown) case Shutdown => + stopping.set(true) executor.stop() stop() rpcEnv.shutdown() } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (driver.exists(_.address == remoteAddress)) { + if (stopping.get()) { + logInfo(s"Driver from $remoteAddress disconnected during shutdown") + } else if (driver.exists(_.address == remoteAddress)) { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2f095b86c69ef..722293bb7a53b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -19,11 +19,13 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{Collections, List => JList} +import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.base.Stopwatch import com.google.common.collect.HashBiMap import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} @@ -60,6 +62,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdown.ms", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdown.ms must be >= 0") + + // Synchronization protected by stateLock + private[this] var stopCalled: Boolean = false + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -245,6 +253,13 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { + if (stopCalled) { + logDebug("Ignoring offers during shutdown") + // Driver should simply return a stopped status on race + // condition between this.stop() and completing here + offers.asScala.map(_.getId).foreach(d.declineOffer) + return + } val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers.asScala) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -364,7 +379,29 @@ private[spark] class CoarseMesosSchedulerBackend( } override def stop() { - super.stop() + // Make sure we're not launching tasks during shutdown + stateLock.synchronized { + if (stopCalled) { + logWarning("Stop called multiple times, ignoring") + return + } + stopCalled = true + super.stop() + } + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. + // See SPARK-12330 + val stopwatch = new Stopwatch() + stopwatch.start() + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent + while (slaveIdsWithExecutors.nonEmpty && + stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { + Thread.sleep(100) + } + if (slaveIdsWithExecutors.nonEmpty) { + logWarning(s"Timed out waiting for ${slaveIdsWithExecutors.size} remaining executors " + + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + + "on the mesos nodes.") + } if (mesosDriver != null) { mesosDriver.stop() } From 62a7c28388539e6fc7d16ee3009f2cf79d8635bd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 4 Feb 2016 10:29:38 -0800 Subject: [PATCH 1802/2089] [SPARK-13164][CORE] Replace deprecated synchronized buffer in core Building with scala 2.11 results in the warning trait SynchronizedBuffer in package mutable is deprecated: Synchronization via traits is deprecated as it is inherently unreliable. Consider java.util.concurrent.ConcurrentLinkedQueue as an alternative. Investigation shows we are already using ConcurrentLinkedQueue in other locations so switch our uses of SynchronizedBuffer to ConcurrentLinkedQueue. Author: Holden Karau Closes #11059 from holdenk/SPARK-13164-replace-deprecated-synchronized-buffer-in-core. --- .../org/apache/spark/ContextCleaner.scala | 26 +++++++++---------- .../spark/deploy/client/AppClientSuite.scala | 20 +++++++------- .../org/apache/spark/rpc/RpcEnvSuite.scala | 23 ++++++++-------- .../apache/spark/util/EventLoopSuite.scala | 10 +++---- 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5a42299a0bf83..17014e4954f90 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,9 +18,9 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} @@ -57,13 +57,11 @@ private class CleanupTaskWeakReference( */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] - with SynchronizedBuffer[CleanupTaskWeakReference] + private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]() private val referenceQueue = new ReferenceQueue[AnyRef] - private val listeners = new ArrayBuffer[CleanerListener] - with SynchronizedBuffer[CleanerListener] + private val listeners = new ConcurrentLinkedQueue[CleanerListener]() private val cleaningThread = new Thread() { override def run() { keepCleaning() }} @@ -111,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener): Unit = { - listeners += listener + listeners.add(listener) } /** Start the cleaner. */ @@ -166,7 +164,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) } /** Keep cleaning RDD, shuffle, and broadcast state. */ @@ -179,7 +177,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { synchronized { reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get + referenceBuffer.remove(reference.get) task match { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) @@ -206,7 +204,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) - listeners.foreach(_.rddCleaned(rddId)) + listeners.asScala.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { case e: Exception => logError("Error cleaning RDD " + rddId, e) @@ -219,7 +217,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) blockManagerMaster.removeShuffle(shuffleId, blocking) - listeners.foreach(_.shuffleCleaned(shuffleId)) + listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { case e: Exception => logError("Error cleaning shuffle " + shuffleId, e) @@ -231,7 +229,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) - listeners.foreach(_.broadcastCleaned(broadcastId)) + listeners.asScala.foreach(_.broadcastCleaned(broadcastId)) logDebug(s"Cleaned broadcast $broadcastId") } catch { case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) @@ -243,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning accumulator " + accId) Accumulators.remove(accId) - listeners.foreach(_.accumCleaned(accId)) + listeners.asScala.foreach(_.accumCleaned(accId)) logInfo("Cleaned accumulator " + accId) } catch { case e: Exception => logError("Error cleaning accumulator " + accId, e) @@ -258,7 +256,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning rdd checkpoint data " + rddId) ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) - listeners.foreach(_.checkpointCleaned(rddId)) + listeners.asScala.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index eb794b6739d5e..658779360b7a5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.deploy.client -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfterAll @@ -165,14 +167,14 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd /** Application Listener to collect events */ private class AppClientCollector extends AppClientListener with Logging { - val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val connectedIdList = new ConcurrentLinkedQueue[String]() @volatile var disconnectedCount: Int = 0 - val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] - val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val deadReasonList = new ConcurrentLinkedQueue[String]() + val execAddedList = new ConcurrentLinkedQueue[String]() + val execRemovedList = new ConcurrentLinkedQueue[String]() def connected(id: String): Unit = { - connectedIdList += id + connectedIdList.add(id) } def disconnected(): Unit = { @@ -182,7 +184,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } def dead(reason: String): Unit = { - deadReasonList += reason + deadReasonList.add(reason) } def executorAdded( @@ -191,11 +193,11 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd hostPort: String, cores: Int, memory: Int): Unit = { - execAddedList += id + execAddedList.add(id) } def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { - execRemovedList += id + execRemovedList.add(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6f4eda8b47dde..22048003882dd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.rpc import java.io.{File, NotSerializableException} import java.nio.charset.StandardCharsets.UTF_8 import java.util.UUID -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeoutException, TimeUnit} import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -490,30 +491,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. - * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( _env: RpcEnv, - name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = { + val events = new ConcurrentLinkedQueue[(Any, Any)] val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => - case m => events += "receive" -> m + case m => events.add("receive" -> m) } override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress + events.add("onConnected" -> remoteAddress) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + events.add("onDisconnected" -> remoteAddress) } override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress + events.add("onNetworkError" -> remoteAddress) } }) @@ -560,7 +561,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) } clientEnv.shutdown() @@ -568,8 +569,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { eventually(timeout(5 seconds), interval(5 millis)) { // We don't know the exact client address but at least we can verify the message type - assert(events.map(_._1).contains("onConnected")) - assert(events.map(_._1).contains("onDisconnected")) + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onDisconnected")) } } finally { clientEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index b207d497f33c2..6f7dddd4f760a 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { - val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { override def onReceive(event: Int): Unit = { - buffer += event + buffer.add(event) } override def onError(e: Throwable): Unit = {} @@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts { eventLoop.start() (1 to 100).foreach(eventLoop.post) eventually(timeout(5 seconds), interval(5 millis)) { - assert((1 to 100) === buffer.toSeq) + assert((1 to 100) === buffer.asScala.toSeq) } eventLoop.stop() } From 4120bcbaffe92da40486b469334119ed12199f4f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 10:32:16 -0800 Subject: [PATCH 1803/2089] [SPARK-13162] Standalone mode does not respect initial executors Currently the Master would always set an application's initial executor limit to infinity. If the user specified `spark.dynamicAllocation.initialExecutors`, the config would not take effect. This is similar to #11047 but for standalone mode. Author: Andrew Or Closes #11054 from andrewor14/standalone-da-initial. --- .../spark/ExecutorAllocationManager.scala | 2 ++ .../spark/deploy/ApplicationDescription.scala | 3 +++ .../spark/deploy/master/ApplicationInfo.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 16 ++++++++++++---- .../StandaloneDynamicAllocationSuite.scala | 17 ++++++++++++++++- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 3431fc13dcb4e..db143d7341ce4 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -231,6 +231,8 @@ private[spark] class ExecutorAllocationManager( } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) + + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 78bbd5c03f4a6..c5c5c60923f4e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -29,6 +29,9 @@ private[spark] case class ApplicationDescription( // short name of compression codec used when writing event logs, if any (e.g. lzf) eventLogCodec: Option[String] = None, coresPerExecutor: Option[Int] = None, + // number of executors this application wants to start with, + // only used if dynamic allocation is enabled + initialExecutorLimit: Option[Int] = None, user: String = System.getProperty("user.name", "")) { override def toString: String = "ApplicationDescription(" + name + ")" diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 7e2cf956c7253..4ffb5283e99a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -65,7 +65,7 @@ private[spark] class ApplicationInfo( appSource = new ApplicationSource(this) nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] - executorLimit = Integer.MAX_VALUE + executorLimit = desc.initialExecutorLimit.getOrElse(Integer.MAX_VALUE) appUIUrlAtHistoryServer = None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 16f33163789ab..d209645610c12 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,11 +19,11 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -89,8 +89,16 @@ private[spark] class SparkDeploySchedulerBackend( args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, - command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) + // If we're using dynamic allocation, set our initial executor limit to 0 for now. + // ExecutorAllocationManager will send the real initial limit to the Master later. + val initialExecutorLimit = + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index fdada0777f9a9..b7ff5c9e8c0d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -447,7 +447,23 @@ class StandaloneDynamicAllocationSuite apps = getApplications() // kill executor successfully assert(apps.head.executors.size === 1) + } + test("initial executor limit") { + val initialExecutorLimit = 1 + val myConf = appConf + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString) + sc = new SparkContext(myConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === initialExecutorLimit) + assert(apps.head.getExecutorLimit === initialExecutorLimit) + } } // =============================== @@ -540,7 +556,6 @@ class StandaloneDynamicAllocationSuite val missingExecutors = masterExecutors.toSet.diff(driverExecutors.toSet).toSeq.sorted missingExecutors.foreach { id => // Fake an executor registration so the driver knows about us - val port = System.currentTimeMillis % 65536 val endpointRef = mock(classOf[RpcEndpointRef]) val mockAddress = mock(classOf[RpcAddress]) when(endpointRef.address).thenReturn(mockAddress) From 15205da817b24ef0e349ec24d84034dc30b501f8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 10:34:43 -0800 Subject: [PATCH 1804/2089] [SPARK-13053][TEST] Unignore tests in InternalAccumulatorSuite These were ignored because they are incorrectly written; they don't actually trigger stage retries, which is what the tests are testing. These tests are now rewritten to induce stage retries through fetch failures. Note: there were 2 tests before and now there's only 1. What happened? It turns out that the case where we only resubmit a subset of of the original missing partitions is very difficult to simulate in tests without potentially introducing flakiness. This is because the `DAGScheduler` removes all map outputs associated with a given executor when this happens, and we will need multiple executors to trigger this case, and sometimes the scheduler still removes map outputs from all executors. Author: Andrew Or Closes #10969 from andrewor14/unignore-accum-test. --- .../org/apache/spark/AccumulatorSuite.scala | 52 +++++-- .../spark/InternalAccumulatorSuite.scala | 128 +++++++++--------- 2 files changed, 102 insertions(+), 78 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index b8f2b96d7088d..e0fdd45973858 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -323,35 +323,60 @@ private[spark] object AccumulatorSuite { * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. */ private class SaveInfoListener extends SparkListener { - private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] - private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] - private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + type StageId = Int + type StageAttemptId = Int - // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated + private val completedStageInfos = new ArrayBuffer[StageInfo] + private val completedTaskInfos = + new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]] + + // Callback to call when a job completes. Parameter is job ID. @GuardedBy("this") + private var jobCompletionCallback: () => Unit = null + private var calledJobCompletionCallback: Boolean = false private var exception: Throwable = null def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq - def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq + def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] = + completedTaskInfos.get((stageId, stageAttemptId)).getOrElse(Seq.empty[TaskInfo]) - /** Register a callback to be called on job end. */ - def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { - jobCompletionCallback = callback + /** + * If `jobCompletionCallback` is set, block until the next call has finished. + * If the callback failed with an exception, throw it. + */ + def awaitNextJobCompletion(): Unit = synchronized { + if (jobCompletionCallback != null) { + while (!calledJobCompletionCallback) { + wait() + } + calledJobCompletionCallback = false + if (exception != null) { + exception = null + throw exception + } + } } - /** Throw a stored exception, if any. */ - def maybeThrowException(): Unit = synchronized { - if (exception != null) { throw exception } + /** + * Register a callback to be called on job end. + * A call to this should be followed by [[awaitNextJobCompletion]]. + */ + def registerJobCompletionCallback(callback: () => Unit): Unit = { + jobCompletionCallback = callback } override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { if (jobCompletionCallback != null) { try { - jobCompletionCallback(jobEnd.jobId) + jobCompletionCallback() } catch { // Store any exception thrown here so we can throw them later in the main thread. // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. case NonFatal(e) => exception = e + } finally { + calledJobCompletionCallback = true + notify() } } } @@ -361,7 +386,8 @@ private class SaveInfoListener extends SparkListener { } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - completedTaskInfos += taskEnd.taskInfo + completedTaskInfos.getOrElseUpdate( + (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 630b46f828df7..44a16e26f4935 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockStatus} @@ -160,7 +161,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { iter } // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => + listener.registerJobCompletionCallback { () => val stageInfos = listener.getCompletedStageInfos val taskInfos = listener.getCompletedTaskInfos assert(stageInfos.size === 1) @@ -179,6 +180,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) } rdd.count() + listener.awaitNextJobCompletion() } test("internal accumulators in multiple stages") { @@ -205,7 +207,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { iter } // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => + listener.registerJobCompletionCallback { () => // We ran 3 stages, and the accumulator values should be distinct val stageInfos = listener.getCompletedStageInfos assert(stageInfos.size === 3) @@ -220,13 +222,66 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { rdd.count() } - // TODO: these two tests are incorrect; they don't actually trigger stage retries. - ignore("internal accumulators in fully resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks - } + test("internal accumulators in resubmitted stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + + // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with + // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt. + // This should retry both stages in the scheduler. Note that we only want to fail the + // first stage attempt because we want the stage to eventually succeed. + val x = sc.parallelize(1 to 100, numPartitions) + .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter } + .groupBy(identity) + val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId + val rdd = x.mapPartitionsWithIndex { case (i, iter) => + // Fail the first stage attempt. Here we use the task attempt ID to determine this. + // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt + // ID that's < 2 * numPartitions belongs to the first attempt of this stage. + val taskContext = TaskContext.get() + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + if (isFirstStageAttempt) { + throw new FetchFailedException( + SparkEnv.get.blockManager.blockManagerId, + sid, + taskContext.partitionId(), + taskContext.partitionId(), + "simulated fetch failure") + } else { + iter + } + } - ignore("internal accumulators in partially resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { () => + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried + val mapStageId = stageInfos.head.stageId + val mapStageInfo1stAttempt = stageInfos.head + val mapStageInfo2ndAttempt = { + stageInfos.tail.find(_.stageId == mapStageId).getOrElse { + fail("expected two attempts of the same shuffle map stage.") + } + } + val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values) + val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values) + // Both map stages should have succeeded, since the fetch failure happened in the + // result stage, not the map stage. This means we should get the accumulator updates + // from all partitions. + assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions) + assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions) + // Because this test resubmitted the map stage with all missing partitions, we should have + // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is + // the case by comparing the accumulator IDs between the two attempts. + // Note: it would be good to also test the case where the map stage is resubmitted where + // only a subset of the original partitions are missing. However, this scenario is very + // difficult to construct without potentially introducing flakiness. + assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id) + } + rdd.count() + listener.awaitNextJobCompletion() } test("internal accumulators are registered for cleanups") { @@ -257,63 +312,6 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } } - /** - * Test whether internal accumulators are merged properly if some tasks fail. - * TODO: make this actually retry the stage. - */ - private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { - val listener = new SaveInfoListener - val numPartitions = 10 - val numFailedPartitions = (0 until numPartitions).count(failCondition) - // This says use 1 core and retry tasks up to 2 times - sc = new SparkContext("local[1, 2]", "test") - sc.addSparkListener(listener) - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => - val taskContext = TaskContext.get() - taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1 - // Fail the first attempts of a subset of the tasks - if (failCondition(i) && taskContext.attemptNumber() == 0) { - throw new Exception("Failing a task intentionally.") - } - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findTestAccum(stageInfos.head.accumulables.values) - // If all partitions failed, then we would resubmit the whole stage again and create a - // fresh set of internal accumulators. Otherwise, these internal accumulators do count - // failed values, so we must include the failed values. - val expectedAccumValue = - if (numPartitions == numFailedPartitions) { - numPartitions - } else { - numPartitions + numFailedPartitions - } - assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findTestAccum(taskInfo.accumulables) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.asInstanceOf[Long] === 1L) - assert(taskAccum.value.isDefined) - Some(taskAccum.value.get.asInstanceOf[Long]) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None - } - } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) - } - rdd.count() - listener.maybeThrowException() - } - /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ From 085f510ae554e2739a38ee0bc7210c4ece902f3f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 11:07:06 -0800 Subject: [PATCH 1805/2089] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #7971 (requested by yhuai) Closes #8539 (requested by srowen) Closes #8746 (requested by yhuai) Closes #9288 (requested by andrewor14) Closes #9321 (requested by andrewor14) Closes #9935 (requested by JoshRosen) Closes #10442 (requested by andrewor14) Closes #10585 (requested by srowen) Closes #10785 (requested by srowen) Closes #10832 (requested by andrewor14) Closes #10941 (requested by marmbrus) Closes #11024 (requested by andrewor14) From 33212cb9a13a6012b4c19ccfc0fb3db75de304da Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 4 Feb 2016 11:08:50 -0800 Subject: [PATCH 1806/2089] [SPARK-13168][SQL] Collapse adjacent repartition operators Spark SQL should collapse adjacent `Repartition` operators and only keep the last one. Author: Josh Rosen Closes #11064 from JoshRosen/collapse-repartition. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 16 ++++++++++++++-- ...ingSuite.scala => CollapseProjectSuite.scala} | 4 ++-- .../catalyst/optimizer/FilterPushdownSuite.scala | 2 +- .../sql/catalyst/optimizer/JoinOrderSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 +++++++++++++-- .../org/apache/spark/sql/hive/SQLBuilder.scala | 4 ++-- 6 files changed, 33 insertions(+), 10 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ProjectCollapsingSuite.scala => CollapseProjectSuite.scala} (96%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ecee75048248..a1ac93073916c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,7 +68,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughAggregate, ColumnPruning, // Operator combine - ProjectCollapsing, + CollapseRepartition, + CollapseProject, CombineFilters, CombineLimits, CombineUnions, @@ -322,7 +323,7 @@ object ColumnPruning extends Rule[LogicalPlan] { * Combines two adjacent [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression. */ -object ProjectCollapsing extends Rule[LogicalPlan] { +object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p @ Project(projectList1, Project(projectList2, child)) => @@ -390,6 +391,16 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Combines adjacent [[Repartition]] operators by keeping only the last one. + */ +object CollapseRepartition extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) => + Repartition(numPartitions, shuffle, child) + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given @@ -857,6 +868,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. + * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala similarity index 96% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 85b6530481b03..f5fd5ca6beb15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -25,11 +25,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ProjectCollapsingSuite extends PlanTest { +class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: - Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + Batch("CollapseProject", Once, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index f9f3bd55aa578..b49ca928b6292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -42,7 +42,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala index 9b1e16c727647..858a0d8fde3ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -43,7 +43,7 @@ class JoinOrderSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 8fca5e2167d04..adaeb513bc1b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -223,6 +222,18 @@ class PlannerSuite extends SharedSQLContext { } } + test("collapse adjacent repartitions") { + val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) + def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length + assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + doubleRepartitioned.queryExecution.optimizedPlan match { + case r: Repartition => + assert(r.numPartitions === 5) + assert(r.shuffle === false) + } + } + // --- Unit tests of EnsureRequirements --------------------------------------------------------- // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1654594538366..fc5725d6915ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -188,7 +188,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - ProjectCollapsing, + CollapseProject, // Used to handle other auxiliary `Project`s added by analyzer (e.g. // `ResolveAggregateFunctions` rule) From c756bda477f458ba4aad7fdb2026263507e0ad9b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 12:04:54 -0800 Subject: [PATCH 1807/2089] [SPARK-12330][MESOS][HOTFIX] Rename timeout config The config already describes time and accepts a general format that is not restricted to ms. This commit renames the internal config to use a format that's consistent in Spark. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 722293bb7a53b..8ed7553258e7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -62,8 +62,8 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt - private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdown.ms", "10s") - .ensuring(_ >= 0, "spark.mesos.coarse.shutdown.ms must be >= 0") + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false From bd38dd6f75c4af0f8f32bb21a82da53fffa5e825 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 12:20:18 -0800 Subject: [PATCH 1808/2089] [SPARK-13079][SQL] InMemoryCatalog follow-ups This patch incorporates review feedback from #11069, which is already merged. Author: Andrew Or Closes #11080 from andrewor14/catalog-follow-ups. --- .../spark/sql/catalyst/catalog/interface.scala | 15 +++++++++++++++ .../sql/catalyst/catalog/CatalogTestCases.scala | 12 +++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b4d7dd2f4e31c..56aaa6bc6c2e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -39,6 +39,9 @@ abstract class Catalog { def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + /** + * Alter an existing database. This operation does not support renaming. + */ def alterDatabase(db: String, dbDefinition: Database): Unit def getDatabase(db: String): Database @@ -57,6 +60,9 @@ abstract class Catalog { def renameTable(db: String, oldName: String, newName: String): Unit + /** + * Alter an existing table. This operation does not support renaming. + */ def alterTable(db: String, table: String, tableDefinition: Table): Unit def getTable(db: String, table: String): Table @@ -81,6 +87,9 @@ abstract class Catalog { parts: Seq[PartitionSpec], ignoreIfNotExists: Boolean): Unit + /** + * Alter an existing table partition and optionally override its spec. + */ def alterPartition( db: String, table: String, @@ -100,6 +109,9 @@ abstract class Catalog { def dropFunction(db: String, funcName: String): Unit + /** + * Alter an existing function and optionally override its name. + */ def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit def getFunction(db: String, funcName: String): Function @@ -194,5 +206,8 @@ case class Database( object Catalog { + /** + * Specifications of a table partition indexed by column name. + */ type PartitionSpec = Map[String, String] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 0d8434323fcbb..45c5ceecb0eef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.AnalysisException * Implementations of the [[Catalog]] interface can create test suites by extending this. */ abstract class CatalogTestCases extends SparkFunSuite { - private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map.empty[String, String]) - private val part1 = TablePartition(Map[String, String]("a" -> "1"), storageFormat) - private val part2 = TablePartition(Map[String, String]("b" -> "2"), storageFormat) - private val part3 = TablePartition(Map[String, String]("c" -> "3"), storageFormat) + private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map()) + private val part1 = TablePartition(Map("a" -> "1"), storageFormat) + private val part2 = TablePartition(Map("b" -> "2"), storageFormat) + private val part3 = TablePartition(Map("c" -> "3"), storageFormat) private val funcClass = "org.apache.spark.myFunc" protected def newEmptyCatalog(): Catalog @@ -42,6 +42,8 @@ abstract class CatalogTestCases extends SparkFunSuite { * db2 * - tbl1 * - tbl2 + * - part1 + * - part2 * - func1 */ private def newBasicCatalog(): Catalog = { @@ -50,8 +52,8 @@ abstract class CatalogTestCases extends SparkFunSuite { catalog.createDatabase(newDb("db2"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) - catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) catalog } From 8e2f296306131e6c7c2f06d6672995d3ff8ab021 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 4 Feb 2016 12:43:16 -0800 Subject: [PATCH 1809/2089] [SPARK-13195][STREAMING] Fix NoSuchElementException when a state is not set but timeoutThreshold is defined Check the state Existence before calling get. Author: Shixiong Zhu Closes #11081 from zsxwing/SPARK-13195. --- .../org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 3 ++- .../apache/spark/streaming/rdd/MapWithStateRDDSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 1d2244eaf22b3..6ab1956bed900 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -57,7 +57,8 @@ private[streaming] object MapWithStateRDDRecord { val returned = mappingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) - } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { + } else if (wrappedState.isUpdated + || (wrappedState.exists && timeoutThresholdTime.isDefined)) { newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } mappedData ++= returned diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index 5b13fd6ad611a..e8c814ba7184b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -190,6 +190,11 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + // If a state is not set but timeoutThreshold is defined, we should ignore this state. + // Previously it threw NoSuchElementException (SPARK-13195). + assertRecordUpdate(initStates = Seq(), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil) } test("states generated by MapWithStateRDD") { From 7a4b37f02cffd6d971c07716688a7cb6cee26c8b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 4 Feb 2016 12:47:32 -0800 Subject: [PATCH 1810/2089] [HOTFIX] Fix style violation caused by c756bda --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 8ed7553258e7b..b82196c86b637 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -62,8 +62,9 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt - private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") - .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") + private[this] val shutdownTimeoutMS = + conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false From 6dbfc40776514c3a5667161ebe7829f4cc9c7529 Mon Sep 17 00:00:00 2001 From: Raafat Akkad Date: Thu, 4 Feb 2016 16:09:31 -0800 Subject: [PATCH 1811/2089] [SPARK-13052] waitingApps metric doesn't show the number of apps currently in the WAITING state Author: Raafat Akkad Closes #10959 from RaafatAkkad/master. --- core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- .../scala/org/apache/spark/deploy/master/MasterSource.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 202a1b787c21b..0f11f680b3914 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -74,7 +74,7 @@ private[deploy] class Master( val workers = new HashSet[WorkerInfo] val idToApp = new HashMap[String, ApplicationInfo] - val waitingApps = new ArrayBuffer[ApplicationInfo] + private val waitingApps = new ArrayBuffer[ApplicationInfo] val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 66a9ff38678c6..39b2647a900f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -42,6 +42,6 @@ private[spark] class MasterSource(val master: Master) extends Source { // Gauge for waiting application numbers in cluster metricRegistry.register(MetricRegistry.name("waitingApps"), new Gauge[Int] { - override def getValue: Int = master.waitingApps.size + override def getValue: Int = master.apps.filter(_.state == ApplicationState.WAITING).size }) } From e3c75c6398b1241500343ff237e9bcf78b5396f9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Feb 2016 18:37:58 -0800 Subject: [PATCH 1812/2089] [SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Bucketed Tables) JIRA: https://issues.apache.org/jira/browse/SPARK-12850 This PR is to support bucket pruning when the predicates are `EqualTo`, `EqualNullSafe`, `IsNull`, `In`, and `InSet`. Like HIVE, in this PR, the bucket pruning works when the bucketing key has one and only one column. So far, I do not find a way to verify how many buckets are actually scanned. However, I did verify it when doing the debug. Could you provide a suggestion how to do it properly? Thank you! cloud-fan yhuai rxin marmbrus BTW, we can add more cases to support complex predicate including `Or` and `And`. Please let me know if I should do it in this PR. Maybe we also need to add test cases to verify if bucket pruning works well for each data type. Author: gatorsmile Closes #10942 from gatorsmile/pruningBuckets. --- .../datasources/DataSourceStrategy.scala | 78 ++++++++- .../apache/spark/sql/sources/interfaces.scala | 15 +- .../spark/sql/sources/BucketedReadSuite.scala | 162 +++++++++++++++++- 3 files changed, 245 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index da9320ffb61c3..c24967abeb33e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.BitSet /** * A Strategy for planning scans over data sources defined using the sources API. @@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (partitionAndNormalColumnAttrs ++ projects).toSeq } + // Prune the buckets based on the pushed filters that do not contain partitioning key + // since the bucketing key is not allowed to use the columns in partitioning key + val bucketSet = getBuckets(pushedFilters, t.getBucketSpec) + val scan = buildPartitionedTableScan( l, partitionAndNormalColumnProjs, pushedFilters, + bucketSet, t.partitionSpec.partitionColumns, selectedPartitions) @@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) + // Prune the buckets based on the filters + val bucketSet = getBuckets(filters, t.getBucketSpec) pruneFilterProject( l, projects, filters, - (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil + (a, f) => + t.buildInternalScan(a.map(_.name).toArray, f, bucketSet, t.paths, confBroadcast)) :: Nil case l @ LogicalRelation(baseRelation: TableScan, _, _) => execution.PhysicalRDD.createFromDataSource( @@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { logicalRelation: LogicalRelation, projections: Seq[NamedExpression], filters: Seq[Expression], + buckets: Option[BitSet], partitionColumns: StructType, partitions: Array[Partition]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] @@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. val dataRows = relation.buildInternalScan( - requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast) + requiredDataColumns.map(_.name).toArray, filters, buckets, Array(dir), confBroadcast) // Merges data values with partition values. mergeWithPartitionValues( @@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } } + // Get the bucket ID based on the bucketing values. + // Restriction: Bucket pruning works iff the bucketing column has one and only one column. + def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType)) + mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + val bucketIdGeneration = UnsafeProjection.create( + HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + + bucketIdGeneration(mutableRow).getInt(0) + } + + // Get the bucket BitSet by reading the filters that only contains bucketing keys. + // Note: When the returned BitSet is None, no pruning is possible. + // Restriction: Bucket pruning works iff the bucketing column has one and only one column. + private def getBuckets( + filters: Seq[Expression], + bucketSpec: Option[BucketSpec]): Option[BitSet] = { + + if (bucketSpec.isEmpty || + bucketSpec.get.numBuckets == 1 || + bucketSpec.get.bucketColumnNames.length != 1) { + // None means all the buckets need to be scanned + return None + } + + // Just get the first because bucketing pruning only works when the column has one column + val bucketColumnName = bucketSpec.get.bucketColumnNames.head + val numBuckets = bucketSpec.get.numBuckets + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.clear() + + filters.foreach { + case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, v)) + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + val hSet = list.map(e => e.eval(EmptyRow)) + hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e))) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + matchedBuckets.set(getBucketId(a, numBuckets, null)) + case _ => + } + + logInfo { + val selected = matchedBuckets.cardinality() + val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100 + s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions." + } + + // None means all the buckets need to be scanned + if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) + } + protected def prunePartitions( predicates: Seq[Expression], partitionSpec: PartitionSpec): Seq[Partition] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 299fc6efbb046..737be7dfd12f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection.BitSet /** * ::DeveloperApi:: @@ -722,6 +723,7 @@ abstract class HadoopFsRelation private[sql]( final private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], + bucketSet: Option[BitSet], inputPaths: Array[String], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val inputStatuses = inputPaths.flatMap { input => @@ -743,9 +745,16 @@ abstract class HadoopFsRelation private[sql]( // id from file name. Then read these files into a RDD(use one-partition empty RDD for empty // bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result. val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId => - groupedBucketFiles.get(bucketId).map { inputStatuses => - buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1) - }.getOrElse(sqlContext.emptyResult) + // If the current bucketId is not set in the bucket bitSet, skip scanning it. + if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){ + sqlContext.emptyResult + } else { + // When all the buckets need a scan (i.e., bucketSet is equal to None) + // or when the current bucket need a scan (i.e., the bit of bucketId is set to true) + groupedBucketFiles.get(bucketId).map { inputStatuses => + buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1) + }.getOrElse(sqlContext.emptyResult) + } } new UnionRDD(sqlContext.sparkContext, perBucketRows) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 150d0c748631e..9ba645626fe72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -19,22 +19,28 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.Exchange -import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.execution.{Exchange, PhysicalRDD} +import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy} import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private val nullDF = (for { + i <- 0 to 50 + s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") + } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + test("read bucketed data") { - val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") withTable("bucketed_table") { df.write .format("parquet") @@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } + // To verify if the bucket pruning works, this function checks two conditions: + // 1) Check if the pruned buckets (before filtering) are empty. + // 2) Verify the final result is the same as the expected one + private def checkPrunedAnswers( + bucketSpec: BucketSpec, + bucketValues: Seq[Integer], + filterCondition: Column, + originalDataFrame: DataFrame): Unit = { + + val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k") + val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec + // Limit: bucket pruning only works when the bucket column has one and only one column + assert(bucketColumnNames.length == 1) + val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) + val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + } + + // Filter could hide the bug in bucket pruning. Thus, skipping all the filters + val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan + .find(_.isInstanceOf[PhysicalRDD]) + assert(rdd.isDefined) + + val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty) + } + // checking if all the pruned buckets are empty + assert(checkedResult.collect().forall(_ == true)) + + checkAnswer( + bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), + originalDataFrame.filter(filterCondition).orderBy("i", "j", "k")) + } + + test("read partitioning bucketed tables with bucket pruning filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + // Case 1: EqualTo + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j, + df) + + // Case 2: EqualNullSafe + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" <=> j, + df) + + // Case 3: In + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), + df) + } + } + } + + test("read non-partitioning bucketed tables with bucket pruning filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j, + df) + } + } + } + + test("read partitioning bucketed tables having null in bucketing key") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + nullDF.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + // Case 1: isNull + checkPrunedAnswers( + bucketSpec, + bucketValues = null :: Nil, + filterCondition = $"j".isNull, + nullDF) + + // Case 2: <=> null + checkPrunedAnswers( + bucketSpec, + bucketValues = null :: Nil, + filterCondition = $"j" <=> null, + nullDF) + } + } + + test("read partitioning bucketed tables having composite filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j && $"k" > $"j", + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j && $"i" > j % 5, + df) + } + } + } + private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") From 352102ed0b7be8c335553d7e0389fd7ce83f5fbf Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Thu, 4 Feb 2016 22:22:41 -0800 Subject: [PATCH 1813/2089] [SPARK-13208][CORE] Replace use of Pairs with Tuple2s Another trivial deprecation fix for Scala 2.11 Author: Jakob Odersky Closes #11089 from jodersky/SPARK-13208. --- .../scala/org/apache/spark/api/java/JavaDoubleRDD.scala | 2 +- .../scala/org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 ++-- .../execution/datasources/parquet/ParquetSchemaSuite.scala | 2 +- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 37ae007f880c2..13e18a56c8fd8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -230,7 +230,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { + def histogram(bucketCount: Int): (Array[scala.Double], Array[Long]) = { val result = srdd.histogram(bucketCount) (result._1, result._2) } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index c07f346bbafd5..bd61d04d42f05 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -103,7 +103,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope { + def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope { // Scala's built-in range has issues. See #SI-8782 def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { val span = max - min @@ -112,7 +112,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, - Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => + Double.PositiveInfinity)((e: Double, x: (Double, Double)) => (x._1.max(e), x._2.min(e)))) }.reduce { (maxmin1, maxmin2) => (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index d860651d421f8..90e3d50714ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -261,7 +261,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { int96AsTimestamp = true, writeLegacyParquetFormat = true) - testSchemaInference[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[(Int, String)]]( "struct", """ |message root { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 9632d27a2ffce..1337a25eb26a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -770,14 +770,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) - .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + .zipWithIndex.map {case ((value, attr), key) => HavingRow(key, value, attr)} TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() - .map(x => Pair(x.getString(0), x.getInt(1))) + .map(x => (x.getString(0), x.getInt(1))) - assert(results === Array(Pair("foo", 4))) + assert(results === Array(("foo", 4))) TestHive.reset() } From 82d84ff2dd3efb3bda20b529f09a4022586fb722 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Feb 2016 22:43:44 -0800 Subject: [PATCH 1814/2089] [SPARK-13187][SQL] Add boolean/long/double options in DataFrameReader/Writer This patch adds option function for boolean, long, and double types. This makes it slightly easier for Spark users to specify options without turning them into strings. Using the JSON data source as an example. Before this patch: ```scala sqlContext.read.option("primitivesAsString", "true").json("/path/to/json") ``` After this patch: Before this patch: ```scala sqlContext.read.option("primitivesAsString", true).json("/path/to/json") ``` Author: Reynold Xin Closes #11072 from rxin/SPARK-13187. --- .../apache/spark/sql/DataFrameReader.scala | 21 ++++++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 21 ++++++++++++++++ .../sql/streaming/DataStreamReaderSuite.scala | 25 +++++++++++++++++++ 3 files changed, 67 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a58643a5ba154..962fdadf1431d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -78,6 +78,27 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { this } + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameReader = option(key, value.toString) + /** * (Scala-specific) Adds input options for the underlying data source. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 80447fefe1f2a..8060198968988 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -95,6 +95,27 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameWriter = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameWriter = option(key, value.toString) + /** * (Scala-specific) Adds output options for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala index b36b41cac9b4f..95a17f338d04d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala @@ -162,4 +162,29 @@ class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { assert(LastOptions.parameters("path") == "/test") } + test("test different data types for options") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .stream("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + + LastOptions.parameters = null + df.write + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .stream("/test") + .stop() + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + } } From 7b73f1719cff233645c7850a5dbc8ed2dc9c9a58 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 5 Feb 2016 13:44:34 -0800 Subject: [PATCH 1815/2089] [SPARK-13166][SQL] Rename DataStreamReaderWriterSuite to DataFrameReaderWriterSuite A follow up PR for #11062 because it didn't rename the test suite. Author: Shixiong Zhu Closes #11096 from zsxwing/rename. --- ...StreamReaderSuite.scala => DataFrameReaderWriterSuite.scala} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{DataStreamReaderSuite.scala => DataFrameReaderWriterSuite.scala} (98%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 95a17f338d04d..36212e4395985 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataStreamReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -56,7 +56,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with SharedSQLContext { +class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext { import testImplicits._ test("resolve default source") { From 1ed354a5362967d904e9513e5a1618676c9c67a6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Feb 2016 14:34:12 -0800 Subject: [PATCH 1816/2089] [SPARK-12939][SQL] migrate encoder resolution logic to Analyzer https://issues.apache.org/jira/browse/SPARK-12939 Now we will catch `ObjectOperator` in `Analyzer` and resolve the `fromRowExpression/deserializer` inside it. Also update the `MapGroups` and `CoGroup` to pass in `dataAttributes`, so that we can correctly resolve value deserializer(the `child.output` contains both groupking key and values, which may mess things up if they have same-name attribtues). End-to-end tests are added. follow-ups: * remove encoders from typed aggregate expression. * completely remove resolve/bind in `ExpressionEncoder` Author: Wenchen Fan Closes #10852 from cloud-fan/bug. --- .../sql/catalyst/analysis/Analyzer.scala | 82 +++++++++++++------ .../catalyst/encoders/ExpressionEncoder.scala | 50 +++++------ .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 9 +- .../sql/catalyst/plans/logical/object.scala | 68 +++++++++++---- .../encoders/EncoderResolutionSuite.scala | 8 +- .../encoders/ExpressionEncoderSuite.scala | 5 +- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../org/apache/spark/sql/GroupedDataset.scala | 3 + .../spark/sql/execution/SparkStrategies.scala | 9 +- .../apache/spark/sql/execution/objects.scala | 31 +++---- .../org/apache/spark/sql/DatasetSuite.scala | 64 +++++++++++++-- 12 files changed, 230 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b59eb12419c45..cb228cf52b433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ @@ -457,25 +458,34 @@ class Analyzer( // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) if child.resolved && !generator.resolved => - val newG = generator transformUp { - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } - case UnresolvedExtractValue(child, fieldExpr) => - ExtractValue(child, fieldExpr, resolver) - } + val newG = resolveExpression(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator + // should be resolved by their corresponding attributes instead of children's output. + case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => + val deserializerToAttributes = o.deserializers.map { + case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes + }.toMap + + o.transformExpressions { + case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes => + resolveDeserializer(expr, attributes) + }.getOrElse(expr) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { @@ -490,6 +500,32 @@ class Analyzer( } } + private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined + } + } + + def resolveDeserializer( + deserializer: Expression, + attributes: Seq[Attribute]): Expression = { + val unbound = deserializer transform { + case b: BoundReference => attributes(b.ordinal) + } + + resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(Literal.fromObject(outer))) + } + } + def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)() @@ -508,23 +544,20 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { - ordering.map { order => - // Resolve SortOrder in one round. - // If throws == false or the desired attribute doesn't exist - // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. - // Else, throw exception. - try { - val newOrder = order transformUp { - case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).getOrElse(u) - case UnresolvedExtractValue(child, fieldName) if child.resolved => - ExtractValue(child, fieldName, resolver) - } - newOrder.asInstanceOf[SortOrder] - } catch { - case a: AnalysisException if !throws => order + private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { + // Resolve expression in one round. + // If throws == false or the desired attribute doesn't exist + // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. + // Else, throw exception. + try { + expr transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) } + } catch { + case a: AnalysisException if !throws => expr } } @@ -619,7 +652,8 @@ class Analyzer( ordering: Seq[SortOrder], plan: LogicalPlan, child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 64832dc114e67..58f6d0eb9e929 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -50,7 +50,7 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(typeTag[T].tpe) val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] @@ -257,12 +257,10 @@ case class ExpressionEncoder[T]( } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce + * friendly error messages to explain why it fails to resolve if there is something wrong. */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + def validate(schema: Seq[Attribute]): Unit = { def fail(st: StructType, maxOrdinal: Int): Unit = { throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + "but failed as the number of fields does not line up.\n" + @@ -270,6 +268,8 @@ case class ExpressionEncoder[T]( " - Target schema: " + this.schema.simpleString) } + // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all + // `BoundReference`, make sure their ordinals are all valid. var maxOrdinal = -1 fromRowExpression.foreach { case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal @@ -279,6 +279,10 @@ case class ExpressionEncoder[T]( fail(StructType.fromAttributes(schema), maxOrdinal) } + // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of + // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. + // Note that, `BoundReference` contains the expected type, but here we need the actual type, so + // we unbound it by the given `schema` and propagate the actual type to `GetStructField`. val unbound = fromRowExpression transform { case b: BoundReference => schema(b.ordinal) } @@ -299,28 +303,24 @@ case class ExpressionEncoder[T]( fail(schema, maxOrdinal) } } + } - val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer( + fromRowExpression, schema) + + // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check + // analysis, go through optimizer, etc. + val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - val optimizedPlan = SimplifyCasts(analyzedPlan) - - // In order to construct instances of inner classes (for example those declared in a REPL cell), - // we need an instance of the outer scope. This rule substitues those outer objects into - // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` - // registry. - copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => - val outer = outerScopes.get(n.cls.getDeclaringClass.getName) - if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + - s"to the scope that this class was defined in. " + "" + - "Try moving this class out of its parent class.") - } - - n.copy(outerPointer = Some(Literal.fromObject(outer))) - }) + copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 89d40b3b2c141..d8f755a39c7ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -154,7 +154,7 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable)) + constructorFor(field) ) } CreateExternalRow(fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a1ac93073916c..902e18081bddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,10 +119,13 @@ object SamplePushDown extends Rule[LogicalPlan] { */ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case m @ MapPartitions(_, input, _, child: ObjectOperator) - if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType => + case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => val childWithoutSerialization = child.withObjectOutput - m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization) + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 760348052739c..3f97662957b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{ObjectType, StructType} /** * A trait for logical operators that apply user defined functions to domain objects. @@ -30,6 +30,15 @@ trait ObjectOperator extends LogicalPlan { /** The serializer that is used to produce the output of this operator. */ def serializer: Seq[NamedExpression] + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + /** + * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. + * It must also provide the attributes that are available during the resolution of each + * deserializer. + */ + def deserializers: Seq[(Expression, Seq[Attribute])] + /** * The object type that is produced by the user defined function. Note that the return type here * is the same whether or not the operator is output serialized data. @@ -44,13 +53,13 @@ trait ObjectOperator extends LogicalPlan { def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { this } else { - withNewSerializer(outputObject) + withNewSerializer(outputObject :: Nil) } /** Returns a copy of this operator with a different serializer. */ - def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy { + def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { productIterator.map { - case c if c == serializer => newSerializer :: Nil + case c if c == serializer => newSerializer case other: AnyRef => other }.toArray } @@ -70,15 +79,16 @@ object MapPartitions { /** * A relation produced by applying `func` to each partition of the `child`. - * @param input used to extract the input to `func` from an input row. + * + * @param deserializer used to extract the input to `func` from an input row. * @param serializer use to serialize the output of `func`. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `AppendColumn` nodes. */ @@ -97,16 +107,21 @@ object AppendColumns { /** * A relation produced by applying `func` to each partition of the `child`, concatenating the * resulting columns at the end of the input row. - * @param input used to extract the input to `func` from an input row. + * + * @param deserializer used to extract the input to `func` from an input row. * @param serializer use to serialize the output of `func`. */ case class AppendColumns( func: Any => Any, - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = child.output ++ newColumns + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) + + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `MapGroups` nodes. */ @@ -114,6 +129,7 @@ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], @@ -121,6 +137,7 @@ object MapGroups { encoderFor[T].fromRowExpression, encoderFor[U].namedExpressions, groupingAttributes, + dataAttributes, child) } } @@ -129,19 +146,22 @@ object MapGroups { * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. * Func is invoked with an object representation of the grouping key an iterator containing the * object representation of all the rows with that key. - * @param keyObject used to extract the key object for each group. - * @param input used to extract the items in the iterator from an input row. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. * @param serializer use to serialize the output of `func`. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - input: Expression, + keyDeserializer: Expression, + valueDeserializer: Expression, serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: LogicalPlan) extends UnaryNode with ObjectOperator { - def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = + Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) } /** Factory for constructing new `CoGroup` nodes. */ @@ -150,8 +170,12 @@ object CoGroup { func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftData: Seq[Attribute], + rightData: Seq[Attribute], left: LogicalPlan, right: LogicalPlan): CoGroup = { + require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], encoderFor[Key].fromRowExpression, @@ -160,6 +184,8 @@ object CoGroup { encoderFor[Result].namedExpressions, leftGroup, rightGroup, + leftData, + rightData, left, right) } @@ -171,15 +197,21 @@ object CoGroup { */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - leftObject: Expression, - rightObject: Expression, + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectOperator { + override def producedAttributes: AttributeSet = outputSet - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve + // the `keyDeserializer` based on either of them, here we pick the left one. + Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index bc36a55ae0ea2..92a68a4dba915 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -127,7 +127,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct\n" + @@ -136,7 +136,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct\n" + @@ -149,7 +149,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct>\n" + @@ -158,7 +158,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct>\n" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 88c558d80a79a..e00060f9b6aff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -19,13 +19,10 @@ package org.apache.spark.sql.catalyst.encoders import java.sql.{Date, Timestamp} import java.util.Arrays -import java.util.concurrent.ConcurrentMap import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag -import com.google.common.collect.MapMaker - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} @@ -78,7 +75,7 @@ class JavaSerializable(val value: Int) extends Serializable { } class ExpressionEncoderSuite extends SparkFunSuite { - OuterScopes.outerScopes.put(getClass.getName, this) + OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f182270a08729..378763268acc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,6 +74,7 @@ class Dataset[T] private[sql]( * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + unresolvedTEncoder.validate(logicalPlan.output) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = @@ -85,7 +86,7 @@ class Dataset[T] private[sql]( */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) - private implicit def classTag = resolvedTEncoder.clsTag + private implicit def classTag = unresolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b3f8284364782..c0e28f2dc5bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -116,6 +116,7 @@ class GroupedDataset[K, V] private[sql]( MapGroups( f, groupingAttributes, + dataAttributes, logicalPlan)) } @@ -310,6 +311,8 @@ class GroupedDataset[K, V] private[sql]( f, this.groupingAttributes, other.groupingAttributes, + this.dataAttributes, + other.dataAttributes, this.logicalPlan, other.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9293e55141757..830bb011beab4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -306,11 +306,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.MapPartitions(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil - case logical.MapGroups(f, key, in, out, grouping, child) => - execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil - case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) => + case logical.MapGroups(f, key, in, out, grouping, data, child) => + execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil + case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) => execution.CoGroup( - f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil + f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, + planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2acca1743cbb9..582dda8603f4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -53,14 +53,14 @@ trait ObjectOperator extends SparkPlan { */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator { override def output: Seq[Attribute] = serializer.map(_.toAttribute) override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(input, child.output) + val getObject = generateToObject(deserializer, child.output) val outputObject = generateToRow(serializer) func(iter.map(getObject)).map(outputObject) } @@ -72,7 +72,7 @@ case class MapPartitions( */ case class AppendColumns( func: Any => Any, - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator { @@ -82,7 +82,7 @@ case class AppendColumns( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(input, child.output) + val getObject = generateToObject(deserializer, child.output) val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) val outputObject = generateToRow(serializer) @@ -103,10 +103,11 @@ case class AppendColumns( */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - input: Expression, + keyDeserializer: Expression, + valueDeserializer: Expression, serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: SparkPlan) extends UnaryNode with ObjectOperator { override def output: Seq[Attribute] = serializer.map(_.toAttribute) @@ -121,8 +122,8 @@ case class MapGroups( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val getKey = generateToObject(keyObject, groupingAttributes) - val getValue = generateToObject(input, child.output) + val getKey = generateToObject(keyDeserializer, groupingAttributes) + val getValue = generateToObject(valueDeserializer, dataAttributes) val outputObject = generateToRow(serializer) grouped.flatMap { case (key, rowIter) => @@ -142,12 +143,14 @@ case class MapGroups( */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - leftObject: Expression, - rightObject: Expression, + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: SparkPlan, right: SparkPlan) extends BinaryNode with ObjectOperator { @@ -164,9 +167,9 @@ case class CoGroup( val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val getKey = generateToObject(keyObject, leftGroup) - val getLeft = generateToObject(leftObject, left.output) - val getRight = generateToObject(rightObject, right.output) + val getKey = generateToObject(keyDeserializer, leftGroup) + val getLeft = generateToObject(leftDeserializer, leftAttr) + val getRight = generateToObject(rightDeserializer, rightAttr) val outputObject = generateToRow(serializer) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b69bb21db532b..374f4320a9239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -45,13 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } - test("SPARK-12404: Datatype Helper Serializablity") { val ds = sparkContext.parallelize(( - new Timestamp(0), - new Date(0), - java.math.BigDecimal.valueOf(1), - scala.math.BigDecimal(1)) :: Nil).toDS() + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() ds.collect() } @@ -523,7 +523,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("verify mismatching field names fail with a good error") { val ds = Seq(ClassData("a", 1)).toDS() val e = intercept[AnalysisException] { - ds.as[ClassData2].collect() + ds.as[ClassData2] } assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) } @@ -567,6 +567,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) checkAnswer(ds1.toDF(), Row(Row(null))) } + + test("support inner class in Dataset") { + val outer = new OuterClass + OuterScopes.addOuterScope(outer) + val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS() + checkAnswer(ds.map(_.a), "1", "2") + } + + test("grouping key and grouped value has field with same name") { + val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS() + val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups { + case (key, values) => key.a + values.map(_.b).sum + } + + checkAnswer(agged, "a3") + } + + test("cogroup's left and right side has field with same name") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS() + val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) { + case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum) + } + + checkAnswer(cogrouped, "a13", "b24") + } + + test("give nice error message when the real number of fields doesn't match encoder schema") { + val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + val message = intercept[AnalysisException] { + ds.as[(String, Int, Long)] + }.message + assert(message == + "Try to map struct to Tuple3, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:int,_3:bigint>") + + val message2 = intercept[AnalysisException] { + ds.as[Tuple1[String]] + }.message + assert(message2 == + "Try to map struct to Tuple1, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string>") + } +} + +class OuterClass extends Serializable { + case class InnerClass(a: String) } case class ClassData(a: String, b: Int) From 66e1383de2650a0f06929db8109a02e32c5eaf6b Mon Sep 17 00:00:00 2001 From: Bill Chambers Date: Fri, 5 Feb 2016 14:35:39 -0800 Subject: [PATCH 1817/2089] [SPARK-13214][DOCS] update dynamicAllocation documentation Author: Bill Chambers Closes #11094 from anabranch/dynamic-docs. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 93b399d819ccd..cd9dc1bcfc113 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1169,8 +1169,8 @@ Apart from these, the following properties are also available, and may be useful false Whether to use dynamic resource allocation, which scales the number of executors registered - with this application up and down based on the workload. Note that this is currently only - available on YARN mode. For more detail, see the description + with this application up and down based on the workload. + For more detail, see the description here.

        This requires spark.shuffle.service.enabled to be set. From 0bb5b73387e60ef007b415fba69a3e1e89a4f013 Mon Sep 17 00:00:00 2001 From: Luc Bourlier Date: Fri, 5 Feb 2016 14:37:42 -0800 Subject: [PATCH 1818/2089] [SPARK-13002][MESOS] Send initial request of executors for dyn allocation Fix for [SPARK-13002](https://issues.apache.org/jira/browse/SPARK-13002) about the initial number of executors when running with dynamic allocation on Mesos. Instead of fixing it just for the Mesos case, made the change in `ExecutorAllocationManager`. It is already driving the number of executors running on Mesos, only no the initial value. The `None` and `Some(0)` are internal details on the computation of resources to reserved, in the Mesos backend scheduler. `executorLimitOption` has to be initialized correctly, otherwise the Mesos backend scheduler will, either, create to many executors at launch, or not create any executors and not be able to recover from this state. Removed the 'special case' description in the doc. It was not totally accurate, and is not needed anymore. This doesn't fix the same problem visible with Spark standalone. There is no straightforward way to send the initial value in standalone mode. Somebody knowing this part of the yarn support should review this change. Author: Luc Bourlier Closes #11047 from skyluc/issue/initial-dyn-alloc-2. --- .../mesos/CoarseMesosSchedulerBackend.scala | 13 +++++++++--- docs/running-on-mesos.md | 21 ++++++++----------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b82196c86b637..0a2d72f4dcb4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -87,10 +87,17 @@ private[spark] class CoarseMesosSchedulerBackend( val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] /** - * The total number of executors we aim to have. Undefined when not using dynamic allocation - * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + * The total number of executors we aim to have. Undefined when not using dynamic allocation. + * Initially set to 0 when using dynamic allocation, the executor allocation manager will send + * the real initial limit later. */ - private var executorLimitOption: Option[Int] = None + private var executorLimitOption: Option[Int] = { + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + } /** * Return the current executor limit, which may be [[Int.MaxValue]] diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 0ef1ccb36e117..e1c87a8d95a32 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -246,18 +246,15 @@ In either case, HDFS runs separately from Hadoop MapReduce, without being schedu # Dynamic Resource Allocation with Mesos -Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics -of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down -since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler -can scale back up to the same amount of executors when Spark signals more executors are needed. - -Users that like to utilize this feature should launch the Mesos Shuffle Service that -provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's -termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh -scripts accordingly. - -The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos -is to launch the Shuffle Service with Marathon with a unique host constraint. +Mesos supports dynamic allocation only with coarse-grain mode, which can resize the number of +executors based on statistics of the application. For general information, +see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocation). + +The External Shuffle Service to use is the Mesos Shuffle Service. It provides shuffle data cleanup functionality +on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all slave nodes, with `spark.shuffle.service.enabled` set to `true`. + +This can also be achieved through Marathon, using a unique host constraint, and the following command: `bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`. # Configuration From 875f5079290a06c12d44de657dfeabc913e4a303 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Feb 2016 15:07:43 -0800 Subject: [PATCH 1819/2089] [SPARK-13215] [SQL] remove fallback in codegen Since we remove the configuration for codegen, we are heavily reply on codegen (also TungstenAggregate require the generated MutableProjection to update UnsafeRow), should remove the fallback, which could make user confusing, see the discussion in SPARK-13116. Author: Davies Liu Closes #11097 from davies/remove_fallback. --- .../spark/sql/execution/SparkPlan.scala | 36 ++----------------- .../spark/sql/execution/aggregate/udaf.scala | 12 ++----- .../spark/sql/execution/local/LocalNode.scala | 26 ++------------ 3 files changed, 8 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b19b772409d83..3cc99d3c7b1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -200,47 +200,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute], useSubexprElimination: Boolean = false): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") - try { - GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } + GeneratePredicate.generate(expression, inputSchema) } protected def newOrdering( order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { - try { - GenerateOrdering.generate(order, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate ordering, fallback to interpreted", e) - new InterpretedOrdering(order, inputSchema) - } - } + GenerateOrdering.generate(order, inputSchema) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 5a19920add717..812e696338362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedMutableProjection, MutableRow} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, ImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -361,13 +361,7 @@ private[sql] case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - try { - GenerateMutableProjection.generate(children, inputAttributes)() - } catch { - case e: Exception => - log.error("Failed to generate mutable projection, fallback to interpreted", e) - new InterpretedMutableProjection(children, inputAttributes) - } + GenerateMutableProjection.generate(children, inputAttributes)() } private[this] lazy val inputToScalaConverters: Any => Any = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index a0dfe996ccd55..8726e4878106d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.Logging import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -96,33 +94,13 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema") - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } + GenerateMutableProjection.generate(expressions, inputSchema) } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } + GeneratePredicate.generate(expression, inputSchema) } } From 6883a5120c6b3a806024dfad03b8bf7371c6c0eb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Fri, 5 Feb 2016 19:00:12 -0800 Subject: [PATCH 1820/2089] [SPARK-13171][CORE] Replace future calls with Future Trivial search-and-replace to eliminate deprecation warnings in Scala 2.11. Also works with 2.10 Author: Jakob Odersky Closes #11085 from jodersky/SPARK-13171. --- .../spark/deploy/FaultToleranceTest.scala | 8 ++++---- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/JobCancellationSuite.scala | 18 +++++++++--------- .../ShuffleBlockFetcherIteratorSuite.scala | 6 +++--- .../expressions/CodeGenerationSuite.scala | 2 +- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 2 +- .../thriftserver/HiveThriftServer2Suites.scala | 6 +++--- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c0ede4b7c8734..15d220d01b461 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -22,7 +22,7 @@ import java.net.URL import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer -import scala.concurrent.{future, promise, Await} +import scala.concurrent.{Await, Future, Promise} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.language.postfixOps @@ -249,7 +249,7 @@ private object FaultToleranceTest extends App with Logging { /** This includes Client retry logic, so it may take a while if the cluster is recovering. */ private def assertUsable() = { - val f = future { + val f = Future { try { val res = sc.parallelize(0 until 10).collect() assertTrue(res.toList == (0 until 10)) @@ -283,7 +283,7 @@ private object FaultToleranceTest extends App with Logging { numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1 } - val f = future { + val f = Future { try { while (!stateValid()) { Thread.sleep(1000) @@ -405,7 +405,7 @@ private object SparkDocker { } private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = { - val ipPromise = promise[String]() + val ipPromise = Promise[String]() val outFile = File.createTempFile("fault-tolerance-test", "", Utils.createTempDir()) val outStream: FileWriter = new FileWriter(outFile) def findIpAndLog(line: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 179d3b9f20b1f..df3c286a0a66f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -394,7 +394,7 @@ private[deploy] class Worker( // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. val appIds = executors.values.map(_.appId).toSet - val cleanupFuture = concurrent.future { + val cleanupFuture = concurrent.Future { val appDirs = workDir.listFiles() if (appDirs == null) { throw new IOException("ERROR: Failed to list files in " + appDirs) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index e13a442463e8d..c347ab8dc8020 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.Semaphore import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.future +import scala.concurrent.Future import org.scalatest.BeforeAndAfter import org.scalatest.Matchers @@ -103,7 +103,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val rdd1 = rdd.map(x => x) - future { + Future { taskStartedSemaphore.acquire() sc.cancelAllJobs() taskCancelledSemaphore.release(100000) @@ -126,7 +126,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) // jobA is the one to be cancelled. - val jobA = future { + val jobA = Future { sc.setJobGroup("jobA", "this is a job to be cancelled") sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() } @@ -191,7 +191,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) // jobA is the one to be cancelled. - val jobA = future { + val jobA = Future { sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100000); i }.count() } @@ -231,7 +231,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val f2 = rdd.countAsync() // Kill one of the action. - future { + Future { sem1.acquire() f1.cancel() JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) @@ -247,7 +247,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Cancel before launching any tasks { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() - future { f.cancel() } + Future { f.cancel() } val e = intercept[SparkException] { f.get() } assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -263,7 +263,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() - future { + Future { // Wait until some tasks were launched before we cancel the job. sem.acquire() f.cancel() @@ -277,7 +277,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Cancel before launching any tasks { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) - future { f.cancel() } + Future { f.cancel() } val e = intercept[SparkException] { f.get() } assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -292,7 +292,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) - future { + Future { sem.acquire() f.cancel() } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 828153bdbfc44..c9c2fb2691d70 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,7 +21,7 @@ import java.io.InputStream import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.future +import scala.concurrent.Future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ @@ -149,7 +149,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - future { + Future { // Return the first two blocks, and wait till task completion before returning the 3rd one listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) @@ -211,7 +211,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - future { + Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 0c42e2fc7c5e5..b5413fbe2bbcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -36,7 +36,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { import scala.concurrent.duration._ val futures = (1 to 20).map { _ => - future { + Future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8b275e886c46c..943ad31c0cef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -77,7 +77,7 @@ case class BroadcastHashJoin( // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { + Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sparkContext, executionId) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index db8edd169dcfa..f48fc3b84864d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -76,7 +76,7 @@ case class BroadcastHashOuterJoin( // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { + Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sparkContext, executionId) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ba3b26e1b7d49..865197e24caf8 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -23,7 +23,7 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{future, Await, ExecutionContext, Promise} +import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.io.Source import scala.util.{Random, Try} @@ -362,7 +362,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { try { // Start a very-long-running query that will take hours to finish, then cancel it in order // to demonstrate that cancellation works. - val f = future { + val f = Future { statement.executeQuery( "SELECT COUNT(*) FROM test_map " + List.fill(10)("join test_map").mkString(" ")) @@ -380,7 +380,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") try { - val sf = future { + val sf = Future { statement.executeQuery( "SELECT COUNT(*) FROM test_map " + List.fill(4)("join test_map").mkString(" ") From 4f28291f851b9062da3941e63de4eabb0c77f5d0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 5 Feb 2016 22:40:40 -0800 Subject: [PATCH 1821/2089] [HOTFIX] fix float part of avgRate --- core/src/main/scala/org/apache/spark/util/Benchmark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index d1699f5c28655..1bf6f821e9b31 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -124,7 +124,7 @@ private[spark] object Benchmark { } val best = runTimes.min val avg = runTimes.sum / iters - Result(avg / 1000000, num / (best / 1000), best / 1000000) + Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0) } } From 81da3bee669aaeb79ec68baaf7c99bff6e5d14fe Mon Sep 17 00:00:00 2001 From: Tommy YU Date: Sat, 6 Feb 2016 17:29:09 +0000 Subject: [PATCH 1822/2089] [SPARK-5865][API DOC] Add doc warnings for methods that return local data structures rxin srowen I work out note message for rdd.take function, please help to review. If it's fine, I can apply to all other function later. Author: Tommy YU Closes #10874 from Wenpei/spark-5865-add-warning-for-localdatastructure. --- .../apache/spark/api/java/JavaPairRDD.scala | 3 +++ .../apache/spark/api/java/JavaRDDLike.scala | 24 +++++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 3 +++ .../main/scala/org/apache/spark/rdd/RDD.scala | 15 ++++++++++++ python/pyspark/rdd.py | 17 +++++++++++++ python/pyspark/sql/dataframe.py | 6 +++++ .../org/apache/spark/sql/DataFrame.scala | 4 ++++ 7 files changed, 72 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index fb04472ee73fd..94d103588b696 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -636,6 +636,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. + * + * @note this method should only be used if the resulting data is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 7340defabfe58..37c211fe706e1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -327,6 +327,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collect(): JList[T] = rdd.collect().toSeq.asJava @@ -465,6 +468,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def take(num: Int): JList[T] = rdd.take(num).toSeq.asJava @@ -548,6 +554,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -559,6 +568,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD using the * natural ordering for T and maintains the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of top elements to return * @return an array of top elements */ @@ -570,6 +582,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the first k (smallest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -601,6 +616,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the first k (smallest) elements from this RDD using the * natural ordering for T while maintain the order. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. * @param num k, the number of top elements to return * @return an array of top elements */ @@ -634,6 +652,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * The asynchronous version of `collect`, which returns a future for * retrieving an array containing all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsync(): JavaFutureAction[JList[T]] = { new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava) @@ -642,6 +663,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * The asynchronous version of the `take` action, which returns a * future for retrieving the first `num` elements of this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeAsync(num: Int): JavaFutureAction[JList[T]] = { new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 33f2f0b44f773..61905a8421124 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -726,6 +726,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only * one value per key is preserved in the map returned) + * + * @note this method should only be used if the resulting data is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsMap(): Map[K, V] = self.withScope { val data = self.collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e8157cf4ebe7d..a81a98b526b5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -481,6 +481,9 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator @@ -836,6 +839,9 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) @@ -1202,6 +1208,9 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @note due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ @@ -1263,6 +1272,9 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements @@ -1283,6 +1295,9 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c28594625457a..fe2264a63cf30 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -426,6 +426,9 @@ def takeSample(self, withReplacement, num, seed=None): """ Return a fixed-size sampled subset of this RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) 20 @@ -766,6 +769,8 @@ def func(it): def collect(self): """ Return a list that contains all of the elements in this RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) @@ -1213,6 +1218,9 @@ def top(self, num, key=None): """ Get the top N elements from a RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) @@ -1235,6 +1243,9 @@ def takeOrdered(self, num, key=None): Get the N elements from a RDD ordered in ascending order or as specified by the optional key function. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) @@ -1254,6 +1265,9 @@ def take(self, num): that partition to estimate the number of additional partitions needed to satisfy the limit. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + Translated from the Scala implementation in RDD#take(). >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) @@ -1511,6 +1525,9 @@ def collectAsMap(self): """ Return the key-value pairs in this RDD to the master as a dictionary. + Note that this method should only be used if the resulting data is expected + to be small, as all the data is loaded into the driver's memory. + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() >>> m[1] 2 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 90a6b5d9c0dda..3a8c8305ee3d8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -739,6 +739,9 @@ def describe(self, *cols): def head(self, n=None): """Returns the first ``n`` rows. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row. @@ -1330,6 +1333,9 @@ def toDF(self, *cols): def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + Note that this method should only be used if the resulting Pandas's DataFrame is expected + to be small, as all the data is loaded into the driver's memory. + This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f15b926bd27cf..7aa08fb63053b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1384,6 +1384,10 @@ class DataFrame private[sql]( /** * Returns the first `n` rows. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @group action * @since 1.3.0 */ From bc8890b357811612ba6c10d96374902b9e08134f Mon Sep 17 00:00:00 2001 From: Gary King Date: Sun, 7 Feb 2016 09:13:28 +0000 Subject: [PATCH 1823/2089] [SPARK-13132][MLLIB] cache standardization param value in LogisticRegression cache the value of the standardization Param in LogisticRegression, rather than re-fetching it from the ParamMap for every index and every optimization step in the quasi-newton optimizer also, fix Param#toString to cache the stringified representation, rather than re-interpolating it on every call, so any other implementations that have similar repeated access patterns will see a benefit. this change improves training times for one of my test sets from ~7m30s to ~4m30s Author: Gary King Closes #11027 from idigary/spark-13132-optimize-logistic-regression. --- .../apache/spark/ml/classification/LogisticRegression.scala | 3 ++- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 9b2340a1f16fc..ac0124513f283 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") ( val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept if (index == numFeatures) { 0.0 } else { - if ($(standardization)) { + if (standardizationParam) { regParamL1 } else { // If `standardization` is false, we still standardize the data diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index f48923d69974b..d7d6c0f5fa16e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - override final def toString: String = s"${parent}__$name" + private[this] val stringRepresentation = s"${parent}__$name" + + override final def toString: String = stringRepresentation override final def hashCode: Int = toString.## From 140ddef373680cb08a3948a883b172dc80814170 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Sun, 7 Feb 2016 12:52:00 +0000 Subject: [PATCH 1824/2089] [SPARK-10963][STREAMING][KAFKA] make KafkaCluster public Author: cody koeninger Closes #9007 from koeninger/SPARK-10963. --- .../spark/streaming/kafka/KafkaCluster.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index d7885d7cc1ae1..8a66621a3125c 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -29,15 +29,19 @@ import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, To import kafka.consumer.{ConsumerConfig, SimpleConsumer} import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi /** + * :: DeveloperApi :: * Convenience methods for interacting with a Kafka cluster. + * See + * A Guide To The Kafka Protocol for more details on individual api calls. * @param kafkaParams Kafka * configuration parameters. * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), * NOT zookeeper servers, specified in host1:port1,host2:port2 form */ -private[spark] +@DeveloperApi class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} @@ -227,7 +231,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { // this 0 here indicates api version, in this case the original ZK backed api. private def defaultConsumerApiVersion: Short = 0 - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def getConsumerOffsets( groupId: String, topicAndPartitions: Set[TopicAndPartition] @@ -246,7 +250,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { } } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def getConsumerOffsetMetadata( groupId: String, topicAndPartitions: Set[TopicAndPartition] @@ -283,7 +287,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { Left(errs) } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def setConsumerOffsets( groupId: String, offsets: Map[TopicAndPartition, Long] @@ -301,7 +305,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { setConsumerOffsetMetadata(groupId, meta, consumerApiVersion) } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def setConsumerOffsetMetadata( groupId: String, metadata: Map[TopicAndPartition, OffsetAndMetadata] @@ -359,7 +363,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { } } -private[spark] +@DeveloperApi object KafkaCluster { type Err = ArrayBuffer[Throwable] @@ -371,7 +375,6 @@ object KafkaCluster { ) } - private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) /** @@ -379,7 +382,6 @@ object KafkaCluster { * Simple consumers connect directly to brokers, but need many of the same configs. * This subclass won't warn about missing ZK params, or presence of broker params. */ - private[spark] class SimpleConsumerConfig private(brokers: String, originalProps: Properties) extends ConsumerConfig(originalProps) { val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => @@ -391,7 +393,6 @@ object KafkaCluster { } } - private[spark] object SimpleConsumerConfig { /** * Make a consumer config without requiring group.id or zookeeper.connect, From edf4a0e62e6fdb849cca4f23a7060da5ec782b07 Mon Sep 17 00:00:00 2001 From: Nam Pham Date: Mon, 8 Feb 2016 11:06:41 -0800 Subject: [PATCH 1825/2089] [SPARK-12986][DOC] Fix pydoc warnings in mllib/regression.py I have fixed the warnings by running "make html" under "python/docs/". They are caused by not having blank lines around indented paragraphs. Author: Nam Pham Closes #11025 from nampham2/SPARK-12986. --- python/pyspark/mllib/regression.py | 34 ++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 13b3397501c0b..4dd7083d79c8c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -219,8 +219,10 @@ class LinearRegressionWithSGD(object): """ Train a linear regression model with no regularization using Stochastic Gradient Descent. This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/n ||A weights-y||^2 + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -367,8 +369,10 @@ def load(cls, sc, path): class LassoWithSGD(object): """ Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + This solves the L1-regularized least squares regression formulation + + f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -505,8 +509,10 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + This solves the L2-regularized least squares regression formulation + + f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -655,17 +661,19 @@ class IsotonicRegression(object): Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: - Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + + Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. - Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf Sequential PAV parallelization based on: - Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. - "An approach to parallelizing isotonic regression." - Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. - Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] - @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + "An approach to parallelizing isotonic regression." + Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf + + See `Isotonic regression (Wikipedia) `_. .. versionadded:: 1.4.0 """ From 06f0df6df204c4722ff8a6bf909abaa32a715c41 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Feb 2016 11:38:21 -0800 Subject: [PATCH 1826/2089] [SPARK-8964] [SQL] Use Exchange to perform shuffle in Limit This patch changes the implementation of the physical `Limit` operator so that it relies on the `Exchange` operator to perform data movement rather than directly using `ShuffledRDD`. In addition to improving efficiency, this lays the necessary groundwork for further optimization of limit, such as limit pushdown or whole-stage codegen. At a high-level, this replaces the old physical `Limit` operator with two new operators, `LocalLimit` and `GlobalLimit`. `LocalLimit` performs per-partition limits, while `GlobalLimit` applies the final limit to a single partition; `GlobalLimit`'s declares that its `requiredInputDistribution` is `SinglePartition`, which will cause the planner to use an `Exchange` to perform the appropriate shuffles. Thus, a logical `Limit` appearing in the middle of a query plan will be expanded into `LocalLimit -> Exchange to one partition -> GlobalLimit`. In the old code, calling `someDataFrame.limit(100).collect()` or `someDataFrame.take(100)` would actually skip the shuffle and use a fast-path which used `executeTake()` in order to avoid computing all partitions in case only a small number of rows were requested. This patch preserves this optimization by treating logical `Limit` operators specially when they appear as the terminal operator in a query plan: if a `Limit` is the final operator, then we will plan a special `CollectLimit` physical operator which implements the old `take()`-based logic. In order to be able to match on operators only at the root of the query plan, this patch introduces a special `ReturnAnswer` logical operator which functions similar to `BroadcastHint`: this dummy operator is inserted at the root of the optimized logical plan before invoking the physical planner, allowing the planner to pattern-match on it. Author: Josh Rosen Closes #7334 from JoshRosen/remove-copy-in-limit. --- .../plans/logical/basicOperators.scala | 11 ++ .../apache/spark/sql/execution/Exchange.scala | 130 ++++++++++-------- .../spark/sql/execution/QueryExecution.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../spark/sql/execution/basicOperators.scala | 95 +------------ .../apache/spark/sql/execution/limit.scala | 122 ++++++++++++++++ .../spark/sql/execution/PlannerSuite.scala | 10 +- .../spark/sql/execution/SortSuite.scala | 4 +- 8 files changed, 223 insertions(+), 160 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 03a79520cbd3a..57575f9ee09ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -25,6 +25,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +/** + * When planning take() or collect() operations, this special node that is inserted at the top of + * the logical plan before invoking the query planner. + * + * Rules can pattern-match on this node in order to apply transformations that only take effect + * at the top of the logical query plan. + */ +case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3770883af1e2f..97f65f18bfdcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -57,6 +57,69 @@ case class Exchange( override def output: Seq[Attribute] = child.output + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + override protected def doPrepare(): Unit = { + // If an ExchangeCoordinator is needed, we register this Exchange operator + // to the coordinator when we do prepare. It is important to make sure + // we register this operator right before the execution instead of register it + // in the constructor because it is possible that we create new instances of + // Exchange operators when we transform the physical plan + // (then the ExchangeCoordinator will hold references of unneeded Exchanges). + // So, we should only call registerExchange just before we start to execute + // the plan. + coordinator match { + case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) + case None => + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { + Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer) + } + + /** + * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. + * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional + * partition start indices array. If this optional array is defined, the returned + * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. + */ + private[sql] def preparePostShuffleRDD( + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { + // If an array of partition start indices is provided, we need to use this array + // to create the ShuffledRowRDD. Also, we need to update newPartitioning to + // update the number of post-shuffle partitions. + specifiedPartitionStartIndices.foreach { indices => + assert(newPartitioning.isInstanceOf[HashPartitioning]) + newPartitioning = UnknownPartitioning(indices.length) + } + new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } + } +} + +object Exchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { + Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) + } + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -82,7 +145,7 @@ case class Exchange( // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output // fewer partitions (like RangePartitioner, for example). - val conf = child.sqlContext.sparkContext.conf + val conf = SparkEnv.get.conf val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) @@ -117,30 +180,16 @@ case class Exchange( } } - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - - override protected def doPrepare(): Unit = { - // If an ExchangeCoordinator is needed, we register this Exchange operator - // to the coordinator when we do prepare. It is important to make sure - // we register this operator right before the execution instead of register it - // in the constructor because it is possible that we create new instances of - // Exchange operators when we transform the physical plan - // (then the ExchangeCoordinator will hold references of unneeded Exchanges). - // So, we should only call registerExchange just before we start to execute - // the plan. - coordinator match { - case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) - case None => - } - } - /** * Returns a [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { - val rdd = child.execute() + private[sql] def prepareShuffleDependency( + rdd: RDD[InternalRow], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) case HashPartitioning(_, n) => @@ -160,7 +209,7 @@ case class Exchange( // We need to use an interpreted ordering here because generated orderings cannot be // serialized and this ordering needs to be created on the driver in order to be passed into // Spark core code. - implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) + implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { @@ -180,7 +229,7 @@ case class Exchange( position } case h: HashPartitioning => - val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output) + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") @@ -211,43 +260,6 @@ case class Exchange( dependency } - - /** - * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. - * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional - * partition start indices array. If this optional array is defined, the returned - * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. - */ - private[sql] def preparePostShuffleRDD( - shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], - specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { - // If an array of partition start indices is provided, we need to use this array - // to create the ShuffledRowRDD. Also, we need to update newPartitioning to - // update the number of post-shuffle partitions. - specifiedPartitionStartIndices.foreach { indices => - assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = UnknownPartitioning(indices.length) - } - new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - coordinator match { - case Some(exchangeCoordinator) => - val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) - assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) - shuffleRDD - case None => - val shuffleDependency = prepareShuffleDependency() - preparePostShuffleRDD(shuffleDependency) - } - } -} - -object Exchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 107570f9dbcc8..8616fe317034f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(optimizedPlan).next() + sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 830bb011beab4..ee392e4e8debe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -338,8 +338,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil + case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) => + execution.CollectLimit(limit, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child)) :: Nil + val perPartitionLimit = execution.LocalLimit(limit, planLater(child)) + val globalLimit = execution.GlobalLimit(limit, perPartitionLimit) + globalLimit :: Nil case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => @@ -358,6 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil + case logical.ReturnAnswer(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6e51c4d84824a..f63e8a9b6d79d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.execution -import org.apache.spark.{HashPartitioner, SparkEnv} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} -import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType -import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler case class Project(projectList: Seq[NamedExpression], child: SparkPlan) @@ -306,96 +303,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { sparkContext.union(children.map(_.execute())) } -/** - * Take the first limit elements. Note that the implementation is different depending on whether - * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, - * this operator uses something similar to Spark's take method on the Spark driver. If it is not - * terminal or is invoked using execute, we first take the limit on each partition, and then - * repartition all the data to a single partition to compute the global limit. - */ -case class Limit(limit: Int, child: SparkPlan) - extends UnaryNode { - // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: - // partition local limit -> exchange into one partition -> partition local limit again - - /** We must copy rows when sort based shuffle is on */ - private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = SinglePartition - - override def executeCollect(): Array[InternalRow] = child.executeTake(limit) - - protected override def doExecute(): RDD[InternalRow] = { - val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { - child.execute().mapPartitionsInternal { iter => - iter.take(limit).map(row => (false, row.copy())) - } - } else { - child.execute().mapPartitionsInternal { iter => - val mutablePair = new MutablePair[Boolean, InternalRow]() - iter.take(limit).map(row => mutablePair.update(false, row)) - } - } - val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) - shuffled.mapPartitionsInternal(_.take(limit).map(_._2)) - } -} - -/** - * Take the first limit elements as defined by the sortOrder, and do projection if needed. - * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, - * or having a [[Project]] operator between them. - * This could have been named TopK, but Spark's top operator does the opposite in ordering - * so we name it TakeOrdered to avoid confusion. - */ -case class TakeOrderedAndProject( - limit: Int, - sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], - child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } - - override def outputPartitioning: Partitioning = SinglePartition - - // We need to use an interpreted ordering here because generated orderings cannot be serialized - // and this ordering needs to be created on the driver in order to be passed into Spark core code. - private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - - private def collectData(): Array[InternalRow] = { - val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) - data.map(r => proj(r).copy()) - } else { - data - } - } - - override def executeCollect(): Array[InternalRow] = { - collectData() - } - - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def simpleString: String = { - val orderByString = sortOrder.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") - - s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" - } -} - /** * Return a new RDD that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala new file mode 100644 index 0000000000000..256f4228ae99e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ + + +/** + * Take the first `limit` elements and collect them to a single partition. + * + * This operator will be used when a logical `Limit` operation is the final operator in an + * logical plan, which happens when the user is collecting results back to the driver. + */ +case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + protected override def doExecute(): RDD[InternalRow] = { + val shuffled = new ShuffledRowRDD( + Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer)) + shuffled.mapPartitionsInternal(_.take(limit)) + } +} + +/** + * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. + */ +trait BaseLimit extends UnaryNode { + val limit: Int + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + iter.take(limit) + } +} + +/** + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + */ +case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + +/** + * Take the first `limit` elements of the child's single output partition. + */ +case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit { + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil +} + +/** + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a Limit operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. + */ +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def outputPartitioning: Partitioning = SinglePartition + + // We need to use an interpreted ordering here because generated orderings cannot be serialized + // and this ordering needs to be created on the driver in order to be passed into Spark core code. + private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) + + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } + } + + override def executeCollect(): Array[InternalRow] = { + collectData() + } + + // TODO: Terminal split should be implemented differently from non-terminal split. + // TODO: Pick num splits based on |limit|. + protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index adaeb513bc1b8..a64ad4038c7c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -181,6 +181,12 @@ class PlannerSuite extends SharedSQLContext { } } + test("terminal limits use CollectLimit") { + val query = testData.select('value).limit(2) + val planned = query.queryExecution.sparkPlan + assert(planned.isInstanceOf[CollectLimit]) + } + test("PartitioningCollection") { withTempTable("normal", "small", "tiny") { testData.registerTempTable("normal") @@ -200,7 +206,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: Exchange => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } { @@ -215,7 +221,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: Exchange => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 6259453da26a1..cb6d68dc3ac46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -56,8 +56,8 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)), - (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)), + (child: SparkPlan) => GlobalLimit(10, Sort('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => GlobalLimit(10, ReferenceSort('a.asc :: Nil, global = true, child)), sortAnswers = false ) } From 8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Feb 2016 12:06:00 -0800 Subject: [PATCH 1827/2089] [SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder nullability should only be considered as an optimization rather than part of the type system, so instead of failing analysis for mismatch nullability, we should pass analysis and add runtime null check. Author: Wenchen Fan Closes #11035 from cloud-fan/ignore-nullability. --- .../sql/catalyst/JavaTypeInference.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 29 +++-- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/expressions/objects.scala | 20 ++-- .../encoders/EncoderResolutionSuite.scala | 107 +++++------------- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 4 +- 7 files changed, 64 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 3c3717d5043aa..59ee41d02f198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -292,7 +292,7 @@ object JavaTypeInference { val setter = if (nullable) { constructor } else { - AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + AssertNotNull(constructor, Seq("currently no type path record in java")) } p.getWriteMethod.getName -> setter }.toMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e5811efb436a6..02cb2d9a2b118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + + // TODO: add runtime null check for primitive array val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) + + val mapFunction: Expression => Expression = p => { + val converter = constructorFor(elementType, Some(p), newTypePath) + if (nullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } + } + + val array = Invoke( + MapObjects(mapFunction, getPath, dataType), + "array", + ObjectType(classOf[Array[Any]])) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + array :: Nil) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map @@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection { newTypePath) if (!nullable) { - AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + AssertNotNull(constructor, newTypePath) } else { constructor } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb228cf52b433..4d53b232d5510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1426,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] { fail(child, DateType, walkedTypePath) case (StringType, to: NumericType) => fail(child, to, walkedTypePath) - case _ => Cast(child, dataType) + case _ => Cast(child, dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 79fe0033b71ab..fef6825b2db5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -365,7 +365,7 @@ object MapObjects { * to handle collection elements. * @param inputData An expression that when evaluted returns a collection object. */ -case class MapObjects( +case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, inputData: Expression) extends Expression { @@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull( - child: Expression, parentType: String, fieldName: String, fieldType: String) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) extends UnaryExpression { override def dataType: DataType = child.dataType @@ -651,6 +650,14 @@ case class AssertNotNull( override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childGen = child.gen(ctx) + val errMsg = "Null value appeared in non-nullable field:" + + walkedTypePath.mkString("\n", "\n", "\n") + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + val idx = ctx.references.length + ctx.references += errMsg + ev.isNull = "false" ev.value = childGen.value @@ -658,12 +665,7 @@ case class AssertNotNull( ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException( - "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + - "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + - "please try to use scala.Option[_] or other nullable types " + - "(e.g. java.lang.Integer instead of int/scala.Int)." - ); + throw new RuntimeException((String) references[$idx]); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 92a68a4dba915..8b02b63c6cf3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class StringLongClass(a: String, b: Long) @@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) class EncoderResolutionSuite extends PlanTest { + private val str = UTF8String.fromString("hello") + test("real type doesn't match encoder schema but they are compatible: product") { val encoder = ExpressionEncoder[StringLongClass] - val cls = classOf[StringLongClass] - - { - val attrs = Seq('a.string, 'b.int) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - cls, - Seq( - toExternalString('a.string), - AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") - ), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) - } + // int type can be up cast to long type + val attrs1 = Seq('a.string, 'b.int) + encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) - { - val attrs = Seq('a.int, 'b.long) - val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression - val expected = NewInstance( - cls, - Seq( - toExternalString('a.int.cast(StringType)), - AssertNotNull('b.long, cls.getName, "b", "Long") - ), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) - } + // int type can be up cast to string type + val attrs2 = Seq('a.int, 'b.long) + encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] - val innerCls = classOf[StringLongClass] - val cls = classOf[ComplexClass] - val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - cls, - Seq( - AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), - If( - 'b.struct('a.int, 'b.long).isNull, - Literal.create(null, ObjectType(innerCls)), - NewInstance( - innerCls, - Seq( - toExternalString( - GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - AssertNotNull( - GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), - innerCls.getName, "b", "Long")), - ObjectType(innerCls), - propagateNull = false) - )), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { val encoder = ExpressionEncoder.tuple( ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) - val cls = classOf[StringLongClass] - val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - classOf[Tuple2[_, _]], - Seq( - NewInstance( - cls, - Seq( - toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - AssertNotNull( - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), - cls.getName, "b", "Long")), - ObjectType(cls), - propagateNull = false), - 'b.int.cast(LongType)), - ObjectType(classOf[Tuple2[_, _]]), - propagateNull = false) - compareExpressions(fromRowExpr, expected) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + } + + test("nullability of array type element should not fail analysis") { + val encoder = ExpressionEncoder[Seq[Int]] + val attrs = 'a.array(IntegerType) :: Nil + + // It should pass analysis + val bound = encoder.resolve(attrs, null).bind(attrs) + + // If no null values appear, it should works fine + bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) + + // If there is null value, it should throw runtime exception + val e = intercept[RuntimeException] { + bound.fromRow(InternalRow(new GenericArrayData(Array(1, null)))) + } + assert(e.getMessage.contains("Null value appeared in non-nullable field")) } test("the real number of fields doesn't match encoder schema: tuple encoder") { @@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest { } } - private def toExternalString(e: Expression): Expression = { - Invoke(e, "toString", ObjectType(classOf[String]), Nil) - } - test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a6fb62c17d59b..1181244c8a4ed 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() { } nullabilityCheck.expect(RuntimeException.class); - nullabilityCheck.expectMessage( - "Null value appeared in non-nullable field " + - "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + nullabilityCheck.expectMessage("Null value appeared in non-nullable field"); { Row row = new GenericRow(new Object[] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 374f4320a9239..f9ba60770022d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { buildDataset(Row(Row("hello", null))).collect() }.getMessage - assert(message.contains( - "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." - )) + assert(message.contains("Null value appeared in non-nullable field")) } test("SPARK-12478: top level null field") { From 37bc203c8dd5022cb11d53b697c28a737ee85bcc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Feb 2016 12:08:58 -0800 Subject: [PATCH 1828/2089] [SPARK-13210][SQL] catch OOM when allocate memory and expand array There is a bug when we try to grow the buffer, OOM is ignore wrongly (the assert also skipped by JVM), then we try grow the array again, this one will trigger spilling free the current page, the current record we inserted will be invalid. The root cause is that JVM has less free memory than MemoryManager thought, it will OOM when allocate a page without trigger spilling. We should catch the OOM, and acquire memory again to trigger spilling. And also, we could not grow the array in `insertRecord` of `InMemorySorter` (it was there just for easy testing). Author: Davies Liu Closes #11095 from davies/fix_expand. --- .../spark/memory/TaskMemoryManager.java | 23 ++++++++++++++++++- .../shuffle/sort/ShuffleExternalSorter.java | 10 +------- .../shuffle/sort/ShuffleInMemorySorter.java | 2 +- .../unsafe/sort/UnsafeExternalSorter.java | 10 +------- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../sort/ShuffleInMemorySorterSuite.java | 6 +++++ .../sort/UnsafeInMemorySorterSuite.java | 3 +++ 7 files changed, 35 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d2a88864f7ac9..8757dff36f159 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -111,6 +111,11 @@ public class TaskMemoryManager { @GuardedBy("this") private final HashSet consumers; + /** + * The amount of memory that is acquired but not used. + */ + private long acquiredButNotUsed = 0L; + /** * Construct a new TaskMemoryManager. */ @@ -256,7 +261,20 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { } allocatedPages.set(pageNumber); } - final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + MemoryBlock page = null; + try { + page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + } catch (OutOfMemoryError e) { + logger.warn("Failed to allocate a page ({} bytes), try again.", acquired); + // there is no enough memory actually, it means the actual free memory is smaller than + // MemoryManager thought, we should keep the acquired memory. + acquiredButNotUsed += acquired; + synchronized (this) { + allocatedPages.clear(pageNumber); + } + // this could trigger spilling to free some pages. + return allocatePage(size, consumer); + } page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { @@ -378,6 +396,9 @@ public long cleanUpAllAllocatedMemory() { } Arrays.fill(pageTable, null); + // release the memory that is not used by any consumer. + memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 2c84de5bf2a5a..f97e76d7ed0d9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -320,15 +320,7 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - LongArray array; - try { - // could trigger spilling - array = allocateArray(used / 8 * 2); - } catch (OutOfMemoryError e) { - // should have trigger spilling - assert(inMemSorter.hasSpaceForAnotherRecord()); - return; - } + LongArray array = allocateArray(used / 8 * 2); // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { freeArray(array); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 58ad88e1ed87b..d74602cd205ad 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -104,7 +104,7 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(consumer.allocateArray(array.size() * 2)); + throw new IllegalStateException("There is no space for new record"); } array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); pos++; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index a6edc1ad3f665..296bf722fc178 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -293,15 +293,7 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - LongArray array; - try { - // could trigger spilling - array = allocateArray(used / 8 * 2); - } catch (OutOfMemoryError e) { - // should have trigger spilling - assert(inMemSorter.hasSpaceForAnotherRecord()); - return; - } + LongArray array = allocateArray(used / 8 * 2); // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { freeArray(array); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index d1b0bc5d11f46..cea0f0a0c6c11 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -164,7 +164,7 @@ public void expandPointerArray(LongArray newArray) { */ public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(consumer.allocateArray(array.size() * 2)); + throw new IllegalStateException("There is no space for new record"); } array.set(pos, recordPointer); pos++; diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 0328e63e45439..eb1da8e1b43eb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -75,6 +75,9 @@ public void testBasicSorting() throws Exception { // Write the records into the data page and store pointers into the sorter long position = dataPage.getBaseOffset(); for (String str : dataToSort) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + } final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); Platform.putInt(baseObject, position, strBytes.length); @@ -114,6 +117,9 @@ public void testSortingManyNumbers() throws Exception { int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + } numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); sorter.insertRecord(0, numbersToSort[i]); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 93efd033eb940..8e557ec0ab0b4 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -111,6 +111,9 @@ public int compare(long prefix1, long prefix2) { // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2)); + } // position now points to the start of a record (which holds its length). final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); From ff0af0ddfa4d198b203c3a39f8532cfbd4f4e027 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Feb 2016 14:09:14 -0800 Subject: [PATCH 1829/2089] [SPARK-13095] [SQL] improve performance for broadcast join with dimension table This PR improve the performance for Broadcast join with dimension tables, which is common in data warehouse. If the join key can fit in a long, we will use a special api `get(Long)` to get the rows from HashedRelation. If the HashedRelation only have unique keys, we will use a special api `getValue(Long)` or `getValue(InternalRow)`. If the keys can fit within a long, also the keys are dense, we will use a array of UnsafeRow, instead a hash map. TODO: will do cleanup Author: Davies Liu Closes #11065 from davies/gen_dim. --- .../sql/execution/WholeStageCodegen.scala | 1 + .../execution/joins/BroadcastHashJoin.scala | 96 +++--- .../joins/BroadcastHashOuterJoin.scala | 7 +- .../joins/BroadcastLeftSemiJoinHash.scala | 6 +- .../spark/sql/execution/joins/HashJoin.scala | 44 ++- .../sql/execution/joins/HashedRelation.scala | 297 +++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 27 +- .../execution/joins/HashedRelationSuite.scala | 29 +- 8 files changed, 438 insertions(+), 69 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 131efea20f31e..4ca2d85406bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -38,6 +38,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" + case _: BroadcastHashJoin => "bhj" case _ => nodeName.toLowerCase } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 943ad31c0cef5..cbd549763ac95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -90,8 +90,14 @@ case class BroadcastHashJoin( // The following line doesn't run in a job so we cannot track the metric value. However, we // have already tracked it in the above lines. So here we can use // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + // TODO: move this check into HashedRelation + val hashed = if (canJoinKeyFitWithinLong) { + LongHashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + } else { + HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + } sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -112,15 +118,12 @@ case class BroadcastHashJoin( streamedPlan.execute().mapPartitions { streamedIter => val hashedRelation = broadcastRelation.value - hashedRelation match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) - case _ => - } + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + private var broadcastRelation: Broadcast[HashedRelation] = _ // the term for hash relation private var relationTerm: String = _ @@ -129,16 +132,15 @@ case class BroadcastHashJoin( } override def doProduce(ctx: CodegenContext): String = { - // create a name for HashRelation - val broadcastRelation = Await.result(broadcastFuture, timeout) + // create a name for HashedRelation + broadcastRelation = Await.result(broadcastFuture, timeout) val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) relationTerm = ctx.freshName("relation") - // TODO: create specialized HashRelation for single join key - val clsName = classOf[UnsafeHashedRelation].getName + val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = ($clsName) $broadcast.value(); - | incPeakExecutionMemory($relationTerm.getUnsafeSize()); + | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) s""" @@ -147,23 +149,24 @@ case class BroadcastHashJoin( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // generate the key as UnsafeRow + // generate the key as UnsafeRow or Long ctx.currentVars = input - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) - val keyTerm = keyVal.value - val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + val expr = rewriteKeyExpr(streamedKeys).head + val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + (ev, ev.isNull) + } else { + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + (ev, s"${ev.value}.anyNull()") + } // find the matches from HashedRelation - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - val row = ctx.freshName("row") + val matched = ctx.freshName("matched") // create variables for output ctx.currentVars = null - ctx.INPUT_ROW = row + ctx.INPUT_ROW = matched val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => BoundReference(i, a.dataType, a.nullable).gen(ctx) } @@ -172,7 +175,7 @@ case class BroadcastHashJoin( case BuildRight => input ++ buildColumns } - val ouputCode = if (condition.isDefined) { + val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) @@ -186,20 +189,39 @@ case class BroadcastHashJoin( consume(ctx, resultVars) } - s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $row = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $ouputCode - | } - | } + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashedRelation + | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); + | if ($matched != null) { + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = ${anyNull} ? null : + | ($bufferType) $relationTerm.get(${keyVal.value}); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } + | } + """.stripMargin + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index f48fc3b84864d..ad3275696e637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -116,12 +116,7 @@ case class BroadcastHashOuterJoin( val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator - - hashTable match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) - case _ => - } + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) val resultProj = resultProjection joinType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 8929dc3af1912..d0e18dfcf3d90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -64,11 +64,7 @@ case class BroadcastLeftSemiJoinHash( left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value - hashedRelation match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) - case _ => - } + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8ef854001f4de..ecbb1ac64b7c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric - +import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { self: SparkPlan => @@ -47,11 +47,49 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + if (e.dataType != LongType) { + keyExpr = Cast(e, LongType) + } else { + keyExpr = e + } + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + width -= bits + } + // TODO: support BooleanType, DateType and TimestampType + case other => + return keys + } + } + keyExpr :: Nil + } + + protected val canJoinKeyFitWithinLong: Boolean = { + val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) + val key = rewriteKeyExpr(buildKeys) + sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] + } + protected def buildSideKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) + UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) protected def streamSideKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) + UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ee7a1bdc343c0..c94d6c195b1d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -39,8 +39,23 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[execution] sealed trait HashedRelation { + /** + * Returns matched rows. + */ def get(key: InternalRow): Seq[InternalRow] + /** + * Returns matched rows for a key that has only one column with LongType. + */ + def get(key: Long): Seq[InternalRow] = { + throw new UnsupportedOperationException + } + + /** + * Returns the size of used memory. + */ + def getMemorySize: Long = 1L // to make the test happy + // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { @@ -58,11 +73,48 @@ private[execution] sealed trait HashedRelation { } } +/** + * Interface for a hashed relation that have only one row per key. + * + * We should call getValue() for better performance. + */ +private[execution] trait UniqueHashedRelation extends HashedRelation { + + /** + * Returns the matched single row. + */ + def getValue(key: InternalRow): InternalRow + + /** + * Returns the matched single row with key that have only one column of LongType. + */ + def getValue(key: Long): InternalRow = { + throw new UnsupportedOperationException + } + + override def get(key: InternalRow): Seq[InternalRow] = { + val row = getValue(key) + if (row != null) { + CompactBuffer[InternalRow](row) + } else { + null + } + } + + override def get(key: Long): Seq[InternalRow] = { + val row = getValue(key) + if (row != null) { + CompactBuffer[InternalRow](row) + } else { + null + } + } +} /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation( +private[joins] class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { @@ -85,19 +137,14 @@ private[joins] final class GeneralHashedRelation( * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] -final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) - extends HashedRelation with Externalizable { +private[joins] class UniqueKeyHashedRelation( + private var hashTable: JavaHashMap[InternalRow, InternalRow]) + extends UniqueHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) def this() = this(null) - override def get(key: InternalRow): Seq[InternalRow] = { - val v = hashTable.get(key) - if (v eq null) null else CompactBuffer(v) - } - - def getValue(key: InternalRow): InternalRow = hashTable.get(key) + override def getValue(key: InternalRow): InternalRow = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -108,8 +155,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } -// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. - private[execution] object HashedRelation { @@ -208,7 +253,7 @@ private[joins] final class UnsafeHashedRelation( * * For non-broadcast joins or in local mode, return 0. */ - def getUnsafeSize: Long = { + override def getMemorySize: Long = { if (binaryMap != null) { binaryMap.getTotalMemoryConsumption } else { @@ -408,6 +453,232 @@ private[joins] object UnsafeHashedRelation { } } + // TODO: create UniqueUnsafeRelation new UnsafeHashedRelation(hashTable) } } + +/** + * An interface for a hashed relation that the key is a Long. + */ +private[joins] trait LongHashedRelation extends HashedRelation { + override def get(key: InternalRow): Seq[InternalRow] = { + get(key.getLong(0)) + } +} + +private[joins] final class GeneralLongHashedRelation( + private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) + extends LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) + + override def get(key: Long): Seq[InternalRow] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] final class UniqueLongHashedRelation( + private var hashTable: JavaHashMap[Long, UnsafeRow]) + extends UniqueHashedRelation with LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) + + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def getValue(key: Long): InternalRow = { + hashTable.get(key) + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +/** + * A relation that pack all the rows into a byte array, together with offsets and sizes. + * + * All the bytes of UnsafeRow are packed together as `bytes`: + * + * [ Row0 ][ Row1 ][] ... [ RowN ] + * + * With keys: + * + * start start+1 ... start+N + * + * `offsets` are offsets of UnsafeRows in the `bytes` + * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. + * + * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: + * + * start = 3 + * offsets = [0, 0, 24] + * sizes = [24, 0, 32] + * bytes = [0 - 24][][24 - 56] + */ +private[joins] final class LongArrayRelation( + private var numFields: Int, + private var start: Long, + private var offsets: Array[Int], + private var sizes: Array[Int], + private var bytes: Array[Byte] + ) extends UniqueHashedRelation with LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(0, 0L, null, null, null) + + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def getMemorySize: Long = { + offsets.length * 4 + sizes.length * 4 + bytes.length + } + + override def getValue(key: Long): InternalRow = { + val idx = (key - start).toInt + if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { + val result = new UnsafeRow(numFields) + result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + result + } else { + null + } + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(numFields) + out.writeLong(start) + out.writeInt(sizes.length) + var i = 0 + while (i < sizes.length) { + out.writeInt(sizes(i)) + i += 1 + } + out.writeInt(bytes.length) + out.write(bytes) + } + + override def readExternal(in: ObjectInput): Unit = { + numFields = in.readInt() + start = in.readLong() + val length = in.readInt() + // read sizes of rows + sizes = new Array[Int](length) + offsets = new Array[Int](length) + var i = 0 + var offset = 0 + while (i < length) { + offsets(i) = offset + sizes(i) = in.readInt() + offset += sizes(i) + i += 1 + } + // read all the bytes + val total = in.readInt() + assert(total == offset) + bytes = new Array[Byte](total) + in.readFully(bytes) + } +} + +/** + * Create hashed relation with key that is long. + */ +private[joins] object LongHashedRelation { + + val DENSE_FACTOR = 0.2 + + def apply( + input: Iterator[InternalRow], + numInputRows: LongSQLMetric, + keyGenerator: Projection, + sizeEstimate: Int): HashedRelation = { + + // Use a Java hash table here because unsafe maps expect fixed size records + val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + + // Create a mapping of key -> rows + var numFields = 0 + var keyIsUnique = true + var minKey = Long.MaxValue + var maxKey = Long.MinValue + while (input.hasNext) { + val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numFields = unsafeRow.numFields() + numInputRows += 1 + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val key = rowKey.getLong(0) + minKey = math.min(minKey, key) + maxKey = math.max(maxKey, key) + val existingMatchList = hashTable.get(key) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(key, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + if (keyIsUnique) { + if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { + // The keys are dense enough, so use LongArrayRelation + val length = (maxKey - minKey).toInt + 1 + val sizes = new Array[Int](length) + val offsets = new Array[Int](length) + var offset = 0 + var i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + offsets(i) = offset + sizes(i) = rows(0).getSizeInBytes + offset += sizes(i) + } + i += 1 + } + val bytes = new Array[Byte](offset) + i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) + } + i += 1 + } + new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) + + } else { + // all the keys are unique, one row per key. + val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueLongHashedRelation(uniqHashTable) + } + } else { + new GeneralLongHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 33d4976403d9a..f015d297048a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -22,6 +22,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.IntegerType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -122,10 +123,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("broadcast hash join") { - val N = 20 << 20 + val N = 100 << 20 val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) - runBenchmark("BroadcastHashJoin", N) { + runBenchmark("Join w long", N) { sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() } @@ -133,9 +134,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X - BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X + Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X + Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X + */ + + val dim2 = broadcast(sqlContext.range(1 << 16) + .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 ints", N) { + sqlContext.range(N).join(dim2, + (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1") + && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X + Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X */ + } ignore("hash and BytesToBytesMap") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index e5fd9e277fc61..f985dfbd8ade9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer - class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself @@ -134,4 +133,32 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { out2.flush() assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } + + test("LongArrayRelation") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) + val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) + val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) + val longRelation = LongHashedRelation(rows.iterator, SQLMetrics.nullLongMetric, keyProj, 100) + assert(longRelation.isInstanceOf[LongArrayRelation]) + val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] + (0 until 100).foreach { i => + val row = longArrayRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) + } + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + longArrayRelation.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val relation = new LongArrayRelation() + relation.readExternal(in) + (0 until 100).foreach { i => + val row = longArrayRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) + } + } } From eeaf45b92695c577279f3a17d8c80ee40425e9aa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 8 Feb 2016 17:23:33 -0800 Subject: [PATCH 1830/2089] [SPARK-10620][SPARK-13054] Minor addendum to #10835 Additional changes to #10835, mainly related to style and visibility. This patch also adds back a few deprecated methods for backward compatibility. Author: Andrew Or Closes #10958 from andrewor14/task-metrics-to-accums-followups. --- .../scala/org/apache/spark/Accumulator.scala | 11 +++---- .../apache/spark/InternalAccumulator.scala | 4 +-- .../org/apache/spark/TaskContextImpl.scala | 2 +- .../org/apache/spark/TaskEndReason.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 16 +++++----- .../apache/spark/executor/TaskMetrics.scala | 20 +++++++++++-- .../apache/spark/scheduler/ResultTask.scala | 2 +- .../spark/scheduler/SparkListener.scala | 1 + .../org/apache/spark/ui/jobs/StagePage.scala | 6 ++-- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 2 +- .../spark/InternalAccumulatorSuite.scala | 6 ++-- .../spark/executor/TaskMetricsSuite.scala | 30 +++++++++---------- .../spark/scheduler/TaskContextSuite.scala | 2 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 4 +-- project/MimaExcludes.scala | 3 +- 17 files changed, 66 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 558bd447e22c5..5e8f1d4a705c3 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -60,19 +60,20 @@ import org.apache.spark.storage.{BlockId, BlockStatus} * @tparam T result type */ class Accumulator[T] private[spark] ( - @transient private[spark] val initialValue: T, + // SI-8813: This must explicitly be a private val, or else scala 2.11 doesn't compile + @transient private val initialValue: T, param: AccumulatorParam[T], name: Option[String], internal: Boolean, - override val countFailedValues: Boolean = false) + private[spark] override val countFailedValues: Boolean = false) extends Accumulable[T, T](initialValue, param, name, internal, countFailedValues) { def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { - this(initialValue, param, name, false) + this(initialValue, param, name, false /* internal */) } def this(initialValue: T, param: AccumulatorParam[T]) = { - this(initialValue, param, None, false) + this(initialValue, param, None, false /* internal */) } } @@ -84,7 +85,7 @@ private[spark] object Accumulators extends Logging { * This global map holds the original accumulator objects that are created on the driver. * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. - * TODO: Don't use a global map; these should be tied to a SparkContext at the very least. + * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051). */ @GuardedBy("Accumulators") val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index c191122c0630a..7aa9057858a04 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -119,7 +119,7 @@ private[spark] object InternalAccumulator { /** * Accumulators for tracking internal metrics. */ - def create(): Seq[Accumulator[_]] = { + def createAll(): Seq[Accumulator[_]] = { Seq[String]( EXECUTOR_DESERIALIZE_TIME, EXECUTOR_RUN_TIME, @@ -188,7 +188,7 @@ private[spark] object InternalAccumulator { * values across all tasks within each stage. */ def create(sc: SparkContext): Seq[Accumulator[_]] = { - val accums = create() + val accums = createAll() accums.foreach { accum => Accumulators.register(accum) sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum)) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 27ca46f73d8ca..1d228b6b86c55 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -32,7 +32,7 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, - initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.create()) + initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll()) extends TaskContext with Logging { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 68340cc704dae..c8f201ea9e4d5 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -118,7 +118,7 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - exceptionWrapper: Option[ThrowableSerializationWrapper], + private val exceptionWrapper: Option[ThrowableSerializationWrapper], accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]) extends TaskFailedReason { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 51c000ea5c574..00be3a240dbac 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -300,15 +300,15 @@ private[spark] class Executor( // Collect latest accumulator values to report back to the driver val accumulatorUpdates: Seq[AccumulableInfo] = - if (task != null) { - task.metrics.foreach { m => - m.setExecutorRunTime(System.currentTimeMillis() - taskStart) - m.setJvmGCTime(computeTotalGcTime() - startGCTime) + if (task != null) { + task.metrics.foreach { m => + m.setExecutorRunTime(System.currentTimeMillis() - taskStart) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) + } + task.collectAccumulatorUpdates(taskFailed = true) + } else { + Seq.empty[AccumulableInfo] } - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty[AccumulableInfo] - } val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0a6ebcb3e0293..8ff0620f837c9 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -45,13 +45,12 @@ import org.apache.spark.storage.{BlockId, BlockStatus} * these requirements. */ @DeveloperApi -class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { - +class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Serializable { import InternalAccumulator._ // Needed for Java tests def this() { - this(InternalAccumulator.create()) + this(InternalAccumulator.createAll()) } /** @@ -144,6 +143,11 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None } + @deprecated("setting updated blocks is not allowed", "2.0.0") + def updatedBlocks_=(blocks: Option[Seq[(BlockId, BlockStatus)]]): Unit = { + blocks.foreach(setUpdatedBlockStatuses) + } + // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) @@ -220,6 +224,11 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { */ def outputMetrics: Option[OutputMetrics] = _outputMetrics + @deprecated("setting OutputMetrics is for internal use only", "2.0.0") + def outputMetrics_=(om: Option[OutputMetrics]): Unit = { + _outputMetrics = om + } + /** * Get or create a new [[OutputMetrics]] associated with this task. */ @@ -296,6 +305,11 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable { */ def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics + @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0") + def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = { + _shuffleWriteMetrics = swm + } + /** * Get or create a new [[ShuffleWriteMetrics]] associated with this task. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 885f70e89fbf5..cd2736e1960c2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -49,7 +49,7 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, - _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.create()) + _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) with Serializable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 7b09c2eded0be..0a45ef5283326 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -61,6 +61,7 @@ case class SparkListenerTaskEnd( taskType: String, reason: TaskEndReason, taskInfo: TaskInfo, + // may be null if the task has failed @Nullable taskMetrics: TaskMetrics) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 29c5ff0b5cf0b..0b68b88566b70 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -408,9 +408,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) - val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.peakExecutionMemory.toDouble - } + val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.peakExecutionMemory.toDouble + } val peakExecutionMemoryQuantiles = { Utils.getFormattedClassName(metricsUpdate)) ~ ("Executor ID" -> execId) ~ - ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => + ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ ("Stage ID" -> stageId) ~ ("Stage Attempt ID" -> stageAttemptId) ~ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index e0fdd45973858..4d49fe5159850 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -268,7 +268,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false) val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false) val externalAccums = Seq(acc1, acc2) - val internalAccums = InternalAccumulator.create() + val internalAccums = InternalAccumulator.createAll() // Set some values; these should not be observed later on the "executors" acc1.setValue(10) acc2.setValue(20L) diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 44a16e26f4935..c426bb7a4e809 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -87,7 +87,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } test("create") { - val accums = create() + val accums = createAll() val shuffleReadAccums = createShuffleReadAccums() val shuffleWriteAccums = createShuffleWriteAccums() val inputAccums = createInputAccums() @@ -123,7 +123,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } test("naming") { - val accums = create() + val accums = createAll() val shuffleReadAccums = createShuffleReadAccums() val shuffleWriteAccums = createShuffleWriteAccums() val inputAccums = createInputAccums() @@ -291,7 +291,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } assert(Accumulators.originals.isEmpty) sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count() - val internalAccums = InternalAccumulator.create() + val internalAccums = InternalAccumulator.createAll() // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage assert(Accumulators.originals.size === internalAccums.size * 2) val accumsRegistered = sc.cleaner match { diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 67c4595ed1923..3a1a67cdc001a 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -31,7 +31,7 @@ class TaskMetricsSuite extends SparkFunSuite { import TaskMetricsSuite._ test("create") { - val internalAccums = InternalAccumulator.create() + val internalAccums = InternalAccumulator.createAll() val tm1 = new TaskMetrics val tm2 = new TaskMetrics(internalAccums) assert(tm1.accumulatorUpdates().size === internalAccums.size) @@ -51,7 +51,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("create with unnamed accum") { intercept[IllegalArgumentException] { new TaskMetrics( - InternalAccumulator.create() ++ Seq( + InternalAccumulator.createAll() ++ Seq( new Accumulator(0, IntAccumulatorParam, None, internal = true))) } } @@ -59,7 +59,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("create with duplicate name accum") { intercept[IllegalArgumentException] { new TaskMetrics( - InternalAccumulator.create() ++ Seq( + InternalAccumulator.createAll() ++ Seq( new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true))) } } @@ -67,7 +67,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("create with external accum") { intercept[IllegalArgumentException] { new TaskMetrics( - InternalAccumulator.create() ++ Seq( + InternalAccumulator.createAll() ++ Seq( new Accumulator(0, IntAccumulatorParam, Some("x")))) } } @@ -131,7 +131,7 @@ class TaskMetricsSuite extends SparkFunSuite { } test("mutating values") { - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val tm = new TaskMetrics(accums) // initial values assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L) @@ -180,7 +180,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("mutating shuffle read metrics values") { import shuffleRead._ - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val tm = new TaskMetrics(accums) def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = { assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value) @@ -234,7 +234,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("mutating shuffle write metrics values") { import shuffleWrite._ - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val tm = new TaskMetrics(accums) def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = { assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value) @@ -267,7 +267,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("mutating input metrics values") { import input._ - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val tm = new TaskMetrics(accums) def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = { assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value, @@ -296,7 +296,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("mutating output metrics values") { import output._ - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val tm = new TaskMetrics(accums) def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = { assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value, @@ -381,7 +381,7 @@ class TaskMetricsSuite extends SparkFunSuite { } test("additional accumulables") { - val internalAccums = InternalAccumulator.create() + val internalAccums = InternalAccumulator.createAll() val tm = new TaskMetrics(internalAccums) assert(tm.accumulatorUpdates().size === internalAccums.size) val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a")) @@ -419,7 +419,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("existing values in shuffle read accums") { // set shuffle read accum before passing it into TaskMetrics - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME)) assert(srAccum.isDefined) srAccum.get.asInstanceOf[Accumulator[Long]] += 10L @@ -432,7 +432,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("existing values in shuffle write accums") { // set shuffle write accum before passing it into TaskMetrics - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN)) assert(swAccum.isDefined) swAccum.get.asInstanceOf[Accumulator[Long]] += 10L @@ -445,7 +445,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("existing values in input accums") { // set input accum before passing it into TaskMetrics - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val inAccum = accums.find(_.name === Some(input.RECORDS_READ)) assert(inAccum.isDefined) inAccum.get.asInstanceOf[Accumulator[Long]] += 10L @@ -458,7 +458,7 @@ class TaskMetricsSuite extends SparkFunSuite { test("existing values in output accums") { // set output accum before passing it into TaskMetrics - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN)) assert(outAccum.isDefined) outAccum.get.asInstanceOf[Accumulator[Long]] += 10L @@ -470,7 +470,7 @@ class TaskMetricsSuite extends SparkFunSuite { } test("from accumulator updates") { - val accumUpdates1 = InternalAccumulator.create().map { a => + val accumUpdates1 = InternalAccumulator.createAll().map { a => AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues) } val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index b3bb86db10a32..850e470ca14d6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -127,7 +127,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val param = AccumulatorParam.LongAccumulatorParam val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) - val initialAccums = InternalAccumulator.create() + val initialAccums = InternalAccumulator.createAll() // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 18a16a25bfac5..9876bded33a08 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -269,7 +269,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val accums = InternalAccumulator.create() + val accums = InternalAccumulator.createAll() accums.foreach(Accumulators.register) val taskMetrics = new TaskMetrics(accums) val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 48951c3168032..de6f408fa82be 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -508,7 +508,7 @@ private[spark] object JsonProtocolSuite extends Assertions { /** -------------------------------- * | Util methods for comparing events | - * --------------------------------- */ + * --------------------------------- */ private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { (event1, event2) match { @@ -773,7 +773,7 @@ private[spark] object JsonProtocolSuite extends Assertions { /** ----------------------------------- * | Util methods for constructing events | - * ------------------------------------ */ + * ------------------------------------ */ private val properties = { val p = new Properties diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8b1a7303fc5b2..9209094385395 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -187,7 +187,8 @@ object MimaExcludes { ) ++ Seq( // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulator.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulator.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") ) ++ Seq( // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), From 3708d13f1a282a9ebf12e3b736f1aa1712cbacd5 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 8 Feb 2016 22:21:26 -0800 Subject: [PATCH 1831/2089] [SPARK-12992] [SQL] Support vectorized decoding in UnsafeRowParquetRecordReader. WIP: running tests. Code needs a bit of clean up. This patch completes the vectorized decoding with the goal of passing the existing tests. There is still more patches to support the rest of the format spec, even just for flat schemas. This patch adds a new flag to enable the vectorized decoding. Tests were updated to try with both modes where applicable. Once this is working well, we can remove the previous code path. Author: Nong Li Closes #11055 from nongli/spark-12992-2. --- .../parquet/UnsafeRowParquetRecordReader.java | 174 +++++++++++++++-- .../parquet/VectorizedPlainValuesReader.java | 59 +++++- .../parquet/VectorizedRleValuesReader.java | 180 +++++++++++++++++- .../parquet/VectorizedValuesReader.java | 13 +- .../vectorized/ColumnVectorUtils.java | 4 + .../execution/vectorized/ColumnarBatch.java | 39 +++- .../vectorized/OffHeapColumnVector.java | 4 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 8 + .../datasources/SqlNewHadoopRDD.scala | 3 + .../parquet/CatalystSchemaConverter.scala | 3 +- .../datasources/parquet/ParquetIOSuite.scala | 86 ++++++--- .../parquet/ParquetQuerySuite.scala | 6 +- .../parquet/ParquetReadBenchmark.scala | 33 +++- .../datasources/parquet/ParquetTest.scala | 22 ++- .../spark/sql/hive/HiveParquetSuite.scala | 3 +- 16 files changed, 549 insertions(+), 90 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index b5dddb9f11b22..4576ac2a3222f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -37,6 +37,7 @@ import org.apache.parquet.schema.Type; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; @@ -44,7 +45,7 @@ import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.unsafe.types.UTF8String; import static org.apache.parquet.column.ValuesType.*; @@ -57,7 +58,7 @@ * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. * All of these can be handled efficiently and easily with codegen. */ -public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { /** * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -110,6 +111,9 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas * and currently unsupported cases will fail with potentially difficult to diagnose errors. * This should be only turned on for development to work on this feature. * + * When this is set, the code will branch early on in the RecordReader APIs. There is no shared + * code between the path that uses the MR decoders and the vectorized ones. + * * TODOs: * - Implement all the encodings to support vectorized. * - Implement v2 page formats (just make sure we create the correct decoders). @@ -166,15 +170,23 @@ public void close() throws IOException { @Override public boolean nextKeyValue() throws IOException, InterruptedException { if (batchIdx >= numBatched) { - if (!loadBatch()) return false; + if (vectorizedDecode()) { + if (!nextBatch()) return false; + } else { + if (!loadBatch()) return false; + } } ++batchIdx; return true; } @Override - public UnsafeRow getCurrentValue() throws IOException, InterruptedException { - return rows[batchIdx - 1]; + public InternalRow getCurrentValue() throws IOException, InterruptedException { + if (vectorizedDecode()) { + return columnarBatch.getRow(batchIdx - 1); + } else { + return rows[batchIdx - 1]; + } } @Override @@ -202,20 +214,27 @@ public ColumnarBatch resultBatch(MemoryMode memMode) { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - assert(columnarBatch != null); + assert(vectorizedDecode()); columnarBatch.reset(); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned); + int num = (int)Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { columnReaders[i].readBatch(num, columnarBatch.column(i)); } rowsReturned += num; columnarBatch.setNumRows(num); + numBatched = num; + batchIdx = 0; return true; } + /** + * Returns true if we are doing a vectorized decode. + */ + private boolean vectorizedDecode() { return columnarBatch != null; } + private void initializeInternal() throws IOException { /** * Check that the requested schema is supported. @@ -613,15 +632,27 @@ private void readBatch(int total, ColumnVector column) throws IOException { decodeDictionaryIds(rowId, num, column); } else { switch (descriptor.getType()) { + case BOOLEAN: + readBooleanBatch(rowId, num, column); + break; case INT32: readIntBatch(rowId, num, column); break; case INT64: readLongBatch(rowId, num, column); break; + case FLOAT: + readFloatBatch(rowId, num, column); + break; + case DOUBLE: + readDoubleBatch(rowId, num, column); + break; case BINARY: readBinaryBatch(rowId, num, column); break; + case FIXED_LEN_BYTE_ARRAY: + readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength()); + break; default: throw new IOException("Unsupported type: " + descriptor.getType()); } @@ -645,7 +676,15 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { } } else if (column.dataType() == DataTypes.ByteType) { for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i))); + column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ShortType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); } } else { throw new NotImplementedException("Unimplemented type: " + column.dataType()); @@ -653,8 +692,36 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { break; case INT64: + if (column.dataType() == DataTypes.LongType || + DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case FLOAT: + for (int i = rowId; i < rowId + num; ++i) { + column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + } + break; + + case DOUBLE: for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + } + break; + + case FIXED_LEN_BYTE_ARRAY: + if (DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else { + throw new NotImplementedException(); } break; @@ -691,15 +758,24 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { * is guaranteed that num is smaller than the number of values left in the current page. */ + private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException { + assert(column.dataType() == DataTypes.BooleanType); + defColumn.readBooleans( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } + private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.IntegerType) { + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (DecimalType.is64BitDecimalType(column.dataType())) { + defColumn.readIntsAsLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { throw new NotImplementedException("Unimplemented type: " + column.dataType()); } @@ -707,10 +783,32 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. - // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.LongType) { + if (column.dataType() == DataTypes.LongType || + DecimalType.is64BitDecimalType(column.dataType())) { defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + } + } + + private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: support implicit cast to double? + if (column.dataType() == DataTypes.FloatType) { + defColumn.readFloats( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + } + } + + private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.DoubleType) { + defColumn.readDoubles( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { throw new NotImplementedException("Unimplemented type: " + column.dataType()); } @@ -727,6 +825,24 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE } } + private void readFixedLenByteArrayBatch(int rowId, int num, + ColumnVector column, int arrayLen) throws IOException { + VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, + CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } private void readPage() throws IOException { DataPage page = pageReader.readPage(); @@ -763,7 +879,11 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset)thro "could not read page in col " + descriptor + " as the dictionary was missing for encoding " + dataEncoding); } - if (columnarBatch != null && dataEncoding == Encoding.PLAIN_DICTIONARY) { + if (vectorizedDecode()) { + if (dataEncoding != Encoding.PLAIN_DICTIONARY && + dataEncoding != Encoding.RLE_DICTIONARY) { + throw new NotImplementedException("Unsupported encoding: " + dataEncoding); + } this.dataColumn = new VectorizedRleValuesReader(); } else { this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( @@ -771,8 +891,11 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset)thro } this.useDictionary = true; } else { - if (columnarBatch != null && dataEncoding == Encoding.PLAIN) { - this.dataColumn = new VectorizedPlainValuesReader(4); + if (vectorizedDecode()) { + if (dataEncoding != Encoding.PLAIN) { + throw new NotImplementedException("Unsupported encoding: " + dataEncoding); + } + this.dataColumn = new VectorizedPlainValuesReader(); } else { this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); } @@ -791,10 +914,12 @@ private void readPageV1(DataPageV1 page) throws IOException { ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); ValuesReader dlReader; - // Initialize the decoders. Use custom ones if vectorized decoding is enabled. - if (columnarBatch != null && page.getDlEncoding() == Encoding.RLE) { + // Initialize the decoders. + if (vectorizedDecode()) { + if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { + throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding()); + } int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - assert(bitWidth != 0); // not implemented this.defColumn = new VectorizedRleValuesReader(bitWidth); dlReader = this.defColumn; } else { @@ -818,8 +943,17 @@ private void readPageV2(DataPageV2 page) throws IOException { this.pageValueCount = page.getValueCount(); this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), page.getRepetitionLevels(), descriptor); - this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), - page.getDefinitionLevels(), descriptor); + + if (vectorizedDecode()) { + int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); + this.defColumn = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); + this.defColumn.initFromBuffer( + this.pageValueCount, page.getDefinitionLevels().toByteArray()); + } else { + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + } try { initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); } catch (IOException e) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index cec2418e46030..bf3283e85329b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -32,10 +32,9 @@ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { private byte[] buffer; private int offset; - private final int byteSize; + private int bitOffset; // Only used for booleans. - public VectorizedPlainValuesReader(int byteSize) { - this.byteSize = byteSize; + public VectorizedPlainValuesReader() { } @Override @@ -46,12 +45,15 @@ public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOExce @Override public void skip() { - offset += byteSize; + throw new UnsupportedOperationException(); } @Override - public void skip(int n) { - offset += n * byteSize; + public final void readBooleans(int total, ColumnVector c, int rowId) { + // TODO: properly vectorize this + for (int i = 0; i < total; i++) { + c.putBoolean(rowId + i, readBoolean()); + } } @Override @@ -66,6 +68,18 @@ public final void readLongs(int total, ColumnVector c, int rowId) { offset += 8 * total; } + @Override + public final void readFloats(int total, ColumnVector c, int rowId) { + c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 4 * total; + } + + @Override + public final void readDoubles(int total, ColumnVector c, int rowId) { + c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + @Override public final void readBytes(int total, ColumnVector c, int rowId) { for (int i = 0; i < total; i++) { @@ -76,6 +90,18 @@ public final void readBytes(int total, ColumnVector c, int rowId) { } } + @Override + public final boolean readBoolean() { + byte b = Platform.getByte(buffer, offset); + boolean v = (b & (1 << bitOffset)) != 0; + bitOffset += 1; + if (bitOffset == 8) { + bitOffset = 0; + offset++; + } + return v; + } + @Override public final int readInteger() { int v = Platform.getInt(buffer, offset); @@ -95,6 +121,20 @@ public final byte readByte() { return (byte)readInteger(); } + @Override + public final float readFloat() { + float v = Platform.getFloat(buffer, offset); + offset += 4; + return v; + } + + @Override + public final double readDouble() { + double v = Platform.getDouble(buffer, offset); + offset += 8; + return v; + } + @Override public final void readBinary(int total, ColumnVector v, int rowId) { for (int i = 0; i < total; i++) { @@ -104,4 +144,11 @@ public final void readBinary(int total, ColumnVector v, int rowId) { v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); } } + + @Override + public final Binary readBinary(int len) { + Binary result = Binary.fromByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); + offset += len; + return result; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 9bfd74db38766..629959a73baf3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -87,13 +87,38 @@ public void initFromPage(int valueCount, byte[] page, int start) { this.offset = start; this.in = page; if (fixedWidth) { - int length = readIntLittleEndian(); - this.end = this.offset + length; + if (bitWidth != 0) { + int length = readIntLittleEndian(); + this.end = this.offset + length; + } } else { this.end = page.length; if (this.end != this.offset) init(page[this.offset++] & 255); } - this.currentCount = 0; + if (bitWidth == 0) { + // 0 bit width, treat this as an RLE run of valueCount number of 0's. + this.mode = MODE.RLE; + this.currentCount = valueCount; + this.currentValue = 0; + } else { + this.currentCount = 0; + } + } + + // Initialize the reader from a buffer. This is used for the V2 page encoding where the + // definition are in its own buffer. + public void initFromBuffer(int valueCount, byte[] data) { + this.offset = 0; + this.in = data; + this.end = data.length; + if (bitWidth == 0) { + // 0 bit width, treat this as an RLE run of valueCount number of 0's. + this.mode = MODE.RLE; + this.currentCount = valueCount; + this.currentValue = 0; + } else { + this.currentCount = 0; + } } /** @@ -126,7 +151,6 @@ public int readValueDictionaryId() { return readInteger(); } - @Override public int readInteger() { if (this.currentCount == 0) { this.readNextGroup(); } @@ -189,6 +213,72 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } // TODO: can this code duplication be removed without a perf penalty? + public void readBooleans(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBooleans(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putBoolean(rowId + i, data.readBoolean()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readIntsAsLongs(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + for (int i = 0; i < n; i++) { + c.putLong(rowId + i, data.readInteger()); + } + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, data.readInteger()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + public void readBytes(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; @@ -253,6 +343,70 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, } } + public void readFloats(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readFloats(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putFloat(rowId + i, data.readFloat()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readDoubles(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readDoubles(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putDouble(rowId + i, data.readDouble()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + public void readBinarys(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; @@ -272,7 +426,7 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putNotNull(rowId + i); - data.readBinary(1, c, rowId); + data.readBinary(1, c, rowId + i); } else { c.putNull(rowId + i); } @@ -331,10 +485,24 @@ public void readBinary(int total, ColumnVector c, int rowId) { } @Override - public void skip(int n) { + public void readBooleans(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readFloats(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readDoubles(int total, ColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } + @Override + public Binary readBinary(int len) { + throw new UnsupportedOperationException("only readInts is valid."); + } /** * Reads the next varint encoded int. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index b6ec7311c564a..88418ca53fe1e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -19,24 +19,29 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.parquet.io.api.Binary; + /** * Interface for value decoding that supports vectorized (aka batched) decoding. * TODO: merge this into parquet-mr. */ public interface VectorizedValuesReader { + boolean readBoolean(); byte readByte(); int readInteger(); long readLong(); + float readFloat(); + double readDouble(); + Binary readBinary(int len); /* * Reads `total` values into `c` start at `c[rowId]` */ + void readBooleans(int total, ColumnVector c, int rowId); void readBytes(int total, ColumnVector c, int rowId); void readIntegers(int total, ColumnVector c, int rowId); void readLongs(int total, ColumnVector c, int rowId); + void readFloats(int total, ColumnVector c, int rowId); + void readDoubles(int total, ColumnVector c, int rowId); void readBinary(int total, ColumnVector c, int rowId); - - // TODO: add all the other parquet types. - - void skip(int n); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 453bc15e13503..2aeef7f2f90fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -18,11 +18,13 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.sql.Date; import java.util.Iterator; import java.util.List; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -100,6 +102,8 @@ private static void appendValue(ColumnVector dst, DataType t, Object o) { dst.appendStruct(false); dst.getChildColumn(0).appendInt(c.months); dst.getChildColumn(1).appendLong(c.microseconds); + } else if (t instanceof DateType) { + dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { throw new NotImplementedException("Type " + t); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index dbad5e070f1fe..070d897a7158c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -23,11 +23,11 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -62,6 +62,9 @@ public final class ColumnarBatch { // Total number of rows that have been filtered. private int numRowsFiltered = 0; + // Staging row returned from getRow. + final Row row; + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); } @@ -123,24 +126,36 @@ public final void markFiltered() { @Override /** - * Revisit this. This is expensive. + * Revisit this. This is expensive. This is currently only used in test paths. */ public final InternalRow copy() { - UnsafeRow row = new UnsafeRow(numFields()); - row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize); + GenericMutableRow row = new GenericMutableRow(columns.length); for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { DataType dt = columns[i].dataType(); - if (dt instanceof IntegerType) { + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof IntegerType) { row.setInt(i, getInt(i)); } else if (dt instanceof LongType) { row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); } else if (dt instanceof DoubleType) { row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i)); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); } else { - throw new RuntimeException("Not implemented."); + throw new RuntimeException("Not implemented. " + dt); } } } @@ -315,6 +330,16 @@ public int numValidRows() { */ public ColumnVector column(int ordinal) { return columns[ordinal]; } + /** + * Returns the row in this batch at `rowId`. Returned row is reused across calls. + */ + public ColumnarBatch.Row getRow(int rowId) { + assert(rowId >= 0); + assert(rowId < numRows); + row.rowId = rowId; + return row; + } + /** * Marks this row as being filtered out. This means a subsequent iteration over the rows * in this batch will not include this row. @@ -335,5 +360,7 @@ private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { StructField field = schema.fields()[i]; columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode); } + + this.row = new Row(this); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 7a224d19d15b7..c15f3d34a4929 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType; import org.apache.spark.sql.types.FloatType; @@ -391,7 +392,8 @@ private final void reserveInternal(int newCapacity) { this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); } else if (type instanceof ShortType) { this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); - } else if (type instanceof IntegerType || type instanceof FloatType) { + } else if (type instanceof IntegerType || type instanceof FloatType || + type instanceof DateType) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index c42bbd642ecae..99548bc83beaa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -376,7 +376,7 @@ private final void reserveInternal(int newCapacity) { short[] newData = new short[newCapacity]; if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); shortData = newData; - } else if (type instanceof IntegerType) { + } else if (type instanceof IntegerType || type instanceof DateType) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index eb9da0bd4fd4c..61a7b9935af0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -345,6 +345,14 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "Enables using the custom ParquetUnsafeRowRecordReader.") + // Note: this can not be enabled all the time because the reader will not be returning UnsafeRows. + // Doing so is very expensive and we should remove this requirement instead of fixing it here. + // Initial testing seems to indicate only sort requires this. + val PARQUET_VECTORIZED_READER_ENABLED = booleanConf( + key = "spark.sql.parquet.enableVectorizedReader", + defaultValue = Some(false), + doc = "Enables vectorized parquet decoding.") + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "When true, enable filter pushdown for ORC files.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 3605150b3b767..25911334a674f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -99,6 +99,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // a subset of the types (no complex types). protected val enableUnsafeRowParquetReader: Boolean = sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean + protected val enableVectorizedParquetReader: Boolean = + sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) @@ -176,6 +178,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( parquetReader.close() } else { reader = parquetReader.asInstanceOf[RecordReader[Void, V]] + if (enableVectorizedParquetReader) parquetReader.resultBatch() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index fb97a03df60f4..1c0d53fc77854 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -65,7 +65,8 @@ private[parquet] class CatalystSchemaConverter( def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, - writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean) + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index ab48e971b507a..bd87449f920bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -114,8 +114,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val path = new Path(location.getCanonicalPath) val conf = sparkContext.hadoopConfiguration writeMetadata(parquetSchema, path, conf) - val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) - assert(sparkTypes === expectedSparkTypes) + readParquetFile(path.toString)(df => { + val sparkTypes = df.schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + }) } } @@ -142,7 +144,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + readParquetFile(dir.getCanonicalPath){ df => { + checkAnswer(df, data.collect().toSeq) + }} } } } @@ -158,7 +162,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val data = makeDateRDD() data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + readParquetFile(dir.getCanonicalPath) { df => + checkAnswer(df, data.collect().toSeq) + } } } @@ -335,9 +341,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) + readParquetFile(path.toString) { df => + checkAnswer(df, (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) + } } } @@ -363,7 +370,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) - checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) + readParquetFile(file) { df => + checkAnswer(df, newData.map(Row.fromTuple)) + } } } @@ -372,7 +381,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) - checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) + readParquetFile(file) { df => + checkAnswer(df, data.map(Row.fromTuple)) + } } } @@ -392,7 +403,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) - checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) + readParquetFile(file) { df => + checkAnswer(df, (data ++ newData).map(Row.fromTuple)) + } } } @@ -420,11 +433,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val conf = sparkContext.hadoopConfiguration writeMetadata(parquetSchema, path, conf, extraMetadata) - assertResult(sqlContext.read.parquet(path.toString).schema) { - StructType( - StructField("a", BooleanType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: - Nil) + readParquetFile(path.toString) { df => + assertResult(df.schema) { + StructType( + StructField("a", BooleanType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: + Nil) + } } } } @@ -594,30 +609,43 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val path = s"${dir.getCanonicalPath}/data" df.write.parquet(path) - val df2 = sqlContext.read.parquet(path) - assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + readParquetFile(path) { df2 => + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } } } test("read dictionary encoded decimals written as INT32") { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i32.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i32.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + } + } } test("read dictionary encoded decimals written as INT64") { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i64.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i64.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + } + } } test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-fixed-len.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-fixed-len.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + } + } } test("SPARK-12589 copy() on rows returned from reader works for strings") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 0bc64404f1648..b123d2b31efcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -45,7 +45,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withParquetTable(data, "t") { + // Query appends, don't test with both read modes. + withParquetTable(data, "t", false) { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } @@ -69,7 +70,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext (maybeInt, i.toString) } - withParquetTable(data, "t") { + // TODO: vectorized doesn't work here because it requires UnsafeRows + withParquetTable(data, "t", false) { val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") val queryOutput = selfJoin.queryExecution.analyzed.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 14be9eec9a97a..e8893073e305a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -81,6 +81,12 @@ object ParquetReadBenchmark { } } + sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + sqlContext.sql("select sum(id) from tempTable").collect() + } + } + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray // Driving the parquet reader directly without Spark. parquetReaderBenchmark.addCase("ParquetReader") { num => @@ -143,10 +149,11 @@ object ParquetReadBenchmark { /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + SQL Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- - SQL Parquet Reader 1682.6 15.58 1.00 X - SQL Parquet MR 2379.6 11.02 0.71 X + SQL Parquet Reader 1350.56 11.65 1.00 X + SQL Parquet MR 1844.09 8.53 0.73 X + SQL Parquet Vectorized 1062.04 14.81 1.27 X */ sqlBenchmark.run() @@ -185,6 +192,13 @@ object ParquetReadBenchmark { } } + benchmark.addCase("SQL Parquet Vectorized") { iter => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray benchmark.addCase("ParquetReader") { num => var sum1 = 0L @@ -202,12 +216,13 @@ object ParquetReadBenchmark { } /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - SQL Parquet Reader 2245.6 7.00 1.00 X - SQL Parquet MR 2914.2 5.40 0.77 X - ParquetReader 1544.6 10.18 1.45 X + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + SQL Parquet Reader 1737.94 6.03 1.00 X + SQL Parquet MR 2393.08 4.38 0.73 X + SQL Parquet Vectorized 1442.99 7.27 1.20 X + ParquetReader 1032.11 10.16 1.68 X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 449fcc860fac9..5cbcccbd862dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -43,6 +43,20 @@ import org.apache.spark.sql.types.StructType */ private[sql] trait ParquetTest extends SQLTestUtils { + /** + * Reads the parquet file at `path` + */ + protected def readParquetFile(path: String, testVectorized: Boolean = true) + (f: DataFrame => Unit) = { + (true :: false :: Nil).foreach { vectorized => + if (!vectorized || testVectorized) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + f(sqlContext.read.parquet(path.toString)) + } + } + } + } + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. @@ -61,9 +75,9 @@ private[sql] trait ParquetTest extends SQLTestUtils { * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) + (data: Seq[T], testVectorized: Boolean = true) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + withParquetFile(data)(path => readParquetFile(path.toString, testVectorized)(f)) } /** @@ -72,9 +86,9 @@ private[sql] trait ParquetTest extends SQLTestUtils { * Parquet file will be dropped/deleted after `f` returns. */ protected def withParquetTable[T <: Product: ClassTag: TypeTag] - (data: Seq[T], tableName: String) + (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { - withParquetDataFrame(data) { df => + withParquetDataFrame(data, testVectorized) { df => sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 7841ffe5e03d1..b5af758a65b1c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -61,7 +61,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton } test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + // Don't run with vectorized: currently relies on UnsafeRow. + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") From f9307d8fc5223b4c5be07e3dc691a327f3bbfa7f Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 9 Feb 2016 08:43:46 +0000 Subject: [PATCH 1832/2089] [SPARK-13176][CORE] Use native file linking instead of external process ln Since Spark requires at least JRE 1.7, it is safe to use built-in java.nio.Files. Author: Jakob Odersky Closes #11098 from jodersky/SPARK-13176. --- .../scala/org/apache/spark/util/Utils.scala | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9ecbffbf715c5..e0c9bf02a1a20 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,6 +22,7 @@ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels +import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -34,7 +35,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} -import com.google.common.io.{ByteStreams, Files} +import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration @@ -516,7 +517,7 @@ private[spark] object Utils extends Logging { // The file does not exist in the target directory. Copy or move it there. if (removeSourceFile) { - Files.move(sourceFile, destFile) + Files.move(sourceFile.toPath, destFile.toPath) } else { logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") copyRecursive(sourceFile, destFile) @@ -534,7 +535,7 @@ private[spark] object Utils extends Logging { case (f1, f2) => filesEqualRecursive(f1, f2) } } else if (file1.isFile && file2.isFile) { - Files.equal(file1, file2) + GFiles.equal(file1, file2) } else { false } @@ -548,7 +549,7 @@ private[spark] object Utils extends Logging { val subfiles = source.listFiles() subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) } else { - Files.copy(source, dest) + Files.copy(source.toPath, dest.toPath) } } @@ -1596,30 +1597,18 @@ private[spark] object Utils extends Logging { } /** - * Creates a symlink. Note jdk1.7 has Files.createSymbolicLink but not used here - * for jdk1.6 support. Supports windows by doing copy, everything else uses "ln -sf". + * Creates a symlink. * @param src absolute path to the source * @param dst relative path for the destination */ - def symlink(src: File, dst: File) { + def symlink(src: File, dst: File): Unit = { if (!src.isAbsolute()) { throw new IOException("Source must be absolute") } if (dst.isAbsolute()) { throw new IOException("Destination must be relative") } - var cmdSuffix = "" - val linkCmd = if (isWindows) { - // refer to http://technet.microsoft.com/en-us/library/cc771254.aspx - cmdSuffix = " /s /e /k /h /y /i" - "cmd /c xcopy " - } else { - cmdSuffix = "" - "ln -sf " - } - import scala.sys.process._ - (linkCmd + src.getAbsolutePath() + " " + dst.getPath() + cmdSuffix) lines_! - ProcessLogger(line => logInfo(line)) + Files.createSymbolicLink(dst.toPath, src.toPath) } From 159198eff67ee9ead08fba60a585494ea1575147 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 9 Feb 2016 08:44:56 +0000 Subject: [PATCH 1833/2089] [SPARK-13165][STREAMING] Replace deprecated synchronizedBuffer in streaming Building with Scala 2.11 results in the warning trait SynchronizedBuffer in package mutable is deprecated: Synchronization via traits is deprecated as it is inherently unreliable. Consider java.util.concurrent.ConcurrentLinkedQueue as an alternative - we already use ConcurrentLinkedQueue elsewhere so lets replace it. Some notes about how behaviour is different for reviewers: The Seq from a SynchronizedBuffer that was implicitly converted would continue to receive updates - however when we do the same conversion explicitly on the ConcurrentLinkedQueue this isn't the case. Hence changing some of the (internal & test) APIs to pass an Iterable. toSeq is safe to use if there are no more updates. Author: Holden Karau Author: tedyu Closes #11067 from holdenk/SPARK-13165-replace-deprecated-synchronizedBuffer-in-streaming. --- .../spark/streaming/TestOutputStream.scala | 5 +- .../flume/FlumePollingStreamSuite.scala | 13 ++- .../streaming/flume/FlumeStreamSuite.scala | 16 ++-- .../kafka/DirectKafkaStreamSuite.scala | 43 +++++---- .../receiver/ReceiverSupervisorImpl.scala | 17 ++-- .../apache/spark/streaming/ui/BatchPage.scala | 2 +- .../spark/streaming/ui/BatchUIData.scala | 2 +- .../ui/StreamingJobProgressListener.scala | 21 +++-- .../spark/streaming/JavaTestUtils.scala | 4 +- .../streaming/BasicOperationsSuite.scala | 24 ++--- .../spark/streaming/CheckpointSuite.scala | 22 ++--- .../spark/streaming/InputStreamsSuite.scala | 72 +++++++-------- .../spark/streaming/MapWithStateSuite.scala | 10 +-- .../spark/streaming/MasterFailureTest.scala | 9 +- .../streaming/StreamingListenerSuite.scala | 89 +++++++++---------- .../spark/streaming/TestSuiteBase.scala | 36 ++++---- .../receiver/BlockGeneratorSuite.scala | 32 ++++--- .../StreamingJobProgressListenerSuite.scala | 2 +- .../streaming/util/RecurringTimerSuite.scala | 22 ++--- 19 files changed, 226 insertions(+), 215 deletions(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala index 57374ef515431..fc02c9fcb50ae 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -33,10 +34,10 @@ import org.apache.spark.util.Utils * The buffer contains a sequence of RDD's, each containing a sequence of items */ class TestOutputStream[T: ClassTag](parent: DStream[T], - val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) + val output: ConcurrentLinkedQueue[Seq[T]] = new ConcurrentLinkedQueue[Seq[T]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() - output += collected + output.add(collected) }, false) { // This is to clear the output buffer every it is read from a checkpoint diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 60db846ffb7a2..10dcbf98bc3b6 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps @@ -102,9 +102,8 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, utils.eventsPerBatch, 5) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputQueue) outputStream.register() ssc.start() @@ -115,11 +114,11 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log // The eventually is required to ensure that all data in the batch has been processed. eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenOutputBuffer = outputBuffer.flatten - val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + val flattenOutput = outputQueue.asScala.toSeq.flatten + val headers = flattenOutput.map(_.event.getHeaders.asScala.map { case (key, value) => (key.toString, value.toString) }).map(_.asJava) - val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody)) + val bodies = flattenOutput.map(e => JavaUtils.bytesToString(e.event.getBody)) utils.assertOutput(headers.asJava, bodies.asJava) } } finally { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index b29e591c07374..38208c651805f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.flume +import java.util.concurrent.ConcurrentLinkedQueue + import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps @@ -51,14 +52,14 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w val input = (1 to 100).map { _.toString } val utils = new FlumeTestUtils try { - val outputBuffer = startContext(utils.getTestPort(), testCompression) + val outputQueue = startContext(utils.getTestPort(), testCompression) eventually(timeout(10 seconds), interval(100 milliseconds)) { utils.writeInput(input.asJava, testCompression) } eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } + val outputEvents = outputQueue.asScala.toSeq.flatten.map { _.event } outputEvents.foreach { event => event.getHeaders.get("test") should be("header") @@ -76,16 +77,15 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Setup and start the streaming context */ private def startContext( - testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = { + testPort: Int, testCompression: Boolean): (ConcurrentLinkedQueue[Seq[SparkFlumeEvent]]) = { ssc = new StreamingContext(conf, Milliseconds(200)) val flumeStream = FlumeUtils.createStream( ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputQueue) outputStream.register() ssc.start() - outputBuffer + outputQueue } /** Class to create socket channel with compression */ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 655b161734d94..8398178e9b79b 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.streaming.kafka import java.io.File +import java.util.Arrays import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -101,8 +102,7 @@ class DirectKafkaStreamSuite ssc, kafkaParams, topics) } - val allReceived = - new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + val allReceived = new ConcurrentLinkedQueue[(String, String)]() // hold a reference to the current offset ranges, so it can be used downstream var offsetRanges = Array[OffsetRange]() @@ -131,11 +131,12 @@ class DirectKafkaStreamSuite assert(partSize === rangeSize, "offset ranges are wrong") } } - stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + stream.foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { assert(allReceived.size === totalSent, - "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) } ssc.stop() } @@ -173,8 +174,8 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] - stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } + val collectedData = new ConcurrentLinkedQueue[String]() + stream.map { _._2 }.foreachRDD { rdd => collectedData.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() val newData = Map("b" -> 10) kafkaTestUtils.sendMessages(topic, newData) @@ -219,8 +220,8 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] - stream.foreachRDD { rdd => collectedData ++= rdd.collect() } + val collectedData = new ConcurrentLinkedQueue[String]() + stream.foreachRDD { rdd => collectedData.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() val newData = Map("b" -> 10) kafkaTestUtils.sendMessages(topic, newData) @@ -265,7 +266,7 @@ class DirectKafkaStreamSuite // This is to collect the raw data received from Kafka kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => val data = rdd.map { _._2 }.collect() - DirectKafkaStreamSuite.collectedData.appendAll(data) + DirectKafkaStreamSuite.collectedData.addAll(Arrays.asList(data: _*)) } // This is ensure all the data is eventually receiving only once @@ -335,14 +336,14 @@ class DirectKafkaStreamSuite ssc, kafkaParams, Set(topic)) } - val allReceived = - new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + val allReceived = new ConcurrentLinkedQueue[(String, String)] - stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + stream.foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { assert(allReceived.size === totalSent, - "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) // Calculate all the record number collected in the StreamingListener. assert(collector.numRecordsSubmitted.get() === totalSent) @@ -389,17 +390,16 @@ class DirectKafkaStreamSuite } } - val collectedData = - new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + val collectedData = new ConcurrentLinkedQueue[Array[String]]() // Used for assertion failure messages. def dataToString: String = - collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + collectedData.asScala.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") // This is to collect the raw data received from Kafka kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => val data = rdd.map { _._2 }.collect() - collectedData += data + collectedData.add(data) } ssc.start() @@ -415,7 +415,7 @@ class DirectKafkaStreamSuite eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { // Assert that rate estimator values are used to determine maxMessagesPerPartition. // Funky "-" in message makes the complete assertion message read better. - assert(collectedData.exists(_.size == expectedSize), + assert(collectedData.asScala.exists(_.size == expectedSize), s" - No arrays of size $expectedSize for rate $rate found in $dataToString") } } @@ -433,7 +433,7 @@ class DirectKafkaStreamSuite } object DirectKafkaStreamSuite { - val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + val collectedData = new ConcurrentLinkedQueue[String]() @volatile var total = -1L class InputInfoCollector extends StreamingListener { @@ -468,4 +468,3 @@ private[streaming] class ConstantEstimator(@volatile private var rate: Long) processingDelay: Long, schedulingDelay: Long): Option[Double] = Some(rate) } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index b774b6b9a55d1..8d4e6827d6a28 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -19,8 +19,9 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -83,7 +84,7 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) case UpdateRateLimit(eps) => logInfo(s"Received a new rate limit: $eps.") - registeredBlockGenerators.foreach { bg => + registeredBlockGenerators.asScala.foreach { bg => bg.updateRate(eps) } } @@ -92,8 +93,7 @@ private[streaming] class ReceiverSupervisorImpl( /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) - private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] - with mutable.SynchronizedBuffer[BlockGenerator] + private val registeredBlockGenerators = new ConcurrentLinkedQueue[BlockGenerator]() /** Divides received data records into data blocks for pushing in BlockManager. */ private val defaultBlockGeneratorListener = new BlockGeneratorListener { @@ -170,11 +170,11 @@ private[streaming] class ReceiverSupervisorImpl( } override protected def onStart() { - registeredBlockGenerators.foreach { _.start() } + registeredBlockGenerators.asScala.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - registeredBlockGenerators.foreach { _.stop() } + registeredBlockGenerators.asScala.foreach { _.stop() } env.rpcEnv.stop(endpoint) } @@ -194,10 +194,11 @@ private[streaming] class ReceiverSupervisorImpl( override def createBlockGenerator( blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { // Cleanup BlockGenerators that have already been stopped - registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + val stoppedGenerators = registeredBlockGenerators.asScala.filter{ _.isStopped() } + stoppedGenerators.foreach(registeredBlockGenerators.remove(_)) val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) - registeredBlockGenerators += newBlockGenerator + registeredBlockGenerators.add(newBlockGenerator) newBlockGenerator } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 81de07f933f8a..e235afad5e8f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -273,7 +273,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId - (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) + (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).toSeq.sorted) } val outputOps: Seq[(OutputOperationUIData, Seq[SparkJobId])] = diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index 3ef3689de1c45..1af60857bc770 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -33,7 +33,7 @@ private[ui] case class BatchUIData( val processingStartTime: Option[Long], val processingEndTime: Option[Long], val outputOperations: mutable.HashMap[OutputOpId, OutputOperationUIData] = mutable.HashMap(), - var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { + var outputOpIdSparkJobIdPairs: Iterable[OutputOpIdAndSparkJobId] = Seq.empty) { /** * Time taken for the first job of this batch to start processing from the time this batch diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index cacd430cf339c..30a3a98c0183c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -18,8 +18,10 @@ package org.apache.spark.streaming.ui import java.util.{LinkedHashMap, Map => JMap, Properties} +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.{ArrayBuffer, HashMap, Queue, SynchronizedBuffer} +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, Queue} import org.apache.spark.scheduler._ import org.apache.spark.streaming.{StreamingContext, Time} @@ -41,9 +43,9 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) // we may not be able to get the corresponding BatchUIData when receiving onJobStart. So here we // cannot use a map of (Time, BatchUIData). private[ui] val batchTimeToOutputOpIdSparkJobIdPair = - new LinkedHashMap[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]] { + new LinkedHashMap[Time, ConcurrentLinkedQueue[OutputOpIdAndSparkJobId]] { override def removeEldestEntry( - p1: JMap.Entry[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]]): Boolean = { + p1: JMap.Entry[Time, ConcurrentLinkedQueue[OutputOpIdAndSparkJobId]]): Boolean = { // If a lot of "onBatchCompleted"s happen before "onJobStart" (image if // SparkContext.listenerBus is very slow), "batchTimeToOutputOpIdToSparkJobIds" // may add some information for a removed batch when processing "onJobStart". It will be a @@ -131,12 +133,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) if (outputOpIdToSparkJobIds == null) { - outputOpIdToSparkJobIds = - new ArrayBuffer[OutputOpIdAndSparkJobId]() - with SynchronizedBuffer[OutputOpIdAndSparkJobId] + outputOpIdToSparkJobIds = new ConcurrentLinkedQueue[OutputOpIdAndSparkJobId]() batchTimeToOutputOpIdSparkJobIdPair.put(batchTime, outputOpIdToSparkJobIds) } - outputOpIdToSparkJobIds += OutputOpIdAndSparkJobId(outputOpId, jobStart.jobId) + outputOpIdToSparkJobIds.add(OutputOpIdAndSparkJobId(outputOpId, jobStart.jobId)) } } @@ -256,8 +256,11 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } batchUIData.foreach { _batchUIData => - val outputOpIdToSparkJobIds = - Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime)).getOrElse(Seq.empty) + // We use an Iterable rather than explicitly converting to a seq so that updates + // will propegate + val outputOpIdToSparkJobIds: Iterable[OutputOpIdAndSparkJobId] = + Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime).asScala) + .getOrElse(Seq.empty) _batchUIData.outputOpIdSparkJobIdPairs = outputOpIdToSparkJobIds } batchUIData diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index 57b50bdfd6520..ae44fd07ac558 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -69,7 +69,7 @@ trait JavaTestBase extends TestSuiteBase { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - res.map(_.asJava).asJava + res.map(_.asJava).toSeq.asJava } /** @@ -85,7 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - res.map(entry => entry.map(_.asJava).asJava).asJava + res.map(entry => entry.map(_.asJava).asJava).toSeq.asJava } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 25e7ae8262a5f..f1c64799c6f7d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag @@ -84,9 +86,10 @@ class BasicOperationsSuite extends TestSuiteBase { withStreamingContext(setupStreams(input, operation, 2)) { ssc => val output = runStreamsWithPartitions(ssc, 3, 3) assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) + val outputArray = output.toArray + val first = outputArray(0) + val second = outputArray(1) + val third = outputArray(2) assert(first.size === 5) assert(second.size === 5) @@ -104,9 +107,10 @@ class BasicOperationsSuite extends TestSuiteBase { withStreamingContext(setupStreams(input, operation, 5)) { ssc => val output = runStreamsWithPartitions(ssc, 3, 3) assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) + val outputArray = output.toArray + val first = outputArray(0) + val second = outputArray(1) + val third = outputArray(2) assert(first.size === 2) assert(second.size === 2) @@ -645,8 +649,8 @@ class BasicOperationsSuite extends TestSuiteBase { val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val mappedStream = networkStream.map(_ + ".").persist() - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(mappedStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(mappedStream, outputQueue) outputStream.register() ssc.start() @@ -685,7 +689,7 @@ class BasicOperationsSuite extends TestSuiteBase { testServer.stop() // verify data has been received - assert(outputBuffer.size > 0) + assert(!outputQueue.isEmpty) assert(blockRdds.size > 0) assert(persistentRddIds.size > 0) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 786703eb9a84e..1f0245a397570 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.streaming import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream} +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.base.Charsets @@ -105,7 +106,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => val operatedStream = operation(inputStream) operatedStream.print() val outputStream = new TestOutputStreamWithPartitions(operatedStream, - new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + new ConcurrentLinkedQueue[Seq[Seq[V]]]) outputStream.register() ssc.checkpoint(checkpointDir) @@ -166,7 +167,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => // are written to make sure that both of them have been written. assert(checkpointFilesOfLatestTime.size === 2) } - outputStream.output.map(_.flatten) + outputStream.output.asScala.map(_.flatten).toSeq } finally { ssc.stop(stopSparkContext = stopSparkContext) @@ -591,7 +592,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester // Set up the streaming context and input streams val batchDuration = Seconds(2) // Due to 1-second resolution of setLastModified() on some OS's. val testDir = Utils.createTempDir() - val outputBuffer = new ArrayBuffer[Seq[Int]] with SynchronizedBuffer[Seq[Int]] + val outputBuffer = new ConcurrentLinkedQueue[Seq[Int]] /** * Writes a file named `i` (which contains the number `i`) to the test directory and sets its @@ -671,7 +672,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester ssc.stop() // Check that we shut down while the third batch was being processed assert(batchCounter.getNumCompletedBatches === 2) - assert(outputStream.output.flatten === Seq(1, 3)) + assert(outputStream.output.asScala.toSeq.flatten === Seq(1, 3)) } // The original StreamingContext has now been stopped. @@ -721,7 +722,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) + logInfo("Output after restart = " + outputStream.output.asScala.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() @@ -730,11 +731,11 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(recordedFiles(ssc) === (1 to 9)) // Append the new output to the old buffer - outputBuffer ++= outputStream.output + outputBuffer.addAll(outputStream.output) // Verify whether all the elements received are as expected val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) - assert(outputBuffer.flatten.toSet === expectedOutput.toSet) + assert(outputBuffer.asScala.flatten.toSet === expectedOutput.toSet) } } finally { Utils.deleteRecursively(testDir) @@ -894,7 +895,8 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): + Iterable[Seq[V]] = { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) @@ -908,7 +910,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] - outputStream.output.map(_.flatten) + outputStream.output.asScala.map(_.flatten) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 75591f04ca00d..93c883362cfbe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.streaming import java.io.{BufferedWriter, File, OutputStreamWriter} import java.net.{ServerSocket, Socket, SocketException} import java.nio.charset.Charset -import java.util.concurrent.{ArrayBlockingQueue, CountDownLatch, Executors, TimeUnit} +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer, SynchronizedQueue} +import scala.collection.JavaConverters._ +import scala.collection.mutable.SynchronizedQueue import scala.language.postfixOps import com.google.common.io.Files @@ -58,8 +59,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchCounter = new BatchCounter(ssc) val networkStream = ssc.socketTextStream( "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputQueue) outputStream.register() ssc.start() @@ -90,9 +91,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -100,7 +101,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) + val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray assert(output.size === expectedOutput.size) for (i <- 0 until output.size) { assert(output(i) === expectedOutput(i)) @@ -119,8 +120,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchCounter = new BatchCounter(ssc) val networkStream = ssc.socketTextStream( "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputQueue) outputStream.register() ssc.start() @@ -156,9 +157,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { clock.setTime(existingFile.lastModified + batchDuration.milliseconds) val batchCounter = new BatchCounter(ssc) val fileStream = ssc.binaryRecordsStream(testDir.toString, 1) - val outputBuffer = new ArrayBuffer[Seq[Array[Byte]]] - with SynchronizedBuffer[Seq[Array[Byte]]] - val outputStream = new TestOutputStream(fileStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[Array[Byte]]] + val outputStream = new TestOutputStream(fileStream, outputQueue) outputStream.register() ssc.start() @@ -183,8 +183,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } val expectedOutput = input.map(i => i.toByte) - val obtainedOutput = outputBuffer.flatten.toList.map(i => i(0).toByte) - assert(obtainedOutput === expectedOutput) + val obtainedOutput = outputQueue.asScala.flatten.toList.map(i => i(0).toByte) + assert(obtainedOutput.toSeq === expectedOutput) } } finally { if (testDir != null) Utils.deleteRecursively(testDir) @@ -206,15 +206,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val numTotalRecords = numThreads * numRecordsPerThread val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false - val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] - def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) + val outputQueue = new ConcurrentLinkedQueue[Seq[Long]] + def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x) // set up the network stream using the test receiver withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => val networkStream = ssc.receiverStream[Int](testReceiver) val countStream = networkStream.count - val outputStream = new TestOutputStream(countStream, outputBuffer) + val outputStream = new TestOutputStream(countStream, outputQueue) outputStream.register() ssc.start() @@ -231,9 +231,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") assert(output.sum === numTotalRecords) } @@ -241,14 +241,14 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("queue input stream - oneAtATime = true") { val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.size > 0) // Set up the streaming context and input streams withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => val queue = new SynchronizedQueue[RDD[String]]() val queueStream = ssc.queueStream(queue, oneAtATime = true) - val outputStream = new TestOutputStream(queueStream, outputBuffer) + val outputStream = new TestOutputStream(queueStream, outputQueue) outputStream.register() ssc.start() @@ -266,9 +266,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -276,14 +276,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } + output.zipWithIndex.foreach{case (e, i) => assert(e == expectedOutput(i))} } test("queue input stream - oneAtATime = false") { - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.size > 0) val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) @@ -291,7 +289,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => val queue = new SynchronizedQueue[RDD[String]]() val queueStream = ssc.queueStream(queue, oneAtATime = false) - val outputStream = new TestOutputStream(queueStream, outputBuffer) + val outputStream = new TestOutputStream(queueStream, outputQueue) outputStream.register() ssc.start() @@ -312,9 +310,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -322,9 +320,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } + output.zipWithIndex.foreach{case (e, i) => assert(e == expectedOutput(i))} } test("test track the number of input stream") { @@ -373,8 +369,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchCounter = new BatchCounter(ssc) val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( testDir.toString, (x: Path) => true, newFilesOnly = newFilesOnly).map(_._2.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(fileStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputQueue) outputStream.register() ssc.start() @@ -404,7 +400,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } else { (Seq(0) ++ input).map(_.toString).toSet } - assert(outputBuffer.flatten.toSet === expectedOutput) + assert(outputQueue.asScala.flatten.toSet === expectedOutput) } } finally { if (testDir != null) Utils.deleteRecursively(testDir) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 2984fd2b298dc..b6d6585bd8244 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.streaming import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -550,9 +551,9 @@ class MapWithStateSuite extends SparkFunSuite val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec) - val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val collectedOutputs = new ConcurrentLinkedQueue[Seq[T]] val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) - val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] + val collectedStateSnapshots = new ConcurrentLinkedQueue[Seq[(K, S)]] val stateSnapshotStream = new TestOutputStream( trackeStateStream.stateSnapshots(), collectedStateSnapshots) outputStream.register() @@ -567,7 +568,7 @@ class MapWithStateSuite extends SparkFunSuite batchCounter.waitUntilBatchesCompleted(numBatches, 10000) ssc.stop(stopSparkContext = false) - (collectedOutputs, collectedStateSnapshots) + (collectedOutputs.asScala.toSeq, collectedStateSnapshots.asScala.toSeq) } private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { @@ -583,4 +584,3 @@ class MapWithStateSuite extends SparkFunSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 7bbbdebd9b19f..a02d49eced1d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -21,6 +21,7 @@ import java.io.{File, IOException} import java.nio.charset.Charset import java.util.UUID +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.Random @@ -215,8 +216,8 @@ object MasterFailureTest extends Logging { while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer - val outputBuffer = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[T]].output - def output = outputBuffer.flatMap(x => x) + val outputQueue = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[T]].output + def output = outputQueue.asScala.flatten // Start the thread to kill the streaming after some time killed = false @@ -257,9 +258,9 @@ object MasterFailureTest extends Logging { // Verify whether the output of each batch has only one element or no element // and then merge the new output with all the earlier output - mergedOutput ++= output + mergedOutput ++= output.toSeq totalTimeRan += timeRan - logInfo("New output = " + output) + logInfo("New output = " + output.toSeq) logInfo("Merged output = " + mergedOutput) logInfo("Time ran = " + timeRan) logInfo("Total time ran = " + totalTimeRan) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 1ed68c74db9fd..66f47394c7ac1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming -import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -62,43 +65,43 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val batchInfosSubmitted = collector.batchInfosSubmitted batchInfosSubmitted should have size 4 - batchInfosSubmitted.foreach(info => { + batchInfosSubmitted.asScala.foreach(info => { info.schedulingDelay should be (None) info.processingDelay should be (None) info.totalDelay should be (None) }) - batchInfosSubmitted.foreach { info => + batchInfosSubmitted.asScala.foreach { info => info.numRecords should be (1L) info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } - isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosSubmitted.asScala.map(_.submissionTime)) should be (true) // SPARK-6766: processingStartTime of batch info should not be None when starting val batchInfosStarted = collector.batchInfosStarted batchInfosStarted should have size 4 - batchInfosStarted.foreach(info => { + batchInfosStarted.asScala.foreach(info => { info.schedulingDelay should not be None info.schedulingDelay.get should be >= 0L info.processingDelay should be (None) info.totalDelay should be (None) }) - batchInfosStarted.foreach { info => + batchInfosStarted.asScala.foreach { info => info.numRecords should be (1L) info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } - isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosStarted.asScala.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.asScala.map(_.processingStartTime.get)) should be (true) // test onBatchCompleted val batchInfosCompleted = collector.batchInfosCompleted batchInfosCompleted should have size 4 - batchInfosCompleted.foreach(info => { + batchInfosCompleted.asScala.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -107,14 +110,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - batchInfosCompleted.foreach { info => + batchInfosCompleted.asScala.foreach { info => info.numRecords should be (1L) info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } - isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.asScala.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.asScala.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.asScala.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { @@ -129,13 +132,13 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { try { eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) - collector.startedReceiverStreamIds(0) should equal (0) - collector.stoppedReceiverStreamIds should have size 1 - collector.stoppedReceiverStreamIds(0) should equal (0) + collector.startedReceiverStreamIds.peek() should equal (0) + collector.stoppedReceiverStreamIds.size should equal (1) + collector.stoppedReceiverStreamIds.peek() should equal (0) collector.receiverErrors should have size 1 - collector.receiverErrors(0)._1 should equal (0) - collector.receiverErrors(0)._2 should include ("report error") - collector.receiverErrors(0)._3 should include ("report exception") + collector.receiverErrors.peek()._1 should equal (0) + collector.receiverErrors.peek()._2 should include ("report error") + collector.receiverErrors.peek()._3 should include ("report exception") } } finally { ssc.stop() @@ -155,8 +158,8 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { eventually(timeout(30 seconds), interval(20 millis)) { - collector.startedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) - collector.completedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) + collector.startedOutputOperationIds.asScala.take(3) should be (Seq(0, 1, 2)) + collector.completedOutputOperationIds.asScala.take(3) should be (Seq(0, 1, 2)) } } finally { ssc.stop() @@ -271,69 +274,63 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } /** Check if a sequence of numbers is in increasing order */ - def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for (i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) { - return false - } - } - true + def isInIncreasingOrder(data: Iterable[Long]): Boolean = { + !data.sliding(2).map{itr => itr.size == 2 && itr.head > itr.tail.head }.contains(true) } } /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfosCompleted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] - val batchInfosStarted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] - val batchInfosSubmitted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] + val batchInfosCompleted = new ConcurrentLinkedQueue[BatchInfo] + val batchInfosStarted = new ConcurrentLinkedQueue[BatchInfo] + val batchInfosSubmitted = new ConcurrentLinkedQueue[BatchInfo] override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { - batchInfosSubmitted += batchSubmitted.batchInfo + batchInfosSubmitted.add(batchSubmitted.batchInfo) } override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { - batchInfosStarted += batchStarted.batchInfo + batchInfosStarted.add(batchStarted.batchInfo) } override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfosCompleted += batchCompleted.batchInfo + batchInfosCompleted.add(batchCompleted.batchInfo) } } /** Listener that collects information on processed batches */ class ReceiverInfoCollector extends StreamingListener { - val startedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] - val stoppedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] - val receiverErrors = - new ArrayBuffer[(Int, String, String)] with SynchronizedBuffer[(Int, String, String)] + val startedReceiverStreamIds = new ConcurrentLinkedQueue[Int] + val stoppedReceiverStreamIds = new ConcurrentLinkedQueue[Int] + val receiverErrors = new ConcurrentLinkedQueue[(Int, String, String)] override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - startedReceiverStreamIds += receiverStarted.receiverInfo.streamId + startedReceiverStreamIds.add(receiverStarted.receiverInfo.streamId) } override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped) { - stoppedReceiverStreamIds += receiverStopped.receiverInfo.streamId + stoppedReceiverStreamIds.add(receiverStopped.receiverInfo.streamId) } override def onReceiverError(receiverError: StreamingListenerReceiverError) { - receiverErrors += ((receiverError.receiverInfo.streamId, - receiverError.receiverInfo.lastErrorMessage, receiverError.receiverInfo.lastError)) + receiverErrors.add(((receiverError.receiverInfo.streamId, + receiverError.receiverInfo.lastErrorMessage, receiverError.receiverInfo.lastError))) } } /** Listener that collects information on processed output operations */ class OutputOperationInfoCollector extends StreamingListener { - val startedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] - val completedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + val startedOutputOperationIds = new ConcurrentLinkedQueue[Int]() + val completedOutputOperationIds = new ConcurrentLinkedQueue[Int]() override def onOutputOperationStarted( outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { - startedOutputOperationIds += outputOperationStarted.outputOperationInfo.id + startedOutputOperationIds.add(outputOperationStarted.outputOperationInfo.id) } override def onOutputOperationCompleted( outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { - completedOutputOperationIds += outputOperationCompleted.outputOperationInfo.id + completedOutputOperationIds.add(outputOperationCompleted.outputOperationInfo.id) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 239b10894ad2c..82cd63bcafc97 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -18,9 +18,9 @@ package org.apache.spark.streaming import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.SynchronizedBuffer +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -87,17 +87,17 @@ class TestInputStream[T: ClassTag](_ssc: StreamingContext, input: Seq[Seq[T]], n /** * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * ConcurrentLinkedQueue. This queue is wiped clean on being restored from checkpoint. * - * The buffer contains a sequence of RDD's, each containing a sequence of items + * The buffer contains a sequence of RDD's, each containing a sequence of items. */ class TestOutputStream[T: ClassTag]( parent: DStream[T], - val output: SynchronizedBuffer[Seq[T]] = - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val output: ConcurrentLinkedQueue[Seq[T]] = + new ConcurrentLinkedQueue[Seq[T]]() ) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() - output += collected + output.add(collected) }, false) { // This is to clear the output buffer every it is read from a checkpoint @@ -110,18 +110,18 @@ class TestOutputStream[T: ClassTag]( /** * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * ConcurrentLinkedQueue. This queue is wiped clean on being restored from checkpoint. * - * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each + * The queue contains a sequence of RDD's, each containing a sequence of partitions, each * containing a sequence of items. */ class TestOutputStreamWithPartitions[T: ClassTag]( parent: DStream[T], - val output: SynchronizedBuffer[Seq[Seq[T]]] = - new ArrayBuffer[Seq[Seq[T]]] with SynchronizedBuffer[Seq[Seq[T]]]) + val output: ConcurrentLinkedQueue[Seq[Seq[T]]] = + new ConcurrentLinkedQueue[Seq[Seq[T]]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.glom().collect().map(_.toSeq) - output += collected + output.add(collected) }, false) { // This is to clear the output buffer every it is read from a checkpoint @@ -322,7 +322,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val inputStream = new TestInputStream(ssc, input, numPartitions) val operatedStream = operation(inputStream) val outputStream = new TestOutputStreamWithPartitions(operatedStream, - new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + new ConcurrentLinkedQueue[Seq[Seq[V]]]) outputStream.register() ssc } @@ -347,7 +347,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) val operatedStream = operation(inputStream1, inputStream2) val outputStream = new TestOutputStreamWithPartitions(operatedStream, - new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]]) + new ConcurrentLinkedQueue[Seq[Seq[W]]]) outputStream.register() ssc } @@ -418,7 +418,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { } val timeTaken = System.currentTimeMillis() - startTime logInfo("Output generated in " + timeTaken + " milliseconds") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + output.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") @@ -426,7 +426,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { } finally { ssc.stop(stopSparkContext = true) } - output + output.asScala.toSeq } /** @@ -501,7 +501,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size withStreamingContext(setupStreams[U, V](input, operation)) { ssc => val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) + verifyOutput[V](output.toSeq, expectedOutput, useSet) } } @@ -540,7 +540,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc => val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) + verifyOutput[W](output.toSeq, expectedOutput, useSet) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index f5ec0ff60aa27..a1d0561bf308a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.streaming.receiver +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.reflectiveCalls @@ -84,7 +87,7 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { assert(listener.onPushBlockCalled === true) } } - listener.pushedData should contain theSameElementsInOrderAs (data1) + listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs (data1) assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly @@ -92,21 +95,24 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { val metadata2 = data2.map { _.toString } data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } assert(listener.onAddDataCalled === true) - listener.addedData should contain theSameElementsInOrderAs (data2) - listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + listener.addedData.asScala.toSeq should contain theSameElementsInOrderAs (data2) + listener.addedMetadata.asScala.toSeq should contain theSameElementsInOrderAs (metadata2) clock.advance(blockIntervalMs) // advance clock to generate blocks eventually(timeout(1 second)) { - listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + val combined = data1 ++ data2 + listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs combined } // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly val data3 = 21 to 30 val metadata3 = "metadata" blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) - listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + val combinedMetadata = metadata2 :+ metadata3 + listener.addedMetadata.asScala.toSeq should contain theSameElementsInOrderAs (combinedMetadata) clock.advance(blockIntervalMs) // advance clock to generate blocks eventually(timeout(1 second)) { - listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + val combinedData = data1 ++ data2 ++ data3 + listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs (combinedData) } // Stop the block generator by starting the stop on a different thread and @@ -191,7 +197,7 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { assert(thread.isAlive === false) } assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped - assert(listener.pushedData === data, "All data not pushed by stop()") + assert(listener.pushedData.asScala.toSeq === data, "All data not pushed by stop()") } test("block push errors are reported") { @@ -231,15 +237,15 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { /** A listener for BlockGenerator that records the data in the callbacks */ private class TestBlockGeneratorListener extends BlockGeneratorListener { - val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] - val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] - val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val pushedData = new ConcurrentLinkedQueue[Any] + val addedData = new ConcurrentLinkedQueue[Any] + val addedMetadata = new ConcurrentLinkedQueue[Any] @volatile var onGenerateBlockCalled = false @volatile var onAddDataCalled = false @volatile var onPushBlockCalled = false override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { - pushedData ++= arrayBuffer + pushedData.addAll(arrayBuffer.asJava) onPushBlockCalled = true } override def onError(message: String, throwable: Throwable): Unit = {} @@ -247,8 +253,8 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { onGenerateBlockCalled = true } override def onAddData(data: Any, metadata: Any): Unit = { - addedData += data - addedMetadata += metadata + addedData.add(data) + addedMetadata.add(metadata) onAddDataCalled = true } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 34cd7435569e1..26b757cc2d535 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -200,7 +200,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) - batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) + batchUIData.get.outputOpIdSparkJobIdPairs.toSeq should be (Seq(OutputOpIdAndSparkJobId(0, 0))) // A lot of "onBatchCompleted"s happen before "onJobStart" for(i <- limit + 1 to limit * 2) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala index 0544972d95c03..25b70a3d089ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.streaming.util -import scala.collection.mutable +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import org.scalatest.PrivateMethodTester @@ -30,34 +32,34 @@ class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { test("basic") { val clock = new ManualClock() - val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val results = new ConcurrentLinkedQueue[Long]() val timer = new RecurringTimer(clock, 100, time => { - results += time + results.add(time) }, "RecurringTimerSuite-basic") timer.start(0) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L)) + assert(results.asScala.toSeq === Seq(0L)) } clock.advance(100) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L, 100L)) + assert(results.asScala.toSeq === Seq(0L, 100L)) } clock.advance(200) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L, 100L, 200L, 300L)) + assert(results.asScala.toSeq === Seq(0L, 100L, 200L, 300L)) } assert(timer.stop(interruptTimer = true) === 300L) } test("SPARK-10224: call 'callback' after stopping") { val clock = new ManualClock() - val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val results = new ConcurrentLinkedQueue[Long] val timer = new RecurringTimer(clock, 100, time => { - results += time + results.add(time) }, "RecurringTimerSuite-SPARK-10224") timer.start(0) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L)) + assert(results.asScala.toSeq === Seq(0L)) } @volatile var lastTime = -1L // Now RecurringTimer is waiting for the next interval @@ -77,7 +79,7 @@ class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { // Then it will find `stopped` is true and exit the loop, but it should call `callback` again // before exiting its internal thread. thread.join() - assert(results === Seq(0L, 100L, 200L)) + assert(results.asScala.toSeq === Seq(0L, 100L, 200L)) assert(lastTime === 200L) } } From ce83fe9756582e73ada21c3741d15aa9bbf385ed Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 9 Feb 2016 08:47:28 +0000 Subject: [PATCH 1834/2089] [SPARK-13201][SPARK-13200] Deprecation warning cleanups: KMeans & MFDataGenerator KMeans: Make a private non-deprecated version of setRuns API so that we can call it from the PythonAPI without deprecation warnings in our own build. Also use it internally when being called from train. Add a logWarning for non-1 values MFDataGenerator: Apparently we are calling round on an integer which now in Scala 2.11 results in a warning (it didn't make any sense before either). Figure out if this is a mistake we can just remove or if we got the types wrong somewhere. I put these two together since they are both deprecation fixes in MLlib and pretty small, but I can split them up if we would prefer it that way. Author: Holden Karau Closes #11112 from holdenk/SPARK-13201-non-deprecated-setRuns-SPARK-mathround-integer. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 2 +- .../org/apache/spark/mllib/clustering/KMeans.scala | 13 +++++++++++-- .../apache/spark/mllib/util/MFDataGenerator.scala | 3 +-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 088ec6a0c0465..93cf16e6f0c2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -357,7 +357,7 @@ private[python] class PythonMLLibAPI extends Serializable { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) - .setRuns(runs) + .internalSetRuns(runs) .setInitializationMode(initializationMode) .setInitializationSteps(initializationSteps) .setEpsilon(epsilon) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 901164a391170..67de62bc2e848 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -119,9 +119,18 @@ class KMeans private ( @Since("0.8.0") @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def setRuns(runs: Int): this.type = { + internalSetRuns(runs) + } + + // Internal version of setRuns for Python API, this should be removed at the same time as setRuns + // this is done to avoid deprecation warnings in our build. + private[mllib] def internalSetRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") } + if (runs != 1) { + logWarning("Setting number of runs is deprecated and will have no effect in 2.0.0") + } this.runs = runs this } @@ -502,7 +511,7 @@ object KMeans { seed: Long): KMeansModel = { new KMeans().setK(k) .setMaxIterations(maxIterations) - .setRuns(runs) + .internalSetRuns(runs) .setInitializationMode(initializationMode) .setSeed(seed) .run(data) @@ -528,7 +537,7 @@ object KMeans { initializationMode: String): KMeansModel = { new KMeans().setK(k) .setMaxIterations(maxIterations) - .setRuns(runs) + .internalSetRuns(runs) .setInitializationMode(initializationMode) .run(data) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 8af6750da4ff3..898a09e51636c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -105,8 +105,7 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = math.min( - math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt + val testSampSize = math.min(math.round(sampSize * testSampFact).toInt, mn - sampSize) val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) From c882ec57de509895706dcafea8234238e4277a2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 9 Feb 2016 08:49:34 +0000 Subject: [PATCH 1835/2089] [SPARK-13040][DOCS] Update JDBC deprecated SPARK_CLASSPATH documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update JDBC documentation based on http://stackoverflow.com/a/30947090/219530 as SPARK_CLASSPATH is deprecated. Also, that's how it worked, it didn't work with the SPARK_CLASSPATH or the --jars alone. This would solve issue: https://issues.apache.org/jira/browse/SPARK-13040 Author: Sebastián Ramírez Closes #10948 from tiangolo/patch-docs-jdbc. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 550a40010e828..ce53a39f9f604 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1869,7 +1869,7 @@ spark classpath. For example, to connect to postgres from the Spark Shell you wo following command: {% highlight bash %} -SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell +bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9.4.1207.jar {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using From d9ba4d27f4d324a7055b9b914c75d176f3e2f71d Mon Sep 17 00:00:00 2001 From: sachin aggarwal Date: Tue, 9 Feb 2016 08:52:58 +0000 Subject: [PATCH 1836/2089] [SPARK-13177][EXAMPLES] Update ActorWordCount example to not directly use low level linked list as it is deprecated. Author: sachin aggarwal Closes #11113 from agsachin/master. --- .../apache/spark/examples/streaming/ActorWordCount.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 8e88987439ffc..9f7c7d50e5176 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import scala.collection.mutable.LinkedList +import scala.collection.mutable.LinkedHashSet import scala.reflect.ClassTag import scala.util.Random @@ -39,7 +39,7 @@ case class UnsubscribeReceiver(receiverActor: ActorRef) class FeederActor extends Actor { val rand = new Random() - var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() + val receivers = new LinkedHashSet[ActorRef]() val strings: Array[String] = Array("words ", "may ", "count ") @@ -63,11 +63,11 @@ class FeederActor extends Actor { def receive: Receive = { case SubscribeReceiver(receiverActor: ActorRef) => println("received subscribe from %s".format(receiverActor.toString)) - receivers = LinkedList(receiverActor) ++ receivers + receivers += receiverActor case UnsubscribeReceiver(receiverActor: ActorRef) => println("received unsubscribe from %s".format(receiverActor.toString)) - receivers = receivers.dropWhile(x => x eq receiverActor) + receivers -= receiverActor } } From e30121afac35439be5d42c04da6f047f7d973dd6 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Tue, 9 Feb 2016 09:05:22 +0000 Subject: [PATCH 1837/2089] [SPARK-13086][SHELL] Use the Scala REPL settings, to enable things like `-i file`. Now: ``` $ bin/spark-shell -i test.scala NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of assembly. Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). 16/01/29 17:37:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 16/01/29 17:37:39 INFO Main: Created spark context.. Spark context available as sc (master = local[*], app id = local-1454085459000). 16/01/29 17:37:39 INFO Main: Created sql context.. SQL context available as sqlContext. Loading test.scala... hello Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.0.0-SNAPSHOT /_/ Using Scala version 2.11.7 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_45) Type in expressions to have them evaluated. Type :help for more information. ``` Author: Iulian Dragos Closes #10984 from dragos/issue/repl-eval-file. --- .../src/main/scala/org/apache/spark/repl/Main.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 07ba28bb07545..1ae4182947c8f 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -19,7 +19,7 @@ package org.apache.spark.repl import java.io.File -import scala.tools.nsc.Settings +import scala.tools.nsc.GenericRunnerSettings import org.apache.spark.util.Utils import org.apache.spark._ @@ -56,7 +56,7 @@ object Main extends Logging { "-classpath", getAddedJars.mkString(File.pathSeparator) ) ++ args.toList - val settings = new Settings(scalaOptionError) + val settings = new GenericRunnerSettings(scalaOptionError) settings.processArguments(interpArguments, true) if (!hasErrors) { From 68ed3632c56389ab3ff4ea5d73c575f224dab4f6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 9 Feb 2016 11:23:29 +0000 Subject: [PATCH 1838/2089] [SPARK-13170][STREAMING] Investigate replacing SynchronizedQueue as it is deprecated Replace SynchronizeQueue with synchronized access to a Queue Author: Sean Owen Closes #11111 from srowen/SPARK-13170. --- .../examples/streaming/QueueStream.scala | 8 ++-- .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/dstream/QueueInputDStream.scala | 13 ++++--- .../spark/streaming/InputStreamsSuite.scala | 37 ++++++++++++------- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index 13ba9a43ec3c9..5455aed22085d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.streaming -import scala.collection.mutable.SynchronizedQueue +import scala.collection.mutable.Queue import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD @@ -34,7 +34,7 @@ object QueueStream { // Create the queue through which RDDs can be pushed to // a QueueInputDStream - val rddQueue = new SynchronizedQueue[RDD[Int]]() + val rddQueue = new Queue[RDD[Int]]() // Create the QueueInputDStream and use it do some processing val inputStream = ssc.queueStream(rddQueue) @@ -45,7 +45,9 @@ object QueueStream { // Create and push some RDDs into for (i <- 1 to 30) { - rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) + rddQueue.synchronized { + rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) + } Thread.sleep(1000) } ssc.stop() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 32bea88ec6cc0..a1b25c9f7da7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -459,7 +459,7 @@ class StreamingContext private[streaming] ( * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. * - * @param queue Queue of RDDs + * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD */ @@ -477,7 +477,7 @@ class StreamingContext private[streaming] ( * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. * - * @param queue Queue of RDDs + * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. * Set as null if no RDD should be returned when empty diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index a8d108de6c3e1..f9c78699164ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -48,12 +48,15 @@ class QueueInputDStream[T: ClassTag]( override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue.dequeueAll(_ => true) + queue.synchronized { + if (oneAtATime && queue.nonEmpty) { + buffer += queue.dequeue() + } else { + buffer ++= queue + queue.clear() + } } - if (buffer.size > 0) { + if (buffer.nonEmpty) { if (oneAtATime) { Some(buffer.head) } else { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 93c883362cfbe..fa17b3a15c4b6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -24,7 +24,7 @@ import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ -import scala.collection.mutable.SynchronizedQueue +import scala.collection.mutable import scala.language.postfixOps import com.google.common.io.Files @@ -40,7 +40,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerBatchCompleted} import org.apache.spark.util.{ManualClock, Utils} class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -67,7 +66,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Feed data to the server to send to the network receiver val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val expectedOutput = input.map(_.toString) - for (i <- 0 until input.size) { + for (i <- input.indices) { testServer.send(input(i).toString + "\n") Thread.sleep(500) clock.advance(batchDuration.milliseconds) @@ -102,8 +101,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { + assert(output.length === expectedOutput.size) + for (i <- output.indices) { assert(output(i) === expectedOutput(i)) } } @@ -242,11 +241,11 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) val outputQueue = new ConcurrentLinkedQueue[Seq[String]] - def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.size > 0) + def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.nonEmpty) // Set up the streaming context and input streams withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - val queue = new SynchronizedQueue[RDD[String]]() + val queue = new mutable.Queue[RDD[String]]() val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputStream = new TestOutputStream(queueStream, outputQueue) outputStream.register() @@ -256,9 +255,13 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val inputIterator = input.toIterator - for (i <- 0 until input.size) { + for (i <- input.indices) { // Enqueue more than 1 item per tick but they should dequeue one at a time - inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + inputIterator.take(2).foreach { i => + queue.synchronized { + queue += ssc.sparkContext.makeRDD(Seq(i)) + } + } clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) @@ -281,13 +284,13 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("queue input stream - oneAtATime = false") { val outputQueue = new ConcurrentLinkedQueue[Seq[String]] - def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.size > 0) + def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.nonEmpty) val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) // Set up the streaming context and input streams withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - val queue = new SynchronizedQueue[RDD[String]]() + val queue = new mutable.Queue[RDD[String]]() val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputStream = new TestOutputStream(queueStream, outputQueue) outputStream.register() @@ -298,12 +301,20 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Enqueue the first 3 items (one by one), they should be merged in the next batch val inputIterator = input.toIterator - inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + inputIterator.take(3).foreach { i => + queue.synchronized { + queue += ssc.sparkContext.makeRDD(Seq(i)) + } + } clock.advance(batchDuration.milliseconds) Thread.sleep(1000) // Enqueue the remaining items (again one by one), merged in the final batch - inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + inputIterator.foreach { i => + queue.synchronized { + queue += ssc.sparkContext.makeRDD(Seq(i)) + } + } clock.advance(batchDuration.milliseconds) Thread.sleep(1000) } From 34d0b70b309f16af263eb4e6d7c36e2ea170bc67 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Tue, 9 Feb 2016 11:01:47 -0800 Subject: [PATCH 1839/2089] [SPARK-12807][YARN] Spark External Shuffle not working in Hadoop clusters with Jackson 2.2.3 Patch to 1. Shade jackson 2.x in spark-yarn-shuffle JAR: core, databind, annotation 2. Use maven antrun to verify the JAR has the renamed classes Being Maven-based, I don't know if the verification phase kicks in on an SBT/jenkins build. It will on a `mvn install` Author: Steve Loughran Closes #10780 from steveloughran/stevel/patches/SPARK-12807-master-shuffle. --- network/yarn/pom.xml | 49 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a28785b16e1e6..3cb44324f25f2 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -35,6 +35,8 @@ network-yarn provided + ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar + org/spark-project/ @@ -70,7 +72,7 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar + ${shuffle.jar} *:* @@ -86,6 +88,15 @@ + + + com.fasterxml.jackson + org.spark-project.com.fasterxml.jackson + + com.fasterxml.jackson.** + + + @@ -96,6 +107,42 @@ + + + + org.apache.maven.plugins + maven-antrun-plugin + + + verify + + run + + + + + + + + + + + + + + + + + + Verifying dependency shading + + + + + + + + From 2dbb9164405d6f595905c7d4b32e20177f0f669f Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Tue, 9 Feb 2016 11:56:25 -0800 Subject: [PATCH 1840/2089] [SPARK-13189] Cleanup build references to Scala 2.10 Author: Luciano Resende Closes #11092 from lresende/SPARK-13189. --- LICENSE | 18 +++++++++--------- dev/audit-release/README.md | 9 +++++---- dev/audit-release/audit_release.py | 4 ++-- docker/spark-test/base/Dockerfile | 2 +- docs/_config.yml | 4 ++-- pom.xml | 2 +- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/LICENSE b/LICENSE index 9fc29db8d3f22..9b78f3bbf8878 100644 --- a/LICENSE +++ b/LICENSE @@ -249,14 +249,14 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.10.5 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.10:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.10:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.10:0.7.1 - http://spire-math.org) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) + (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) + (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) (New BSD License) Kryo (com.esotericsoftware.kryo:kryo:2.21 - http://code.google.com/p/kryo/) (New BSD License) MinLog (com.esotericsoftware.minlog:minlog:1.2 - http://code.google.com/p/minlog/) (New BSD License) ReflectASM (com.esotericsoftware.reflectasm:reflectasm:1.07 - http://code.google.com/p/reflectasm/) @@ -283,7 +283,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) + (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index f72f8c653a265..37b2a0afb7aee 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -1,10 +1,11 @@ -# Test Application Builds -This directory includes test applications which are built when auditing releases. You can -run them locally by setting appropriate environment variables. +Test Application Builds +======================= + +This directory includes test applications which are built when auditing releases. You can run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.5 \ +$ SCALA_VERSION=2.11.7 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 972be30da1eb6..4dabb51254af7 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,8 +35,8 @@ RELEASE_KEY = "XXXXXXXX" # Your 8-digit hex RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1033" RELEASE_VERSION = "1.1.1" -SCALA_VERSION = "2.10.5" -SCALA_BINARY_VERSION = "2.10" +SCALA_VERSION = "2.11.7" +SCALA_BINARY_VERSION = "2.11" # Do not set these LOG_FILE_NAME = "spark_audit_%s" % time.strftime("%h_%m_%Y_%I_%M_%S") diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 7ba0de603dc7d..76f550f886ce4 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ENV SCALA_VERSION 2.10.5 +ENV SCALA_VERSION 2.11.7 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/_config.yml b/docs/_config.yml index dc25ff2c16c5e..8bdc68aeeac7f 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -16,8 +16,8 @@ include: # of Spark, Scala, and Mesos. SPARK_VERSION: 2.0.0-SNAPSHOT SPARK_VERSION_SHORT: 2.0.0 -SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.5" +SCALA_BINARY_VERSION: "2.11" +SCALA_VERSION: "2.11.7" MESOS_VERSION: 0.21.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/pom.xml b/pom.xml index d0387aca66d0d..4f7a0574c52dc 100644 --- a/pom.xml +++ b/pom.xml @@ -164,7 +164,7 @@ 3.4.1 3.2.2 - 2.10.5 + 2.11.7 2.11 ${scala.version} org.scala-lang From 7fe4fe630a3fc9755ebd0325bb595d76381633e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Feb 2016 13:06:36 -0800 Subject: [PATCH 1841/2089] [SPARK-12888] [SQL] [FOLLOW-UP] benchmark the new hash expression Adds the benchmark results as comments. The codegen version is slower than the interpreted version for `simple` case becasue of 3 reasons: 1. codegen version use a more complex hash algorithm than interpreted version, i.e. `Murmur3_x86_32.hashInt` vs [simple multiplication and addition](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala#L153). 2. codegen version will write the hash value to a row first and then read it out. I tried to create a `GenerateHasher` that can generate code to return hash value directly and got about 60% speed up for the `simple` case, does it worth? 3. the row in `simple` case only has one int field, so the runtime reflection may be removed because of branch prediction, which makes the interpreted version faster. The `array` case is also slow for similar reasons, e.g. array elements are of same type, so interpreted version can probably get rid of runtime reflection by branch prediction. Author: Wenchen Fan Closes #10917 from cloud-fan/hash-benchmark. --- .../org/apache/spark/util/Benchmark.scala | 4 +- .../org/apache/spark/sql/HashBenchmark.scala | 40 +++++++++++++++---- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 1bf6f821e9b31..39d1829310762 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -35,7 +35,8 @@ import org.apache.commons.lang3.SystemUtils * If outputPerIteration is true, the timing for each run will be printed to stdout. */ private[spark] class Benchmark( - name: String, valuesPerIteration: Long, + name: String, + valuesPerIteration: Long, iters: Int = 5, outputPerIteration: Boolean = false) { val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] @@ -61,7 +62,6 @@ private[spark] class Benchmark( println val firstBest = results.head.bestMs - val firstAvg = results.head.avgMs // The results are going to be processor specific so it is useful to include that. println(Benchmark.getProcessorName()) printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 184f845b4dce2..5a929f211aaa4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -29,9 +29,7 @@ import org.apache.spark.util.Benchmark */ object HashBenchmark { - def test(name: String, schema: StructType, iters: Int): Unit = { - val numRows = 1024 * 8 - + def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = { val generator = RandomDataGenerator.forType(schema, nullable = false).get val encoder = RowEncoder(schema) val attrs = schema.toAttributes @@ -70,7 +68,14 @@ object HashBenchmark { def main(args: Array[String]): Unit = { val simple = new StructType().add("i", IntegerType) - test("simple", simple, 1024) + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Hash For simple: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 941 / 955 142.6 7.0 1.0X + codegen version 1737 / 1775 77.3 12.9 0.5X + */ + test("simple", simple, 1 << 13, 1 << 14) val normal = new StructType() .add("null", NullType) @@ -87,18 +92,39 @@ object HashBenchmark { .add("binary", BinaryType) .add("date", DateType) .add("timestamp", TimestampType) - test("normal", normal, 128) + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 2209 / 2271 0.9 1053.4 1.0X + codegen version 1887 / 2018 1.1 899.9 1.2X + */ + test("normal", normal, 1 << 10, 1 << 11) val arrayOfInt = ArrayType(IntegerType) val array = new StructType() .add("array", arrayOfInt) .add("arrayOfArray", ArrayType(arrayOfInt)) - test("array", array, 64) + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1481 / 1529 0.1 11301.7 1.0X + codegen version 2591 / 2636 0.1 19771.1 0.6X + */ + test("array", array, 1 << 8, 1 << 9) val mapOfInt = MapType(IntegerType, IntegerType) val map = new StructType() .add("map", mapOfInt) .add("mapOfMap", MapType(IntegerType, mapOfInt)) - test("map", map, 64) + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1820 / 1861 0.0 444347.2 1.0X + codegen version 205 / 223 0.0 49936.5 8.9X + */ + test("map", map, 1 << 6, 1 << 6) } } From fae830d15846f7ffdfe49eeb45e175a3cdd2c670 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 9 Feb 2016 16:31:00 -0800 Subject: [PATCH 1842/2089] [SPARK-13245][CORE] Call shuffleMetrics methods only in one thread for ShuffleBlockFetcherIterator Call shuffleMetrics's incRemoteBytesRead and incRemoteBlocksFetched when polling FetchResult from `results` so as to always use shuffleMetrics in one thread. Also fix a race condition that could cause memory leak. Author: Shixiong Zhu Closes #11138 from zsxwing/SPARK-13245. --- .../storage/ShuffleBlockFetcherIterator.scala | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c6065df64ae03..c368a39e629f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.control.NonFatal @@ -107,7 +108,8 @@ final class ShuffleBlockFetcherIterator( * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. */ - @volatile private[this] var isZombie = false + @GuardedBy("this") + private[this] var isZombie = false initialize() @@ -126,14 +128,22 @@ final class ShuffleBlockFetcherIterator( * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ private[this] def cleanup() { - isZombie = true + synchronized { + isZombie = true + } releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, _, buf) => buf.release() + case SuccessFetchResult(_, address, _, buf) => { + if (address != blockManager.blockManagerId) { + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + } + buf.release() + } case _ => } } @@ -154,13 +164,13 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. - if (!isZombie) { - // Increment the ref count because we need to pass this to a different thread. - // This needs to be released after use. - buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) - shuffleMetrics.incRemoteBytesRead(buf.size) - shuffleMetrics.incRemoteBlocksFetched(1) + ShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) + } } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -289,7 +299,13 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, address, size, buf) => { + if (address != blockManager.blockManagerId) { + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + } + bytesInFlight -= size + } case _ => } // Send fetch requests up to maxBytesInFlight From 0e5ebac3c1f1ff58f938be59c7c9e604977d269c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Feb 2016 16:41:21 -0800 Subject: [PATCH 1843/2089] [SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate This PR improve the lookup of BytesToBytesMap by: 1. Generate code for calculate the hash code of grouping keys. 2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection). Author: Davies Liu Closes #11010 from davies/gen_map. --- .../spark/unsafe/map/BytesToBytesMap.java | 108 ++++++++++-------- .../map/AbstractBytesToBytesMapSuite.java | 64 ++++++----- project/MimaExcludes.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 1 - .../UnsafeFixedWidthAggregationMap.java | 34 +++--- .../sql/execution/UnsafeKVExternalSorter.java | 4 +- .../sql/execution/WholeStageCodegen.scala | 6 +- .../aggregate/TungstenAggregate.scala | 10 +- .../sql/execution/joins/HashedRelation.scala | 17 ++- .../BenchmarkWholeStageCodegen.scala | 64 ++++++++--- 10 files changed, 182 insertions(+), 127 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 3387f9a4177ce..b55a322a1b413 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -38,7 +38,6 @@ import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; @@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); - private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); - private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; private final TaskMemoryManager taskMemoryManager; @@ -417,7 +414,19 @@ public MapIterator destructiveIterator() { * This function always return the same {@link Location} instance to avoid object allocation. */ public Location lookup(Object keyBase, long keyOffset, int keyLength) { - safeLookup(keyBase, keyOffset, keyLength, loc); + safeLookup(keyBase, keyOffset, keyLength, loc, + Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42)); + return loc; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) { + safeLookup(keyBase, keyOffset, keyLength, loc, hash); return loc; } @@ -426,14 +435,13 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength) { * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) { assert(longArray != null); if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); - int pos = hashcode & mask; + int pos = hash & mask; int step = 1; while (true) { if (enablePerfMetrics) { @@ -441,22 +449,19 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l } if (longArray.get(pos * 2) == 0) { // This is a new key. - loc.with(pos, hashcode, false); + loc.with(pos, hash, false); return; } else { long stored = longArray.get(pos * 2 + 1); - if ((int) (stored) == hashcode) { + if ((int) (stored) == hash) { // Full hash code matches. Let's compare the keys for equality. - loc.with(pos, hashcode, true); + loc.with(pos, hash, true); if (loc.getKeyLength() == keyLength) { - final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedkeyBase = keyAddress.getBaseObject(); - final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( keyBase, keyOffset, - storedkeyBase, - storedkeyOffset, + loc.getKeyBase(), + loc.getKeyOffset(), keyLength ); if (areEqual) { @@ -484,13 +489,14 @@ public final class Location { private boolean isDefined; /** * The hashcode of the most recent key passed to - * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to - * avoid re-hashing the key when storing a value for that key. + * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us + * to avoid re-hashing the key when storing a value for that key. */ private int keyHashcode; - private final MemoryLocation keyMemoryLocation = new MemoryLocation(); - private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private Object baseObject; // the base object for key and value + private long keyOffset; private int keyLength; + private long valueOffset; private int valueLength; /** @@ -504,18 +510,15 @@ private void updateAddressesAndSizes(long fullKeyAddress) { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object base, final long offset) { - long position = offset; - final int totalLength = Platform.getInt(base, position); - position += 4; - keyLength = Platform.getInt(base, position); - position += 4; + private void updateAddressesAndSizes(final Object base, long offset) { + baseObject = base; + final int totalLength = Platform.getInt(base, offset); + offset += 4; + keyLength = Platform.getInt(base, offset); + offset += 4; + keyOffset = offset; + valueOffset = offset + keyLength; valueLength = totalLength - keyLength - 4; - - keyMemoryLocation.setObjAndOffset(base, position); - - position += keyLength; - valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -543,10 +546,11 @@ private Location with(MemoryBlock page, long offsetInPage) { private Location with(Object base, long offset, int length) { this.isDefined = true; this.memoryPage = null; + baseObject = base; + keyOffset = offset + 4; keyLength = Platform.getInt(base, offset); + valueOffset = offset + 4 + keyLength; valueLength = length - 4 - keyLength; - keyMemoryLocation.setObjAndOffset(base, offset + 4); - valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); return this; } @@ -566,34 +570,44 @@ public boolean isDefined() { } /** - * Returns the address of the key defined at this position. - * This points to the first byte of the key data. - * Unspecified behavior if the key is not defined. - * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + * Returns the base object for key. */ - public MemoryLocation getKeyAddress() { + public Object getKeyBase() { assert (isDefined); - return keyMemoryLocation; + return baseObject; } /** - * Returns the length of the key defined at this position. - * Unspecified behavior if the key is not defined. + * Returns the offset for key. */ - public int getKeyLength() { + public long getKeyOffset() { assert (isDefined); - return keyLength; + return keyOffset; + } + + /** + * Returns the base object for value. + */ + public Object getValueBase() { + assert (isDefined); + return baseObject; } /** - * Returns the address of the value defined at this position. - * This points to the first byte of the value data. + * Returns the offset for value. + */ + public long getValueOffset() { + assert (isDefined); + return valueOffset; + } + + /** + * Returns the length of the key defined at this position. * Unspecified behavior if the key is not defined. - * For efficiency reasons, calls to this method always returns the same MemoryLocation object. */ - public MemoryLocation getValueAddress() { + public int getKeyLength() { assert (isDefined); - return valueMemoryLocation; + return keyLength; } /** diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 702ba5469b8b4..d8af2b336dd4d 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -39,14 +39,13 @@ import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThan; @@ -142,10 +141,9 @@ public void tearDown() { protected abstract boolean useOffHeapMemoryAllocator(); - private static byte[] getByteArray(MemoryLocation loc, int size) { + private static byte[] getByteArray(Object base, long offset, int size) { final byte[] arr = new byte[size]; - Platform.copyMemory( - loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); + Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -163,13 +161,14 @@ private byte[] getRandomByteArray(int numWords) { */ private static boolean arrayEquals( byte[] expected, - MemoryLocation actualAddr, + Object base, + long offset, long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, Platform.BYTE_ARRAY_OFFSET, - actualAddr.getBaseObject(), - actualAddr.getBaseOffset(), + base, + offset, expected.length ); } @@ -212,16 +211,20 @@ public void setAndRetrieveAKey() { // reflect the result of this store without us having to call lookup() again on the same key. Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + Assert.assertArrayEquals(keyData, + getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, + getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. Assert.assertTrue( map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + Assert.assertArrayEquals(keyData, + getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, + getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); try { Assert.assertTrue(loc.putNewKey( @@ -283,15 +286,12 @@ private void iteratorTestBase(boolean destructive) throws Exception { while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = Platform.getLong( - valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -365,15 +365,15 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); Platform.copyMemory( - loc.getKeyAddress().getBaseObject(), - loc.getKeyAddress().getBaseOffset(), + loc.getKeyBase(), + loc.getKeyOffset(), key, Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Platform.copyMemory( - loc.getValueAddress().getBaseObject(), - loc.getValueAddress().getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), value, Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH @@ -425,8 +425,9 @@ public void randomizedStressTest() { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); } } @@ -436,8 +437,10 @@ public void randomizedStressTest() { final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + Assert.assertTrue( + arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); @@ -476,8 +479,9 @@ public void randomizedTestWithRecordsLargerThanPageSize() { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); } } for (Map.Entry entry : expected.entrySet()) { @@ -486,8 +490,10 @@ public void randomizedTestWithRecordsLargerThanPageSize() { final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + Assert.assertTrue( + arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9209094385395..133894704b6cc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,6 +40,7 @@ object MimaExcludes { excludePackage("org.apache.spark.rpc"), excludePackage("org.spark-project.jetty"), excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.unsafe"), excludePackage("org.apache.spark.util.collection.unsafe"), excludePackage("org.apache.spark.sql.catalyst"), excludePackage("org.apache.spark.sql.execution"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index f4ccadd9c563e..28e4f50eee809 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" val childrenHash = children.map { child => diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 6bf9d7bd0367c..2e84178d690e6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -121,19 +121,24 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); } - public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) { + return getAggregationBufferFromUnsafeRow(key, key.hashCode()); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( - unsafeGroupingKeyRow.getBaseObject(), - unsafeGroupingKeyRow.getBaseOffset(), - unsafeGroupingKeyRow.getSizeInBytes()); + key.getBaseObject(), + key.getBaseOffset(), + key.getSizeInBytes(), + hash); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: boolean putSucceeded = loc.putNewKey( - unsafeGroupingKeyRow.getBaseObject(), - unsafeGroupingKeyRow.getBaseOffset(), - unsafeGroupingKeyRow.getSizeInBytes(), + key.getBaseObject(), + key.getBaseOffset(), + key.getSizeInBytes(), emptyAggregationBuffer, Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length @@ -144,10 +149,9 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo } // Reset the pointer to point to the value that we just stored or looked up: - final MemoryLocation address = loc.getValueAddress(); currentAggregationBuffer.pointTo( - address.getBaseObject(), - address.getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), loc.getValueLength() ); return currentAggregationBuffer; @@ -172,16 +176,14 @@ public KVIterator iterator() { public boolean next() { if (mapLocationIterator.hasNext()) { final BytesToBytesMap.Location loc = mapLocationIterator.next(); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); key.pointTo( - keyAddress.getBaseObject(), - keyAddress.getBaseOffset(), + loc.getKeyBase(), + loc.getKeyOffset(), loc.getKeyLength() ); value.pointTo( - valueAddress.getBaseObject(), - valueAddress.getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), loc.getValueLength() ); return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 0da26bf376a6a..51e10b0e936b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -97,8 +97,8 @@ public UnsafeKVExternalSorter( UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); - final Object baseObject = loc.getKeyAddress().getBaseObject(); - final long baseOffset = loc.getKeyAddress().getBaseOffset(); + final Object baseObject = loc.getKeyBase(); + final long baseOffset = loc.getKeyOffset(); // Get encoded memory address // baseObject + baseOffset point to the beginning of the key data in the map, but that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 4ca2d85406bb7..b200239c94206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru def apply(plan: SparkPlan): SparkPlan = { if (sqlContext.conf.wholeStageEnabled) { plan.transform { - case plan: CodegenSupport if supportCodegen(plan) && - // Whole stage codegen is only useful when there are at least two levels of operators that - // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => - + case plan: CodegenSupport if supportCodegen(plan) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { // The build side can't be compiled together diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9d9f14f2dd014..340b8f78e5c9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -501,6 +501,11 @@ case class TungstenAggregate( } } + // generate hash code for key + val hashExpr = Murmur3Hash(groupingExpressions, 42) + ctx.currentVars = input + val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) + val inputAttr = bufferAttributes ++ child.output ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input ctx.INPUT_ROW = buffer @@ -526,10 +531,11 @@ case class TungstenAggregate( s""" // generate grouping key ${keyCode.code.trim} + ${hashEval.code.trim} UnsafeRow $buffer = null; if ($checkFallback) { // try to get the buffer from hash map - $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); } if ($buffer == null) { if ($sorterTerm == null) { @@ -540,7 +546,7 @@ case class TungstenAggregate( $resetCoulter // the hash map had be spilled, it should have enough memory now, // try to allocate buffer again. - $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); if ($buffer == null) { // failed to allocate the first page throw new OutOfMemoryError("No enough memory for aggregation"); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c94d6c195b1d8..eb6930a14f9c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation( val map = binaryMap // avoid the compiler error val loc = new map.Location // this could be allocated in stack binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc) + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() - val base = loc.getValueAddress.getBaseObject - var offset = loc.getValueAddress.getBaseOffset - val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + val base = loc.getValueBase + var offset = loc.getValueOffset + val last = offset + loc.getValueLength while (offset < last) { val numFields = Platform.getInt(base, offset) val sizeInBytes = Platform.getInt(base, offset + 4) @@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(binaryMap.numElements()) var buffer = new Array[Byte](64) - def write(addr: MemoryLocation, length: Int): Unit = { + def write(base: Object, offset: Long, length: Int): Unit = { if (buffer.length < length) { buffer = new Array[Byte](length) } - Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, - buffer, Platform.BYTE_ARRAY_OFFSET, length) + Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) out.write(buffer, 0, length) } @@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation( // [key size] [values size] [key bytes] [values bytes] out.writeInt(loc.getKeyLength) out.writeInt(loc.getValueLength) - write(loc.getKeyAddress, loc.getKeyLength) - write(loc.getValueAddress, loc.getValueLength) + write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) + write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index f015d297048a3..dc6c647a4a95f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X - Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X + Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X */ } @@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.addCase("hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(2) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var s = 0 while (i < N) { key.setInt(0, i % 1000) val h = Murmur3_x86_32.hashUnsafeWords( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) + s += h + i += 1 + } + } + + benchmark.addCase("fast hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashLong(i % 1000, 42) s += h i += 1 } } + benchmark.addCase("arrayEqual") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + if (key.equals(value)) { + s += 1 + } + i += 1 + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val valueBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(2) + val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 while (i < N) { key.setInt(0, i % 65536) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 65536, 42)) if (loc.isDefined) { - value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, - loc.getValueLength) + value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) value.setInt(0, value.getInt(0) + 1) i += 1 } else { @@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 628 / 661 83.0 12.0 1.0X - BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X - BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X + hash 651 / 678 80.0 12.5 1.0X + fast hash 336 / 343 155.9 6.4 1.9X + arrayEqual 417 / 428 125.0 8.0 1.6X + BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X + BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X */ benchmark.run() } From 9267bc68fab65c6a798e065a1dbe0f5171df3077 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 9 Feb 2016 17:10:55 -0800 Subject: [PATCH 1844/2089] [SPARK-10524][ML] Use the soft prediction to order categories' bins JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Author: Joseph K. Bradley Closes #8734 from viirya/dt-soft-centroids. --- .../spark/ml/tree/impl/RandomForest.scala | 42 ++-- .../spark/mllib/tree/DecisionTree.scala | 219 +++++++++--------- .../DecisionTreeClassifierSuite.scala | 36 ++- .../spark/mllib/tree/DecisionTreeSuite.scala | 30 +++ 4 files changed, 194 insertions(+), 133 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index d3376a7dff938..ea733d577a5fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging { * @param binAggregates Bin statistics. * @return tuple for best split: (Split, information gain, prediction at node) */ - private def binsToBestSplit( + private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], @@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging { * * centroidForCategories is a list: (category, centroid) */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { + val centroidForCategories = Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) } else { - Double.MaxValue - } - (featureValue, centroid) - } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the centroid of their corresponding labels. - Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. categoryStats.predict - } else { - Double.MaxValue } - (featureValue, centroid) + } else { + Double.MaxValue } + (featureValue, centroid) } logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 07ba0d8ccb2a8..51235a23711a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging { * @param binAggregates Bin statistics. * @return tuple for best split: (Split, information gain, prediction at node) */ - private def binsToBestSplit( + private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], @@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging { // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIdx, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numBins = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - categoryStats.calculate() - } else { - Double.MaxValue - } - (featureValue, centroid) + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the centroid of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = + binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numBins = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = Range(0, numBins).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { - categoryStats.predict + if (binAggregates.metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) + } else { + // For categorical variables in regression, + // the bins are ordered by the prediction. + categoryStats.predict + } } else { Double.MaxValue } (featureValue, centroid) } - } - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) - (bestFeatureSplit, bestFeatureGainStats) - } }.maxBy(_._2.gain) (bestSplit, bestSplitStats, predictWithImpurity.get._1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index fda2711fed0fd..baf6b9083900f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val model = dt.fit(df) } + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val data = sc.parallelize(arr) + val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(1) + .setMaxBins(3) + val model = dt.fit(df) + model.rootNode match { + case n: InternalNode => + n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a9c935bd42445..dca8ea815aa6a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils @@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.rightNode.get.impurity === 0.0) } + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val input = sc.parallelize(arr) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) + + val model = new DecisionTree(strategy).run(input) + model.topNode.split.get match { + case Split(_, _, _, categories: List[Double]) => + assert(categories === List(1.0)) + } + } + test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) From 6f710f9fd4f85370557b7705020ff16f2385e645 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 10 Feb 2016 09:45:13 +0800 Subject: [PATCH 1845/2089] [SPARK-12476][SQL] Implement JdbcRelation#unhandledFilters for removing unnecessary Spark Filter Input: SELECT * FROM jdbcTable WHERE col0 = 'xxx' Current plan: ``` == Optimized Logical Plan == Project [col0#0,col1#1] +- Filter (col0#0 = xxx) +- Relation[col0#0,col1#1] JDBCRelation(jdbc:postgresql:postgres,testRel,[Lorg.apache.spark.Partition;2ac7c683,{user=maropu, password=, driver=org.postgresql.Driver}) == Physical Plan == +- Filter (col0#0 = xxx) +- Scan JDBCRelation(jdbc:postgresql:postgres,testRel,[Lorg.apache.spark.Partition;2ac7c683,{user=maropu, password=, driver=org.postgresql.Driver})[col0#0,col1#1] PushedFilters: [EqualTo(col0,xxx)] ``` This patch enables a plan below; ``` == Optimized Logical Plan == Project [col0#0,col1#1] +- Filter (col0#0 = xxx) +- Relation[col0#0,col1#1] JDBCRelation(jdbc:postgresql:postgres,testRel,[Lorg.apache.spark.Partition;2ac7c683,{user=maropu, password=, driver=org.postgresql.Driver}) == Physical Plan == Scan JDBCRelation(jdbc:postgresql:postgres,testRel,[Lorg.apache.spark.Partition;2ac7c683,{user=maropu, password=, driver=org.postgresql.Driver})[col0#0,col1#1] PushedFilters: [EqualTo(col0,xxx)] ``` Author: Takeshi YAMAMURO Closes #10427 from maropu/RemoveFilterInJdbcScan. --- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 5 ++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 70 +++++++++++++------ 3 files changed, 56 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index d867e144e517f..befba867bc460 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -189,7 +189,7 @@ private[sql] object JDBCRDD extends Logging { * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. */ - private def compileFilter(f: Filter): Option[String] = { + private[jdbc] def compileFilter(f: Filter): Option[String] = { Option(f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" case EqualNullSafe(attr, value) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 572be823ca87c..ee6373d03e1fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -90,6 +90,11 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + // Check if JDBCRDD.compileFilter can accept input filters + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + filters.filter(JDBCRDD.compileFilter(_).isEmpty) + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 518607543b482..7a0f7abaa1baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -22,12 +22,12 @@ import java.sql.{Date, DriverManager, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException -import org.scalatest.BeforeAndAfter -import org.scalatest.PrivateMethodTester +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.Row import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.sources._ @@ -183,26 +183,34 @@ class JDBCSuite extends SparkFunSuite } test("SELECT * WHERE (simple predicates)") { - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) + def checkPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is removed in a physical plan and + // the plan only has PhysicalRDD to scan JDBCRelation. + assert(parentPlan.isInstanceOf[PhysicalRDD]) + assert(parentPlan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) + df + } + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) .collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) .collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) .collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " + "AND THEID = 2")).collect().size == 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) - assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) // This is a test to reflect discussion in SPARK-12218. // The older versions of spark have this kind of bugs in parquet data source. @@ -210,6 +218,28 @@ class JDBCSuite extends SparkFunSuite val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") assert(df1.collect.toSet === Set(Row("mary", 2))) assert(df2.collect.toSet === Set(Row("mary", 2))) + + def checkNotPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD + // cannot compile given predicates. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] + assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter]) + df + } + assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) + assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) + } + + test("SELECT COUNT(1) WHERE (predicates)") { + // Check if an answer is correct when Filter is removed from operations such as count() which + // does not require any columns. In some data sources, e.g., Parquet, `requiredColumns` in + // org.apache.spark.sql.sources.interfaces is not given in logical plans, but some filters + // are applied for columns with Filter producing wrong results. On the other hand, JDBCRDD + // correctly handles this case by assigning `requiredColumns` properly. See PR 10427 for more + // discussions. + assert(sql("SELECT COUNT(1) FROM foobar WHERE NAME = 'mary'").collect.toSet === Set(Row(1))) } test("SELECT * WHERE (quoted strings)") { From b385ce38825de4b1420c5a0e8191e91fc8afecf5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 9 Feb 2016 18:50:06 -0800 Subject: [PATCH 1846/2089] [SPARK-13149][SQL] Add FileStreamSource `FileStreamSource` is an implementation of `org.apache.spark.sql.execution.streaming.Source`. It takes advantage of the existing `HadoopFsRelationProvider` to support various file formats. It remembers files in each batch and stores it into the metadata files so as to recover them when restarting. The metadata files are stored in the file system. There will be a further PR to clean up the metadata files periodically. This is based on the initial work from marmbrus. Author: Shixiong Zhu Closes #11034 from zsxwing/stream-df-file-source. --- .../datasources/ResolvedDataSource.scala | 2 +- .../streaming/FileStreamSource.scala | 240 ++++++++++ .../apache/spark/sql/sources/interfaces.scala | 33 +- .../org/apache/spark/sql/StreamTest.scala | 2 + .../DataFrameReaderWriterSuite.scala | 5 +- .../sql/streaming/FileStreamSourceSuite.scala | 435 ++++++++++++++++++ 6 files changed, 710 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 7702f535ad2f4..cefa8be0c6007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -104,7 +104,7 @@ object ResolvedDataSource extends Logging { s"Data source $providerName does not support streamed reading") } - provider.createSource(sqlContext, options, userSpecifiedSchema) + provider.createSource(sqlContext, userSpecifiedSchema, providerName, options) } def createSink( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala new file mode 100644 index 0000000000000..14ba9f69bb1d7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Codec + +import com.google.common.base.Charsets.UTF_8 +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} + +import org.apache.spark.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashSet + +/** + * A very simple source that reads text files from the given directory as they appear. + * + * TODO Clean up the metadata files periodically + */ +class FileStreamSource( + sqlContext: SQLContext, + metadataPath: String, + path: String, + dataSchema: Option[StructType], + providerName: String, + dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { + + private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + private var maxBatchId = -1 + private val seenFiles = new OpenHashSet[String] + + /** Map of batch id to files. This map is also stored in `metadataPath`. */ + private val batchToMetadata = new HashMap[Long, Seq[String]] + + { + // Restore file paths from the metadata files + val existingBatchFiles = fetchAllBatchFiles() + if (existingBatchFiles.nonEmpty) { + val existingBatchIds = existingBatchFiles.map(_.getPath.getName.toInt) + maxBatchId = existingBatchIds.max + // Recover "batchToMetadata" and "seenFiles" from existing metadata files. + existingBatchIds.sorted.foreach { batchId => + val files = readBatch(batchId) + if (files.isEmpty) { + // Assert that the corrupted file must be the latest metadata file. + if (batchId != maxBatchId) { + throw new IllegalStateException("Invalid metadata files") + } + maxBatchId = maxBatchId - 1 + } else { + batchToMetadata(batchId) = files + files.foreach(seenFiles.add) + } + } + } + } + + /** Returns the schema of the data from this source */ + override lazy val schema: StructType = { + dataSchema.getOrElse { + val filesPresent = fetchAllFiles() + if (filesPresent.isEmpty) { + if (providerName == "text") { + // Add a default schema for "text" + new StructType().add("value", StringType) + } else { + throw new IllegalArgumentException("No schema specified") + } + } else { + // There are some existing files. Use them to infer the schema. + dataFrameBuilder(filesPresent.toArray).schema + } + } + } + + /** + * Returns the maximum offset that can be retrieved from the source. + * + * `synchronized` on this method is for solving race conditions in tests. In the normal usage, + * there is no race here, so the cost of `synchronized` should be rare. + */ + private def fetchMaxOffset(): LongOffset = synchronized { + val filesPresent = fetchAllFiles() + val newFiles = new ArrayBuffer[String]() + filesPresent.foreach { file => + if (!seenFiles.contains(file)) { + logDebug(s"new file: $file") + newFiles.append(file) + seenFiles.add(file) + } else { + logDebug(s"old file: $file") + } + } + + if (newFiles.nonEmpty) { + maxBatchId += 1 + writeBatch(maxBatchId, newFiles) + } + + new LongOffset(maxBatchId) + } + + /** + * For test only. Run `func` with the internal lock to make sure when `func` is running, + * the current offset won't be changed and no new batch will be emitted. + */ + def withBatchingLocked[T](func: => T): T = synchronized { + func + } + + /** Return the latest offset in the source */ + def currentOffset: LongOffset = synchronized { + new LongOffset(maxBatchId) + } + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + override def getNextBatch(start: Option[Offset]): Option[Batch] = { + val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + val end = fetchMaxOffset() + val endId = end.offset + + if (startId + 1 <= endId) { + val files = (startId + 1 to endId).filter(_ >= 0).flatMap { batchId => + batchToMetadata.getOrElse(batchId, Nil) + }.toArray + logDebug(s"Return files from batches ${startId + 1}:$endId") + logDebug(s"Streaming ${files.mkString(", ")}") + Some(new Batch(end, dataFrameBuilder(files))) + } + else { + None + } + } + + private def fetchAllBatchFiles(): Seq[FileStatus] = { + try fs.listStatus(new Path(metadataPath)) catch { + case _: java.io.FileNotFoundException => + fs.mkdirs(new Path(metadataPath)) + Seq.empty + } + } + + private def fetchAllFiles(): Seq[String] = { + fs.listStatus(new Path(path)) + .filterNot(_.getPath.getName.startsWith("_")) + .map(_.getPath.toUri.toString) + } + + /** + * Write the metadata of a batch to disk. The file format is as follows: + * + * {{{ + * + * START + * -/a/b/c + * -/d/e/f + * ... + * END + * }}} + * + * Note: means the value of `FileStreamSource.VERSION`. Every file + * path starts with "-" so that we can know if a line is a file path easily. + */ + private def writeBatch(id: Int, files: Seq[String]): Unit = { + assert(files.nonEmpty, "create a new batch without any file") + val output = fs.create(new Path(metadataPath + "/" + id), true) + val writer = new PrintWriter(new OutputStreamWriter(output, UTF_8)) + try { + // scalastyle:off println + writer.println(FileStreamSource.VERSION) + writer.println(FileStreamSource.START_TAG) + files.foreach(file => writer.println(FileStreamSource.PATH_PREFIX + file)) + writer.println(FileStreamSource.END_TAG) + // scalastyle:on println + } finally { + writer.close() + } + batchToMetadata(id) = files + } + + /** Read the file names of the specified batch id from the metadata file */ + private def readBatch(id: Int): Seq[String] = { + val input = fs.open(new Path(metadataPath + "/" + id)) + try { + FileStreamSource.readBatch(input) + } finally { + input.close() + } + } +} + +object FileStreamSource { + + private val START_TAG = "START" + private val END_TAG = "END" + private val PATH_PREFIX = "-" + val VERSION = "FILESTREAM_V1" + + /** + * Parse a metadata file and return the content. If the metadata file is corrupted, it will return + * an empty `Seq`. + */ + def readBatch(input: InputStream): Seq[String] = { + val lines = scala.io.Source.fromInputStream(input)(Codec.UTF8).getLines().toArray + if (lines.length < 4) { + // version + start tag + end tag + at least one file path + return Nil + } + if (lines.head != VERSION) { + return Nil + } + if (lines(1) != START_TAG) { + return Nil + } + if (lines.last != END_TAG) { + return Nil + } + lines.slice(2, lines.length - 1).map(_.stripPrefix(PATH_PREFIX)) // Drop character "-" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 737be7dfd12f6..428a313ca9dc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet @@ -131,8 +131,9 @@ trait SchemaRelationProvider { trait StreamSourceProvider { def createSource( sqlContext: SQLContext, - parameters: Map[String, String], - schema: Option[StructType]): Source + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source } /** @@ -169,7 +170,7 @@ trait StreamSinkProvider { * @since 1.4.0 */ @Experimental -trait HadoopFsRelationProvider { +trait HadoopFsRelationProvider extends StreamSourceProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity @@ -196,6 +197,30 @@ trait HadoopFsRelationProvider { } createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters) } + + override def createSource( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val path = parameters.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + val metadataPath = parameters.getOrElse("metadataPath", s"$path/_metadata") + + def dataFrameBuilder(files: Array[String]): DataFrame = { + val relation = createRelation( + sqlContext, + files, + schema, + partitionColumns = None, + bucketSpec = None, + parameters) + DataFrame(sqlContext, LogicalRelation(relation)) + } + + new FileStreamSource(sqlContext, metadataPath, path, schema, providerName, dataFrameBuilder) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index f45abbf2496a2..7e388ea602343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -59,6 +59,8 @@ trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + + def toDS[A: Encoder](): Dataset[A] = new Dataset(sqlContext, StreamingRelation(s)) } /** How long to wait for an active stream to catch up when checking a result. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 36212e4395985..b762f9b90ed86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -33,8 +33,9 @@ object LastOptions { class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def createSource( sqlContext: SQLContext, - parameters: Map[String, String], - schema: Option[StructType]): Source = { + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { LastOptions.parameters = parameters LastOptions.schema = schema new Source { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala new file mode 100644 index 0000000000000..7a4ee0ef264d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.{ByteArrayInputStream, File, FileNotFoundException, InputStream} + +import com.google.common.base.Charsets.UTF_8 + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.FileStreamSource._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.Utils + +class FileStreamSourceTest extends StreamTest with SharedSQLContext { + + import testImplicits._ + + case class AddTextFileData(source: FileStreamSource, content: String, src: File, tmp: File) + extends AddData { + + override def addData(): Offset = { + source.withBatchingLocked { + val file = Utils.tempFileWith(new File(tmp, "text")) + stringToFile(file, content).renameTo(new File(src, file.getName)) + source.currentOffset + } + 1 + } + } + + case class AddParquetFileData( + source: FileStreamSource, + content: Seq[String], + src: File, + tmp: File) extends AddData { + + override def addData(): Offset = { + source.withBatchingLocked { + val file = Utils.tempFileWith(new File(tmp, "parquet")) + content.toDS().toDF().write.parquet(file.getCanonicalPath) + file.renameTo(new File(src, file.getName)) + source.currentOffset + } + 1 + } + } + + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ + def createFileStreamSource( + format: String, + path: String, + schema: Option[StructType] = None): FileStreamSource = { + val reader = + if (schema.isDefined) { + sqlContext.read.format(format).schema(schema.get) + } else { + sqlContext.read.format(format) + } + reader.stream(path) + .queryExecution.analyzed + .collect { case StreamingRelation(s: FileStreamSource, _) => s } + .head + } + + val valueSchema = new StructType().add("value", StringType) +} + +class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { + + import testImplicits._ + + private def createFileStreamSourceAndGetSchema( + format: Option[String], + path: Option[String], + schema: Option[StructType] = None): StructType = { + val reader = sqlContext.read + format.foreach(reader.format) + schema.foreach(reader.schema) + val df = + if (path.isDefined) { + reader.stream(path.get) + } else { + reader.stream() + } + df.queryExecution.analyzed + .collect { case StreamingRelation(s: FileStreamSource, _) => s } + .head + .schema + } + + test("FileStreamSource schema: no path") { + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema(format = None, path = None, schema = None) + } + assert("'path' is not specified" === e.getMessage) + } + + test("FileStreamSource schema: path doesn't exist") { + intercept[FileNotFoundException] { + createFileStreamSourceAndGetSchema(format = None, path = Some("/a/b/c"), schema = None) + } + } + + test("FileStreamSource schema: text, no existing files, no schema") { + withTempDir { src => + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: text, existing files, no schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "a\nb\nc") + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: text, existing files, schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "a\nb\nc") + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("FileStreamSource schema: parquet, no existing files, no schema") { + withTempDir { src => + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(new File(src, "1").getCanonicalPath), schema = None) + } + assert("No schema specified" === e.getMessage) + } + } + + test("FileStreamSource schema: parquet, existing files, no schema") { + withTempDir { src => + Seq("a", "b", "c").toDS().as("userColumn").toDF() + .write.parquet(new File(src, "1").getCanonicalPath) + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: parquet, existing files, schema") { + withTempPath { src => + Seq("a", "b", "c").toDS().as("oldUserColumn").toDF() + .write.parquet(new File(src, "1").getCanonicalPath) + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("FileStreamSource schema: json, no existing files, no schema") { + withTempDir { src => + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + } + assert("No schema specified" === e.getMessage) + } + } + + test("FileStreamSource schema: json, existing files, no schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c': '3'}") + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("c", StringType)) + } + } + + test("FileStreamSource schema: json, existing files, schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c', '3'}") + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("read from text files") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from json files") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("json", src.getCanonicalPath, Some(valueSchema)) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData( + textSource, + "{'value': 'drop1'}\n{'value': 'keep2'}\n{'value': 'keep3'}", + src, + tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData( + textSource, + "{'value': 'drop4'}\n{'value': 'keep5'}\n{'value': 'keep6'}", + src, + tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData( + textSource, + "{'value': 'drop7'}\n{'value': 'keep8'}\n{'value': 'keep9'}", + src, + tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from json files with inferring schema") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") + + val textSource = createFileStreamSource("json", src.getCanonicalPath) + + // FileStreamSource should infer the column "c" + val filtered = textSource.toDF().filter($"c" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from parquet files") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val fileSource = createFileStreamSource("parquet", src.getCanonicalPath, Some(valueSchema)) + val filtered = fileSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddParquetFileData(fileSource, Seq("drop1", "keep2", "keep3"), src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddParquetFileData(fileSource, Seq("drop4", "keep5", "keep6"), src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddParquetFileData(fileSource, Seq("drop7", "keep8", "keep9"), src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("file stream source without schema") { + val src = Utils.createTempDir("streaming.src") + + // Only "text" doesn't need a schema + createFileStreamSource("text", src.getCanonicalPath) + + // Both "json" and "parquet" require a schema if no existing file to infer + intercept[IllegalArgumentException] { + createFileStreamSource("json", src.getCanonicalPath) + } + intercept[IllegalArgumentException] { + createFileStreamSource("parquet", src.getCanonicalPath) + } + + Utils.deleteRecursively(src) + } + + test("fault tolerance") { + def assertBatch(batch1: Option[Batch], batch2: Option[Batch]): Unit = { + (batch1, batch2) match { + case (Some(b1), Some(b2)) => + assert(b1.end === b2.end) + assert(b1.data.as[String].collect() === b2.data.as[String].collect()) + case (None, None) => + case _ => fail(s"batch ($batch1) is not equal to batch ($batch2)") + } + } + + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + val textSource2 = createFileStreamSource("text", src.getCanonicalPath) + assert(textSource2.currentOffset === textSource.currentOffset) + assertBatch(textSource2.getNextBatch(None), textSource.getNextBatch(None)) + for (f <- 0L to textSource.currentOffset.offset) { + val offset = LongOffset(f) + assertBatch(textSource2.getNextBatch(Some(offset)), textSource.getNextBatch(Some(offset))) + } + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("fault tolerance with corrupted metadata file") { + val src = Utils.createTempDir("streaming.src") + assert(new File(src, "_metadata").mkdirs()) + stringToFile( + new File(src, "_metadata/0"), + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n") + stringToFile(new File(src, "_metadata/1"), s"${FileStreamSource.VERSION}\nSTART\n-") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + // the metadata file of batch is corrupted, so currentOffset should be 0 + assert(textSource.currentOffset === LongOffset(0)) + + Utils.deleteRecursively(src) + } + + test("fault tolerance with normal metadata file") { + val src = Utils.createTempDir("streaming.src") + assert(new File(src, "_metadata").mkdirs()) + stringToFile( + new File(src, "_metadata/0"), + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n") + stringToFile( + new File(src, "_metadata/1"), + s"${FileStreamSource.VERSION}\nSTART\n-/x/y/z\nEND\n") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + assert(textSource.currentOffset === LongOffset(1)) + + Utils.deleteRecursively(src) + } + + test("readBatch") { + def stringToStream(str: String): InputStream = new ByteArrayInputStream(str.getBytes(UTF_8)) + + // Invalid metadata + assert(readBatch(stringToStream("")) === Nil) + assert(readBatch(stringToStream(FileStreamSource.VERSION)) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\n")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEN")) === Nil) + + // Valid metadata + assert(readBatch(stringToStream( + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND")) === Seq("/a/b/c")) + assert(readBatch(stringToStream( + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND\n")) === Seq("/a/b/c")) + assert(readBatch(stringToStream( + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n")) + === Seq("/a/b/c", "/e/f/g")) + } +} + +class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { + + import testImplicits._ + + test("file source stress test") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val ds = textSource.toDS[String]().map(_.toInt + 1) + runStressTest(ds, data => { + AddTextFileData(textSource, data.mkString("\n"), src, tmp) + }) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } +} From 9269036d8c8bb60097fd9aacfb7a89d8e873d978 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Lipt=C3=A1k?= Date: Wed, 10 Feb 2016 09:52:35 +0000 Subject: [PATCH 1847/2089] [SPARK-11565] Replace deprecated DigestUtils.shaHex call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Gábor Lipták Closes #9532 from gliptak/SPARK-11565. --- sql/catalyst/pom.xml | 4 ++++ .../org/apache/spark/sql/catalyst/expressions/misc.scala | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index c2ad9b99f3ac9..5d1d9edd251c6 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -75,6 +75,10 @@ org.antlr antlr-runtime + + commons-codec + commons-codec + target/scala-${scala.binary.version}/classes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 28e4f50eee809..dcbb594afd86e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -143,11 +143,11 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu override def inputTypes: Seq[DataType] = Seq(BinaryType) protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]])) + UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" ) } } From 2ba9b6a2dfff8eb06b6f93024f5140e784b8be49 Mon Sep 17 00:00:00 2001 From: Jon Maurer Date: Wed, 10 Feb 2016 09:54:22 +0000 Subject: [PATCH 1848/2089] [SPARK-11518][DEPLOY, WINDOWS] Handle spaces in Windows command scripts Author: Jon Maurer Author: Jonathan Maurer Closes #10789 from tritab/cmd_updates. --- bin/beeline.cmd | 2 +- bin/load-spark-env.cmd | 6 +++--- bin/pyspark.cmd | 2 +- bin/pyspark2.cmd | 4 ++-- bin/run-example.cmd | 2 +- bin/run-example2.cmd | 15 ++++++--------- bin/spark-class.cmd | 2 +- bin/spark-class2.cmd | 10 +++++----- bin/spark-shell.cmd | 2 +- bin/spark-shell2.cmd | 2 +- bin/spark-submit.cmd | 2 +- bin/spark-submit2.cmd | 2 +- bin/sparkR.cmd | 2 +- bin/sparkR2.cmd | 4 ++-- 14 files changed, 27 insertions(+), 30 deletions(-) diff --git a/bin/beeline.cmd b/bin/beeline.cmd index 8293f311029dd..8ddaa419967a5 100644 --- a/bin/beeline.cmd +++ b/bin/beeline.cmd @@ -18,4 +18,4 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.hive.beeline.BeeLine %* +cmd /V /E /C "%SPARK_HOME%\bin\spark-class.cmd" org.apache.hive.beeline.BeeLine %* diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 59080edd294f2..0977025c2036e 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -27,7 +27,7 @@ if [%SPARK_ENV_LOADED%] == [] ( if not [%SPARK_CONF_DIR%] == [] ( set user_conf_dir=%SPARK_CONF_DIR% ) else ( - set user_conf_dir=%~dp0..\conf + set user_conf_dir=..\conf ) call :LoadSparkEnv @@ -35,8 +35,8 @@ if [%SPARK_ENV_LOADED%] == [] ( rem Setting SPARK_SCALA_VERSION if not already set. -set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 -set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 +set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11" +set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.10" if [%SPARK_SCALA_VERSION%] == [] ( diff --git a/bin/pyspark.cmd b/bin/pyspark.cmd index 7c26fbbac28b8..72d046a4ba2cf 100644 --- a/bin/pyspark.cmd +++ b/bin/pyspark.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running PySpark. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0pyspark2.cmd %* +cmd /V /E /C "%~dp0pyspark2.cmd" %* diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 51d6d15f66c69..21fe28155a596 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. @@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* +call "%SPARK_HOME%\bin\spark-submit2.cmd" pyspark-shell-main --name "PySparkShell" %* diff --git a/bin/run-example.cmd b/bin/run-example.cmd index 5b2d048d6ed50..64f6bc3728d07 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running a Spark example. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0run-example2.cmd %* +cmd /V /E /C "%~dp0run-example2.cmd" %* diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index c3e0221fb62e3..fada43581d184 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -20,12 +20,9 @@ rem set SCALA_VERSION=2.10 rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ +set SPARK_HOME=%~dp0.. -rem Export this as SPARK_HOME -set SPARK_HOME=%FWDIR% - -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" rem Test that an argument was given if not "x%1"=="x" goto arg_given @@ -36,12 +33,12 @@ if not "x%1"=="x" goto arg_given goto exit :arg_given -set EXAMPLES_DIR=%FWDIR%examples +set EXAMPLES_DIR=%SPARK_HOME%\examples rem Figure out the JAR file that our examples were packaged into. set SPARK_EXAMPLES_JAR= -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-examples*.jar") do ( +if exist "%SPARK_HOME%\RELEASE" ( + for %%d in ("%SPARK_HOME%\lib\spark-examples*.jar") do ( set SPARK_EXAMPLES_JAR=%%d ) ) else ( @@ -80,7 +77,7 @@ if "%~1" neq "" ( ) if defined ARGS set ARGS=%ARGS:~1% -call "%FWDIR%bin\spark-submit.cmd" ^ +call "%SPARK_HOME%\bin\spark-submit.cmd" ^ --master %EXAMPLE_MASTER% ^ --class %EXAMPLE_CLASS% ^ "%SPARK_EXAMPLES_JAR%" %ARGS% diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index 19850db9e1e5d..3bf3d20cb57b5 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running a Spark class. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0spark-class2.cmd %* +cmd /V /E /C "%~dp0spark-class2.cmd" %* diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index db09fa27e51a6..c4fadb822323d 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" rem Test that an argument was given if "x%1"=="x" ( @@ -32,9 +32,9 @@ rem Find assembly jar set SPARK_ASSEMBLY_JAR=0 if exist "%SPARK_HOME%\RELEASE" ( - set ASSEMBLY_DIR=%SPARK_HOME%\lib + set ASSEMBLY_DIR="%SPARK_HOME%\lib" ) else ( - set ASSEMBLY_DIR=%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION% + set ASSEMBLY_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%" ) for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( @@ -50,7 +50,7 @@ set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% rem Add the launcher build dir to the classpath if requested. if not "x%SPARK_PREPEND_CLASSES%"=="x" ( - set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% + set LAUNCH_CLASSPATH="%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH%" ) set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% @@ -62,7 +62,7 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt -"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +"%RUNNER%" -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 8f90ba5a0b3b8..991423da6ab99 100644 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark shell. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0spark-shell2.cmd %* +cmd /V /E /C "%~dp0spark-shell2.cmd" %* diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index b9b0f510d7f5d..7b5d396be888c 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* +"%SPARK_HOME%\bin\spark-submit2.cmd" --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index 8f3b84c7b971d..f121b62a53d24 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0spark-submit2.cmd %* +cmd /V /E /C spark-submit2.cmd %* diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 651376e526928..49e350fa5c416 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,4 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -%~dp0spark-class2.cmd %CLASS% %* +"%~dp0spark-class2.cmd" %CLASS% %* diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd index d7b60183ca8e0..1e5ea6a623219 100644 --- a/bin/sparkR.cmd +++ b/bin/sparkR.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running SparkR. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0sparkR2.cmd %* +cmd /V /E /C "%~dp0sparkR2.cmd" %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd index e47f22c7300bb..459b780e2ae33 100644 --- a/bin/sparkR2.cmd +++ b/bin/sparkR2.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" -call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* +call "%SPARK_HOME%\bin\spark-submit2.cmd" sparkr-shell-main %* From e834e421dec30be8dade21287165d5eb95411c73 Mon Sep 17 00:00:00 2001 From: tedyu Date: Wed, 10 Feb 2016 10:58:41 +0000 Subject: [PATCH 1849/2089] [SPARK-13203] Add scalastyle rule banning use of mutable.SynchronizedBuffer andrewor14 Please take a look Author: tedyu Closes #11134 from tedyu/master. --- scalastyle-config.xml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 967a482ba4f9b..64619d2108999 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -169,6 +169,18 @@ This file is divided into 3 sections: ]]> + + mutable\.SynchronizedBuffer + + + Class\.forName Date: Wed, 10 Feb 2016 11:02:00 +0000 Subject: [PATCH 1850/2089] [SPARK-9307][CORE][SPARK] Logging: Make it either stable or private Make Logging private[spark]. Pretty much all there is to it. Author: Sean Owen Closes #11103 from srowen/SPARK-9307. --- core/src/main/scala/org/apache/spark/Logging.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index e35e158c7e8a6..9e0a840b72e27 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -21,19 +21,15 @@ import org.apache.log4j.{Level, LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder -import org.apache.spark.annotation.Private import org.apache.spark.util.Utils /** * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. - * - * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility. - * This will likely be changed or removed in future releases. */ -@Private -trait Logging { +private[spark] trait Logging { + // Make the log field transient so that objects with Logging can // be serialized and used on another machine @transient private var log_ : Logger = null From 80cb963ad963e26c3a7f8388bdd4ffd5e99aad1a Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Wed, 10 Feb 2016 10:53:33 -0800 Subject: [PATCH 1851/2089] [SPARK-5095][MESOS] Support launching multiple mesos executors in coarse grained mesos mode. This is the next iteration of tnachen's previous PR: https://github.com/apache/spark/pull/4027 In that PR, we resolved with andrewor14 and pwendell to implement the Mesos scheduler's support of `spark.executor.cores` to be consistent with YARN and Standalone. This PR implements that resolution. This PR implements two high-level features. These two features are co-dependent, so they're implemented both here: - Mesos support for spark.executor.cores - Multiple executors per slave We at Mesosphere have been working with Typesafe on a Spark/Mesos integration test suite: https://github.com/typesafehub/mesos-spark-integration-tests, which passes for this PR. The contribution is my original work and I license the work to the project under the project's open source license. Author: Michael Gummelt Closes #10993 from mgummelt/executor_sizing. --- .../CoarseGrainedSchedulerBackend.scala | 11 +- .../mesos/CoarseMesosSchedulerBackend.scala | 375 +++++++++++------- .../cluster/mesos/MesosSchedulerBackend.scala | 4 +- .../cluster/mesos/MesosSchedulerUtils.scala | 10 +- .../CoarseMesosSchedulerBackendSuite.scala | 365 ++++++++++++----- .../mesos/MesosSchedulerBackendSuite.scala | 2 +- .../mesos/MesosSchedulerUtilsSuite.scala | 6 +- docs/configuration.md | 15 +- docs/running-on-mesos.md | 8 +- 9 files changed, 521 insertions(+), 275 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f69a3d371e5dd..0a5b09dc0d1fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -240,6 +240,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK + + logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + s"${executorData.executorHost}.") + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } @@ -309,7 +313,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + driverEndpoint = createDriverEndpointRef(properties) + } + + protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = { + rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) } protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 0a2d72f4dcb4b..98699e0b294ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -23,17 +23,17 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashMap, HashSet} import com.google.common.base.Stopwatch -import com.google.common.collect.HashBiMap import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} +import org.apache.spark.rpc.{RpcEndpointAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -74,17 +74,13 @@ private[spark] class CoarseMesosSchedulerBackend( private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID - val coresByTaskId = new HashMap[Int, Int] + val coresByTaskId = new HashMap[String, Int] var totalCoresAcquired = 0 - val slaveIdsWithExecutors = new HashSet[String] - - // Maping from slave Id to hostname - private val slaveIdToHost = new HashMap[String, String] - - val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] - // How many times tasks on each slave failed - val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + // SlaveID -> Slave + // This map accumulates entries for the duration of the job. Slaves are never deleted, because + // we need to maintain e.g. failure state and connection state. + private val slaves = new HashMap[String, Slave] /** * The total number of executors we aim to have. Undefined when not using dynamic allocation. @@ -105,13 +101,11 @@ private[spark] class CoarseMesosSchedulerBackend( */ private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) - private val pendingRemovedSlaveIds = new HashSet[String] - // private lock object protecting mutable state above. Using the intrinsic lock // may lead to deadlocks since the superclass might also try to lock private val stateLock = new ReentrantLock - val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) // Offer constraints private val slaveOfferConstraints = @@ -121,27 +115,31 @@ private[spark] class CoarseMesosSchedulerBackend( private val rejectOfferDurationForUnmetConstraints = getRejectOfferDurationForUnmetConstraints(sc) - // A client for talking to the external shuffle service, if it is a + // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { - Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf, "shuffle"), - securityManager, - securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled())) + Some(getShuffleClient()) } else { None } } + protected def getShuffleClient(): MesosExternalShuffleClient = { + new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf, "shuffle"), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) + } + var nextMesosTaskId = 0 @volatile var appId: String = _ - def newMesosTaskId(): Int = { + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 - id + id.toString } override def start() { @@ -156,7 +154,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -200,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + s" --driver-url $driverURL" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -208,12 +206,11 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head - val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + s" --driver-url $driverURL" + - s" --executor-id $executorId" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -268,113 +265,209 @@ private[spark] class CoarseMesosSchedulerBackend( offers.asScala.map(_.getId).foreach(d.declineOffer) return } - val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers.asScala) { + + logDebug(s"Received ${offers.size} resource offers.") + + val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => val offerAttributes = toAttributeMap(offer.getAttributesList) - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + } + + declineUnmatchedOffers(d, unmatchedOffers) + handleMatchedOffers(d, matchedOffers) + } + } + + private def declineUnmatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + for (offer <- offers) { + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val filters = Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build() + + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + + d.declineOffer(offer.getId, filters) + } + } + + /** + * Launches executors on accepted offers, and declines unused offers. Executors are launched + * round-robin on offers. + * + * @param d SchedulerDriver + * @param offers Mesos offers that match attribute constraints + */ + private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + val tasks = buildMesosTasks(offers) + for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val offerMem = getResource(offer.getResourcesList, "mem") + val offerCpus = getResource(offer.getResourcesList, "cpus") + val id = offer.getId.getValue + + if (tasks.contains(offer.getId)) { // accept + val offerTasks = tasks(offer.getId) + + logDebug(s"Accepting offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus. Launching ${offerTasks.size} Mesos tasks.") + + for (task <- offerTasks) { + val taskId = task.getTaskId + val mem = getResource(task.getResourcesList, "mem") + val cpus = getResource(task.getResourcesList, "cpus") + + logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus.") + } + + d.launchTasks( + Collections.singleton(offer.getId), + offerTasks.asJava) + } else { // decline + logDebug(s"Declining offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus") + + d.declineOffer(offer.getId) + } + } + } + + /** + * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize + * per-task memory and IO, tasks are round-robin assigned to offers. + * + * @param offers Mesos offers that match attribute constraints + * @return A map from OfferID to a list of Mesos tasks to launch on that offer + */ + private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = { + // offerID -> tasks + val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil) + + // offerID -> resources + val remainingResources = mutable.Map(offers.map(offer => + (offer.getId.getValue, offer.getResourcesList)): _*) + + var launchTasks = true + + // TODO(mgummelt): combine offers for a single slave + // + // round-robin create executors on the available offers + while (launchTasks) { + launchTasks = false + + for (offer <- offers) { val slaveId = offer.getSlaveId.getValue - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus").toInt - val id = offer.getId.getValue - if (meetsConstraints) { - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse.asJava) - .addAllResources(memResourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) - } - - // Accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) - } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + val offerId = offer.getId.getValue + val resources = remainingResources(offerId) + + if (canLaunchTask(slaveId, resources)) { + // Create a task + launchTasks = true + val taskId = newMesosTaskId() + val offerCPUs = getResource(resources, "cpus").toInt + + val taskCPUs = executorCores(offerCPUs) + val taskMemory = executorMemory(sc) + + slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) + + val (afterCPUResources, cpuResourcesToUse) = + partitionResources(resources, "cpus", taskCPUs) + val (resourcesLeft, memResourcesToUse) = + partitionResources(afterCPUResources.asJava, "mem", taskMemory) + + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder) } - } else { - // This offer does not meet constraints. We don't need to see it again. - // Decline the offer for a long period of time. - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" - + s" for $rejectOfferDurationForUnmetConstraints seconds") - d.declineOffer(offer.getId, Filters.newBuilder() - .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + + tasks(offer.getId) ::= taskBuilder.build() + remainingResources(offerId) = resourcesLeft.asJava + totalCoresAcquired += taskCPUs + coresByTaskId(taskId) = taskCPUs } } } + tasks.toMap + } + + private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + val offerMem = getResource(resources, "mem") + val offerCPUs = getResource(resources, "cpus").toInt + val cpus = executorCores(offerCPUs) + val mem = executorMemory(sc) + + cpus > 0 && + cpus <= offerCPUs && + cpus + totalCoresAcquired <= maxCores && + mem <= offerMem && + numExecutors() < executorLimit && + slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES } + private def executorCores(offerCPUs: Int): Int = { + sc.conf.getInt("spark.executor.cores", + math.min(offerCPUs, maxCores - totalCoresAcquired)) + } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val taskId = status.getTaskId.getValue.toInt - val state = status.getState - logInfo(s"Mesos task $taskId is now $state") - val slaveId: String = status.getSlaveId.getValue + val taskId = status.getTaskId.getValue + val slaveId = status.getSlaveId.getValue + val state = TaskState.fromMesos(status.getState) + + logInfo(s"Mesos task $taskId is now ${status.getState}") + stateLock.synchronized { + val slave = slaves(slaveId) + // If the shuffle service is enabled, have the driver register with each one of the // shuffle services. This allows the shuffle services to clean up state associated with // this application when the driver exits. There is currently not a great way to detect // this through Mesos, since the shuffle services are set up independently. - if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && - slaveIdToHost.contains(slaveId) && - shuffleServiceEnabled) { + if (state.equals(TaskState.RUNNING) && + shuffleServiceEnabled && + !slave.shuffleRegistered) { assume(mesosExternalShuffleClient.isDefined, "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) - val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + - s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get - .registerDriverWithShuffleService(hostname, externalShufflePort) + .registerDriverWithShuffleService(slave.hostname, externalShufflePort) + slave.shuffleRegistered = true } - if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId.get(taskId) - slaveIdsWithExecutors -= slaveId - taskIdToSlaveId.remove(taskId) + if (TaskState.isFinished(state)) { // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (TaskState.isFailed(TaskState.fromMesos(state))) { - failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 - if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { + if (TaskState.isFailed(state)) { + slave.taskFailures += 1 + + if (slave.taskFailures >= MAX_SLAVE_FAILURES) { logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } - executorTerminated(d, slaveId, s"Executor finished with state $state") + executorTerminated(d, slaveId, taskId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node d.reviveOffers() } @@ -396,20 +489,24 @@ private[spark] class CoarseMesosSchedulerBackend( stopCalled = true super.stop() } + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. // See SPARK-12330 val stopwatch = new Stopwatch() stopwatch.start() + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent - while (slaveIdsWithExecutors.nonEmpty && + while (numExecutors() > 0 && stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { Thread.sleep(100) } - if (slaveIdsWithExecutors.nonEmpty) { - logWarning(s"Timed out waiting for ${slaveIdsWithExecutors.size} remaining executors " + + if (numExecutors() > 0) { + logWarning(s"Timed out waiting for ${numExecutors()} remaining executors " + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + "on the mesos nodes.") } + if (mesosDriver != null) { mesosDriver.stop() } @@ -418,40 +515,25 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} /** - * Called when a slave is lost or a Mesos task finished. Update local view on - * what tasks are running and remove the terminated slave from the list of pending - * slave IDs that we might have asked to be killed. It also notifies the driver - * that an executor was removed. + * Called when a slave is lost or a Mesos task finished. Updates local view on + * what tasks are running. It also notifies the driver that an executor was removed. */ - private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + private def executorTerminated(d: SchedulerDriver, + slaveId: String, + taskId: String, + reason: String): Unit = { stateLock.synchronized { - if (slaveIdsWithExecutors.contains(slaveId)) { - val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.containsKey(slaveId)) { - val taskId: Int = slaveIdToTaskId.get(slaveId) - taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) - } - // TODO: This assumes one Spark executor per Mesos slave, - // which may no longer be true after SPARK-5095 - pendingRemovedSlaveIds -= slaveId - slaveIdsWithExecutors -= slaveId - } + removeExecutor(taskId, SlaveLost(reason)) + slaves(slaveId).taskIDs.remove(taskId) } } - private def sparkExecutorId(slaveId: String, taskId: String): String = { - s"$slaveId/$taskId" - } - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { logInfo(s"Mesos slave lost: ${slaveId.getValue}") - executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) } override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) + logInfo("Mesos executor lost: %s".format(e.getValue)) } override def applicationId(): String = @@ -471,23 +553,26 @@ private[spark] class CoarseMesosSchedulerBackend( override def doKillExecutors(executorIds: Seq[String]): Boolean = { if (mesosDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") - return false - } - - val slaveIdToTaskId = taskIdToSlaveId.inverse() - for (executorId <- executorIds) { - val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.containsKey(slaveId)) { - mesosDriver.killTask( - TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) - pendingRemovedSlaveIds += slaveId - } else { - logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + false + } else { + for (executorId <- executorIds) { + val taskId = TaskID.newBuilder().setValue(executorId).build() + mesosDriver.killTask(taskId) } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true } - // no need to adjust `executorLimitOption` since the AllocationManager already communicated - // the desired limit through a call to `doRequestTotalExecutors`. - // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] - true } + + private def numExecutors(): Int = { + slaves.values.map(_.taskIDs.size).sum + } +} + +private class Slave(val hostname: String) { + val taskIDs = new HashSet[String]() + var taskFailures = 0 + var shuffleRegistered = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 340f29bac9218..8929d8a427789 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -138,7 +138,7 @@ private[spark] class MesosSchedulerBackend( val (resourcesAfterCpu, usedCpuResources) = partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = - partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) + partitionResources(resourcesAfterCpu.asJava, "mem", executorMemory(sc)) builder.addAllResources(usedCpuResources.asJava) builder.addAllResources(usedMemResources.asJava) @@ -250,7 +250,7 @@ private[spark] class MesosSchedulerBackend( // check offers for // 1. Memory requirements // 2. CPU requirements - need at least 1 for executor, 1 for task - val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsMemoryRequirements = mem >= executorMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) val meetsRequirements = (meetsMemoryRequirements && meetsCPURequirements) || diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index f9f5da9bc8df6..a98f2f1fe5da8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -140,15 +140,15 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } } - /** - * Signal that the scheduler has registered with Mesos. - */ - protected def getResource(res: JList[Resource], name: String): Double = { + def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } + /** + * Signal that the scheduler has registered with Mesos. + */ protected def markRegistered(): Unit = { registerLatch.countDown() } @@ -337,7 +337,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ - def calculateTotalMemory(sc: SparkContext): Int = { + def executorMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + sc.executorMemory diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index a4110d2d462de..e542aa0cfc4dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -17,19 +17,23 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util import java.util.Collections +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.Scalar -import org.mockito.Matchers +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient +import org.apache.spark.rpc.{RpcEndpointRef} import org.apache.spark.scheduler.TaskSchedulerImpl class CoarseMesosSchedulerBackendSuite extends SparkFunSuite @@ -37,6 +41,223 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with BeforeAndAfter { + var sparkConf: SparkConf = _ + var driver: SchedulerDriver = _ + var taskScheduler: TaskSchedulerImpl = _ + var backend: CoarseMesosSchedulerBackend = _ + var externalShuffleClient: MesosExternalShuffleClient = _ + var driverEndpoint: RpcEndpointRef = _ + + test("mesos supports killing and limiting executors") { + setBackend() + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val minMem = backend.executorMemory(sc) + val minCpu = 4 + val offers = List((minMem, minCpu)) + + // launches a task on a valid offer + offerResources(offers) + verifyTaskLaunched("o1") + + // kills executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("0"))) + val taskID0 = createTaskId("0") + verify(driver, times(1)).killTask(taskID0) + + // doesn't launch a new task when requested executors == 0 + offerResources(offers, 2) + verifyDeclinedOffer(driver, createOfferId("o2")) + + // Launches a new task when requested executors is positive + backend.doRequestTotalExecutors(2) + offerResources(offers, 2) + verifyTaskLaunched("o2") + } + + test("mesos supports killing and relaunching tasks with executors") { + setBackend() + + // launches a task on a valid offer + val minMem = backend.executorMemory(sc) + 1024 + val minCpu = 4 + val offer1 = (minMem, minCpu) + val offer2 = (minMem, 1) + offerResources(List(offer1, offer2)) + verifyTaskLaunched("o1") + + // accounts for a killed task + val status = createTaskStatus("0", "s1", TaskState.TASK_KILLED) + backend.statusUpdate(driver, status) + verify(driver, times(1)).reviveOffers() + + // Launches a new task on a valid offer from the same slave + offerResources(List(offer2)) + verifyTaskLaunched("o2") + } + + test("mesos supports spark.executor.cores") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + val executorMemory = backend.executorMemory(sc) + val offers = List((executorMemory * 2, executorCores + 1)) + offerResources(offers) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == executorCores) + } + + test("mesos supports unset spark.executor.cores") { + setBackend() + + val executorMemory = backend.executorMemory(sc) + val offerCores = 10 + offerResources(List((executorMemory * 2, offerCores))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == offerCores) + } + + test("mesos does not acquire more than spark.cores.max") { + val maxCores = 10 + setBackend(Map("spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory, maxCores + 1))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == maxCores) + } + + test("mesos declines offers that violate attribute constraints") { + setBackend(Map("spark.mesos.constraints" -> "x:true")) + offerResources(List((backend.executorMemory(sc), 4))) + verifyDeclinedOffer(driver, createOfferId("o1"), true) + } + + test("mesos assigns tasks round-robin on offers") { + val executorCores = 4 + val maxCores = executorCores * 2 + setBackend(Map("spark.executor.cores" -> executorCores.toString, + "spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List( + (executorMemory * 2, executorCores * 2), + (executorMemory * 2, executorCores * 2))) + + verifyTaskLaunched("o1") + verifyTaskLaunched("o2") + } + + test("mesos creates multiple executors on a single slave") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + // offer with room for two executors + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory * 2, executorCores * 2))) + + // verify two executors were started on a single offer + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 2) + } + + test("mesos doesn't register twice with the same shuffle service") { + setBackend(Map("spark.shuffle.service.enabled" -> "true")) + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + val offer2 = createOffer("o2", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer2).asJava) + verifyTaskLaunched("o2") + + val status1 = createTaskStatus("0", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status1) + + val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status2) + verify(externalShuffleClient, times(1)).registerDriverWithShuffleService(anyString, anyInt) + } + + test("mesos kills an executor when told") { + setBackend() + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + backend.doKillExecutors(List("0")) + verify(driver, times(1)).killTask(createTaskId("0")) + } + + private def verifyDeclinedOffer(driver: SchedulerDriver, + offerId: OfferID, + filter: Boolean = false): Unit = { + if (filter) { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId), anyObject[Filters]) + } else { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId)) + } + } + + private def offerResources(offers: List[(Int, Int)], startId: Int = 1): Unit = { + val mesosOffers = offers.zipWithIndex.map {case (offer, i) => + createOffer(s"o${i + startId}", s"s${i + startId}", offer._1, offer._2)} + + backend.resourceOffers(driver, mesosOffers.asJava) + } + + private def verifyTaskLaunched(offerId: String): java.util.Collection[TaskInfo] = { + val captor = ArgumentCaptor.forClass(classOf[java.util.Collection[TaskInfo]]) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + captor.capture()) + captor.getValue + } + + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { + TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setState(state) + .build + } + + + private def createOfferId(offerId: String): OfferID = { + OfferID.newBuilder().setValue(offerId).build() + } + + private def createSlaveId(slaveId: String): SlaveID = { + SlaveID.newBuilder().setValue(slaveId).build() + } + + private def createExecutorId(executorId: String): ExecutorID = { + ExecutorID.newBuilder().setValue(executorId).build() + } + + private def createTaskId(taskId: String): TaskID = { + TaskID.newBuilder().setValue(taskId).build() + } + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() @@ -47,8 +268,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder() - .setValue(offerId).build()) + builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() .setValue("f1")) .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) @@ -58,130 +278,55 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, - driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + driver: SchedulerDriver, + shuffleClient: MesosExternalShuffleClient, + endpoint: RpcEndpointRef): CoarseMesosSchedulerBackend = { val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( - masterUrl: String, - scheduler: Scheduler, - sparkUser: String, - appName: String, - conf: SparkConf, - webuiUrl: Option[String] = None, - checkpoint: Option[Boolean] = None, - failoverTimeout: Option[Double] = None, - frameworkId: Option[String] = None): SchedulerDriver = driver + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver + + override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient + + override protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint + markRegistered() } backend.start() backend } - var sparkConf: SparkConf = _ - - before { + private def setBackend(sparkConfVars: Map[String, String] = null) { sparkConf = (new SparkConf) .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") - sc = new SparkContext(sparkConf) - } - - test("mesos supports killing and limiting executors") { - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) - - sparkConf.set("spark.driver.host", "driverHost") - sparkConf.set("spark.driver.port", "1234") - - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) - - val taskID0 = TaskID.newBuilder().setValue("0").build() - - backend.resourceOffers(driver, mesosOffers) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) - - // simulate the allocation manager down-scaling executors - backend.doRequestTotalExecutors(0) - assert(backend.doKillExecutors(Seq("s1/0"))) - verify(driver, times(1)).killTask(taskID0) - - val mesosOffers2 = new java.util.ArrayList[Offer] - mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) - backend.resourceOffers(driver, mesosOffers2) - - verify(driver, times(1)) - .declineOffer(OfferID.newBuilder().setValue("o2").build()) - - // Verify we didn't launch any new executor - assert(backend.slaveIdsWithExecutors.size === 1) - - backend.doRequestTotalExecutors(2) - backend.resourceOffers(driver, mesosOffers2) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) + if (sparkConfVars != null) { + for (attr <- sparkConfVars) { + sparkConf.set(attr._1, attr._2) + } + } - assert(backend.slaveIdsWithExecutors.size === 2) - backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) - assert(backend.slaveIdsWithExecutors.size === 1) - } + sc = new SparkContext(sparkConf) - test("mesos supports killing and relaunching tasks with executors") { - val driver = mock[SchedulerDriver] + driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] + taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + externalShuffleClient = mock[MesosExternalShuffleClient] + driverEndpoint = mock[RpcEndpointRef] - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) + 1024 - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - val offer1 = createOffer("o1", "s1", minMem, minCpu) - mesosOffers.add(offer1) - - val offer2 = createOffer("o2", "s1", minMem, 1); - - backend.resourceOffers(driver, mesosOffers) - - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer1.getId)), - anyObject(), - anyObject[Filters]) - - // Simulate task killed, executor no longer running - val status = TaskStatus.newBuilder() - .setTaskId(TaskID.newBuilder().setValue("0").build()) - .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) - .setState(TaskState.TASK_KILLED) - .build - - backend.statusUpdate(driver, status) - assert(!backend.slaveIdsWithExecutors.contains("s1")) - - mesosOffers.clear() - mesosOffers.add(offer2) - backend.resourceOffers(driver, mesosOffers) - assert(backend.slaveIdsWithExecutors.contains("s1")) - - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer2.getId)), - anyObject(), - anyObject[Filters]) - - verify(driver, times(1)).reviveOffers() + backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index e111e2e9f6163..3fb3279073f24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -189,7 +189,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val minMem = backend.calculateTotalMemory(sc) + val minMem = backend.executorMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 2eb43b7313381..85437b2f80817 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -41,20 +41,20 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("use at-least minimum overhead") { val f = fixture when(f.sc.executorMemory).thenReturn(512) - utils.calculateTotalMemory(f.sc) shouldBe 896 + utils.executorMemory(f.sc) shouldBe 896 } test("use overhead if it is greater than minimum value") { val f = fixture when(f.sc.executorMemory).thenReturn(4096) - utils.calculateTotalMemory(f.sc) shouldBe 4505 + utils.executorMemory(f.sc) shouldBe 4505 } test("use spark.mesos.executor.memoryOverhead (if set)") { val f = fixture when(f.sc.executorMemory).thenReturn(1024) f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") - utils.calculateTotalMemory(f.sc) shouldBe 1536 + utils.executorMemory(f.sc) shouldBe 1536 } test("parse a non-empty constraint string correctly") { diff --git a/docs/configuration.md b/docs/configuration.md index cd9dc1bcfc113..b07c69cd4c987 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -825,13 +825,18 @@ Apart from these, the following properties are also available, and may be useful spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - The number of cores to use on each executor. For YARN and standalone mode only. + 1 in YARN mode, all the available cores on the worker in + standalone and Mesos coarse-grained modes. + + + The number of cores to use on each executor. - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one - executor per application will run on each worker. + In standalone and Mesos coarse-grained modes, setting this + parameter allows an application to run multiple executors on the + same worker, provided that there are enough cores on that + worker. Otherwise, only one executor per application will run on + each worker. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index e1c87a8d95a32..0df476d9b4dd7 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -277,9 +277,11 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores 0 - Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. - The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Set the extra number of cores for an executor to advertise. This + does not result in more cores allocated. It instead means that an + executor will "pretend" it has more cores, so that the driver will + send it more tasks. Use this to increase parallelism. This + setting is only used for Mesos coarse-grained mode. From 5cf20598cec4e60b53c0e40dc4243f436396e7fc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Feb 2016 11:00:38 -0800 Subject: [PATCH 1852/2089] [SPARK-13254][SQL] Fix planning of TakeOrderedAndProject operator The patch for SPARK-8964 ("use Exchange to perform shuffle in Limit" / #7334) inadvertently broke the planning of the TakeOrderedAndProject operator: because ReturnAnswer was the new root of the query plan, the TakeOrderedAndProject rule was unable to match before BasicOperators. This patch fixes this by moving the `TakeOrderedAndCollect` and `CollectLimit` rules into the same strategy. In addition, I made changes to the TakeOrderedAndProject operator in order to make its `doExecute()` method lazy and added a new TakeOrderedAndProjectSuite which tests the new code path. /cc davies and marmbrus for review. Author: Josh Rosen Closes #11145 from JoshRosen/take-ordered-and-project-fix. --- .../spark/sql/execution/SparkPlanner.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 40 +++++---- .../apache/spark/sql/execution/limit.scala | 30 +++++-- .../spark/sql/execution/PlannerSuite.scala | 44 ++++++---- .../TakeOrderedAndProjectSuite.scala | 85 +++++++++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 2 +- 6 files changed, 159 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6e9a4df828246..d1569a4ec2b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -31,7 +31,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { sqlContext.experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrderedAndProject :: + SpecialLimits :: Aggregation :: LeftSemiJoin :: EquiJoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ee392e4e8debe..598ddd71613b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,31 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => + /** + * Plans special cases of limit operators. + */ + object SpecialLimits extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ReturnAnswer(rootPlan) => rootPlan match { + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.CollectLimit(limit, planLater(child)) :: Nil + case other => planLater(other) :: Nil + } + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil + case _ => Nil + } + } + object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( @@ -264,18 +289,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) - object TakeOrderedAndProject extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil - case logical.Limit( - IntegerLiteral(limit), - logical.Project(projectList, logical.Sort(order, true, child))) => - execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil - case _ => Nil - } - } - object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => @@ -338,8 +351,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil - case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) => - execution.CollectLimit(limit, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => val perPartitionLimit = execution.LocalLimit(limit, planLater(child)) val globalLimit = execution.GlobalLimit(limit, perPartitionLimit) @@ -362,7 +373,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil - case logical.ReturnAnswer(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 256f4228ae99e..04daf9d0ce2a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -83,8 +83,7 @@ case class TakeOrderedAndProject( child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) + projectList.map(_.map(_.toAttribute)).getOrElse(child.output) } override def outputPartitioning: Partitioning = SinglePartition @@ -93,7 +92,7 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - private def collectData(): Array[InternalRow] = { + override def executeCollect(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) if (projectList.isDefined) { val proj = UnsafeProjection.create(projectList.get, child.output) @@ -103,13 +102,26 @@ case class TakeOrderedAndProject( } } - override def executeCollect(): Array[InternalRow] = { - collectData() - } + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = { + val localTopK: RDD[InternalRow] = { + child.execute().map(_.copy()).mapPartitions { iter => + org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord) + } + } + val shuffled = new ShuffledRowRDD( + Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer)) + shuffled.mapPartitions { iter => + val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + topK.map(r => proj(r)) + } else { + topK + } + } + } override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a64ad4038c7c3..250ce8f86698f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -161,30 +162,37 @@ class PlannerSuite extends SharedSQLContext { } } - test("efficient limit -> project -> sort") { - { - val query = - testData.select('key, 'value).sort('key).limit(2).logicalPlan - val planned = sqlContext.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) - } + test("efficient terminal limit -> sort should use TakeOrderedAndProject") { + val query = testData.select('key, 'value).sort('key).limit(2) + val planned = query.queryExecution.executedPlan + assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.output === testData.select('key, 'value).logicalPlan.output) + } - { - // We need to make sure TakeOrderedAndProject's output is correct when we push a project - // into it. - val query = - testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan - val planned = sqlContext.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) - } + test("terminal limit -> project -> sort should use TakeOrderedAndProject") { + val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) + val planned = query.queryExecution.executedPlan + assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.output === testData.select('value, 'key).logicalPlan.output) } - test("terminal limits use CollectLimit") { + test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimit]) + assert(planned.output === testData.select('value).logicalPlan.output) + } + + test("TakeOrderedAndProject can appear in the middle of plans") { + val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3) + val planned = query.queryExecution.executedPlan + assert(planned.find(_.isInstanceOf[TakeOrderedAndProject]).isDefined) + } + + test("CollectLimit can appear in the middle of a plan when caching is used") { + val query = testData.select('key, 'value).limit(2).cache() + val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] + assert(planned.child.isInstanceOf[CollectLimit]) } test("PartitioningCollection") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala new file mode 100644 index 0000000000000..03cb04a5f7a03 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + + +class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { + + private var rand: Random = _ + private var seed: Long = 0 + + protected override def beforeAll(): Unit = { + super.beforeAll() + seed = System.currentTimeMillis() + rand = new Random(seed) + } + + private def generateRandomInputData(): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + sqlContext.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } + + /** + * Adds a no-op filter to the child plan in order to prevent executeCollect() from being + * called directly on the child plan. + */ + private def noOpFilter(plan: SparkPlan): SparkPlan = Filter(Literal(true), plan) + + val limit = 250 + val sortOrder = 'a.desc :: 'b.desc :: Nil + + test("TakeOrderedAndProject.doExecute without project") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProject(limit, sortOrder, None, input)), + input => + GlobalLimit(limit, + LocalLimit(limit, + Sort(sortOrder, global = true, input))), + sortAnswers = false) + } + } + + test("TakeOrderedAndProject.doExecute with project") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProject(limit, sortOrder, Some(Seq(input.output.last)), input)), + input => + GlobalLimit(limit, + LocalLimit(limit, + Project(Seq(input.output.last), + Sort(sortOrder, global = true, input)))), + sortAnswers = false) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 05863ae18350d..2433b54ffcb8e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -559,7 +559,7 @@ class HiveContext private[hive]( HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrderedAndProject, + SpecialLimits, InMemoryScans, HiveTableScans, DataSinks, From 39cc620e9c9bf3722992da8d39a928ba638ff21d Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 10 Feb 2016 14:07:50 -0600 Subject: [PATCH 1853/2089] [SPARK-13163][WEB UI] Column width on new History Server DataTables not getting set correctly The column width for the new DataTables now adjusts for the current page rather than being hard-coded for the entire table's data. Author: Alex Bozarth Closes #11057 from ajbozarth/spark13163. --- .../src/main/resources/org/apache/spark/ui/static/historypage.js | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 785abe45bc56e..689c92e86129e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -118,6 +118,7 @@ $(document).ready(function() { {name: 'seventh'}, {name: 'eighth'}, ], + "autoWidth": false }; var rowGroupConf = { From 4b80026f077edb25a2a4563abaf326240e89d4af Mon Sep 17 00:00:00 2001 From: zhuol Date: Wed, 10 Feb 2016 14:23:41 -0600 Subject: [PATCH 1854/2089] [SPARK-13126] fix the right margin of history page. The right margin of the history page is little bit off. A simple fix for that issue. Author: zhuol Closes #11029 from zhuoliu/13126. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 3fee22e39166a..513b81315b3d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -35,7 +35,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val providerConfig = parent.getProviderConfig() val content = -
        +
          {providerConfig.map { case (k, v) =>
        • {k}: {v}
        • }} From ce3bdaeeff718b7b5809eed15f7e4b5188e9fc7c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Feb 2016 12:44:14 -0800 Subject: [PATCH 1855/2089] [HOTFIX] Fix Scala 2.10 build break in TakeOrderedAndProjectSuite. --- .../spark/sql/execution/TakeOrderedAndProjectSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 03cb04a5f7a03..a4c6d072f33a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -63,7 +63,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { input => GlobalLimit(limit, LocalLimit(limit, - Sort(sortOrder, global = true, input))), + Sort(sortOrder, true, input))), sortAnswers = false) } } @@ -78,7 +78,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { GlobalLimit(limit, LocalLimit(limit, Project(Seq(input.output.last), - Sort(sortOrder, global = true, input)))), + Sort(sortOrder, true, input)))), sortAnswers = false) } } From 5947fa8fa1f95d8fc1537c1e37bc16bae8fe7988 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 10 Feb 2016 13:34:02 -0800 Subject: [PATCH 1856/2089] [SPARK-13057][SQL] Add benchmark codes and the performance results for implemented compression schemes for InMemoryRelation This pr adds benchmark codes for in-memory cache compression to make future developments and discussions more smooth. Author: Takeshi YAMAMURO Closes #10965 from maropu/ImproveColumnarCache. --- .../CompressionSchemeBenchmark.scala | 240 ++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala new file mode 100644 index 0000000000000..95eb5cf912e2a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar.compression + +import java.nio.ByteBuffer +import java.nio.ByteOrder + +import org.apache.commons.lang3.RandomStringUtils +import org.apache.commons.math3.distribution.LogNormalDistribution + +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.execution.columnar.BOOLEAN +import org.apache.spark.sql.execution.columnar.INT +import org.apache.spark.sql.execution.columnar.LONG +import org.apache.spark.sql.execution.columnar.NativeColumnType +import org.apache.spark.sql.execution.columnar.SHORT +import org.apache.spark.sql.execution.columnar.STRING +import org.apache.spark.sql.types.AtomicType +import org.apache.spark.util.Benchmark +import org.apache.spark.util.Utils._ + +/** + * Benchmark to decoders using various compression schemes. + */ +object CompressionSchemeBenchmark extends AllCompressionSchemes { + + private[this] def allocateLocal(size: Int): ByteBuffer = { + ByteBuffer.allocate(size).order(ByteOrder.nativeOrder) + } + + private[this] def genLowerSkewData() = { + val rng = new LogNormalDistribution(0.0, 0.01) + () => rng.sample + } + + private[this] def genHigherSkewData() = { + val rng = new LogNormalDistribution(0.0, 1.0) + () => rng.sample + } + + private[this] def runBenchmark[T <: AtomicType]( + name: String, + iters: Int, + count: Int, + tpe: NativeColumnType[T], + input: ByteBuffer): Unit = { + + val benchmark = new Benchmark(name, iters * count) + + schemes.filter(_.supports(tpe)).map { scheme => + def toRow(d: Any) = new GenericInternalRow(Array[Any](d)) + val encoder = scheme.encoder(tpe) + for (i <- 0 until count) { + encoder.gatherCompressibilityStats(toRow(tpe.extract(input)), 0) + } + input.rewind() + + val label = s"${getFormattedClassName(scheme)}(${encoder.compressionRatio.formatted("%.3f")})" + benchmark.addCase(label)({ i: Int => + val compressedSize = if (encoder.compressedSize == 0) { + input.remaining() + } else { + encoder.compressedSize + } + + val buf = allocateLocal(4 + compressedSize) + val rowBuf = new GenericMutableRow(1) + val compressedBuf = encoder.compress(input, buf) + input.rewind() + + for (n <- 0L until iters) { + compressedBuf.rewind.position(4) + val decoder = scheme.decoder(compressedBuf, tpe) + while (decoder.hasNext) { + decoder.next(rowBuf, 0) + } + } + }) + } + + benchmark.run() + } + + def bitDecode(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * BOOLEAN.defaultSize) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Decode: Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 124.98 536.96 1.00 X + // RunLengthEncoding(2.494) 631.37 106.29 0.20 X + // BooleanBitSet(0.125) 1200.36 55.91 0.10 X + val g = { + val rng = genLowerSkewData() + () => (rng().toInt % 2).toByte + } + for (i <- 0 until count) { + testData.put(i * BOOLEAN.defaultSize, g()) + } + runBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) + } + + def shortDecode(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * SHORT.defaultSize) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 376.87 178.07 1.00 X + // RunLengthEncoding(1.498) 831.59 80.70 0.45 X + val g1 = genLowerSkewData() + for (i <- 0 until count) { + testData.putShort(i * SHORT.defaultSize, g1().toShort) + } + runBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 426.83 157.23 1.00 X + // RunLengthEncoding(1.996) 845.56 79.37 0.50 X + val g2 = genHigherSkewData() + for (i <- 0 until count) { + testData.putShort(i * SHORT.defaultSize, g2().toShort) + } + runBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) + } + + def intDecode(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * INT.defaultSize) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode(Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 325.16 206.39 1.00 X + // RunLengthEncoding(0.997) 1219.44 55.03 0.27 X + // DictionaryEncoding(0.500) 955.51 70.23 0.34 X + // IntDelta(0.250) 1146.02 58.56 0.28 X + val g1 = genLowerSkewData() + for (i <- 0 until count) { + testData.putInt(i * INT.defaultSize, g1().toInt) + } + runBenchmark("INT Decode(Lower Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode(Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 1133.45 59.21 1.00 X + // RunLengthEncoding(1.334) 1399.00 47.97 0.81 X + // DictionaryEncoding(0.501) 1032.87 64.97 1.10 X + // IntDelta(0.250) 948.02 70.79 1.20 X + val g2 = genHigherSkewData() + for (i <- 0 until count) { + testData.putInt(i * INT.defaultSize, g2().toInt) + } + runBenchmark("INT Decode(Higher Skew)", iters, count, INT, testData) + } + + def longDecode(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * LONG.defaultSize) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode(Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 1101.07 60.95 1.00 X + // RunLengthEncoding(0.756) 1372.57 48.89 0.80 X + // DictionaryEncoding(0.250) 947.80 70.81 1.16 X + // LongDelta(0.125) 721.51 93.01 1.53 X + val g1 = genLowerSkewData() + for (i <- 0 until count) { + testData.putLong(i * LONG.defaultSize, g1().toLong) + } + runBenchmark("LONG Decode(Lower Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode(Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 986.71 68.01 1.00 X + // RunLengthEncoding(1.013) 1348.69 49.76 0.73 X + // DictionaryEncoding(0.251) 865.48 77.54 1.14 X + // LongDelta(0.125) 816.90 82.15 1.21 X + val g2 = genHigherSkewData() + for (i <- 0 until count) { + testData.putLong(i * LONG.defaultSize, g2().toLong) + } + runBenchmark("LONG Decode(Higher Skew)", iters, count, LONG, testData) + } + + def stringDecode(iters: Int): Unit = { + val count = 65536 + val strLen = 8 + val tableSize = 16 + val testData = allocateLocal(count * (4 + strLen)) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Decode: Avg Time(ms) Avg Rate(M/s) Relative Rate + // ------------------------------------------------------------------------------- + // PassThrough(1.000) 2277.05 29.47 1.00 X + // RunLengthEncoding(0.893) 2624.35 25.57 0.87 X + // DictionaryEncoding(0.167) 2672.28 25.11 0.85 X + val g = { + val dataTable = (0 until tableSize).map(_ => RandomStringUtils.randomAlphabetic(strLen)) + val rng = genHigherSkewData() + () => dataTable(rng().toInt % tableSize) + } + for (i <- 0 until count) { + testData.putInt(strLen) + testData.put(g().getBytes) + } + testData.rewind() + runBenchmark("STRING Decode", iters, count, STRING, testData) + } + + def main(args: Array[String]): Unit = { + bitDecode(1024) + shortDecode(1024) + intDecode(1024) + longDecode(1024) + stringDecode(1024) + } +} From 29c547303f886b96b74b411ac70f0fa81113f086 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 10 Feb 2016 13:34:53 -0800 Subject: [PATCH 1857/2089] [SPARK-12414][CORE] Remove closure serializer Remove spark.closure.serializer option and use JavaSerializer always CC andrewor14 rxin I see there's a discussion in the JIRA but just thought I'd offer this for a look at what the change would be. Author: Sean Owen Closes #11150 from srowen/SPARK-12414. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 5 ++--- .../org/apache/spark/serializer/KryoSerializerSuite.scala | 3 +-- docs/configuration.md | 7 ------- docs/streaming-programming-guide.md | 2 -- 4 files changed, 3 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 9461afdc54124..204f7356f7ef8 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -35,7 +35,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -277,8 +277,7 @@ object SparkEnv extends Logging { "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val closureSerializer = instantiateClassFromConf[Serializer]( - "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") + val closureSerializer = new JavaSerializer(conf) def registerOrLookupEndpoint( name: String, endpointCreator: => RpcEndpoint): diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index f869bcd708619..27d063630be9d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -282,8 +282,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("kryo with fold") { val control = 1 :: 2 :: Nil // zeroValue must not be a ClassWithoutNoArgConstructor instance because it will be - // serialized by spark.closure.serializer but spark.closure.serializer only supports - // the default Java serializer. + // serialized by the Java serializer. val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) .fold(null)((t1, t2) => { val t1x = if (t1 == null) 0 else t1.x diff --git a/docs/configuration.md b/docs/configuration.md index b07c69cd4c987..dd2cde81941db 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -586,13 +586,6 @@ Apart from these, the following properties are also available, and may be useful Whether to compress broadcast variables before sending them. Generally a good idea. - - spark.closure.serializer - org.apache.spark.serializer.
          JavaSerializer - - Serializer class to use for closures. Currently only the Java serializer is supported. - - spark.io.compression.codec lz4 diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 7e681b67cf0c2..677f5ff7bea8b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2163,8 +2163,6 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. This is controlled by the ```spark.closure.serializer``` property. However, at this time, Kryo serialization cannot be enabled for closure serialization. This may be resolved in a future release. - * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. From 0902e20288366db6270f3a444e66114b1b63a3e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 10 Feb 2016 16:45:06 -0800 Subject: [PATCH 1858/2089] [SPARK-13146][SQL] Management API for continuous queries ### Management API for Continuous Queries **API for getting status of each query** - Whether active or not - Unique name of each query - Status of the sources and sinks - Exceptions **API for managing each query** - Immediately stop an active query - Waiting for a query to be terminated, correctly or with error **API for managing multiple queries** - Listing all active queries - Getting an active query by name - Waiting for any one of the active queries to be terminated **API for listening to query life cycle events** - ContinuousQueryListener API for query start, progress and termination events. Author: Tathagata Das Closes #11030 from tdas/streaming-df-management-api. --- .../apache/spark/sql/ContinuousQuery.scala | 72 ++++- .../spark/sql/ContinuousQueryException.scala | 54 ++++ .../spark/sql/ContinuousQueryManager.scala | 193 +++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 14 +- .../org/apache/spark/sql/SQLContext.scala | 12 + .../org/apache/spark/sql/SinkStatus.scala | 34 ++ .../org/apache/spark/sql/SourceStatus.scala | 34 ++ .../ContinuousQueryListenerBus.scala | 82 +++++ .../execution/streaming/StreamExecution.scala | 215 +++++++++--- .../execution/streaming/StreamProgress.scala | 4 + .../sql/execution/streaming/memory.scala | 20 +- .../sql/util/ContinuousQueryListener.scala | 67 ++++ .../org/apache/spark/sql/StreamTest.scala | 252 ++++++++++++--- .../ContinuousQueryManagerSuite.scala | 306 ++++++++++++++++++ .../sql/streaming/ContinuousQuerySuite.scala | 139 ++++++++ .../DataFrameReaderWriterSuite.scala | 69 +++- .../util/ContinuousQueryListenerSuite.scala | 222 +++++++++++++ 17 files changed, 1680 insertions(+), 109 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index 1c2c0290fc4cd..eb69804c39b5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -17,14 +17,84 @@ package org.apache.spark.sql +import org.apache.spark.annotation.Experimental + /** + * :: Experimental :: * A handle to a query that is executing continuously in the background as new data arrives. + * All these methods are thread-safe. + * @since 2.0.0 */ +@Experimental trait ContinuousQuery { /** - * Stops the execution of this query if it is running. This method blocks until the threads + * Returns the name of the query. + * @since 2.0.0 + */ + def name: String + + /** + * Returns the SQLContext associated with `this` query + * @since 2.0.0 + */ + def sqlContext: SQLContext + + /** + * Whether the query is currently active or not + * @since 2.0.0 + */ + def isActive: Boolean + + /** + * Returns the [[ContinuousQueryException]] if the query was terminated by an exception. + * @since 2.0.0 + */ + def exception: Option[ContinuousQueryException] + + /** + * Returns current status of all the sources. + * @since 2.0.0 + */ + def sourceStatuses: Array[SourceStatus] + + /** Returns current status of the sink. */ + def sinkStatus: SinkStatus + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. + * If the query has terminated with an exception, then the exception will be thrown. + * + * If the query has terminated, then all subsequent calls to this method will either return + * immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws ContinuousQueryException, if `this` query has terminated with an exception. + * + * @since 2.0.0 + */ + def awaitTermination(): Unit + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. + * If the query has terminated with an exception, then the exception will be throw. + * Otherwise, it returns whether the query has terminated or not within the `timeoutMs` + * milliseconds. + * + * If the query has terminated, then all subsequent calls to this method will either return + * `true` immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws ContinuousQueryException, if `this` query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitTermination(timeoutMs: Long): Boolean + + /** + * Stops the execution of this query if it is running. This method blocks until the threads * performing execution has stopped. + * @since 2.0.0 */ def stop(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala new file mode 100644 index 0000000000000..67dd9dbe23726 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} + +/** + * :: Experimental :: + * Exception that stopped a [[ContinuousQuery]]. + * @param query Query that caused the exception + * @param message Message of this exception + * @param cause Internal cause of this exception + * @param startOffset Starting offset (if known) of the range of data in which exception occurred + * @param endOffset Ending offset (if known) of the range of data in exception occurred + * @since 2.0.0 + */ +@Experimental +class ContinuousQueryException private[sql]( + val query: ContinuousQuery, + val message: String, + val cause: Throwable, + val startOffset: Option[Offset] = None, + val endOffset: Option[Offset] = None + ) extends Exception(message, cause) { + + /** Time when the exception occurred */ + val time: Long = System.currentTimeMillis + + override def toString(): String = { + val causeStr = + s"${cause.getMessage} ${cause.getStackTrace.take(10).mkString("", "\n|\t", "\n")}" + s""" + |$causeStr + | + |${query.asInstanceOf[StreamExecution].toDebugString} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala new file mode 100644 index 0000000000000..13142d0e61f71 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.util.ContinuousQueryListener + +/** + * :: Experimental :: + * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active + * on a [[SQLContext]]. + * + * @since 2.0.0 + */ +@Experimental +class ContinuousQueryManager(sqlContext: SQLContext) { + + private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) + private val activeQueries = new mutable.HashMap[String, ContinuousQuery] + private val activeQueriesLock = new Object + private val awaitTerminationLock = new Object + + private var lastTerminatedQuery: ContinuousQuery = null + + /** + * Returns a list of active queries associated with this SQLContext + * + * @since 2.0.0 + */ + def active: Array[ContinuousQuery] = activeQueriesLock.synchronized { + activeQueries.values.toArray + } + + /** + * Returns an active query from this SQLContext or throws exception if bad name + * + * @since 2.0.0 + */ + def get(name: String): ContinuousQuery = activeQueriesLock.synchronized { + activeQueries.get(name).getOrElse { + throw new IllegalArgumentException(s"There is no active query with name $name") + } + } + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the + * creation of the context, or since `resetTerminated()` was called. If any query was terminated + * with an exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return immediately (if the query was terminated by `query.stop()`), + * or throw the exception immediately (if the query was terminated with exception). Use + * `resetTerminated()` to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, + * if any query has terminated with exception, then `awaitAnyTermination()` will + * throw any of the exception. For correctly documenting exceptions across multiple queries, + * users need to stop all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws ContinuousQueryException, if any query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitAnyTermination(): Unit = { + awaitTerminationLock.synchronized { + while (lastTerminatedQuery == null) { + awaitTerminationLock.wait(10) + } + if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) { + throw lastTerminatedQuery.exception.get + } + } + } + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the + * creation of the context, or since `resetTerminated()` was called. Returns whether any query + * has terminated or not (multiple may have terminated). If any query has terminated with an + * exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return `true` immediately (if the query was terminated by `query.stop()`), + * or throw the exception immediately (if the query was terminated with exception). Use + * `resetTerminated()` to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, + * if any query has terminated with exception, then `awaitAnyTermination()` will + * throw any of the exception. For correctly documenting exceptions across multiple queries, + * users need to stop all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws ContinuousQueryException, if any query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitAnyTermination(timeoutMs: Long): Boolean = { + + val startTime = System.currentTimeMillis + def isTimedout = System.currentTimeMillis - startTime >= timeoutMs + + awaitTerminationLock.synchronized { + while (!isTimedout && lastTerminatedQuery == null) { + awaitTerminationLock.wait(10) + } + if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) { + throw lastTerminatedQuery.exception.get + } + lastTerminatedQuery != null + } + } + + /** + * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to + * wait for new terminations. + * + * @since 2.0.0 + */ + def resetTerminated(): Unit = { + awaitTerminationLock.synchronized { + lastTerminatedQuery = null + } + } + + /** + * Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]]. + * + * @since 2.0.0 + */ + def addListener(listener: ContinuousQueryListener): Unit = { + listenerBus.addListener(listener) + } + + /** + * Deregister a [[ContinuousQueryListener]]. + * + * @since 2.0.0 + */ + def removeListener(listener: ContinuousQueryListener): Unit = { + listenerBus.removeListener(listener) + } + + /** Post a listener event */ + private[sql] def postListenerEvent(event: ContinuousQueryListener.Event): Unit = { + listenerBus.post(event) + } + + /** Start a query */ + private[sql] def startQuery(name: String, df: DataFrame, sink: Sink): ContinuousQuery = { + activeQueriesLock.synchronized { + if (activeQueries.contains(name)) { + throw new IllegalArgumentException( + s"Cannot start query with name $name as a query with that name is already active") + } + val query = new StreamExecution(sqlContext, name, df.logicalPlan, sink) + query.start() + activeQueries.put(name, query) + query + } + } + + /** Notify (by the ContinuousQuery) that the query has been terminated */ + private[sql] def notifyQueryTermination(terminatedQuery: ContinuousQuery): Unit = { + activeQueriesLock.synchronized { + activeQueries -= terminatedQuery.name + } + awaitTerminationLock.synchronized { + if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) { + lastTerminatedQuery = terminatedQuery + } + awaitTerminationLock.notifyAll() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8060198968988..d6bdd3d825565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -205,6 +205,17 @@ final class DataFrameWriter private[sql](df: DataFrame) { df) } + /** + * Specifies the name of the [[ContinuousQuery]] that can be started with `stream()`. + * This name must be unique among all the currently active queries in the associated SQLContext. + * + * @since 2.0.0 + */ + def queryName(queryName: String): DataFrameWriter = { + this.extraOptions += ("queryName" -> queryName) + this + } + /** * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with @@ -230,7 +241,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { extraOptions.toMap, normalizedParCols.getOrElse(Nil)) - new StreamExecution(df.sqlContext, df.logicalPlan, sink) + df.sqlContext.continuousQueryManager.startQuery( + extraOptions.getOrElse("queryName", StreamExecution.nextName), df, sink) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1661fdbec5326..050a1031c0561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -181,6 +181,8 @@ class SQLContext private[sql]( @transient lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + protected[sql] lazy val continuousQueryManager = new ContinuousQueryManager(this) + @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) @@ -835,6 +837,16 @@ class SQLContext private[sql]( DataFrame(this, ShowTablesCommand(Some(databaseName))) } + /** + * Returns a [[ContinuousQueryManager]] that allows managing all the + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context. + * + * @since 2.0.0 + */ + def streams: ContinuousQueryManager = { + continuousQueryManager + } + /** * Returns the names of tables in the current database as an array. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala new file mode 100644 index 0000000000000..ce21451b2c9c7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, Sink} + +/** + * :: Experimental :: + * Status and metrics of a streaming [[Sink]]. + * + * @param description Description of the source corresponding to this status + * @param offset Current offset up to which data has been written by the sink + * @since 2.0.0 + */ +@Experimental +class SinkStatus private[sql]( + val description: String, + val offset: Option[Offset]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala new file mode 100644 index 0000000000000..2479e67e369ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, Source} + +/** + * :: Experimental :: + * Status and metrics of a streaming [[Source]]. + * + * @param description Description of the source corresponding to this status + * @param offset Current offset of the source, if known + * @since 2.0.0 + */ +@Experimental +class SourceStatus private[sql] ( + val description: String, + val offset: Option[Offset]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala new file mode 100644 index 0000000000000..b1d24b6cfc0bd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} +import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.sql.util.ContinuousQueryListener._ +import org.apache.spark.util.ListenerBus + +/** + * A bus to forward events to [[ContinuousQueryListener]]s. This one will wrap received + * [[ContinuousQueryListener.Event]]s as WrappedContinuousQueryListenerEvents and send them to the + * Spark listener bus. It also registers itself with Spark listener bus, so that it can receive + * WrappedContinuousQueryListenerEvents, unwrap them as ContinuousQueryListener.Events and + * dispatch them to ContinuousQueryListener. + */ +class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus) + extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] { + + sparkListenerBus.addListener(this) + + /** + * Post a ContinuousQueryListener event to the Spark listener bus asynchronously. This event will + * be dispatched to all ContinuousQueryListener in the thread of the Spark listener bus. + */ + def post(event: ContinuousQueryListener.Event) { + event match { + case s: QueryStarted => + postToAll(s) + case _ => + sparkListenerBus.post(new WrappedContinuousQueryListenerEvent(event)) + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case WrappedContinuousQueryListenerEvent(e) => + postToAll(e) + case _ => + } + } + + override protected def doPostEvent( + listener: ContinuousQueryListener, + event: ContinuousQueryListener.Event): Unit = { + event match { + case queryStarted: QueryStarted => + listener.onQueryStarted(queryStarted) + case queryProgress: QueryProgress => + listener.onQueryProgress(queryProgress) + case queryTerminated: QueryTerminated => + listener.onQueryTerminated(queryTerminated) + case _ => + } + } + + /** + * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark + * listener bus. + */ + private case class WrappedContinuousQueryListenerEvent( + streamingListenerEvent: ContinuousQueryListener.Event) extends SparkListenerEvent { + + // Do not log streaming events in event log as history server does not support these events. + protected[spark] override def logEvent: Boolean = false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index ebebb829710b2..bc7c520930f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,16 +17,20 @@ package org.apache.spark.sql.execution.streaming -import java.lang.Thread.UncaughtExceptionHandler +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.sql.util.ContinuousQueryListener._ /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. @@ -35,15 +39,15 @@ import org.apache.spark.sql.execution.QueryExecution * and the results are committed transactionally to the given [[Sink]]. */ class StreamExecution( - sqlContext: SQLContext, + val sqlContext: SQLContext, + override val name: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink) extends ContinuousQuery with Logging { /** An monitor used to wait/notify when batches complete. */ private val awaitBatchLock = new Object - - @volatile - private var batchRun = false + private val startLatch = new CountDownLatch(1) + private val terminationLatch = new CountDownLatch(1) /** Minimum amount of time in between the start of each batch. */ private val minBatchTime = 10 @@ -55,9 +59,92 @@ class StreamExecution( private val sources = logicalPlan.collect { case s: StreamingRelation => s.source } - // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data - // that we have already processed). - { + /** Defines the internal state of execution */ + @volatile + private var state: State = INITIALIZED + + @volatile + private[sql] var lastExecution: QueryExecution = null + + @volatile + private[sql] var streamDeathCause: ContinuousQueryException = null + + /** The thread that runs the micro-batches of this stream. */ + private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") { + override def run(): Unit = { runBatches() } + } + + /** Whether the query is currently active or not */ + override def isActive: Boolean = state == ACTIVE + + /** Returns current status of all the sources. */ + override def sourceStatuses: Array[SourceStatus] = { + sources.map(s => new SourceStatus(s.toString, streamProgress.get(s))).toArray + } + + /** Returns current status of the sink. */ + override def sinkStatus: SinkStatus = new SinkStatus(sink.toString, sink.currentOffset) + + /** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */ + override def exception: Option[ContinuousQueryException] = Option(streamDeathCause) + + /** + * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event + * has been posted to all the listeners. + */ + private[sql] def start(): Unit = { + microBatchThread.setDaemon(true) + microBatchThread.start() + startLatch.await() // Wait until thread started and QueryStart event has been posted + } + + /** + * Repeatedly attempts to run batches as data arrives. + * + * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted + * so that listeners are guaranteed to get former event before the latter. Furthermore, this + * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns. + */ + private def runBatches(): Unit = { + try { + // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, + // so must mark this as ACTIVE first. + state = ACTIVE + postEvent(new QueryStarted(this)) // Assumption: Does not throw exception. + + // Unblock starting thread + startLatch.countDown() + + // While active, repeatedly attempt to run batches. + SQLContext.setActive(sqlContext) + populateStartOffsets() + logInfo(s"Stream running at $streamProgress") + while (isActive) { + attemptBatch() + Thread.sleep(minBatchTime) // TODO: Could be tighter + } + } catch { + case _: InterruptedException if state == TERMINATED => // interrupted by stop() + case NonFatal(e) => + streamDeathCause = new ContinuousQueryException( + this, + s"Query $name terminated with exception: ${e.getMessage}", + e, + Some(streamProgress.toCompositeOffset(sources))) + logError(s"Query $name terminated with error", e) + } finally { + state = TERMINATED + sqlContext.streams.notifyQueryTermination(StreamExecution.this) + postEvent(new QueryTerminated(this)) + terminationLatch.countDown() + } + } + + /** + * Populate the start offsets to start the execution at the current offsets stored in the sink + * (i.e. avoid reprocessing data that we have already processed). + */ + private def populateStartOffsets(): Unit = { sink.currentOffset match { case Some(c: CompositeOffset) => val storedProgress = c.offsets @@ -74,37 +161,8 @@ class StreamExecution( } } - logInfo(s"Stream running at $streamProgress") - - /** When false, signals to the microBatchThread that it should stop running. */ - @volatile private var shouldRun = true - - /** The thread that runs the micro-batches of this stream. */ - private[sql] val microBatchThread = new Thread("stream execution thread") { - override def run(): Unit = { - SQLContext.setActive(sqlContext) - while (shouldRun) { - attemptBatch() - Thread.sleep(minBatchTime) // TODO: Could be tighter - } - } - } - microBatchThread.setDaemon(true) - microBatchThread.setUncaughtExceptionHandler( - new UncaughtExceptionHandler { - override def uncaughtException(t: Thread, e: Throwable): Unit = { - streamDeathCause = e - } - }) - microBatchThread.start() - - @volatile - private[sql] var lastExecution: QueryExecution = null - @volatile - private[sql] var streamDeathCause: Throwable = null - /** - * Checks to see if any new data is present in any of the sources. When new data is available, + * Checks to see if any new data is present in any of the sources. When new data is available, * a batch is executed and passed to the sink, updating the currentOffsets. */ private def attemptBatch(): Unit = { @@ -150,36 +208,43 @@ class StreamExecution( streamProgress.synchronized { // Update the offsets and calculate a new composite offset newOffsets.foreach(streamProgress.update) - val newStreamProgress = logicalPlan.collect { - case StreamingRelation(source, _) => streamProgress.get(source) - } - val batchOffset = CompositeOffset(newStreamProgress) // Construct the batch and send it to the sink. + val batchOffset = streamProgress.toCompositeOffset(sources) val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan)) sink.addBatch(nextBatch) } - batchRun = true awaitBatchLock.synchronized { // Wake up any threads that are waiting for the stream to progress. awaitBatchLock.notifyAll() } val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 - logInfo(s"Compete up to $newOffsets in ${batchTime}ms") + logInfo(s"Completed up to $newOffsets in ${batchTime}ms") + postEvent(new QueryProgress(this)) } logDebug(s"Waiting for data, current: $streamProgress") } + private def postEvent(event: ContinuousQueryListener.Event) { + sqlContext.streams.postListenerEvent(event) + } + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. */ - def stop(): Unit = { - shouldRun = false - if (microBatchThread.isAlive) { microBatchThread.join() } + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state = TERMINATED + if (microBatchThread.isAlive) { + microBatchThread.interrupt() + microBatchThread.join() + } + logInfo(s"Query $name was stopped") } /** @@ -198,14 +263,60 @@ class StreamExecution( logDebug(s"Unblocked at $newOffset for $source") } - override def toString: String = + override def awaitTermination(): Unit = { + if (state == INITIALIZED) { + throw new IllegalStateException("Cannot wait for termination on a query that has not started") + } + terminationLatch.await() + if (streamDeathCause != null) { + throw streamDeathCause + } + } + + override def awaitTermination(timeoutMs: Long): Boolean = { + if (state == INITIALIZED) { + throw new IllegalStateException("Cannot wait for termination on a query that has not started") + } + require(timeoutMs > 0, "Timeout has to be positive") + terminationLatch.await(timeoutMs, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } else { + !isActive + } + } + + override def toString: String = { + s"Continuous Query - $name [state = $state]" + } + + def toDebugString: String = { + val deathCauseStr = if (streamDeathCause != null) { + "Error:\n" + stackTraceToString(streamDeathCause.cause) + } else "" s""" - |=== Streaming Query === - |CurrentOffsets: $streamProgress + |=== Continuous Query === + |Name: $name + |Current Offsets: $streamProgress + | + |Current State: $state |Thread State: ${microBatchThread.getState} - |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} | + |Logical Plan: |$logicalPlan + | + |$deathCauseStr """.stripMargin + } + + trait State + case object INITIALIZED extends State + case object ACTIVE extends State + case object TERMINATED extends State } +private[sql] object StreamExecution { + private val nextId = new AtomicInteger() + + def nextName: String = s"query-${nextId.getAndIncrement}" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 0ded1d7152c19..d45b9bd9838c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -55,6 +55,10 @@ class StreamProgress { copied } + private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { + CompositeOffset(source.map(get)) + } + override def toString: String = currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index e6a0842936ea2..8124df15af4a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType object MemoryStream { @@ -46,14 +47,13 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val logicalPlan = StreamingRelation(this) protected val output = logicalPlan.output protected val batches = new ArrayBuffer[Dataset[A]] + protected var currentOffset: LongOffset = new LongOffset(-1) protected def blockManager = SparkEnv.get.blockManager def schema: StructType = encoder.schema - def getCurrentOffset: Offset = currentOffset - def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { new Dataset(sqlContext, logicalPlan) } @@ -62,6 +62,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) new DataFrame(sqlContext, logicalPlan) } + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + def addData(data: TraversableOnce[A]): Offset = { import sqlContext.implicits._ this.synchronized { @@ -110,6 +114,7 @@ class MemorySink(schema: StructType) extends Sink with Logging { } override def addBatch(nextBatch: Batch): Unit = synchronized { + nextBatch.data.collect() // 'compute' the batch's data and record the batch batches.append(nextBatch) } @@ -131,8 +136,13 @@ class MemorySink(schema: StructType) extends Sink with Logging { batches.dropRight(num) } - override def toString: String = synchronized { - batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n") + def toDebugString: String = synchronized { + batches.map { b => + val dataStr = try b.data.collect().mkString(" ") catch { + case NonFatal(e) => "[Error converting to string]" + } + s"${b.end}: $dataStr" + }.mkString("\n") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala new file mode 100644 index 0000000000000..73c78d1b62bbc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.ContinuousQuery +import org.apache.spark.sql.util.ContinuousQueryListener._ + +/** + * :: Experimental :: + * Interface for listening to events related to [[ContinuousQuery ContinuousQueries]]. + * @note The methods are not thread-safe as they may be called from different threads. + */ +@Experimental +abstract class ContinuousQueryListener { + + /** + * Called when a query is started. + * @note This is called synchronously with + * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.stream()`]], + * that is, `onQueryStart` will be called on all listeners before `DataFrameWriter.stream()` + * returns the corresponding [[ContinuousQuery]]. + */ + def onQueryStarted(queryStarted: QueryStarted) + + /** Called when there is some status update (ingestion rate updated, etc. */ + def onQueryProgress(queryProgress: QueryProgress) + + /** Called when a query is stopped, with or without error */ + def onQueryTerminated(queryTerminated: QueryTerminated) +} + + +/** + * :: Experimental :: + * Companion object of [[ContinuousQueryListener]] that defines the listener events. + */ +@Experimental +object ContinuousQueryListener { + + /** Base type of [[ContinuousQueryListener]] events */ + trait Event + + /** Event representing the start of a query */ + class QueryStarted private[sql](val query: ContinuousQuery) extends Event + + /** Event representing any progress updates in a query */ + class QueryProgress private[sql](val query: ContinuousQuery) extends Event + + /** Event representing that termination of a query */ + class QueryTerminated private[sql](val query: ContinuousQuery) extends Event +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 7e388ea602343..62710e72fba8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -21,9 +21,16 @@ import java.lang.Thread.UncaughtExceptionHandler import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.experimental.macros +import scala.reflect.ClassTag import scala.util.Random +import scala.util.control.NonFatal -import org.scalatest.concurrent.Timeouts +import org.scalatest.Assertions +import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} @@ -64,7 +71,7 @@ trait StreamTest extends QueryTest with Timeouts { } /** How long to wait for an active stream to catch up when checking a result. */ - val streamingTimout = 10.seconds + val streamingTimeout = 10.seconds /** A trait for actions that can be performed while testing a streaming DataFrame. */ trait StreamAction @@ -128,7 +135,38 @@ trait StreamTest extends QueryTest with Timeouts { case object StartStream extends StreamAction /** Signals that a failure is expected and should not kill the test. */ - case object ExpectFailure extends StreamAction + case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { + val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] + override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]" + } + + /** Assert that a body is true */ + class Assert(condition: => Boolean, val message: String = "") extends StreamAction { + def run(): Unit = { Assertions.assert(condition) } + override def toString: String = s"Assert(, $message)" + } + + object Assert { + def apply(condition: => Boolean, message: String = ""): Assert = new Assert(condition, message) + def apply(message: String)(body: => Unit): Assert = new Assert( { body; true }, message) + def apply(body: => Unit): Assert = new Assert( { body; true }, "") + } + + /** Assert that a condition on the active query is true */ + class AssertOnQuery(val condition: StreamExecution => Boolean, val message: String) + extends StreamAction { + override def toString: String = s"AssertOnQuery(, $message)" + } + + object AssertOnQuery { + def apply(condition: StreamExecution => Boolean, message: String = ""): AssertOnQuery = { + new AssertOnQuery(condition, message) + } + + def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = { + new AssertOnQuery(condition, message) + } + } /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = @@ -145,6 +183,7 @@ trait StreamTest extends QueryTest with Timeouts { var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null + var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Source, Offset]() val sink = new MemorySink(stream.schema) @@ -170,6 +209,7 @@ trait StreamTest extends QueryTest with Timeouts { def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def testState = s""" |== Progress == @@ -181,16 +221,49 @@ trait StreamTest extends QueryTest with Timeouts { |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} | |== Sink == - |$sink + |${sink.toDebugString} | |== Plan == |${if (currentStream != null) currentStream.lastExecution else ""} - """ + """.stripMargin + + def verify(condition: => Boolean, message: String): Unit = { + try { + Assertions.assert(condition) + } catch { + case NonFatal(e) => + failTest(message, e) + } + } + + def eventually[T](message: String)(func: => T): T = { + try { + Eventually.eventually(Timeout(streamingTimeout)) { + func + } + } catch { + case NonFatal(e) => + failTest(message, e) + } + } + + def failTest(message: String, cause: Throwable = null) = { - def checkState(check: Boolean, error: String) = if (!check) { + // Recursively pretty print a exception with truncated stacktrace and internal cause + def exceptionToString(e: Throwable, prefix: String = ""): String = { + val base = s"$prefix${e.getMessage}" + + e.getStackTrace.take(10).mkString(s"\n$prefix", s"\n$prefix\t", "\n") + if (e.getCause != null) { + base + s"\n$prefix\tCaused by: " + exceptionToString(e.getCause, s"$prefix\t") + } else { + base + } + } + val c = Option(cause).map(exceptionToString(_)) + val m = if (message != null && message.size > 0) Some(message) else None fail( s""" - |Invalid State: $error + |${(m ++ c).mkString(": ")} |$testState """.stripMargin) } @@ -201,9 +274,13 @@ trait StreamTest extends QueryTest with Timeouts { startedTest.foreach { action => action match { case StartStream => - checkState(currentStream == null, "stream already running") - - currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink) + verify(currentStream == null, "stream already running") + lastStream = currentStream + currentStream = + sqlContext + .streams + .startQuery(StreamExecution.nextName, stream, sink) + .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { @@ -213,77 +290,100 @@ trait StreamTest extends QueryTest with Timeouts { }) case StopStream => - checkState(currentStream != null, "can not stop a stream that is not running") - currentStream.stop() - currentStream = null + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.microBatchThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while stopping and waiting for microbatchthread to terminate.") + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { + lastStream = currentStream + currentStream = null + } case DropBatches(num) => - checkState(currentStream == null, "dropping batches while running leads to corruption") + verify(currentStream == null, "dropping batches while running leads to corruption") sink.dropBatches(num) - case ExpectFailure => - try failAfter(streamingTimout) { - while (streamDeathCause == null) { - Thread.sleep(100) + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[ContinuousQueryException] { + currentStream.awaitTermination() } + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.microBatchThread.isAlive) + } + verify(thrownException.query.eq(currentStream), + s"incorrect query reference in exception") + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") } catch { case _: InterruptedException => case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - fail( - s""" - |Timed out while waiting for failure. - |$testState - """.stripMargin) + failTest("Timed out while waiting for failure") + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + streamDeathCause = null } - currentStream = null - streamDeathCause = null + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when not stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") case a: AddData => awaiting.put(a.source, a.addData()) case CheckAnswerRows(expectedAnswer) => - checkState(currentStream != null, "stream not running") + verify(currentStream != null, "stream not running") // Block until all data added has been processed awaiting.foreach { case (source, offset) => - failAfter(streamingTimout) { + failAfter(streamingTimeout) { currentStream.awaitOffset(source, offset) } } val allData = try sink.allData catch { case e: Exception => - fail( - s""" - |Exception while getting data from sink $e - |$testState - """.stripMargin) + failTest("Exception while getting data from sink", e) } QueryTest.sameRows(expectedAnswer, allData).foreach { - error => fail( - s""" - |$error - |$testState - """.stripMargin) + error => failTest(error) } } pos += 1 } } catch { case _: InterruptedException if streamDeathCause != null => - fail( - s""" - |Stream Thread Died - |$testState - """.stripMargin) + failTest("Stream Thread Died") case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - fail( - s""" - |Timed out waiting for stream - |$testState - """.stripMargin) + failTest("Timed out waiting for stream") } finally { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop() @@ -335,7 +435,8 @@ trait StreamTest extends QueryTest with Timeouts { case r if r < 0.7 => // AddData addRandomData() - case _ => // StartStream + case _ => // StopStream + addCheck() actions += StopStream running = false } @@ -345,4 +446,59 @@ trait StreamTest extends QueryTest with Timeouts { addCheck() testStream(ds)(actions: _*) } + + + object AwaitTerminationTester { + + trait ExpectedBehavior + + /** Expect awaitTermination to not be blocked */ + case object ExpectNotBlocked extends ExpectedBehavior + + /** Expect awaitTermination to get blocked */ + case object ExpectBlocked extends ExpectedBehavior + + /** Expect awaitTermination to throw an exception */ + case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E]) + extends ExpectedBehavior + + private val DEFAULT_TEST_TIMEOUT = 1 second + + def test( + expectedBehavior: ExpectedBehavior, + awaitTermFunc: () => Unit, + testTimeout: Span = DEFAULT_TEST_TIMEOUT + ): Unit = { + + expectedBehavior match { + case ExpectNotBlocked => + withClue("Got blocked when expected non-blocking.") { + failAfter(testTimeout) { + awaitTermFunc() + } + } + + case ExpectBlocked => + withClue("Was not blocked when expected.") { + intercept[TestFailedDueToTimeoutException] { + failAfter(testTimeout) { + awaitTermFunc() + } + } + } + + case e: ExpectException[_] => + val thrownException = + withClue(s"Did not throw ${e.t.runtimeClass.getSimpleName} when expected.") { + intercept[ContinuousQueryException] { + failAfter(testTimeout) { + awaitTermFunc() + } + } + } + assert(thrownException.cause.getClass === e.t.runtimeClass, + "exception of incorrect type was throw") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala new file mode 100644 index 0000000000000..daf08efca4e3a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.concurrent.Future +import scala.util.Random +import scala.util.control.NonFatal + +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} +import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation} +import org.apache.spark.sql.test.SharedSQLContext + +class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import AwaitTerminationTester._ + import testImplicits._ + + override val streamingTimeout = 20.seconds + + before { + assert(sqlContext.streams.active.isEmpty) + sqlContext.streams.resetTerminated() + } + + after { + assert(sqlContext.streams.active.isEmpty) + sqlContext.streams.resetTerminated() + } + + test("listing") { + val (m1, ds1) = makeDataset + val (m2, ds2) = makeDataset + val (m3, ds3) = makeDataset + + withQueriesOn(ds1, ds2, ds3) { queries => + require(queries.size === 3) + assert(sqlContext.streams.active.toSet === queries.toSet) + val (q1, q2, q3) = (queries(0), queries(1), queries(2)) + + assert(sqlContext.streams.get(q1.name).eq(q1)) + assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(sqlContext.streams.get(q3.name).eq(q3)) + intercept[IllegalArgumentException] { + sqlContext.streams.get("non-existent-name") + } + + q1.stop() + + assert(sqlContext.streams.active.toSet === Set(q2, q3)) + val ex1 = withClue("no error while getting non-active query") { + intercept[IllegalArgumentException] { + sqlContext.streams.get(q1.name) + } + } + assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") + assert(sqlContext.streams.get(q2.name).eq(q2)) + + m2.addData(0) // q2 should terminate with error + + eventually(Timeout(streamingTimeout)) { + require(!q2.isActive) + require(q2.exception.isDefined) + } + val ex2 = withClue("no error while getting non-active query") { + intercept[IllegalArgumentException] { + sqlContext.streams.get(q2.name).eq(q2) + } + } + + assert(sqlContext.streams.active.toSet === Set(q3)) + } + } + + test("awaitAnyTermination without timeout and resetTerminated") { + val datasets = Seq.fill(5)(makeDataset._2) + withQueriesOn(datasets: _*) { queries => + require(queries.size === datasets.size) + assert(sqlContext.streams.active.toSet === queries.toSet) + + // awaitAnyTermination should be blocking + testAwaitAnyTermination(ExpectBlocked) + + // Stop a query asynchronously and see if it is reported through awaitAnyTermination + val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false) + testAwaitAnyTermination(ExpectNotBlocked) + require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should be non-blocking + testAwaitAnyTermination(ExpectNotBlocked) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination(ExpectBlocked) + + // Terminate a query asynchronously with exception and see awaitAnyTermination throws + // the exception + val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) + testAwaitAnyTermination(ExpectException[SparkException]) + require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should throw the exception + testAwaitAnyTermination(ExpectException[SparkException]) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination(ExpectBlocked) + + // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws + // the exception + val q3 = stopRandomQueryAsync(10 milliseconds, withError = false) + testAwaitAnyTermination(ExpectNotBlocked) + require(!q3.isActive) + val q4 = stopRandomQueryAsync(10 milliseconds, withError = true) + eventually(Timeout(streamingTimeout)) { require(!q4.isActive) } + // After q4 terminates with exception, awaitAnyTerm should start throwing exception + testAwaitAnyTermination(ExpectException[SparkException]) + } + } + + test("awaitAnyTermination with timeout and resetTerminated") { + val datasets = Seq.fill(6)(makeDataset._2) + withQueriesOn(datasets: _*) { queries => + require(queries.size === datasets.size) + assert(sqlContext.streams.active.toSet === queries.toSet) + + // awaitAnyTermination should be blocking or non-blocking depending on timeout values + testAwaitAnyTermination( + ExpectBlocked, + awaitTimeout = 2 seconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 50 milliseconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + // Stop a query asynchronously within timeout and awaitAnyTerm should be unblocked + val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false) + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 1 second, + expectedReturnedValue = true, + testBehaviorFor = 2 seconds) + require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should be non-blocking even if timeout is high + testAwaitAnyTermination( + ExpectNotBlocked, awaitTimeout = 2 seconds, expectedReturnedValue = true) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination( + ExpectBlocked, + awaitTimeout = 2 seconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + // Terminate a query asynchronously with exception within timeout, awaitAnyTermination should + // throws the exception + val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 1 second, + testBehaviorFor = 2 seconds) + require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should throw the exception + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 1 second, + testBehaviorFor = 2 seconds) + + // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked + sqlContext.streams.resetTerminated() + val q3 = stopRandomQueryAsync(1 second, withError = true) + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 100 milliseconds, + expectedReturnedValue = false, + testBehaviorFor = 2 seconds) + + // After that query is stopped, awaitAnyTerm should throw exception + eventually(Timeout(streamingTimeout)) { require(!q3.isActive) } // wait for query to stop + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 100 milliseconds, + testBehaviorFor = 2 seconds) + + + // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws + // the exception + sqlContext.streams.resetTerminated() + + val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) + testAwaitAnyTermination( + ExpectNotBlocked, awaitTimeout = 1 second, expectedReturnedValue = true) + require(!q4.isActive) + val q5 = stopRandomQueryAsync(10 milliseconds, withError = true) + eventually(Timeout(streamingTimeout)) { require(!q5.isActive) } + // After q5 terminates with exception, awaitAnyTerm should start throwing exception + testAwaitAnyTermination(ExpectException[SparkException], awaitTimeout = 100 milliseconds) + } + } + + + /** Run a body of code by defining a query each on multiple datasets */ + private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = { + failAfter(streamingTimeout) { + val queries = withClue("Error starting queries") { + datasets.map { ds => + @volatile var query: StreamExecution = null + try { + val df = ds.toDF + query = sqlContext + .streams + .startQuery(StreamExecution.nextName, df, new MemorySink(df.schema)) + .asInstanceOf[StreamExecution] + } catch { + case NonFatal(e) => + if (query != null) query.stop() + throw e + } + query + } + } + try { + body(queries) + } finally { + queries.foreach(_.stop()) + } + } + } + + /** Test the behavior of awaitAnyTermination */ + private def testAwaitAnyTermination( + expectedBehavior: ExpectedBehavior, + expectedReturnedValue: Boolean = false, + awaitTimeout: Span = null, + testBehaviorFor: Span = 2 seconds + ): Unit = { + + def awaitTermFunc(): Unit = { + if (awaitTimeout != null && awaitTimeout.toMillis > 0) { + val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") + } else { + sqlContext.streams.awaitAnyTermination() + } + } + + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor) + } + + /** Stop a random active query either with `stop()` or with an error */ + private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): ContinuousQuery = { + + import scala.concurrent.ExecutionContext.Implicits.global + + val activeQueries = sqlContext.streams.active + val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) + Future { + Thread.sleep(stopAfter.toMillis) + if (withError) { + logDebug(s"Terminating query ${queryToStop.name} with error") + queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { + case StreamingRelation(memoryStream, _) => + memoryStream.asInstanceOf[MemoryStream[Int]].addData(0) + } + } else { + logDebug(s"Stopping query ${queryToStop.name}") + queryToStop.stop() + } + } + queryToStop + } + + private def makeDataset: (MemoryStream[Int], Dataset[Int]) = { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS.map(6 / _) + (inputData, mapped) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala new file mode 100644 index 0000000000000..dac1a398ff5e6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} +import org.apache.spark.sql.test.SharedSQLContext + +class ContinuousQuerySuite extends StreamTest with SharedSQLContext { + + import AwaitTerminationTester._ + import testImplicits._ + + test("lifecycle states and awaitTermination") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + AssertOnQuery(_.exception.isEmpty), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + TestAwaitTermination(ExpectBlocked), + TestAwaitTermination(ExpectBlocked, timeoutMs = 2000), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false), + StopStream, + AssertOnQuery(_.isActive === false), + AssertOnQuery(_.exception.isEmpty), + TestAwaitTermination(ExpectNotBlocked), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true), + StartStream, + AssertOnQuery(_.isActive === true), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.isActive === false), + TestAwaitTermination(ExpectException[SparkException]), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), + AssertOnQuery( + q => q.exception.get.startOffset.get === q.streamProgress.toCompositeOffset(Seq(inputData)), + "incorrect start offset on exception") + ) + } + + test("source and sink statuses") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(6 / _) + + testStream(mapped)( + AssertOnQuery(_.sourceStatuses.length === 1), + AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.sourceStatuses(0).offset === None), + AssertOnQuery(_.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.sinkStatus.offset === None), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))), + AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))), + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))), + AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))), + AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))) + ) + } + + /** + * A [[StreamAction]] to test the behavior of `ContinuousQuery.awaitTermination()`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + case class TestAwaitTermination( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int = -1, + expectedReturnValue: Boolean = false + ) extends AssertOnQuery( + TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue), + "Error testing awaitTermination behavior" + ) { + override def toString(): String = { + s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " + + s"expectedReturnValue = $expectedReturnValue)" + } + } + + object TestAwaitTermination { + + /** + * Tests the behavior of `ContinuousQuery.awaitTermination`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + def assertOnQueryCondition( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int, + expectedReturnValue: Boolean + )(q: StreamExecution): Boolean = { + + def awaitTermFunc(): Unit = { + if (timeoutMs <= 0) { + q.awaitTermination() + } else { + val returnedValue = q.awaitTermination(timeoutMs) + assert(returnedValue === expectedReturnValue, "Returned value does not match expected") + } + } + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + true // If the control reached here, then everything worked as expected + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index b762f9b90ed86..f060c6f623b6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.streaming.test -import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, ContinuousQuery, SQLContext, StreamTest} import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source} import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.test.SharedSQLContext @@ -57,9 +59,13 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext { +class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ + after { + sqlContext.streams.active.foreach(_.stop()) + } + test("resolve default source") { sqlContext.read .format("org.apache.spark.sql.streaming.test") @@ -188,4 +194,63 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext { assert(LastOptions.parameters("boolOpt") == "false") assert(LastOptions.parameters("doubleOpt") == "6.7") } + + test("unique query names") { + + /** Start a query with a specific name */ + def startQueryWithName(name: String = ""): ContinuousQuery = { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .queryName(name) + .stream() + } + + /** Start a query without specifying a name */ + def startQueryWithoutName(): ContinuousQuery = { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .stream() + } + + /** Get the names of active streams */ + def activeStreamNames: Set[String] = { + val streams = sqlContext.streams.active + val names = streams.map(_.name).toSet + assert(streams.length === names.size, s"names of active queries are not unique: $names") + names + } + + val q1 = startQueryWithName("name") + + // Should not be able to start another query with the same name + intercept[IllegalArgumentException] { + startQueryWithName("name") + } + assert(activeStreamNames === Set("name")) + + // Should be able to start queries with other names + val q3 = startQueryWithName("another-name") + assert(activeStreamNames === Set("name", "another-name")) + + // Should be able to start queries with auto-generated names + val q4 = startQueryWithoutName() + assert(activeStreamNames.contains(q4.name)) + + // Should not be able to start a query with same auto-generated name + intercept[IllegalArgumentException] { + startQueryWithName(q4.name) + } + + // Should be able to start query with that name after stopping the previous query + q1.stop() + val q5 = startQueryWithName("name") + assert(activeStreamNames.contains("name")) + sqlContext.streams.active.foreach(_.stop()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala new file mode 100644 index 0000000000000..d6cc6ad86bf61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.util.control.NonFatal + +import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester._ +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated} + +class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + assert(sqlContext.streams.active.isEmpty) + assert(addedListeners.isEmpty) + } + + test("single listener") { + val listener = new QueryStatusCollector + val input = MemoryStream[Int] + withListenerAdded(listener) { + testStream(input.toDS)( + StartStream, + Assert("Incorrect query status in onQueryStarted") { + val status = listener.startStatus + assert(status != null) + assert(status.active == true) + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).description.contains("Memory")) + + // The source and sink offsets must be None as this must be called before the + // batches have started + assert(status.sourceStatuses(0).offset === None) + assert(status.sinkStatus.offset === None) + + // No progress events or termination events + assert(listener.progressStatuses.isEmpty) + assert(listener.terminationStatus === null) + }, + AddDataMemory(input, Seq(1, 2, 3)), + CheckAnswer(1, 2, 3), + Assert("Incorrect query status in onQueryProgress") { + eventually(Timeout(streamingTimeout)) { + + // There should be only on progress event as batch has been processed + assert(listener.progressStatuses.size === 1) + val status = listener.progressStatuses.peek() + assert(status != null) + assert(status.active == true) + assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) + assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))) + + // No termination events + assert(listener.terminationStatus === null) + } + }, + StopStream, + Assert("Incorrect query status in onQueryTerminated") { + eventually(Timeout(streamingTimeout)) { + val status = listener.terminationStatus + assert(status != null) + + assert(status.active === false) // must be inactive by the time onQueryTerm is called + assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) + assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))) + } + listener.checkAsyncErrors() + } + ) + } + } + + test("adding and removing listener") { + def isListenerActive(listener: QueryStatusCollector): Boolean = { + listener.reset() + testStream(MemoryStream[Int].toDS)( + StartStream, + StopStream + ) + listener.startStatus != null + } + + try { + val listener1 = new QueryStatusCollector + val listener2 = new QueryStatusCollector + + sqlContext.streams.addListener(listener1) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === false) + sqlContext.streams.addListener(listener2) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === true) + sqlContext.streams.removeListener(listener1) + assert(isListenerActive(listener1) === false) + assert(isListenerActive(listener2) === true) + } finally { + addedListeners.foreach(sqlContext.streams.removeListener) + } + } + + test("event ordering") { + val listener = new QueryStatusCollector + withListenerAdded(listener) { + for (i <- 1 to 100) { + listener.reset() + require(listener.startStatus === null) + testStream(MemoryStream[Int].toDS)( + StartStream, + Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"), + StopStream, + Assert { listener.checkAsyncErrors() } + ) + } + } + } + + + private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { + @volatile var query: StreamExecution = null + try { + failAfter(1 minute) { + sqlContext.streams.addListener(listener) + body + } + } finally { + sqlContext.streams.removeListener(listener) + } + } + + private def addedListeners(): Array[ContinuousQueryListener] = { + val listenerBusMethod = + PrivateMethod[ContinuousQueryListenerBus]('listenerBus) + val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() + listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) + } + + class QueryStatusCollector extends ContinuousQueryListener { + + private val asyncTestWaiter = new Waiter // to catch errors in the async listener events + + @volatile var startStatus: QueryStatus = null + @volatile var terminationStatus: QueryStatus = null + val progressStatuses = new ConcurrentLinkedQueue[QueryStatus] + + def reset(): Unit = { + startStatus = null + terminationStatus = null + progressStatuses.clear() + + // To reset the waiter + try asyncTestWaiter.await(timeout(1 milliseconds)) catch { + case NonFatal(e) => + } + } + + def checkAsyncErrors(): Unit = { + asyncTestWaiter.await(timeout(streamingTimeout)) + } + + + override def onQueryStarted(queryStarted: QueryStarted): Unit = { + asyncTestWaiter { + startStatus = QueryStatus(queryStarted.query) + } + } + + override def onQueryProgress(queryProgress: QueryProgress): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryProgress called before onQueryStarted") + progressStatuses.add(QueryStatus(queryProgress.query)) + } + } + + override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryTerminated called before onQueryStarted") + terminationStatus = QueryStatus(queryTerminated.query) + } + asyncTestWaiter.dismiss() + } + } + + case class QueryStatus( + active: Boolean, + expection: Option[Exception], + sourceStatuses: Array[SourceStatus], + sinkStatus: SinkStatus) + + object QueryStatus { + def apply(query: ContinuousQuery): QueryStatus = { + QueryStatus(query.isActive, query.exception, query.sourceStatuses, query.sinkStatus) + } + } +} From 719973b05ef6d6b9fbb83d76aebac6454ae84fad Mon Sep 17 00:00:00 2001 From: raela Date: Wed, 10 Feb 2016 17:00:54 -0800 Subject: [PATCH 1859/2089] [SPARK-13274] Fix Aggregator Links on GroupedDataset Scala API Update Aggregator links to point to #org.apache.spark.sql.expressions.Aggregator Author: raela Closes #11158 from raelawang/master. --- .../scala/org/apache/spark/sql/GroupedDataset.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c0e28f2dc5bd6..53cb8eb524947 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -101,7 +101,8 @@ class GroupedDataset[K, V] private[sql]( * * This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an [[Aggregator]]. + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group @@ -128,7 +129,8 @@ class GroupedDataset[K, V] private[sql]( * * This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an [[Aggregator]]. + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group @@ -148,7 +150,8 @@ class GroupedDataset[K, V] private[sql]( * * This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an [[Aggregator]]. + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group @@ -169,7 +172,8 @@ class GroupedDataset[K, V] private[sql]( * * This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an [[Aggregator]]. + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group From 663cc400f3b927633e47df07eea409da0e9ae70e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Feb 2016 10:44:39 +0800 Subject: [PATCH 1860/2089] [SPARK-12725][SQL] Resolving Name Conflicts in SQL Generation and Name Ambiguity Caused by Internally Generated Expressions Some analysis rules generate aliases or auxiliary attribute references with the same name but different expression IDs. For example, `ResolveAggregateFunctions` introduces `havingCondition` and `aggOrder`, and `DistinctAggregationRewriter` introduces `gid`. This is OK for normal query execution since these attribute references get expression IDs. However, it's troublesome when converting resolved query plans back to SQL query strings since expression IDs are erased. Here's an example Spark 1.6.0 snippet for illustration: ```scala sqlContext.range(10).select('id as 'a, 'id as 'b).registerTempTable("t") sqlContext.sql("SELECT SUM(a) FROM t GROUP BY a, b ORDER BY COUNT(a), COUNT(b)").explain(true) ``` The above code produces the following resolved plan: ``` == Analyzed Logical Plan == _c0: bigint Project [_c0#101L] +- Sort [aggOrder#102L ASC,aggOrder#103L ASC], true +- Aggregate [a#47L,b#48L], [(sum(a#47L),mode=Complete,isDistinct=false) AS _c0#101L,(count(a#47L),mode=Complete,isDistinct=false) AS aggOrder#102L,(count(b#48L),mode=Complete,isDistinct=false) AS aggOrder#103L] +- Subquery t +- Project [id#46L AS a#47L,id#46L AS b#48L] +- LogicalRDD [id#46L], MapPartitionsRDD[44] at range at :26 ``` Here we can see that both aggregate expressions in `ORDER BY` are extracted into an `Aggregate` operator, and both of them are named `aggOrder` with different expression IDs. The solution is to automatically add the expression IDs into the attribute name for the Alias and AttributeReferences that are generated by Analyzer in SQL Generation. In this PR, it also resolves another issue. Users could use the same name as the internally generated names. The duplicate names should not cause name ambiguity. When resolving the column, Catalyst should not pick the column that is internally generated. Could you review the solution? marmbrus liancheng I did not set the newly added flag for all the alias and attribute reference generated by Analyzers. Please let me know if I should do it? Thank you! Author: gatorsmile Closes #11050 from gatorsmile/namingConflicts. --- .../sql/catalyst/analysis/Analyzer.scala | 20 +++++--- .../DistinctAggregationRewriter.scala | 3 +- .../expressions/namedExpressions.scala | 47 ++++++++++++------- .../sql/catalyst/planning/patterns.scala | 7 ++- .../catalyst/plans/logical/LogicalPlan.scala | 5 +- .../spark/sql/catalyst/trees/TreeNode.scala | 2 + .../analysis/AnalysisErrorSuite.scala | 11 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 16 +++++-- .../sql/hive/LogicalPlanToSQLSuite.scala | 3 +- 9 files changed, 78 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4d53b232d5510..62b241f05270a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -416,9 +416,10 @@ class Analyzer( case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => val newChildren = expandStarExpressions(args, child) UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil - case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => + case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => val newChildren = expandStarExpressions(args, child) - Alias(child = f.copy(children = newChildren), name)() :: Nil + Alias(child = f.copy(children = newChildren), name)( + isGenerated = a.isGenerated) :: Nil case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) @@ -528,7 +529,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)() + case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) case other => other } } @@ -734,7 +735,10 @@ class Analyzer( // Try resolving the condition of the filter as though it is in the aggregate clause val aggregatedCondition = - Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child) + Aggregate( + grouping, + Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator @@ -759,7 +763,8 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) + val aliasedOrdering = + unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -1190,7 +1195,7 @@ class Analyzer( leafNondeterministic.map { e => val ne = e match { case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() + case _ => Alias(e, "_nondeterministic")(isGenerated = true) } new TreeNodeRef(e) -> ne } @@ -1355,7 +1360,8 @@ object CleanupAliases extends Rule[LogicalPlan] { def trimNonTopLevelAliases(e: Expression): Expression = e match { case a: Alias => - Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + Alias(trimAliases(a.child), a.name)( + a.exprId, a.qualifiers, a.explicitMetadata, a.isGenerated) case other => trimAliases(other) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 4e7d1341028ca..5dfce89bd68a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,7 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // Aggregation strategy can handle the query with single distinct if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. - val gid = new AttributeReference("gid", IntegerType, false)() + val gid = + new AttributeReference("gid", IntegerType, false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7983501ada9bd..207b8a0a88556 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -79,6 +79,9 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty + /** Returns true if the expression is generated by Catalyst */ + def isGenerated: java.lang.Boolean = false + /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression @@ -114,16 +117,21 @@ abstract class Attribute extends LeafExpression with NamedExpression { * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. * - * @param child the computation being performed - * @param name the name to be associated with the result of computing [[child]]. + * @param child The computation being performed + * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. + * @param qualifiers A list of strings that can be used to referred to this attribute in a fully + * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. + * @param isGenerated A flag to indicate if this alias is generated by Catalyst */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil, - val explicitMetadata: Option[Metadata] = None) + val explicitMetadata: Option[Metadata] = None, + override val isGenerated: java.lang.Boolean = false) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -148,11 +156,13 @@ case class Alias(child: Expression, name: String)( } def newInstance(): NamedExpression = - Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata) + Alias(child, name)( + qualifiers = qualifiers, explicitMetadata = explicitMetadata, isGenerated = isGenerated) override def toAttribute: Attribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) + AttributeReference(name, child.dataType, child.nullable, metadata)( + exprId, qualifiers, isGenerated) } else { UnresolvedAttribute(name) } @@ -161,7 +171,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifiers :: explicitMetadata :: Nil + exprId :: qualifiers :: explicitMetadata :: isGenerated :: Nil } override def equals(other: Any): Boolean = other match { @@ -174,7 +184,8 @@ case class Alias(child: Expression, name: String)( override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") - s"${child.sql} AS $qualifiersString`$name`" + val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name" + s"${child.sql} AS $qualifiersString`$aliasName`" } } @@ -187,9 +198,10 @@ case class Alias(child: Expression, name: String)( * @param metadata The metadata of this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. - * @param qualifiers a list of strings that can be used to referred to this attribute in a fully + * @param qualifiers A list of strings that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. + * @param isGenerated A flag to indicate if this reference is generated by Catalyst */ case class AttributeReference( name: String, @@ -197,7 +209,8 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifiers: Seq[String] = Nil) + val qualifiers: Seq[String] = Nil, + override val isGenerated: java.lang.Boolean = false) extends Attribute with Unevaluable { /** @@ -234,7 +247,8 @@ case class AttributeReference( } override def newInstance(): AttributeReference = - AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) + AttributeReference(name, dataType, nullable, metadata)( + qualifiers = qualifiers, isGenerated = isGenerated) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -243,7 +257,7 @@ case class AttributeReference( if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers, isGenerated) } } @@ -251,7 +265,7 @@ case class AttributeReference( if (name == newName) { this } else { - AttributeReference(newName, dataType, nullable)(exprId, qualifiers) + AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifiers, isGenerated) } } @@ -262,7 +276,7 @@ case class AttributeReference( if (newQualifiers.toSet == qualifiers.toSet) { this } else { - AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers, isGenerated) } } @@ -270,12 +284,12 @@ case class AttributeReference( if (exprId == newExprId) { this } else { - AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers) + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers, isGenerated) } } override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifiers :: Nil + exprId :: qualifiers :: isGenerated :: Nil } override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -287,7 +301,8 @@ case class AttributeReference( override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") - s"$qualifiersString`$name`" + val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name" + s"$qualifiersString`$attrRefName`" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f0ee124e88a9f..7302b63646d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -78,10 +78,13 @@ object PhysicalOperation extends PredicateHelper { private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + aliases.get(ref) + .map(Alias(_, name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated)) + .getOrElse(a) case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + aliases.get(a) + .map(Alias(_, a.name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated)).getOrElse(a) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index d8944a424156e..18b7bde906fda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -139,7 +139,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { case a: Alias => // As the root of the expression, Alias will always take an arbitrary exprId, we need // to erase that for equality testing. - val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) + val cleanedExprId = + Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) BindReferences.bindReference(cleanedExprId, input, allowFailures = true) case other => BindReferences.bindReference(other, input, allowFailures = true) } @@ -222,7 +223,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { + if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { Option((attribute.withName(nameParts.head), nameParts.tail.toList)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2df0683f9fa16..30df2a84f62c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -656,6 +656,8 @@ object TreeNode { case t if t <:< definitions.DoubleTpe => value.asInstanceOf[JDouble].num: java.lang.Double + case t if t <:< localTypeOf[java.lang.Boolean] => + value.asInstanceOf[JBool].value: java.lang.Boolean case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fc35959f20547..e0cec09742eba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @BeanInfo @@ -176,6 +176,13 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation.select('abcd), "cannot resolve" :: "abcd" :: Nil) + errorTest( + "unresolved attributes with a generated name", + testRelation2.groupBy('a)(max('b)) + .where(sum('b) > 0) + .orderBy('havingCondition.asc), + "cannot resolve" :: "havingCondition" :: Nil) + errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c02133ffc8540..3ea4adcaa6424 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -998,12 +998,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") { val df = Seq(1 -> 2).toDF("i", "j") - val query = df.groupBy('i) - .agg(max('j).as("aggOrdering")) + val query1 = df.groupBy('i) + .agg(max('j).as("aggOrder")) .orderBy(sum('j)) - checkAnswer(query, Row(1, 2)) + checkAnswer(query1, Row(1, 2)) + + // In the plan, there are two attributes having the same name 'havingCondition' + // One is a user-provided alias name; another is an internally generated one. + val query2 = df.groupBy('i) + .agg(max('j).as("havingCondition")) + .where(sum('j) > 0) + .orderBy('havingCondition.asc) + checkAnswer(query2, Row(1, 2)) } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 1f731db26f387..129bfe0a7dfd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -92,12 +92,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") } - // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query // execution since these aliases have different expression ID. But this introduces name collision // when converting resolved plans back to SQL query strings as expression IDs are stripped. - ignore("aggregate function in order by clause with multiple order keys") { + test("aggregate function in order by clause with multiple order keys") { checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") } From 0f09f0226983cdc409ef504dff48395787dc844f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Feb 2016 11:08:21 +0800 Subject: [PATCH 1861/2089] [SPARK-13205][SQL] SQL Generation Support for Self Join This PR addresses two issues: - Self join does not work in SQL Generation - When creating new instances for `LogicalRelation`, `metastoreTableIdentifier` is lost. liancheng Could you please review the code changes? Thank you! Author: gatorsmile Closes #11084 from gatorsmile/selfJoinInSQLGen. --- .../sql/execution/datasources/LogicalRelation.scala | 6 +++++- .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 10 +++++++++- .../apache/spark/sql/hive/LogicalPlanToSQLSuite.scala | 8 ++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index fa97f3d7199ed..0e0748ff32df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -76,7 +76,11 @@ case class LogicalRelation( /** Used to lookup original attribute capitalization */ val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance(): this.type = LogicalRelation(relation).asInstanceOf[this.type] + def newInstance(): this.type = + LogicalRelation( + relation, + expectedOutputAttributes, + metastoreTableIdentifier).asInstanceOf[this.type] override def simpleString: String = s"Relation[${output.mkString(",")}] $relation" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index fc5725d6915ea..4b75e60f8d5f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -142,7 +142,15 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi Some(s"`$database`.`$table`") case Subquery(alias, child) => - toSQL(child).map(childSQL => s"($childSQL) AS $alias") + toSQL(child).map( childSQL => + child match { + // Parentheses is not used for persisted data source relations + // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 + case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => + s"$childSQL AS $alias" + case _ => + s"($childSQL) AS $alias" + }) case Join(left, right, joinType, condition) => for { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 129bfe0a7dfd8..80ae312d913de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -104,6 +104,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") } + test("self join") { + checkHiveQl("SELECT x.key FROM t1 x JOIN t1 y ON x.key = y.key") + } + + test("self join with group by") { + checkHiveQl("SELECT x.key, COUNT(*) FROM t1 x JOIN t1 y ON x.key = y.key group by x.key") + } + test("three-child union") { checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0") } From b5761d150b66ee0ae5f1be897d9d7a1abb039884 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 20:13:38 -0800 Subject: [PATCH 1862/2089] [SPARK-12706] [SQL] grouping() and grouping_id() Grouping() returns a column is aggregated or not, grouping_id() returns the aggregation levels. grouping()/grouping_id() could be used with window function, but does not work in having/sort clause, will be fixed by another PR. The GROUPING__ID/grouping_id() in Hive is wrong (according to docs), we also did it wrongly, this PR change that to match the behavior in most databases (also the docs of Hive). Author: Davies Liu Closes #10677 from davies/grouping. --- python/pyspark/sql/dataframe.py | 22 ++++---- python/pyspark/sql/functions.py | 44 ++++++++++++++++ .../spark/sql/catalyst/CatalystQl.scala | 13 +++-- .../sql/catalyst/analysis/Analyzer.scala | 48 ++++++++++++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 ++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/aggregate/interfaces.scala | 1 - .../sql/catalyst/expressions/grouping.scala | 23 +++++++++ .../plans/logical/basicOperators.scala | 2 +- .../org/apache/spark/sql/functions.scala | 46 +++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 44 ++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 50 +++++++++++++++++++ .../execution/HiveCompatibilitySuite.scala | 8 +-- ...CUBE #1-0-63b61fb3f0e74226001ad279be440864 | 12 ++--- ...llup #1-0-a78e3dbf242f240249e36b3d3fd0926a | 12 ++--- ...llup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 | 20 ++++---- ...llup #3-0-9257085d123728730be96b6d9fbb84ce | 20 ++++---- 17 files changed, 309 insertions(+), 63 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3a8c8305ee3d8..3104e41407114 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -887,8 +887,8 @@ def groupBy(self, *cols): [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(['name', df.age]).count().collect() - [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] + >>> sorted(df.groupBy(['name', df.age]).count().collect()) + [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData @@ -900,15 +900,15 @@ def rollup(self, *cols): Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. - >>> df.rollup('name', df.age).count().show() + >>> df.rollup("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ - |Alice| 2| 1| - | Bob| 5| 1| - | Bob|null| 1| | null|null| 2| |Alice|null| 1| + |Alice| 2| 1| + | Bob|null| 1| + | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) @@ -921,17 +921,17 @@ def cube(self, *cols): Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. - >>> df.cube('name', df.age).count().show() + >>> df.cube("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ + | null|null| 2| | null| 2| 1| - |Alice| 2| 1| - | Bob| 5| 1| | null| 5| 1| - | Bob|null| 1| - | null|null| 2| |Alice|null| 1| + |Alice| 2| 1| + | Bob|null| 1| + | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0d5708526701e..680493e0e689e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -288,6 +288,50 @@ def first(col, ignorenulls=False): return Column(jc) +@since(2.0) +def grouping(col): + """ + Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + or not, returns 1 for aggregated or 0 for not aggregated in the result set. + + >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() + +-----+--------------+--------+ + | name|grouping(name)|sum(age)| + +-----+--------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+--------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping(_to_java_column(col)) + return Column(jc) + + +@since(2.0) +def grouping_id(*cols): + """ + Aggregate function: returns the level of grouping, equals to + + (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + + Note: the list of columns should match with grouping columns exactly, or empty (means all the + grouping columns). + + >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() + +-----+------------+--------+ + | name|groupingid()|sum(age)| + +-----+------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index a42360d5629f8..8099751900a42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -186,8 +186,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C * * The bitmask denotes the grouping expressions validity for a grouping set, * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. + * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of + * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively. */ protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { val (keyASTs, setASTs) = children.partition { @@ -198,12 +198,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val keys = keyASTs.map(nodeToExpr) val keyMap = keyASTs.zipWithIndex.toMap + val mask = (1 << keys.length) - 1 val bitmasks: Seq[Int] = setASTs.map { case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => - columns.foldLeft(0)((bitmap, col) => { - val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) - bitmap | 1 << keyIndex.getOrElse( + columns.foldLeft(mask)((bitmap, col) => { + val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse( throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (keys.length - 1 - keyIndex)) }) case _ => sys.error("Expect GROUPING SETS clause") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 62b241f05270a..c0fa79612a007 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -238,14 +238,39 @@ class Analyzer( } }.toMap - val aggregations: Seq[NamedExpression] = x.aggregations.map { - // If an expression is an aggregate (contains a AggregateExpression) then we dont change - // it so that the aggregation is computed on the unmodified value of its argument - // expressions. - case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr - // If not then its a grouping expression and we need to use the modified (with nulls from - // Expand) value of the expression. - case expr => expr.transformDown { + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => + // collect all the found AggregateExpression, so we can check an expression is part of + // any AggregateExpression or not. + val aggsBuffer = ArrayBuffer[Expression]() + // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. + def isPartOfAggregation(e: Expression): Boolean = { + aggsBuffer.exists(a => a.find(_ eq e).isDefined) + } + expr.transformDown { + // AggregateExpression should be computed on the unmodified value of its argument + // expressions, so we should not replace any references to grouping expression + // inside it. + case e: AggregateExpression => + aggsBuffer += e + e + case e if isPartOfAggregation(e) => e + case e: GroupingID => + if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { + gid + } else { + throw new AnalysisException( + s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + + s"grouping columns (${x.groupByExprs.mkString(",")})") + } + case Grouping(col: Expression) => + val idx = x.groupByExprs.indexOf(col) + if (idx >= 0) { + Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)), + Literal(1)), ByteType) + } else { + throw new AnalysisException(s"Column of grouping ($col) can't be found " + + s"in grouping columns ${x.groupByExprs.mkString(",")}") + } case e => groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) }.asInstanceOf[NamedExpression] @@ -819,8 +844,11 @@ class Analyzer( } } + private def isAggregateExpression(e: Expression): Boolean = { + e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] + } def containsAggregate(condition: Expression): Boolean = { - condition.find(_.isInstanceOf[AggregateExpression]).isDefined + condition.find(isAggregateExpression).isDefined } } @@ -1002,7 +1030,7 @@ class Analyzer( _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). - case wf : WindowFunction => + case wf: WindowFunction => val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4a2f2b8bc6e4c..fe053b9a0b27f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -70,6 +70,11 @@ trait CheckAnalysis { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + case g: Grouping => + failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") + case g: GroupingID => + failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") + case w @ WindowExpression(AggregateExpression(_, _, true), _) => failAnalysis(s"Distinct window functions are not supported: $w") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d9009e3848e58..1be97c7b81197 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -291,6 +291,8 @@ object FunctionRegistry { // grouping sets expression[Cube]("cube"), expression[Rollup]("rollup"), + expression[Grouping]("grouping"), + expression[GroupingID]("grouping_id"), // window functions expression[Lead]("lead"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 561fa3321d8fa..f88a57a254b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -344,4 +344,3 @@ abstract class DeclarativeAggregate def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 2997ee879d479..a204060630050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -41,3 +41,26 @@ trait GroupingSet extends Expression with CodegenFallback { case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} + +/** + * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. + * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. + */ +case class Grouping(child: Expression) extends Expression with Unevaluable { + override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) + override def children: Seq[Expression] = child :: Nil + override def dataType: DataType = ByteType + override def nullable: Boolean = false +} + +/** + * GroupingID is a function that computes the level of grouping. + * + * If groupByExprs is empty, it means all grouping expressions in GroupingSets. + */ +case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable { + override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) + override def children: Seq[Expression] = groupByExprs + override def dataType: DataType = IntegerType + override def nullable: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 57575f9ee09ab..e8e0a78904a32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -412,7 +412,7 @@ private[sql] object Expand { var bit = exprs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set += exprs(bit) + if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1) bit -= 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b970eee4e31a7..d34d377ab66dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -396,6 +396,52 @@ object functions extends LegacyFunctions { */ def first(columnName: String): Column = first(Column(columnName)) + + /** + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping(e: Column): Column = Column(Grouping(e.expr)) + + /** + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping(columnName: String): Column = grouping(Column(columnName)) + + /** + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * grouping columns). + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) + + /** + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping_id(colName: String, colNames: String*): Column = { + grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) + } + /** * Aggregate function: returns the kurtosis of the values in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 08fb7c9d84c0b..78bf6c1bcebf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType @@ -98,6 +99,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(cube0.where("date IS NULL").count > 0) } + test("grouping and grouping_id") { + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping("course"), grouping("year"), grouping_id("course", "year")), + Row("Java", 2012, 0, 0, 0) :: + Row("Java", 2013, 0, 0, 0) :: + Row("Java", null, 0, 1, 1) :: + Row("dotNET", 2012, 0, 0, 0) :: + Row("dotNET", 2013, 0, 0, 0) :: + Row("dotNET", null, 0, 1, 1) :: + Row(null, 2012, 1, 0, 2) :: + Row(null, 2013, 1, 0, 2) :: + Row(null, null, 1, 1, 3) :: Nil + ) + + intercept[AnalysisException] { + courseSales.groupBy().agg(grouping("course")).explain() + } + intercept[AnalysisException] { + courseSales.groupBy().agg(grouping_id("course")).explain() + } + } + + test("grouping/grouping_id inside window function") { + + val w = Window.orderBy(sum("earnings")) + checkAnswer( + courseSales.cube("course", "year") + .agg(sum("earnings"), + grouping_id("course", "year"), + rank().over(Window.partitionBy(grouping_id("course", "year")).orderBy(sum("earnings")))), + Row("Java", 2012, 20000.0, 0, 2) :: + Row("Java", 2013, 30000.0, 0, 3) :: + Row("Java", null, 50000.0, 1, 1) :: + Row("dotNET", 2012, 15000.0, 0, 1) :: + Row("dotNET", 2013, 48000.0, 0, 4) :: + Row("dotNET", null, 63000.0, 1, 2) :: + Row(null, 2012, 35000.0, 2, 1) :: + Row(null, 2013, 78000.0, 2, 2) :: + Row(null, null, 113000.0, 3, 1) :: Nil + ) + } + test("rollup overlapping columns") { checkAnswer( testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8ef7b61314a56..f665a1c87bd78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2055,6 +2055,56 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("grouping sets") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(course, year)"), + Row("Java", null, 50000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: Nil + ) + + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(course)"), + Row("Java", null, 50000.0) :: + Row("dotNET", null, 63000.0) :: Nil + ) + + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(year)"), + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: Nil + ) + } + + test("grouping and grouping_id") { + checkAnswer( + sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" + + " from courseSales group by cube(course, year)"), + Row("Java", 2012, 0, 0, 0) :: + Row("Java", 2013, 0, 0, 0) :: + Row("Java", null, 0, 1, 1) :: + Row("dotNET", 2012, 0, 0, 0) :: + Row("dotNET", 2013, 0, 0, 0) :: + Row("dotNET", null, 0, 1, 1) :: + Row(null, 2012, 1, 0, 2) :: + Row(null, 2013, 1, 0, 2) :: + Row(null, null, 1, 1, 3) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year, grouping(course) from courseSales group by course, year") + } + assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year, grouping_id(course, year) from courseSales group by course, year") + } + assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + test("SPARK-13056: Null in map value causes NPE") { val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") withTempTable("maptest") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 61b73fa557144..9097c1a1d3117 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -328,6 +328,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive returns null rather than NaN when n = 1 "udaf_covar_samp", + // The implementation of GROUPING__ID in Hive is wrong (not match with doc). + "groupby_grouping_id1", + "groupby_grouping_id2", + "groupby_grouping_sets1", + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. "udf_abs", "udf_format_number", @@ -503,9 +508,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby11", "groupby12", "groupby1_limit", - "groupby_grouping_id1", - "groupby_grouping_id2", - "groupby_grouping_sets1", "groupby_grouping_sets2", "groupby_grouping_sets3", "groupby_grouping_sets4", diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 index dac1b84b916d7..c066aeead822e 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 @@ -1,6 +1,6 @@ -500 NULL 0 -91 0 1 -84 1 1 -105 2 1 -113 3 1 -107 4 1 +500 NULL 1 +91 0 0 +84 1 0 +105 2 0 +113 3 0 +107 4 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a index dac1b84b916d7..c066aeead822e 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a @@ -1,6 +1,6 @@ -500 NULL 0 -91 0 1 -84 1 1 -105 2 1 -113 3 1 -107 4 1 +500 NULL 1 +91 0 0 +84 1 0 +105 2 0 +113 3 0 +107 4 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 index 1eea4a9b23687..fcacbe3f69227 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 @@ -1,10 +1,10 @@ -1 0 5 3 -1 0 15 3 -1 0 25 3 -1 0 60 3 -1 0 75 3 -1 0 80 3 -1 0 100 3 -1 0 140 3 -1 0 145 3 -1 0 150 3 +1 0 5 0 +1 0 15 0 +1 0 25 0 +1 0 60 0 +1 0 75 0 +1 0 80 0 +1 0 100 0 +1 0 140 0 +1 0 145 0 +1 0 150 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce index 1eea4a9b23687..fcacbe3f69227 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce @@ -1,10 +1,10 @@ -1 0 5 3 -1 0 15 3 -1 0 25 3 -1 0 60 3 -1 0 75 3 -1 0 80 3 -1 0 100 3 -1 0 140 3 -1 0 145 3 -1 0 150 3 +1 0 5 0 +1 0 15 0 +1 0 25 0 +1 0 60 0 +1 0 75 0 +1 0 80 0 +1 0 100 0 +1 0 140 0 +1 0 145 0 +1 0 150 0 From 8f744fe3d931c2380613b8e5bafa1bb1fd292839 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 23:23:01 -0800 Subject: [PATCH 1863/2089] [SPARK-13234] [SQL] remove duplicated SQL metrics For lots of SQL operators, we have metrics for both of input and output, the number of input rows should be exactly the number of output rows of child, we could only have metrics for output rows. After we improved the performance using whole stage codegen, the overhead of SQL metrics are not trivial anymore, we should avoid that if it's not necessary. This PR remove all the SQL metrics for number of input rows, add SQL metric of number of output rows for all LeafNode. All remove the SQL metrics from those operators that have the same number of rows from input and output (for example, Projection, we may don't need that). The new SQL UI will looks like: ![metrics](https://cloud.githubusercontent.com/assets/40902/12965227/63614e5e-d009-11e5-88b3-84fea04f9c20.png) Author: Davies Liu Closes #11163 from davies/remove_metrics. --- .../spark/sql/execution/ExistingRDD.scala | 12 +++++- .../spark/sql/execution/LocalTableScan.scala | 12 +++++- .../aggregate/SortBasedAggregate.scala | 3 -- .../SortBasedAggregationIterator.scala | 3 -- .../aggregate/TungstenAggregate.scala | 3 -- .../TungstenAggregationIterator.scala | 3 -- .../spark/sql/execution/basicOperators.scala | 21 ++++------- .../columnar/InMemoryColumnarTableScan.scala | 14 ++++++- .../execution/joins/BroadcastHashJoin.scala | 18 ++------- .../joins/BroadcastHashOuterJoin.scala | 26 +------------ .../joins/BroadcastLeftSemiJoinHash.scala | 13 ++----- .../joins/BroadcastNestedLoopJoin.scala | 8 ---- .../execution/joins/CartesianProduct.scala | 14 +------ .../spark/sql/execution/joins/HashJoin.scala | 2 - .../sql/execution/joins/HashSemiJoin.scala | 7 +--- .../sql/execution/joins/HashedRelation.scala | 10 +---- .../sql/execution/joins/LeftSemiJoinBNL.scala | 6 --- .../execution/joins/LeftSemiJoinHash.scala | 12 ++---- .../sql/execution/joins/SortMergeJoin.scala | 14 +------ .../execution/joins/SortMergeOuterJoin.scala | 18 +-------- .../execution/joins/HashedRelationSuite.scala | 14 ++----- .../execution/metric/SQLMetricsSuite.scala | 37 ------------------- .../sql/util/DataFrameCallbackSuite.scala | 8 ++-- .../sql/hive/execution/HiveTableScan.scala | 10 ++++- 24 files changed, 80 insertions(+), 208 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 92cfd5f841c51..cad7e25a32788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -103,8 +104,11 @@ private[sql] case class PhysicalRDD( override val outputPartitioning: Partitioning = UnknownPartitioning(0)) extends LeafNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - if (isUnsafeRow) { + val unsafeRow = if (isUnsafeRow) { rdd } else { rdd.mapPartitionsInternal { iter => @@ -112,6 +116,12 @@ private[sql] case class PhysicalRDD( iter.map(proj) } } + + val numOutputRows = longMetric("numOutputRows") + unsafeRow.map { r => + numOutputRows += 1 + r + } } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 59057bf9666ef..f8aec9e7a1d1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -29,6 +30,9 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + private val unsafeRows: Array[InternalRow] = { val proj = UnsafeProjection.create(output, output) rows.map(r => proj(r).copy()).toArray @@ -36,7 +40,13 @@ private[sql] case class LocalTableScan( private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) - protected override def doExecute(): RDD[InternalRow] = rdd + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + rdd.map { r => + numOutputRows += 1 + r + } + } override def executeCollect(): Array[InternalRow] = { unsafeRows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 06a3991459f08..9fcfea8381ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -46,7 +46,6 @@ case class SortBasedAggregate( AttributeSet(aggregateBufferAttributes) override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -68,7 +67,6 @@ case class SortBasedAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsInternal { iter => // Because the constructor of an aggregation iterator will read at least the first row, @@ -89,7 +87,6 @@ case class SortBasedAggregate( resultExpressions, (expressions, inputSchema) => newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), - numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 6501634ff998b..8f974980bb323 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -35,7 +35,6 @@ class SortBasedAggregationIterator( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends AggregationIterator( groupingExpressions, @@ -97,7 +96,6 @@ class SortBasedAggregationIterator( val inputRow = inputIterator.next() nextGroupingKey = groupingProjection(inputRow).copy() firstRowInNextGroup = inputRow.copy() - numInputRows += 1 sortedInputHasNewGroup = true } else { // This inputIter is empty. @@ -122,7 +120,6 @@ class SortBasedAggregationIterator( // Get the grouping key. val currentRow = inputIterator.next() val groupingKey = groupingProjection(currentRow) - numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 340b8f78e5c9d..a6950f805a113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -47,7 +47,6 @@ case class TungstenAggregate( require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) @@ -77,7 +76,6 @@ case class TungstenAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") @@ -102,7 +100,6 @@ case class TungstenAggregate( child.output, iter, testFallbackStartsAt, - numInputRows, numOutputRows, dataSize, spillSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 001e9c306ac45..c4f65948357e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -85,7 +85,6 @@ class TungstenAggregationIterator( originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], - numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) @@ -179,14 +178,12 @@ class TungstenAggregationIterator( val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = inputIter.next() - numInputRows += 1 processRow(buffer, newInput) } } else { var i = 0 while (inputIter.hasNext) { val newInput = inputIter.next() - numInputRows += 1 val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null if (i < fallbackStartsAt) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f63e8a9b6d79d..949acb9aca762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -29,9 +29,6 @@ import org.apache.spark.util.random.PoissonSampler case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with CodegenSupport { - override private[sql] lazy val metrics = Map( - "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def upstream(): RDD[InternalRow] = { @@ -55,14 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) } protected override def doExecute(): RDD[InternalRow] = { - val numRows = longMetric("numRows") child.execute().mapPartitionsInternal { iter => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) - iter.map { row => - numRows += 1 - project(row) - } + iter.map(project) } } @@ -74,7 +67,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit override def output: Seq[Attribute] = child.output private[sql] override lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def upstream(): RDD[InternalRow] = { @@ -104,12 +96,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit } protected override def doExecute(): RDD[InternalRow] = { - val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsInternal { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => - numInputRows += 1 val r = predicate(row) if (r) numOutputRows += 1 r @@ -135,9 +125,7 @@ case class Sample( upperBound: Double, withReplacement: Boolean, seed: Long, - child: SparkPlan) - extends UnaryNode -{ + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output protected override def doExecute(): RDD[InternalRow] = { @@ -163,6 +151,9 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def upstream(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) } @@ -241,6 +232,7 @@ case class Range( } protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) @@ -283,6 +275,7 @@ case class Range( overflow = true } + numOutputRows += 1 unsafeRow.setLong(0, ret) unsafeRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 9084b74d1a741..4858140229d45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel @@ -216,6 +217,9 @@ private[sql] case class InMemoryColumnarTableScan( @transient relation: InMemoryRelation) extends LeafNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = attributes // The cached version does not change the outputPartitioning of the original SparkPlan. @@ -286,6 +290,8 @@ private[sql] case class InMemoryColumnarTableScan( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + if (enableAccumulators) { readPartitions.setValue(0) readBatches.setValue(0) @@ -332,12 +338,18 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } + // update SQL metrics + val withMetrics = cachedBatchesToScan.map { batch => + numOutputRows += batch.numRows + batch + } + val columnTypes = requestedColumnDataTypes.map { case udt: UserDefinedType[_] => udt.sqlType case other => other }.toArray val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { readPartitions += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index cbd549763ac95..35c7963b48c4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -48,8 +48,6 @@ case class BroadcastHashJoin( extends BinaryNode with HashJoin with CodegenSupport { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) val timeout: Duration = { @@ -70,11 +68,6 @@ case class BroadcastHashJoin( // for the same query. @transient private lazy val broadcastFuture = { - val numBuildRows = buildSide match { - case BuildLeft => longMetric("numLeftRows") - case BuildRight => longMetric("numRightRows") - } - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) Future { @@ -84,7 +77,6 @@ case class BroadcastHashJoin( // Note that we use .execute().collect() because we don't want to convert data to Scala // types val input: Array[InternalRow] = buildPlan.execute().map { row => - numBuildRows += 1 row.copy() }.collect() // The following line doesn't run in a job so we cannot track the metric value. However, we @@ -93,10 +85,10 @@ case class BroadcastHashJoin( // TODO: move this check into HashedRelation val hashed = if (canJoinKeyFitWithinLong) { LongHashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + input.iterator, buildSideKeyGenerator, input.size) } else { HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + input.iterator, buildSideKeyGenerator, input.size) } sparkContext.broadcast(hashed) } @@ -108,10 +100,6 @@ case class BroadcastHashJoin( } protected override def doExecute(): RDD[InternalRow] = { - val numStreamedRows = buildSide match { - case BuildLeft => longMetric("numRightRows") - case BuildRight => longMetric("numLeftRows") - } val numOutputRows = longMetric("numOutputRows") val broadcastRelation = Await.result(broadcastFuture, timeout) @@ -119,7 +107,7 @@ case class BroadcastHashJoin( streamedPlan.execute().mapPartitions { streamedIter => val hashedRelation = broadcastRelation.value TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) + hashJoin(streamedIter, hashedRelation, numOutputRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ad3275696e637..5e8c8ca043629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -44,8 +44,6 @@ case class BroadcastHashOuterJoin( right: SparkPlan) extends BinaryNode with HashOuterJoin { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) val timeout = { @@ -66,14 +64,6 @@ case class BroadcastHashOuterJoin( // for the same query. @transient private lazy val broadcastFuture = { - val numBuildRows = joinType match { - case RightOuter => longMetric("numLeftRows") - case LeftOuter => longMetric("numRightRows") - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) Future { @@ -83,14 +73,9 @@ case class BroadcastHashOuterJoin( // Note that we use .execute().collect() because we don't want to convert data to Scala // types val input: Array[InternalRow] = buildPlan.execute().map { row => - numBuildRows += 1 row.copy() }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) + val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -101,13 +86,6 @@ case class BroadcastHashOuterJoin( } override def doExecute(): RDD[InternalRow] = { - val numStreamedRows = joinType match { - case RightOuter => longMetric("numRightRows") - case LeftOuter => longMetric("numLeftRows") - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } val numOutputRows = longMetric("numOutputRows") val broadcastRelation = Await.result(broadcastFuture, timeout) @@ -122,7 +100,6 @@ case class BroadcastHashOuterJoin( joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { - numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) @@ -130,7 +107,6 @@ case class BroadcastHashOuterJoin( case RightOuter => streamedIter.flatMap(currentRow => { - numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index d0e18dfcf3d90..4f1cfd2e8171b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -36,36 +36,31 @@ case class BroadcastLeftSemiJoinHash( condition: Option[Expression]) extends BinaryNode with HashSemiJoin { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") val input = right.execute().map { row => - numRightRows += 1 row.copy() }.collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) + val hashSet = buildKeyHashSet(input.toIterator) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitionsInternal { streamIter => - hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) + hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows) } } else { val hashRelation = - HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) + HashedRelation(input.toIterator, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) + hashSemiJoin(streamIter, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index e55f8694781a3..4585cbda929ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -36,8 +36,6 @@ case class BroadcastNestedLoopJoin( // TODO: Override requiredChildDistribution. override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) /** BuildRight means the right relation <=> the broadcast relation. */ @@ -73,15 +71,10 @@ case class BroadcastNestedLoopJoin( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { - val (numStreamedRows, numBuildRows) = buildSide match { - case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) - case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) - } val numOutputRows = longMetric("numOutputRows") val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map { row => - numBuildRows += 1 row.copy() }.collect().toIndexedSeq) @@ -98,7 +91,6 @@ case class BroadcastNestedLoopJoin( streamedIter.foreach { streamedRow => var i = 0 var streamRowMatched = false - numStreamedRows += 1 while (i < broadcastedRelation.value.size) { val broadcastedRow = broadcastedRelation.value(i) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 93d32e1fb93ae..e417079b61b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -82,23 +82,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod override def output: Seq[Attribute] = left.output ++ right.output override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") - val leftResults = left.execute().map { row => - numLeftRows += 1 - row.asInstanceOf[UnsafeRow] - } - val rightResults = right.execute().map { row => - numRightRows += 1 - row.asInstanceOf[UnsafeRow] - } + val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] + val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) pair.mapPartitionsInternal { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ecbb1ac64b7c0..332a748d3bfc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -99,7 +99,6 @@ trait HashJoin { protected def hashJoin( streamIter: Iterator[InternalRow], - numStreamRows: LongSQLMetric, hashedRelation: HashedRelation, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { @@ -126,7 +125,6 @@ trait HashJoin { // find the next match while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() - numStreamRows += 1 val key = joinKeys(currentStreamedRow) if (!key.anyNull) { currentHashMatches = hashedRelation.get(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 3e0f74cd98c21..0220e0b8a7c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -43,14 +43,13 @@ trait HashSemiJoin { newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected def buildKeyHashSet( - buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { + buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { val currentRow = buildIter.next() - numBuildRows += 1 val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -65,12 +64,10 @@ trait HashSemiJoin { protected def hashSemiJoin( streamIter: Iterator[InternalRow], - numStreamRows: LongSQLMetric, hashSet: java.util.Set[InternalRow], numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator streamIter.filter(current => { - numStreamRows += 1 val key = joinKeys(current) val r = !key.anyNull && hashSet.contains(key) if (r) numOutputRows += 1 @@ -80,13 +77,11 @@ trait HashSemiJoin { protected def hashSemiJoin( streamIter: Iterator[InternalRow], - numStreamRows: LongSQLMetric, hashedRelation: HashedRelation, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter { current => - numStreamRows += 1 val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index eb6930a14f9c1..0978570d429f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -159,18 +159,17 @@ private[joins] class UniqueKeyHashedRelation( private[execution] object HashedRelation { def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { - apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) + apply(localNode.asIterator, keyGenerator) } def apply( input: Iterator[InternalRow], - numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { if (keyGenerator.isInstanceOf[UnsafeProjection]) { return UnsafeHashedRelation( - input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } // TODO: Use Spark's HashMap implementation. @@ -184,7 +183,6 @@ private[execution] object HashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { currentRow = input.next() - numInputRows += 1 val rowKey = keyGenerator(currentRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -427,7 +425,6 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - numInputRows: LongSQLMetric, keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { @@ -437,7 +434,6 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] - numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -604,7 +600,6 @@ private[joins] object LongHashedRelation { def apply( input: Iterator[InternalRow], - numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int): HashedRelation = { @@ -619,7 +614,6 @@ private[joins] object LongHashedRelation { while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() - numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val key = rowKey.getLong(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 82498ee395649..ce758d63b36f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -34,8 +34,6 @@ case class LeftSemiJoinBNL( // TODO: Override requiredChildDistribution. override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def outputPartitioning: Partitioning = streamed.outputPartitioning @@ -52,13 +50,10 @@ case class LeftSemiJoinBNL( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map { row => - numRightRows += 1 row.copy() }.collect().toIndexedSeq) @@ -66,7 +61,6 @@ case class LeftSemiJoinBNL( val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { - numLeftRows += 1 var i = 0 var matched = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 25b3b5ca2377f..d8d3045ccf5c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -36,8 +36,6 @@ case class LeftSemiJoinHash( condition: Option[Expression]) extends BinaryNode with HashSemiJoin { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def outputPartitioning: Partitioning = left.outputPartitioning @@ -46,17 +44,15 @@ case class LeftSemiJoinHash( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter, numRightRows) - hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) + val hashSet = buildKeyHashSet(buildIter) + hashSemiJoin(streamIter, hashSet, numOutputRows) } else { - val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) - hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + hashSemiJoin(streamIter, hashRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 322a954b4f792..cd8a5670e2301 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -37,8 +37,6 @@ case class SortMergeJoin( right: SparkPlan) extends BinaryNode { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = left.output ++ right.output @@ -60,8 +58,6 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => @@ -89,9 +85,7 @@ case class SortMergeJoin( rightKeyGenerator, keyOrdering, RowIterator.fromScala(leftIter), - numLeftRows, - RowIterator.fromScala(rightIter), - numRightRows + RowIterator.fromScala(rightIter) ) private[this] val joinRow = new JoinedRow private[this] val resultProjection: (InternalRow) => InternalRow = @@ -157,9 +151,7 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - numStreamedRows: LongSQLMetric, - bufferedIter: RowIterator, - numBufferedRows: LongSQLMetric) { + bufferedIter: RowIterator) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -284,7 +276,6 @@ private[joins] class SortMergeJoinScanner( if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) - numStreamedRows += 1 true } else { streamedRow = null @@ -302,7 +293,6 @@ private[joins] class SortMergeJoinScanner( while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) - numBufferedRows += 1 foundRow = !bufferedRowKey.anyNull } if (!foundRow) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index ed41ad2a005eb..40a6c93b59cbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -40,8 +40,6 @@ case class SortMergeOuterJoin( right: SparkPlan) extends BinaryNode { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = { @@ -96,8 +94,6 @@ case class SortMergeOuterJoin( UnsafeProjection.create(rightKeys, right.output) override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => @@ -119,9 +115,7 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - numLeftRows, - bufferedIter = RowIterator.fromScala(rightIter), - numRightRows + bufferedIter = RowIterator.fromScala(rightIter) ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -133,9 +127,7 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - numRightRows, - bufferedIter = RowIterator.fromScala(leftIter), - numLeftRows + bufferedIter = RowIterator.fromScala(leftIter) ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -149,9 +141,7 @@ case class SortMergeOuterJoin( rightKeyGenerator = createRightKeyGenerator(), keyOrdering, leftIter = RowIterator.fromScala(leftIter), - numLeftRows, rightIter = RowIterator.fromScala(rightIter), - numRightRows, boundCondition, leftNullRow, rightNullRow) @@ -289,9 +279,7 @@ private class SortMergeFullOuterJoinScanner( rightKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], leftIter: RowIterator, - numLeftRows: LongSQLMetric, rightIter: RowIterator, - numRightRows: LongSQLMetric, boundCondition: InternalRow => Boolean, leftNullRow: InternalRow, rightNullRow: InternalRow) { @@ -321,7 +309,6 @@ private class SortMergeFullOuterJoinScanner( if (leftIter.advanceNext()) { leftRow = leftIter.getRow leftRowKey = leftKeyGenerator(leftRow) - numLeftRows += 1 true } else { leftRow = null @@ -338,7 +325,6 @@ private class SortMergeFullOuterJoinScanner( if (rightIter.advanceNext()) { rightRow = rightIter.getRow rightRowKey = rightKeyGenerator(rightRow) - numRightRows += 1 true } else { rightRow = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index f985dfbd8ade9..04dd809df17ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -36,8 +36,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") - val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) + val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -47,13 +46,11 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) === data2) - assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") - val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) + val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -66,19 +63,17 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(uniqHashed.getValue(data(1)) === data(1)) assert(uniqHashed.getValue(data(2)) === data(2)) assert(uniqHashed.getValue(InternalRow(10)) === null) - assert(numDataRows.value.value === data.length) } test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) @@ -100,7 +95,6 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) - assert(numDataRows.value.value === data.length) val os2 = new ByteArrayOutputStream() val out2 = new ObjectOutputStream(os2) @@ -139,7 +133,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) - val longRelation = LongHashedRelation(rows.iterator, SQLMetrics.nullLongMetric, keyProj, 100) + val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) assert(longRelation.isInstanceOf[LongArrayRelation]) val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] (0 until 100).foreach { i => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2260e4870299a..d24625a5351ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -113,23 +113,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } - test("Project metrics") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L -> ("Project", Map( - "number of rows" -> 2L))) - ) - } - test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) val df = person.filter('age < 25) testSparkPlanMetrics(df, 1, Map( 0L -> ("Filter", Map( - "number of input rows" -> 2L, "number of output rows" -> 1L))) ) } @@ -149,10 +138,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, "number of output rows" -> 2L)), 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 2L, "number of output rows" -> 1L))) ) @@ -160,10 +147,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, "number of output rows" -> 4L)), 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 4L, "number of output rows" -> 3L))) ) } @@ -181,8 +166,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df, 1, Map( 1L -> ("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 4L, - "number of right rows" -> 2L, "number of output rows" -> 4L))) ) } @@ -201,8 +184,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df, 1, Map( 1L -> ("SortMergeOuterJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 6L, - "number of right rows" -> 2L, "number of output rows" -> 8L))) ) @@ -211,8 +192,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df2, 1, Map( 1L -> ("SortMergeOuterJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 2L, - "number of right rows" -> 6L, "number of output rows" -> 8L))) ) } @@ -226,8 +205,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> ("BroadcastHashJoin", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, "number of output rows" -> 2L))) ) } @@ -240,16 +217,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") testSparkPlanMetrics(df, 2, Map( 0L -> ("BroadcastHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, "number of output rows" -> 5L))) ) val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") testSparkPlanMetrics(df3, 2, Map( 0L -> ("BroadcastHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, "number of output rows" -> 6L))) ) } @@ -265,8 +238,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") testSparkPlanMetrics(df, 3, Map( 1L -> ("BroadcastNestedLoopJoin", Map( - "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 2L, "number of output rows" -> 12L))) ) } @@ -280,8 +251,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( 0L -> ("BroadcastLeftSemiJoinHash", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, "number of output rows" -> 2L))) ) } @@ -295,8 +264,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(df2, $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 1, Map( 0L -> ("LeftSemiJoinHash", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, "number of output rows" -> 2L))) ) } @@ -310,8 +277,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(df2, $"key" < $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( 0L -> ("LeftSemiJoinBNL", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, "number of output rows" -> 2L))) ) } @@ -326,8 +291,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "SELECT * FROM testData2 JOIN testDataForJoin") testSparkPlanMetrics(df, 1, Map( 1L -> ("CartesianProduct", Map( - "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 4L, // right is read twice "number of output rows" -> 12L))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a3e5243b68aba..d3191d3aead95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -92,7 +92,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - metrics += qe.executedPlan.longMetric("numInputRows").value.value + metrics += qe.executedPlan.longMetric("numOutputRows").value.value } } sqlContext.listenerManager.register(listener) @@ -105,9 +105,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } assert(metrics.length == 3) - assert(metrics(0) == 1) - assert(metrics(1) == 1) - assert(metrics(2) == 2) + assert(metrics(0) === 1) + assert(metrics(1) === 1) + assert(metrics(2) === 2) sqlContext.listenerManager.unregister(listener) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index eff8833e9232e..235b80b7c697c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.util.Utils @@ -52,6 +53,9 @@ case class HiveTableScan( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(partitionPruningPred.flatMap(_.references)) @@ -146,9 +150,13 @@ case class HiveTableScan( prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) } } + val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsInternal { iter => val proj = UnsafeProjection.create(schema) - iter.map(proj) + iter.map { r => + numOutputRows += 1 + proj(r) + } } } From 1842c55d89ae99a610a955ce61633a9084e000f2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 11 Feb 2016 08:30:58 +0100 Subject: [PATCH 1864/2089] [SPARK-13276] Catch bad characters at the end of a Table Identifier/Expression string The parser currently parses the following strings without a hitch: * Table Identifier: * `a.b.c` should fail, but results in the following table identifier `a.b` * `table!#` should fail, but results in the following table identifier `table` * Expression * `1+2 r+e` should fail, but results in the following expression `1 + 2` This PR fixes this by adding terminated rules for both expression parsing and table identifier parsing. cc cloud-fan (we discussed this in https://github.com/apache/spark/pull/10649) jayadevanmurali (this causes your PR https://github.com/apache/spark/pull/11051 to fail) Author: Herman van Hovell Closes #11159 from hvanhovell/SPARK-13276. --- .../spark/sql/catalyst/parser/SparkSqlParser.g | 12 ++++++++++++ .../spark/sql/catalyst/parser/ParseDriver.scala | 4 ++-- .../apache/spark/sql/catalyst/CatalystQlSuite.scala | 5 +++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 9935678ca2ca2..9f2a5eb35cba5 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -722,6 +722,18 @@ statement | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG) ; +// Rule for expression parsing +singleNamedExpression + : + namedExpression EOF + ; + +// Rule for table name parsing +singleTableName + : + tableName EOF + ; + explainStatement @init { pushMsg("explain statement", state); } @after { popMsg(state); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index f8e4f21451192..9ff41f5bece5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -35,12 +35,12 @@ object ParseDriver extends Logging { /** Create an Expression ASTNode from a SQL command. */ def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.namedExpression().getTree + parser.singleNamedExpression().getTree } /** Create an TableIdentifier ASTNode from a SQL command. */ def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.tableName().getTree + parser.singleTableName().getTree } private def parse( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 6d25de98cebc4..682b77dc65ac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -134,14 +134,15 @@ class CatalystQlSuite extends PlanTest { Literal("o") :: UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) :: Nil, false))) + + intercept[AnalysisException](parser.parseExpression("1 - f('o', o(bar)) hello * world")) } test("table identifier") { assert(TableIdentifier("q") === parser.parseTableIdentifier("q")) assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q")) intercept[AnalysisException](parser.parseTableIdentifier("")) - // TODO parser swallows third identifier. - // intercept[AnalysisException](parser.parseTableIdentifier("d.q.g")) + intercept[AnalysisException](parser.parseTableIdentifier("d.q.g")) } test("parse union/except/intersect") { From e88bff12795a6134e2e7204996b603e948380e18 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Feb 2016 08:40:27 +0100 Subject: [PATCH 1865/2089] [SPARK-13235][SQL] Removed an Extra Distinct from the Plan when Using Union in SQL Currently, the parser added two `Distinct` operators in the plan if we are using `Union` or `Union Distinct` in the SQL. This PR is to remove the extra `Distinct` from the plan. For example, before the fix, the following query has a plan with two `Distinct` ```scala sql("select * from t0 union select * from t0").explain(true) ``` ``` == Parsed Logical Plan == 'Project [unresolvedalias(*,None)] +- 'Subquery u_2 +- 'Distinct +- 'Project [unresolvedalias(*,None)] +- 'Subquery u_1 +- 'Distinct +- 'Union :- 'Project [unresolvedalias(*,None)] : +- 'UnresolvedRelation `t0`, None +- 'Project [unresolvedalias(*,None)] +- 'UnresolvedRelation `t0`, None == Analyzed Logical Plan == id: bigint Project [id#16L] +- Subquery u_2 +- Distinct +- Project [id#16L] +- Subquery u_1 +- Distinct +- Union :- Project [id#16L] : +- Subquery t0 : +- Relation[id#16L] ParquetRelation +- Project [id#16L] +- Subquery t0 +- Relation[id#16L] ParquetRelation == Optimized Logical Plan == Aggregate [id#16L], [id#16L] +- Aggregate [id#16L], [id#16L] +- Union :- Project [id#16L] : +- Relation[id#16L] ParquetRelation +- Project [id#16L] +- Relation[id#16L] ParquetRelation ``` After the fix, the plan is changed without the extra `Distinct` as follows: ``` == Parsed Logical Plan == 'Project [unresolvedalias(*,None)] +- 'Subquery u_1 +- 'Distinct +- 'Union :- 'Project [unresolvedalias(*,None)] : +- 'UnresolvedRelation `t0`, None +- 'Project [unresolvedalias(*,None)] +- 'UnresolvedRelation `t0`, None == Analyzed Logical Plan == id: bigint Project [id#17L] +- Subquery u_1 +- Distinct +- Union :- Project [id#16L] : +- Subquery t0 : +- Relation[id#16L] ParquetRelation +- Project [id#16L] +- Subquery t0 +- Relation[id#16L] ParquetRelation == Optimized Logical Plan == Aggregate [id#17L], [id#17L] +- Union :- Project [id#16L] : +- Relation[id#16L] ParquetRelation +- Project [id#16L] +- Relation[id#16L] ParquetRelation ``` Author: gatorsmile Closes #11120 from gatorsmile/unionDistinct. --- .../sql/catalyst/parser/SparkSqlParser.g | 28 +--------------- .../spark/sql/catalyst/CatalystQlSuite.scala | 33 +++++++++++++++++-- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 9f2a5eb35cba5..24483ccb5d50e 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -2370,34 +2370,8 @@ setOpSelectStatement[CommonTree t, boolean topLevel] u=setOperator LPAREN b=simpleSelectStatement RPAREN | u=setOperator b=simpleSelectStatement) - -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? - ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - ^($u {$setOpSelectStatement.tree} $b) - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) - -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? + -> {$setOpSelectStatement.tree != null}? ^($u {$setOpSelectStatement.tree} $b) - -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? - ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - ^($u {$t} $b) - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) -> ^($u {$t} $b) )+ o=orderByClause? diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 682b77dc65ac9..8d7d6b5bf52ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { @@ -45,6 +45,35 @@ class CatalystQlSuite extends PlanTest { comparePlans(parsed, expected) } + test("test Union Distinct operator") { + val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1") + val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") + val expected = + Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, + Subquery("u_1", + Distinct( + Union( + Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, + UnresolvedRelation(TableIdentifier("t0"), None)), + Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, + UnresolvedRelation(TableIdentifier("t1"), None)))))) + comparePlans(parsed1, expected) + comparePlans(parsed2, expected) + } + + test("test Union All operator") { + val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") + val expected = + Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, + Subquery("u_1", + Union( + Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, + UnresolvedRelation(TableIdentifier("t0"), None)), + Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, + UnresolvedRelation(TableIdentifier("t1"), None))))) + comparePlans(parsed, expected) + } + test("support hive interval literal") { def checkInterval(sql: String, result: CalendarInterval): Unit = { val parsed = parser.parsePlan(sql) From 18bcbbdd84e80222d1d29530831c6d68d02e7593 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 10 Feb 2016 23:52:19 -0800 Subject: [PATCH 1866/2089] [SPARK-13270][SQL] Remove extra new lines in whole stage codegen and include pipeline plan in comments. Author: Nong Li Closes #11155 from nongli/spark-13270. --- .../expressions/codegen/CodeFormatter.scala | 14 ++++++++++++++ .../spark/sql/execution/WholeStageCodegen.scala | 8 ++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 9b8b6382d753d..9d99bbffbe13e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -25,6 +25,20 @@ package org.apache.spark.sql.catalyst.expressions.codegen */ object CodeFormatter { def format(code: String): String = new CodeFormatter().addLines(code).result() + def stripExtraNewLines(input: String): String = { + val code = new StringBuilder + var lastLine: String = "dummy" + input.split('\n').foreach { l => + val line = l.trim() + val skip = line == "" && (lastLine == "" || lastLine.endsWith("{")) + if (!skip) { + code.append(line) + code.append("\n") + } + lastLine = line + } + code.result() + } } private class CodeFormatter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index b200239c94206..30f74fc14f6c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -237,6 +237,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) return new GeneratedIterator(references); } + /** Codegened pipeline for: + * ${plan.treeString.trim} + */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; @@ -256,8 +259,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) """ // try to compile, helpful for debug - // println(s"${CodeFormatter.format(source)}") - CodeGenerator.compile(source) + val cleanedSource = CodeFormatter.stripExtraNewLines(source) + // println(s"${CodeFormatter.format(cleanedSource)}") + CodeGenerator.compile(cleanedSource) plan.upstream().mapPartitions { iter => From c2f21d88981789fe3366f2c4040445aeff5bf083 Mon Sep 17 00:00:00 2001 From: Sasaki Toru Date: Thu, 11 Feb 2016 09:30:36 +0000 Subject: [PATCH 1867/2089] [SPARK-13264][DOC] Removed multi-byte characters in spark-env.sh.template In spark-env.sh.template, there are multi-byte characters, this PR will remove it. Author: Sasaki Toru Closes #11149 from sasakitoa/remove_multibyte_in_sparkenv. --- R/pkg/R/serialize.R | 2 +- conf/spark-env.sh.template | 2 +- docs/sql-programming-guide.md | 2 +- .../org/apache/spark/ml/regression/LinearRegressionSuite.scala | 2 +- sql/README.md | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 095ddb9aed2e7..70e87a93e610f 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -54,7 +54,7 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") - # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # Checking types is needed here, since 'is.na' only handles atomic vectors, # lists and pairlists if (type %in% c("integer", "character", "logical", "double", "numeric")) { if (is.na(object)) { diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 771251f90ee36..a031cd6a722f9 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -41,7 +41,7 @@ # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) -# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) +# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: 'default') # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. # - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ce53a39f9f604..d246100f3e6aa 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2389,7 +2389,7 @@ let user control table caching explicitly: CACHE TABLE logs_last_month; UNCACHE TABLE logs_last_month; -**NOTE:** `CACHE TABLE tbl` is now __eager__ by default not __lazy__. Don’t need to trigger cache materialization manually anymore. +**NOTE:** `CACHE TABLE tbl` is now __eager__ by default not __lazy__. Don't need to trigger cache materialization manually anymore. Spark SQL newly introduced a statement to let user control table caching whether or not lazy since Spark 1.2.0: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 81fc6603ccfe6..3ae108d822de7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -956,7 +956,7 @@ class LinearRegressionSuite V1 -3.7271 2.9032 -1.284 0.3279 V2 3.0100 0.6022 4.998 0.0378 * --- - Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 + Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 (Dispersion parameter for gaussian family taken to be 17.4376) diff --git a/sql/README.md b/sql/README.md index a13bdab6d457f..9ea271d33d856 100644 --- a/sql/README.md +++ b/sql/README.md @@ -5,7 +5,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. + - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. From f9ae99fee13681e436fde9899b6a189746348ba1 Mon Sep 17 00:00:00 2001 From: Junyang Date: Thu, 11 Feb 2016 09:33:11 +0000 Subject: [PATCH 1868/2089] [SPARK-13074][CORE] Add JavaSparkContext. getPersistentRDDs method The "getPersistentRDDs()" is a useful API of SparkContext to get cached RDDs. However, the JavaSparkContext does not have this API. Add a simple getPersistentRDDs() to get java.util.Map for Java users. Author: Junyang Closes #10978 from flyjy/master. --- .../org/apache/spark/api/java/JavaSparkContext.scala | 10 ++++++++++ .../src/test/java/org/apache/spark/JavaAPISuite.java | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 01433ca2efc14..f1aebbcd39638 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -774,6 +774,16 @@ class JavaSparkContext(val sc: SparkContext) /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs(): Unit = sc.cancelAllJobs() + + /** + * Returns an Java map of JavaRDDs that have marked themselves as persistent via cache() call. + * Note that this does not necessarily mean the caching or computation was successful. + */ + def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { + sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) + .asJava.asInstanceOf[JMap[java.lang.Integer, JavaRDD[_]]] + } + } object JavaSparkContext { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8117ad9e60641..e6a4ab7550c2a 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1811,4 +1811,16 @@ public void testRegisterKryoClasses() { conf.get("spark.kryo.classesToRegister")); } + @Test + public void testGetPersistentRDDs() { + java.util.Map> cachedRddsMap = sc.getPersistentRDDs(); + Assert.assertTrue(cachedRddsMap.isEmpty()); + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b")).setName("RDD1").cache(); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("c", "d")).setName("RDD2").cache(); + cachedRddsMap = sc.getPersistentRDDs(); + Assert.assertEquals(2, cachedRddsMap.size()); + Assert.assertEquals("RDD1", cachedRddsMap.get(0).name()); + Assert.assertEquals("RDD2", cachedRddsMap.get(1).name()); + } + } From 13c17cbb0530d52b9a08d5197017c96501d99e8c Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Thu, 11 Feb 2016 08:50:27 -0600 Subject: [PATCH 1869/2089] [SPARK-13124][WEB UI] Fixed CSS and JS issues caused by addition of JQuery DataTables Made sure the old tables continue to use the old css and the new DataTables use the new css. Also fixed it so the Safari Web Inspector doesn't throw errors when on the new DataTables pages. Author: Alex Bozarth Closes #11038 from ajbozarth/spark13124. --- .../spark/ui/static/jsonFormatter.min.js | 1 - .../spark/deploy/history/HistoryPage.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 31 ++++++++++++------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js b/core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js index f2ffcecd7ff51..aa86fa2a0c327 100755 --- a/core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/jsonFormatter.min.js @@ -1,2 +1 @@ (function($){$.fn.jsonFormatter=function(n){var _settings,u=new Date,r=new RegExp,i=function(n,t,i){for(var r="",u=0;u0&&t.charAt(t.length-1)!="\n"&&(t=t+"\n"),r+t},f=function(n,t){for(var r,u,f="",i=0;i").join(">"));var o=""+t+n+t+r+"<\/span>";return f&&(o=i(u,o)),o},_processObject=function(n,e,o,s,h){var c="",l=o?",<\/span> ":"",v=typeof n,a="",y,p,k,w,b;if($.isArray(n))if(n.length==0)c+=i(e,"[ ]<\/span>"+l,h);else{for(a=_settings.collapsible?"<\/span>":"",c+=i(e,"[<\/span>"+a,h),y=0;y":"";c+=i(e,a+"]<\/span>"+l)}else if(v=="object")if(n==null)c+=t("null","",l,e,s,"jsonFormatter-null");else if(n.constructor==u.constructor)c+=t("new Date("+n.getTime()+") /*"+n.toLocaleString()+"*/","",l,e,s,"Date");else if(n.constructor==r.constructor)c+=t("new RegExp("+n+")","",l,e,s,"RegExp");else{p=0;for(w in n)p++;if(p==0)c+=i(e,"{ }<\/span>"+l,h);else{a=_settings.collapsible?"<\/span>":"";c+=i(e,"{<\/span>"+a,h);k=0;for(w in n)b=_settings.quoteKeys?'"':"",c+=i(e+1,""+b+w+b+"<\/span>: "+_processObject(n[w],e+1,++k":"";c+=i(e,a+"}<\/span>"+l)}}else v=="number"?c+=t(n,"",l,e,s,"jsonFormatter-number"):v=="boolean"?c+=t(n,"",l,e,s,"jsonFormatter-boolean"):v=="function"?n.constructor==r.constructor?c+=t("new RegExp("+n+")","",l,e,s,"RegExp"):(n=f(e,n),c+=t(n,"",l,e,s,"jsonFormatter-function")):c+=v=="undefined"?t("undefined","",l,e,s,"jsonFormatter-null"):t(n.toString().split("\\").join("\\\\").split('"').join('\\"'),'"',l,e,s,"jsonFormatter-string");return c},e=function(element){var json=$(element).html(),obj,original;json.trim()==""&&(json='""');try{obj=eval("["+json+"]")}catch(exception){return}html=_processObject(obj[0],0,!1,!1,!1);original=$(element).wrapInner("
          <\/div>");_settings.hideOriginal===!0&&$(".jsonFormatter-original",original).hide();original.append("
          "+html+"<\/PRE>")},o=function(){var n=$(this).next();n.length<1||($(this).hasClass("jsonFormatter-expanded")==!0?(n.hide(),$(this).removeClass("jsonFormatter-expanded").addClass("jsonFormatter-collapsed")):(n.show(),$(this).removeClass("jsonFormatter-collapsed").addClass("jsonFormatter-expanded")))};return _settings=$.extend({tab:"  ",quoteKeys:!0,collapsible:!0,hideOriginal:!0},n),this.each(function(n,t){e(t);$(t).on("click",".jsonFormatter-expander",o)})}})(jQuery);
          -//# sourceMappingURL=jsonFormatter.min.js.map
          diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
          index 513b81315b3d2..cab7faefe89c3 100644
          --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
          +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
          @@ -69,7 +69,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("")
                       
                     
        - UIUtils.basicSparkPage(content, "History Server") + UIUtils.basicSparkPage(content, "History Server", true) } private def makePageLink(showIncomplete: Boolean): String = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 4ebee9093d41c..ddd7f713fe417 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -157,22 +157,11 @@ private[spark] object UIUtils extends Logging { def commonHeaderNodes: Seq[Node] = { - - - - - - - - - @@ -189,6 +178,20 @@ private[spark] object UIUtils extends Logging { } + def dataTablesHeaderNodes: Seq[Node] = { + + + + + + + + + + } + /** Returns a spark page with correctly formatted headers */ def headerSparkPage( title: String, @@ -244,10 +247,14 @@ private[spark] object UIUtils extends Logging { } /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ - def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { + def basicSparkPage( + content: => Seq[Node], + title: String, + useDataTables: Boolean = false): Seq[Node] = { {commonHeaderNodes} + {if (useDataTables) dataTablesHeaderNodes else Seq.empty} {title} From 219a74a7c2d3b858224c4738190ccc92d7cbf06d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 11 Feb 2016 10:10:36 -0800 Subject: [PATCH 1870/2089] [STREAMING][TEST] Fix flaky streaming.FailureSuite Under some corner cases, the test suite failed to shutdown the SparkContext causing cascaded failures. This fix does two things - Makes sure no SparkContext is active after every test - Makes sure StreamingContext is always shutdown (prevents leaking of StreamingContexts as well, just in case) Author: Tathagata Das Closes #11166 from tdas/fix-failuresuite. --- .../test/scala/org/apache/spark/streaming/FailureSuite.scala | 5 ++++- .../scala/org/apache/spark/streaming/MasterFailureTest.scala | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 6a0b0a1d47bc4..31e159e968c1c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.util.Utils /** @@ -43,6 +43,9 @@ class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { Utils.deleteRecursively(directory) } StreamingContext.getActive().foreach { _.stop() } + + // Stop SparkContext if active + SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop() } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index a02d49eced1d5..faa9c4f0cbd6a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -242,6 +242,8 @@ object MasterFailureTest extends Logging { } } catch { case e: Exception => logError("Error running streaming context", e) + } finally { + ssc.stop() } if (killingThread.isAlive) { killingThread.interrupt() @@ -250,7 +252,6 @@ object MasterFailureTest extends Logging { // to null after the next test creates the new SparkContext and fail the test. killingThread.join() } - ssc.stop() logInfo("Has been killed = " + killed) logInfo("Is last output generated = " + isLastOutputGenerated) From e31c80737b7f4d8baa02230788e3963433cb3ef9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 11 Feb 2016 21:09:44 +0100 Subject: [PATCH 1871/2089] [SPARK-13277][SQL] ANTLR ignores other rule using the USING keyword JIRA: https://issues.apache.org/jira/browse/SPARK-13277 There is an ANTLR warning during compilation: warning(200): org/apache/spark/sql/catalyst/parser/SparkSqlParser.g:938:7: Decision can match input such as "KW_USING Identifier" using multiple alternatives: 2, 3 As a result, alternative(s) 3 were disabled for that input This patch is to fix it. Author: Liang-Chi Hsieh Closes #11168 from viirya/fix-parser-using. --- .../org/apache/spark/sql/catalyst/parser/SparkSqlParser.g | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 24483ccb5d50e..e1908a8e03b30 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -949,7 +949,7 @@ createTableStatement tablePropertiesPrefixed? ) | - tableProvider + (tableProvider) => tableProvider tableOpts? (KW_AS selectStatementWithCTE)? -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? From 0d50a22084eea91d4efb0a3ed3fa59b8d9680795 Mon Sep 17 00:00:00 2001 From: jayadevanmurali Date: Thu, 11 Feb 2016 21:21:03 +0100 Subject: [PATCH 1872/2089] [SPARK-12982][SQL] Add table name validation in temp table registration Add the table name validation at the temp table creation Author: jayadevanmurali Closes #11051 from jayadevanmurali/branch-0.2-SPARK-12982. --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 050a1031c0561..d58b99655c1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -720,7 +720,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(TableIdentifier(tableName), df.logicalPlan) + catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3ea4adcaa6424..99ba2e2061258 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1305,4 +1305,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"), Row(1, "a")) } + + test("SPARK-12982: Add table name validation in temp table registration") { + val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col") + // invalid table name test as below + intercept[AnalysisException](df.registerTempTable("t~")) + // valid table name test as below + df.registerTempTable("table1") + // another invalid table name test as below + intercept[AnalysisException](df.registerTempTable("#$@sum")) + // another invalid table name test as below + intercept[AnalysisException](df.registerTempTable("table!#")) + } } From 50fa6fd1b365d5db7e2b2c59624a365cef0d1696 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Thu, 11 Feb 2016 13:28:03 -0800 Subject: [PATCH 1873/2089] [SPARK-13279] Remove O(n^2) operation from scheduler. This commit removes an unnecessary duplicate check in addPendingTask that meant that scheduling a task set took time proportional to (# tasks)^2. Author: Sital Kedia Closes #11167 from sitalkedia/fix_stuck_driver and squashes the following commits: 3fe1af8 [Sital Kedia] [SPARK-13279] Remove unnecessary duplicate check in addPendingTask function --- .../apache/spark/scheduler/TaskSetManager.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index cf97877476d54..4b19beb43fd6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -114,9 +114,14 @@ private[spark] class TaskSetManager( // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. + // back at the head of the stack. These collections may contain duplicates + // for two reasons: + // (1): Tasks are only removed lazily; when a task is launched, it remains + // in all the pending lists except the one that it was launched from. + // (2): Tasks may be re-added to these lists multiple times as a result + // of failures. + // Duplicates are handled in dequeueTaskFromList, which ensures that a + // task hasn't already started running before launching it. private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each host. Similar to pendingTasksForExecutor, @@ -181,9 +186,7 @@ private[spark] class TaskSetManager( private def addPendingTask(index: Int) { // Utility method that adds `index` to a list only if it's not already there def addTo(list: ArrayBuffer[Int]) { - if (!list.contains(index)) { - list += index - } + list += index } for (loc <- tasks(index).preferredLocations) { From c86009ceb9613201b41319245526a13b1f0b5451 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Feb 2016 13:31:13 -0800 Subject: [PATCH 1874/2089] Revert "[SPARK-13279] Remove O(n^2) operation from scheduler." This reverts commit 50fa6fd1b365d5db7e2b2c59624a365cef0d1696. --- .../apache/spark/scheduler/TaskSetManager.scala | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 4b19beb43fd6b..cf97877476d54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -114,14 +114,9 @@ private[spark] class TaskSetManager( // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. These collections may contain duplicates - // for two reasons: - // (1): Tasks are only removed lazily; when a task is launched, it remains - // in all the pending lists except the one that it was launched from. - // (2): Tasks may be re-added to these lists multiple times as a result - // of failures. - // Duplicates are handled in dequeueTaskFromList, which ensures that a - // task hasn't already started running before launching it. + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each host. Similar to pendingTasksForExecutor, @@ -186,7 +181,9 @@ private[spark] class TaskSetManager( private def addPendingTask(index: Int) { // Utility method that adds `index` to a list only if it's not already there def addTo(list: ArrayBuffer[Int]) { - list += index + if (!list.contains(index)) { + list += index + } } for (loc <- tasks(index).preferredLocations) { From efb65e09bcfa4542348f5cd37fe5c14047b862e5 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 11 Feb 2016 15:00:23 -0800 Subject: [PATCH 1875/2089] [SPARK-13265][ML] Refactoring of basic ML import/export for other file system besides HDFS jkbradley I tried to improve the function to export a model. When I tried to export a model to S3 under Spark 1.6, we couldn't do that. So, it should offer S3 besides HDFS. Can you review it when you have time? Thanks! Author: Yu ISHIKAWA Closes #11151 from yu-iskw/SPARK-13265. --- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 8484b1f801066..7b2504361a6ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.util import java.io.IOException -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.json4s._ -import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonDSL._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} @@ -75,13 +75,14 @@ abstract class MLWriter extends BaseReadWrite with Logging { @throws[IOException]("If the input path already exists but overwrite is not enabled.") def save(path: String): Unit = { val hadoopConf = sc.hadoopConfiguration - val fs = FileSystem.get(hadoopConf) - val p = new Path(path) - if (fs.exists(p)) { + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + if (fs.exists(qualifiedOutputPath)) { if (shouldOverwrite) { logInfo(s"Path $path already exists. It will be overwritten.") // TODO: Revert back to the original content if save is not successful. - fs.delete(p, true) + fs.delete(qualifiedOutputPath, true) } else { throw new IOException( s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") From 574571c87098795a2206a113ee9ed4bafba8f00f Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 11 Feb 2016 15:05:34 -0800 Subject: [PATCH 1876/2089] [SPARK-11515][ML] QuantileDiscretizer should take random seed cc jkbradley Author: Yu ISHIKAWA Closes #9535 from yu-iskw/SPARK-11515. --- .../spark/ml/feature/QuantileDiscretizer.scala | 15 ++++++++++----- .../ml/feature/QuantileDiscretizerSuite.scala | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 8fd0ce2f2e26c..2a294d3881829 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param.{IntParam, _} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * Params for [[QuantileDiscretizer]]. */ -private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait QuantileDiscretizerBase extends Params + with HasInputCol with HasOutputCol with HasSeed { /** * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must @@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + override def transformSchema(schema: StructType): StructType = { validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) @@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String) } override def fit(dataset: DataFrame): Bucketizer = { - val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets)) + val samples = QuantileDiscretizer + .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) .map { case Row(feature: Double) => feature } val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) val splits = QuantileDiscretizer.getSplits(candidates) @@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi /** * Sampling from the given dataset to collect quantile statistics. */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { + private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") val requiredSamples = math.max(numBins * numBins, 10000) val fraction = math.min(requiredSamples / dataset.count(), 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() + dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 722f1abde4359..4fde42972f01b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite { val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket) + .setNumBuckets(numBucket).setSeed(1) val result = discretizer.fit(df).transform(df) val transformedFeatures = result.select("result").collect() From c8f667d7c1a0b02685e17b6f498879b05ced9b9d Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Thu, 11 Feb 2016 15:50:33 -0800 Subject: [PATCH 1877/2089] [SPARK-13037][ML][PYSPARK] PySpark ml.recommendation support export/import PySpark ml.recommendation support export/import. Author: Kai Jiang Closes #11044 from vectorijk/spark-13037. --- python/pyspark/ml/recommendation.py | 31 +++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 08180a2f25eb9..ef9448855e175 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -16,7 +16,7 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc @@ -26,7 +26,8 @@ @inherit_doc -class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed): +class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, + MLWritable, MLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -81,6 +82,27 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> als_path = path + "/als" + >>> als.save(als_path) + >>> als2 = ALS.load(als_path) + >>> als.getMaxIter() + 5 + >>> model_path = path + "/als_model" + >>> model.save(model_path) + >>> model2 = ALSModel.load(model_path) + >>> model.rank == model2.rank + True + >>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect()) + True + >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect()) + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -274,7 +296,7 @@ def getNonnegative(self): return self.getOrDefault(self.nonnegative) -class ALSModel(JavaModel): +class ALSModel(JavaModel, MLWritable, MLReadable): """ Model fitted by ALS. @@ -308,9 +330,10 @@ def itemFactors(self): if __name__ == "__main__": import doctest + import pyspark.ml.recommendation from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.recommendation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.recommendation tests") From 2426eb3e167fece19831070594247e9481dbbe2a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 11 Feb 2016 15:53:45 -0800 Subject: [PATCH 1878/2089] [MINOR][ML][PYSPARK] Cleanup test cases of clustering.py Test cases should be removed from annotation of ```setXXX``` function, otherwise it will be parts of [Python API docs](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans.setInitMode). cc mengxr jkbradley Author: Yanbo Liang Closes #10975 from yanboliang/clustering-cleanup. --- python/pyspark/ml/clustering.py | 15 --------------- python/pyspark/ml/tests.py | 9 +++++++++ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 60d1c9aaec988..12afb88563633 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -113,10 +113,6 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, def setK(self, value): """ Sets the value of :py:attr:`k`. - - >>> algo = KMeans().setK(10) - >>> algo.getK() - 10 """ self._paramMap[self.k] = value return self @@ -132,13 +128,6 @@ def getK(self): def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. - - >>> algo = KMeans() - >>> algo.getInitMode() - 'k-means||' - >>> algo = algo.setInitMode("random") - >>> algo.getInitMode() - 'random' """ self._paramMap[self.initMode] = value return self @@ -154,10 +143,6 @@ def getInitMode(self): def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. - - >>> algo = KMeans().setInitSteps(10) - >>> algo.getInitSteps() - 10 """ self._paramMap[self.initSteps] = value return self diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 54806ee336666..e93a4e157b931 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -39,6 +39,7 @@ from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params @@ -243,6 +244,14 @@ def test_params(self): "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", "seed: random seed. (default: 41, current: 43)"])) + def test_kmeans_param(self): + algo = KMeans() + self.assertEqual(algo.getInitMode(), "k-means||") + algo.setK(10) + self.assertEqual(algo.getK(), 10) + algo.setInitSteps(10) + self.assertEqual(algo.getInitSteps(), 10) + def test_hasseed(self): noSeedSpecd = TestParams() withSeedSpecd = TestParams(seed=42) From 30e00955663dfe6079effe4465bbcecedb5d93b9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 11 Feb 2016 15:55:40 -0800 Subject: [PATCH 1879/2089] [SPARK-13035][ML][PYSPARK] PySpark ml.clustering support export/import PySpark ml.clustering support export/import. Author: Yanbo Liang Closes #10999 from yanboliang/spark-13035. --- python/pyspark/ml/clustering.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 12afb88563633..f156eda125bf2 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,7 +16,7 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc @@ -24,7 +24,7 @@ __all__ = ['KMeans', 'KMeansModel'] -class KMeansModel(JavaModel): +class KMeansModel(JavaModel, MLWritable, MLReadable): """ Model fitted by KMeans. @@ -46,7 +46,8 @@ def computeCost(self, dataset): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, + MLWritable, MLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -69,6 +70,25 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> kmeans_path = path + "/kmeans" + >>> kmeans.save(kmeans_path) + >>> kmeans2 = KMeans.load(kmeans_path) + >>> kmeans2.getK() + 2 + >>> model_path = path + "/kmeans_model" + >>> model.save(model_path) + >>> model2 = KMeansModel.load(model_path) + >>> model.clusterCenters()[0] == model2.clusterCenters()[0] + array([ True, True], dtype=bool) + >>> model.clusterCenters()[1] == model2.clusterCenters()[1] + array([ True, True], dtype=bool) + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.5.0 """ @@ -157,9 +177,10 @@ def getInitSteps(self): if __name__ == "__main__": import doctest + import pyspark.ml.clustering from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.clustering tests") From b35467388612167f0bc3d17142c21a406f6c620d Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 11 Feb 2016 16:42:44 -0800 Subject: [PATCH 1880/2089] [SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could modify the existing behavior of `hasParam` or add an additional method with this functionality. In Python: ```python from pyspark.ml.classification import NaiveBayes nb = NaiveBayes() print nb.hasParam("smoothing") print nb.hasParam("notAParam") ``` produces: > True > AttributeError: 'NaiveBayes' object has no attribute 'notAParam' However, in Scala: ```scala import org.apache.spark.ml.classification.NaiveBayes val nb = new NaiveBayes() nb.hasParam("smoothing") nb.hasParam("notAParam") ``` produces: > true > false cc holdenk Author: sethah Closes #10962 from sethah/SPARK-13047. --- python/pyspark/ml/param/__init__.py | 7 +++++-- python/pyspark/ml/tests.py | 9 +++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index ea86d6aeb8b31..bbf83f0310dd7 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -179,8 +179,11 @@ def hasParam(self, paramName): Tests whether this instance contains a param with a given (string) name. """ - param = self._resolveParam(paramName) - return param in self.params + if isinstance(paramName, str): + p = getattr(self, paramName, None) + return isinstance(p, Param) + else: + raise TypeError("hasParam(): paramName must be a string") @since("1.4.0") def getOrDefault(self, param): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e93a4e157b931..5fcfa9e61f6da 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -209,6 +209,11 @@ def test_param(self): self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") self.assertTrue(maxIter.parent == testParams.uid) + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + def test_params(self): testParams = TestParams() maxIter = testParams.maxIter @@ -218,7 +223,7 @@ def test_params(self): params = testParams.params self.assertEqual(params, [inputCol, maxIter, seed]) - self.assertTrue(testParams.hasParam(maxIter)) + self.assertTrue(testParams.hasParam(maxIter.name)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -227,7 +232,7 @@ def test_params(self): self.assertTrue(testParams.isSet(maxIter)) self.assertEqual(testParams.getMaxIter(), 100) - self.assertTrue(testParams.hasParam(inputCol)) + self.assertTrue(testParams.hasParam(inputCol.name)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) From a5257048d74359c3fa7810009be1d60d370e2896 Mon Sep 17 00:00:00 2001 From: Liu Xiang Date: Thu, 11 Feb 2016 17:28:37 -0800 Subject: [PATCH 1881/2089] [SPARK-12765][ML][COUNTVECTORIZER] fix CountVectorizer.transform's lost transformSchema https://issues.apache.org/jira/browse/SPARK-12765 Author: Liu Xiang Closes #10720 from sloth2012/sloth. --- .../main/scala/org/apache/spark/ml/feature/CountVectorizer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 10dcda2382f48..d5cb05f29bf69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -210,6 +210,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) From b10af5e238ce2051be2bf4d7ddda181d34cbb69a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 11 Feb 2016 18:00:03 -0800 Subject: [PATCH 1882/2089] [SPARK-12915][SQL] add SQL metrics of numOutputRows for whole stage codegen This PR add SQL metrics (numOutputRows) for generated operators (same as non-generated), the cost is about 0.2 nano seconds per row. gen metrics Author: Davies Liu Closes #11170 from davies/gen_metric. --- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 7 ++++ .../sql/execution/WholeStageCodegen.scala | 19 ++++++++++- .../aggregate/TungstenAggregate.scala | 5 +++ .../spark/sql/execution/basicOperators.scala | 8 ++++- .../execution/joins/BroadcastHashJoin.scala | 7 +++- .../BenchmarkWholeStageCodegen.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 34 +++++++++---------- .../sql/util/DataFrameCallbackSuite.scala | 18 +++++----- 9 files changed, 71 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7aa08fb63053b..c5b2b7d11893c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1775,7 +1775,7 @@ class DataFrame private[sql]( private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { try { df.queryExecution.executedPlan.foreach { plan => - plan.metrics.valuesIterator.foreach(_.reset()) + plan.resetMetrics() } val start = System.nanoTime() val result = action(df) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 3cc99d3c7b1b2..c72b8dc70708f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -77,6 +77,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty + /** + * Reset all the metrics. + */ + private[sql] def resetMetrics(): Unit = { + metrics.valuesIterator.foreach(_.reset()) + } + /** * Return a LongSQLMetric according to the name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 30f74fc14f6c6..f35efb5b24b1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} -import org.apache.spark.util.Utils +import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric} /** * An interface for those physical operators that support codegen. @@ -42,6 +42,19 @@ trait CodegenSupport extends SparkPlan { case _ => nodeName.toLowerCase } + /** + * Creates a metric using the specified name. + * + * @return name of the variable representing the metric + */ + def metricTerm(ctx: CodegenContext, name: String): String = { + val metric = ctx.addReferenceObj(name, longMetric(name)) + val value = ctx.freshName("metricValue") + val cls = classOf[LongSQLMetricValue].getName + ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();") + value + } + /** * Whether this SparkPlan support whole stage codegen or not. */ @@ -316,6 +329,10 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } + private[sql] override def resetMetrics(): Unit = { + plan.foreach(_.resetMetrics()) + } + override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a6950f805a113..852203f3743dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -202,6 +202,7 @@ case class TungstenAggregate( | } """.stripMargin) + val numOutput = metricTerm(ctx, "numOutputRows") s""" | if (!$initAgg) { | $initAgg = true; @@ -210,6 +211,7 @@ case class TungstenAggregate( | // output the result | ${genResult.trim} | + | $numOutput.add(1); | ${consume(ctx, resultVars).trim} | } """.stripMargin @@ -297,6 +299,7 @@ case class TungstenAggregate( val peakMemory = Math.max(mapMemory, sorterMemory) val metrics = TaskContext.get().taskMetrics() metrics.incPeakExecutionMemory(peakMemory) + // TODO: update data size and spill size if (sorter == null) { // not spilled @@ -456,6 +459,7 @@ case class TungstenAggregate( val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + val numOutput = metricTerm(ctx, "numOutputRows") s""" if (!$initAgg) { @@ -465,6 +469,7 @@ case class TungstenAggregate( // output the result while ($iterTerm.next()) { + $numOutput.add(1); UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); $outputCode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 949acb9aca762..4b82d5563460b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.types.LongType import org.apache.spark.util.random.PoissonSampler @@ -78,6 +78,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val numOutput = metricTerm(ctx, "numOutputRows") val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -90,6 +91,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit s""" | ${eval.code} | if ($nullCheck ${eval.value}) { + | $numOutput.add(1); | ${consume(ctx, ctx.currentVars)} | } """.stripMargin @@ -159,6 +161,8 @@ case class Range( } protected override def doProduce(ctx: CodegenContext): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") val partitionEnd = ctx.freshName("partitionEnd") @@ -204,6 +208,8 @@ case class Range( | } else { | $partitionEnd = end.longValue(); | } + | + | $numOutput.add(($partitionEnd - $number) / ${step}L); | } """.stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 35c7963b48c4a..985e74011daa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -163,6 +163,7 @@ case class BroadcastHashJoin( case BuildRight => input ++ buildColumns } + val numOutput = metricTerm(ctx, "numOutputRows") val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars @@ -170,11 +171,15 @@ case class BroadcastHashJoin( s""" | ${ev.code} | if (!${ev.isNull} && ${ev.value}) { + | $numOutput.add(1); | ${consume(ctx, resultVars)} | } """.stripMargin } else { - consume(ctx, resultVars) + s""" + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index dc6c647a4a95f..1c7e69f30fb48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -63,7 +63,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X - rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X + rang/filter/sum codegen=true 897 / 1022 584.6 1.7 16.4X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d24625a5351ee..f4bc9e501c21c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -298,24 +298,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) - } + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index d3191d3aead95..15a95623d1e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.sql.{functions, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegen} import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -92,17 +92,19 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - metrics += qe.executedPlan.longMetric("numOutputRows").value.value + val metric = qe.executedPlan match { + case w: WholeStageCodegen => w.plan.longMetric("numOutputRows") + case other => other.longMetric("numOutputRows") + } + metrics += metric.value.value } } sqlContext.listenerManager.register(listener) - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() - } + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) assert(metrics(0) === 1) From 8121a4b1cb4d7efa84a5e9e8e16d6656cdb79b85 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 11 Feb 2016 18:23:44 -0800 Subject: [PATCH 1883/2089] [SPARK-13277][BUILD] Follow-up ANTLR warnings are treated as build errors It is possible to create faulty but legal ANTLR grammars. ANTLR will produce warnings but also a valid compileable parser. This PR makes sure we treat such warnings as build errors. cc rxin / viirya Author: Herman van Hovell Closes #11174 from hvanhovell/ANTLR-warnings-as-errors. --- project/SparkBuild.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 550b5bad8a46a..6eba58c87c39c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -433,9 +433,12 @@ object Catalyst { } // Generate the parser. - antlr.process - if (antlr.getNumErrors > 0) { - log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) + antlr.process() + val errorState = org.antlr.tool.ErrorManager.getErrorState + if (errorState.errors > 0) { + sys.error("ANTLR: Caught %d build errors.".format(errorState.errors)) + } else if (errorState.warnings > 0) { + sys.error("ANTLR: Caught %d build warnings.".format(errorState.warnings)) } // Return all generated java files. From 5f1c359069545e75dfe83757c67a4be80428d342 Mon Sep 17 00:00:00 2001 From: Earthson Lu Date: Thu, 11 Feb 2016 18:31:46 -0800 Subject: [PATCH 1884/2089] [SPARK-12746][ML] ArrayType(_, true) should also accept ArrayType(_, false) https://issues.apache.org/jira/browse/SPARK-12746 Author: Earthson Lu Closes #10697 from Earthson/SPARK-12746. --- .../scala/org/apache/spark/ml/feature/CountVectorizer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index d5cb05f29bf69..a6dfe58e56f0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -71,7 +71,8 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { validateParams() - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } From d3e2e202994e063856c192e9fdd0541777b88e0e Mon Sep 17 00:00:00 2001 From: Tommy YU Date: Thu, 11 Feb 2016 18:38:49 -0800 Subject: [PATCH 1885/2089] [SPARK-13153][PYSPARK] ML persistence failed when handle no default value parameter Fix this defect by check default value exist or not. yanboliang Please help to review. Author: Tommy YU Closes #11043 from Wenpei/spark-13153-handle-param-withnodefaultvalue. --- python/pyspark/ml/wrapper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index d4d48eb2150e3..f8feaa1dfa2be 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -79,8 +79,9 @@ def _transfer_params_from_java(self): for param in self.params: if self._java_obj.hasParam(param.name): java_param = self._java_obj.getParam(param.name) - value = _java2py(sc, self._java_obj.getOrDefault(java_param)) - self._paramMap[param] = value + if self._java_obj.isDefined(java_param): + value = _java2py(sc, self._java_obj.getOrDefault(java_param)) + self._paramMap[param] = value @staticmethod def _empty_java_param_map(): From a2c7dcf61f33fa1897c950d2d905651103c170ea Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 11 Feb 2016 21:37:53 -0600 Subject: [PATCH 1886/2089] [SPARK-7889][WEBUI] HistoryServer updates UI for incomplete apps When the HistoryServer is showing an incomplete app, it needs to check if there is a newer version of the app available. It does this by checking if a version of the app has been loaded with a larger *filesize*. If so, it detaches the current UI, attaches the new one, and redirects back to the same URL to show the new UI. https://issues.apache.org/jira/browse/SPARK-7889 Author: Steve Loughran Author: Imran Rashid Closes #11118 from squito/SPARK-7889-alternate. --- .../deploy/history/ApplicationCache.scala | 665 ++++++++++++++++++ .../history/ApplicationHistoryProvider.scala | 42 +- .../deploy/history/FsHistoryProvider.scala | 149 +++- .../spark/deploy/history/HistoryPage.scala | 2 +- .../spark/deploy/history/HistoryServer.scala | 78 +- .../scheduler/EventLoggingListener.scala | 7 + .../history/ApplicationCacheSuite.scala | 488 +++++++++++++ .../deploy/history/HistoryServerSuite.scala | 224 +++++- docs/monitoring.md | 70 +- project/MimaExcludes.scala | 3 + 10 files changed, 1654 insertions(+), 74 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala create mode 100644 core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala new file mode 100644 index 0000000000000..e2fda29044385 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala @@ -0,0 +1,665 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.util.NoSuchElementException +import javax.servlet.{DispatcherType, Filter, FilterChain, FilterConfig, ServletException, ServletRequest, ServletResponse} +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.codahale.metrics.{Counter, MetricRegistry, Timer} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalListener, RemovalNotification} +import org.eclipse.jetty.servlet.FilterHolder + +import org.apache.spark.Logging +import org.apache.spark.metrics.source.Source +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Clock + +/** + * Cache for applications. + * + * Completed applications are cached for as long as there is capacity for them. + * Incompleted applications have their update time checked on every + * retrieval; if the cached entry is out of date, it is refreshed. + * + * @note there must be only one instance of [[ApplicationCache]] in a + * JVM at a time. This is because a static field in [[ApplicationCacheCheckFilterRelay]] + * keeps a reference to the cache so that HTTP requests on the attempt-specific web UIs + * can probe the current cache to see if the attempts have changed. + * + * Creating multiple instances will break this routing. + * @param operations implementation of record access operations + * @param retainedApplications number of retained applications + * @param clock time source + */ +private[history] class ApplicationCache( + val operations: ApplicationCacheOperations, + val retainedApplications: Int, + val clock: Clock) extends Logging { + + /** + * Services the load request from the cache. + */ + private val appLoader = new CacheLoader[CacheKey, CacheEntry] { + + /** the cache key doesn't match a cached entry, or the entry is out-of-date, so load it. */ + override def load(key: CacheKey): CacheEntry = { + loadApplicationEntry(key.appId, key.attemptId) + } + + } + + /** + * Handler for callbacks from the cache of entry removal. + */ + private val removalListener = new RemovalListener[CacheKey, CacheEntry] { + + /** + * Removal event notifies the provider to detach the UI. + * @param rm removal notification + */ + override def onRemoval(rm: RemovalNotification[CacheKey, CacheEntry]): Unit = { + metrics.evictionCount.inc() + val key = rm.getKey + logDebug(s"Evicting entry ${key}") + operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().ui) + } + } + + /** + * The cache of applications. + * + * Tagged as `protected` so as to allow subclasses in tests to accesss it directly + */ + protected val appCache: LoadingCache[CacheKey, CacheEntry] = { + CacheBuilder.newBuilder() + .maximumSize(retainedApplications) + .removalListener(removalListener) + .build(appLoader) + } + + /** + * The metrics which are updated as the cache is used. + */ + val metrics = new CacheMetrics("history.cache") + + init() + + /** + * Perform any startup operations. + * + * This includes declaring this instance as the cache to use in the + * [[ApplicationCacheCheckFilterRelay]]. + */ + private def init(): Unit = { + ApplicationCacheCheckFilterRelay.setApplicationCache(this) + } + + /** + * Stop the cache. + * This will reset the relay in [[ApplicationCacheCheckFilterRelay]]. + */ + def stop(): Unit = { + ApplicationCacheCheckFilterRelay.resetApplicationCache() + } + + /** + * Get an entry. + * + * Cache fetch/refresh will have taken place by the time this method returns. + * @param appAndAttempt application to look up in the format needed by the history server web UI, + * `appId/attemptId` or `appId`. + * @return the entry + */ + def get(appAndAttempt: String): SparkUI = { + val parts = splitAppAndAttemptKey(appAndAttempt) + get(parts._1, parts._2) + } + + /** + * Get the Spark UI, converting a lookup failure from an exception to `None`. + * @param appAndAttempt application to look up in the format needed by the history server web UI, + * `appId/attemptId` or `appId`. + * @return the entry + */ + def getSparkUI(appAndAttempt: String): Option[SparkUI] = { + try { + val ui = get(appAndAttempt) + Some(ui) + } catch { + case NonFatal(e) => e.getCause() match { + case nsee: NoSuchElementException => + None + case cause: Exception => throw cause + } + } + } + + /** + * Get the associated spark UI. + * + * Cache fetch/refresh will have taken place by the time this method returns. + * @param appId application ID + * @param attemptId optional attempt ID + * @return the entry + */ + def get(appId: String, attemptId: Option[String]): SparkUI = { + lookupAndUpdate(appId, attemptId)._1.ui + } + + /** + * Look up the entry; update it if needed. + * @param appId application ID + * @param attemptId optional attempt ID + * @return the underlying cache entry -which can have its timestamp changed, and a flag to + * indicate that the entry has changed + */ + private def lookupAndUpdate(appId: String, attemptId: Option[String]): (CacheEntry, Boolean) = { + metrics.lookupCount.inc() + val cacheKey = CacheKey(appId, attemptId) + var entry = appCache.getIfPresent(cacheKey) + var updated = false + if (entry == null) { + // no entry, so fetch without any post-fetch probes for out-of-dateness + // this will trigger a callback to loadApplicationEntry() + entry = appCache.get(cacheKey) + } else if (!entry.completed) { + val now = clock.getTimeMillis() + log.debug(s"Probing at time $now for updated application $cacheKey -> $entry") + metrics.updateProbeCount.inc() + updated = time(metrics.updateProbeTimer) { + entry.updateProbe() + } + if (updated) { + logDebug(s"refreshing $cacheKey") + metrics.updateTriggeredCount.inc() + appCache.refresh(cacheKey) + // and repeat the lookup + entry = appCache.get(cacheKey) + } else { + // update the probe timestamp to the current time + entry.probeTime = now + } + } + (entry, updated) + } + + /** + * This method is visible for testing. + * + * It looks up the cached entry *and returns a clone of it*. + * This ensures that the cached entries never leak + * @param appId application ID + * @param attemptId optional attempt ID + * @return a new entry with shared SparkUI, but copies of the other fields. + */ + def lookupCacheEntry(appId: String, attemptId: Option[String]): CacheEntry = { + val entry = lookupAndUpdate(appId, attemptId)._1 + new CacheEntry(entry.ui, entry.completed, entry.updateProbe, entry.probeTime) + } + + /** + * Probe for an application being updated. + * @param appId application ID + * @param attemptId attempt ID + * @return true if an update has been triggered + */ + def checkForUpdates(appId: String, attemptId: Option[String]): Boolean = { + val (entry, updated) = lookupAndUpdate(appId, attemptId) + updated + } + + /** + * Size probe, primarily for testing. + * @return size + */ + def size(): Long = appCache.size() + + /** + * Emptiness predicate, primarily for testing. + * @return true if the cache is empty + */ + def isEmpty: Boolean = appCache.size() == 0 + + /** + * Time a closure, returning its output. + * @param t timer + * @param f function + * @tparam T type of return value of time + * @return the result of the function. + */ + private def time[T](t: Timer)(f: => T): T = { + val timeCtx = t.time() + try { + f + } finally { + timeCtx.close() + } + } + + /** + * Load the Spark UI via [[ApplicationCacheOperations.getAppUI()]], + * then attach it to the web UI via [[ApplicationCacheOperations.attachSparkUI()]]. + * + * If the application is incomplete, it has the [[ApplicationCacheCheckFilter]] + * added as a filter to the HTTP requests, so that queries on the UI will trigger + * update checks. + * + * The generated entry contains the UI and the current timestamp. + * The timer [[metrics.loadTimer]] tracks the time taken to load the UI. + * + * @param appId application ID + * @param attemptId optional attempt ID + * @return the cache entry + * @throws NoSuchElementException if there is no matching element + */ + @throws[NoSuchElementException] + def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { + + logDebug(s"Loading application Entry $appId/$attemptId") + metrics.loadCount.inc() + time(metrics.loadTimer) { + operations.getAppUI(appId, attemptId) match { + case Some(LoadedAppUI(ui, updateState)) => + val completed = ui.getApplicationInfoList.exists(_.attempts.last.completed) + if (completed) { + // completed spark UIs are attached directly + operations.attachSparkUI(appId, attemptId, ui, completed) + } else { + // incomplete UIs have the cache-check filter put in front of them. + ApplicationCacheCheckFilterRelay.registerFilter(ui, appId, attemptId) + operations.attachSparkUI(appId, attemptId, ui, completed) + } + // build the cache entry + val now = clock.getTimeMillis() + val entry = new CacheEntry(ui, completed, updateState, now) + logDebug(s"Loaded application $appId/$attemptId -> $entry") + entry + case None => + metrics.lookupFailureCount.inc() + // guava's cache logs via java.util log, so is of limited use. Hence: our own message + logInfo(s"Failed to load application attempt $appId/$attemptId") + throw new NoSuchElementException(s"no application with application Id '$appId'" + + attemptId.map { id => s" attemptId '$id'" }.getOrElse(" and no attempt Id")) + } + } + } + + /** + * Split up an `applicationId/attemptId` or `applicationId` key into the separate pieces. + * + * @param appAndAttempt combined key + * @return a tuple of the application ID and, if present, the attemptID + */ + def splitAppAndAttemptKey(appAndAttempt: String): (String, Option[String]) = { + val parts = appAndAttempt.split("/") + require(parts.length == 1 || parts.length == 2, s"Invalid app key $appAndAttempt") + val appId = parts(0) + val attemptId = if (parts.length > 1) Some(parts(1)) else None + (appId, attemptId) + } + + /** + * Merge an appId and optional attempt Id into a key of the form `applicationId/attemptId`. + * + * If there is an `attemptId`; `applicationId` if not. + * @param appId application ID + * @param attemptId optional attempt ID + * @return a unified string + */ + def mergeAppAndAttemptToKey(appId: String, attemptId: Option[String]): String = { + appId + attemptId.map { id => s"/$id" }.getOrElse("") + } + + /** + * String operator dumps the cache entries and metrics. + * @return a string value, primarily for testing and diagnostics + */ + override def toString: String = { + val sb = new StringBuilder(s"ApplicationCache(" + + s" retainedApplications= $retainedApplications)") + sb.append(s"; time= ${clock.getTimeMillis()}") + sb.append(s"; entry count= ${appCache.size()}\n") + sb.append("----\n") + appCache.asMap().asScala.foreach { + case(key, entry) => sb.append(s" $key -> $entry\n") + } + sb.append("----\n") + sb.append(metrics) + sb.append("----\n") + sb.toString() + } +} + +/** + * An entry in the cache. + * + * @param ui Spark UI + * @param completed Flag to indicated that the application has completed (and so + * does not need refreshing). + * @param updateProbe function to call to see if the application has been updated and + * therefore that the cached value needs to be refreshed. + * @param probeTime Times in milliseconds when the probe was last executed. + */ +private[history] final class CacheEntry( + val ui: SparkUI, + val completed: Boolean, + val updateProbe: () => Boolean, + var probeTime: Long) { + + /** string value is for test assertions */ + override def toString: String = { + s"UI $ui, completed=$completed, probeTime=$probeTime" + } +} + +/** + * Cache key: compares on `appId` and then, if non-empty, `attemptId`. + * The [[hashCode()]] function uses the same fields. + * @param appId application ID + * @param attemptId attempt ID + */ +private[history] final case class CacheKey(appId: String, attemptId: Option[String]) { + + override def toString: String = { + appId + attemptId.map { id => s"/$id" }.getOrElse("") + } +} + +/** + * Metrics of the cache + * @param prefix prefix to register all entries under + */ +private[history] class CacheMetrics(prefix: String) extends Source { + + /* metrics: counters and timers */ + val lookupCount = new Counter() + val lookupFailureCount = new Counter() + val evictionCount = new Counter() + val loadCount = new Counter() + val loadTimer = new Timer() + val updateProbeCount = new Counter() + val updateProbeTimer = new Timer() + val updateTriggeredCount = new Counter() + + /** all the counters: for registration and string conversion. */ + private val counters = Seq( + ("lookup.count", lookupCount), + ("lookup.failure.count", lookupFailureCount), + ("eviction.count", evictionCount), + ("load.count", loadCount), + ("update.probe.count", updateProbeCount), + ("update.triggered.count", updateTriggeredCount)) + + /** all metrics, including timers */ + private val allMetrics = counters ++ Seq( + ("load.timer", loadTimer), + ("update.probe.timer", updateProbeTimer)) + + /** + * Name of metric source + */ + override val sourceName = "ApplicationCache" + + override val metricRegistry: MetricRegistry = new MetricRegistry + + /** + * Startup actions. + * This includes registering metrics with [[metricRegistry]] + */ + private def init(): Unit = { + allMetrics.foreach { case (name, metric) => + metricRegistry.register(MetricRegistry.name(prefix, name), metric) + } + } + + override def toString: String = { + val sb = new StringBuilder() + counters.foreach { case (name, counter) => + sb.append(name).append(" = ").append(counter.getCount).append('\n') + } + sb.toString() + } +} + +/** + * API for cache events. That is: loading an App UI; and for + * attaching/detaching the UI to and from the Web UI. + */ +private[history] trait ApplicationCacheOperations { + + /** + * Get the application UI and the probe neededed to see if it has been updated. + * @param appId application ID + * @param attemptId attempt ID + * @return If found, the Spark UI and any history information to be used in the cache + */ + def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] + + /** + * Attach a reconstructed UI. + * @param appId application ID + * @param attemptId attempt ID + * @param ui UI + * @param completed flag to indicate that the UI has completed + */ + def attachSparkUI( + appId: String, + attemptId: Option[String], + ui: SparkUI, + completed: Boolean): Unit + + /** + * Detach a Spark UI. + * + * @param ui Spark UI + */ + def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit + +} + +/** + * This is a servlet filter which intercepts HTTP requests on application UIs and + * triggers checks for updated data. + * + * If the application cache indicates that the application has been updated, + * the filter returns a 302 redirect to the caller, asking them to re-request the web + * page. + * + * Because the application cache will detach and then re-attach the UI, when the caller + * repeats that request, it will now pick up the newly-updated web application. + * + * This does require the caller to handle 302 requests. Because of the ambiguity + * in how POST and PUT operations are responded to (that is, should a 307 be + * processed directly), the filter does not filter those requests. + * As the current web UIs are read-only, this is not an issue. If it were ever to + * support more HTTP verbs, then some support may be required. Perhaps, rather + * than sending a redirect, simply updating the value so that the next + * request will pick it up. + * + * Implementation note: there's some abuse of a shared global entry here because + * the configuration data passed to the servlet is just a string:string map. + */ +private[history] class ApplicationCacheCheckFilter() extends Filter with Logging { + + import ApplicationCacheCheckFilterRelay._ + var appId: String = _ + var attemptId: Option[String] = _ + + /** + * Bind the app and attempt ID, throwing an exception if no application ID was provided. + * @param filterConfig configuration + */ + override def init(filterConfig: FilterConfig): Unit = { + + appId = Option(filterConfig.getInitParameter(APP_ID)) + .getOrElse(throw new ServletException(s"Missing Parameter $APP_ID")) + attemptId = Option(filterConfig.getInitParameter(ATTEMPT_ID)) + logDebug(s"initializing filter $this") + } + + /** + * Filter the request. + * Either the caller is given a 302 redirect to the current URL, or the + * request is passed on to the SparkUI servlets. + * + * @param request HttpServletRequest + * @param response HttpServletResponse + * @param chain the rest of the request chain + */ + override def doFilter( + request: ServletRequest, + response: ServletResponse, + chain: FilterChain): Unit = { + + // nobody has ever implemented any other kind of servlet, yet + // this check is universal, just in case someone does exactly + // that on your classpath + if (!(request.isInstanceOf[HttpServletRequest])) { + throw new ServletException("This filter only works for HTTP/HTTPS") + } + val httpRequest = request.asInstanceOf[HttpServletRequest] + val httpResponse = response.asInstanceOf[HttpServletResponse] + val requestURI = httpRequest.getRequestURI + val operation = httpRequest.getMethod + + // if the request is for an attempt, check to see if it is in need of delete/refresh + // and have the cache update the UI if so + if (operation=="HEAD" || operation=="GET" + && checkForUpdates(requestURI, appId, attemptId)) { + // send a redirect back to the same location. This will be routed + // to the *new* UI + logInfo(s"Application Attempt $appId/$attemptId updated; refreshing") + val queryStr = Option(httpRequest.getQueryString).map("?" + _).getOrElse("") + val redirectUrl = httpResponse.encodeRedirectURL(requestURI + queryStr) + httpResponse.sendRedirect(redirectUrl) + } else { + chain.doFilter(request, response) + } + } + + override def destroy(): Unit = { + } + + override def toString: String = s"ApplicationCacheCheckFilter for $appId/$attemptId" +} + +/** + * Global state for the [[ApplicationCacheCheckFilter]] instances, so that they can relay cache + * probes to the cache. + * + * This is an ugly workaround for the limitation of servlets and filters in the Java servlet + * API; they are still configured on the model of a list of classnames and configuration + * strings in a `web.xml` field, rather than a chain of instances wired up by hand or + * via an injection framework. There is no way to directly configure a servlet filter instance + * with a reference to the application cache which is must use: some global state is needed. + * + * Here, [[ApplicationCacheCheckFilter]] is that global state; it relays all requests + * to the singleton [[ApplicationCache]] + * + * The field `applicationCache` must be set for the filters to work - + * this is done during the construction of [[ApplicationCache]], which requires that there + * is only one cache serving requests through the WebUI. + * + * *Important* In test runs, if there is more than one [[ApplicationCache]], the relay logic + * will break: filters may not find instances. Tests must not do that. + * + */ +private[history] object ApplicationCacheCheckFilterRelay extends Logging { + // name of the app ID entry in the filter configuration. Mandatory. + val APP_ID = "appId" + + // name of the attempt ID entry in the filter configuration. Optional. + val ATTEMPT_ID = "attemptId" + + // namer of the filter to register + val FILTER_NAME = "org.apache.spark.deploy.history.ApplicationCacheCheckFilter" + + /** the application cache to relay requests to */ + @volatile + private var applicationCache: Option[ApplicationCache] = None + + /** + * Set the application cache. Logs a warning if it is overwriting an existing value + * @param cache new cache + */ + def setApplicationCache(cache: ApplicationCache): Unit = { + applicationCache.foreach( c => logWarning(s"Overwriting application cache $c")) + applicationCache = Some(cache) + } + + /** + * Reset the application cache + */ + def resetApplicationCache(): Unit = { + applicationCache = None + } + + /** + * Check to see if there has been an update + * @param requestURI URI the request came in on + * @param appId application ID + * @param attemptId attempt ID + * @return true if an update was loaded for the app/attempt + */ + def checkForUpdates(requestURI: String, appId: String, attemptId: Option[String]): Boolean = { + + logDebug(s"Checking $appId/$attemptId from $requestURI") + applicationCache match { + case Some(cache) => + try { + cache.checkForUpdates(appId, attemptId) + } catch { + case ex: Exception => + // something went wrong. Keep going with the existing UI + logWarning(s"When checking for $appId/$attemptId from $requestURI", ex) + false + } + + case None => + logWarning("No application cache instance defined") + false + } + } + + + /** + * Register a filter for the web UI which checks for updates to the given app/attempt + * @param ui Spark UI to attach filters to + * @param appId application ID + * @param attemptId attempt ID + */ + def registerFilter( + ui: SparkUI, + appId: String, + attemptId: Option[String] ): Unit = { + require(ui != null) + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) + val holder = new FilterHolder() + holder.setClassName(FILTER_NAME) + holder.setInitParameter(APP_ID, appId) + attemptId.foreach( id => holder.setInitParameter(ATTEMPT_ID, id)) + require(ui.getHandlers != null, "null handlers") + ui.getHandlers.foreach { handler => + handler.addFilter(holder, "/*", enumDispatcher) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 5f5e0fe1c34d7..44661edfff90b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -33,7 +33,42 @@ private[spark] case class ApplicationAttemptInfo( private[spark] case class ApplicationHistoryInfo( id: String, name: String, - attempts: List[ApplicationAttemptInfo]) + attempts: List[ApplicationAttemptInfo]) { + + /** + * Has this application completed? + * @return true if the most recent attempt has completed + */ + def completed: Boolean = { + attempts.nonEmpty && attempts.head.completed + } +} + +/** + * A probe which can be invoked to see if a loaded Web UI has been updated. + * The probe is expected to be relative purely to that of the UI returned + * in the same [[LoadedAppUI]] instance. That is, whenever a new UI is loaded, + * the probe returned with it is the one that must be used to check for it + * being out of date; previous probes must be discarded. + */ +private[history] abstract class HistoryUpdateProbe { + /** + * Return true if the history provider has a later version of the application + * attempt than the one against this probe was constructed. + * @return + */ + def isUpdated(): Boolean +} + +/** + * All the information returned from a call to `getAppUI()`: the new UI + * and any required update state. + * @param ui Spark UI + * @param updateProbe probe to call to check on the update state of this application attempt + */ +private[history] case class LoadedAppUI( + ui: SparkUI, + updateProbe: () => Boolean) private[history] abstract class ApplicationHistoryProvider { @@ -49,9 +84,10 @@ private[history] abstract class ApplicationHistoryProvider { * * @param appId The application ID. * @param attemptId The application attempt ID (or None if there is no attempt ID). - * @return The application's UI, or None if application is not found. + * @return a [[LoadedAppUI]] instance containing the application's UI and any state information + * for update probes, or `None` if the application/attempt is not found. */ - def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] + def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] /** * Called when the server is shutting down. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9648959dbacb9..f885798760464 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID import java.util.concurrent.{Executors, ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -33,7 +33,6 @@ import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -42,6 +41,31 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * A class that provides application history from event logs stored in the file system. * This provider checks for new finished applications in the background periodically and * renders the history application UI by parsing the associated event logs. + * + * == How new and updated attempts are detected == + * + * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any + * entries in the log dir whose modification time is greater than the last scan time + * are considered new or updated. These are replayed to create a new [[FsApplicationAttemptInfo]] + * entry and update or create a matching [[FsApplicationHistoryInfo]] element in the list + * of applications. + * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the + * [[FsApplicationAttemptInfo]] is replaced by another one with a larger log size. + * - When [[updateProbe()]] is invoked to check if a loaded [[SparkUI]] + * instance is out of date, the log size of the cached instance is checked against the app last + * loaded by [[checkForLogs]]. + * + * The use of log size, rather than simply relying on modification times, is needed to + * address the following issues + * - some filesystems do not appear to update the `modtime` value whenever data is flushed to + * an open file output stream. Changes to the history may not be picked up. + * - the granularity of the `modtime` field may be 2+ seconds. Rapid changes to the FS can be + * missed. + * + * Tracking filesize works given the following invariant: the logs get bigger + * as new events are added. If a format was used in which this did not hold, the mechanism would + * break. Simple streaming of JSON-formatted events, as is implemented today, implicitly + * maintains this invariant. */ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) extends ApplicationHistoryProvider with Logging { @@ -77,9 +101,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() .setNameFormat("spark-history-task-%d").setDaemon(true).build()) - // The modification time of the newest log detected during the last scan. This is used - // to ignore logs that are older during subsequent scans, to avoid processing data that - // is already known. + // The modification time of the newest log detected during the last scan. Currently only + // used for logging msgs (logs are re-scanned based on file size, rather than modtime) private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted @@ -87,6 +110,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() + val fileToAppInfo = new mutable.HashMap[Path, FsApplicationAttemptInfo]() + // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] @@ -176,18 +201,21 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. + logDebug(s"Scheduling update thread every $UPDATE_INTERVAL_S seconds") pool.scheduleWithFixedDelay(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. pool.scheduleWithFixedDelay(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } + } else { + logDebug("Background update thread disabled for testing") } } override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values - override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => @@ -210,7 +238,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) - ui + LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize)) } } } @@ -243,12 +271,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private[history] def checkForLogs(): Unit = { try { val newLastScanTime = getNewLastScanTime() + logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) + // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => try { - !entry.isDirectory() && (entry.getModificationTime() >= lastScanTime) + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && prevFileSize < entry.getLen() } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -262,6 +293,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) entry1.getModificationTime() >= entry2.getModificationTime() } + if (logInfos.nonEmpty) { + logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") + } logInfos.grouped(20) .map { batch => replayExecutor.submit(new Runnable { @@ -356,7 +390,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { - case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully: $r") case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + "The application may have not started.") } @@ -511,6 +545,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") + // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, + // and when we read the file here. That is OK -- it may result in an unnecessary refresh + // when there is no update, but will not result in missing an update. We *must* prevent + // an error the other way -- if we report a size bigger (ie later) than the file that is + // actually read, we may never refresh the app. FileStatus is guaranteed to be static + // after it's created, so we get a file size that is no bigger than what is actually read. val logInput = EventLoggingListener.openEventLog(logPath, fs) try { val appListener = new ApplicationEventListener @@ -521,7 +561,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Without an app ID, new logs will render incorrectly in the listing page, so do not list or // try to show their UI. if (appListener.appId.isDefined) { - Some(new FsApplicationAttemptInfo( + val attemptInfo = new FsApplicationAttemptInfo( logPath.getName(), appListener.appName.getOrElse(NOT_STARTED), appListener.appId.getOrElse(logPath.getName()), @@ -530,7 +570,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.endTime.getOrElse(-1L), eventLog.getModificationTime(), appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted)) + appCompleted, + eventLog.getLen() + ) + fileToAppInfo(logPath) = attemptInfo + Some(attemptInfo) } else { None } @@ -564,12 +608,77 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) } + /** + * String description for diagnostics + * @return a summary of the component state + */ + override def toString: String = { + val header = s""" + | FsHistoryProvider: logdir=$logDir, + | last scan time=$lastScanTime + | Cached application count =${applications.size}} + """.stripMargin + val sb = new StringBuilder(header) + applications.foreach(entry => sb.append(entry._2).append("\n")) + sb.toString + } + + /** + * Look up an application attempt + * @param appId application ID + * @param attemptId Attempt ID, if set + * @return the matching attempt, if found + */ + def lookup(appId: String, attemptId: Option[String]): Option[FsApplicationAttemptInfo] = { + applications.get(appId).flatMap { appInfo => + appInfo.attempts.find(_.attemptId == attemptId) + } + } + + /** + * Return true iff a newer version of the UI is available. The check is based on whether the + * fileSize for the currently loaded UI is smaller than the file size the last time + * the logs were loaded. + * + * This is a very cheap operation -- the work of loading the new attempt was already done + * by [[checkForLogs]]. + * @param appId application to probe + * @param attemptId attempt to probe + * @param prevFileSize the file size of the logs for the currently displayed UI + */ + private def updateProbe( + appId: String, + attemptId: Option[String], + prevFileSize: Long)(): Boolean = { + lookup(appId, attemptId) match { + case None => + logDebug(s"Application Attempt $appId/$attemptId not found") + false + case Some(latest) => + prevFileSize < latest.fileSize + } + } } private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" } +/** + * Application attempt information. + * + * @param logPath path to the log file, or, for a legacy log, its directory + * @param name application name + * @param appId application ID + * @param attemptId optional attempt ID + * @param startTime start time (from playback) + * @param endTime end time (from playback). -1 if the application is incomplete. + * @param lastUpdated the modification time of the log file when this entry was built by replaying + * the history. + * @param sparkUser user running the application + * @param completed flag to indicate whether or not the application has completed. + * @param fileSize the size of the log file the last time the file was scanned for changes + */ private class FsApplicationAttemptInfo( val logPath: String, val name: String, @@ -579,10 +688,24 @@ private class FsApplicationAttemptInfo( endTime: Long, lastUpdated: Long, sparkUser: String, - completed: Boolean = true) + completed: Boolean, + val fileSize: Long) extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed) + attemptId, startTime, endTime, lastUpdated, sparkUser, completed) { + /** extend the superclass string value with the extra attributes of this class */ + override def toString: String = { + s"FsApplicationAttemptInfo($name, $appId," + + s" ${super.toString}, source=$logPath, size=$fileSize" + } +} + +/** + * Application history information + * @param id application ID + * @param name application name + * @param attempts list of attempts, most recent first. + */ private class FsApplicationHistoryInfo( id: String, override val name: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index cab7faefe89c3..2fad1120cdc8a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -30,7 +30,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean val allApps = parent.getApplicationList() - .filter(_.attempts.head.completed != requestedIncomplete) + .filter(_.completed != requestedIncomplete) val allAppsSize = allApps.size val providerConfig = parent.getProviderConfig() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 1f13d7db348ec..076bdc5c058e1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -23,7 +23,6 @@ import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.util.control.NonFatal -import com.google.common.cache._ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -31,7 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.{ShutdownHookManager, SystemClock, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -50,31 +49,16 @@ class HistoryServer( securityManager: SecurityManager, port: Int) extends WebUI(securityManager, securityManager.getSSLOptions("historyServer"), port, conf) - with Logging with UIRoot { + with Logging with UIRoot with ApplicationCacheOperations { // How many applications to retain private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) - private val appLoader = new CacheLoader[String, SparkUI] { - override def load(key: String): SparkUI = { - val parts = key.split("/") - require(parts.length == 1 || parts.length == 2, s"Invalid app key $key") - val ui = provider - .getAppUI(parts(0), if (parts.length > 1) Some(parts(1)) else None) - .getOrElse(throw new NoSuchElementException(s"no app with key $key")) - attachSparkUI(ui) - ui - } - } + // application + private val appCache = new ApplicationCache(this, retainedApplications, new SystemClock()) - private val appCache = CacheBuilder.newBuilder() - .maximumSize(retainedApplications) - .removalListener(new RemovalListener[String, SparkUI] { - override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = { - detachSparkUI(rm.getValue()) - } - }) - .build(appLoader) + // and its metrics, for testing as well as monitoring + val cacheMetrics = appCache.metrics private val loaderServlet = new HttpServlet { protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { @@ -117,17 +101,7 @@ class HistoryServer( } def getSparkUI(appKey: String): Option[SparkUI] = { - try { - val ui = appCache.get(appKey) - Some(ui) - } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - None - - case cause: Exception => throw cause - } - } + appCache.getSparkUI(appKey) } initialize() @@ -160,21 +134,36 @@ class HistoryServer( override def stop() { super.stop() provider.stop() + appCache.stop() } /** Attach a reconstructed UI to this server. Only valid after bind(). */ - private def attachSparkUI(ui: SparkUI) { + override def attachSparkUI( + appId: String, + attemptId: Option[String], + ui: SparkUI, + completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") ui.getHandlers.foreach(attachHandler) addFilters(ui.getHandlers, conf) } /** Detach a reconstructed UI from this server. Only valid after bind(). */ - private def detachSparkUI(ui: SparkUI) { + override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") ui.getHandlers.foreach(detachHandler) } + /** + * Get the application UI and whether or not it is completed + * @param appId application ID + * @param attemptId attempt ID + * @return If found, the Spark UI and any history information to be used in the cache + */ + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { + provider.getAppUI(appId, attemptId) + } + /** * Returns a list of available applications, in descending order according to their end time. * @@ -202,9 +191,15 @@ class HistoryServer( */ def getProviderConfig(): Map[String, String] = provider.getConfig() + /** + * Load an application UI and attach it to the web server. + * @param appId application ID + * @param attemptId optional attempt ID + * @return true if the application was found and loaded. + */ private def loadAppUi(appId: String, attemptId: Option[String]): Boolean = { try { - appCache.get(appId + attemptId.map { id => s"/$id" }.getOrElse("")) + appCache.get(appId, attemptId) true } catch { case NonFatal(e) => e.getCause() match { @@ -216,6 +211,17 @@ class HistoryServer( } } + /** + * String value for diagnostics. + * @return a multi-line description of the server state. + */ + override def toString: String = { + s""" + | History Server; + | provider = $provider + | cache = $appCache + """.stripMargin + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 01fee46e73a80..8354e2a6112a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -224,6 +224,13 @@ private[spark] class EventLoggingListener( } } fileSystem.rename(new Path(logPath + IN_PROGRESS), target) + // touch file to ensure modtime is current across those filesystems where rename() + // does not set it, -and which support setTimes(); it's a no-op on most object stores + try { + fileSystem.setTimes(target, System.currentTimeMillis(), -1) + } catch { + case e: Exception => logDebug(s"failed to set time of $target", e) + } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala new file mode 100644 index 0000000000000..de6680c61006d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.util.{Date, NoSuchElementException} +import javax.servlet.Filter +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.language.postfixOps + +import com.codahale.metrics.Counter +import com.google.common.cache.LoadingCache +import com.google.common.util.concurrent.UncheckedExecutionException +import org.eclipse.jetty.servlet.ServletContextHandler +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.Matchers +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.status.api.v1.{ApplicationAttemptInfo => AttemptInfo, ApplicationInfo} +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{Clock, ManualClock, Utils} + +class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar with Matchers { + + /** + * subclass with access to the cache internals + * @param retainedApplications number of retained applications + */ + class TestApplicationCache( + operations: ApplicationCacheOperations = new StubCacheOperations(), + retainedApplications: Int, + clock: Clock = new ManualClock(0)) + extends ApplicationCache(operations, retainedApplications, clock) { + + def cache(): LoadingCache[CacheKey, CacheEntry] = appCache + } + + /** + * Stub cache operations. + * The state is kept in a map of [[CacheKey]] to [[CacheEntry]], + * the `probeTime` field in the cache entry setting the timestamp of the entry + */ + class StubCacheOperations extends ApplicationCacheOperations with Logging { + + /** map to UI instances, including timestamps, which are used in update probes */ + val instances = mutable.HashMap.empty[CacheKey, CacheEntry] + + /** Map of attached spark UIs */ + val attached = mutable.HashMap.empty[CacheKey, SparkUI] + + var getAppUICount = 0L + var attachCount = 0L + var detachCount = 0L + var updateProbeCount = 0L + + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { + logDebug(s"getAppUI($appId, $attemptId)") + getAppUICount += 1 + instances.get(CacheKey(appId, attemptId)).map( e => + LoadedAppUI(e.ui, updateProbe(appId, attemptId, e.probeTime))) + } + + override def attachSparkUI( + appId: String, + attemptId: Option[String], + ui: SparkUI, + completed: Boolean): Unit = { + logDebug(s"attachSparkUI($appId, $attemptId, $ui)") + attachCount += 1 + attached += (CacheKey(appId, attemptId) -> ui) + } + + def putAndAttach( + appId: String, + attemptId: Option[String], + completed: Boolean, + started: Long, + ended: Long, + timestamp: Long): SparkUI = { + val ui = putAppUI(appId, attemptId, completed, started, ended, timestamp) + attachSparkUI(appId, attemptId, ui, completed) + ui + } + + def putAppUI( + appId: String, + attemptId: Option[String], + completed: Boolean, + started: Long, + ended: Long, + timestamp: Long): SparkUI = { + val ui = newUI(appId, attemptId, completed, started, ended) + putInstance(appId, attemptId, ui, completed, timestamp) + ui + } + + def putInstance( + appId: String, + attemptId: Option[String], + ui: SparkUI, + completed: Boolean, + timestamp: Long): Unit = { + instances += (CacheKey(appId, attemptId) -> + new CacheEntry(ui, completed, updateProbe(appId, attemptId, timestamp), timestamp)) + } + + /** + * Detach a reconstructed UI + * + * @param ui Spark UI + */ + override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { + logDebug(s"detachSparkUI($appId, $attemptId, $ui)") + detachCount += 1 + var name = ui.getAppName + val key = CacheKey(appId, attemptId) + attached.getOrElse(key, { throw new java.util.NoSuchElementException() }) + attached -= key + } + + /** + * Lookup from the internal cache of attached UIs + */ + def getAttached(appId: String, attemptId: Option[String]): Option[SparkUI] = { + attached.get(CacheKey(appId, attemptId)) + } + + /** + * The update probe. + * @param appId application to probe + * @param attemptId attempt to probe + * @param updateTime timestamp of this UI load + */ + private[history] def updateProbe( + appId: String, + attemptId: Option[String], + updateTime: Long)(): Boolean = { + updateProbeCount += 1 + logDebug(s"isUpdated($appId, $attemptId, ${updateTime})") + val entry = instances.get(CacheKey(appId, attemptId)).get + val updated = entry.probeTime > updateTime + logDebug(s"entry = $entry; updated = $updated") + updated + } + } + + /** + * Create a new UI. The info/attempt info classes here are from the package + * `org.apache.spark.status.api.v1`, not the near-equivalents from the history package + */ + def newUI( + name: String, + attemptId: Option[String], + completed: Boolean, + started: Long, + ended: Long): SparkUI = { + val info = new ApplicationInfo(name, name, Some(1), Some(1), Some(1), Some(64), + Seq(new AttemptInfo(attemptId, new Date(started), new Date(ended), + new Date(ended), ended - started, "user", completed))) + val ui = mock[SparkUI] + when(ui.getApplicationInfoList).thenReturn(List(info).iterator) + when(ui.getAppName).thenReturn(name) + when(ui.appName).thenReturn(name) + val handler = new ServletContextHandler() + when(ui.getHandlers).thenReturn(Seq(handler)) + ui + } + + /** + * Test operations on completed UIs: they are loaded on demand, entries + * are removed on overload. + * + * This effectively tests the original behavior of the history server's cache. + */ + test("Completed UI get") { + val operations = new StubCacheOperations() + val clock = new ManualClock(1) + implicit val cache = new ApplicationCache(operations, 2, clock) + val metrics = cache.metrics + // cache misses + val app1 = "app-1" + assertNotFound(app1, None) + assertMetric("lookupCount", metrics.lookupCount, 1) + assertMetric("lookupFailureCount", metrics.lookupFailureCount, 1) + assert(1 === operations.getAppUICount, "getAppUICount") + assertNotFound(app1, None) + assert(2 === operations.getAppUICount, "getAppUICount") + assert(0 === operations.attachCount, "attachCount") + + val now = clock.getTimeMillis() + // add the entry + operations.putAppUI(app1, None, true, now, now, now) + + // make sure its local + operations.getAppUI(app1, None).get + operations.getAppUICount = 0 + // now expect it to be found + val cacheEntry = cache.lookupCacheEntry(app1, None) + assert(1 === cacheEntry.probeTime) + assert(cacheEntry.completed) + // assert about queries made of the opereations + assert(1 === operations.getAppUICount, "getAppUICount") + assert(1 === operations.attachCount, "attachCount") + + // and in the map of attached + assert(operations.getAttached(app1, None).isDefined, s"attached entry '1' from $cache") + + // go forward in time + clock.setTime(10) + val time2 = clock.getTimeMillis() + val cacheEntry2 = cache.get(app1) + // no more refresh as this is a completed app + assert(1 === operations.getAppUICount, "getAppUICount") + assert(0 === operations.updateProbeCount, "updateProbeCount") + assert(0 === operations.detachCount, "attachCount") + + // evict the entry + operations.putAndAttach("2", None, true, time2, time2, time2) + operations.putAndAttach("3", None, true, time2, time2, time2) + cache.get("2") + cache.get("3") + + // there should have been a detachment here + assert(1 === operations.detachCount, s"detach count from $cache") + // and entry app1 no longer attached + assert(operations.getAttached(app1, None).isEmpty, s"get($app1) in $cache") + val appId = "app1" + val attemptId = Some("_01") + val time3 = clock.getTimeMillis() + operations.putAppUI(appId, attemptId, false, time3, 0, time3) + // expect an error here + assertNotFound(appId, None) + } + + test("Test that if an attempt ID is is set, it must be used in lookups") { + val operations = new StubCacheOperations() + val clock = new ManualClock(1) + implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) + val appId = "app1" + val attemptId = Some("_01") + operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0, 0) + assertNotFound(appId, None) + } + + /** + * Test that incomplete apps are not probed for updates during the time window, + * but that they are checked if that window has expired and they are not completed. + * Then, if they have changed, the old entry is replaced by a new one. + */ + test("Incomplete apps refreshed") { + val operations = new StubCacheOperations() + val clock = new ManualClock(50) + val window = 500 + implicit val cache = new ApplicationCache(operations, retainedApplications = 5, clock = clock) + val metrics = cache.metrics + // add the incomplete app + // add the entry + val started = clock.getTimeMillis() + val appId = "app1" + val attemptId = Some("001") + operations.putAppUI(appId, attemptId, false, started, 0, started) + val firstEntry = cache.lookupCacheEntry(appId, attemptId) + assert(started === firstEntry.probeTime, s"timestamp in $firstEntry") + assert(!firstEntry.completed, s"entry is complete: $firstEntry") + assertMetric("lookupCount", metrics.lookupCount, 1) + + assert(0 === operations.updateProbeCount, "expected no update probe on that first get") + + val checkTime = window * 2 + clock.setTime(checkTime) + val entry3 = cache.lookupCacheEntry(appId, attemptId) + assert(firstEntry !== entry3, s"updated entry test from $cache") + assertMetric("lookupCount", metrics.lookupCount, 2) + assertMetric("updateProbeCount", metrics.updateProbeCount, 1) + assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 0) + assert(1 === operations.updateProbeCount, s"refresh count in $cache") + assert(0 === operations.detachCount, s"detach count") + assert(entry3.probeTime === checkTime) + + val updateTime = window * 3 + // update the cached value + val updatedApp = operations.putAppUI(appId, attemptId, true, started, updateTime, updateTime) + val endTime = window * 10 + clock.setTime(endTime) + logDebug(s"Before operation = $cache") + val entry5 = cache.lookupCacheEntry(appId, attemptId) + assertMetric("lookupCount", metrics.lookupCount, 3) + assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + // the update was triggered + assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 1) + assert(updatedApp === entry5.ui, s"UI {$updatedApp} did not match entry {$entry5} in $cache") + + // at which point, the refreshes stop + clock.setTime(window * 20) + assertCacheEntryEquals(appId, attemptId, entry5) + assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + } + + /** + * Assert that a metric counter has a specific value; failure raises an exception + * including the cache's toString value + * @param name counter name (for exceptions) + * @param counter counter + * @param expected expected value. + * @param cache cache + */ + def assertMetric( + name: String, + counter: Counter, + expected: Long) + (implicit cache: ApplicationCache): Unit = { + val actual = counter.getCount + if (actual != expected) { + // this is here because Scalatest loses stack depth + throw new Exception(s"Wrong $name value - expected $expected but got $actual in $cache") + } + } + + /** + * Look up the cache entry and assert that it maches in the expected value. + * This assertion works if the two CacheEntries are different -it looks at the fields. + * UI are compared on object equality; the timestamp and completed flags directly. + * @param appId application ID + * @param attemptId attempt ID + * @param expected expected value + * @param cache app cache + */ + def assertCacheEntryEquals( + appId: String, + attemptId: Option[String], + expected: CacheEntry) + (implicit cache: ApplicationCache): Unit = { + val actual = cache.lookupCacheEntry(appId, attemptId) + val errorText = s"Expected get($appId, $attemptId) -> $expected, but got $actual from $cache" + assert(expected.ui === actual.ui, errorText + " SparkUI reference") + assert(expected.completed === actual.completed, errorText + " -completed flag") + assert(expected.probeTime === actual.probeTime, errorText + " -timestamp") + } + + /** + * Assert that a key wasn't found in cache or loaded. + * + * Looks for the specific nested exception raised by [[ApplicationCache]] + * @param appId application ID + * @param attemptId attempt ID + * @param cache app cache + */ + def assertNotFound( + appId: String, + attemptId: Option[String]) + (implicit cache: ApplicationCache): Unit = { + val ex = intercept[UncheckedExecutionException] { + cache.get(appId, attemptId) + } + var cause = ex.getCause + assert(cause !== null) + if (!cause.isInstanceOf[NoSuchElementException]) { + throw cause + } + } + + test("Large Scale Application Eviction") { + val operations = new StubCacheOperations() + val clock = new ManualClock(0) + val size = 5 + // only two entries are retained, so we expect evictions to occurr on lookups + implicit val cache: ApplicationCache = new TestApplicationCache(operations, + retainedApplications = size, clock = clock) + + val attempt1 = Some("01") + + val ids = new ListBuffer[String]() + // build a list of applications + val count = 100 + for (i <- 1 to count ) { + val appId = f"app-$i%04d" + ids += appId + clock.advance(10) + val t = clock.getTimeMillis() + operations.putAppUI(appId, attempt1, true, t, t, t) + } + // now go through them in sequence reading them, expect evictions + ids.foreach { id => + cache.get(id, attempt1) + } + logInfo(cache.toString) + val metrics = cache.metrics + + assertMetric("loadCount", metrics.loadCount, count) + assertMetric("evictionCount", metrics.evictionCount, count - size) +} + + test("Attempts are Evicted") { + val operations = new StubCacheOperations() + implicit val cache: ApplicationCache = new TestApplicationCache(operations, + retainedApplications = 4) + val metrics = cache.metrics + val appId = "app1" + val attempt1 = Some("01") + val attempt2 = Some("02") + val attempt3 = Some("03") + operations.putAppUI(appId, attempt1, true, 100, 110, 110) + operations.putAppUI(appId, attempt2, true, 200, 210, 210) + operations.putAppUI(appId, attempt3, true, 300, 310, 310) + val attempt4 = Some("04") + operations.putAppUI(appId, attempt4, true, 400, 410, 410) + val attempt5 = Some("05") + operations.putAppUI(appId, attempt5, true, 500, 510, 510) + + def expectLoadAndEvictionCounts(expectedLoad: Int, expectedEvictionCount: Int): Unit = { + assertMetric("loadCount", metrics.loadCount, expectedLoad) + assertMetric("evictionCount", metrics.evictionCount, expectedEvictionCount) + } + + // first entry + cache.get(appId, attempt1) + expectLoadAndEvictionCounts(1, 0) + + // second + cache.get(appId, attempt2) + expectLoadAndEvictionCounts(2, 0) + + // no change + cache.get(appId, attempt2) + expectLoadAndEvictionCounts(2, 0) + + // eviction time + cache.get(appId, attempt3) + cache.size() should be(3) + cache.get(appId, attempt4) + expectLoadAndEvictionCounts(4, 0) + cache.get(appId, attempt5) + expectLoadAndEvictionCounts(5, 1) + cache.get(appId, attempt5) + expectLoadAndEvictionCounts(5, 1) + + } + + test("Instantiate Filter") { + // this is a regression test on the filter being constructable + val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) + val instance = clazz.newInstance() + instance shouldBe a [Filter] + } + + test("redirect includes query params") { + val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) + val filter = clazz.newInstance().asInstanceOf[ApplicationCacheCheckFilter] + filter.appId = "local-123" + val cache = mock[ApplicationCache] + when(cache.checkForUpdates(any(), any())).thenReturn(true) + ApplicationCacheCheckFilterRelay.setApplicationCache(cache) + val request = mock[HttpServletRequest] + when(request.getMethod()).thenReturn("GET") + when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/") + when(request.getQueryString()).thenReturn("id=2") + val resp = mock[HttpServletResponse] + when(resp.encodeRedirectURL(any())).thenAnswer(new Answer[String](){ + override def answer(invocationOnMock: InvocationOnMock): String = { + invocationOnMock.getArguments()(0).asInstanceOf[String] + } + }) + filter.doFilter(request, resp, null) + verify(resp).sendRedirect("http://localhost:18080/history/local-123/jobs/job/?id=2") + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 40d0076eecfc8..4b05469c42055 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -21,16 +21,28 @@ import java.net.{HttpURLConnection, URL} import java.util.zip.ZipInputStream import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.codahale.metrics.Counter import com.google.common.base.Charsets import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} -import org.mockito.Mockito.when +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods._ +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.concurrent.Eventually import org.scalatest.mock.MockitoSugar +import org.scalatest.selenium.WebBrowser -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.ui.{SparkUI, UIUtils} -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.UIData.JobUIData +import org.apache.spark.util.{ResetSystemProperties, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -44,7 +56,8 @@ import org.apache.spark.util.ResetSystemProperties * are considered part of Spark's public api. */ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar - with JsonTestUtils with ResetSystemProperties { + with JsonTestUtils with Eventually with WebBrowser with LocalSparkContext + with ResetSystemProperties { private val logDir = new File("src/test/resources/spark-events") private val expRoot = new File("src/test/resources/HistoryServerExpectations/") @@ -56,7 +69,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers def init(): Unit = { val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) - .set("spark.history.fs.updateInterval", "0") + .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -256,6 +269,204 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("incomplete apps get refreshed") { + + implicit val webDriver: WebDriver = new HtmlUnitDriver + implicit val formats = org.json4s.DefaultFormats + + // this test dir is explictly deleted on successful runs; retained for diagnostics when + // not + val logDir = Utils.createDirectory(System.getProperty("java.io.tmpdir", "logs")) + + // a new conf is used with the background thread set and running at its fastest + // alllowed refresh rate (1Hz) + val myConf = new SparkConf() + .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) + .set("spark.eventLog.dir", logDir.getAbsolutePath) + .set("spark.history.fs.update.interval", "1s") + .set("spark.eventLog.enabled", "true") + .set("spark.history.cache.window", "250ms") + .remove("spark.testing") + val provider = new FsHistoryProvider(myConf) + val securityManager = new SecurityManager(myConf) + + sc = new SparkContext("local", "test", myConf) + val logDirUri = logDir.toURI + val logDirPath = new Path(logDirUri) + val fs = FileSystem.get(logDirUri, sc.hadoopConfiguration) + + def listDir(dir: Path): Seq[FileStatus] = { + val statuses = fs.listStatus(dir) + statuses.flatMap( + stat => if (stat.isDirectory) listDir(stat.getPath) else Seq(stat)) + } + + def dumpLogDir(msg: String = ""): Unit = { + if (log.isDebugEnabled) { + logDebug(msg) + listDir(logDirPath).foreach { status => + val s = status.toString + logDebug(s) + } + } + } + + // stop the server with the old config, and start the new one + server.stop() + server = new HistoryServer(myConf, provider, securityManager, 18080) + server.initialize() + server.bind() + val port = server.boundPort + val metrics = server.cacheMetrics + + // assert that a metric has a value; if not dump the whole metrics instance + def assertMetric(name: String, counter: Counter, expected: Long): Unit = { + val actual = counter.getCount + if (actual != expected) { + // this is here because Scalatest loses stack depth + fail(s"Wrong $name value - expected $expected but got $actual" + + s" in metrics\n$metrics") + } + } + + // build a URL for an app or app/attempt plus a page underneath + def buildURL(appId: String, suffix: String): URL = { + new URL(s"http://localhost:$port/history/$appId$suffix") + } + + // build a rest URL for the application and suffix. + def applications(appId: String, suffix: String): URL = { + new URL(s"http://localhost:$port/api/v1/applications/$appId$suffix") + } + + val historyServerRoot = new URL(s"http://localhost:$port/") + + // start initial job + val d = sc.parallelize(1 to 10) + d.count() + val stdInterval = interval(100 milliseconds) + val appId = eventually(timeout(20 seconds), stdInterval) { + val json = getContentAndCode("applications", port)._2.get + val apps = parse(json).asInstanceOf[JArray].arr + apps should have size 1 + (apps.head \ "id").extract[String] + } + + val appIdRoot = buildURL(appId, "") + val rootAppPage = HistoryServerSuite.getUrl(appIdRoot) + logDebug(s"$appIdRoot ->[${rootAppPage.length}] \n$rootAppPage") + // sanity check to make sure filter is chaining calls + rootAppPage should not be empty + + def getAppUI: SparkUI = { + provider.getAppUI(appId, None).get.ui + } + + // selenium isn't that useful on failures...add our own reporting + def getNumJobs(suffix: String): Int = { + val target = buildURL(appId, suffix) + val targetBody = HistoryServerSuite.getUrl(target) + try { + go to target.toExternalForm + findAll(cssSelector("tbody tr")).toIndexedSeq.size + } catch { + case ex: Exception => + throw new Exception(s"Against $target\n$targetBody", ex) + } + } + // use REST API to get #of jobs + def getNumJobsRestful(): Int = { + val json = HistoryServerSuite.getUrl(applications(appId, "/jobs")) + val jsonAst = parse(json) + val jobList = jsonAst.asInstanceOf[JArray] + jobList.values.size + } + + // get a list of app Ids of all apps in a given state. REST API + def listApplications(completed: Boolean): Seq[String] = { + val json = parse(HistoryServerSuite.getUrl(applications("", ""))) + logDebug(s"${JsonMethods.pretty(json)}") + json match { + case JNothing => Seq() + case apps: JArray => + apps.filter(app => { + (app \ "attempts") match { + case attempts: JArray => + val state = (attempts.children.head \ "completed").asInstanceOf[JBool] + state.value == completed + case _ => false + } + }).map(app => (app \ "id").asInstanceOf[JString].values) + case _ => Seq() + } + } + + def completedJobs(): Seq[JobUIData] = { + getAppUI.jobProgressListener.completedJobs + } + + def activeJobs(): Seq[JobUIData] = { + getAppUI.jobProgressListener.activeJobs.values.toSeq + } + + activeJobs() should have size 0 + completedJobs() should have size 1 + getNumJobs("") should be (1) + getNumJobs("/jobs") should be (1) + getNumJobsRestful() should be (1) + assert(metrics.lookupCount.getCount > 1, s"lookup count too low in $metrics") + + // dump state before the next bit of test, which is where update + // checking really gets stressed + dumpLogDir("filesystem before executing second job") + logDebug(s"History Server: $server") + + val d2 = sc.parallelize(1 to 10) + d2.count() + dumpLogDir("After second job") + + val stdTimeout = timeout(10 seconds) + logDebug("waiting for UI to update") + eventually(stdTimeout, stdInterval) { + assert(2 === getNumJobs(""), + s"jobs not updated, server=$server\n dir = ${listDir(logDirPath)}") + assert(2 === getNumJobs("/jobs"), + s"job count under /jobs not updated, server=$server\n dir = ${listDir(logDirPath)}") + getNumJobsRestful() should be(2) + } + + d.count() + d.count() + eventually(stdTimeout, stdInterval) { + assert(4 === getNumJobsRestful(), s"two jobs back-to-back not updated, server=$server\n") + } + val jobcount = getNumJobs("/jobs") + assert(!provider.getListing().head.completed) + + listApplications(false) should contain(appId) + + // stop the spark context + resetSparkContext() + // check the app is now found as completed + eventually(stdTimeout, stdInterval) { + assert(provider.getListing().head.completed, + s"application never completed, server=$server\n") + } + + // app becomes observably complete + eventually(stdTimeout, stdInterval) { + listApplications(true) should contain (appId) + } + // app is no longer incomplete + listApplications(false) should not contain(appId) + + assert(jobcount === getNumJobs("/jobs")) + + // no need to retain the test dir now the tests complete + logDir.deleteOnExit(); + + } + def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path")) } @@ -275,6 +486,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers out.write(json) out.close() } + } object HistoryServerSuite { diff --git a/docs/monitoring.md b/docs/monitoring.md index cedceb2958023..c37f6fb20dd6e 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -38,11 +38,25 @@ You can start the history server by executing: ./sbin/start-history-server.sh -When using the file-system provider class (see spark.history.provider below), the base logging -directory must be supplied in the spark.history.fs.logDirectory configuration option, -and should contain sub-directories that each represents an application's event logs. This creates a -web interface at `http://:18080` by default. The history server can be configured as -follows: +This creates a web interface at `http://:18080` by default, listing incomplete +and completed applications and attempts, and allowing them to be viewed + +When using the file-system provider class (see `spark.history.provider` below), the base logging +directory must be supplied in the `spark.history.fs.logDirectory` configuration option, +and should contain sub-directories that each represents an application's event logs. + +The spark jobs themselves must be configured to log events, and to log them to the same shared, +writeable directory. For example, if the server was configured with a log directory of +`hdfs://namenode/shared/spark-logs`, then the client-side options would be: + +``` +spark.eventLog.enabled true +spark.eventLog.dir hdfs://namenode/shared/spark-logs +``` + +The history server can be configured as follows: + +### Environment Variables @@ -69,11 +83,13 @@ follows:
        Environment VariableMeaning
        +### Spark configuration options + - + @@ -82,15 +98,21 @@ follows: @@ -112,7 +134,7 @@ follows:
        Property NameDefaultMeaning
        spark.history.providerorg.apache.spark.deploy.history.FsHistoryProviderorg.apache.spark.deploy.history.FsHistoryProvider Name of the class implementing the application history backend. Currently there is only one implementation, provided by Spark, which looks for application logs stored in the file system.spark.history.fs.logDirectory file:/tmp/spark-events - Directory that contains application event logs to be loaded by the history server + For the filesystem history provider, the URL to the directory containing application event + logs to load. This can be a local file:// path, + an HDFS path hdfs://namenode/shared/spark-logs + or that of an alternative filesystem supported by the Hadoop APIs.
        spark.history.fs.update.interval 10s - The period at which information displayed by this history server is updated. - Each update checks for any changes made to the event logs in persisted storage. + The period at which the the filesystem history provider checks for new or + updated logs in the log directory. A shorter interval detects new applications faster, + at the expense of more server load re-reading updated applications. + As soon as an update has completed, listings of the completed and incomplete applications + will reflect the changes.
        spark.history.kerberos.enabled false - Indicates whether the history server should use kerberos to login. This is useful + Indicates whether the history server should use kerberos to login. This is required if the history server is accessing HDFS files on a secure Hadoop cluster. If this is true, it uses the configs spark.history.kerberos.principal and spark.history.kerberos.keytab. @@ -156,15 +178,15 @@ follows: spark.history.fs.cleaner.interval 1d - How often the job history cleaner checks for files to delete. - Files are only deleted if they are older than spark.history.fs.cleaner.maxAge. + How often the filesystem job history cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.cleaner.maxAge
        spark.history.fs.cleaner.maxAge 7d - Job history files older than this will be deleted when the history cleaner runs. + Job history files older than this will be deleted when the filesystem history cleaner runs.
        @@ -172,7 +194,25 @@ follows: Note that in all of these UIs, the tables are sortable by clicking their headers, making it easy to identify slow tasks, data skew, etc. -Note that the history server only displays completed Spark jobs. One way to signal the completion of a Spark job is to stop the Spark Context explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` to handle the Spark Context setup and tear down, and still show the job history on the UI. +Note + +1. The history server displays both completed and incomplete Spark jobs. If an application makes +multiple attempts after failures, the failed attempts will be displayed, as well as any ongoing +incomplete attempt or the final successful attempt. + +2. Incomplete applications are only updated intermittently. The time between updates is defined +by the interval between checks for changed files (`spark.history.fs.update.interval`). +On larger clusters the update interval may be set to large values. +The way to view a running application is actually to view its own web UI. + +3. Applications which exited without registering themselves as completed will be listed +as incomplete —even though they are no longer running. This can happen if an application +crashes. + +2. One way to signal the completion of a Spark job is to stop the Spark Context +explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` construct +to handle the Spark Context setup and tear down. + ## REST API @@ -249,7 +289,7 @@ These endpoints have been strongly versioned to make it easier to develop applic * New endpoints may be added * New fields may be added to existing endpoints * New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. -* Api versions may be dropped, but only after at least one minor release of co-existing with a new api version +* Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 133894704b6cc..8611106db0cf0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -233,6 +233,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") + ) ++ Seq( + // SPARK-7889 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI") ) case v if v.startsWith("1.6") => Seq( From 894921d813a259f2f266fde7d86d2ecb5a0af24b Mon Sep 17 00:00:00 2001 From: Sanket Date: Thu, 11 Feb 2016 22:40:00 -0800 Subject: [PATCH 1887/2089] [SPARK-6166] Limit number of in flight outbound requests This JIRA is related to https://github.com/apache/spark/pull/5852 Had to do some minor rework and test to make sure it works with current version of spark. Author: Sanket Closes #10838 from redsanket/limit-outbound-connections. --- .../org/apache/spark/SecurityManager.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../storage/ShuffleBlockFetcherIterator.scala | 40 ++++++++++++++----- .../ShuffleBlockFetcherIteratorSuite.scala | 9 +++-- docs/configuration.md | 10 +++++ 5 files changed, 49 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0675957e16680..6132fa349ef2d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -69,7 +69,7 @@ import org.apache.spark.util.Utils * * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * Basic, Digest, Form, Spnego, etc. It also supports multiple different login * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService * to authenticate using DIGEST-MD5 via a single user and the shared secret. * Since we are using DIGEST-MD5, the shared secret is not passed on the wire diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index acbe16001f5ba..dc182f596335b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -46,7 +46,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // Wrap the streams for compression based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c368a39e629f0..478a928acd03c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -47,6 +47,7 @@ import org.apache.spark.util.Utils * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -54,7 +55,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - maxBytesInFlight: Long) + maxBytesInFlight: Long, + maxReqsInFlight: Int) extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -102,6 +104,9 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + private[this] val shuffleMetrics = context.taskMetrics().registerTempShuffleReadMetrics() /** @@ -118,7 +123,7 @@ final class ShuffleBlockFetcherIterator( private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf, _) => buf.release() case _ => } currentResult = null @@ -137,7 +142,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf) => { + case SuccessFetchResult(_, address, _, buf, _) => { if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) @@ -153,9 +158,11 @@ final class ShuffleBlockFetcherIterator( logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) bytesInFlight += req.size + reqsInFlight += 1 // so we can look up the size of each blockID val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap + val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) val address = req.address @@ -169,7 +176,10 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) + remainingBlocks -= blockId + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) } } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) @@ -249,7 +259,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -268,6 +278,9 @@ final class ShuffleBlockFetcherIterator( val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) + assert ((0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) // Send out initial requests for blocks, up to our maxBytesInFlight fetchUpToMaxBytes() @@ -299,12 +312,16 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, address, size, buf) => { + case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => { if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } bytesInFlight -= size + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } } case _ => } @@ -315,7 +332,7 @@ final class ShuffleBlockFetcherIterator( case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) - case SuccessFetchResult(blockId, address, _, buf) => + case SuccessFetchResult(blockId, address, _, buf, _) => try { (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) } catch { @@ -328,7 +345,9 @@ final class ShuffleBlockFetcherIterator( private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { sendRequest(fetchRequests.dequeue()) } } @@ -406,13 +425,14 @@ object ShuffleBlockFetcherIterator { * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. + * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( blockId: BlockId, address: BlockManagerId, size: Long, - buf: ManagedBuffer) - extends FetchResult { + buf: ManagedBuffer, + isNetworkReqDone: Boolean) extends FetchResult { require(buf != null) require(size >= 0) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c9c2fb2691d70..e3ec99685f73c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -99,7 +99,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - 48 * 1024 * 1024) + 48 * 1024 * 1024, + Int.MaxValue) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -171,7 +172,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - 48 * 1024 * 1024) + 48 * 1024 * 1024, + Int.MaxValue) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -233,7 +235,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - 48 * 1024 * 1024) + 48 * 1024 * 1024, + Int.MaxValue) // Continue only after the mock calls onBlockFetchFailure sem.acquire() diff --git a/docs/configuration.md b/docs/configuration.md index dd2cde81941db..0dbfe3b0796ba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -391,6 +391,16 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. + + spark.reducer.maxReqsInFlight + Int.MaxValue + + This configuration limits the number of remote requests to fetch blocks at any given point. + When the number of hosts in the cluster increase, it might lead to very large number + of in-bound connections to one or more nodes, causing the workers to fail under load. + By allowing it to limit the number of fetch requests, this scenario can be mitigated. + + spark.shuffle.compress true From a183dda6ab597e5b7ead58bbaa696f836b16e524 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 12 Feb 2016 01:45:45 -0800 Subject: [PATCH 1888/2089] [SPARK-12974][ML][PYSPARK] Add Python API for spark.ml bisecting k-means Add Python API for spark.ml bisecting k-means. Author: Yanbo Liang Closes #10889 from yanboliang/spark-12974. --- python/pyspark/ml/clustering.py | 125 +++++++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index f156eda125bf2..91278d570a642 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -21,7 +21,7 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['KMeans', 'KMeansModel'] +__all__ = ['KMeans', 'KMeansModel', 'BisectingKMeans', 'BisectingKMeansModel'] class KMeansModel(JavaModel, MLWritable, MLReadable): @@ -175,6 +175,129 @@ def getInitSteps(self): return self.getOrDefault(self.initSteps) +class BisectingKMeansModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by BisectingKMeans. + + .. versionadded:: 2.0.0 + """ + + @since("2.0.0") + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy arrays.""" + return [c.toArray() for c in self._call_java("clusterCenters")] + + @since("2.0.0") + def computeCost(self, dataset): + """ + Computes the sum of squared distances between the input points + and their corresponding cluster centers. + """ + return self._call_java("computeCost", dataset) + + +@inherit_doc +class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed): + """ + .. note:: Experimental + + A bisecting k-means algorithm based on the paper "A comparison of document clustering + techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. + The algorithm starts from a single cluster that contains all points. + Iteratively it finds divisible clusters on the bottom level and bisects each of them using + k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + The bisecting steps of clusters on the same level are grouped together to increase parallelism. + If bisecting all divisible clusters on the bottom level would result more than `k` leaf + clusters, larger clusters get higher priority. + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0) + >>> model = bkm.fit(df) + >>> centers = model.clusterCenters() + >>> len(centers) + 2 + >>> model.computeCost(df) + 2.000... + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[0].prediction == rows[1].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + + .. versionadded:: 2.0.0 + """ + + k = Param(Params._dummy(), "k", "number of clusters to create") + minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize", + "the minimum number of points (if >= 1.0) " + + "or the minimum proportion") + + @keyword_only + def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, + seed=None, k=4, minDivisibleClusterSize=1.0): + """ + __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ + seed=None, k=4, minDivisibleClusterSize=1.0) + """ + super(BisectingKMeans, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", + self.uid) + self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, + seed=None, k=4, minDivisibleClusterSize=1.0): + """ + setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ + seed=None, k=4, minDivisibleClusterSize=1.0) + Sets params for BisectingKMeans. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + self._paramMap[self.k] = value + return self + + @since("2.0.0") + def getK(self): + """ + Gets the value of `k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.0.0") + def setMinDivisibleClusterSize(self, value): + """ + Sets the value of :py:attr:`minDivisibleClusterSize`. + """ + self._paramMap[self.minDivisibleClusterSize] = value + return self + + @since("2.0.0") + def getMinDivisibleClusterSize(self): + """ + Gets the value of `minDivisibleClusterSize` or its default value. + """ + return self.getOrDefault(self.minDivisibleClusterSize) + + def _create_model(self, java_model): + return BisectingKMeansModel(java_model) + + if __name__ == "__main__": import doctest import pyspark.ml.clustering From 64515e5fbfd694d06fdbc28040fce7baf90a32aa Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Feb 2016 02:13:06 -0800 Subject: [PATCH 1889/2089] [SPARK-13154][PYTHON] Add linting for pydocs We should have lint rules using sphinx to automatically catch the pydoc issues that are sometimes introduced. Right now ./dev/lint-python will skip building the docs if sphinx isn't present - but it might make sense to fail hard - just a matter of if we want to insist all PySpark developers have sphinx present. Author: Holden Karau Closes #11109 from holdenk/SPARK-13154-add-pydoc-lint-for-docs. --- dev/lint-python | 24 ++++++++++++++++++++++++ python/docs/conf.py | 3 +++ 2 files changed, 27 insertions(+) diff --git a/dev/lint-python b/dev/lint-python index 1765a07d2f22b..068337d273f82 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -24,6 +24,8 @@ PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/r PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" +SPHINXBUILD=${SPHINXBUILD:=sphinx-build} +SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" cd "$SPARK_ROOT_DIR" @@ -96,6 +98,28 @@ fi rm "$PEP8_REPORT_PATH" +# Check that the documentation builds acceptably, skip check if sphinx is not installed. +if hash "$SPHINXBUILD" 2> /dev/null; then + cd python/docs + make clean + # Treat warnings as errors so we stop correctly + SPHINXOPTS="-a -W" make html &> "$SPHINX_REPORT_PATH" || lint_status=1 + if [ "$lint_status" -ne 0 ]; then + echo "pydoc checks failed." + cat "$SPHINX_REPORT_PATH" + echo "re-running make html to print full warning list" + make clean + SPHINXOPTS="-a" make html + else + echo "pydoc checks passed." + fi + rm "$SPHINX_REPORT_PATH" + cd ../.. +else + echo >&2 "The $SPHINXBUILD command was not found. Skipping pydoc checks for now" +fi + + # for to_be_checked in "$PATHS_TO_CHECK" # do # pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" diff --git a/python/docs/conf.py b/python/docs/conf.py index 365d6af514177..d35bf73c30510 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -334,3 +334,6 @@ # If false, no index is generated. #epub_use_index = True + +# Skip sample endpoint link (not expected to resolve) +linkcheck_ignore = [r'https://kinesis.us-east-1.amazonaws.com'] From 5b805df279d744543851f06e5a0d741354ef485b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 12 Feb 2016 09:34:18 -0800 Subject: [PATCH 1890/2089] [SPARK-12705] [SQL] push missing attributes for Sort The current implementation of ResolveSortReferences can only push one missing attributes into it's child, it failed to analyze TPCDS Q98, because of there are two missing attributes in that (one from Window, another from Aggregate). Author: Davies Liu Closes #11153 from davies/resolve_sort. --- .../sql/catalyst/analysis/Analyzer.scala | 133 +++++++----------- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 15 ++ 3 files changed, 67 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c0fa79612a007..26c3d286b19fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -598,98 +597,69 @@ class Analyzer( // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(_, _, child) if !s.resolved && child.resolved => - val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) - - if (missingResolvableAttrs.isEmpty) { - val unresolvableAttrs = s.order.filterNot(_.resolved) - logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") - s // Nothing we can do here. Return original plan. - } else { - // Add the missing attributes into projectList of Project/Window or - // aggregateExpressions of Aggregate, if they are in the inputSet - // but not in the outputSet of the plan. - val newChild = child transformUp { - case p: Project => - p.copy(projectList = p.projectList ++ - missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) - case w: Window => - w.copy(projectList = w.projectList ++ - missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) - case a: Aggregate => - val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) - val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) - val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs - a.copy(aggregateExpressions = newAggregateExpressions) - case o => o - } - + case s @ Sort(order, _, child) if !s.resolved && child.resolved => + val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) + val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(child.output, - Sort(newOrdering, s.global, newChild)) + Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) + } else if (newOrder != order) { + s.copy(order = newOrder) + } else { + s } } /** - * Traverse the tree until resolving the sorting attributes - * Return all the resolvable missing sorting attributes - */ - @tailrec - private def collectResolvableMissingAttrs( - ordering: Seq[SortOrder], - plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + * Add the missing attributes into projectList of Project/Window or aggregateExpressions of + * Aggregate. + */ + private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { + if (missingAttrs.isEmpty) { + return plan + } plan match { - // Only Windows and Project have projectList-like attribute. - case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) - // If missingAttrs is non empty, that means we got it and return it; - // Otherwise, continue to traverse the tree. - if (missingAttrs.nonEmpty) { - (newOrdering, missingAttrs) - } else { - collectResolvableMissingAttrs(ordering, un.child) - } + case p: Project => + val missing = missingAttrs -- p.child.outputSet + Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) + case w: Window => + val missing = missingAttrs -- w.child.outputSet + w.copy(projectList = w.projectList ++ missingAttrs, + child = addMissingAttr(w.child, missing)) case a: Aggregate => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) - // For Aggregate, all the order by columns must be specified in group by clauses - if (missingAttrs.nonEmpty && - missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { - (newOrdering, missingAttrs) - } else { - // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes - (Seq.empty[SortOrder], Seq.empty[Attribute]) + // all the missing attributes should be grouping expressions + // TODO: push down AggregateExpression + missingAttrs.foreach { attr => + if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { + throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") + } } - // Jump over the following UnaryNode types - // The output of these types is the same as their child's output - case _: Distinct | - _: Filter | - _: RepartitionByExpression => - collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) - // If hitting the other unsupported operators, we are unable to resolve it. - case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case u: UnaryNode => + u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) + case other => + throw new AnalysisException(s"Can't add $missingAttrs to $other") } } /** - * Try to resolve the sort ordering and returns it with a list of attributes that are missing - * from the plan but are present in the child. - */ - private def resolveAndFindMissing( - ordering: Seq[SortOrder], - plan: LogicalPlan, - child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = - ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) - // Construct a set that contains all of the attributes that we need to evaluate the - // ordering. - val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) - // Figure out which ones are missing from the projection, so that we can add them and - // remove them after the sort. - val missingInProject = requiredAttributes -- plan.outputSet - // It is important to return the new SortOrders here, instead of waiting for the standard - // resolving process as adding attributes to the project below can actually introduce - // ambiguity that was not present before. - (newOrdering, missingInProject.toSeq) + * Resolve the expression on a specified logical plan and it's child (recursively), until + * the expression is resolved or meet a non-unary node or Subquery. + */ + private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { + val resolved = resolveExpression(expr, plan) + if (resolved.resolved) { + resolved + } else { + plan match { + case u: UnaryNode if !u.isInstanceOf[Subquery] => + resolveExpressionRecursively(resolved, u.child) + case other => resolved + } + } } } @@ -782,8 +752,7 @@ class Analyzer( filter } - case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved => + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ebf885a8fe484..f85ae24e0459b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -90,7 +90,7 @@ class AnalysisSuite extends AnalysisTest { .where(a > "str").select(a, b, c) .where(b > "str").select(a, b, c) .sortBy(b.asc, c.desc) - .select(a, b).select(a) + .select(a) checkAnalysis(plan1, expected1) // Case 2: all the missing attributes are in the leaf node diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6048b8f5a3998..be864f79d6b7e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -978,6 +978,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ("d", 1), ("c", 2) ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, sum(product) / sum(sum(product)) over (partition by area) as c1 + |from windowData group by area, month order by month, c1 + """.stripMargin), + Seq( + ("d", 1.0), + ("a", 1.0), + ("b", 0.4666666666666667), + ("b", 0.5333333333333333), + ("c", 0.45), + ("c", 0.55) + ).map(i => Row(i._1, i._2))) } // todo: fix this test case by reimplementing the function ResolveAggregateFunctions From c4d5ad80c8091c961646a82e85ecbc335b8ffe2d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Feb 2016 10:08:19 -0800 Subject: [PATCH 1891/2089] [SPARK-13282][SQL] LogicalPlan toSql should just return a String Previously we were using Option[String] and None to indicate the case when Spark fails to generate SQL. It is easier to just use exceptions to propagate error cases, rather than having for comprehension everywhere. I also introduced a "build" function that simplifies string concatenation (i.e. no need to reason about whether we have an extra space or not). Author: Reynold Xin Closes #11171 from rxin/SPARK-13282. --- .../apache/spark/sql/hive/SQLBuilder.scala | 224 ++++++++---------- .../hive/execution/CreateViewAsSelect.scala | 8 +- .../sql/hive/ExpressionSQLBuilderSuite.scala | 3 +- .../sql/hive/LogicalPlanToSQLSuite.scala | 29 +-- .../spark/sql/hive/SQLBuilderTest.scala | 12 +- .../hive/execution/HiveComparisonTest.scala | 21 +- 6 files changed, 141 insertions(+), 156 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 4b75e60f8d5f3..d7bae913f8720 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong +import scala.util.control.NonFatal + import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -37,16 +39,10 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) - def toSQL: Option[String] = { + def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) - val maybeSQL = try { - toSQL(canonicalizedPlan) - } catch { case cause: UnsupportedOperationException => - logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") - None - } - - if (maybeSQL.isDefined) { + try { + val generatedSQL = toSQL(canonicalizedPlan) logDebug( s"""Built SQL query string successfully from given logical plan: | @@ -54,10 +50,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi |${logicalPlan.treeString} |# Canonicalized logical plan: |${canonicalizedPlan.treeString} - |# Built SQL query string: - |${maybeSQL.get} + |# Generated SQL: + |$generatedSQL """.stripMargin) - } else { + generatedSQL + } catch { case NonFatal(e) => logDebug( s"""Failed to build SQL query string from given logical plan: | @@ -66,128 +63,113 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi |# Canonicalized logical plan: |${canonicalizedPlan.treeString} """.stripMargin) + throw e } - - maybeSQL } - private def projectToSQL( - projectList: Seq[NamedExpression], - child: LogicalPlan, - isDistinct: Boolean): Option[String] = { - for { - childSQL <- toSQL(child) - listSQL = projectList.map(_.sql).mkString(", ") - maybeFrom = child match { - case OneRowRelation => " " - case _ => " FROM " - } - distinct = if (isDistinct) " DISTINCT " else " " - } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" - } + private def toSQL(node: LogicalPlan): String = node match { + case Distinct(p: Project) => + projectToSQL(p, isDistinct = true) - private def aggregateToSQL( - groupingExprs: Seq[Expression], - aggExprs: Seq[Expression], - child: LogicalPlan): Option[String] = { - val aggSQL = aggExprs.map(_.sql).mkString(", ") - val groupingSQL = groupingExprs.map(_.sql).mkString(", ") - val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " - val maybeFrom = child match { - case OneRowRelation => " " - case _ => " FROM " - } + case p: Project => + projectToSQL(p, isDistinct = false) - toSQL(child).map { childSQL => - s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" - } - } + case p: Aggregate => + aggregateToSQL(p) - private def toSQL(node: LogicalPlan): Option[String] = node match { - case Distinct(Project(list, child)) => - projectToSQL(list, child, isDistinct = true) - - case Project(list, child) => - projectToSQL(list, child, isDistinct = false) - - case Aggregate(groupingExprs, aggExprs, child) => - aggregateToSQL(groupingExprs, aggExprs, child) - - case Limit(limit, child) => - for { - childSQL <- toSQL(child) - limitSQL = limit.sql - } yield s"$childSQL LIMIT $limitSQL" - - case Filter(condition, child) => - for { - childSQL <- toSQL(child) - whereOrHaving = child match { - case _: Aggregate => "HAVING" - case _ => "WHERE" - } - conditionSQL = condition.sql - } yield s"$childSQL $whereOrHaving $conditionSQL" - - case Union(children) if children.length > 1 => - val childrenSql = children.map(toSQL(_)) - if (childrenSql.exists(_.isEmpty)) { - None - } else { - Some(childrenSql.map(_.get).mkString(" UNION ALL ")) + case p: Limit => + s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}" + + case p: Filter => + val whereOrHaving = p.child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + build(toSQL(p.child), whereOrHaving, p.condition.sql) + + case p: Union if p.children.length > 1 => + val childrenSql = p.children.map(toSQL(_)) + childrenSql.mkString(" UNION ALL ") + + case p: Subquery => + p.child match { + // Persisted data source relation + case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => + s"`$database`.`$table`" + // Parentheses is not used for persisted data source relations + // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 + case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => + build(toSQL(p.child), "AS", p.alias) + case _ => + build("(" + toSQL(p.child) + ")", "AS", p.alias) } - // Persisted data source relation - case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => - Some(s"`$database`.`$table`") - - case Subquery(alias, child) => - toSQL(child).map( childSQL => - child match { - // Parentheses is not used for persisted data source relations - // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 - case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => - s"$childSQL AS $alias" - case _ => - s"($childSQL) AS $alias" - }) - - case Join(left, right, joinType, condition) => - for { - leftSQL <- toSQL(left) - rightSQL <- toSQL(right) - joinTypeSQL = joinType.sql - conditionSQL = condition.map(" ON " + _.sql).getOrElse("") - } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" - - case MetastoreRelation(database, table, alias) => - val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") - Some(s"`$database`.`$table`$aliasSQL") + case p: Join => + build( + toSQL(p.left), + p.joinType.sql, + "JOIN", + toSQL(p.right), + p.condition.map(" ON " + _.sql).getOrElse("")) + + case p: MetastoreRelation => + build( + s"`${p.databaseName}`.`${p.tableName}`", + p.alias.map(a => s" AS `$a`").getOrElse("") + ) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => - for { - childSQL <- toSQL(child) - partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") - } yield s"$childSQL CLUSTER BY $partitionExprsSQL" - - case Sort(orders, global, child) => - for { - childSQL <- toSQL(child) - ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") - orderOrSort = if (global) "ORDER" else "SORT" - } yield s"$childSQL $orderOrSort BY $ordersSQL" - - case RepartitionByExpression(partitionExprs, child, _) => - for { - childSQL <- toSQL(child) - partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") - } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" + build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", ")) + + case p: Sort => + build( + toSQL(p.child), + if (p.global) "ORDER BY" else "SORT BY", + p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + ) + + case p: RepartitionByExpression => + build( + toSQL(p.child), + "DISTRIBUTE BY", + p.partitionExpressions.map(_.sql).mkString(", ") + ) case OneRowRelation => - Some("") + "" - case _ => None + case _ => + throw new UnsupportedOperationException(s"unsupported plan $node") + } + + /** + * Turns a bunch of string segments into a single string and separate each segment by a space. + * The segments are trimmed so only a single space appears in the separation. + * For example, `build("a", " b ", " c")` becomes "a b c". + */ + private def build(segments: String*): String = segments.map(_.trim).mkString(" ") + + private def projectToSQL(plan: Project, isDistinct: Boolean): String = { + build( + "SELECT", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child) + ) + } + + private def aggregateToSQL(plan: Aggregate): String = { + val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ") + build( + "SELECT", + plan.aggregateExpressions.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child), + if (groupingSQL.isEmpty) "" else "GROUP BY", + groupingSQL + ) } object Canonicalizer extends RuleExecutor[LogicalPlan] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 31bda56e8a163..5da58a73e1e36 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import scala.util.control.NonFatal + import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Alias @@ -72,7 +74,9 @@ private[hive] case class CreateViewAsSelect( private def prepareTable(sqlContext: SQLContext): HiveTable = { val expandedText = if (sqlContext.conf.canonicalView) { - rebuildViewQueryString(sqlContext).getOrElse(wrapViewTextWithSelect) + try rebuildViewQueryString(sqlContext) catch { + case NonFatal(e) => wrapViewTextWithSelect + } } else { wrapViewTextWithSelect } @@ -112,7 +116,7 @@ private[hive] case class CreateViewAsSelect( s"SELECT $viewOutput FROM ($viewText) $viewName" } - private def rebuildViewQueryString(sqlContext: SQLContext): Option[String] = { + private def rebuildViewQueryString(sqlContext: SQLContext): String = { val logicalPlan = if (tableDesc.schema.isEmpty) { child } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 3a6eb57add4e3..3fb6543b1a1d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -33,8 +33,7 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") checkSQL(Literal(2.5D), "2.5") checkSQL( - Literal(Timestamp.valueOf("2016-01-01 00:00:00")), - "TIMESTAMP('2016-01-01 00:00:00.0')") + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 80ae312d913de..dc8ac7e47ffec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.util.control.NonFatal + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils @@ -46,29 +48,28 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { private def checkHiveQl(hiveQl: String): Unit = { val df = sql(hiveQl) - val convertedSQL = new SQLBuilder(df).toSQL - if (convertedSQL.isEmpty) { - fail( - s"""Cannot convert the following HiveQL query plan back to SQL query string: - | - |# Original HiveQL query string: - |$hiveQl - | - |# Resolved query plan: - |${df.queryExecution.analyzed.treeString} - """.stripMargin) + val convertedSQL = try new SQLBuilder(df).toSQL catch { + case NonFatal(e) => + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) } - val sqlString = convertedSQL.get try { - checkAnswer(sql(sqlString), df) + checkAnswer(sql(convertedSQL), df) } catch { case cause: Throwable => fail( s"""Failed to execute converted SQL string or got wrong answer: | |# Converted SQL query string: - |$sqlString + |$convertedSQL | |# Original HiveQL query string: |$hiveQl diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index a5e209ac9db3b..4adc5c11160ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.util.control.NonFatal + import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -40,9 +42,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { } protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { - val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL - - if (maybeSQL.isEmpty) { + val generatedSQL = try new SQLBuilder(plan, hiveContext).toSQL catch { case NonFatal(e) => fail( s"""Cannot convert the following logical query plan to SQL: | @@ -50,10 +50,8 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - val actualSQL = maybeSQL.get - try { - assert(actualSQL === expectedSQL) + assert(generatedSQL === expectedSQL) } catch { case cause: Throwable => fail( @@ -65,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + checkAnswer(sqlContext.sql(generatedSQL), new DataFrame(sqlContext, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 207bb814f0a27..af4c44e578c84 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -412,21 +412,22 @@ abstract class HiveComparisonTest originalQuery } else { numTotalQueries += 1 - new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + try { + val sql = new SQLBuilder(originalQuery.analyzed, TestHive).toSQL numConvertibleQueries += 1 logInfo( s""" - |### Running SQL generation round-trip test {{{ - |${originalQuery.analyzed.treeString} - |Original SQL: - |$queryString - | - |Generated SQL: - |$sql - |}}} + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} """.stripMargin.trim) new TestHive.QueryExecution(sql) - }.getOrElse { + } catch { case NonFatal(e) => logInfo( s""" |### Cannot convert the following logical plan back to SQL {{{ From ac7d6af1cafc6b159d1df6cf349bb0c7ffca01cd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 12 Feb 2016 11:54:58 -0800 Subject: [PATCH 1892/2089] [SPARK-13260][SQL] count(*) does not work with CSV data source https://issues.apache.org/jira/browse/SPARK-13260 This is a quicky fix for `count(*)`. When the `requiredColumns` is empty, currently it returns `sqlContext.sparkContext.emptyRDD[Row]` which does not have the count. Just like JSON datasource, this PR lets the CSV datasource count the rows but do not parse each set of tokens. Author: hyukjinkwon Closes #11169 from HyukjinKwon/SPARK-13260. --- .../datasources/csv/CSVRelation.scala | 84 +++++++++---------- .../execution/datasources/csv/CSVSuite.scala | 2 +- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index dc449fea956f8..f8e3a1b6d46ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -197,52 +197,48 @@ object CSVRelation extends Logging { } else { requiredFields } - if (requiredColumns.isEmpty) { - sqlContext.sparkContext.emptyRDD[Row] - } else { - val safeRequiredIndices = new Array[Int](safeRequiredFields.length) - schemaFields.zipWithIndex.filter { - case (field, _) => safeRequiredFields.contains(field) - }.foreach { - case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index - } - val rowArray = new Array[Any](safeRequiredIndices.length) - val requiredSize = requiredFields.length - tokenizedRDD.flatMap { tokens => - if (params.dropMalformed && schemaFields.length != tokens.size) { - logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - None - } else if (params.failFast && schemaFields.length != tokens.size) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(params.delimiter.toString)}") + val safeRequiredIndices = new Array[Int](safeRequiredFields.length) + schemaFields.zipWithIndex.filter { + case (field, _) => safeRequiredFields.contains(field) + }.foreach { + case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index + } + val rowArray = new Array[Any](safeRequiredIndices.length) + val requiredSize = requiredFields.length + tokenizedRDD.flatMap { tokens => + if (params.dropMalformed && schemaFields.length != tokens.size) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } else if (params.failFast && schemaFields.length != tokens.size) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(params.delimiter.toString)}") + } else { + val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) { + tokens ++ new Array[String](schemaFields.length - tokens.size) + } else if (params.permissive && schemaFields.length < tokens.size) { + tokens.take(schemaFields.length) } else { - val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.size) { - tokens ++ new Array[String](schemaFields.length - tokens.size) - } else if (params.permissive && schemaFields.length < tokens.size) { - tokens.take(schemaFields.length) - } else { - tokens - } - try { - var index: Int = 0 - var subIndex: Int = 0 - while (subIndex < safeRequiredIndices.length) { - index = safeRequiredIndices(subIndex) - val field = schemaFields(index) - rowArray(subIndex) = CSVTypeCast.castTo( - indexSafeTokens(index), - field.dataType, - field.nullable, - params.nullValue) - subIndex = subIndex + 1 - } - Some(Row.fromSeq(rowArray.take(requiredSize))) - } catch { - case NonFatal(e) if params.dropMalformed => - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - None + tokens + } + try { + var index: Int = 0 + var subIndex: Int = 0 + while (subIndex < safeRequiredIndices.length) { + index = safeRequiredIndices(subIndex) + val field = schemaFields(index) + rowArray(subIndex) = CSVTypeCast.castTo( + indexSafeTokens(index), + field.dataType, + field.nullable, + params.nullValue) + subIndex = subIndex + 1 } + Some(Row.fromSeq(rowArray.take(requiredSize))) + } catch { + case NonFatal(e) if params.dropMalformed => + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index fa4f137b703b4..9d1f4569ad5e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -56,7 +56,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val numRows = if (withHeader) numCars else numCars + 1 // schema assert(df.schema.fieldNames.length === numColumns) - assert(df.collect().length === numRows) + assert(df.count === numRows) if (checkHeader) { if (withHeader) { From 90de6b2fae71d05415610be70300625c409f6092 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 12 Feb 2016 12:43:13 -0800 Subject: [PATCH 1893/2089] [SPARK-12962] [SQL] [PySpark] PySpark support covar_samp and covar_pop PySpark support ```covar_samp``` and ```covar_pop```. cc rxin davies marmbrus Author: Yanbo Liang Closes #10876 from yanboliang/spark-12962. --- python/pyspark/sql/functions.py | 41 ++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 680493e0e689e..416d722bbac3c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -250,17 +250,46 @@ def corr(col1, col2): """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. - >>> a = [x * x - 2 * x + 3.5 for x in range(20)] - >>> b = range(20) - >>> corrDf = sqlContext.createDataFrame(zip(a, b)) - >>> corrDf = corrDf.agg(corr(corrDf._1, corrDf._2).alias('c')) - >>> corrDf.selectExpr('abs(c - 0.9572339139475857) < 1e-16 as t').collect() - [Row(t=True)] + >>> a = range(20) + >>> b = [2 * x for x in range(20)] + >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df.agg(corr("a", "b").alias('c')).collect() + [Row(c=1.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2))) +@since(2.0) +def covar_pop(col1, col2): + """Returns a new :class:`Column` for the population covariance of ``col1`` + and ``col2``. + + >>> a = [1] * 10 + >>> b = [1] * 10 + >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df.agg(covar_pop("a", "b").alias('c')).collect() + [Row(c=0.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.covar_pop(_to_java_column(col1), _to_java_column(col2))) + + +@since(2.0) +def covar_samp(col1, col2): + """Returns a new :class:`Column` for the sample covariance of ``col1`` + and ``col2``. + + >>> a = [1] * 10 + >>> b = [1] * 10 + >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df.agg(covar_samp("a", "b").alias('c')).collect() + [Row(c=0.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.covar_samp(_to_java_column(col1), _to_java_column(col2))) + + @since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. From 42d656814f756599a2bc426f0e1f32bd4cc4470f Mon Sep 17 00:00:00 2001 From: vijaykiran Date: Fri, 12 Feb 2016 14:24:24 -0800 Subject: [PATCH 1894/2089] [SPARK-12630][PYSPARK] [DOC] PySpark classification parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the classification module. Author: vijaykiran Author: Bryan Cutler Closes #11183 from BryanCutler/pyspark-consistent-param-classification-SPARK-12630. --- python/pyspark/mllib/classification.py | 261 ++++++++++++++----------- 1 file changed, 143 insertions(+), 118 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 9e6f17ef6e942..b24592c3798e6 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -94,16 +94,19 @@ class LogisticRegressionModel(LinearClassificationModel): Classification model trained using Multinomial/Binary Logistic Regression. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. (Only used - in Binary Logistic Regression. In Multinomial Logistic - Regression, the intercepts will not be a single value, - so the intercepts will be part of the weights.) - :param numFeatures: the dimension of the features. - :param numClasses: the number of possible outcomes for k classes - classification problem in Multinomial Logistic Regression. - By default, it is binary logistic regression so numClasses - will be set to 2. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. (Only used in Binary Logistic + Regression. In Multinomial Logistic Regression, the intercepts will + not bea single value, so the intercepts will be part of the + weights.) + :param numFeatures: + The dimension of the features. + :param numClasses: + The number of possible outcomes for k classes classification problem + in Multinomial Logistic Regression. By default, it is binary + logistic regression so numClasses will be set to 2. >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -189,8 +192,8 @@ def numFeatures(self): @since('1.4.0') def numClasses(self): """ - Number of possible outcomes for k classes classification problem in Multinomial - Logistic Regression. + Number of possible outcomes for k classes classification problem + in Multinomial Logistic Regression. """ return self._numClasses @@ -272,37 +275,42 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.01). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization - - "l2" for using L2 regularization - - None for no regularization - - (default: "l2") - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param regType: + The type of regularizer used for training our model. + Allowed values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization (default) + - None for no regularization + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), @@ -323,38 +331,43 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.01). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization - - "l2" for using L2 regularization - - None for no regularization - - (default: "l2") - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param corrections: The number of corrections used in the - LBFGS update (default: 10). - :param tolerance: The convergence tolerance of iterations - for L-BFGS (default: 1e-4). - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before - training. (default: True) - :param numClasses: The number of classes (i.e., outcomes) a - label can take in Multinomial Logistic - Regression (default: 2). + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param regType: + The type of regularizer used for training our model. + Allowed values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization (default) + - None for no regularization + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param corrections: + The number of corrections used in the LBFGS update. + (default: 10) + :param tolerance: + The convergence tolerance of iterations for L-BFGS. + (default: 1e-4) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param numClasses: + The number of classes (i.e., outcomes) a label can take in + Multinomial Logistic Regression. + (default: 2) >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -387,8 +400,10 @@ class SVMModel(LinearClassificationModel): """ Model for Support Vector Machines (SVMs). - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -490,37 +505,42 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ Train a support vector machine on the given data. - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization - - "l2" for using L2 regularization - - None for no regularization - - (default: "l2") - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regType: + The type of regularizer used for training our model. + Allowed values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization (default) + - None for no regularization + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), @@ -536,11 +556,13 @@ class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. - :param labels: list of labels. - :param pi: log of class priors, whose dimension is C, - number of labels. - :param theta: log of class conditional probabilities, whose - dimension is C-by-D, where D is number of features. + :param labels: + List of labels. + :param pi: + Log of class priors, whose dimension is C, number of labels. + :param theta: + Log of class conditional probabilities, whose dimension is C-by-D, + where D is number of features. >>> data = [ ... LabeledPoint(0.0, [0.0, 0.0]), @@ -639,8 +661,11 @@ def train(cls, data, lambda_=1.0): it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). The input feature values must be nonnegative. - :param data: RDD of LabeledPoint. - :param lambda_: The smoothing parameter (default: 1.0). + :param data: + RDD of LabeledPoint. + :param lambda_: + The smoothing parameter. + (default: 1.0) """ first = data.first() if not isinstance(first, LabeledPoint): @@ -652,9 +677,9 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Train or predict a logistic regression model on streaming data. Training uses - Stochastic Gradient Descent to update the model based on each new batch of - incoming data from a DStream. + Train or predict a logistic regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model based on + each new batch of incoming data from a DStream. Each batch of data is assumed to be an RDD of LabeledPoints. The number of data points per batch can vary, but the number From 38bc6018e922877fdf43c90b24cc0438262eb157 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Fri, 12 Feb 2016 14:57:31 -0800 Subject: [PATCH 1895/2089] [SPARK-5095] Fix style in mesos coarse grained scheduler code andrewor14 This addressed your style comments from #10993 Author: Michael Gummelt Closes #11187 from mgummelt/fix_mesos_style. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 10 ++++++---- .../mesos/CoarseMesosSchedulerBackendSuite.scala | 12 ++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 98699e0b294ce..f803cc7a36a9a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -124,6 +124,7 @@ private[spark] class CoarseMesosSchedulerBackend( } } + // This method is factored out for testability protected def getShuffleClient(): MesosExternalShuffleClient = { new MesosExternalShuffleClient( SparkTransportConf.fromSparkConf(conf, "shuffle"), @@ -518,10 +519,11 @@ private[spark] class CoarseMesosSchedulerBackend( * Called when a slave is lost or a Mesos task finished. Updates local view on * what tasks are running. It also notifies the driver that an executor was removed. */ - private def executorTerminated(d: SchedulerDriver, - slaveId: String, - taskId: String, - reason: String): Unit = { + private def executorTerminated( + d: SchedulerDriver, + slaveId: String, + taskId: String, + reason: String): Unit = { stateLock.synchronized { removeExecutor(taskId, SlaveLost(reason)) slaves(slaveId).taskIDs.remove(taskId) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index e542aa0cfc4dd..5e01d95fc3bbb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -41,12 +41,12 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with BeforeAndAfter { - var sparkConf: SparkConf = _ - var driver: SchedulerDriver = _ - var taskScheduler: TaskSchedulerImpl = _ - var backend: CoarseMesosSchedulerBackend = _ - var externalShuffleClient: MesosExternalShuffleClient = _ - var driverEndpoint: RpcEndpointRef = _ + private var sparkConf: SparkConf = _ + private var driver: SchedulerDriver = _ + private var taskScheduler: TaskSchedulerImpl = _ + private var backend: CoarseMesosSchedulerBackend = _ + private var externalShuffleClient: MesosExternalShuffleClient = _ + private var driverEndpoint: RpcEndpointRef = _ test("mesos supports killing and limiting executors") { setBackend() From 62b1c07e7e88fe9c951c59cf022dfd52f160cfeb Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Fri, 12 Feb 2016 15:00:39 -0800 Subject: [PATCH 1896/2089] [SPARK-5095] remove flaky test Overrode the start() method, which was previously starting a thread causing a race condition. I believe this should fix the flaky test. Author: Michael Gummelt Closes #11164 from mgummelt/fix_mesos_tests. --- .../cluster/mesos/CoarseMesosSchedulerBackendSuite.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 5e01d95fc3bbb..5db7535d36e18 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -300,6 +300,11 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite override protected def createDriverEndpointRef( properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint + // override to avoid race condition with the driver thread on `mesosDriver` + override def startScheduler(newDriver: SchedulerDriver): Unit = { + mesosDriver = newDriver + } + markRegistered() } backend.start() From 2228f074e1bc11ee452925e10f780eaf24faf9e5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 12 Feb 2016 17:32:15 -0800 Subject: [PATCH 1897/2089] [SPARK-13293][SQL] generate Expand Expand suffer from create the UnsafeRow from same input multiple times, with codegen, it only need to copy some of the columns. After this, we can see 3X improvements (from 43 seconds to 13 seconds) on a TPCDS query (Q67) that have eight columns in Rollup. Ideally, we could mask some of the columns based on bitmask, I'd leave that in the future, because currently Aggregation (50 ns) is much slower than that just copy the variables (1-2 ns). Author: Davies Liu Closes #11177 from davies/gen_expand. --- .../apache/spark/sql/execution/Expand.scala | 124 +++++++++++++++++- .../BenchmarkWholeStageCodegen.scala | 17 +++ 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index c3683cc4e7aac..d26a0b74674a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution +import scala.collection.immutable.IndexedSeq + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Apply the all of the GroupExpressions to every input row, hence we will get @@ -35,7 +39,10 @@ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { + + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) // The GroupExpressions can output data with arbitrary partitioning, so set it // as UNKNOWN partitioning @@ -48,6 +55,8 @@ case class Expand( (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitions { iter => val groups = projections.map(projection).toArray new Iterator[InternalRow] { @@ -71,9 +80,122 @@ case class Expand( idx = 0 } + numOutputRows += 1 result } } } } + + override def upstream(): RDD[InternalRow] = { + child.asInstanceOf[CodegenSupport].upstream() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + /* + * When the projections list looks like: + * expr1A, exprB, expr1C + * expr2A, exprB, expr2C + * ... + * expr(N-1)A, exprB, expr(N-1)C + * + * i.e. column A and C have different values for each output row, but column B stays constant. + * + * The generated code looks something like (note that B is only computed once in declaration): + * + * // part 1: declare all the columns + * colA = ... + * colB = ... + * colC = ... + * + * // part 2: code that computes the columns + * for (row = 0; row < N; row++) { + * switch (row) { + * case 0: + * colA = ... + * colC = ... + * case 1: + * colA = ... + * colC = ... + * ... + * case N - 1: + * colA = ... + * colC = ... + * } + * // increment metrics and consume output values + * } + * + * We use a for loop here so we only includes one copy of the consume code and avoid code + * size explosion. + */ + + // Set input variables + ctx.currentVars = input + + // Tracks whether a column has the same output for all rows. + // Size of sameOutput array should equal N. + // If sameOutput(i) is true, then the i-th column has the same value for all output rows given + // an input row. + val sameOutput: Array[Boolean] = output.indices.map { colIndex => + projections.map(p => p(colIndex)).toSet.size == 1 + }.toArray + + // Part 1: declare variables for each column + // If a column has the same value for all output rows, then we also generate its computation + // right after declaration. Otherwise its value is computed in the part 2. + val outputColumns = output.indices.map { col => + val firstExpr = projections.head(col) + if (sameOutput(col)) { + // This column is the same across all output rows. Just generate code for it here. + BindReferences.bindReference(firstExpr, child.output).gen(ctx) + } else { + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + """.stripMargin + ExprCode(code, isNull, value) + } + } + + // Part 2: switch/case statements + val cases = projections.zipWithIndex.map { case (exprs, row) => + var updateCode = "" + for (col <- exprs.indices) { + if (!sameOutput(col)) { + val ev = BindReferences.bindReference(exprs(col), child.output).gen(ctx) + updateCode += + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + } + } + + s""" + |case $row: + | ${updateCode.trim} + | break; + """.stripMargin + } + + val numOutput = metricTerm(ctx, "numOutputRows") + val i = ctx.freshName("i") + s""" + |${outputColumns.map(_.code).mkString("\n").trim} + |for (int $i = 0; $i < ${projections.length}; $i ++) { + | switch ($i) { + | ${cases.mkString("\n").trim} + | } + | $numOutput.add(1); + | ${consume(ctx, outputColumns)} + |} + """.stripMargin + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 1c7e69f30fb48..4a151179bf6f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -157,6 +157,23 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } + ignore("rube") { + val N = 5 << 20 + + runBenchmark("cube", N) { + sqlContext.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") + .cube("k1", "k2").sum("id").collect() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + cube codegen=false 3188 / 3392 1.6 608.2 1.0X + cube codegen=true 1239 / 1394 4.2 236.3 2.6X + */ + } + ignore("hash and BytesToBytesMap") { val N = 50 << 20 From 374c4b2869fc50570a68819cf0ece9b43ddeb34b Mon Sep 17 00:00:00 2001 From: markpavey Date: Sat, 13 Feb 2016 08:39:43 +0000 Subject: [PATCH 1898/2089] [SPARK-13142][WEB UI] Problem accessing Web UI /logPage/ on Microsoft Windows Due to being on a Windows platform I have been unable to run the tests as described in the "Contributing to Spark" instructions. As the change is only to two lines of code in the Web UI, which I have manually built and tested, I am submitting this pull request anyway. I hope this is OK. Is it worth considering also including this fix in any future 1.5.x releases (if any)? I confirm this is my own original work and license it to the Spark project under its open source license. Author: markpavey Closes #11135 from markpavey/JIRA_SPARK-13142_WindowsWebUILogFix. --- .../scala/org/apache/spark/deploy/worker/ui/LogPage.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 49803a27a5b00..0ca90640aed2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker - private val workDir = parent.workDir + private val workDir = new File(parent.workDir.toURI.normalize().getPath) private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { @@ -138,7 +138,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with } // Verify that the normalized path of the log directory is in the working directory - val normalizedUri = new URI(logDirectory).normalize() + val normalizedUri = new File(logDirectory).toURI.normalize() val normalizedLogDir = new File(normalizedUri.getPath) if (!Utils.isInDirectory(workDir, normalizedLogDir)) { return ("Error: invalid log directory " + logDirectory, 0, 0, 0) From e3441e3f68923224d5b576e6112917cf1fe1f89a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 13 Feb 2016 15:56:20 -0800 Subject: [PATCH 1899/2089] [SPARK-12363][MLLIB] Remove setRun and fix PowerIterationClustering failed test JIRA: https://issues.apache.org/jira/browse/SPARK-12363 This issue is pointed by yanboliang. When `setRuns` is removed from PowerIterationClustering, one of the tests will be failed. I found that some `dstAttr`s of the normalized graph are not correct values but 0.0. By setting `TripletFields.All` in `mapTriplets` it can work. Author: Liang-Chi Hsieh Author: Xiangrui Meng Closes #10539 from viirya/fix-poweriter. --- .../PowerIterationClusteringExample.scala | 53 +++++-------- .../clustering/PowerIterationClustering.scala | 24 +++--- .../PowerIterationClusteringSuite.scala | 79 ++++++++++--------- python/pyspark/mllib/clustering.py | 25 ++++-- 4 files changed, 96 insertions(+), 85 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 9208d8e245881..bb9c1cbca99cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -40,27 +40,23 @@ import org.apache.spark.rdd.RDD * n: Number of sampled points on innermost circle.. There are proportionally more points * within the outer/larger circles * maxIterations: Number of Power Iterations - * outerRadius: radius of the outermost of the concentric circles * }}} * * Here is a sample run and output: * - * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 - * - * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], - * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * ./bin/run-example mllib.PowerIterationClusteringExample -k 2 --n 10 --maxIterations 15 * + * Cluster assignments: 1 -> [0,1,2,3,4,5,6,7,8,9], + * 0 -> [10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] * * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ object PowerIterationClusteringExample { case class Params( - input: String = null, - k: Int = 3, - numPoints: Int = 5, - maxIterations: Int = 10, - outerRadius: Double = 3.0 + k: Int = 2, + numPoints: Int = 10, + maxIterations: Int = 15 ) extends AbstractParams[Params] def main(args: Array[String]) { @@ -69,7 +65,7 @@ object PowerIterationClusteringExample { val parser = new OptionParser[Params]("PowerIterationClusteringExample") { head("PowerIterationClusteringExample: an example PIC app using concentric circles.") opt[Int]('k', "k") - .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .text(s"number of circles (clusters), default: ${defaultParams.k}") .action((x, c) => c.copy(k = x)) opt[Int]('n', "n") .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") @@ -77,9 +73,6 @@ object PowerIterationClusteringExample { opt[Int]("maxIterations") .text(s"number of iterations, default: ${defaultParams.maxIterations}") .action((x, c) => c.copy(maxIterations = x)) - opt[Double]('r', "r") - .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") - .action((x, c) => c.copy(outerRadius = x)) } parser.parse(args, defaultParams).map { params => @@ -97,20 +90,21 @@ object PowerIterationClusteringExample { Logger.getRootLogger.setLevel(Level.WARN) - val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints) val model = new PowerIterationClustering() .setK(params.k) .setMaxIterations(params.maxIterations) + .setInitializationMode("degree") .run(circlesRdd) val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) - val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignments = clusters.toList.sortBy { case (k, v) => v.length } val assignmentsStr = assignments .map { case (k, v) => s"$k -> ${v.sorted.mkString("[", ",", "]")}" - }.mkString(",") + }.mkString(", ") val sizesStr = assignments.map { - _._2.size + _._2.length }.sorted.mkString("(", ",", ")") println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") @@ -124,20 +118,17 @@ object PowerIterationClusteringExample { } } - def generateCirclesRdd(sc: SparkContext, - nCircles: Int = 3, - nPoints: Int = 30, - outerRadius: Double): RDD[(Long, Long, Double)] = { - - val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} - val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} - val points = (0 until nCircles).flatMap { cx => - generateCircle(radii(cx), groupSizes(cx)) + def generateCirclesRdd( + sc: SparkContext, + nCircles: Int, + nPoints: Int): RDD[(Long, Long, Double)] = { + val points = (1 to nCircles).flatMap { i => + generateCircle(i, i * nPoints) }.zipWithIndex val rdd = sc.parallelize(points) val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => if (i0 < i1) { - Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1)))) } else { None } @@ -148,11 +139,9 @@ object PowerIterationClusteringExample { /** * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel */ - def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = { - val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) - val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double)): Double = { val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) - coeff * math.exp(expCoeff * ssquares) + math.exp(-ssquares / 2.0) } } // scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 1ab7cb393b081..feacafec7930f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -25,7 +25,6 @@ import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD @@ -264,10 +263,12 @@ object PowerIterationClustering extends Logging { }, mergeMsg = _ + _, TripletFields.EdgeOnly) - GraphImpl.fromExistingRDDs(vD, graph.edges) + Graph(vD, graph.edges) .mapTriplets( e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), - TripletFields.Src) + new TripletFields(/* useSrc */ true, + /* useDst */ false, + /* useEdge */ true)) } /** @@ -293,10 +294,12 @@ object PowerIterationClustering extends Logging { }, mergeMsg = _ + _, TripletFields.EdgeOnly) - GraphImpl.fromExistingRDDs(vD, gA.edges) + Graph(vD, gA.edges) .mapTriplets( e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), - TripletFields.Src) + new TripletFields(/* useSrc */ true, + /* useDst */ false, + /* useEdge */ true)) } /** @@ -317,7 +320,7 @@ object PowerIterationClustering extends Logging { }, preservesPartitioning = true).cache() val sum = r.values.map(math.abs).sum() val v0 = r.mapValues(x => x / sum) - GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + Graph(VertexRDD(v0), g.edges) } /** @@ -332,7 +335,7 @@ object PowerIterationClustering extends Logging { def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { val sum = g.vertices.values.sum() val v0 = g.vertices.mapValues(_ / sum) - GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + Graph(VertexRDD(v0), g.edges) } /** @@ -357,7 +360,9 @@ object PowerIterationClustering extends Logging { val v = curG.aggregateMessages[Double]( sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr), mergeMsg = _ + _, - TripletFields.Dst).cache() + new TripletFields(/* useSrc */ false, + /* useDst */ true, + /* useEdge */ true)).cache() // normalize v val norm = v.values.map(math.abs).sum() logInfo(s"$msgPrefix: norm(v) = $norm.") @@ -370,7 +375,7 @@ object PowerIterationClustering extends Logging { diffDelta = math.abs(delta - prevDelta) logInfo(s"$msgPrefix: diff(delta) = $diffDelta.") // update v - curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges) + curG = Graph(VertexRDD(v1), g.edges) prevDelta = delta } curG.vertices @@ -387,7 +392,6 @@ object PowerIterationClustering extends Logging { val points = v.mapValues(x => Vectors.dense(x)).cache() val model = new KMeans() .setK(k) - .setRuns(5) .setSeed(0L) .run(points.values) points.mapValues(p => model.predict(p)).cache() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 189000512155f..3d81d375c716e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -30,62 +30,65 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon import org.apache.spark.mllib.clustering.PowerIterationClustering._ + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + test("power iteration clustering") { - /* - We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for - edge (3, 4). - - 15-14 -13 -12 - | | - 4 . 3 - 2 11 - | | x | | - 5 0 - 1 10 - | | - 6 - 7 - 8 - 9 - */ + // Generate two circles following the example in the PIC paper. + val r1 = 1.0 + val n1 = 10 + val r2 = 4.0 + val n2 = 40 + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + val similarities = for (i <- 1 until n; j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) + } - val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), - (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge - (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), - (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) val model = new PowerIterationClustering() .setK(2) + .setMaxIterations(40) .run(sc.parallelize(similarities, 2)) val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } - assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) val model2 = new PowerIterationClustering() .setK(2) + .setMaxIterations(10) .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) model2.assignments.collect().foreach { a => predictions2(a.cluster) += a.id } - assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) } test("power iteration clustering on graph") { - /* - We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for - edge (3, 4). - - 15-14 -13 -12 - | | - 4 . 3 - 2 11 - | | x | | - 5 0 - 1 10 - | | - 6 - 7 - 8 - 9 - */ - - val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), - (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge - (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), - (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + // Generate two circles following the example in the PIC paper. + val r1 = 1.0 + val n1 = 10 + val r2 = 4.0 + val n2 = 40 + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + val similarities = for (i <- 1 until n; j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) + } val edges = similarities.flatMap { case (i, j, s) => if (i != j) { @@ -98,22 +101,24 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon val model = new PowerIterationClustering() .setK(2) + .setMaxIterations(40) .run(graph) val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } - assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) val model2 = new PowerIterationClustering() .setK(2) + .setMaxIterations(10) .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) model2.assignments.collect().foreach { a => predictions2(a.cluster) += a.id } - assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) } test("normalize and powerIter") { diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index ad04e46e8870b..5a5bf59dd5fe3 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -571,12 +571,25 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): Model produced by [[PowerIterationClustering]]. - >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0), - ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), - ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0), - ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)] - >>> rdd = sc.parallelize(data, 2) - >>> model = PowerIterationClustering.train(rdd, 2, 100) + >>> import math + >>> def genCircle(r, n): + ... points = [] + ... for i in range(0, n): + ... theta = 2.0 * math.pi * i / n + ... points.append((r * math.cos(theta), r * math.sin(theta))) + ... return points + >>> def sim(x, y): + ... dist2 = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1]) + ... return math.exp(-dist2 / 2.0) + >>> r1 = 1.0 + >>> n1 = 10 + >>> r2 = 4.0 + >>> n2 = 40 + >>> n = n1 + n2 + >>> points = genCircle(r1, n1) + genCircle(r2, n2) + >>> similarities = [(i, j, sim(points[i], points[j])) for i in range(1, n) for j in range(0, i)] + >>> rdd = sc.parallelize(similarities, 2) + >>> model = PowerIterationClustering.train(rdd, 2, 40) >>> model.k 2 >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) From 610196f93a3a6de5af6a2af29a964be4e30f6e28 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 13 Feb 2016 18:03:53 -0800 Subject: [PATCH 1900/2089] Closes #11185 From 388cd9ea8db2e438ebef9dfb894298f843438c43 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 13 Feb 2016 21:05:48 -0800 Subject: [PATCH 1901/2089] [SPARK-13172][CORE][SQL] Stop using RichException.getStackTrace it is deprecated Replace `getStackTraceString` with `Utils.exceptionString` Author: Sean Owen Closes #11182 from srowen/SPARK-13172. --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 +++--- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 4 ++-- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 3 ++- .../hive/thriftserver/SparkExecuteStatementOperation.scala | 5 +++-- .../org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- .../spark/streaming/receiver/ReceiverSupervisor.scala | 4 ++-- 7 files changed, 14 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ee0b8a1c95fd8..379dc14ad7cd1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -981,7 +981,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return } @@ -1017,7 +1017,7 @@ class DAGScheduler( // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) + abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return } @@ -1044,7 +1044,7 @@ class DAGScheduler( } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2626f5a16dfb8..fe2c8299a0d91 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -617,7 +617,7 @@ object JarCreationTest extends Logging { Utils.classForName(args(1)) } catch { case t: Throwable => - exception = t + "\n" + t.getStackTraceString + exception = t + "\n" + Utils.exceptionString(t) exception = exception.replaceAll("\n", "\n\t") } Option(exception).toSeq.iterator diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 62972a0738211..d8849d59482e6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} -import org.apache.spark.util.CallSite +import org.apache.spark.util.{CallSite, Utils} class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -1665,7 +1665,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } // Does not include message, ONLY stack trace. - val stackTraceString = e.getStackTraceString + val stackTraceString = Utils.exceptionString(e) // should actually include the RDD operation that invoked the method: assert(stackTraceString.contains("org.apache.spark.rdd.RDD.count")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index e028d22a54ba0..cf26d4843d84f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types.DataType +import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. @@ -82,7 +83,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { s""" |Code generation of $expression failed: |$e - |${e.getStackTraceString} + |${Utils.exceptionString(e)} """.stripMargin) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index cd2167c4ecb18..8fef22cf777f6 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ +import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, @@ -231,7 +232,7 @@ private[hive] class SparkExecuteStatementOperation( if (getStatus().getState() == OperationState.CANCELED) { return } else { - setState(OperationState.ERROR); + setState(OperationState.ERROR) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and @@ -241,7 +242,7 @@ private[hive] class SparkExecuteStatementOperation( logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( - statementId, e.getMessage, e.getStackTraceString) + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 8932ce9503a3d..f141a9bd0ff81 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -217,7 +217,7 @@ object SparkSubmitClassLoaderTest extends Logging { Utils.classForName(args(1)) } catch { case t: Throwable => - exception = t + "\n" + t.getStackTraceString + exception = t + "\n" + Utils.exceptionString(t) exception = exception.replaceAll("\n", "\n\t") } Option(exception).toSeq.iterator diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index d0195fb14f0a3..9cde5ae080c0a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.util.{ThreadUtils, Utils} @@ -174,7 +174,7 @@ private[streaming] abstract class ReceiverSupervisor( } } catch { case NonFatal(t) => - logError("Error stopping receiver " + streamId + t.getStackTraceString) + logError(s"Error stopping receiver $streamId ${Utils.exceptionString(t)}") } } From 354d4c24be892271bd9a9eab6ceedfbc5d671c9c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 13 Feb 2016 21:06:31 -0800 Subject: [PATCH 1902/2089] [SPARK-13296][SQL] Move UserDefinedFunction into sql.expressions. This pull request has the following changes: 1. Moved UserDefinedFunction into expressions package. This is more consistent with how we structure the packages for window functions and UDAFs. 2. Moved UserDefinedPythonFunction into execution.python package, so we don't have a random private class in the top level sql package. 3. Move everything in execution/python.scala into the newly created execution.python package. Most of the diffs are just straight copy-paste. Author: Reynold Xin Closes #11181 from rxin/SPARK-13296. --- project/MimaExcludes.scala | 8 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 6 +- .../org/apache/spark/sql/DataFrame.scala | 3 +- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../python/BatchPythonEvaluation.scala | 104 ++++++++++ .../EvaluatePython.scala} | 187 ++---------------- .../execution/python/ExtractPythonUDFs.scala | 79 ++++++++ .../sql/execution/python/PythonUDF.scala | 44 +++++ .../python/UserDefinedPythonFunction.scala | 51 +++++ .../UserDefinedFunction.scala | 39 +--- .../org/apache/spark/sql/functions.scala | 1 + .../apache/spark/sql/hive/HiveContext.scala | 2 +- 15 files changed, 320 insertions(+), 217 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/{python.scala => python/EvaluatePython.scala} (56%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => expressions}/UserDefinedFunction.scala (57%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8611106db0cf0..6abab7f126500 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -235,7 +235,13 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") ) ++ Seq( // SPARK-7889 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"), + // SPARK-13296 + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") ) case v if v.startsWith("1.6") => Seq( diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3104e41407114..83b034fe77435 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -262,7 +262,7 @@ def take(self, num): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe( + port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( self._jdf, num) return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 416d722bbac3c..5fc1cc2cae10a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1652,9 +1652,9 @@ def _create_judf(self, name): jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, - sc.pythonExec, sc.pythonVer, broadcast_vars, - sc._javaAccumulator, jdt) + judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( + name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, + broadcast_vars, sc._javaAccumulator, jdt) return judf def __del__(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c5b2b7d11893c..76c09a285dc40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator +import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d58b99655c1eb..c7d1096a1384c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -193,7 +193,7 @@ class SQLContext private[sql]( protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUDFs :: + python.ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) @@ -915,7 +915,7 @@ class SQLContext private[sql]( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f87a88d49744f..ecfc170bee3ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 598ddd71613b4..73fd22b38e1d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -369,8 +369,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => execution.Exchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child, _) => - BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case e @ python.EvaluatePython(udf, child, _) => + python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala new file mode 100644 index 0000000000000..00df0195279c3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -0,0 +1,104 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.PythonRunner +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructField, StructType} + + +/** + * A physical plan that evalutes a [[PythonUDF]], one partition of tuples at a time. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. + */ +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + inputRDD.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + val fields = udf.children.map(_.dataType) + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + queue.add(row) + EvaluatePython.toJava(currentRow(row), schema) + }.toArray + pickle.dumps(toBePickled) + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val row = new GenericMutableRow(1) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + resultProj(joined(queue.poll(), row)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala similarity index 56% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index bf62bb05c3d93..8c46516594a2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -15,106 +15,41 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.python import java.io.OutputStream -import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ -import net.razorvine.pickle._ +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext} -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil} -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ -private[spark] case class PythonUDF( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { - - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - - override def nullable: Boolean = true -} +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { -/** - * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated - * alone in a batch. - * - * This has the limitation that the input to the Python UDF is not allowed include attributes from - * multiple child operators. - */ -private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan - - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan - } - } - } + def output: Seq[Attribute] = child.output :+ resultAttribute + + // References should not include the produced attribute. + override def references: AttributeSet = udf.references } + object EvaluatePython { def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) @@ -221,7 +156,7 @@ object EvaluatePython { if (array.length != fields.length) { throw new IllegalStateException( s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." + s"${fields.length} fields are required while ${array.length} values are provided." ) } new GenericInternalRow(array.zip(fields).map { @@ -235,7 +170,6 @@ object EvaluatePython { case (c, _) => null } - private val module = "pyspark.sql.types" /** @@ -287,7 +221,7 @@ object EvaluatePython { out.write(Opcodes.MARK) var i = 0 - while (i < row.values.size) { + while (i < row.values.length) { pickler.save(row.values(i)) i += 1 } @@ -298,6 +232,7 @@ object EvaluatePython { } private[this] var registered = false + /** * This should be called before trying to serialize any above classes un cluster mode, * this should be put in the closure @@ -324,91 +259,3 @@ object EvaluatePython { } } } - -/** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. - */ -case class EvaluatePython( - udf: PythonUDF, - child: LogicalPlan, - resultAttribute: AttributeReference) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output :+ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = udf.references -} - -/** - * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * - * Python evaluation works by sending the necessary (projected) input data via a socket to an - * external Python process, and combine the result from the Python process with the original row. - * - * For each row we send to Python, we also put it in a queue. For each output row from Python, - * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and eventually run out of memory. - */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - - inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - - val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) - }.toArray - pickle.dumps(toBePickled) - } - - val context = TaskContext.get() - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) - - val unpickle = new Unpickler - val row = new GenericMutableRow(1) - val joined = new JoinedRow - val resultProj = UnsafeProjection.create(output, output) - - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - resultProj(joined(queue.poll(), row)) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala new file mode 100644 index 0000000000000..6e76e9569febb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -0,0 +1,79 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // Skip EvaluatePython nodes. + case plan: EvaluatePython => plan + + case plan: LogicalPlan if plan.resolved => + // Extract any PythonUDFs from the current operator. + val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + udfs.find(_.resolved) match { + case Some(udf) => + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren)) + + case None => + // If there is no Python UDF that is resolved, skip this round. + plan + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala new file mode 100644 index 0000000000000..0e53a0c4737bb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.{Accumulator, Logging} +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} +import org.apache.spark.sql.types.DataType + +/** + * A serialized version of a Python lambda function. + */ +case class PythonUDF( + name: String, + command: Array[Byte], + envVars: java.util.Map[String, String], + pythonIncludes: java.util.List[String], + pythonExec: String, + pythonVer: String, + broadcastVars: java.util.List[Broadcast[PythonBroadcast]], + accumulator: Accumulator[java.util.List[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with Unevaluable with Logging { + + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" + + override def nullable: Boolean = true +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala new file mode 100644 index 0000000000000..79ac1c85c0be5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.Column +import org.apache.spark.sql.types.DataType + +/** + * A user-defined Python function. This is used by the Python API. + */ +case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: java.util.Map[String, String], + pythonIncludes: java.util.List[String], + pythonExec: String, + pythonVer: String, + broadcastVars: java.util.List[Broadcast[PythonBroadcast]], + accumulator: Accumulator[java.util.List[Array[Byte]]], + dataType: DataType) { + + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, + accumulator, dataType, e) + } + + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ + def apply(exprs: Column*): Column = { + val udf = builder(exprs.map(_.expr)) + Column(udf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala similarity index 57% rename from sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala rename to sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 2fb3bf07aa60b..bd35d19aa20bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -15,16 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.expressions -import java.util.{List => JList, Map => JMap} - -import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions import org.apache.spark.sql.types.DataType /** @@ -50,30 +46,3 @@ case class UserDefinedFunction protected[sql] ( Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } } - -/** - * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. - * This is used by Python API. - */ -private[sql] case class UserDefinedPythonFunction( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType) { - - def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) - } - - /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ - def apply(exprs: Column*): Column = { - val udf = builder(exprs.map(_.expr)) - Column(udf) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d34d377ab66dc..e4ab6b4f23486 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2433b54ffcb8e..ac174aa6bfa6d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -465,7 +465,7 @@ class HiveContext private[hive]( catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUDFs :: + python.ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) From 331293c30242dc43e54a25171ca51a1c9330ae44 Mon Sep 17 00:00:00 2001 From: Amit Dev Date: Sun, 14 Feb 2016 11:41:27 +0000 Subject: [PATCH 1903/2089] [SPARK-13300][DOCUMENTATION] Added pygments.rb dependancy Looks like pygments.rb gem is also required for jekyll build to work. At least on Ubuntu/RHEL I could not do build without this dependency. So added this to steps. Author: Amit Dev Closes #11180 from amitdev/master. --- docs/README.md | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/docs/README.md b/docs/README.md index 1f4fd3e56ed5f..bcea93e1f3b6d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,15 +10,18 @@ whichever version of Spark you currently have checked out of revision control. ## Prerequisites The Spark documentation build uses a number of tools to build HTML docs and API docs in Scala, -Python and R. To get started you can run the following commands +Python and R. - $ sudo gem install jekyll - $ sudo gem install jekyll-redirect-from +You need to have [Ruby](https://www.ruby-lang.org/en/documentation/installation/) and +[Python](https://docs.python.org/2/using/unix.html#getting-and-installing-the-latest-version-of-python) +installed. Also install the following libraries: +```sh + $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments + # Following is needed only for generating API docs $ sudo pip install sphinx $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' - - +``` ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as @@ -38,14 +41,16 @@ compiled files. $ jekyll build You can modify the default Jekyll build as follows: - +```sh # Skip generating API docs (which takes a while) $ SKIP_API=1 jekyll build + # Serve content locally on port 4000 $ jekyll serve --watch + # Build the site with extra features used on the live page $ PRODUCTION=1 jekyll build - +``` ## API Docs (Scaladoc, Sphinx, roxygen2) @@ -59,7 +64,7 @@ When you run `jekyll` in the `docs` directory, it will also copy over the scalad Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs [Sphinx](http://sphinx-doc.org/). +PySpark docs using [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. From 22e9723d6208f2cd2dfa26487ea1c041cb9d7dcd Mon Sep 17 00:00:00 2001 From: Claes Redestad Date: Sun, 14 Feb 2016 11:49:37 +0000 Subject: [PATCH 1904/2089] [SPARK-13278][CORE] Launcher fails to start with JDK 9 EA See http://openjdk.java.net/jeps/223 for more information about the JDK 9 version string scheme. Author: Claes Redestad Closes #11160 from cl4es/master. --- .../org/apache/spark/util/UtilsSuite.scala | 6 ++++-- .../spark/launcher/CommandBuilderUtils.java | 20 ++++++++++++++++--- .../launcher/CommandBuilderUtilsSuite.java | 12 +++++++++++ project/SparkBuild.scala | 6 ++++-- 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index bc926c280c7cd..7c6778b065467 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -784,8 +784,10 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { signal(pid, "SIGKILL") } - val v: String = System.getProperty("java.version") - if (v >= "1.8.0") { + val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) + var majorVersion = versionParts(0).toInt + if (majorVersion == 1) majorVersion = versionParts(1).toInt + if (majorVersion >= 8) { // Java8 added a way to forcibly terminate a process. We'll make sure that works by // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On // older versions of java, this will *not* terminate. diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index d30c2ec5f87bb..e328c8a341c28 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -322,11 +322,9 @@ static void addPermGenSizeOpt(List cmd) { if (getJavaVendor() == JavaVendor.IBM) { return; } - String[] version = System.getProperty("java.version").split("\\."); - if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { + if (javaMajorVersion(System.getProperty("java.version")) > 7) { return; } - for (String arg : cmd) { if (arg.startsWith("-XX:MaxPermSize=")) { return; @@ -336,4 +334,20 @@ static void addPermGenSizeOpt(List cmd) { cmd.add("-XX:MaxPermSize=256m"); } + /** + * Get the major version of the java version string supplied. This method + * accepts any JEP-223-compliant strings (9-ea, 9+100), as well as legacy + * version strings such as 1.7.0_79 + */ + static int javaMajorVersion(String javaVersion) { + String[] version = javaVersion.split("[+.\\-]+"); + int major = Integer.parseInt(version[0]); + // if major > 1, we're using the JEP-223 version string, e.g., 9-ea, 9+120 + // otherwise the second number is the major version + if (major > 1) { + return major; + } else { + return Integer.parseInt(version[1]); + } + } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index bc513ec9b3d10..4fafc43ef293b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -87,6 +87,18 @@ public void testPythonArgQuoting() { assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } + @Test + public void testJavaMajorVersion() { + assertEquals(6, javaMajorVersion("1.6.0_50")); + assertEquals(7, javaMajorVersion("1.7.0_79")); + assertEquals(8, javaMajorVersion("1.8.0_66")); + assertEquals(9, javaMajorVersion("9-ea")); + assertEquals(9, javaMajorVersion("9+100")); + assertEquals(9, javaMajorVersion("9")); + assertEquals(9, javaMajorVersion("9.1.0")); + assertEquals(10, javaMajorVersion("10")); + } + private void testOpt(String opts, List expected) { assertEquals(String.format("test string failed to parse: [[ %s ]]", opts), expected, parseOptionString(opts)); diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6eba58c87c39c..646efb4d09301 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -167,8 +167,10 @@ object SparkBuild extends PomBuild { publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn, javacOptions in (Compile, doc) ++= { - val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) - if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty + val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) + var major = versionParts(0).toInt + if (major == 1) major = versionParts(1).toInt + if (major >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty }, javacJVMVersion := "1.7", From 7cb4d74c98c2f1765b48a549f62e47b53ed29b38 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Sun, 14 Feb 2016 16:00:20 -0800 Subject: [PATCH 1905/2089] [SPARK-13185][SQL] Reuse Calendar object in DateTimeUtils.StringToDate method to improve performance The java `Calendar` object is expensive to create. I have a sub query like this `SELECT a, b, c FROM table UV WHERE (datediff(UV.visitDate, '1997-01-01')>=0 AND datediff(UV.visitDate, '2015-01-01')<=0))` The table stores `visitDate` as String type and has 3 billion records. A `Calendar` object is created every time `DateTimeUtils.stringToDate` is called. By reusing the `Calendar` object, I saw about 20 seconds performance improvement for this stage. Author: Carson Wang Closes #11090 from carsonwang/SPARK-13185. --- .../apache/spark/sql/catalyst/util/DateTimeUtils.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index a159bc6a61415..f184d72285d45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -59,6 +59,13 @@ object DateTimeUtils { @transient lazy val defaultTimeZone = TimeZone.getDefault + // Reuse the Calendar object in each thread as it is expensive to create in each method call. + private val threadLocalGmtCalendar = new ThreadLocal[Calendar] { + override protected def initialValue: Calendar = { + Calendar.getInstance(TimeZoneGMT) + } + } + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { override protected def initialValue: TimeZone = { @@ -408,7 +415,8 @@ object DateTimeUtils { segments(2) < 1 || segments(2) > 31) { return None } - val c = Calendar.getInstance(TimeZoneGMT) + val c = threadLocalGmtCalendar.get() + c.clear() c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) c.set(Calendar.MILLISECOND, 0) Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) From a8bbc4f50ef3faacf4b7fe865a29144ea87f0796 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 14 Feb 2016 17:32:21 -0800 Subject: [PATCH 1906/2089] [SPARK-12503][SPARK-12505] Limit pushdown in UNION ALL and OUTER JOIN This patch adds a new optimizer rule for performing limit pushdown. Limits will now be pushed down in two cases: - If a limit is on top of a `UNION ALL` operator, then a partition-local limit operator will be pushed to each of the union operator's children. - If a limit is on top of an `OUTER JOIN` then a partition-local limit will be pushed to one side of the join. For `LEFT OUTER` and `RIGHT OUTER` joins, the limit will be pushed to the left and right side, respectively. For `FULL OUTER` join, we will only push limits when at most one of the inputs is already limited: if one input is limited we will push a smaller limit on top of it and if neither input is limited then we will limit the input which is estimated to be larger. These optimizations were proposed previously by gatorsmile in #10451 and #10454, but those earlier PRs were closed and deferred for later because at that time Spark's physical `Limit` operator would trigger a full shuffle to perform global limits so there was a chance that pushdowns could actually harm performance by causing additional shuffles/stages. In #7334, we split the `Limit` operator into separate `LocalLimit` and `GlobalLimit` operators, so we can now push down only local limits (which don't require extra shuffles). This patch is based on both of gatorsmile's patches, with changes and simplifications due to partition-local-limiting. When we push down the limit, we still keep the original limit in place, so we need a mechanism to ensure that the optimizer rule doesn't keep pattern-matching once the limit has been pushed down. In order to handle this, this patch adds a `maxRows` method to `SparkPlan` which returns the maximum number of rows that the plan can compute, then defines the pushdown rules to only push limits to children if the children's maxRows are greater than the limit's maxRows. This idea is carried over from #10451; see that patch for additional discussion. Author: Josh Rosen Closes #11121 from JoshRosen/limit-pushdown-2. --- .../sql/catalyst/optimizer/Optimizer.scala | 70 ++++++++- .../catalyst/plans/logical/LogicalPlan.scala | 8 + .../plans/logical/basicOperators.scala | 60 +++++++- .../optimizer/LimitPushdownSuite.scala | 145 ++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 8 +- .../apache/spark/sql/hive/SQLBuilder.scala | 12 +- 6 files changed, 294 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 902e18081bddf..567010f23fc8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,6 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughProject, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + // LimitPushDown, // Disabled until we have whole-stage codegen for limit ColumnPruning, // Operator combine CollapseRepartition, @@ -129,6 +130,69 @@ object EliminateSerialization extends Rule[LogicalPlan] { } } +/** + * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. + */ +object LimitPushDown extends Rule[LogicalPlan] { + + private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { + plan match { + case GlobalLimit(expr, child) => child + case _ => plan + } + } + + private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRows) match { + case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case _ => plan + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Adding extra Limits below UNION ALL for children which are not Limit or do not have Limit + // descendants whose maxRow is larger. This heuristic is valid assuming there does not exist any + // Limit push-down rule that is unable to infer the value of maxRows. + // Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to + // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to + // pushdown Limit. + case LocalLimit(exp, Union(children)) => + LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the + // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side + // because we need to ensure that rows from the limited side still have an opportunity to match + // against all candidates from the non-limited side. We also need to ensure that this limit + // pushdown rule will not eventually introduce limits on both sides if it is applied multiple + // times. Therefore: + // - If one side is already limited, stack another limit on top if the new limit is smaller. + // The redundant limit will be collapsed by the CombineLimits rule. + // - If neither side is limited, limit the side that is estimated to be bigger. + case LocalLimit(exp, join @ Join(left, right, joinType, condition)) => + val newJoin = joinType match { + case RightOuter => join.copy(right = maybePushLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case FullOuter => + (left.maxRows, right.maxRows) match { + case (None, None) => + if (left.statistics.sizeInBytes >= right.statistics.sizeInBytes) { + join.copy(left = maybePushLimit(exp, left)) + } else { + join.copy(right = maybePushLimit(exp, right)) + } + case (Some(_), Some(_)) => join + case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + + } + case _ => join + } + LocalLimit(exp, newJoin) + } +} + /** * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. @@ -985,8 +1049,12 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] { */ object CombineLimits extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ll @ GlobalLimit(le, nl @ GlobalLimit(ne, grandChild)) => + GlobalLimit(Least(Seq(ne, le)), grandChild) + case ll @ LocalLimit(le, nl @ LocalLimit(ne, grandChild)) => + LocalLimit(Least(Seq(ne, le)), grandChild) case ll @ Limit(le, nl @ Limit(ne, grandChild)) => - Limit(If(LessThan(ne, le), ne, le), grandChild) + Limit(Least(Seq(ne, le)), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 18b7bde906fda..35e0f5d5639a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -90,6 +90,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) } + /** + * Returns the maximum number of rows that this plan may compute. + * + * Any operator that a Limit can be pushed passed should override this function (e.g., Union). + * Any operator that can push through a Limit should override this function (e.g., Project). + */ + def maxRows: Option[Long] = None + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e8e0a78904a32..502d898fea86c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -38,6 +38,7 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def maxRows: Option[Long] = child.maxRows override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { @@ -56,6 +57,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. @@ -102,6 +104,8 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = child.maxRows + override protected def validConstraints: Set[Expression] = { child.constraints.union(splitConjunctivePredicates(condition).toSet) } @@ -144,6 +148,14 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation left.output.length == right.output.length && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && duplicateResolved + + override def maxRows: Option[Long] = { + if (children.exists(_.maxRows.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRows).min) + } + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -166,6 +178,13 @@ object Union { } case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { + override def maxRows: Option[Long] = { + if (children.exists(_.maxRows.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRows).sum) + } + } // updating nullability to make all the children consistent override def output: Seq[Attribute] = @@ -305,6 +324,7 @@ case class InsertIntoTable( /** * A container for holding named common table expressions (CTEs) and a query plan. * This operator will be removed during analysis and the relations will be substituted into child. + * * @param child The final query of this CTE. * @param cteRelations Queries that this CTE defined, * key is the alias of the CTE definition, @@ -331,6 +351,7 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = child.maxRows } /** Factory for constructing new `Range` nodes. */ @@ -384,6 +405,7 @@ case class Aggregate( } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def maxRows: Option[Long] = child.maxRows } case class Window( @@ -505,6 +527,7 @@ trait GroupingAnalytics extends UnaryNode { * to generated by a UNION ALL of multiple simple GROUP BY clauses. * * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer + * * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected * GroupBy expressions * @param groupByExprs The Group By expressions candidates, take effective only if the @@ -537,9 +560,42 @@ case class Pivot( } } -case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +object Limit { + def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { + GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) + } + + def unapply(p: GlobalLimit): Option[(Expression, LogicalPlan)] = { + p match { + case GlobalLimit(le1, LocalLimit(le2, child)) if le1 == le2 => Some((le1, child)) + case _ => None + } + } +} + +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } + override lazy val statistics: Statistics = { + val limit = limitExpr.eval().asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + Statistics(sizeInBytes = sizeInBytes) + } +} +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum @@ -576,6 +632,7 @@ case class Sample( * Returns a new logical plan that dedups input rows. */ case class Distinct(child: LogicalPlan) extends UnaryNode { + override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output } @@ -594,6 +651,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * A relation with one row. This is used in "SELECT ..." without a from clause. */ case object OneRowRelation extends LeafNode { + override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala new file mode 100644 index 0000000000000..fc1e99458179c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Add +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class LimitPushdownSuite extends PlanTest { + + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Limit pushdown", FixedPoint(100), + LimitPushDown, + CombineLimits, + ConstantFolding, + BooleanSimplification) :: Nil + } + + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + private val x = testRelation.subquery('x) + private val y = testRelation.subquery('y) + + // Union --------------------------------------------------------------------------------------- + + test("Union: limit to each side") { + val unionQuery = Union(testRelation, testRelation2).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with constant-foldable limit expressions") { + val unionQuery = Union(testRelation, testRelation2).limit(Add(1, 1)) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with the new limit number") { + val unionQuery = Union(testRelation, testRelation2.limit(3)).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: no limit to both sides if children having smaller limit values") { + val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each sides if children having larger limit values") { + val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) + val unionQuery = testLimitUnion.limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + // Outer join ---------------------------------------------------------------------------------- + + test("left outer join") { + val originalQuery = x.join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join") { + val originalQuery = x.join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("larger limits are not pushed on top of smaller ones in right outer join") { + val originalQuery = x.join(y.limit(5), RightOuter).limit(10) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(10, x.join(Limit(5, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and both sides have same statistics") { + assert(x.statistics.sizeInBytes === y.statistics.sizeInBytes) + val originalQuery = x.join(y, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and left side has larger statistics") { + val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) + assert(xBig.statistics.sizeInBytes > y.statistics.sizeInBytes) + val originalQuery = xBig.join(y, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and right side has larger statistics") { + val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) + assert(x.statistics.sizeInBytes < yBig.statistics.sizeInBytes) + val originalQuery = x.join(yBig, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where both sides are limited") { + val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73fd22b38e1d6..042c99db4dcff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -351,10 +351,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil - case logical.Limit(IntegerLiteral(limit), child) => - val perPartitionLimit = execution.LocalLimit(limit, planLater(child)) - val globalLimit = execution.GlobalLimit(limit, perPartitionLimit) - globalLimit :: Nil + case logical.LocalLimit(IntegerLiteral(limit), child) => + execution.LocalLimit(limit, planLater(child)) :: Nil + case logical.GlobalLimit(IntegerLiteral(limit), child) => + execution.GlobalLimit(limit, planLater(child)) :: Nil case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index d7bae913f8720..bf5edb4759fbd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -77,8 +77,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Aggregate => aggregateToSQL(p) - case p: Limit => - s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}" + case Limit(limitExpr, child) => + s"${toSQL(child)} LIMIT ${limitExpr.sql}" case p: Filter => val whereOrHaving = p.child match { @@ -203,7 +203,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi wrapChildWithSubquery(plan) case plan @ Project(_, - _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit + _: Subquery + | _: Filter + | _: Join + | _: MetastoreRelation + | OneRowRelation + | _: LocalLimit + | _: GlobalLimit ) => plan case plan: Project => From 56d49397e01306637edf23bbb4f3b9d396cdc6ff Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Mon, 15 Feb 2016 09:20:49 +0000 Subject: [PATCH 1907/2089] [SPARK-12995][GRAPHX] Remove deprecate APIs from Pregel Author: Takeshi YAMAMURO Closes #10918 from maropu/RemoveDeprecateInPregel. --- .../scala/org/apache/spark/graphx/Graph.scala | 49 ----------------- .../org/apache/spark/graphx/GraphXUtils.scala | 27 ++++++++++ .../org/apache/spark/graphx/Pregel.scala | 6 +-- .../apache/spark/graphx/impl/GraphImpl.scala | 25 --------- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 11 ---- .../org/apache/spark/graphx/GraphSuite.scala | 52 +++---------------- project/MimaExcludes.scala | 6 +++ 7 files changed, 42 insertions(+), 134 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 869caa340f52b..fe884d0022500 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -340,55 +340,6 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] - /** - * Aggregates values from the neighboring edges and vertices of each vertex. The user supplied - * `mapFunc` function is invoked on each edge of the graph, generating 0 or more "messages" to be - * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of - * the map phase destined to each vertex. - * - * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead. - * - * @tparam A the type of "message" to be sent to each vertex - * - * @param mapFunc the user defined map function which returns 0 or - * more messages to neighboring vertices - * - * @param reduceFunc the user defined reduce function which should - * be commutative and associative and is used to combine the output - * of the map phase - * - * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if - * desired. This is done by specifying a set of "active" vertices and an edge direction. The - * `sendMsg` function will then run only on edges connected to active vertices by edges in the - * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with - * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges - * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be - * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` - * will be run on edges with *both* vertices in the active set. The active set must have the - * same index as the graph's vertices. - * - * @example We can use this function to compute the in-degree of each - * vertex - * {{{ - * val rawGraph: Graph[(),()] = Graph.textFile("twittergraph") - * val inDeg: RDD[(VertexId, Int)] = - * mapReduceTriplets[Int](et => Iterator((et.dst.id, 1)), _ + _) - * }}} - * - * @note By expressing computation at the edge level we achieve - * maximum parallelism. This is one of the core functions in the - * Graph API in that enables neighborhood level computation. For - * example this function can be used to count neighbors satisfying a - * predicate or implement PageRank. - * - */ - @deprecated("use aggregateMessages", "1.2.0") - def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) - : VertexRDD[A] - /** * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala index 8ec33e140000e..ef0b943fc3c38 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.graphx +import scala.reflect.ClassTag + import org.apache.spark.SparkConf import org.apache.spark.graphx.impl._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap @@ -24,6 +26,7 @@ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.{BitSet, OpenHashSet} object GraphXUtils { + /** * Registers classes that GraphX uses with Kryo. */ @@ -42,4 +45,28 @@ object GraphXUtils { classOf[OpenHashSet[Int]], classOf[OpenHashSet[Long]])) } + + /** + * A proxy method to map the obsolete API to the new one. + */ + private[graphx] def mapReduceTriplets[VD: ClassTag, ED: ClassTag, A: ClassTag]( + g: Graph[VD, ED], + mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], + reduceFunc: (A, A) => A, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } + g.aggregateMessagesWithActiveSet( + sendMsg, reduceFunc, TripletFields.All, activeSetOpt) + } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 796082721d696..3ba73b4c966f6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -121,7 +121,7 @@ object Pregel extends Logging { { var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() // compute the messages - var messages = g.mapReduceTriplets(sendMsg, mergeMsg) + var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) var activeMessages = messages.count() // Loop var prevG: Graph[VD, ED] = null @@ -135,8 +135,8 @@ object Pregel extends Logging { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 81182adbc6389..c5cb533b13a07 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -187,31 +187,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( // Lower level transformation methods // /////////////////////////////////////////////////////////////////////////////////////////////// - override def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { - - def sendMsg(ctx: EdgeContext[VD, ED, A]) { - mapFunc(ctx.toEdgeTriplet).foreach { kv => - val id = kv._1 - val msg = kv._2 - if (id == ctx.srcId) { - ctx.sendToSrc(msg) - } else { - assert(id == ctx.dstId) - ctx.sendToDst(msg) - } - } - } - - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) - - aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) - } - override def aggregateMessagesWithActiveSet[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 16300e0740790..78a5cb057d14a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -39,17 +39,6 @@ object SVDPlusPlus { var gamma7: Double) extends Serializable - /** - * This method is now replaced by the updated version of `run()` and returns exactly - * the same result. - */ - @deprecated("Call run()", "1.4.0") - def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) - : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = - { - run(edges, conf) - } - /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 2fbc6f069d48d..f497e001dfa4f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -221,7 +221,8 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { val vertices: RDD[(VertexId, Int)] = sc.parallelize(Array((1L, 1), (2L, 2))) val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse - val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) + val result = GraphXUtils.mapReduceTriplets[Int, Int, Int]( + graph, et => Iterator((et.dstId, et.srcAttr)), _ + _) assert(result.collect().toSet === Set((1L, 2))) } } @@ -281,49 +282,6 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } - test("mapReduceTriplets") { - withSpark { sc => - val n = 5 - val star = starGraph(sc, n).mapVertices { (_, _) => 0 }.cache() - val starDeg = star.joinVertices(star.degrees){ (vid, oldV, deg) => deg } - val neighborDegreeSums = starDeg.mapReduceTriplets( - edge => Iterator((edge.srcId, edge.dstAttr), (edge.dstId, edge.srcAttr)), - (a: Int, b: Int) => a + b) - assert(neighborDegreeSums.collect().toSet === (0 to n).map(x => (x, n)).toSet) - - // activeSetOpt - val allPairs = for (x <- 1 to n; y <- 1 to n) yield (x: VertexId, y: VertexId) - val complete = Graph.fromEdgeTuples(sc.parallelize(allPairs, 3), 0) - val vids = complete.mapVertices((vid, attr) => vid).cache() - val active = vids.vertices.filter { case (vid, attr) => attr % 2 == 0 } - val numEvenNeighbors = vids.mapReduceTriplets(et => { - // Map function should only run on edges with destination in the active set - if (et.dstId % 2 != 0) { - throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) - } - Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet - assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) - - // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) - val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() - val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => - newOpt.getOrElse(old) - } - val numOddNeighbors = changedGraph.mapReduceTriplets(et => { - // Map function should only run on edges with source in the active set - if (et.srcId % 2 != 1) { - throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) - } - Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet - assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) - - } - } - test("aggregateMessages") { withSpark { sc => val n = 5 @@ -347,7 +305,8 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } - val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( + val neighborDegreeSums = GraphXUtils.mapReduceTriplets[Int, Int, Int]( + reverseStarDegrees, et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) @@ -420,7 +379,8 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), numEdgePartitions) val graph = Graph.fromEdgeTuples(edges, 1) - val neighborAttrSums = graph.mapReduceTriplets[Int]( + val neighborAttrSums = GraphXUtils.mapReduceTriplets[Int, Int, Int]( + graph, et => Iterator((et.dstId, et.srcAttr)), _ + _) assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) } finally { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6abab7f126500..65375a3ea787b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -242,6 +242,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") + ) ++ Seq( + // SPARK-12995 Remove deprecated APIs in graphx + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") ) case v if v.startsWith("1.6") => Seq( From adb548365012552e991d51740bfd3c25abf0adec Mon Sep 17 00:00:00 2001 From: JeremyNixon Date: Mon, 15 Feb 2016 09:25:13 +0000 Subject: [PATCH 1908/2089] [SPARK-13312][MLLIB] Update java train-validation-split example in ml-guide Response to JIRA https://issues.apache.org/jira/browse/SPARK-13312. This contribution is my original work and I license the work to this project. Author: JeremyNixon Closes #11199 from JeremyNixon/update_train_val_split_example. --- docs/ml-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index f8279262e673f..1770aabf6f5da 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -870,7 +870,7 @@ import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} // Prepare training and test data. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) val lr = new LinearRegression() @@ -913,7 +913,7 @@ import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.tuning.*; import org.apache.spark.sql.DataFrame; -DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_linear_regression_data.txt"); // Prepare training and test data. DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); From cbeb006f23838b2f19e700e20b25003aeb3dfb01 Mon Sep 17 00:00:00 2001 From: seddonm1 Date: Mon, 15 Feb 2016 20:15:27 -0800 Subject: [PATCH 1909/2089] [SPARK-13097][ML] Binarizer allowing Double AND Vector input types This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type. A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image). This contribution is my original work and I license the work to the project under the project's open source license. viirya mengxr Author: seddonm1 Closes #10976 from seddonm1/master. --- .../apache/spark/ml/feature/Binarizer.scala | 62 ++++++++++++++----- .../spark/ml/feature/BinarizerSuite.scala | 36 +++++++++++ 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 544cf05a30d48..2f8e3a0371a41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.feature +import scala.collection.mutable.ArrayBuilder + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types._ /** * :: Experimental :: @@ -62,28 +65,53 @@ final class Binarizer(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) + val schema = dataset.schema + val inputType = schema($(inputCol)).dataType val td = $(threshold) - val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } - val outputColName = $(outputCol) - val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() - dataset.select(col("*"), - binarizer(col($(inputCol))).as(outputColName, metadata)) + + val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } + val binarizerVector = udf { (data: Vector) => + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + + data.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + + Vectors.sparse(data.size, indices.result(), values.result()).compressed + } + + val metadata = outputSchema($(outputCol)).metadata + + inputType match { + case DoubleType => + dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) + case _: VectorUDT => + dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) + } } override def transformSchema(schema: StructType): StructType = { - validateParams() - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - - val inputFields = schema.fields + val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - - val attr = BinaryAttribute.defaultAttr.withName(outputColName) - val outputFields = inputFields :+ attr.toStructField() - StructType(outputFields) + val outCol: StructField = inputType match { + case DoubleType => + BinaryAttribute.defaultAttr.withName(outputColName).toStructField() + case _: VectorUDT => + new StructField(outputColName, new VectorUDT, true) + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ outCol) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 6d2d8fe714444..714b9db3aa19f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } + test("Binarize vector of continuous features with default parameter") { + val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("Binarize vector of continuous features with setter") { + val threshold: Double = 0.2 + val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(threshold) + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("read/write") { val t = new Binarizer() .setInputCol("myInputCol") From e4675c240255207c5dd812fa657e6aca2dc9cfeb Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Mon, 15 Feb 2016 20:17:21 -0800 Subject: [PATCH 1910/2089] [SPARK-13018][DOCS] Replace example code in mllib-pmml-model-export.md using include_example Replace example code in mllib-pmml-model-export.md using include_example https://issues.apache.org/jira/browse/SPARK-13018 The example code in the user guide is embedded in the markdown and hence it is not easy to test. It would be nice to automatically test them. This JIRA is to discuss options to automate example code testing and see what we can do in Spark 1.6. Goal is to move actual example code to spark/examples and test compilation in Jenkins builds. Then in the markdown, we can reference part of the code to show in the user guide. This requires adding a Jekyll tag that is similar to https://github.com/jekyll/jekyll/blob/master/lib/jekyll/tags/include.rb, e.g., called include_example. `{% include_example scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala %}` Jekyll will find `examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala` and pick code blocks marked "example" and replace code block in `{% highlight %}` in the markdown. See more sub-tasks in parent ticket: https://issues.apache.org/jira/browse/SPARK-11337 Author: Xin Ren Closes #11126 from keypointt/SPARK-13018. --- docs/mllib-pmml-model-export.md | 35 +---------- .../mllib/PMMLModelExportExample.scala | 59 +++++++++++++++++++ 2 files changed, 62 insertions(+), 32 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index b532ad907dfc5..58ed5a0e9d702 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -45,41 +45,12 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a
        To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats. + Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors) for details on the API. Here a complete example of building a KMeansModel and print it out in PMML format: -{% highlight scala %} -import org.apache.spark.mllib.clustering.KMeans -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() - -// Cluster the data into two classes using KMeans -val numClusters = 2 -val numIterations = 20 -val clusters = KMeans.train(parsedData, numClusters, numIterations) - -// Export to PMML -println("PMML Model:\n" + clusters.toPMML) -{% endhighlight %} - -As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: - -{% highlight scala %} -// Export the model to a String in PMML format -clusters.toPMML - -// Export the model to a local file in PMML format -clusters.toPMML("/tmp/kmeans.xml") - -// Export the model to a directory on a distributed file system in PMML format -clusters.toPMML(sc,"/tmp/kmeans") - -// Export the model to the OutputStream in PMML format -clusters.toPMML(System.out) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala %} For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala new file mode 100644 index 0000000000000..d74d74a37fb11 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object PMMLModelExportExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PMMLModelExportExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/kmeans_data.txt") + val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + + // Cluster the data into two classes using KMeans + val numClusters = 2 + val numIterations = 20 + val clusters = KMeans.train(parsedData, numClusters, numIterations) + + // Export to PMML to a String in PMML format + println("PMML Model:\n" + clusters.toPMML) + + // Export the model to a local file in PMML format + clusters.toPMML("/tmp/kmeans.xml") + + // Export the model to a directory on a distributed file system in PMML format + clusters.toPMML(sc, "/tmp/kmeans") + + // Export the model to the OutputStream in PMML format + clusters.toPMML(System.out) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println From fee739f07b3bc37dd65682e93e60e0add848f583 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 15 Feb 2016 23:16:58 -0800 Subject: [PATCH 1911/2089] [SPARK-13221] [SQL] Fixing GroupingSets when Aggregate Functions Containing GroupBy Columns Using GroupingSets will generate a wrong result when Aggregate Functions containing GroupBy columns. This PR is to fix it. Since the code changes are very small. Maybe we also can merge it to 1.6 For example, the following query returns a wrong result: ```scala sql("select course, sum(earnings) as sum from courseSales group by course, earnings" + " grouping sets((), (course), (course, earnings))" + " order by course, sum").show() ``` Before the fix, the results are like ``` [null,null] [Java,null] [Java,20000.0] [Java,30000.0] [dotNET,null] [dotNET,5000.0] [dotNET,10000.0] [dotNET,48000.0] ``` After the fix, the results become correct: ``` [null,113000.0] [Java,20000.0] [Java,30000.0] [Java,50000.0] [dotNET,5000.0] [dotNET,10000.0] [dotNET,48000.0] [dotNET,63000.0] ``` UPDATE: This PR also deprecated the external column: GROUPING__ID. Author: gatorsmile Closes #11100 from gatorsmile/groupingSets. --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 34 ++++++ ...CUBE #1-0-63b61fb3f0e74226001ad279be440864 | 6 - ...CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 | 10 -- ...pingSet-0-8c14c24670a4b06c440346277ce9cf1c | 10 -- ...llup #1-0-a78e3dbf242f240249e36b3d3fd0926a | 6 - ...llup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 | 10 -- ...llup #3-0-9257085d123728730be96b6d9fbb84ce | 10 -- .../sql/hive/execution/HiveQuerySuite.scala | 54 --------- .../sql/hive/execution/SQLQuerySuite.scala | 110 ++++++++++++++++++ 10 files changed, 155 insertions(+), 107 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 delete mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 delete mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c delete mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a delete mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 delete mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 26c3d286b19fa..004c1caaffec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -209,13 +209,23 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } + private def hasGroupingId(expr: Seq[Expression]): Boolean = { + expr.exists(_.collectFirst { + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u + }.isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) - case x: GroupingSets => + case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) => + failAnalysis( + s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead") + // Ensure all the expressions have been resolved. + case x: GroupingSets if x.expressions.forall(_.resolved) => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() // Expand works by setting grouping expressions to null as determined by the bitmasks. To diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f665a1c87bd78..b3e179755a19b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2040,6 +2040,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("grouping sets when aggregate functions containing groupBy columns") { + checkAnswer( + sql("select course, sum(earnings) as sum from courseSales group by course, earnings " + + "grouping sets((), (course), (course, earnings)) " + + "order by course, sum"), + Row(null, 113000.0) :: + Row("Java", 20000.0) :: + Row("Java", 30000.0) :: + Row("Java", 50000.0) :: + Row("dotNET", 5000.0) :: + Row("dotNET", 10000.0) :: + Row("dotNET", 48000.0) :: + Row("dotNET", 63000.0) :: Nil + ) + + checkAnswer( + sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " + + "group by course, earnings grouping sets((), (course), (course, earnings)) " + + "order by course, sum"), + Row(null, 113000.0, 3) :: + Row("Java", 20000.0, 0) :: + Row("Java", 30000.0, 0) :: + Row("Java", 50000.0, 1) :: + Row("dotNET", 5000.0, 0) :: + Row("dotNET", 10000.0, 0) :: + Row("dotNET", 48000.0, 0) :: + Row("dotNET", 63000.0, 1) :: Nil + ) + } + test("cube") { checkAnswer( sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), @@ -2103,6 +2133,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("select course, year, grouping_id(course, year) from courseSales group by course, year") } assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year, grouping__id from courseSales group by cube(course, year)") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") } test("SPARK-13056: Null in map value causes NPE") { diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 deleted file mode 100644 index c066aeead822e..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 +++ /dev/null @@ -1,6 +0,0 @@ -500 NULL 1 -91 0 0 -84 1 0 -105 2 0 -113 3 0 -107 4 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 deleted file mode 100644 index c7cb747c0a659..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 +++ /dev/null @@ -1,10 +0,0 @@ -1 NULL -3 2 -1 NULL -1 2 -1 NULL 3 2 -1 NULL 4 2 -1 NULL 5 2 -1 NULL 6 2 -1 NULL 12 2 -1 NULL 14 2 -1 NULL 15 2 -1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c deleted file mode 100644 index c7cb747c0a659..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c +++ /dev/null @@ -1,10 +0,0 @@ -1 NULL -3 2 -1 NULL -1 2 -1 NULL 3 2 -1 NULL 4 2 -1 NULL 5 2 -1 NULL 6 2 -1 NULL 12 2 -1 NULL 14 2 -1 NULL 15 2 -1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a deleted file mode 100644 index c066aeead822e..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a +++ /dev/null @@ -1,6 +0,0 @@ -500 NULL 1 -91 0 0 -84 1 0 -105 2 0 -113 3 0 -107 4 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 deleted file mode 100644 index fcacbe3f69227..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 +++ /dev/null @@ -1,10 +0,0 @@ -1 0 5 0 -1 0 15 0 -1 0 25 0 -1 0 60 0 -1 0 75 0 -1 0 80 0 -1 0 100 0 -1 0 140 0 -1 0 145 0 -1 0 150 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce deleted file mode 100644 index fcacbe3f69227..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce +++ /dev/null @@ -1,10 +0,0 @@ -1 0 5 0 -1 0 15 0 -1 0 25 0 -1 0 60 0 -1 0 75 0 -1 0 80 0 -1 0 100 0 -1 0 140 0 -1 0 145 0 -1 0 150 0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 1337a25eb26a3..3208ebc9ffb2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -123,60 +123,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertBroadcastNestedLoopJoin(spark_10484_4) } - createQueryTest("SPARK-8976 Wrong Result for Rollup #1", - """ - SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for Rollup #2", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM src group by key%5, key-5 - WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for Rollup #3", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 - WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for CUBE #1", - """ - SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for CUBE #2", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 - WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for GroupingSet", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 - GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - createQueryTest("insert table with generator with column name", """ | CREATE TABLE gen_tmp (key Int); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index be864f79d6b7e..c2b0daa66b013 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1551,6 +1551,116 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-8976 Wrong Result for Rollup #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + + test("SPARK-8976 Wrong Result for Rollup #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("SPARK-8976 Wrong Result for Rollup #3") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("SPARK-8976 Wrong Result for CUBE #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + + test("SPARK-8976 Wrong Result for CUBE #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("SPARK-8976 Wrong Result for GroupingSet") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-10562: partition by column with mixed case name") { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 827ed1c06785692d14857bd41f1fd94a0853874a Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 16 Feb 2016 13:01:21 +0000 Subject: [PATCH 1912/2089] Correct SparseVector.parse documentation There's a small typo in the SparseVector.parse docstring (which says that it returns a DenseVector rather than a SparseVector), which seems to be incorrect. Author: Miles Yucht Closes #11213 from mgyucht/fix-sparsevector-docs. --- python/pyspark/mllib/linalg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 131b855bf99c3..abf00a4737948 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -558,7 +558,7 @@ def __reduce__(self): @staticmethod def parse(s): """ - Parse string representation back into the DenseVector. + Parse string representation back into the SparseVector. >>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )') SparseVector(4, {0: 4.0, 1: 5.0}) From 00c72d27bf2e3591c4068fb344fa3edf1662ad81 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Tue, 16 Feb 2016 13:03:28 +0000 Subject: [PATCH 1913/2089] [SPARK-12247][ML][DOC] Documentation for spark.ml's ALS and collaborative filtering in general This documents the implementation of ALS in `spark.ml` with example code in scala, java and python. Author: BenFradet Closes #10411 from BenFradet/SPARK-12247. --- data/mllib/als/sample_movielens_movies.txt | 100 ---------- docs/_data/menu-ml.yaml | 2 + docs/ml-collaborative-filtering.md | 148 ++++++++++++++ docs/mllib-collaborative-filtering.md | 30 +-- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaALSExample.java | 125 ++++++++++++ examples/src/main/python/ml/als_example.py | 57 ++++++ .../apache/spark/examples/ml/ALSExample.scala | 82 ++++++++ .../spark/examples/ml/MovieLensALS.scala | 182 ------------------ .../apache/spark/ml/recommendation/ALS.scala | 2 +- 10 files changed, 431 insertions(+), 298 deletions(-) delete mode 100644 data/mllib/als/sample_movielens_movies.txt create mode 100644 docs/ml-collaborative-filtering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java create mode 100644 examples/src/main/python/ml/als_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala diff --git a/data/mllib/als/sample_movielens_movies.txt b/data/mllib/als/sample_movielens_movies.txt deleted file mode 100644 index 934a0253849e1..0000000000000 --- a/data/mllib/als/sample_movielens_movies.txt +++ /dev/null @@ -1,100 +0,0 @@ -0::Movie 0::Romance|Comedy -1::Movie 1::Action|Anime -2::Movie 2::Romance|Thriller -3::Movie 3::Action|Romance -4::Movie 4::Anime|Comedy -5::Movie 5::Action|Action -6::Movie 6::Action|Comedy -7::Movie 7::Anime|Comedy -8::Movie 8::Comedy|Action -9::Movie 9::Anime|Thriller -10::Movie 10::Action|Anime -11::Movie 11::Action|Anime -12::Movie 12::Anime|Comedy -13::Movie 13::Thriller|Action -14::Movie 14::Anime|Comedy -15::Movie 15::Comedy|Thriller -16::Movie 16::Anime|Romance -17::Movie 17::Thriller|Action -18::Movie 18::Action|Comedy -19::Movie 19::Anime|Romance -20::Movie 20::Action|Anime -21::Movie 21::Romance|Thriller -22::Movie 22::Romance|Romance -23::Movie 23::Comedy|Comedy -24::Movie 24::Anime|Action -25::Movie 25::Comedy|Comedy -26::Movie 26::Anime|Romance -27::Movie 27::Anime|Anime -28::Movie 28::Thriller|Anime -29::Movie 29::Anime|Romance -30::Movie 30::Thriller|Romance -31::Movie 31::Thriller|Romance -32::Movie 32::Comedy|Anime -33::Movie 33::Comedy|Comedy -34::Movie 34::Anime|Anime -35::Movie 35::Action|Thriller -36::Movie 36::Anime|Romance -37::Movie 37::Romance|Anime -38::Movie 38::Thriller|Romance -39::Movie 39::Romance|Comedy -40::Movie 40::Action|Anime -41::Movie 41::Comedy|Thriller -42::Movie 42::Comedy|Action -43::Movie 43::Thriller|Anime -44::Movie 44::Anime|Action -45::Movie 45::Comedy|Romance -46::Movie 46::Comedy|Action -47::Movie 47::Romance|Comedy -48::Movie 48::Action|Comedy -49::Movie 49::Romance|Romance -50::Movie 50::Comedy|Romance -51::Movie 51::Action|Action -52::Movie 52::Thriller|Action -53::Movie 53::Action|Action -54::Movie 54::Romance|Thriller -55::Movie 55::Anime|Romance -56::Movie 56::Comedy|Action -57::Movie 57::Action|Anime -58::Movie 58::Thriller|Romance -59::Movie 59::Thriller|Comedy -60::Movie 60::Anime|Comedy -61::Movie 61::Comedy|Action -62::Movie 62::Comedy|Romance -63::Movie 63::Romance|Thriller -64::Movie 64::Romance|Action -65::Movie 65::Anime|Romance -66::Movie 66::Comedy|Action -67::Movie 67::Thriller|Anime -68::Movie 68::Thriller|Romance -69::Movie 69::Action|Comedy -70::Movie 70::Thriller|Thriller -71::Movie 71::Action|Comedy -72::Movie 72::Thriller|Romance -73::Movie 73::Comedy|Action -74::Movie 74::Action|Action -75::Movie 75::Action|Action -76::Movie 76::Comedy|Comedy -77::Movie 77::Comedy|Comedy -78::Movie 78::Comedy|Comedy -79::Movie 79::Thriller|Thriller -80::Movie 80::Comedy|Anime -81::Movie 81::Comedy|Anime -82::Movie 82::Romance|Anime -83::Movie 83::Comedy|Thriller -84::Movie 84::Anime|Action -85::Movie 85::Thriller|Anime -86::Movie 86::Romance|Anime -87::Movie 87::Thriller|Thriller -88::Movie 88::Romance|Thriller -89::Movie 89::Action|Anime -90::Movie 90::Anime|Romance -91::Movie 91::Anime|Thriller -92::Movie 92::Action|Comedy -93::Movie 93::Romance|Thriller -94::Movie 94::Thriller|Comedy -95::Movie 95::Action|Action -96::Movie 96::Thriller|Romance -97::Movie 97::Thriller|Thriller -98::Movie 98::Thriller|Comedy -99::Movie 99::Thriller|Romance diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index 2eea9a917a4cc..3fd3ee2823f75 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -6,5 +6,7 @@ url: ml-classification-regression.html - text: Clustering url: ml-clustering.html +- text: Collaborative filtering + url: ml-collaborative-filtering.html - text: Advanced topics url: ml-advanced.html diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md new file mode 100644 index 0000000000000..4514a358e12f2 --- /dev/null +++ b/docs/ml-collaborative-filtering.md @@ -0,0 +1,148 @@ +--- +layout: global +title: Collaborative Filtering - spark.ml +displayTitle: Collaborative Filtering - spark.ml +--- + +* Table of contents +{:toc} + +## Collaborative filtering + +[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) +is commonly used for recommender systems. These techniques aim to fill in the +missing entries of a user-item association matrix. `spark.ml` currently supports +model-based collaborative filtering, in which users and products are described +by a small set of latent factors that can be used to predict missing entries. +`spark.ml` uses the [alternating least squares +(ALS)](http://dl.acm.org/citation.cfm?id=1608614) +algorithm to learn these latent factors. The implementation in `spark.ml` has the +following parameters: + +* *numBlocks* is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10). +* *rank* is the number of latent factors in the model (defaults to 10). +* *maxIter* is the maximum number of iterations to run (defaults to 10). +* *regParam* specifies the regularization parameter in ALS (defaults to 1.0). +* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for + *implicit feedback* data (defaults to `false` which means using *explicit feedback*). +* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the + *baseline* confidence in preference observations (defaults to 1.0). +* *nonnegative* specifies whether or not to use nonnegative constraints for least squares (defaults to `false`). + +### Explicit vs. implicit feedback + +The standard approach to matrix factorization based collaborative filtering treats +the entries in the user-item matrix as *explicit* preferences given by the user to the item, +for example, users giving ratings to movies. + +It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken +from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data +as numbers representing the *strength* in observations of user actions (such as the number of clicks, +or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of +confidence in observed user preferences, rather than explicit ratings given to items. The model +then tries to find latent factors that can be used to predict the expected preference of a user for +an item. + +### Scaling of the regularization parameter + +We scale the regularization parameter `regParam` in solving each least squares problem by +the number of ratings the user generated in updating user factors, +or the number of ratings the product received in updating product factors. +This approach is named "ALS-WR" and discussed in the paper +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +It makes `regParam` less dependent on the scale of the dataset, so we can apply the +best parameter learned from a sampled subset to the full dataset and expect similar performance. + +## Examples + +
        +
        + +In the following example, we load rating data from the +[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row +consisting of a user, a movie, a rating and a timestamp. +We then train an ALS model which assumes, by default, that the ratings are +explicit (`implicitPrefs` is `false`). +We evaluate the recommendation model by measuring the root-mean-square error of +rating prediction. + +Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.ml.recommendation.ALS) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ALSExample.scala %} + +If the rating matrix is derived from another source of information (i.e. it is +inferred from other signals), you can set `implicitPrefs` to `true` to get +better results: + +{% highlight scala %} +val als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setImplicitPrefs(true) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating") +{% endhighlight %} + +
        + +
        + +In the following example, we load rating data from the +[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row +consisting of a user, a movie, a rating and a timestamp. +We then train an ALS model which assumes, by default, that the ratings are +explicit (`implicitPrefs` is `false`). +We evaluate the recommendation model by measuring the root-mean-square error of +rating prediction. + +Refer to the [`ALS` Java docs](api/java/org/apache/spark/ml/recommendation/ALS.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaALSExample.java %} + +If the rating matrix is derived from another source of information (i.e. it is +inferred from other signals), you can set `implicitPrefs` to `true` to get +better results: + +{% highlight java %} +ALS als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setImplicitPrefs(true) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating"); +{% endhighlight %} + +
        + +
        + +In the following example, we load rating data from the +[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row +consisting of a user, a movie, a rating and a timestamp. +We then train an ALS model which assumes, by default, that the ratings are +explicit (`implicitPrefs` is `False`). +We evaluate the recommendation model by measuring the root-mean-square error of +rating prediction. + +Refer to the [`ALS` Python docs](api/python/pyspark.ml.html#pyspark.ml.recommendation.ALS) +for more details on the API. + +{% include_example python/ml/als_example.py %} + +If the rating matrix is derived from another source of information (i.e. it is +inferred from other signals), you can set `implicitPrefs` to `True` to get +better results: + +{% highlight python %} +als = ALS(maxIter=5, regParam=0.01, implicitPrefs=True, + userCol="userId", itemCol="movieId", ratingCol="rating") +{% endhighlight %} + +
        +
        diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 1ebb4654aef12..b8f0566d8735e 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -31,17 +31,18 @@ following parameters: ### Explicit vs. implicit feedback The standard approach to matrix factorization based collaborative filtering treats -the entries in the user-item matrix as *explicit* preferences given by the user to the item. +the entries in the user-item matrix as *explicit* preferences given by the user to the item, +for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken -from -[Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). -Essentially instead of trying to model the matrix of ratings directly, this approach treats the data -as a combination of binary preferences and *confidence values*. The ratings are then related to the -level of confidence in observed user preferences, rather than explicit ratings given to items. The -model then tries to find latent factors that can be used to predict the expected preference of a -user for an item. +from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data +as numbers representing the *strength* in observations of user actions (such as the number of clicks, +or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of +confidence in observed user preferences, rather than explicit ratings given to items. The model +then tries to find latent factors that can be used to predict the expected preference of a user for +an item. ### Scaling of the regularization parameter @@ -50,9 +51,8 @@ the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper "[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". -It makes `lambda` less dependent on the scale of the dataset. -So we can apply the best parameter learned from a sampled subset to the full dataset -and expect similar performance. +It makes `lambda` less dependent on the scale of the dataset, so we can apply the +best parameter learned from a sampled subset to the full dataset and expect similar performance. ## Examples @@ -64,11 +64,11 @@ We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.rec method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. -Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API. +Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for more details on the API. {% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %} -If the rating matrix is derived from another source of information (e.g., it is inferred from +If the rating matrix is derived from another source of information (i.e. it is inferred from other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} @@ -85,7 +85,7 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: -Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. +Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for more details on the API. {% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
        @@ -99,7 +99,7 @@ Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.rec {% include_example python/mllib/recommendation_example.py %} -If the rating matrix is derived from other source of information (i.e., it is inferred from other +If the rating matrix is derived from other source of information (i.e. it is inferred from other signals), you can use the trainImplicit method to get better results. {% highlight python %} diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 7ef91a178ccd1..fa5e90603505d 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -71,6 +71,7 @@ We list major functionality from both below, with links to detailed guides. * [Extracting, transforming and selecting features](ml-features.html) * [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) +* [Collaborative filtering](ml-collaborative-filtering.html) * [Advanced topics](ml-advanced.html) Some techniques are not available yet in spark.ml, most notably dimensionality reduction diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java new file mode 100644 index 0000000000000..90d2ac2b13bda --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.io.Serializable; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.recommendation.ALS; +import org.apache.spark.ml.recommendation.ALSModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.types.DataTypes; +// $example off$ + +public class JavaALSExample { + + // $example on$ + public static class Rating implements Serializable { + private int userId; + private int movieId; + private float rating; + private long timestamp; + + public Rating() {} + + public Rating(int userId, int movieId, float rating, long timestamp) { + this.userId = userId; + this.movieId = movieId; + this.rating = rating; + this.timestamp = timestamp; + } + + public int getUserId() { + return userId; + } + + public int getMovieId() { + return movieId; + } + + public float getRating() { + return rating; + } + + public long getTimestamp() { + return timestamp; + } + + public static Rating parseRating(String str) { + String[] fields = str.split("::"); + if (fields.length != 4) { + throw new IllegalArgumentException("Each line must contain 4 fields"); + } + int userId = Integer.parseInt(fields[0]); + int movieId = Integer.parseInt(fields[1]); + float rating = Float.parseFloat(fields[2]); + long timestamp = Long.parseLong(fields[3]); + return new Rating(userId, movieId, rating, timestamp); + } + } + // $example off$ + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaALSExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt") + .map(new Function() { + public Rating call(String str) { + return Rating.parseRating(str); + } + }); + DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); + DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + // Build the recommendation model using ALS on the training data + ALS als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating"); + ALSModel model = als.fit(training); + + // Evaluate the model by computing the RMSE on the test data + DataFrame rawPredictions = model.transform(test); + DataFrame predictions = rawPredictions + .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) + .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); + + RegressionEvaluator evaluator = new RegressionEvaluator() + .setMetricName("rmse") + .setLabelCol("rating") + .setPredictionCol("prediction"); + Double rmse = evaluator.evaluate(predictions); + System.out.println("Root-mean-square error = " + rmse); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py new file mode 100644 index 0000000000000..f61c8ab5d6328 --- /dev/null +++ b/examples/src/main/python/ml/als_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext + +# $example on$ +import math + +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.recommendation import ALS +from pyspark.sql import Row +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ALSExample") + sqlContext = SQLContext(sc) + + # $example on$ + lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt") + parts = lines.map(lambda l: l.split("::")) + ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]), + rating=float(p[2]), timestamp=long(p[3]))) + ratings = sqlContext.createDataFrame(ratingsRDD) + (training, test) = ratings.randomSplit([0.8, 0.2]) + + # Build the recommendation model using ALS on the training data + als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating") + model = als.fit(training) + + # Evaluate the model by computing the RMSE on the test data + rawPredictions = model.transform(test) + predictions = rawPredictions\ + .withColumn("rating", rawPredictions.rating.cast("double"))\ + .withColumn("prediction", rawPredictions.prediction.cast("double")) + evaluator =\ + RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction") + rmse = evaluator.evaluate(predictions) + print("Root-mean-square error = " + str(rmse)) + # $example off$ + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala new file mode 100644 index 0000000000000..a79e15c767e1f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.recommendation.ALS +// $example off$ +import org.apache.spark.sql.SQLContext +// $example on$ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType +// $example off$ + +object ALSExample { + + // $example on$ + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + // $example off$ + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ALSExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val ratings = sc.textFile("data/mllib/als/sample_movielens_ratings.txt") + .map(Rating.parseRating) + .toDF() + val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2)) + + // Build the recommendation model using ALS on the training data + val als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating") + val model = als.fit(training) + + // Evaluate the model by computing the RMSE on the test data + val predictions = model.transform(test) + .withColumn("rating", col("rating").cast(DoubleType)) + .withColumn("prediction", col("prediction").cast(DoubleType)) + + val evaluator = new RegressionEvaluator() + .setMetricName("rmse") + .setLabelCol("rating") + .setPredictionCol("prediction") + val rmse = evaluator.evaluate(predictions) + println(s"Root-mean-square error = $rmse") + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala deleted file mode 100644 index 02ed746954f23..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.recommendation.ALS -import org.apache.spark.sql.{Row, SQLContext} - -/** - * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). - * Run with - * {{{ - * bin/run-example ml.MovieLensALS - * }}} - */ -object MovieLensALS { - - case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) - - object Rating { - def parseRating(str: String): Rating = { - val fields = str.split("::") - assert(fields.size == 4) - Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) - } - } - - case class Movie(movieId: Int, title: String, genres: Seq[String]) - - object Movie { - def parseMovie(str: String): Movie = { - val fields = str.split("::") - assert(fields.size == 3) - Movie(fields(0).toInt, fields(1), fields(2).split("\\|")) - } - } - - case class Params( - ratings: String = null, - movies: String = null, - maxIter: Int = 10, - regParam: Double = 0.1, - rank: Int = 10, - numBlocks: Int = 10) extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("MovieLensALS") { - head("MovieLensALS: an example app for ALS on MovieLens data.") - opt[String]("ratings") - .required() - .text("path to a MovieLens dataset of ratings") - .action((x, c) => c.copy(ratings = x)) - opt[String]("movies") - .required() - .text("path to a MovieLens dataset of movies") - .action((x, c) => c.copy(movies = x)) - opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}") - .action((x, c) => c.copy(rank = x)) - opt[Int]("maxIter") - .text(s"max number of iterations, default: ${defaultParams.maxIter}") - .action((x, c) => c.copy(maxIter = x)) - opt[Double]("regParam") - .text(s"regularization parameter, default: ${defaultParams.regParam}") - .action((x, c) => c.copy(regParam = x)) - opt[Int]("numBlocks") - .text(s"number of blocks, default: ${defaultParams.numBlocks}") - .action((x, c) => c.copy(numBlocks = x)) - note( - """ - |Example command line to run this app: - | - | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ - | examples/target/scala-*/spark-examples-*.jar \ - | --rank 10 --maxIter 15 --regParam 0.1 \ - | --movies data/mllib/als/sample_movielens_movies.txt \ - | --ratings data/mllib/als/sample_movielens_ratings.txt - """.stripMargin) - } - - parser.parse(args, defaultParams).map { params => - run(params) - } getOrElse { - System.exit(1) - } - } - - def run(params: Params) { - val conf = new SparkConf().setAppName(s"MovieLensALS with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - - val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() - - val numRatings = ratings.count() - val numUsers = ratings.map(_.userId).distinct().count() - val numMovies = ratings.map(_.movieId).distinct().count() - - println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - - val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) - val training = splits(0).cache() - val test = splits(1).cache() - - val numTraining = training.count() - val numTest = test.count() - println(s"Training: $numTraining, test: $numTest.") - - ratings.unpersist(blocking = false) - - val als = new ALS() - .setUserCol("userId") - .setItemCol("movieId") - .setRank(params.rank) - .setMaxIter(params.maxIter) - .setRegParam(params.regParam) - .setNumBlocks(params.numBlocks) - - val model = als.fit(training.toDF()) - - val predictions = model.transform(test.toDF()).cache() - - // Evaluate the model. - // TODO: Create an evaluator to compute RMSE. - val mse = predictions.select("rating", "prediction").rdd - .flatMap { case Row(rating: Float, prediction: Float) => - val err = rating.toDouble - prediction - val err2 = err * err - if (err2.isNaN) { - None - } else { - Some(err2) - } - }.mean() - val rmse = math.sqrt(mse) - println(s"Test RMSE = $rmse.") - - // Inspect false positives. - // Note: We reference columns in 2 ways: - // (1) predictions("movieId") lets us specify the movieId column in the predictions - // DataFrame, rather than the movieId column in the movies DataFrame. - // (2) $"userId" specifies the userId column in the predictions DataFrame. - // We could also write predictions("userId") but do not have to since - // the movies DataFrame does not have a column "userId." - val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF() - val falsePositives = predictions.join(movies) - .where((predictions("movieId") === movies("movieId")) - && ($"rating" <= 1) && ($"prediction" >= 4)) - .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction") - val numFalsePositives = falsePositives.count() - println(s"Found $numFalsePositives false positives") - if (numFalsePositives > 0) { - println(s"Example false positives:") - falsePositives.limit(100).collect().foreach(println) - } - - sc.stop() - } -} -// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1481c8268a198..8f49423af89ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -496,7 +496,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } /** - * Solves a nonnegative least squares problem with L2 regularizatin: + * Solves a nonnegative least squares problem with L2 regularization: * * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ * subject to x >= 0 From 19dc69de795eb08f3bab4988ad88732bf8ca7bae Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 16 Feb 2016 10:54:44 -0800 Subject: [PATCH 1914/2089] [SPARK-12976][SQL] Add LazilyGenerateOrdering and use it for RangePartitioner of Exchange. Add `LazilyGenerateOrdering` to support generated ordering for `RangePartitioner` of `Exchange` instead of `InterpretedOrdering`. Author: Takuya UESHIN Closes #10894 from ueshin/issues/SPARK-12976. --- .../codegen/GenerateOrdering.scala | 37 +++++++++++++++++++ .../apache/spark/sql/execution/Exchange.scala | 6 +-- .../apache/spark/sql/execution/limit.scala | 7 ++-- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 6de57537ec078..5756f201fd60f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import java.io.ObjectInputStream + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Inherits some default implementation for Java from `Ordering[Row]` @@ -138,3 +141,37 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } + +/** + * A lazily generated row ordering comparator. + */ +class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] { + + def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = + this(ordering.map(BindReferences.bindReference(_, inputSchema))) + + @transient + private[this] var generatedOrdering = GenerateOrdering.generate(ordering) + + def compare(a: InternalRow, b: InternalRow): Int = { + generatedOrdering.compare(a, b) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + generatedOrdering = GenerateOrdering.generate(ordering) + } +} + +object LazilyGeneratedOrdering { + + /** + * Creates a [[LazilyGeneratedOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(schema: StructType): LazilyGeneratedOrdering = { + new LazilyGeneratedOrdering(schema.zipWithIndex.map { + case (field, ordinal) => + SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 97f65f18bfdcc..e30adefc69ece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -206,10 +207,7 @@ object Exchange { val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } - // We need to use an interpreted ordering here because generated orderings cannot be - // serialized and this ordering needs to be created on the driver in order to be passed into - // Spark core code. - implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes) + implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 04daf9d0ce2a6..ef76847bcb2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ @@ -88,11 +89,8 @@ case class TakeOrderedAndProject( override def outputPartitioning: Partitioning = SinglePartition - // We need to use an interpreted ordering here because generated orderings cannot be serialized - // and this ordering needs to be created on the driver in order to be passed into Spark core code. - private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - override def executeCollect(): Array[InternalRow] = { + val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) if (projectList.isDefined) { val proj = UnsafeProjection.create(projectList.get, child.output) @@ -105,6 +103,7 @@ case class TakeOrderedAndProject( private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { + val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val localTopK: RDD[InternalRow] = { child.execute().map(_.copy()).mapPartitions { iter => org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord) From c7d00a24da317c9601a9239ac1cf185fb6647352 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 16 Feb 2016 11:25:43 -0800 Subject: [PATCH 1915/2089] [SPARK-13280][STREAMING] Use a better logger name for FileBasedWriteAheadLog. The new logger name is under the org.apache.spark namespace. The detection of the caller name was also enhanced a bit to ignore some common things that show up in the call stack. Author: Marcelo Vanzin Closes #11165 from vanzin/SPARK-13280. --- .../util/FileBasedWriteAheadLog.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 15ad2e27d372f..314263f26ee60 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -57,12 +57,18 @@ private[streaming] class FileBasedWriteAheadLog( import FileBasedWriteAheadLog._ private val pastLogs = new ArrayBuffer[LogInfo] - private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") + private val callerName = getCallerName - private val threadpoolName = s"WriteAheadLogManager $callerNameTag" + private val threadpoolName = { + "WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("") + } private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20) private val executionContext = ExecutionContext.fromExecutorService(threadpool) - override protected val logName = s"WriteAheadLogManager $callerNameTag" + + override protected def logName = { + getClass.getName.stripSuffix("$") + + callerName.map("_" + _).getOrElse("").replaceAll("[ ]", "_") + } private var currentLogPath: Option[String] = None private var currentLogWriter: FileBasedWriteAheadLogWriter = null @@ -253,8 +259,12 @@ private[streaming] object FileBasedWriteAheadLog { } def getCallerName(): Option[String] = { - val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) - stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split("\\.").lastOption) + val blacklist = Seq("WriteAheadLog", "Logging", "java.lang", "scala.") + Thread.currentThread.getStackTrace() + .map(_.getClassName) + .find { c => !blacklist.exists(c.contains) } + .flatMap(_.split("\\.").lastOption) + .flatMap(_.split("\\$\\$").headOption) } /** Convert a sequence of files to a sequence of sorted LogInfo objects */ From 5f37aad48cb729a80c4cc25347460f12aafec9fb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 16 Feb 2016 12:06:30 -0800 Subject: [PATCH 1916/2089] [SPARK-13308] ManagedBuffers passed to OneToOneStreamManager need to be freed in non-error cases ManagedBuffers that are passed to `OneToOneStreamManager.registerStream` need to be freed by the manager once it's done using them. However, the current code only frees them in certain error-cases and not during typical operation. This isn't a major problem today, but it will cause memory leaks after we implement better locking / pinning in the BlockManager (see #10705). This patch modifies the relevant network code so that the ManagedBuffers are freed as soon as the messages containing them are processed by the lower-level Netty message sending code. /cc zsxwing for review. Author: Josh Rosen Closes #11193 from JoshRosen/add-missing-release-calls-in-network-layer. --- .../spark/network/buffer/ManagedBuffer.java | 6 ++- .../network/buffer/NettyManagedBuffer.java | 2 +- .../network/protocol/MessageEncoder.java | 7 ++- .../network/protocol/MessageWithHeader.java | 28 ++++++++++- .../server/OneForOneStreamManager.java | 1 - .../protocol/MessageWithHeaderSuite.java | 34 +++++++++++-- .../server/OneForOneStreamManagerSuite.java | 50 +++++++++++++++++++ 7 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index a415db593a788..1861f8d7fd8f3 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -65,7 +65,11 @@ public abstract class ManagedBuffer { public abstract ManagedBuffer release(); /** - * Convert the buffer into an Netty object, used to write the data out. + * Convert the buffer into an Netty object, used to write the data out. The return value is either + * a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}. + * + * If this method returns a ByteBuf, then that buffer's reference count will be incremented and + * the caller will be responsible for releasing this new reference. */ public abstract Object convertToNetty() throws IOException; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index c806bfa45bef3..4c8802af7ae67 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -64,7 +64,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - return buf.duplicate(); + return buf.duplicate().retain(); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index abca22347b783..664df57feca4f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -54,6 +54,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro body = in.body().convertToNetty(); isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { + in.body().release(); if (in instanceof AbstractResponseMessage) { AbstractResponseMessage resp = (AbstractResponseMessage) in; // Re-encode this message as a failure response. @@ -80,8 +81,10 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro in.encode(header); assert header.writableBytes() == 0; - if (body != null && bodyLength > 0) { - out.add(new MessageWithHeader(header, body, bodyLength)); + if (body != null) { + // We transfer ownership of the reference on in.body() to MessageWithHeader. + // This reference will be freed when MessageWithHeader.deallocate() is called. + out.add(new MessageWithHeader(in.body(), header, body, bodyLength)); } else { out.add(header); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index d686a951467cf..66227f96a1a21 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.channels.WritableByteChannel; +import javax.annotation.Nullable; import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; @@ -26,6 +27,8 @@ import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; +import org.apache.spark.network.buffer.ManagedBuffer; + /** * A wrapper message that holds two separate pieces (a header and a body). * @@ -33,15 +36,35 @@ */ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + @Nullable private final ManagedBuffer managedBuffer; private final ByteBuf header; private final int headerLength; private final Object body; private final long bodyLength; private long totalBytesTransferred; - MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + /** + * Construct a new MessageWithHeader. + * + * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to + * be passed in so that the buffer can be freed when this message is + * deallocated. Ownership of the caller's reference to this buffer is + * transferred to this class, so if the caller wants to continue to use the + * ManagedBuffer in other messages then they will need to call retain() on + * it before passing it to this constructor. This may be null if and only if + * `body` is a {@link FileRegion}. + * @param header the message header. + * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}. + * @param bodyLength the length of the message body, in bytes. + */ + MessageWithHeader( + @Nullable ManagedBuffer managedBuffer, + ByteBuf header, + Object body, + long bodyLength) { Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, "Body must be a ByteBuf or a FileRegion."); + this.managedBuffer = managedBuffer; this.header = header; this.headerLength = header.readableBytes(); this.body = body; @@ -99,6 +122,9 @@ public long transferTo(final WritableByteChannel target, final long position) th protected void deallocate() { header.release(); ReferenceCountUtil.release(body); + if (managedBuffer != null) { + managedBuffer.release(); + } } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index e671854da1cae..ea9e735e0a173 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,7 +20,6 @@ import java.util.Iterator; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index 6c98e733b462f..fbbe4b7014ff2 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -26,9 +26,13 @@ import io.netty.channel.FileRegion; import io.netty.util.AbstractReferenceCounted; import org.junit.Test; +import org.mockito.Mockito; import static org.junit.Assert.*; +import org.apache.spark.network.TestManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.util.ByteArrayWritableChannel; public class MessageWithHeaderSuite { @@ -46,20 +50,43 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { ByteBuf header = Unpooled.copyLong(42); - ByteBuf body = Unpooled.copyLong(84); - MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); + assertEquals(1, header.refCnt()); + assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); + ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer); + Object body = managedBuf.convertToNetty(); + assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt()); + assertEquals(1, header.refCnt()); + + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); assertEquals(42, result.readLong()); assertEquals(84, result.readLong()); + + assert(msg.release()); + assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); + assertEquals(0, header.refCnt()); + } + + @Test + public void testDeallocateReleasesManagedBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); + ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); + assertEquals(2, body.refCnt()); + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); + assert(msg.release()); + Mockito.verify(managedBuf, Mockito.times(1)).release(); + assertEquals(0, body.refCnt()); } private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { ByteBuf header = Unpooled.copyLong(42); int headerLength = header.readableBytes(); TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); - MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count()); ByteBuf result = doWrite(msg, totalWrites / writesPerCall); assertEquals(headerLength + region.count(), result.readableBytes()); @@ -67,6 +94,7 @@ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Excep for (long i = 0; i < 8; i++) { assertEquals(i, result.readLong()); } + assert(msg.release()); } private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java new file mode 100644 index 0000000000000..c647525d8f1bd --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import org.junit.Test; +import org.mockito.Mockito; + +import org.apache.spark.network.TestManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; + +public class OneForOneStreamManagerSuite { + + @Test + public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { + OneForOneStreamManager manager = new OneForOneStreamManager(); + List buffers = new ArrayList<>(); + TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); + TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); + buffers.add(buffer1); + buffers.add(buffer2); + long streamId = manager.registerStream("appId", buffers.iterator()); + + Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); + manager.registerChannel(dummyChannel, streamId); + + manager.connectionTerminated(dummyChannel); + + Mockito.verify(buffer1, Mockito.times(1)).release(); + Mockito.verify(buffer2, Mockito.times(1)).release(); + } +} From 7218c0eba957e0a079a407b79c3a050cce9647b2 Mon Sep 17 00:00:00 2001 From: junhao Date: Tue, 16 Feb 2016 19:43:17 -0800 Subject: [PATCH 1917/2089] [SPARK-11627] Add initial input rate limit for spark streaming backpressure mechanism. https://issues.apache.org/jira/browse/SPARK-11627 Spark Streaming backpressure mechanism has no initial input rate limit, it might cause OOM exception. In the firest batch task ,receivers receive data at the maximum speed they can reach,it might exhaust executors memory resources. Add a initial input rate limit value can make sure the Streaming job execute success in the first batch,then the backpressure mechanism can adjust receiving rate adaptively. Author: junhao Closes #9593 from junhaoMg/junhao-dev. --- docs/configuration.md | 8 ++++++++ .../apache/spark/streaming/receiver/RateLimiter.scala | 9 ++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 0dbfe3b0796ba..a2c0dfe76ca45 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1470,6 +1470,14 @@ Apart from these, the following properties are also available, and may be useful if they are set (see below). + + spark.streaming.backpressure.initialRate + not set + + This is the initial maximum receiving rate at which each receiver will receive data for the + first batch when the backpressure mechanism is enabled. + + spark.streaming.blockInterval 200ms diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index bca1fbc8fda2f..6a1b672220bd2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -36,7 +36,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { // treated as an upper limit private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue) - private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble) + private lazy val rateLimiter = GuavaRateLimiter.create(getInitialRateLimit().toDouble) def waitToPush() { rateLimiter.acquire() @@ -61,4 +61,11 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { rateLimiter.setRate(newRate) } } + + /** + * Get the initial rateLimit to initial rateLimiter + */ + private def getInitialRateLimit(): Long = { + math.min(conf.getLong("spark.streaming.backpressure.initialRate", maxRateLimit), maxRateLimit) + } } From 1e1e31e03df14f2e7a9654e640fb2796cf059fe0 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Tue, 16 Feb 2016 22:27:34 -0800 Subject: [PATCH 1918/2089] [SPARK-13279] Remove O(n^2) operation from scheduler. This commit removes an unnecessary duplicate check in addPendingTask that meant that scheduling a task set took time proportional to (# tasks)^2. Author: Sital Kedia Closes #11175 from sitalkedia/fix_stuck_driver. --- .../spark/scheduler/TaskSetManager.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index cf97877476d54..05cfa52f16b3e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -114,9 +114,14 @@ private[spark] class TaskSetManager( // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. + // back at the head of the stack. These collections may contain duplicates + // for two reasons: + // (1): Tasks are only removed lazily; when a task is launched, it remains + // in all the pending lists except the one that it was launched from. + // (2): Tasks may be re-added to these lists multiple times as a result + // of failures. + // Duplicates are handled in dequeueTaskFromList, which ensures that a + // task hasn't already started running before launching it. private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each host. Similar to pendingTasksForExecutor, @@ -179,23 +184,16 @@ private[spark] class TaskSetManager( /** Add a task to all the pending-task lists that it should be on. */ private def addPendingTask(index: Int) { - // Utility method that adds `index` to a list only if it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!list.contains(index)) { - list += index - } - } - for (loc <- tasks(index).preferredLocations) { loc match { case e: ExecutorCacheTaskLocation => - addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer)) + pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index case e: HDFSCacheTaskLocation => { val exe = sched.getExecutorsAliveOnHost(loc.host) exe match { case Some(set) => { for (e <- set) { - addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer)) + pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index } logInfo(s"Pending task $index has a cached location at ${e.host} " + ", where there are executors " + set.mkString(",")) @@ -206,14 +204,14 @@ private[spark] class TaskSetManager( } case _ => Unit } - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) + pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index } } if (tasks(index).preferredLocations == Nil) { - addTo(pendingTasksWithNoPrefs) + pendingTasksWithNoPrefs += index } allPendingTasks += index // No point scanning this whole list to find the old task there From 04e8afe362f3bf36097688f9574c66002535ac0f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 17 Feb 2016 00:21:15 -0800 Subject: [PATCH 1919/2089] [SPARK-13357][SQL] Use generated projection and ordering for TakeOrderedAndProjectNode `TakeOrderedAndProjectNode` should use generated projection and ordering like other `LocalNode`s. Author: Takuya UESHIN Closes #11230 from ueshin/issues/SPARK-13357. --- .../sql/execution/local/TakeOrderedAndProjectNode.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index ae672fbca8d83..ca68b7677ce83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.util.BoundedPriorityQueue case class TakeOrderedAndProjectNode( @@ -30,7 +31,7 @@ case class TakeOrderedAndProjectNode( child: LocalNode) extends UnaryLocalNode(conf) { private[this] var projection: Option[Projection] = _ - private[this] var ord: InterpretedOrdering = _ + private[this] var ord: Ordering[InternalRow] = _ private[this] var iterator: Iterator[InternalRow] = _ private[this] var currentRow: InternalRow = _ @@ -41,8 +42,8 @@ case class TakeOrderedAndProjectNode( override def open(): Unit = { child.open() - projection = projectList.map(new InterpretedProjection(_, child.output)) - ord = new InterpretedOrdering(sortOrder, child.output) + projection = projectList.map(UnsafeProjection.create(_, child.output)) + ord = GenerateOrdering.generate(sortOrder, child.output) // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) while (child.next()) { From a7c74d7563926573c01baf613708a0f105a03e57 Mon Sep 17 00:00:00 2001 From: "Christopher C. Aycock" Date: Wed, 17 Feb 2016 11:24:18 -0800 Subject: [PATCH 1920/2089] [SPARK-13350][DOCS] Config doc updated to state that PYSPARK_PYTHON's default is "python2.7" Author: Christopher C. Aycock Closes #11239 from chrisaycock/master. --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index a2c0dfe76ca45..f2443e98574cf 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1655,7 +1655,7 @@ The following variables can be set in `spark-env.sh`: PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is python). + Python binary executable to use for PySpark in both driver and workers (default is python2.7 if available, otherwise python). PYSPARK_DRIVER_PYTHON From 1eac3800087a804c4b58a4af8b65eeb0879b6528 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 17 Feb 2016 15:05:39 -0800 Subject: [PATCH 1921/2089] [SPARK-13109][BUILD] Fix SBT publishLocal issue Add local ivy repo to the SBT build file to fix this. Scaladoc compile error is fixed. Author: jerryshao Closes #11001 from jerryshao/SPARK-13109. --- project/SparkBuild.scala | 3 ++- .../spark/sql/execution/vectorized/OffHeapColumnVector.java | 1 + .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 646efb4d09301..1ba6a075134c1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -155,7 +155,8 @@ object SparkBuild extends PomBuild { // Override SBT's default resolvers: resolvers := Seq( DefaultMavenRepository, - Resolver.mavenLocal + Resolver.mavenLocal, + Resolver.file("local", file(Path.userHome.absolutePath + "/.ivy2/local"))(Resolver.ivyStylePatterns) ), externalResolvers := resolvers.value, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index c15f3d34a4929..e38ed051219b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,6 +19,7 @@ import java.nio.ByteOrder; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 99548bc83beaa..3502d31bd1dfa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.vectorized; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; From 97ee85daf68345cf5c3c11ae5bf288cc697bdf9e Mon Sep 17 00:00:00 2001 From: shijinkui Date: Wed, 17 Feb 2016 15:08:22 -0800 Subject: [PATCH 1922/2089] [SPARK-12953][EXAMPLES] RDDRelation writer set overwrite mode https://issues.apache.org/jira/browse/SPARK-12953 fix error when run RDDRelation.main(): "path file:/Users/sjk/pair.parquet already exists" Set DataFrameWriter's mode to SaveMode.Overwrite Author: shijinkui Closes #10864 from shijinkui/set_mode. --- .../scala/org/apache/spark/examples/sql/RDDRelation.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2cc56f04e5c1f..a2f0fcd0e486e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,8 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SaveMode, SQLContext} // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -58,8 +57,8 @@ object RDDRelation { // Queries can also be written using a LINQ-like Scala DSL. df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) - // Write out an RDD as a parquet file. - df.write.parquet("pair.parquet") + // Write out an RDD as a parquet file with overwrite mode. + df.write.mode(SaveMode.Overwrite).parquet("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. val parquetFile = sqlContext.read.parquet("pair.parquet") From 9451fed52cb8a00c706b582a0b51d8cd832f9350 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 17 Feb 2016 16:17:20 -0800 Subject: [PATCH 1923/2089] [SPARK-13344][TEST] Fix harmless accumulator not found exceptions See [JIRA](https://issues.apache.org/jira/browse/SPARK-13344) for more detail. This was caused by #10835. Author: Andrew Or Closes #11222 from andrewor14/fix-test-accum-exceptions. --- .../org/apache/spark/AccumulatorSuite.scala | 8 ++++++++ .../spark/InternalAccumulatorSuite.scala | 8 ++++++++ .../scala/org/apache/spark/SparkFunSuite.scala | 18 ++++++++++++++---- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 4d49fe5159850..8acd0439b6900 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -34,6 +34,14 @@ import org.apache.spark.serializer.JavaSerializer class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { import AccumulatorParam._ + override def afterEach(): Unit = { + try { + Accumulators.clear() + } finally { + super.afterEach() + } + } + implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index c426bb7a4e809..474550608ba2f 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -28,6 +28,14 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { import InternalAccumulator._ import AccumulatorParam._ + override def afterEach(): Unit = { + try { + Accumulators.clear() + } finally { + super.afterEach() + } + } + test("get param") { assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam) assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index d3359c7406e45..99366a32c4e16 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -18,14 +18,26 @@ package org.apache.spark // scalastyle:off -import org.scalatest.{FunSuite, Outcome} +import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} /** * Base abstract class for all unit tests in Spark for handling common functionality. */ -private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +private[spark] abstract class SparkFunSuite + extends FunSuite + with BeforeAndAfterAll + with Logging { // scalastyle:on + protected override def afterAll(): Unit = { + try { + // Avoid leaking map entries in tests that use accumulators without SparkContext + Accumulators.clear() + } finally { + super.afterAll() + } + } + /** * Log the suite name and the test name before and after each test. * @@ -42,8 +54,6 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { test() } finally { logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") - // Avoid leaking map entries in tests that use accumulators without SparkContext - Accumulators.clear() } } From 0088b252bf56d391025a19b70af03f3ce6d7574c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 17 Feb 2016 18:56:19 -0800 Subject: [PATCH 1924/2089] [MINOR][MLLIB] fix mllib compile warnings This PR fixes some warnings found by `build/sbt mllib/test:compile`. Author: Xiangrui Meng Closes #11227 from mengxr/fix-mllib-warnings-201602. --- .../spark/ml/classification/DecisionTreeClassifierSuite.scala | 4 ++++ .../test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala | 2 ++ 2 files changed, 6 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index baf6b9083900f..9169bcd390dce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -305,7 +305,11 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) + case other => + fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") } + case other => + fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index b9e997c207bc7..dc44c58e97eb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.fpm +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils From b84404865b90fc07858a8b28c2a7028f96517167 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 17 Feb 2016 19:03:29 -0800 Subject: [PATCH 1925/2089] [SPARK-13324][CORE][BUILD] Update plugin, test, example dependencies for 2.x Phase 1: update plugin versions, test dependencies, some example and third-party versions Author: Sean Owen Closes #11206 from srowen/SPARK-13324. --- build/mvn | 2 +- dev/deps/spark-deps-hadoop-2.2 | 8 ++++---- dev/deps/spark-deps-hadoop-2.3 | 8 ++++---- dev/deps/spark-deps-hadoop-2.4 | 8 ++++---- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- docs/building-spark.md | 2 +- examples/pom.xml | 6 +++--- pom.xml | 34 +++++++++++++++++----------------- 9 files changed, 42 insertions(+), 42 deletions(-) diff --git a/build/mvn b/build/mvn index 63ca9c98067d0..58058c04b8911 100755 --- a/build/mvn +++ b/build/mvn @@ -69,7 +69,7 @@ install_app() { # Install maven under the build/ folder install_mvn() { - local MVN_VERSION="3.3.3" + local MVN_VERSION="3.3.9" install_app \ "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 3a14499d9b4d9..47482fd3adeeb 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -96,7 +96,7 @@ javax.servlet-api-3.0.1.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar jaxb-impl-2.2.3-1.jar -jcl-over-slf4j-1.7.10.jar +jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar jersey-client-1.9.jar jersey-core-1.9.jar @@ -121,7 +121,7 @@ json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar -jul-to-slf4j-1.7.10.jar +jul-to-slf4j-1.7.16.jar kryo-2.21.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar @@ -164,8 +164,8 @@ scala-reflect-2.11.7.jar scala-xml_2.11-1.0.2.jar scalap-2.11.7.jar servlet-api-2.5.jar -slf4j-api-1.7.10.jar -slf4j-log4j12-1.7.10.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.jar spire-macros_2.11-0.7.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 615836b3d3b77..dd93abf4f395b 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -90,7 +90,7 @@ javax.servlet-3.0.0.v201112011016.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar jaxb-impl-2.2.3-1.jar -jcl-over-slf4j-1.7.10.jar +jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar jersey-core-1.9.jar jersey-guice-1.9.jar @@ -112,7 +112,7 @@ json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar -jul-to-slf4j-1.7.10.jar +jul-to-slf4j-1.7.16.jar kryo-2.21.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar @@ -155,8 +155,8 @@ scala-reflect-2.11.7.jar scala-xml_2.11-1.0.2.jar scalap-2.11.7.jar servlet-api-2.5.jar -slf4j-api-1.7.10.jar -slf4j-log4j12-1.7.10.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.jar spire-macros_2.11-0.7.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index f275226f1d088..987a6d7f593f6 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -90,7 +90,7 @@ javax.servlet-3.0.0.v201112011016.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar jaxb-impl-2.2.3-1.jar -jcl-over-slf4j-1.7.10.jar +jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar jersey-client-1.9.jar jersey-core-1.9.jar @@ -113,7 +113,7 @@ json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar -jul-to-slf4j-1.7.10.jar +jul-to-slf4j-1.7.16.jar kryo-2.21.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar @@ -156,8 +156,8 @@ scala-reflect-2.11.7.jar scala-xml_2.11-1.0.2.jar scalap-2.11.7.jar servlet-api-2.5.jar -slf4j-api-1.7.10.jar -slf4j-log4j12-1.7.10.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.jar spire-macros_2.11-0.7.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 21432a16e3659..9f158485d2803 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -96,7 +96,7 @@ javax.servlet-3.0.0.v201112011016.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar jaxb-impl-2.2.3-1.jar -jcl-over-slf4j-1.7.10.jar +jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar jersey-client-1.9.jar jersey-core-1.9.jar @@ -119,7 +119,7 @@ json4s-jackson_2.11-3.2.10.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar -jul-to-slf4j-1.7.10.jar +jul-to-slf4j-1.7.16.jar kryo-2.21.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar @@ -162,8 +162,8 @@ scala-reflect-2.11.7.jar scala-xml_2.11-1.0.2.jar scalap-2.11.7.jar servlet-api-2.5.jar -slf4j-api-1.7.10.jar -slf4j-log4j12-1.7.10.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.jar spire-macros_2.11-0.7.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 20e09cd002635..6cbb9ab9591ce 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -96,7 +96,7 @@ javax.servlet-3.0.0.v201112011016.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar jaxb-impl-2.2.3-1.jar -jcl-over-slf4j-1.7.10.jar +jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar jersey-client-1.9.jar jersey-core-1.9.jar @@ -120,7 +120,7 @@ jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar -jul-to-slf4j-1.7.10.jar +jul-to-slf4j-1.7.16.jar kryo-2.21.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar @@ -163,8 +163,8 @@ scala-reflect-2.11.7.jar scala-xml_2.11-1.0.2.jar scalap-2.11.7.jar servlet-api-2.5.jar -slf4j-api-1.7.10.jar -slf4j-log4j12-1.7.10.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.jar spire-macros_2.11-0.7.4.jar diff --git a/docs/building-spark.md b/docs/building-spark.md index 975e1b295c8ae..adf798847c3c3 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,7 +7,7 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.3.3 or newer and Java 7+. +Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. The Spark build can supply a suitable Maven binary; see below. # Building with `build/mvn` diff --git a/examples/pom.xml b/examples/pom.xml index 82baa9085b4f9..3a3f547915015 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -249,7 +249,7 @@ com.twitter algebird-core_${scala.binary.version} - 0.9.0 + 0.11.0 org.scalacheck @@ -259,7 +259,7 @@ org.apache.cassandra cassandra-all - 1.2.6 + 1.2.19 com.google.guava @@ -318,7 +318,7 @@ com.github.scopt scopt_${scala.binary.version} - 3.2.0 + 3.3.0 test @@ -679,25 +679,25 @@ org.scalatest scalatest_${scala.binary.version} - 2.2.1 + 2.2.6 test org.mockito mockito-core - 1.9.5 + 1.10.19 test org.scalacheck scalacheck_${scala.binary.version} - 1.11.3 + 1.12.5 test junit junit - 4.11 + 4.12 test @@ -722,7 +722,7 @@ com.spotify docker-client shaded - 3.2.1 + 3.4.0 test @@ -750,13 +750,13 @@ mysql mysql-connector-java - 5.1.34 + 5.1.38 test org.postgresql postgresql - 9.3-1102-jdbc41 + 9.4.1207.jre7 test @@ -1820,7 +1820,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.9.1 + 1.10 net.alchim31.maven @@ -1883,7 +1883,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.3 + 3.5.1 ${java.version} ${java.version} @@ -1904,7 +1904,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18.1 + 2.19.1 @@ -2017,7 +2017,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.6.1 + 3.0.0 @@ -2045,12 +2045,12 @@ org.apache.maven.plugins maven-assembly-plugin - 2.5.5 + 2.6 org.apache.maven.plugins maven-shade-plugin - 2.4.1 + 2.4.3 org.apache.maven.plugins From 892b2dd6dd00d2088c75ab3c8443e8b3e44e5803 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 17 Feb 2016 22:14:45 -0500 Subject: [PATCH 1926/2089] Add github pull request template --- .github/PULL_REQUEST_TEMPLATE | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE new file mode 100644 index 0000000000000..39bcc84225b4a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE @@ -0,0 +1,12 @@ +## What changes were proposed in this pull request? + +(Please fill in changes proposed in this fix) + + +## How was the this patch tested? + +(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) + + +(If this patch involves UI changes, please attach a screenshot; otherwise, remove this) + From 78562535feb6e214520b29e0bbdd4b1302f01e93 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 18 Feb 2016 12:14:30 -0800 Subject: [PATCH 1927/2089] [SPARK-13371][CORE][STRING] TaskSetManager.dequeueSpeculativeTask compares Option and String directly. ## What changes were proposed in this pull request? Fix some comparisons between unequal types that cause IJ warnings and in at least one case a likely bug (TaskSetManager) ## How was the this patch tested? Running Jenkins tests Author: Sean Owen Closes #11253 from srowen/SPARK-13371. --- .../scala/org/apache/spark/deploy/FaultToleranceTest.scala | 2 +- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- core/src/test/scala/org/apache/spark/CheckpointSuite.scala | 4 ++-- .../src/test/scala/org/apache/spark/PartitioningSuite.scala | 4 ++-- .../test/scala/org/apache/spark/ui/UISeleniumSuite.scala | 6 +++++- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 15d220d01b461..434aadd2c61ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -252,7 +252,7 @@ private object FaultToleranceTest extends App with Logging { val f = Future { try { val res = sc.parallelize(0 until 10).collect() - assertTrue(res.toList == (0 until 10)) + assertTrue(res.toList == (0 until 10).toList) true } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 05cfa52f16b3e..2b0eab71694bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -338,7 +338,7 @@ private[spark] class TaskSetManager( if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { for (rack <- sched.getRackForHost(host)) { for (index <- speculatableTasks if canRunOnHost(index)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) + val racks = tasks(index).preferredLocations.map(_.host).flatMap(sched.getRackForHost) if (racks.contains(rack)) { speculatableTasks -= index return Some((index, TaskLocality.RACK_LOCAL)) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index ce35856dce3f7..9f94e36324536 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -54,7 +54,7 @@ trait RDDCheckpointTester { self: SparkFunSuite => // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull + val parentDependency = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName val numPartitions = operatedRDD.partitions.length @@ -82,7 +82,7 @@ trait RDDCheckpointTester { self: SparkFunSuite => } // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) + assert(operatedRDD.dependencies.head != parentDependency) // Test whether the partitions have been changed from its earlier partitions assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index aa8028792cb41..3d31c7864e760 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -163,8 +163,8 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva val hashP2 = new HashPartitioner(2) assert(rangeP2 === rangeP2) assert(hashP2 === hashP2) - assert(hashP2 != rangeP2) - assert(rangeP2 != hashP2) + assert(hashP2 !== rangeP2) + assert(rangeP2 !== hashP2) } test("partitioner preservation") { diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index aa22f3ba2b4d1..b0a35fe8c3319 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -289,7 +289,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B JInt(stageId) <- stage \ "stageId" JInt(attemptId) <- stage \ "attemptId" } { - val exp = if (attemptId == 0 && stageId == 1) StageStatus.FAILED else StageStatus.COMPLETE + val exp = if (attemptId.toInt == 0 && stageId.toInt == 1) { + StageStatus.FAILED + } else { + StageStatus.COMPLETE + } status should be (exp.name()) } From 26f38bb83c423e512955ca25775914dae7e5bbf0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Feb 2016 13:07:41 -0800 Subject: [PATCH 1928/2089] [SPARK-13351][SQL] fix column pruning on Expand Currently, the columns in projects of Expand that are not used by Aggregate are not pruned, this PR fix that. Author: Davies Liu Closes #11225 from davies/fix_pruning_expand. --- .../sql/catalyst/optimizer/Optimizer.scala | 10 ++++++ .../optimizer/ColumnPruningSuite.scala | 33 +++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 567010f23fc8a..55c168d552a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -300,6 +300,16 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, e @ Expand(projects, output, child)) + if (e.outputSet -- a.references).nonEmpty => + val newOutput = output.filter(a.references.contains(_)) + val newProjects = projects.map { proj => + proj.zip(output).filter { case (e, a) => + newOutput.contains(a) + }.unzip._1 + } + a.copy(child = Expand(newProjects, newOutput, child)) + case a @ Aggregate(_, _, e @ Expand(_, _, child)) if (child.outputSet -- e.references -- a.references).nonEmpty => a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 81f3928035619..c890fffc40167 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Explode +import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.StringType @@ -96,5 +96,34 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Column pruning for Expand") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + Aggregate( + Seq('aa, 'gid), + Seq(sum('c).as("sum")), + Expand( + Seq( + Seq('a, 'b, 'c, Literal.create(null, StringType), 1), + Seq('a, 'b, 'c, 'a, 2)), + Seq('a, 'b, 'c, 'aa.int, 'gid.int), + input)).analyze + val optimized = Optimize.execute(query) + + val expected = + Aggregate( + Seq('aa, 'gid), + Seq(sum('c).as("sum")), + Expand( + Seq( + Seq('c, Literal.create(null, StringType), 1), + Seq('c, 'a, 2)), + Seq('c, 'aa.int, 'gid.int), + Project(Seq('c, 'a), + input))).analyze + + comparePlans(optimized, expected) + } + // todo: add more tests for column pruning } From 95e1ab223e87fc216f3256d404fe3be50d111a9d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Feb 2016 15:15:06 -0800 Subject: [PATCH 1929/2089] [SPARK-13237] [SQL] generated broadcast outer join This PR support codegen for broadcast outer join. In order to reduce the duplicated codes, this PR merge HashJoin and HashOuterJoin together (also BroadcastHashJoin and BroadcastHashOuterJoin). Author: Davies Liu Closes #11130 from davies/gen_out. --- .../spark/sql/execution/SparkStrategies.scala | 16 +- .../sql/execution/WholeStageCodegen.scala | 8 +- .../execution/joins/BroadcastHashJoin.scala | 253 ++++++++++++++---- .../joins/BroadcastHashOuterJoin.scala | 121 --------- .../spark/sql/execution/joins/HashJoin.scala | 111 +++++++- .../sql/execution/joins/HashOuterJoin.scala | 153 ----------- .../org/apache/spark/sql/JoinSuite.scala | 5 +- .../BenchmarkWholeStageCodegen.scala | 131 ++++++++- .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 9 +- .../execution/metric/SQLMetricsSuite.scala | 8 +- 12 files changed, 448 insertions(+), 371 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 042c99db4dcff..382654afacb89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => @@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys( RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index f35efb5b24b1f..8626f54eb413c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue /** * An interface for those physical operators that support codegen. @@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" - case _: BroadcastHashJoin => "bhj" + case _: BroadcastHashJoin => "join" case _ => nodeName.toLowerCase } @@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { // The build side can't be compiled together - case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) => b.copy(left = apply(left)) - case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 985e74011daa7..a64da225800a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -24,8 +24,9 @@ import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, @@ -105,75 +107,144 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - val hashedRelation = broadcastRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashJoin(streamedIter, hashedRelation, numOutputRows) + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + + joinType match { + case Inner => + hashJoin(streamedIter, hashTable, numOutputRows) + + case LeftOuter => + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) + } + + case RightOuter => + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) + } + + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } } } - private var broadcastRelation: Broadcast[HashedRelation] = _ - // the term for hash relation - private var relationTerm: String = _ - override def upstream(): RDD[InternalRow] = { streamedPlan.asInstanceOf[CodegenSupport].upstream() } override def doProduce(ctx: CodegenContext): String = { + streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (joinType == Inner) { + codegenInner(ctx, input) + } else { + // LeftOuter and RightOuter + codegenOuter(ctx, input) + } + } + + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ + private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation - broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcastRelation = Await.result(broadcastFuture, timeout) val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - relationTerm = ctx.freshName("relation") + val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = ($clsName) $broadcast.value(); | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) - - s""" - | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} - """.stripMargin + (broadcastRelation, relationTerm) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // generate the key as UnsafeRow or Long + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + if (canJoinKeyFitWithinLong) { + // generate the join key as Long val expr = rewriteKeyExpr(streamedKeys).head val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) (ev, ev.isNull) } else { + // generate the join key as UnsafeRow val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) (ev, s"${ev.value}.anyNull()") } + } - // find the matches from HashedRelation - val matched = ctx.freshName("matched") - - // create variables for output + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = matched - val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).gen(ctx) + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + if (joinType == Inner) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) + } } + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) val resultVars = buildSide match { - case BuildLeft => buildColumns ++ input - case BuildRight => input ++ buildColumns + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars } - val numOutput = metricTerm(ctx, "numOutputRows") + val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) s""" - | ${ev.code} - | if (!${ev.isNull} && ${ev.value}) { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } + |${ev.code} + |if (!${ev.isNull} && ${ev.value}) { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} """.stripMargin } else { s""" @@ -184,36 +255,110 @@ case class BroadcastHashJoin( if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashedRelation - | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); - | if ($matched != null) { - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - """.stripMargin + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched != null) { + | ${buildVars.map(_.code).mkString("\n")} + | $outputCode + |} + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildVars.map(_.code).mkString("\n")} + | $outputCode + | } + |} + """.stripMargin + } + } + + + /** + * Generates the code for left or right outer join. + */ + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + val numOutput = metricTerm(ctx, "numOutputRows") + + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + |boolean $conditionPassed = true; + |if ($matched != null) { + | ${ev.code} + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |${buildVars.map(_.code).mkString("\n")} + |${checkCondition.trim} + |if (!$conditionPassed) { + | // reset to null + | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin } else { val matches = ctx.freshName("matches") val bufferType = classOf[CompactBuffer[UnsafeRow]].getName val i = ctx.freshName("i") val size = ctx.freshName("size") + val found = ctx.freshName("found") s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = ${anyNull} ? null : - | ($bufferType) $relationTerm.get(${keyVal.value}); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - | } - """.stripMargin + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |int $size = $matches != null ? $matches.size() : 0; + |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. + |for (int $i = 0; $i <= $size; $i++) { + | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; + | ${buildVars.map(_.code).mkString("\n")} + | ${checkCondition.trim} + | if ($conditionPassed && ($i < $size || !$found)) { + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala deleted file mode 100644 index 5e8c8ca043629..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs a outer hash join for two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -case class BroadcastHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - val timeout = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - row.copy() - }.collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture - } - - override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastRelation = Await.result(broadcastFuture, timeout) - - streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - val keyGenerator = streamedKeyGenerator - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) - - val resultProj = resultProjection - joinType match { - case LeftOuter => - streamedIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - streamedIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 332a748d3bfc0..2fe9c06cc9537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -21,20 +21,38 @@ import java.util.NoSuchElementException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegralType, LongType} +import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} +import org.apache.spark.util.collection.CompactBuffer trait HashJoin { self: SparkPlan => val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] + val joinType: JoinType val buildSide: BuildSide val condition: Option[Expression] val left: SparkPlan val right: SparkPlan + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") + } + } + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) @@ -45,8 +63,6 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output: Seq[Attribute] = left.output ++ right.output - /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -67,8 +83,17 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 + // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same + // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys + // with two same ints have hash code 0, we rotate the bits of second one. + val rotated = if (e.dataType == IntegerType) { + // (e >>> 15) | (e << 17) + BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) + } else { + e + } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -97,11 +122,13 @@ trait HashJoin { (r: InternalRow) => true } + protected def createResultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) + protected def hashJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = - { + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ private[this] var currentHashMatches: Seq[InternalRow] = _ @@ -109,8 +136,7 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(self.schema) + private[this] val resultProjection = createResultProjection private[this] val joinKeys = streamSideKeyGenerator @@ -163,4 +189,73 @@ trait HashJoin { } } } + + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() + + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) + + protected[this] def leftOuterIterator( + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } + } + ret.iterator + } + + protected[this] def rightOuterIterator( + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } + } + ret.iterator + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala deleted file mode 100644 index 9e614309de129..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.util.collection.CompactBuffer - - -trait HashOuterJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val joinType: JoinType - val condition: Option[Expression] - val left: SparkPlan - val right: SparkPlan - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - - protected[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected def buildKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) - - protected[this] def streamedKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) - - protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(output, output) - - @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - } else { - (row: InternalRow) => true - } - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } - } - ret.iterator - } - - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } - } - ret.iterator - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9a3c262e9485d..92ff7e73fad88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 4a151179bf6f2..bcac660a35a65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.HashMap + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext @@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ignore("broadcast hash join") { val N = 100 << 20 - val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + val M = 1 << 16 + val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X - Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X + Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X + Join w long codegen=true 735 / 853 142.7 7.0 7.8X */ - val dim2 = broadcast(sqlContext.range(1 << 16) + val dim2 = broadcast(sqlContext.range(M) .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) runBenchmark("Join w 2 ints", N) { sqlContext.range(N).join(dim2, - (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1") - && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count() + (col("id") bitwiseAND M).cast(IntegerType) === col("k1") + && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count() } /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X - Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X + Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X + Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X */ + val dim3 = broadcast(sqlContext.range(M) + .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 longs", N) { + sqlContext.range(N).join(dim3, + (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + .count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X + Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X + */ + runBenchmark("outer join w long", N) { + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X + outer join w long codegen=true 769 / 796 136.3 7.3 19.9X + */ } ignore("rube") { @@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 50 << 20 + val N = 10 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) @@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("Java HashMap (Long)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + map.put(i.toLong, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (UnsafeRow)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + key.setInt(0, i % 100000) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { hash 651 / 678 80.0 12.5 1.0X fast hash 336 / 343 155.9 6.4 1.9X arrayEqual 417 / 428 125.0 8.0 1.6X + Java HashMap (Long) 145 / 168 72.2 13.8 0.8X + Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X + Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index aee8e84db56e2..e25b5e0610ea1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 149f34dbd748f..e22a810a6b42f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan) + joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) } def makeSortMergeJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 3d3e9a7b90928..f4b01fbad0585 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + test(s"$testName using BroadcastHashJoin") { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + } extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + BroadcastHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f4bc9e501c21c..46bb699b780a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("BroadcastHashOuterJoin metrics") { + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashOuterJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 5L))) ) val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 6L))) ) } From c776fce99b496a789ffcf2cfab78cf51eeea032b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 18 Feb 2016 21:19:36 -0800 Subject: [PATCH 1930/2089] [SPARK-13380][SQL][DOCUMENT] Document Rand(seed) and Randn(seed) Return Indeterministic Results When Data Partitions are not fixed. `rand` and `randn` functions with a `seed` argument are commonly used. Based on the common sense, the results of `rand` and `randn` should be deterministic if the `seed` parameter value is provided. For example, in MS SQL Server, it also has a function `rand`. Regarding the parameter `seed`, the description is like: ```Seed is an integer expression (tinyint, smallint, or int) that gives the seed value. If seed is not specified, the SQL Server Database Engine assigns a seed value at random. For a specified seed value, the result returned is always the same.``` Update: the current implementation is unable to generate deterministic results when the partitions are not fixed. This PR documents this issue in the function descriptions. jkbradley hit an issue and provided an example in the following JIRA: https://issues.apache.org/jira/browse/SPARK-13333 Author: gatorsmile Closes #11232 from gatorsmile/randSeed. --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2e703671fcd66..6be3cbcae6295 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -85,7 +85,7 @@ case class Randn(seed: Long) extends RDG { def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s - case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e4ab6b4f23486..97c6992e18753 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1052,6 +1052,8 @@ object functions extends LegacyFunctions { /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * + * Note that this is indeterministic when data partitions are not fixed. + * * @group normal_funcs * @since 1.4.0 */ @@ -1068,6 +1070,8 @@ object functions extends LegacyFunctions { /** * Generate a column with i.i.d. samples from the standard normal distribution. * + * Note that this is indeterministic when data partitions are not fixed. + * * @group normal_funcs * @since 1.4.0 */ From fb7e21797ed618d9754545a44f8f95f75b66757a Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 19 Feb 2016 10:26:38 +0000 Subject: [PATCH 1931/2089] [SPARK-13339][DOCS] Clarify commutative / associative operator requirements for reduce, fold Clarify that reduce functions need to be commutative, and fold functions do not See https://github.com/apache/spark/pull/11091 Author: Sean Owen Closes #11217 from srowen/SPARK-13339. --- R/pkg/R/pairRDD.R | 10 +++--- .../scala/org/apache/spark/Accumulator.scala | 6 ++-- .../apache/spark/api/java/JavaPairRDD.scala | 32 +++++++++---------- .../apache/spark/api/java/JavaRDDLike.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 24 +++++++------- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- docs/programming-guide.md | 2 +- docs/streaming-programming-guide.md | 4 +-- python/pyspark/rdd.py | 7 ++-- python/pyspark/streaming/dstream.py | 4 +-- .../streaming/api/java/JavaDStreamLike.scala | 6 ++-- .../streaming/api/java/JavaPairDStream.scala | 18 +++++------ .../spark/streaming/dstream/DStream.scala | 4 +-- .../dstream/PairDStreamFunctions.scala | 16 +++++----- 14 files changed, 68 insertions(+), 69 deletions(-) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index f7131140feafb..4075ef4377acf 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -305,11 +305,11 @@ setMethod("groupByKey", #' Merge values by key #' #' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -#' and merges the values for each key using an associative reduce function. +#' and merges the values for each key using an associative and commutative reduce function. #' #' @param x The RDD to reduce by key. Should be an RDD where each element is #' list(K, V) or c(K, V). -#' @param combineFunc The associative reduce function to use. +#' @param combineFunc The associative and commutative reduce function to use. #' @param numPartitions Number of partitions to create. #' @return An RDD where each element is list(K, V') where V' is the merged #' value @@ -347,12 +347,12 @@ setMethod("reduceByKey", #' Merge values by key locally #' #' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -#' and merges the values for each key using an associative reduce function, but return the -#' results immediately to the driver as an R list. +#' and merges the values for each key using an associative and commutative reduce function, but +#' return the results immediately to the driver as an R list. #' #' @param x The RDD to reduce by key. Should be an RDD where each element is #' list(K, V) or c(K, V). -#' @param combineFunc The associative reduce function to use. +#' @param combineFunc The associative and commutative reduce function to use. #' @return A list of elements of type list(K, V') where V' is the merged value for each key #' @seealso reduceByKey #' @examples diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 5e8f1d4a705c3..0e4bcc33cba57 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -29,9 +29,9 @@ import org.apache.spark.storage.{BlockId, BlockStatus} /** * A simpler value of [[Accumulable]] where the result type being accumulated is the same * as the types of elements being merged, i.e. variables that are only "added" to through an - * associative operation and can therefore be efficiently supported in parallel. They can be used - * to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric - * value types, and programmers can add support for new types. + * associative and commutative operation and can therefore be efficiently supported in parallel. + * They can be used to implement counters (as in MapReduce) or sums. Spark natively supports + * accumulators of numeric value types, and programmers can add support for new types. * * An accumulator is created from an initial value `v` by calling [[SparkContext#accumulator]]. * Tasks running on the cluster can then add to it using the [[Accumulable#+=]] operator. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 94d103588b696..e080f91f50ad4 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -278,17 +278,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = fromRDD(rdd.reduceByKey(partitioner, func)) /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. + * Merge the values for each key using an associative and commutative reduce function, but return + * the result immediately to the master as a Map. This will also perform the merging locally on + * each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) @@ -381,9 +381,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.foldByKey(zeroValue)(func)) /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.reduceByKey(func, numPartitions)) @@ -461,10 +461,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.partitionBy(partitioner)) /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. + */ def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = fromRDD(rdd.join(other, partitioner)) @@ -520,9 +520,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ * parallelism level. */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 37c211fe706e1..4212027122544 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -373,7 +373,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". The function + * given associative function and a neutral "zero value". The function * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object * allocation; however, it should not modify t2. * diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 61905a8421124..e00b9f6cfda0e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -300,27 +300,27 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = self.withScope { reduceByKey(new HashPartitioner(numPartitions), func) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ * parallelism level. */ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope { @@ -328,9 +328,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. + * Merge the values for each key using an associative and commutative reduce function, but return + * the results immediately to the master as a Map. This will also perform the merging locally on + * each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { val cleanedF = self.sparkContext.clean(func) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a81a98b526b5a..6a6ad2d75a897 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -973,7 +973,7 @@ abstract class RDD[T: ClassTag]( /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". The function + * given associative function and a neutral "zero value". The function * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object * allocation; however, it should not modify t2. * diff --git a/docs/programming-guide.md b/docs/programming-guide.md index e45081464af1c..2d6f7767d93ef 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1343,7 +1343,7 @@ value of the broadcast variable (e.g. if the variable is shipped to a new node l ## Accumulators -Accumulators are variables that are only "added" to through an associative operation and can +Accumulators are variables that are only "added" to through an associative and commutative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 677f5ff7bea8b..4d1932bc8c3e1 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -798,7 +798,7 @@ Some of the common ones are as follows. reduce(func) Return a new DStream of single-element RDDs by aggregating the elements in each RDD of the source DStream using a function func (which takes two arguments and returns one). - The function should be associative so that it can be computed in parallel. + The function should be associative and commutative so that it can be computed in parallel. countByValue() @@ -1072,7 +1072,7 @@ said two parameters - windowLength and slideInterval. reduceByWindow(func, windowLength, slideInterval) Return a new single-element stream, created by aggregating elements in the stream over a - sliding interval using func. The function should be associative so that it can be computed + sliding interval using func. The function should be associative and commutative so that it can be computed correctly in parallel. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fe2264a63cf30..4eaf589ad5d46 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -844,8 +844,7 @@ def op(x, y): def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative and commutative function and - a neutral "zero value." + the partitions, using a given associative function and a neutral "zero value." The function C{op(t1, t2)} is allowed to modify C{t1} and return it as its result value to avoid object allocation; however, it should not @@ -1558,7 +1557,7 @@ def values(self): def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ - Merge the values for each key using an associative reduce function. + Merge the values for each key using an associative and commutative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. @@ -1576,7 +1575,7 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): def reduceByKeyLocally(self, func): """ - Merge the values for each key using an associative reduce function, but + Merge the values for each key using an associative and commutative reduce function, but return the results immediately to the master as a dictionary. This will also perform the merging locally on each mapper before diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 86447f5e58ecb..2056663872198 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -453,7 +453,7 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) This is more efficient than `invReduceFunc` is None. - @param reduceFunc: associative reduce function + @param reduceFunc: associative and commutative reduce function @param invReduceFunc: inverse reduce function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @@ -524,7 +524,7 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param func: associative reduce function + @param func: associative and commutative reduce function @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index f10de485d0f75..9931a46d33a9f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -214,7 +214,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -234,7 +234,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -257,7 +257,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient than reduceByWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index d718f1d6fc43e..aad9a12c15246 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -138,8 +138,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the associative reduce function. Hash partitioning is used to generate the RDDs - * with Spark's default number of partitions. + * merged using the associative and commutative reduce function. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = dstream.reduceByKey(func) @@ -257,7 +257,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate * the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ @@ -270,7 +270,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -289,7 +289,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -309,7 +309,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. - * @param reduceFunc associative reduce function + * @param reduceFunc associative rand commutative educe function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -335,7 +335,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -360,7 +360,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -397,7 +397,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index db79eeab9c0cc..70e1d8abded4c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -791,7 +791,7 @@ abstract class DStream[T: ClassTag] ( /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -814,7 +814,7 @@ abstract class DStream[T: ClassTag] ( * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient than reduceByWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index babc722709325..1dcdb64e289bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -75,8 +75,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the associative reduce function. Hash partitioning is used to generate the RDDs - * with Spark's default number of partitions. + * merged using the associative and commutative reduce function. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. */ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner()) @@ -204,7 +204,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate * the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ @@ -219,7 +219,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -238,7 +238,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -259,7 +259,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -289,7 +289,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -320,7 +320,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval From 6915cc23b3166088eab3f7cff2b6adc2a0e6739d Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 19 Feb 2016 11:47:36 -0800 Subject: [PATCH 1932/2089] [MINOR][DOCS][MESOS] Clarify that Mesos version is a lower bound. ## What changes were proposed in this pull request? Clarify that 0.21 is only a **minimum** requirement. ## How was the this patch tested? It's a doc change, so no tests. Author: Iulian Dragos Closes #11271 from dragos/patch-1. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 0df476d9b4dd7..35f6caab1700c 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -32,7 +32,7 @@ To get started, follow the steps below to install Mesos and deploy Spark jobs vi # Installing Mesos -Spark {{site.SPARK_VERSION}} is designed for use with Mesos {{site.MESOS_VERSION}} and does not +Spark {{site.SPARK_VERSION}} is designed for use with Mesos {{site.MESOS_VERSION}} or newer and does not require any special patches of Mesos. If you already have a Mesos cluster running, you can skip this Mesos installation step. From c7c55637bfc523237f5cc5c5b61837b1e3d5fdfc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Feb 2016 12:22:22 -0800 Subject: [PATCH 1933/2089] [SPARK-13384][SQL] Keep attribute qualifiers after dedup in Analyzer JIRA: https://issues.apache.org/jira/browse/SPARK-13384 ## What changes were proposed in this pull request? When we de-duplicate attributes in Analyzer, we create new attributes. However, we don't keep original qualifiers. Some plans will be failed to analysed. We should keep original qualifiers in new attributes. ## How was the this patch tested? Unit test is added. Author: Liang-Chi Hsieh Closes #11261 from viirya/keep-attr-qualifiers. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 004c1caaffec6..8368657b2ba2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -432,7 +432,8 @@ class Analyzer( case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) + case a: Attribute => + attributeRewrites.get(a).getOrElse(a).withQualifiers(a.qualifiers) } } newRight diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f85ae24e0459b..92db02944c6ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -336,4 +337,17 @@ class AnalysisSuite extends AnalysisTest { val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) assertAnalysisSuccess(plan) } + + test("Keep attribute qualifiers after dedup") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + Project(Seq($"x.key"), Subquery("x", input)), + Project(Seq($"y.key"), Subquery("y", input)), + Inner, None)) + + assertAnalysisSuccess(query) + } } From dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c Mon Sep 17 00:00:00 2001 From: Brandon Bradley Date: Fri, 19 Feb 2016 14:43:21 -0800 Subject: [PATCH 1934/2089] [SPARK-12966][SQL] ArrayType(DecimalType) support in Postgres JDBC Fixes error `org.postgresql.util.PSQLException: Unable to find server array type for provided name decimal(38,18)`. * Passes scale metadata to JDBC dialect for usage in type conversions. * Removes unused length/scale/precision parameters from `createArrayOf` parameter `typeName` (for writing). * Adds configurable precision and scale to Postgres `DecimalType` (for reading). * Adds a new kind of test that verifies the schema written by `DataFrame.write.jdbc`. Author: Brandon Bradley Closes #10928 from blbradley/spark-12966. --- .../sql/jdbc/PostgresIntegrationSuite.scala | 16 +++++++++++----- .../execution/datasources/jdbc/JDBCRDD.scala | 4 +++- .../execution/datasources/jdbc/JdbcUtils.scala | 5 ++++- .../spark/sql/jdbc/PostgresDialect.scala | 18 ++++++++++++------ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 72bda8fe1ef10..d55cdcf28b238 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -22,6 +22,7 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types.{ArrayType, DecimalType} import org.apache.spark.tags.DockerTest @DockerTest @@ -42,10 +43,10 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate() + + "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type)").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1')""").executeUpdate() } test("Type mapping for various types") { @@ -53,7 +54,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 14) + assert(types.length == 15) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -67,7 +68,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(classOf[Seq[Double]].isAssignableFrom(types(12))) - assert(classOf[String].isAssignableFrom(types(13))) + assert(classOf[Seq[BigDecimal]].isAssignableFrom(types(13))) + assert(classOf[String].isAssignableFrom(types(14))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) @@ -84,13 +86,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) - assert(rows(0).getString(13) == "d1") + assert(rows(0).getSeq(13) == Seq("0.11", "0.22").map(BigDecimal(_).bigDecimal)) + assert(rows(0).getString(14) == "d1") } test("Basic write test") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) // Test only that it doesn't crash. df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test that written numeric type has same DataType as input + assert(sqlContext.read.jdbc(jdbcUrl, "public.barcopy", new Properties).schema(13).dataType == + ArrayType(DecimalType(2, 2), true)) // Test write null values. df.select(df.queryExecution.analyzed.output.map { a => Column(Literal.create(null, a.dataType)).as(a.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index befba867bc460..ed02b3f95f841 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -137,7 +137,9 @@ private[sql] object JDBCRDD extends Logging { val fieldScale = rsmd.getScale(i + 1) val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) + val metadata = new MetadataBuilder() + .putString("name", columnName) + .putLong("scale", fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 69ba84646f088..e295722cacf15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -194,8 +194,11 @@ object JdbcUtils extends Logging { case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case ArrayType(et, _) => + // remove type length parameters from end of type name + val typeName = getJdbcType(et, dialect).databaseTypeDefinition + .toLowerCase.split("\\(")(0) val array = conn.createArrayOf( - getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, + typeName, row.getSeq[AnyRef](i).toArray) stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 8d43966480a65..2d6c3974a833e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -32,14 +32,18 @@ private object PostgresDialect extends JdbcDialect { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { - toCatalystType(typeName).filter(_ == StringType) - } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { - toCatalystType(typeName.drop(1)).map(ArrayType(_)) + Some(StringType) + } else if (sqlType == Types.ARRAY) { + val scale = md.build.getLong("scale").toInt + // postgres array type names start with underscore + toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } - // TODO: support more type names. - private def toCatalystType(typeName: String): Option[DataType] = typeName match { + private def toCatalystType( + typeName: String, + precision: Int, + scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) @@ -52,7 +56,7 @@ private object PostgresDialect extends JdbcDialect { case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) - case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } @@ -62,6 +66,8 @@ private object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) + case t: DecimalType => Some( + JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) From 14844118b596a93dbc28b442a7ea2b58fa4df648 Mon Sep 17 00:00:00 2001 From: Hossein Date: Fri, 19 Feb 2016 14:46:56 -0800 Subject: [PATCH 1935/2089] [SPARK-13261][SQL] Expose maxCharactersPerColumn as a user configurable option This patch expose `maxCharactersPerColumn` and `maxColumns` to user in CSV data source. Author: Hossein Closes #11147 from falaki/SPARK-13261. --- .../execution/datasources/csv/CSVOptions.scala | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 709daccbbef58..bea8e97a9a94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -36,6 +36,19 @@ private[sql] class CSVOptions( } } + private def getInt(paramName: String, default: Int): Int = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(value) => try { + value.toInt + } catch { + case e: NumberFormatException => + throw new RuntimeException(s"$paramName should be an integer. Found $value") + } + } + } + private def getBool(paramName: String, default: Boolean = false): Boolean = { val param = parameters.getOrElse(paramName, default.toString) if (param.toLowerCase == "true") { @@ -81,9 +94,9 @@ private[sql] class CSVOptions( name.map(CompressionCodecs.getCodecClassName) } - val maxColumns = 20480 + val maxColumns = getInt("maxColumns", 20480) - val maxCharsPerColumn = 100000 + val maxCharsPerColumn = getInt("maxCharsPerColumn", 1000000) val inputBufferSize = 128 From 091f6a7830bbee01fa580fbb0336b9f4fcac0dfa Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 19 Feb 2016 14:48:34 -0800 Subject: [PATCH 1936/2089] [SPARK-13091][SQL] Rewrite/Propagate constraints for Aliases This PR adds support for rewriting constraints if there are aliases in the query plan. For e.g., if there is a query of form `SELECT a, a AS b`, any constraints on `a` now also apply to `b`. JIRA: https://issues.apache.org/jira/browse/SPARK-13091 cc marmbrus Author: Sameer Agarwal Closes #11144 from sameeragarwal/alias. --- .../plans/logical/basicOperators.scala | 20 +++++++++++++++++++ .../plans/ConstraintPropagationSuite.scala | 20 ++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 502d898fea86c..7d155ac183d58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -50,6 +50,26 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + private def getAliasedConstraints: Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + + override def validConstraints: Set[Expression] = { + child.constraints.union(getAliasedConstraints) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index b5cf91394d910..373b1ffa83d23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -27,7 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = - tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + resolveColumn(tr.analyze, columnName) + + private def resolveColumn(plan: LogicalPlan, columnName: String): Expression = + plan.resolveQuoted(columnName, caseInsensitiveResolution).get private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) @@ -69,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aliases") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "x") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))) + } + test("propagating constraints in union") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int) From 983fa2d62029e7334fb661cb65c8cadaa4b86d1c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 19 Feb 2016 15:57:23 -0800 Subject: [PATCH 1937/2089] [SPARK-13407] Guard against garbage-collected accumulators in TaskMetrics.fromAccumulatorUpdates `TaskMetrics.fromAccumulatorUpdates()` can fail if accumulators have been garbage-collected on the driver. To guard against this, this patch introduces `ListenerTaskMetrics`, a subclass of `TaskMetrics` which is used only in `TaskMetrics.fromAccumulatorUpdates()` and which eliminates the need to access the original accumulators on the driver. Author: Josh Rosen Closes #11276 from JoshRosen/accum-updates-fix. --- .../apache/spark/executor/TaskMetrics.scala | 55 ++++++++++--------- .../spark/executor/TaskMetricsSuite.scala | 10 ++-- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 8ff0620f837c9..9da9cb594058b 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -364,6 +364,27 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se } +/** + * Internal subclass of [[TaskMetrics]] which is used only for posting events to listeners. + * Its purpose is to obviate the need for the driver to reconstruct the original accumulators, + * which might have been garbage-collected. See SPARK-13407 for more details. + * + * Instances of this class should be considered read-only and users should not call `inc*()` or + * `set*()` methods. While we could override the setter methods to throw + * UnsupportedOperationException, we choose not to do so because the overrides would quickly become + * out-of-date when new metrics are added. + */ +private[spark] class ListenerTaskMetrics( + initialAccums: Seq[Accumulator[_]], + accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics(initialAccums) { + + override def accumulatorUpdates(): Seq[AccumulableInfo] = accumUpdates + + override private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = { + throw new UnsupportedOperationException("This TaskMetrics is read-only") + } +} + private[spark] object TaskMetrics extends Logging { def empty: TaskMetrics = new TaskMetrics @@ -397,33 +418,15 @@ private[spark] object TaskMetrics extends Logging { // Initial accumulators are passed into the TaskMetrics constructor first because these // are required to be uniquely named. The rest of the accumulators from this task are // registered later because they need not satisfy this requirement. - val (initialAccumInfos, otherAccumInfos) = accumUpdates - .filter { info => info.update.isDefined } - .partition { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) } - val initialAccums = initialAccumInfos.map { info => - val accum = InternalAccumulator.create(info.name.get) - accum.setValueAny(info.update.get) - accum - } - // We don't know the types of the rest of the accumulators, so we try to find the same ones - // that were previously registered here on the driver and make copies of them. It is important - // that we copy the accumulators here since they are used across many tasks and we want to - // maintain a snapshot of their local task values when we post them to listeners downstream. - val otherAccums = otherAccumInfos.flatMap { info => - val id = info.id - val acc = Accumulators.get(id).map { a => - val newAcc = a.copy() - newAcc.setValueAny(info.update.get) - newAcc + val definedAccumUpdates = accumUpdates.filter { info => info.update.isDefined } + val initialAccums = definedAccumUpdates + .filter { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) } + .map { info => + val accum = InternalAccumulator.create(info.name.get) + accum.setValueAny(info.update.get) + accum } - if (acc.isEmpty) { - logWarning(s"encountered unregistered accumulator $id when reconstructing task metrics.") - } - acc - } - val metrics = new TaskMetrics(initialAccums) - otherAccums.foreach(metrics.registerAccumulator) - metrics + new ListenerTaskMetrics(initialAccums, definedAccumUpdates) } } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 3a1a67cdc001a..d91f50f18f431 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -475,10 +475,9 @@ class TaskMetricsSuite extends SparkFunSuite { } val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1) assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1) - // Test this with additional accumulators. Only the ones registered with `Accumulators` - // will show up in the reconstructed TaskMetrics. In practice, all accumulators created + // Test this with additional accumulators to ensure that we do not crash when handling + // updates from unregistered accumulators. In practice, all accumulators created // on the driver, internal or not, should be registered with `Accumulators` at some point. - // Here we show that reconstruction will succeed even if there are unregistered accumulators. val param = IntAccumulatorParam val registeredAccums = Seq( new Accumulator(0, param, Some("a"), internal = true, countFailedValues = true), @@ -497,9 +496,8 @@ class TaskMetricsSuite extends SparkFunSuite { val registeredAccumInfos = registeredAccums.map(makeInfo) val unregisteredAccumInfos = unregisteredAccums.map(makeInfo) val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos - val metrics2 = TaskMetrics.fromAccumulatorUpdates(accumUpdates2) - // accumulators that were not registered with `Accumulators` will not show up - assertUpdatesEquals(metrics2.accumulatorUpdates(), accumUpdates1 ++ registeredAccumInfos) + // Simply checking that this does not crash: + TaskMetrics.fromAccumulatorUpdates(accumUpdates2) } } From ec7a1d6e425509f2472c3ae9497c7da796ce8129 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 19 Feb 2016 22:27:10 -0800 Subject: [PATCH 1938/2089] [SPARK-12594] [SQL] Outer Join Elimination by Filter Conditions Conversion of outer joins, if the predicates in filter conditions can restrict the result sets so that all null-supplying rows are eliminated. - `full outer` -> `inner` if both sides have such predicates - `left outer` -> `inner` if the right side has such predicates - `right outer` -> `inner` if the left side has such predicates - `full outer` -> `left outer` if only the left side has such predicates - `full outer` -> `right outer` if only the right side has such predicates If applicable, this can greatly improve the performance, since outer join is much slower than inner join, full outer join is much slower than left/right outer join. The original PR is https://github.com/apache/spark/pull/10542 Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10567 from gatorsmile/outerJoinEliminationByFilterCond. --- .../sql/catalyst/optimizer/Optimizer.scala | 59 +++++- .../optimizer/OuterJoinEliminationSuite.scala | 195 ++++++++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 48 +++++ .../org/apache/spark/sql/JoinSuite.scala | 2 +- 4 files changed, 302 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 55c168d552a45..b7d8d932edfcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,8 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -62,6 +62,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { SetOperationPushDown, SamplePushDown, ReorderJoin, + OuterJoinElimination, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -931,6 +932,62 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Elimination of outer joins, if the predicates can restrict the result sets so that + * all null-supplying rows are eliminated + * + * - full outer -> inner if both sides have such predicates + * - left outer -> inner if the right side has such predicates + * - right outer -> inner if the left side has such predicates + * - full outer -> left outer if only the left side has such predicates + * - full outer -> right outer if only the right side has such predicates + * + * This rule should be executed before pushing down the Filter + */ +object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Returns whether the expression returns null or false when all inputs are nulls. + */ + private def canFilterOutNull(e: Expression): Boolean = { + if (!e.deterministic) return false + val attributes = e.references.toSeq + val emptyRow = new GenericInternalRow(attributes.length) + val v = BindReferences.bindReference(e, attributes).eval(emptyRow) + v == null || v == false + } + + private def buildNewJoinType(filter: Filter, join: Join): JoinType = { + val splitConjunctiveConditions: Seq[Expression] = splitConjunctivePredicates(filter.condition) + val leftConditions = splitConjunctiveConditions + .filter(_.references.subsetOf(join.left.outputSet)) + val rightConditions = splitConjunctiveConditions + .filter(_.references.subsetOf(join.right.outputSet)) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) || + filter.constraints.filter(_.isInstanceOf[IsNotNull]) + .exists(expr => join.left.outputSet.intersect(expr.references).nonEmpty) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) || + filter.constraints.filter(_.isInstanceOf[IsNotNull]) + .exists(expr => join.right.outputSet.intersect(expr.references).nonEmpty) + + join.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + val newJoinType = buildNewJoinType(f, j) + if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala new file mode 100644 index 0000000000000..a1dc836a5fd1c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class OuterJoinEliminationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Outer Join Elimination", Once, + OuterJoinElimination, + PushPredicateThroughJoin) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int, 'e.int, 'f.int) + + test("joins: full outer to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr >= 1 && "y.d".attr >= 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b >= 1) + val right = testRelation1.where('d >= 2) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: full outer to right") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("y.d".attr > 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('d > 2) + val correctAnswer = + left.join(right, RightOuter, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: full outer to left") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("x.a".attr <=> 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <=> 2) + val right = testRelation1 + val correctAnswer = + left.join(right, LeftOuter, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: right to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, RightOuter, Option("x.a".attr === "y.d".attr)).where("x.b".attr > 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b > 2) + val right = testRelation1 + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: left to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where("y.e".attr.isNotNull) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('e.isNotNull) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // evaluating if mixed OR and NOT expressions can eliminate all null-supplying rows + test("joins: left to inner with complicated filter predicates #1") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // eval(emptyRow) of 'e.in(1, 2) will return null instead of false + test("joins: left to inner with complicated filter predicates #2") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where('e.in(1, 2)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('e.in(1, 2)) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // evaluating if mixed OR and AND expressions can eliminate all null-supplying rows + test("joins: left to inner with complicated filter predicates #3") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where((!'e.isNull || ('d.isNotNull && 'f.isNull)) && 'e.isNull) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where((!'e.isNull || ('d.isNotNull && 'f.isNull)) && 'e.isNull) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // evaluating if the expressions that have both left and right attributes + // can eliminate all null-supplying rows + // FULL OUTER => INNER + test("joins: left to inner with complicated filter predicates #4") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr + 3 === "y.e".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, Inner, Option("b".attr + 3 === "e".attr && "a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index a5e5f156423cc..067a62d011ec4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -156,4 +158,50 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(df1.join(broadcast(pf1)).count() === 4) } } + + test("join - outer join conversion") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + + // outer -> left + val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) + assert(outerJoin2Left.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, LeftOuter, _) => j }.size === 1) + checkAnswer( + outerJoin2Left, + Row(3, 4, "3", null, null, null) :: Nil) + + // outer -> right + val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) + assert(outerJoin2Right.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, RightOuter, _) => j }.size === 1) + checkAnswer( + outerJoin2Right, + Row(null, null, null, 5, 6, "5") :: Nil) + + // outer -> inner + val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer"). + where($"a.int" === 1 && $"b.int2" === 3) + assert(outerJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + outerJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + + // right -> inner + val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) + assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + rightJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + + // left -> inner + val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) + assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + leftJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 92ff7e73fad88..8f2a0c0351361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -81,7 +81,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeOuterJoin]), + classOf[SortMergeJoin]), // converted from Right Outer to Inner ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", From 4f9a66481849dc867cf6592d53e0e9782361d20a Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Fri, 19 Feb 2016 22:28:47 -0800 Subject: [PATCH 1939/2089] [SPARK-12567] [SQL] Add aes_{encrypt,decrypt} UDFs Author: Kai Jiang Closes #10527 from vectorijk/spark-12567. --- python/pyspark/sql/functions.py | 37 ++++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../spark/sql/catalyst/expressions/misc.scala | 89 +++++++++++++++++++ .../expressions/MiscFunctionsSuite.scala | 84 +++++++++++++++++ .../org/apache/spark/sql/functions.scala | 36 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++ 6 files changed, 272 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5fc1cc2cae10a..7d038b8d9e89b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1125,6 +1125,43 @@ def hash(*cols): return Column(jc) +@ignore_unicode_prefix +@since(2.0) +def aes_encrypt(input, key): + """ + Encrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + tion Policy Files are installed. If input is invalid, key length is not one of the permitted + values or using 192/256 bits key before installing JCE, an exception will be thrown. + + >>> df = sqlContext.createDataFrame([('ABC','1234567890123456')], ['input','key']) + >>> df.select(base64(aes_encrypt(df.input, df.key)).alias('aes')).collect() + [Row(aes=u'y6Ss+zCYObpCbgfWfyNWTw==')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.aes_encrypt(_to_java_column(input), _to_java_column(key)) + return Column(jc) + + +@ignore_unicode_prefix +@since(2.0) +def aes_decrypt(input, key): + """ + Decrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + tion Policy Files are installed. If input is invalid, key length is not one of the permitted + values or using 192/256 bits key before installing JCE, an exception will be thrown. + + >>> df = sqlContext.createDataFrame([(u'y6Ss+zCYObpCbgfWfyNWTw==','1234567890123456')], \ + ['input','key']) + >>> df.select(aes_decrypt(unbase64(df.input), df.key).alias('aes')).collect() + [Row(aes=u'ABC')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.aes_decrypt(_to_java_column(input), _to_java_column(key)) + return Column(jc) + + # ---------------------- String/Binary functions ------------------------------ _string_functions = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1be97c7b81197..ae09c3d71f883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -278,6 +278,8 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), // misc functions + expression[AesEncrypt]("aes_encrypt"), + expression[AesDecrypt]("aes_decrypt"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Murmur3Hash]("hash"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index dcbb594afd86e..3b66f5797be64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 +import javax.crypto.Cipher +import javax.crypto.spec.SecretKeySpec import org.apache.commons.codec.digest.DigestUtils @@ -441,3 +443,90 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { """.stripMargin) } } + +/** + * A function that encrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + * tion Policy Files are installed. If either argument is NULL, the result will also be null. If + * input is invalid, key length is not one of the permitted values or using 192/256 bits key before + * installing JCE, an exception will be thrown. + */ +@ExpressionDescription( + usage = "_FUNC_(input, key) - Encrypts input using AES.", + extended = "> SELECT Base64(_FUNC_('ABC', '1234567890123456'));\n 'y6Ss+zCYObpCbgfWfyNWTw=='") +case class AesEncrypt(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val cipher = Cipher.getInstance("AES") + val secretKey: SecretKeySpec = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0, + input2.asInstanceOf[Array[Byte]].length, "AES") + cipher.init(Cipher.ENCRYPT_MODE, secretKey) + cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, input1.asInstanceOf[Array[Byte]].length) + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, (str, key) => { + val Cipher = "javax.crypto.Cipher" + val SecretKeySpec = "javax.crypto.spec.SecretKeySpec" + s""" + try { + $Cipher cipher = $Cipher.getInstance("AES"); + $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES"); + cipher.init($Cipher.ENCRYPT_MODE, secret); + ${ev.value} = cipher.doFinal($str, 0, $str.length); + } catch (java.security.GeneralSecurityException e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + }) + } +} + +/** + * A function that decrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192 + * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- + * tion Policy Files are installed. If either argument is NULL, the result will also be null. If + * input is invalid, key length is not one of the permitted values or using 192/256 bits key before + * installing JCE, an exception will be thrown. + */ +@ExpressionDescription( + usage = "_FUNC_(input, key) - Decrypts input using AES.", + extended = "> SELECT _FUNC_(UnBase64('y6Ss+zCYObpCbgfWfyNWTw=='),'1234567890123456');\n 'ABC'") +case class AesDecrypt(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val cipher = Cipher.getInstance("AES") + val secretKey = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0, + input2.asInstanceOf[Array[Byte]].length, "AES") + + cipher.init(Cipher.DECRYPT_MODE, secretKey) + UTF8String.fromBytes( + cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, + input1.asInstanceOf[Array[Byte]].length)) + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, (str, key) => { + val Cipher = "javax.crypto.Cipher" + val SecretKeySpec = "javax.crypto.spec.SecretKeySpec" + s""" + try { + $Cipher cipher = $Cipher.getInstance("AES"); + $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES"); + cipher.init($Cipher.DECRYPT_MODE, secret); + ${ev.value} = UTF8String.fromBytes(cipher.doFinal($str, 0, $str.length)); + } catch (java.security.GeneralSecurityException e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75131a6170222..67f2dc457d333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -132,4 +132,88 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + test("aesEncrypt") { + val expr1 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890123456".getBytes)) + val expr2 = AesEncrypt(Literal("".getBytes), Literal("1234567890123456".getBytes)) + + checkEvaluation(Base64(expr1), "y6Ss+zCYObpCbgfWfyNWTw==") + checkEvaluation(Base64(expr2), "BQGHoM3lqYcsurCRq3PlUw==") + + // input is null + checkEvaluation(AesEncrypt(Literal.create(null, BinaryType), + Literal("1234567890123456".getBytes)), null) + // key is null + checkEvaluation(AesEncrypt(Literal("ABC".getBytes), + Literal.create(null, BinaryType)), null) + // both are null + checkEvaluation(AesEncrypt(Literal.create(null, BinaryType), + Literal.create(null, BinaryType)), null) + + val expr3 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890".getBytes)) + // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits) + intercept[java.security.InvalidKeyException] { + evaluate(expr3) + } + intercept[java.security.InvalidKeyException] { + UnsafeProjection.create(expr3 :: Nil).apply(null) + } + } + + test("aesDecrypt") { + val expr1 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), + Literal("1234567890123456".getBytes)) + val expr2 = AesDecrypt(UnBase64(Literal("BQGHoM3lqYcsurCRq3PlUw==")), + Literal("1234567890123456".getBytes)) + + checkEvaluation(expr1, "ABC") + checkEvaluation(expr2, "") + + // input is null + checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)), + Literal("1234567890123456".getBytes)), null) + // key is null + checkEvaluation(AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), + Literal.create(null, BinaryType)), null) + // both are null + checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)), + Literal.create(null, BinaryType)), null) + + val expr3 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), + Literal("1234567890".getBytes)) + val expr4 = AesDecrypt(UnBase64(Literal("y6Ss+zCsdYObpCbgfWfyNW3Twewr")), + Literal("1234567890123456".getBytes)) + val expr5 = AesDecrypt(UnBase64(Literal("t6Ss+zCYObpCbgfWfyNWTw==")), + Literal("1234567890123456".getBytes)) + + // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits) + intercept[java.security.InvalidKeyException] { + evaluate(expr3) + } + intercept[java.security.InvalidKeyException] { + UnsafeProjection.create(expr3 :: Nil).apply(null) + } + // input can not be decrypted + intercept[javax.crypto.IllegalBlockSizeException] { + evaluate(expr4) + } + intercept[javax.crypto.IllegalBlockSizeException] { + UnsafeProjection.create(expr4 :: Nil).apply(null) + } + // input can not be decrypted + intercept[javax.crypto.BadPaddingException] { + evaluate(expr5) + } + intercept[javax.crypto.BadPaddingException] { + UnsafeProjection.create(expr5 :: Nil).apply(null) + } + } + + ignore("aesEncryptWith256bitsKey") { + // Before testing this, installing Java Cryptography Extension (JCE) Unlimited Strength Juris- + // diction Policy Files first. Otherwise `java.security.InvalidKeyException` will be thrown. + // Because Oracle JDK does not support 192 and 256 bits key out of box. + checkEvaluation(Base64(AesEncrypt(Literal("ABC".getBytes), + Literal("12345678901234561234567890123456".getBytes))), "nYfCuJeRd5eD60yXDw7WEA==") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c6992e18753..8da50bedfc8cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1982,6 +1982,42 @@ object functions extends LegacyFunctions { new Murmur3Hash(cols.map(_.expr)) } + /** + * Encrypts input using AES and Returns the result as a binary column. + * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java + * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If + * either argument is NULL, the result will also be null. If input is invalid, key length is not + * one of the permitted values or using 192/256 bits key before installing JCE, an exception will + * be thrown. + * + * @param input binary column to encrypt input + * @param key binary column of 128, 192 or 256 bits key + * + * @group misc_funcs + * @since 2.0.0 + */ + def aes_encrypt(input: Column, key: Column): Column = withExpr { + AesEncrypt(input.expr, key.expr) + } + + /** + * Decrypts input using AES and Returns the result as a string column. + * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java + * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If + * either argument is NULL, the result will also be null. If input is invalid, key length is not + * one of the permitted values or using 192/256 bits key before installing JCE, an exception will + * be thrown. + * + * @param input binary column to decrypt input + * @param key binary column of 128, 192 or 256 bits key + * + * @group misc_funcs + * @since 2.0.0 + */ + def aes_decrypt(input: Column, key: Column): Column = withExpr { + AesDecrypt(input.expr, key.expr) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index aff9efe4b2b16..0381d5728077b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -206,6 +206,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(2743272264L, 2180413220L)) } + test("misc aes encrypt function") { + val df = Seq(("ABC", "1234567890123456")).toDF("input", "key") + checkAnswer( + df.select(base64(aes_encrypt($"input", $"key"))), + Row("y6Ss+zCYObpCbgfWfyNWTw==") + ) + checkAnswer( + sql("SELECT base64(aes_encrypt('', '1234567890123456'))"), + Row("BQGHoM3lqYcsurCRq3PlUw==") + ) + } + + test("misc aes decrypt function") { + val df = Seq(("y6Ss+zCYObpCbgfWfyNWTw==", "1234567890123456")).toDF("input", "key") + checkAnswer( + df.select((aes_decrypt(unbase64($"input"), $"key"))), + Row("ABC") + ) + checkAnswer( + sql("SELECT aes_decrypt(unbase64('BQGHoM3lqYcsurCRq3PlUw=='), '1234567890123456')"), + Row("") + ) + } + test("string function find_in_set") { val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") From 6624a588c1b3b6c05fb39285bc6215102dd109c6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 19 Feb 2016 22:44:20 -0800 Subject: [PATCH 1940/2089] Revert "[SPARK-12567] [SQL] Add aes_{encrypt,decrypt} UDFs" This reverts commit 4f9a66481849dc867cf6592d53e0e9782361d20a. --- python/pyspark/sql/functions.py | 37 -------- .../catalyst/analysis/FunctionRegistry.scala | 2 - .../spark/sql/catalyst/expressions/misc.scala | 89 ------------------- .../expressions/MiscFunctionsSuite.scala | 84 ----------------- .../org/apache/spark/sql/functions.scala | 36 -------- .../spark/sql/DataFrameFunctionsSuite.scala | 24 ----- 6 files changed, 272 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7d038b8d9e89b..5fc1cc2cae10a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1125,43 +1125,6 @@ def hash(*cols): return Column(jc) -@ignore_unicode_prefix -@since(2.0) -def aes_encrypt(input, key): - """ - Encrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192 - and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- - tion Policy Files are installed. If input is invalid, key length is not one of the permitted - values or using 192/256 bits key before installing JCE, an exception will be thrown. - - >>> df = sqlContext.createDataFrame([('ABC','1234567890123456')], ['input','key']) - >>> df.select(base64(aes_encrypt(df.input, df.key)).alias('aes')).collect() - [Row(aes=u'y6Ss+zCYObpCbgfWfyNWTw==')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.aes_encrypt(_to_java_column(input), _to_java_column(key)) - return Column(jc) - - -@ignore_unicode_prefix -@since(2.0) -def aes_decrypt(input, key): - """ - Decrypts input of given column using AES. Key lengths of 128, 192 or 256 bits can be used. 192 - and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- - tion Policy Files are installed. If input is invalid, key length is not one of the permitted - values or using 192/256 bits key before installing JCE, an exception will be thrown. - - >>> df = sqlContext.createDataFrame([(u'y6Ss+zCYObpCbgfWfyNWTw==','1234567890123456')], \ - ['input','key']) - >>> df.select(aes_decrypt(unbase64(df.input), df.key).alias('aes')).collect() - [Row(aes=u'ABC')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.aes_decrypt(_to_java_column(input), _to_java_column(key)) - return Column(jc) - - # ---------------------- String/Binary functions ------------------------------ _string_functions = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ae09c3d71f883..1be97c7b81197 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -278,8 +278,6 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), // misc functions - expression[AesEncrypt]("aes_encrypt"), - expression[AesDecrypt]("aes_decrypt"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Murmur3Hash]("hash"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 3b66f5797be64..dcbb594afd86e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 -import javax.crypto.Cipher -import javax.crypto.spec.SecretKeySpec import org.apache.commons.codec.digest.DigestUtils @@ -443,90 +441,3 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { """.stripMargin) } } - -/** - * A function that encrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192 - * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- - * tion Policy Files are installed. If either argument is NULL, the result will also be null. If - * input is invalid, key length is not one of the permitted values or using 192/256 bits key before - * installing JCE, an exception will be thrown. - */ -@ExpressionDescription( - usage = "_FUNC_(input, key) - Encrypts input using AES.", - extended = "> SELECT Base64(_FUNC_('ABC', '1234567890123456'));\n 'y6Ss+zCYObpCbgfWfyNWTw=='") -case class AesEncrypt(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = BinaryType - override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val cipher = Cipher.getInstance("AES") - val secretKey: SecretKeySpec = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0, - input2.asInstanceOf[Array[Byte]].length, "AES") - cipher.init(Cipher.ENCRYPT_MODE, secretKey) - cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, input1.asInstanceOf[Array[Byte]].length) - } - - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - nullSafeCodeGen(ctx, ev, (str, key) => { - val Cipher = "javax.crypto.Cipher" - val SecretKeySpec = "javax.crypto.spec.SecretKeySpec" - s""" - try { - $Cipher cipher = $Cipher.getInstance("AES"); - $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES"); - cipher.init($Cipher.ENCRYPT_MODE, secret); - ${ev.value} = cipher.doFinal($str, 0, $str.length); - } catch (java.security.GeneralSecurityException e) { - org.apache.spark.unsafe.Platform.throwException(e); - } - """ - }) - } -} - -/** - * A function that decrypts input using AES. Key lengths of 128, 192 or 256 bits can be used. 192 - * and 256 bits keys can be used if Java Cryptography Extension (JCE) Unlimited Strength Jurisdic- - * tion Policy Files are installed. If either argument is NULL, the result will also be null. If - * input is invalid, key length is not one of the permitted values or using 192/256 bits key before - * installing JCE, an exception will be thrown. - */ -@ExpressionDescription( - usage = "_FUNC_(input, key) - Decrypts input using AES.", - extended = "> SELECT _FUNC_(UnBase64('y6Ss+zCYObpCbgfWfyNWTw=='),'1234567890123456');\n 'ABC'") -case class AesDecrypt(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val cipher = Cipher.getInstance("AES") - val secretKey = new SecretKeySpec(input2.asInstanceOf[Array[Byte]], 0, - input2.asInstanceOf[Array[Byte]].length, "AES") - - cipher.init(Cipher.DECRYPT_MODE, secretKey) - UTF8String.fromBytes( - cipher.doFinal(input1.asInstanceOf[Array[Byte]], 0, - input1.asInstanceOf[Array[Byte]].length)) - } - - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - nullSafeCodeGen(ctx, ev, (str, key) => { - val Cipher = "javax.crypto.Cipher" - val SecretKeySpec = "javax.crypto.spec.SecretKeySpec" - s""" - try { - $Cipher cipher = $Cipher.getInstance("AES"); - $SecretKeySpec secret = new $SecretKeySpec($key, 0, $key.length, "AES"); - cipher.init($Cipher.DECRYPT_MODE, secret); - ${ev.value} = UTF8String.fromBytes(cipher.doFinal($str, 0, $str.length)); - } catch (java.security.GeneralSecurityException e) { - org.apache.spark.unsafe.Platform.throwException(e); - } - """ - }) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 67f2dc457d333..75131a6170222 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -132,88 +132,4 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - - test("aesEncrypt") { - val expr1 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890123456".getBytes)) - val expr2 = AesEncrypt(Literal("".getBytes), Literal("1234567890123456".getBytes)) - - checkEvaluation(Base64(expr1), "y6Ss+zCYObpCbgfWfyNWTw==") - checkEvaluation(Base64(expr2), "BQGHoM3lqYcsurCRq3PlUw==") - - // input is null - checkEvaluation(AesEncrypt(Literal.create(null, BinaryType), - Literal("1234567890123456".getBytes)), null) - // key is null - checkEvaluation(AesEncrypt(Literal("ABC".getBytes), - Literal.create(null, BinaryType)), null) - // both are null - checkEvaluation(AesEncrypt(Literal.create(null, BinaryType), - Literal.create(null, BinaryType)), null) - - val expr3 = AesEncrypt(Literal("ABC".getBytes), Literal("1234567890".getBytes)) - // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits) - intercept[java.security.InvalidKeyException] { - evaluate(expr3) - } - intercept[java.security.InvalidKeyException] { - UnsafeProjection.create(expr3 :: Nil).apply(null) - } - } - - test("aesDecrypt") { - val expr1 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), - Literal("1234567890123456".getBytes)) - val expr2 = AesDecrypt(UnBase64(Literal("BQGHoM3lqYcsurCRq3PlUw==")), - Literal("1234567890123456".getBytes)) - - checkEvaluation(expr1, "ABC") - checkEvaluation(expr2, "") - - // input is null - checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)), - Literal("1234567890123456".getBytes)), null) - // key is null - checkEvaluation(AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), - Literal.create(null, BinaryType)), null) - // both are null - checkEvaluation(AesDecrypt(UnBase64(Literal.create(null, StringType)), - Literal.create(null, BinaryType)), null) - - val expr3 = AesDecrypt(UnBase64(Literal("y6Ss+zCYObpCbgfWfyNWTw==")), - Literal("1234567890".getBytes)) - val expr4 = AesDecrypt(UnBase64(Literal("y6Ss+zCsdYObpCbgfWfyNW3Twewr")), - Literal("1234567890123456".getBytes)) - val expr5 = AesDecrypt(UnBase64(Literal("t6Ss+zCYObpCbgfWfyNWTw==")), - Literal("1234567890123456".getBytes)) - - // key length (80 bits) is not one of the permitted values (128, 192 or 256 bits) - intercept[java.security.InvalidKeyException] { - evaluate(expr3) - } - intercept[java.security.InvalidKeyException] { - UnsafeProjection.create(expr3 :: Nil).apply(null) - } - // input can not be decrypted - intercept[javax.crypto.IllegalBlockSizeException] { - evaluate(expr4) - } - intercept[javax.crypto.IllegalBlockSizeException] { - UnsafeProjection.create(expr4 :: Nil).apply(null) - } - // input can not be decrypted - intercept[javax.crypto.BadPaddingException] { - evaluate(expr5) - } - intercept[javax.crypto.BadPaddingException] { - UnsafeProjection.create(expr5 :: Nil).apply(null) - } - } - - ignore("aesEncryptWith256bitsKey") { - // Before testing this, installing Java Cryptography Extension (JCE) Unlimited Strength Juris- - // diction Policy Files first. Otherwise `java.security.InvalidKeyException` will be thrown. - // Because Oracle JDK does not support 192 and 256 bits key out of box. - checkEvaluation(Base64(AesEncrypt(Literal("ABC".getBytes), - Literal("12345678901234561234567890123456".getBytes))), "nYfCuJeRd5eD60yXDw7WEA==") - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8da50bedfc8cb..97c6992e18753 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1982,42 +1982,6 @@ object functions extends LegacyFunctions { new Murmur3Hash(cols.map(_.expr)) } - /** - * Encrypts input using AES and Returns the result as a binary column. - * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java - * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If - * either argument is NULL, the result will also be null. If input is invalid, key length is not - * one of the permitted values or using 192/256 bits key before installing JCE, an exception will - * be thrown. - * - * @param input binary column to encrypt input - * @param key binary column of 128, 192 or 256 bits key - * - * @group misc_funcs - * @since 2.0.0 - */ - def aes_encrypt(input: Column, key: Column): Column = withExpr { - AesEncrypt(input.expr, key.expr) - } - - /** - * Decrypts input using AES and Returns the result as a string column. - * Key lengths of 128, 192 or 256 bits can be used. 192 and 256 bits keys can be used if Java - * Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files are installed. If - * either argument is NULL, the result will also be null. If input is invalid, key length is not - * one of the permitted values or using 192/256 bits key before installing JCE, an exception will - * be thrown. - * - * @param input binary column to decrypt input - * @param key binary column of 128, 192 or 256 bits key - * - * @group misc_funcs - * @since 2.0.0 - */ - def aes_decrypt(input: Column, key: Column): Column = withExpr { - AesDecrypt(input.expr, key.expr) - } - ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0381d5728077b..aff9efe4b2b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -206,30 +206,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(2743272264L, 2180413220L)) } - test("misc aes encrypt function") { - val df = Seq(("ABC", "1234567890123456")).toDF("input", "key") - checkAnswer( - df.select(base64(aes_encrypt($"input", $"key"))), - Row("y6Ss+zCYObpCbgfWfyNWTw==") - ) - checkAnswer( - sql("SELECT base64(aes_encrypt('', '1234567890123456'))"), - Row("BQGHoM3lqYcsurCRq3PlUw==") - ) - } - - test("misc aes decrypt function") { - val df = Seq(("y6Ss+zCYObpCbgfWfyNWTw==", "1234567890123456")).toDF("input", "key") - checkAnswer( - df.select((aes_decrypt(unbase64($"input"), $"key"))), - Row("ABC") - ) - checkAnswer( - sql("SELECT aes_decrypt(unbase64('BQGHoM3lqYcsurCRq3PlUw=='), '1234567890123456')"), - Row("") - ) - } - test("string function find_in_set") { val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") From dfb2ae2f141960c10200a870ed21583e6af5c536 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 19 Feb 2016 23:00:08 -0800 Subject: [PATCH 1941/2089] [SPARK-13408] [CORE] Ignore errors when it's already reported in JobWaiter ## What changes were proposed in this pull request? `JobWaiter.taskSucceeded` will be called for each task. When `resultHandler` throws an exception, `taskSucceeded` will also throw it for each task. DAGScheduler just catches it and reports it like this: ```Scala try { job.listener.taskSucceeded(rt.outputId, event.result) } catch { case e: Exception => // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } ``` Therefore `JobWaiter.jobFailed` may be called multiple times. So `JobWaiter.jobFailed` should use `Promise.tryFailure` instead of `Promise.failure` because the latter one doesn't support calling multiple times. ## How was the this patch tested? Jenkins tests. Author: Shixiong Zhu Closes #11280 from zsxwing/SPARK-13408. --- .../apache/spark/scheduler/JobWaiter.scala | 11 +++-- .../spark/scheduler/JobWaiterSuite.scala | 41 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 4326135186a73..ac8229a3c1034 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{Future, Promise} +import org.apache.spark.Logging + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -30,7 +32,7 @@ private[spark] class JobWaiter[T]( val jobId: Int, totalTasks: Int, resultHandler: (Int, T) => Unit) - extends JobListener { + extends JobListener with Logging { private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero @@ -61,7 +63,10 @@ private[spark] class JobWaiter[T]( } } - override def jobFailed(exception: Exception): Unit = - jobPromise.failure(exception) + override def jobFailed(exception: Exception): Unit = { + if (!jobPromise.tryFailure(exception)) { + logWarning("Ignore failure", exception) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala new file mode 100644 index 0000000000000..bc8e513fe5bc8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Failure + +import org.apache.spark.SparkFunSuite + +class JobWaiterSuite extends SparkFunSuite { + + test("call jobFailed multiple times") { + val waiter = new JobWaiter[Int](null, 0, totalTasks = 2, null) + + // Should not throw exception if calling jobFailed multiple times + waiter.jobFailed(new RuntimeException("Oops 1")) + waiter.jobFailed(new RuntimeException("Oops 2")) + waiter.jobFailed(new RuntimeException("Oops 3")) + + waiter.completionFuture.value match { + case Some(Failure(e)) => + // We should receive the first exception + assert("Oops 1" === e.getMessage) + case other => fail("Should receiver the first exception but it was " + other) + } + } +} From 9ca79c1ece5ad139719e4eea9f7d1b59aed01b20 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 20 Feb 2016 09:07:19 +0000 Subject: [PATCH 1942/2089] [SPARK-13302][PYSPARK][TESTS] Move the temp file creation and cleanup outside of the doctests Some of the new doctests in ml/clustering.py have a lot of setup code, move the setup code to the general test init to keep the doctest more example-style looking. In part this is a follow up to https://github.com/apache/spark/pull/10999 Note that the same pattern is followed in regression & recommendation - might as well clean up all three at the same time. Author: Holden Karau Closes #11197 from holdenk/SPARK-13302-cleanup-doctests-in-ml-clustering. --- python/pyspark/ml/clustering.py | 25 ++++++++++++++----------- python/pyspark/ml/recommendation.py | 25 ++++++++++++++----------- python/pyspark/ml/regression.py | 25 ++++++++++++++----------- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 91278d570a642..611b9190491c1 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> kmeans_path = path + "/kmeans" + >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) >>> kmeans2.getK() 2 - >>> model_path = path + "/kmeans_model" + >>> model_path = temp_path + "/kmeans_model" >>> model.save(model_path) >>> model2 = KMeansModel.load(model_path) >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] array([ True, True], dtype=bool) - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.5.0 """ @@ -310,7 +303,17 @@ def _create_model(self, java_model): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index ef9448855e175..2b605e5c5078b 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,14 +82,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> als_path = path + "/als" + >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) >>> als.getMaxIter() 5 - >>> model_path = path + "/als_model" + >>> model_path = temp_path + "/als_model" >>> model.save(model_path) >>> model2 = ALSModel.load(model_path) >>> model.rank == model2.rank @@ -98,11 +96,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha True >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect()) True - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.4.0 """ @@ -340,7 +333,17 @@ def itemFactors(self): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 20dc6c2db91f3..de4a751a54796 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> lr_path = path + "/lr" + >>> lr_path = temp_path + "/lr" >>> lr.save(lr_path) >>> lr2 = LinearRegression.load(lr_path) >>> lr2.getMaxIter() 5 - >>> model_path = path + "/lr_model" + >>> model_path = temp_path + "/lr_model" >>> model.save(model_path) >>> model2 = LinearRegressionModel.load(model_path) >>> model.coefficients[0] == model2.coefficients[0] True >>> model.intercept == model2.intercept True - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.4.0 """ @@ -850,7 +843,17 @@ def predict(self, features): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) From 6ce7c481dc3a94af503d0f3f86e2be7ba82b3bbc Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Sat, 20 Feb 2016 12:24:10 -0800 Subject: [PATCH 1943/2089] [SPARK-13386][GRAPHX] ConnectedComponents should support maxIteration option JIRA: https://issues.apache.org/jira/browse/SPARK-13386 ## What changes were proposed in this pull request? add maxIteration option for ConnectedComponents algorithm ## How was the this patch tested? unit tests passed Author: Zheng RuiFeng Closes #11268 from zhengruifeng/ccwithmax. --- .../org/apache/spark/graphx/GraphOps.scala | 14 +++++++++-- .../graphx/lib/ConnectedComponents.scala | 24 +++++++++++++++---- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d048fb5d561f3..97a82239a97ee 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -405,14 +405,24 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali PageRank.run(graph, numIter, resetProb) } + /** + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] + */ + def connectedComponents(): Graph[VertexId, ED] = { + ConnectedComponents.run(graph) + } + /** * Compute the connected component membership of each vertex and return a graph with the vertex * value containing the lowest vertex id in the connected component containing that vertex. * * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] */ - def connectedComponents(): Graph[VertexId, ED] = { - ConnectedComponents.run(graph) + def connectedComponents(maxIterations: Int): Graph[VertexId, ED] = { + ConnectedComponents.run(graph, maxIterations) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index f72cbb15242ec..40cf0735e274b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -29,13 +29,14 @@ object ConnectedComponents { * * @tparam VD the vertex attribute type (discarded in the computation) * @tparam ED the edge attribute type (preserved in the computation) - * * @param graph the graph for which to compute the connected components - * + * @param maxIterations the maximum number of iterations to run for * @return a graph with vertex attributes containing the smallest vertex in each * connected component */ - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], + maxIterations: Int): Graph[VertexId, ED] = { + require(maxIterations > 0) val ccGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { @@ -47,11 +48,26 @@ object ConnectedComponents { } } val initialMessage = Long.MaxValue - val pregelGraph = Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( + val pregelGraph = Pregel(ccGraph, initialMessage, + maxIterations, EdgeDirection.Either)( vprog = (id, attr, msg) => math.min(attr, msg), sendMsg = sendMessage, mergeMsg = (a, b) => math.min(a, b)) ccGraph.unpersist() pregelGraph } // end of connectedComponents + + /** + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @tparam VD the vertex attribute type (discarded in the computation) + * @tparam ED the edge attribute type (preserved in the computation) + * @param graph the graph for which to compute the connected components + * @return a graph with vertex attributes containing the smallest vertex in each + * connected component + */ + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { + run(graph, Int.MaxValue) + } } From a4a081d1df043cce6db0284ef552e5174ebb0d02 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Sat, 20 Feb 2016 12:58:47 -0800 Subject: [PATCH 1944/2089] [SPARK-13414][MESOS] Allow multiple dispatchers to be launched. ## What changes were proposed in this pull request? Users might want to start multiple mesos dispatchers, as each dispatcher can potentially be part of different roles and used for multi-tenancy. To allow multiple Mesos dispatchers to be launched, we need to be able to specify a instance number when starting the dispatcher daemon. ## How was the this patch tested? Manual testing Author: Timothy Chen Closes #11281 from tnachen/multiple_cluster_dispatchers. --- sbin/start-mesos-dispatcher.sh | 5 ++++- sbin/stop-mesos-dispatcher.sh | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index 4777e1668c703..06a966d1c20b4 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -37,5 +37,8 @@ if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then SPARK_MESOS_DISPATCHER_HOST=`hostname` fi +if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then + SPARK_MESOS_DISPATCHER_NUM=1 +fi -"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" +"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher $SPARK_MESOS_DISPATCHER_NUM --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh index 5c0b4e051db38..b13e018c7d41e 100755 --- a/sbin/stop-mesos-dispatcher.sh +++ b/sbin/stop-mesos-dispatcher.sh @@ -24,5 +24,10 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" -"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 +if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then + SPARK_MESOS_DISPATCHER_NUM=1 +fi + +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher \ + $SPARK_MESOS_DISPATCHER_NUM From f88c641bc8afa836ae597801ae7520da90e8472a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 20 Feb 2016 13:53:23 -0800 Subject: [PATCH 1945/2089] [SPARK-13310] [SQL] Resolve Missing Sorting Columns in Generate ```scala // case 1: missing sort columns are resolvable if join is true sql("SELECT explode(a) AS val, b FROM data WHERE b < 2 order by val, c") // case 2: missing sort columns are not resolvable if join is false. Thus, issue an error message in this case sql("SELECT explode(a) AS val FROM data order by val, c") ``` When sort columns are not in `Generate`, we can resolve them when `join` is equal to `true`. Still trying to add more test cases for the other `UnaryNode` types. Could you review the changes? davies cloud-fan Thanks! Author: gatorsmile Closes #11198 from gatorsmile/missingInSort. --- .../sql/catalyst/analysis/Analyzer.scala | 34 +++++++++++++------ .../analysis/AnalysisErrorSuite.scala | 5 +++ .../sql/hive/execution/SQLQuerySuite.scala | 32 +++++++++++++++-- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8368657b2ba2b..54b91adad11a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -609,17 +609,24 @@ class Analyzer( case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, child) if !s.resolved && child.resolved => - val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) - val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(child.output, - Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) - } else if (newOrder != order) { - s.copy(order = newOrder) - } else { - s + try { + val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) + val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(child.output, + Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) + } else if (newOrder != order) { + s.copy(order = newOrder) + } else { + s + } + } catch { + // Attempting to resolve it might fail. When this happens, return the original plan. + // Users will see an AnalysisException for resolution failure of missing attributes + // in Sort + case ae: AnalysisException => s } } @@ -649,6 +656,11 @@ class Analyzer( } val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs a.copy(aggregateExpressions = newAggregateExpressions) + case g: Generate => + // If join is false, we will convert it to true for getting from the child the missing + // attributes that its child might have or could have. + val missing = missingAttrs -- g.child.outputSet + g.copy(join = true, child = addMissingAttr(g.child, missing)) case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) case other => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index e0cec09742eba..a46006a0c622b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -193,6 +193,11 @@ class AnalysisErrorSuite extends AnalysisTest { mapRelation.orderBy('map.asc), "sort" :: "type" :: "map" :: Nil) + errorTest( + "sorting by attributes are not from grouping expressions", + testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc), + "cannot resolve" :: "'b'" :: "given input columns" :: "[a, c, a3]" :: Nil) + errorTest( "non-boolean filters", testRelation.where(Literal(1)), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c2b0daa66b013..75fa09db69ea0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,11 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} -import org.apache.spark.sql.catalyst.parser.ParserConf -import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.{HiveContext, HiveQl, MetastoreRelation} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -927,6 +926,33 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("Sorting columns are not in Generate") { + withTempTable("data") { + sqlContext.range(1, 5) + .select(array($"id", $"id" + 1).as("a"), $"id".as("b"), (lit(10) - $"id").as("c")) + .registerTempTable("data") + + // case 1: missing sort columns are resolvable if join is true + checkAnswer( + sql("SELECT explode(a) AS val, b FROM data WHERE b < 2 order by val, c"), + Row(1, 1) :: Row(2, 1) :: Nil) + + // case 2: missing sort columns are resolvable if join is false + checkAnswer( + sql("SELECT explode(a) AS val FROM data order by val, c"), + Seq(1, 2, 2, 3, 3, 4, 4, 5).map(i => Row(i))) + + // case 3: missing sort columns are resolvable if join is true and outer is true + checkAnswer( + sql( + """ + |SELECT C.val, b FROM data LATERAL VIEW OUTER explode(a) C as val + |where b < 2 order by c, val, b + """.stripMargin), + Row(1, 1) :: Row(2, 1) :: Nil) + } + } + test("window function: Sorting columns are not in Project") { val data = Seq( WindowData(1, "d", 10), From 7925071280bfa1570435bde3e93492eaf2167d56 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 20 Feb 2016 21:01:51 -0800 Subject: [PATCH 1946/2089] [SPARK-13306] [SQL] uncorrelated scalar subquery A scalar subquery is a subquery that only generate single row and single column, could be used as part of expression. Uncorrelated scalar subquery means it does not has a reference to external table. All the uncorrelated scalar subqueries will be executed during prepare() of SparkPlan. The plans for query ```sql select 1 + (select 2 + (select 3)) ``` looks like this ``` == Parsed Logical Plan == 'Project [unresolvedalias((1 + subquery#1),None)] :- OneRowRelation$ +- 'Subquery subquery#1 +- 'Project [unresolvedalias((2 + subquery#0),None)] :- OneRowRelation$ +- 'Subquery subquery#0 +- 'Project [unresolvedalias(3,None)] +- OneRowRelation$ == Analyzed Logical Plan == _c0: int Project [(1 + subquery#1) AS _c0#4] :- OneRowRelation$ +- Subquery subquery#1 +- Project [(2 + subquery#0) AS _c0#3] :- OneRowRelation$ +- Subquery subquery#0 +- Project [3 AS _c0#2] +- OneRowRelation$ == Optimized Logical Plan == Project [(1 + subquery#1) AS _c0#4] :- OneRowRelation$ +- Subquery subquery#1 +- Project [(2 + subquery#0) AS _c0#3] :- OneRowRelation$ +- Subquery subquery#0 +- Project [3 AS _c0#2] +- OneRowRelation$ == Physical Plan == WholeStageCodegen : +- Project [(1 + subquery#1) AS _c0#4] : :- INPUT : +- Subquery subquery#1 : +- WholeStageCodegen : : +- Project [(2 + subquery#0) AS _c0#3] : : :- INPUT : : +- Subquery subquery#0 : : +- WholeStageCodegen : : : +- Project [3 AS _c0#2] : : : +- INPUT : : +- Scan OneRowRelation[] : +- Scan OneRowRelation[] +- Scan OneRowRelation[] ``` Author: Davies Liu Closes #11190 from davies/scalar_subquery. --- .../sql/catalyst/parser/ExpressionParser.g | 2 + .../spark/sql/catalyst/CatalystQl.scala | 2 + .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++ .../sql/catalyst/expressions/subquery.scala | 82 ++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 14 +++- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++ .../spark/sql/catalyst/trees/TreeNode.scala | 11 ++- .../spark/sql/catalyst/CatalystQlSuite.scala | 7 ++ .../analysis/AnalysisErrorSuite.scala | 11 +++ .../org/apache/spark/sql/SQLContext.scala | 1 + .../spark/sql/execution/SparkPlan.scala | 50 +++++++++++ .../sql/execution/WholeStageCodegen.scala | 5 +- .../spark/sql/execution/basicOperators.scala | 15 ++++ .../apache/spark/sql/execution/subquery.scala | 74 ++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 7 +- .../org/apache/spark/sql/SubquerySuite.scala | 84 +++++++++++++++++++ 16 files changed, 391 insertions(+), 12 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index c162c1a0c5789..10f2e2416bb64 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -205,6 +205,8 @@ atomExpression | whenExpression | (functionName LPAREN) => function | tableOrColumn + | (LPAREN KW_SELECT) => subQueryExpression + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression) | LPAREN! expression RPAREN! ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 8099751900a42..b16025a17e2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -667,6 +667,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) } + case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) => + ScalarSubquery(nodeToPlan(subquery)) /* Stars (*) */ case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 54b91adad11a9..ba306f8b32b7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,6 +80,7 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveSubquery :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalJoin :: @@ -120,7 +121,14 @@ class Analyzer( withAlias.getOrElse(relation) } substituted.getOrElse(u) + case other => + // This can't be done in ResolveSubquery because that does not know the CTE. + other transformExpressions { + case e: SubqueryExpression => + e.withNewPlan(substituteCTE(e.query, cteRelations)) + } } + } } @@ -716,6 +724,30 @@ class Analyzer( } } + /** + * This rule resolve subqueries inside expressions. + * + * Note: CTE are handled in CTESubstitution. + */ + object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { + + private def hasSubquery(e: Expression): Boolean = { + e.find(_.isInstanceOf[SubqueryExpression]).isDefined + } + + private def hasSubquery(q: LogicalPlan): Boolean = { + q.expressions.exists(hasSubquery) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case q: LogicalPlan if q.childrenResolved && hasSubquery(q) => + q transformExpressions { + case e: SubqueryExpression if !e.query.resolved => + e.withNewPlan(execute(e.query)) + } + } + } + /** * Turns projections that contain aggregate expressions into aggregations. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala new file mode 100644 index 0000000000000..a8f5e1f63d4c7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} +import org.apache.spark.sql.types.DataType + +/** + * An interface for subquery that is used in expressions. + */ +abstract class SubqueryExpression extends LeafExpression { + + /** + * The logical plan of the query. + */ + def query: LogicalPlan + + /** + * Either a logical plan or a physical plan. The generated tree string (explain output) uses this + * field to explain the subquery. + */ + def plan: QueryPlan[_] + + /** + * Updates the query with new logical plan. + */ + def withNewPlan(plan: LogicalPlan): SubqueryExpression +} + +/** + * A subquery that will return only one row and one column. + * + * This will be converted into [[execution.ScalarSubquery]] during physical planning. + * + * Note: `exprId` is used to have unique name in explain string output. + */ +case class ScalarSubquery( + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression with Unevaluable { + + override def plan: LogicalPlan = Subquery(toString, query) + + override lazy val resolved: Boolean = query.resolved + + override def dataType: DataType = query.schema.fields.head.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (query.schema.length != 1) { + TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + + query.schema.length.toString) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId) + + override def toString: String = s"subquery#${exprId.id}" + + // TODO: support sql() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b7d8d932edfcf..202f6b5e07fce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -90,7 +90,19 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), - ConvertToLocalRelation) :: Nil + ConvertToLocalRelation) :: + Batch("Subquery", Once, + OptimizeSubqueries) :: Nil + } + + /** + * Optimize all the subqueries inside expression. + */ + object OptimizeSubqueries extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case subquery: SubqueryExpression => + subquery.withNewPlan(Optimizer.this.execute(subquery.query)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 05f5bdbfc0769..86bd33f526209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Subquery import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -226,4 +227,9 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" override def simpleString: String = statePrefix + super.simpleString + + override def treeChildren: Seq[PlanType] = { + val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e}) + children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType]) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 30df2a84f62c4..e46ce1cee7c6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -448,6 +448,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } + /** + * All the nodes that will be used to generate tree string. + */ + protected def treeChildren: Seq[BaseType] = children + /** * Appends the string represent of this node and its children to the given StringBuilder. * @@ -470,9 +475,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(simpleString) builder.append("\n") - if (children.nonEmpty) { - children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + if (treeChildren.nonEmpty) { + treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) } builder diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 8d7d6b5bf52ea..ed7121831ac29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.BooleanType import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { @@ -201,4 +202,10 @@ class CatalystQlSuite extends PlanTest { parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + "from windowData") } + + test("subquery") { + parser.parsePlan("select (select max(b) from s) ss from t") + parser.parsePlan("select * from t where a = (select b from s)") + parser.parsePlan("select * from t group by g having a > (select b from s)") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a46006a0c622b..69f78a097e15c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -113,6 +113,17 @@ class AnalysisErrorSuite extends AnalysisTest { val dateLit = Literal.create(null, DateType) + errorTest( + "scalar subquery with 2 columns", + testRelation.select( + (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), + "Scalar subquery must return only one column, but got 2" :: Nil) + + errorTest( + "scalar subquery with no column", + testRelation.select(ScalarSubquery(LocalRelation()).as('a)), + "Scalar subquery must return only one column, but got 0" :: Nil) + errorTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c7d1096a1384c..932df36b859b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -884,6 +884,7 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( + Batch("Subquery", Once, PlanSubqueries(self)), Batch("Add exchange", Once, EnsureRequirements(self)), Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index c72b8dc70708f..872ccde883060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import org.apache.spark.Logging import org.apache.spark.rdd.{RDD, RDDOperationScope} @@ -31,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ThreadUtils /** * The base class for physical operators. @@ -112,16 +115,58 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ final def execute(): RDD[InternalRow] = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() + waitForSubqueries() doExecute() } } + // All the subqueries and their Future of results. + @transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]() + + /** + * Collects all the subqueries and create a Future to take the first two rows of them. + */ + protected def prepareSubqueries(): Unit = { + val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) + allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => + val futureResult = Future { + // We only need the first row, try to take two rows so we can throw an exception if there + // are more than one rows returned. + e.executedPlan.executeTake(2) + }(SparkPlan.subqueryExecutionContext) + queryResults += e -> futureResult + } + } + + /** + * Waits for all the subqueries to finish and updates the results. + */ + protected def waitForSubqueries(): Unit = { + // fill in the result of subqueries + queryResults.foreach { + case (e, futureResult) => + val rows = Await.result(futureResult, Duration.Inf) + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column") + e.updateResult(rows(0).get(0, e.dataType)) + } else { + // There is no rows returned, the result should be null. + e.updateResult(null) + } + } + queryResults.clear() + } + /** * Prepare a SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { if (prepareCalled.compareAndSet(false, true)) { doPrepare() + prepareSubqueries() children.foreach(_.prepare()) } } @@ -231,6 +276,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } +object SparkPlan { + private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) +} + private[sql] trait LeafNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8626f54eb413c..ad8564e96aa6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -73,9 +73,10 @@ trait CodegenSupport extends SparkPlan { /** * Returns Java source code to process the rows from upstream. */ - def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent ctx.freshNamePrefix = variablePrefix + waitForSubqueries() doProduce(ctx) } @@ -101,7 +102,7 @@ trait CodegenSupport extends SparkPlan { /** * Consume the columns generated from current SparkPlan, call it's parent. */ - def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { if (input != null) { assert(input.length == output.length) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4b82d5563460b..55bddd196ec46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -343,3 +343,18 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl protected override def doExecute(): RDD[InternalRow] = child.execute() } + +/** + * A plan as subquery. + * + * This is used to generate tree string for SparkScalarSubquery. + */ +case class Subquery(name: String, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala new file mode 100644 index 0000000000000..9c645c78e8732 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType + +/** + * A subquery that will return only one row and one column. + * + * This is the physical copy of ScalarSubquery to be used inside SparkPlan. + */ +case class ScalarSubquery( + @transient executedPlan: SparkPlan, + exprId: ExprId) + extends SubqueryExpression { + + override def query: LogicalPlan = throw new UnsupportedOperationException + override def withNewPlan(plan: LogicalPlan): SubqueryExpression = { + throw new UnsupportedOperationException + } + override def plan: SparkPlan = Subquery(simpleString, executedPlan) + + override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def nullable: Boolean = true + override def toString: String = s"subquery#${exprId.id}" + + // the first column in first row from `query`. + private var result: Any = null + + def updateResult(v: Any): Unit = { + result = v + } + + override def eval(input: InternalRow): Any = result + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + Literal.create(result, dataType).genCode(ctx, ev) + } +} + +/** + * Convert the subquery from logical plan into executed plan. + */ +private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = { + plan.transformAllExpressions { + case subquery: expressions.ScalarSubquery => + val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() + val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) + ScalarSubquery(executedPlan, subquery.exprId) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b3e179755a19b..b4d6f4ecddc6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,18 +21,13 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.CatalystQl -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.parser.ParserConf -import org.apache.spark.sql.execution.{aggregate, SparkQl} +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ - class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala new file mode 100644 index 0000000000000..e851eb02f01b3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class SubquerySuite extends QueryTest with SharedSQLContext { + + test("simple uncorrelated scalar subquery") { + assertResult(Array(Row(1))) { + sql("select (select 1 as b) as b").collect() + } + + assertResult(Array(Row(1))) { + sql("with t2 as (select 1 as b, 2 as c) " + + "select a from (select 1 as a union all select 2 as a) t " + + "where a = (select max(b) from t2) ").collect() + } + + assertResult(Array(Row(3))) { + sql("select (select (select 1) + 1) + 1").collect() + } + + // more than one columns + val error = intercept[AnalysisException] { + sql("select (select 1, 2) as b").collect() + } + assert(error.message contains "Scalar subquery must return only one column, but got 2") + + // more than one rows + val error2 = intercept[RuntimeException] { + sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() + } + assert(error2.getMessage contains + "more than one row returned by a subquery used as an expression") + + // string type + assertResult(Array(Row("s"))) { + sql("select (select 's' as s) as b").collect() + } + + // zero rows + assertResult(Array(Row(null))) { + sql("select (select 's' as s limit 0) as b").collect() + } + } + + test("uncorrelated scalar subquery on testData") { + // initialize test Data + testData + + assertResult(Array(Row(5))) { + sql("select (select key from testData where key > 3 limit 1) + 1").collect() + } + + assertResult(Array(Row(-100))) { + sql("select -(select max(key) from testData)").collect() + } + + assertResult(Array(Row(null))) { + sql("select (select value from testData limit 0)").collect() + } + + assertResult(Array(Row("99"))) { + sql("select (select min(value) from testData" + + " where key = (select max(key) from testData) - 1)").collect() + } + } +} From d806ed34365aa27895547297fff4cc48ecbeacdf Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Sun, 21 Feb 2016 00:53:15 -0800 Subject: [PATCH 1947/2089] [SPARK-13416][GraphX] Add positive check for option 'numIter' in StronglyConnectedComponents JIRA: https://issues.apache.org/jira/browse/SPARK-13416 ## What changes were proposed in this pull request? The output of StronglyConnectedComponents with numIter no greater than 1 may make no sense. So I just add require check in it. ## How was the this patch tested? unit tests passed Author: Zheng RuiFeng Closes #11284 from zhengruifeng/scccheck. --- .../apache/spark/graphx/lib/StronglyConnectedComponents.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index 8dd958033b338..7063137d47c08 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -36,7 +36,7 @@ object StronglyConnectedComponents { * @return a graph with vertex attributes containing the smallest vertex id in each SCC */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexId, ED] = { - + require(numIter > 0, s"Number of iterations ${numIter} must be greater than 0.") // the graph we update with final SCC ids, and the graph we return at the end var sccGraph = graph.mapVertices { case (vid, _) => vid } // graph we are going to work with in our iterations From d9efe63ecdc60a9955f1924de0e8a00bcb6a559d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 21 Feb 2016 22:53:15 +0800 Subject: [PATCH 1948/2089] [SPARK-12799] Simplify various string output for expressions This PR introduces several major changes: 1. Replacing `Expression.prettyString` with `Expression.sql` The `prettyString` method is mostly an internal, developer faced facility for debugging purposes, and shouldn't be exposed to users. 1. Using SQL-like representation as column names for selected fields that are not named expression (back-ticks and double quotes should be removed) Before, we were using `prettyString` as column names when possible, and sometimes the result column names can be weird. Here are several examples: Expression | `prettyString` | `sql` | Note ------------------ | -------------- | ---------- | --------------- `a && b` | `a && b` | `a AND b` | `a.getField("f")` | `a[f]` | `a.f` | `a` is a struct 1. Adding trait `NonSQLExpression` extending from `Expression` for expressions that don't have a SQL representation (e.g. Scala UDF/UDAF and Java/Scala object expressions used for encoders) `NonSQLExpression.sql` may return an arbitrary user facing string representation of the expression. Author: Cheng Lian Closes #10757 from liancheng/spark-12799.simplify-expression-string-methods. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 +- python/pyspark/sql/column.py | 32 ++++----- python/pyspark/sql/context.py | 10 +-- python/pyspark/sql/functions.py | 30 ++++---- .../sql/catalyst/analysis/Analyzer.scala | 8 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 26 +++---- .../DistinctAggregationRewriter.scala | 6 +- .../sql/catalyst/analysis/unresolved.scala | 17 ++--- .../expressions/ExpectsInputTypes.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 71 ++++++++----------- .../sql/catalyst/expressions/ScalaUDF.scala | 4 +- .../expressions/aggregate/interfaces.scala | 5 +- .../sql/catalyst/expressions/arithmetic.scala | 8 ++- .../expressions/bitwiseExpressions.scala | 2 + .../expressions/codegen/CodegenFallback.scala | 5 +- .../expressions/complexTypeExtractors.scala | 22 +++--- .../expressions/conditionalExpressions.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 39 +++------- .../expressions/mathExpressions.scala | 6 +- .../expressions/namedExpressions.scala | 20 ++++-- .../expressions/nullExpressions.scala | 2 - .../sql/catalyst/expressions/objects.scala | 25 +++---- .../sql/catalyst/expressions/predicates.scala | 8 +-- .../expressions/stringExpressions.scala | 5 -- .../expressions/windowExpressions.scala | 6 +- .../plans/logical/basicOperators.scala | 2 +- .../spark/sql/catalyst/util/package.scala | 37 +++++++--- .../apache/spark/sql/types/DecimalType.scala | 2 + .../apache/spark/sql/types/StructType.scala | 4 +- .../analysis/AnalysisErrorSuite.scala | 26 +++---- .../sql/catalyst/analysis/AnalysisTest.scala | 15 +++- .../ExpressionTypeCheckingSuite.scala | 4 +- .../encoders/EncoderResolutionSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 17 ++--- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/GroupedData.scala | 3 +- .../sql/execution/WholeStageCodegen.scala | 3 +- .../spark/sql/execution/aggregate/udaf.scala | 4 +- .../parquet/CatalystSchemaConverter.scala | 2 +- .../sql/execution/python/PythonUDF.scala | 9 ++- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 4 +- .../apache/spark/sql/hive/SQLBuilder.scala | 22 ++++-- .../org/apache/spark/sql/hive/hiveUDFs.scala | 43 +++++++---- .../sql/hive/ExpressionSQLBuilderSuite.scala | 15 ++-- .../sql/hive/LogicalPlanToSQLSuite.scala | 64 +++++++++++++++++ .../sql/hive/execution/HivePlanTest.scala | 4 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 22 +++--- .../apache/spark/sql/hive/parquetSuites.scala | 5 +- 49 files changed, 402 insertions(+), 278 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7b5713720df87..cc118108f61cc 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1047,13 +1047,13 @@ test_that("column functions", { schema = c("a", "b", "c")) result <- collect(select(df, struct("a", "c"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) result <- collect(select(df, struct(df$a, df$b))) expected <- data.frame(row.names = 1:2) - expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 320451c52c706..3866a49c0b5f6 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -219,17 +219,17 @@ def getField(self, name): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ + +---+ + |r.b| + +---+ + | b| + +---+ >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ + +---+ + |r.a| + +---+ + | 1| + +---+ """ return self[name] @@ -346,12 +346,12 @@ def between(self, lowerBound, upperBound): expression is between the given columns. >>> df.select(df.name, df.age.between(2, 4)).show() - +-----+--------------------------+ - | name|((age >= 2) && (age <= 4))| - +-----+--------------------------+ - |Alice| true| - | Bob| false| - +-----+--------------------------+ + +-----+---------------------------+ + | name|((age >= 2) AND (age <= 4))| + +-----+---------------------------+ + |Alice| true| + | Bob| false| + +-----+---------------------------+ """ return (self >= lowerBound) & (self <= upperBound) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7f6fb410abe81..89bf1443a68a1 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -92,8 +92,8 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ - time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ + dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -210,17 +210,17 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(_c0=u'4')] + [Row(stringLengthString(test)=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(_c0=4)] + [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(_c0=4)] + [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._ssql_ctx.udf().registerPython(name, udf._judf) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5fc1cc2cae10a..fdae05d98ceb6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -223,22 +223,22 @@ def coalesce(*cols): +----+----+ >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() - +-------------+ - |coalesce(a,b)| - +-------------+ - | null| - | 1| - | 2| - +-------------+ + +--------------+ + |coalesce(a, b)| + +--------------+ + | null| + | 1| + | 2| + +--------------+ >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() - +----+----+---------------+ - | a| b|coalesce(a,0.0)| - +----+----+---------------+ - |null|null| 0.0| - | 1|null| 1.0| - |null| 2| 0.0| - +----+----+---------------+ + +----+----+----------------+ + | a| b|coalesce(a, 0.0)| + +----+----+----------------+ + |null|null| 0.0| + | 1|null| 1.0| + |null| 2| 0.0| + +----+----+----------------+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column)) @@ -1528,7 +1528,7 @@ def array_contains(col, value): >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(array_contains(df.data, "a")).collect() - [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba306f8b32b7d..e153f4dd2b0fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.types._ /** @@ -165,7 +166,8 @@ class Analyzer( case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))() + case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)() + case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -328,7 +330,7 @@ class Analyzer( throw new AnalysisException( s"Aggregate expression required for pivot, found '$aggregate'") } - val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + val name = if (singleAgg) value.toString else value + "_" + aggregate.sql Alias(filteredAggregate, name)() } } @@ -1456,7 +1458,7 @@ object CleanupAliases extends Rule[LogicalPlan] { */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + throw new AnalysisException(s"Cannot up cast ${from.sql} from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fe053b9a0b27f..1e430c1fbbdf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -57,13 +57,13 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]") + a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + s"cannot resolve '${e.sql}' due to data type mismatch: $message") } case c: Cast if !c.resolved => @@ -106,12 +106,12 @@ trait CheckAnalysis { operator match { case f: Filter if f.condition.dataType != BooleanType => failAnalysis( - s"filter expression '${f.condition.prettyString}' " + + s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( - s"join condition '${condition.prettyString}' " + + s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") case j @ Join(_, _, _, Some(condition)) => @@ -119,10 +119,10 @@ trait CheckAnalysis { case p: Predicate => p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) case e if e.dataType.isInstanceOf[BinaryType] => - failAnalysis(s"binary type expression ${e.prettyString} cannot be used " + + failAnalysis(s"binary type expression ${e.sql} cannot be used " + "in join conditions") case e if e.dataType.isInstanceOf[MapType] => - failAnalysis(s"map type expression ${e.prettyString} cannot be used " + + failAnalysis(s"map type expression ${e.sql} cannot be used " + "in join conditions") case _ => // OK } @@ -144,13 +144,13 @@ trait CheckAnalysis { if (!child.deterministic) { failAnalysis( - s"nondeterministic expression ${expr.prettyString} should not " + + s"nondeterministic expression ${expr.sql} should not " + s"appear in the arguments of an aggregate function.") } } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + + s"expression '${e.sql}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() (or first_value) if you don't care " + "which value you get.") @@ -163,7 +163,7 @@ trait CheckAnalysis { // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( - s"expression ${expr.prettyString} cannot be used as a grouping expression " + + s"expression ${expr.sql} cannot be used as a grouping expression " + s"because its data type ${expr.dataType.simpleString} is not a orderable " + s"data type.") } @@ -172,7 +172,7 @@ trait CheckAnalysis { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in // a Project node. - failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + failAnalysis(s"nondeterministic expression ${expr.sql} should not " + s"appear in grouping expression.") } } @@ -217,7 +217,7 @@ trait CheckAnalysis { case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( s"""Only a single table generating function is allowed in a SELECT clause, found: - | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + | ${exprs.map(_.sql).mkString(",")}""".stripMargin) case j: Join if !j.duplicateResolved => val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) @@ -248,9 +248,9 @@ trait CheckAnalysis { failAnalysis( s"""nondeterministic expressions are only allowed in |Project, Filter, Aggregate or Window, found: - | ${o.expressions.map(_.prettyString).mkString(",")} + | ${o.expressions.map(_.sql).mkString(",")} |in operator ${operator.simpleString} - """.stripMargin) + """.stripMargin) case _ => // Analysis successful! } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 5dfce89bd68a6..b49885d469c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -130,7 +130,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP new AttributeReference("gid", IntegerType, false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() + case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -184,7 +184,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) - val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() + val operator = Alias(e.copy(aggregateFunction = af), e.sql)() // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( @@ -269,5 +269,5 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.prettyString, e.dataType, true)() + e -> new AttributeReference(e.sql, e.dataType, true)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 79eebbf9b1ec4..01afa01ae95c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types.{DataType, StructType} /** @@ -67,6 +68,8 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def toString: String = s"'$name" + + override def sql: String = quoteIdentifier(name) } object UnresolvedAttribute { @@ -141,11 +144,8 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def prettyString: String = { - s"${name}(${children.map(_.prettyString).mkString(",")})" - } - - override def toString: String = s"'$name(${children.mkString(",")})" + override def prettyName: String = name + override def toString: String = s"'$name(${children.mkString(", ")})" } /** @@ -208,10 +208,9 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu Alias(extract, f.name)() } - case _ => { + case _ => throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") - } } } else { val from = input.inputSet.map(_.name).mkString(", ") @@ -228,6 +227,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented * as follows: * MultiAlias(stack_function, Seq(a, b)) + * * @param child the computation being performed * @param names the names to be associated with each output of computing [[child]]. @@ -284,13 +284,14 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override lazy val resolved = false override def toString: String = s"$child[$extraction]" + override def sql: String = s"${child.sql}[${extraction.sql}]" } /** * Holds the expression that has yet to be aliased. * * @param child The computation that is needs to be resolved during analysis. - * @param aliasName The name if specified to be asoosicated with the result of computing [[child]] + * @param aliasName The name if specified to be associated with the result of computing [[child]] * */ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 04650d85dec03..c7be8e886ca78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -45,7 +45,7 @@ trait ExpectsInputTypes extends Expression { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type." + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c73b2f8f2a316..119496c7eee5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -96,7 +96,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. - val code = s"/* ${this.toCommentSafeString} */" + val code = s"/* ${toCommentSafeString(this.toString)} */" ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") @@ -105,7 +105,7 @@ abstract class Expression extends TreeNode[Expression] { ve.code = genCode(ctx, ve) if (ve.code != "") { // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim) } else { ve } @@ -201,17 +201,6 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyName: String = getClass.getSimpleName.toLowerCase - /** - * Returns a user-facing string representation of this expression, i.e. does not have developer - * centric debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name, a.dataType) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil @@ -219,24 +208,16 @@ abstract class Expression extends TreeNode[Expression] { override def simpleString: String = toString - override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")") /** - * Returns the string representation of this expression that is safe to be put in - * code comments of generated code. + * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]], + * this method may return an arbitrary user facing string. */ - protected def toCommentSafeString: String = this.toString - .replace("*/", "\\*\\/") - .replace("\\u", "\\\\u") - - /** - * Returns SQL representation of this expression. For expressions that don't have a SQL - * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. - */ - @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") - def sql: String = throw new UnsupportedOperationException( - s"Cannot map expression $this to its SQL representation" - ) + def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } } @@ -254,6 +235,19 @@ trait Unevaluable extends Expression { } +/** + * Expressions that don't have SQL representation should extend this trait. Examples are + * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`. + */ +trait NonSQLExpression extends Expression { + override def sql: String = { + transform { + case a: Attribute => new PrettyAttribute(a) + }.toString + } +} + + /** * An expression that is nondeterministic. */ @@ -373,8 +367,6 @@ abstract class UnaryExpression extends Expression { """ } } - - override def sql: String = s"($prettyName(${child.sql}))" } @@ -477,8 +469,6 @@ abstract class BinaryExpression extends Expression { """ } } - - override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -499,6 +489,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { def symbol: String + def sqlOperator: String = symbol + override def toString: String = s"($left $symbol $right)" override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) @@ -506,17 +498,17 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { - TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," + + TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," + s" not ${left.dataType.simpleString}") } else { TypeCheckResult.TypeCheckSuccess } } - override def sql: String = s"(${left.sql} $symbol ${right.sql})" + override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})" } @@ -623,9 +615,4 @@ abstract class TernaryExpression extends Expression { """ } } - - override def sql: String = { - val childrenSQL = children.map(_.sql).mkString(", ") - s"$prettyName($childrenSQL)" - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 681694746b84c..22184f1ddfbb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -39,11 +39,11 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with NonSQLExpression { override def nullable: Boolean = true - override def toString: String = s"UDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(", ")})" // scalastyle:off line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index f88a57a254b36..ff3064ac66d5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -92,8 +91,6 @@ private[sql] case class AggregateExpression( AttributeSet(childReferences) } - override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" override def sql: String = aggregateFunction.sql(isDistinct) @@ -168,7 +165,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu } def sql(isDistinct: Boolean): String = { - val distinct = if (isDistinct) "DISTINCT " else " " + val distinct = if (isDistinct) "DISTINCT " else "" s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1cacd3f76aa36..5af234609ddd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -319,7 +319,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } -case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { +case class MaxOf(left: Expression, right: Expression) + extends BinaryArithmetic with NonSQLExpression { + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. override def inputType: AbstractDataType = TypeCollection.Ordered @@ -373,7 +375,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "max" } -case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { +case class MinOf(left: Expression, right: Expression) + extends BinaryArithmetic with NonSQLExpression { + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a97bd9edcef84..4c90b3f7d33ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -123,4 +123,6 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } protected override def nullSafeEval(input: Any): Any = not(input) + + override def sql: String = s"~${child.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index f58a2daf902f7..1365ee4b55634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.util.toCommentSafeString /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -37,7 +38,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") if (nullable) { s""" - /* expression: ${this.toCommentSafeString} */ + /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; @@ -48,7 +49,7 @@ trait CodegenFallback extends Expression { } else { ev.isNull = "false" s""" - /* expression: ${this.toCommentSafeString} */ + /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6b24fae9f3f1c..44cdc8d8812e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -93,6 +93,8 @@ object ExtractValue { } } +trait ExtractValue extends Expression + /** * Returns the value of fields in the Struct `child`. * @@ -102,13 +104,15 @@ object ExtractValue { * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression { + extends UnaryExpression with ExtractValue { private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" + override def sql: String = + child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) @@ -130,12 +134,11 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } - - override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** - * Returns the array of value of fields in the Array of Struct `child`. + * For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array + * elements, and returns them as a new array. * * No need to do type checking since it is handled by [[ExtractValue]]. */ @@ -144,10 +147,11 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression { + containsNull: Boolean) extends UnaryExpression with ExtractValue { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" + override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -204,12 +208,13 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with ExtractValue { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) override def toString: String = s"$child[$ordinal]" + override def sql: String = s"${child.sql}[${ordinal.sql}]" override def left: Expression = child override def right: Expression = ordinal @@ -250,7 +255,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with ExtractValue { private def keyType = child.dataType.asInstanceOf[MapType].keyType @@ -258,6 +263,7 @@ case class GetMapValue(child: Expression, key: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" override def left: Expression = child override def right: Expression = key diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1eff2c4dd0086..200c6a05df7e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -35,7 +35,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ca0892eb42158..d7d768babc115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -238,37 +238,20 @@ case class Literal protected (value: Any, dataType: DataType) } override def sql: String = (value, dataType) match { - case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => - "NULL" - - case _ if value == null => - s"CAST(NULL AS ${dataType.sql})" - + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" + case _ if value == null => s"CAST(NULL AS ${dataType.sql})" case (v: UTF8String, StringType) => // Escapes all backslashes and double quotes. "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" - - case (v: Byte, ByteType) => - s"CAST($v AS ${ByteType.simpleString.toUpperCase})" - - case (v: Short, ShortType) => - s"CAST($v AS ${ShortType.simpleString.toUpperCase})" - - case (v: Long, LongType) => - s"CAST($v AS ${LongType.simpleString.toUpperCase})" - - case (v: Float, FloatType) => - s"CAST($v AS ${FloatType.simpleString.toUpperCase})" - - case (v: Decimal, DecimalType.Fixed(precision, scale)) => - s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" - - case (v: Int, DateType) => - s"DATE '${DateTimeUtils.toJavaDate(v)}'" - - case (v: Long, TimestampType) => - s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" - + case (v: Byte, ByteType) => v + "Y" + case (v: Short, ShortType) => v + "S" + case (v: Long, LongType) => v + "L" + // Float type doesn't have a suffix + case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})" + case (v: Double, DoubleType) => v + "D" + case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})" + case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'" + case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" case _ => value.toString } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 8b9a60f97ce6e..bc2df0fb4adf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -42,6 +42,7 @@ abstract class LeafMathExpression(c: Double, name: String) override def foldable: Boolean = true override def nullable: Boolean = false override def toString: String = s"$name()" + override def prettyName: String = name override def eval(input: InternalRow): Any = c } @@ -59,6 +60,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" + override def prettyName: String = name protected override def nullSafeEval(input: Any): Any = { f(input.asInstanceOf[Double]) @@ -70,8 +72,6 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } - - override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) @@ -113,6 +113,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def toString: String = s"$name($left, $right)" + override def prettyName: String = name + override def dataType: DataType = DoubleType protected override def nullSafeEval(input1: Any, input2: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 207b8a0a88556..1af543764776e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types._ object NamedExpression { @@ -183,9 +184,9 @@ case class Alias(child: Expression, name: String)( override def sql: String = { val qualifiersString = - if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".") val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name" - s"${child.sql} AS $qualifiersString`$aliasName`" + s"${child.sql} AS $qualifiersString${quoteIdentifier(aliasName)}" } } @@ -300,9 +301,9 @@ case class AttributeReference( override def sql: String = { val qualifiersString = - if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".") val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name" - s"$qualifiersString`$attrRefName`" + s"$qualifiersString${quoteIdentifier(attrRefName)}" } } @@ -310,10 +311,19 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String, dataType: DataType = NullType) +case class PrettyAttribute( + name: String, + dataType: DataType = NullType) extends Attribute with Unevaluable { + def this(attribute: Attribute) = this(attribute.name, attribute match { + case a: AttributeReference => a.dataType + case a: PrettyAttribute => a.dataType + case _ => NullType + }) + override def toString: String = name + override def sql: String = toString override def withNullability(newNullability: Boolean): Attribute = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 667d3513d32b9..e22026d584654 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,8 +83,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index fef6825b2db5e..737346dc793fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -46,7 +46,7 @@ case class StaticInvoke( dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression { + propagateNull: Boolean = true) extends Expression with NonSQLExpression { val objectName = staticObject.getName.stripSuffix("$") @@ -108,7 +108,7 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil) extends Expression { + arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true override def children: Seq[Expression] = arguments.+:(targetObject) @@ -204,7 +204,7 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[Literal]) extends Expression { + outerPointer: Option[Literal]) extends Expression with NonSQLExpression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -268,7 +268,7 @@ case class NewInstance( */ case class UnwrapOption( dataType: DataType, - child: Expression) extends UnaryExpression with ExpectsInputTypes { + child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def nullable: Boolean = true @@ -298,7 +298,7 @@ case class UnwrapOption( * @param optType The type of this option. */ case class WrapOption(child: Expression, optType: DataType) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -328,7 +328,7 @@ case class WrapOption(child: Expression, optType: DataType) * manually, but will instead be passed into the provided lambda function. */ case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression - with Unevaluable { + with Unevaluable with NonSQLExpression { override def nullable: Boolean = true @@ -368,7 +368,7 @@ object MapObjects { case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, - inputData: Expression) extends Expression { + inputData: Expression) extends Expression with NonSQLExpression { private def itemAccessorMethod(dataType: DataType): String => String = dataType match { case NullType => @@ -483,7 +483,7 @@ case class MapObjects private( * * @param children A list of expression to use as content of the external row. */ -case class CreateExternalRow(children: Seq[Expression]) extends Expression { +case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression { override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false @@ -516,7 +516,8 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) + extends UnaryExpression with NonSQLExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") @@ -558,7 +559,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression { + extends UnaryExpression with NonSQLExpression { override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Code to initialize the serializer. @@ -596,7 +597,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B * Initialize a Java Bean instance by setting its field values via setters. */ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) - extends Expression { + extends Expression with NonSQLExpression { override def nullable: Boolean = beanInstance.nullable override def children: Seq[Expression] = beanInstance +: setters.values.toSeq @@ -638,7 +639,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * non-null `s`, `s.i` can't be null. */ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) - extends UnaryExpression { + extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c290aa8825adc..20818bfb1a514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -249,6 +249,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with override def symbol: String = "&&" + override def sqlOperator: String = "AND" + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == false) { @@ -289,8 +291,6 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } - - override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -300,6 +300,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P override def symbol: String = "||" + override def sqlOperator: String = "OR" + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == true) { @@ -340,8 +342,6 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } - - override def sql: String = s"(${left.sql} OR ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b965212f27777..4be065b30a21f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,7 +23,6 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -62,8 +61,6 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -156,8 +153,6 @@ case class ConcatWs(children: Seq[Expression]) """ } } - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index afe122f6a0e85..9e6bd0ee460f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -556,8 +556,8 @@ abstract class RankLike extends AggregateWindowFunction { override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) /** Store the values of the window 'order' expressions. */ - protected val orderAttrs = children.map{ expr => - AttributeReference(expr.prettyString, expr.dataType)() + protected val orderAttrs = children.map { expr => + AttributeReference(expr.sql, expr.dataType)() } /** Predicate that detects if the order attributes have changed. */ @@ -636,7 +636,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { /** * The PercentRank function computes the percentage ranking of a value in a group of values. The - * result the rank of the minus one divided by the total number of rows in the partitiion minus one: + * result the rank of the minus one divided by the total number of rows in the partition minus one: * (r - 1) / (n - 1). If a partition only contains one row, the function will return 0. * * The PercentRank function is similar to the CumeDist function, but it uses rank values instead of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7d155ac183d58..a19ba38ba061c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -575,7 +575,7 @@ case class Pivot( override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 7a0d0de6328a5..43f707f444b07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst import java.io._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{NumericType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils package object util { @@ -130,19 +133,31 @@ package object util { ret } + // Replaces attributes, string literals, complex type extractors with their pretty form so that + // generated column names don't contain back-ticks or double-quotes. + def usePrettyExpression(e: Expression): Expression = e transform { + case a: Attribute => new PrettyAttribute(a) + case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType) + case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t) + case e: GetStructField => + val name = e.name.getOrElse(e.childSchema(e.ordinal).name) + PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType) + case e: GetArrayStructFields => + PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType) + } + + def quoteIdentifier(name: String): String = { + // Escapes back-ticks within the identifier name with double-back-ticks, and then quote the + // identifier with back-ticks. + "`" + name.replace("`", "``") + "`" + } + /** - * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. */ - def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { - case xs if xs.isEmpty => - Option(Seq.empty[T]) - - case xs => - for { - head <- xs.head - tail <- sequenceOption(xs.tail) - } yield head +: tail - } + def toCommentSafeString(str: String): String = + str.replace("*/", "\\*\\/").replace("\\u", "\\\\u") /* FIX ME implicit class debugLogging(a: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 5dd661ee6b339..71ea5b8841e1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -64,6 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" + override def sql: String = typeName.toUpperCase + /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` * can be casted into `this` safely without losing any precision or range. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e797d83cb05be..5ff5435d5ab9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DataTypeParser, LegacyTypeStringParser} /** * :: DeveloperApi :: @@ -280,7 +280,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } override def sql: String = { - val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}") s"STRUCT<${fieldTypes.mkString(", ")}>" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 69f78a097e15c..de9a56dc9c064 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -127,22 +127,24 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: - "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE))" :: "argument 1" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: - "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" :: + "argument 2" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "requires int type" :: "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" :: + "argument 1" :: "argument 2" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "invalid window function", @@ -207,7 +209,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "sorting by attributes are not from grouping expressions", testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc), - "cannot resolve" :: "'b'" :: "given input columns" :: "[a, c, a3]" :: Nil) + "cannot resolve" :: "'`b`'" :: "given input columns" :: "[a, c, a3]" :: Nil) errorTest( "non-boolean filters", @@ -222,7 +224,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "missing group by", testRelation2.groupBy('a)('b), - "'b'" :: "group by" :: Nil + "'`b`'" :: "group by" :: Nil ) errorTest( @@ -270,7 +272,7 @@ class AnalysisErrorSuite extends AnalysisTest { "SPARK-9955: correct error message for aggregate", // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), - "cannot resolve 'bad_column'" :: Nil) + "cannot resolve '`bad_column`'" :: Nil) test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) @@ -311,7 +313,7 @@ class AnalysisErrorSuite extends AnalysisTest { case true => assertAnalysisSuccess(plan, true) case false => - assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil) + assertAnalysisError(plan, "expression `a` cannot be used as a grouping expression" :: Nil) } } @@ -372,7 +374,7 @@ class AnalysisErrorSuite extends AnalysisTest { Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil) + assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil) val plan2 = Join( @@ -386,6 +388,6 @@ class AnalysisErrorSuite extends AnalysisTest { Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil) + assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index af214b7af0624..2756c463cf00e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -71,8 +71,17 @@ trait AnalysisTest extends PlanTest { val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), - s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " + - s"actually we get ${e.getMessage}") + + if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + fail( + s"""Exception message should contain the following substrings: + | + | ${expectedErrors.mkString("\n ")} + | + |Actual exception message: + | + | ${e.getMessage} + """.stripMargin) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 59549e3998e7e..92c8496fde716 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -41,7 +41,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(expr) } assert(e.getMessage.contains( - s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + s"cannot resolve '${expr.sql}' due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } @@ -52,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.prettyString}'") + s"differing types in '${expr.sql}'") } test("check types for unary arithmetic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 8b02b63c6cf3a..3ad0dae767be3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -142,7 +142,7 @@ class EncoderResolutionSuite extends PlanTest { }.message assert(msg2 == s""" - |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate |The type path of the target object is: |- field (class: "scala.Long", name: "b") |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2ab091e40a073..6c7929c362270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DataTypeParser} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -104,10 +104,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => { + case _ if name.endsWith(".*") => val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) UnresolvedStar(Some(parts)) - } case _ => UnresolvedAttribute.quotedString(name) }) @@ -123,6 +122,8 @@ class Column(protected[sql] val expr: Expression) extends Logging { // make it a NamedExpression. case u: UnresolvedAttribute => UnresolvedAlias(u) + case u: UnresolvedExtractValue => UnresolvedAlias(u) + case expr: NamedExpression => expr // Leave an unaliased generator with an empty list of names since the analyzer will generate @@ -131,7 +132,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { case jt: JsonTuple => MultiAlias(jt, Nil) - case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) // If we have a top level Cast, there is a chance to give it a better alias, if there is a // NamedExpression under this Cast. @@ -139,13 +140,13 @@ class Column(protected[sql] val expr: Expression) extends Logging { case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) } match { case ne: NamedExpression => ne - case other => Alias(expr, expr.prettyString)() + case other => Alias(expr, usePrettyExpression(expr).sql)() } - case expr: Expression => Alias(expr, expr.prettyString)() + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } - override def toString: String = expr.prettyString + override def toString: String = usePrettyExpression(expr).sql override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -987,7 +988,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { if (extended) { println(expr) } else { - println(expr.prettyString) + println(expr.sql) } // scalastyle:on println } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 76c09a285dc40..9674450118f5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator @@ -1359,7 +1360,8 @@ class DataFrame private[sql]( "min" -> ((child: Expression) => Min(child).toAggregateExpression()), "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + val outputCols = + (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c74ef2c03541e..66ec0e7338fc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.types.NumericType /** @@ -74,7 +75,7 @@ class GroupedData protected[sql]( private[this] def alias(expr: Expression): NamedExpression = expr match { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ad8564e96aa6e..990eeb22b6aa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} import org.apache.spark.sql.execution.metric.LongSQLMetricValue @@ -252,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } /** Codegened pipeline for: - * ${plan.treeString.trim} + * ${toCommentSafeString(plan.treeString.trim)} */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 812e696338362..ab178443dd123 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -324,7 +324,7 @@ private[sql] case class ScalaUDAF( udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with Logging { + extends ImperativeAggregate with NonSQLExpression with Logging { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 1c0d53fc77854..54dda0c3915b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -544,7 +544,7 @@ private[parquet] object CatalystSchemaConverter { !name.matches(".*[ ,;{}()\n\t=].*"), s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) + """.stripMargin.split("\n").mkString(" ").trim) } def checkFieldNames(schema: StructType): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 0e53a0c4737bb..9aff0be716d4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType /** @@ -36,9 +36,12 @@ case class PythonUDF( broadcastVars: java.util.List[Broadcast[PythonBroadcast]], accumulator: Accumulator[java.util.List[Array[Byte]]], dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable with Logging { + children: Seq[Expression]) + extends Expression with Unevaluable with NonSQLExpression with Logging { - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" + override def toString: String = s"PythonUDF#$name(${children.mkString(", ")})" override def nullable: Boolean = true + + override def sql: String = s"$name(${children.mkString(", ")})" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f9ba60770022d..498f007081918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -525,7 +525,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { ds.as[ClassData2] } - assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) + assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage) } test("runtime nullability check") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index bd87449f920bb..41a9404b00795 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -455,7 +455,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -479,7 +479,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index bf5edb4759fbd..1dda39d44ea50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,10 +24,11 @@ import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation /** @@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation * supported by this builder (yet). */ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans") + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) try { + canonicalizedPlan.transformAllExpressions { + case e: NonSQLExpression => + throw new UnsupportedOperationException( + s"Expression $e doesn't have a SQL representation" + ) + case e => e + } + val generatedSQL = toSQL(canonicalizedPlan) logDebug( s"""Built SQL query string successfully from given logical plan: @@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi p.child match { // Persisted data source relation case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => - s"`$database`.`$table`" + s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" // Parentheses is not used for persisted data source relations // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => @@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: MetastoreRelation => build( - s"`${p.databaseName}`.`${p.tableName}`", - p.alias.map(a => s" AS `$a`").getOrElse("") + s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}", + p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("") ) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) @@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi * The segments are trimmed so only a single space appears in the separation. * For example, `build("a", " b ", " c")` becomes "a b c". */ - private def build(segments: String*): String = segments.map(_.trim).mkString(" ") + private def build(segments: String*): String = + segments.map(_.trim).filter(_.nonEmpty).mkString(" ") private def projectToSQL(plan: Project, isDistinct: Boolean): String = { build( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d5ed838ca4b1a..bcafa045e00b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ @@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry( // catch the exception and throw AnalysisException instead. try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF( + val udf = HiveGenericUDF( name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + udf.dataType // Force it to check input data types. + udf } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) + udf.dataType // Force it to check input data types. + udf } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) + udf.dataType // Force it to check input data types. + udf } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) + udaf.dataType // Force it to check input data types. + udaf } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction( + val udaf = HiveUDAFFunction( name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. @@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - override val dataType = javaClassToDataType(method.getReturnType) + override lazy val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( @@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + override def prettyName: String = name + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } @@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF( } @transient - private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF( var i = 0 while (i < children.length) { val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set( () => { children(idx).eval(input) }) i += 1 } - unwrap(function.evaluate(deferedObjects), returnInspector) + unwrap(function.evaluate(deferredObjects), returnInspector) } + override def prettyName: String = name + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - - override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" + override def prettyName: String = name } /** @@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) + + override def prettyName: String = name override def sql(isDistinct: Boolean): String = { val distinct = if (isDistinct) "DISTINCT " else " " diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 3fb6543b1a1d9..e4b4d1861a149 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { checkSQL(Literal("foo"), "\"foo\"") checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") - checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") - checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(1: Byte), "1Y") + checkSQL(Literal(2: Short), "2S") checkSQL(Literal(4: Int), "4") - checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(8: Long), "8L") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") - checkSQL(Literal(2.5D), "2.5") + checkSQL(Literal(2.5D), "2.5D") checkSQL( Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals } + test("attributes") { + checkSQL('a.int, "`a`") + checkSQL(Symbol("foo bar").int, "`foo bar`") + // Keyword + checkSQL('int.int, "`int`") + } + test("binary comparisons") { checkSQL('a.int === 'b.int, "(`a` = `b`)") checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index dc8ac7e47ffec..5255b150aa6be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { sql("DROP TABLE IF EXISTS t0") sql("DROP TABLE IF EXISTS t1") sql("DROP TABLE IF EXISTS t2") + sqlContext.range(10).write.saveAsTable("t0") sqlContext @@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } } } + + test("plans with non-SQL expressions") { + sqlContext.udf.register("foo", (_: Int) * 2) + intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) + } + + test("named expression in column names shouldn't be quoted") { + def checkColumnNames(query: String, expectedColNames: String*): Unit = { + checkHiveQl(query) + assert(sql(query).columns === expectedColNames) + } + + // Attributes + checkColumnNames( + """SELECT * FROM ( + | SELECT 1 AS a, 2 AS b, 3 AS `we``ird` + |) s + """.stripMargin, + "a", "b", "we`ird" + ) + + checkColumnNames( + """SELECT x.a, y.a, x.b, y.b + |FROM (SELECT 1 AS a, 2 AS b) x + |INNER JOIN (SELECT 1 AS a, 2 AS b) y + |ON x.a = y.a + """.stripMargin, + "a", "a", "b", "b" + ) + + // String literal + checkColumnNames( + "SELECT 'foo', '\"bar\\''", + "foo", "\"bar\'" + ) + + // Numeric literals (should have CAST or suffixes in column names) + checkColumnNames( + "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D", + "1", "2", "3", "4", "5.1", "6.1" + ) + + // Aliases + checkColumnNames( + "SELECT 1 AS a", + "a" + ) + + // Complex type extractors + checkColumnNames( + """SELECT + | a.f1, b[0].f1, b.f1, c["foo"], d[0] + |FROM ( + | SELECT + | NAMED_STRUCT("f1", 1, "f2", "foo") AS a, + | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b, + | MAP("foo", 1) AS c, + | ARRAY(1) AS d + |) s + """.stripMargin, + "f1", "b[0].f1", "f1", "c[foo]", "d[0]" + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index cd055f9eca37e..d8d3448adde0b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton { test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") - val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan + val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 27ea3e8041651..fe446774effbb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val df = sql( """ |SELECT - | CAST(null as TINYINT), - | CAST(null as SMALLINT), - | CAST(null as INT), - | CAST(null as BIGINT), - | CAST(null as FLOAT), - | CAST(null as DOUBLE), - | CAST(null as DECIMAL(7,2)), - | CAST(null as TIMESTAMP), - | CAST(null as DATE), - | CAST(null as STRING), - | CAST(null as VARCHAR(10)) + | CAST(null as TINYINT) as c0, + | CAST(null as SMALLINT) as c1, + | CAST(null as INT) as c2, + | CAST(null as BIGINT) as c3, + | CAST(null as FLOAT) as c4, + | CAST(null as DOUBLE) as c5, + | CAST(null as DECIMAL(7,2)) as c6, + | CAST(null as TIMESTAMP) as c7, + | CAST(null as DATE) as c8, + | CAST(null as STRING) as c9, + | CAST(null as VARCHAR(10)) as c10 |FROM orc_temp_table limit 1 """.stripMargin) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c997453803b09..a6ca7d0386b22 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest { sql( s"""CREATE TABLE array_of_struct |STORED AS PARQUET LOCATION '$path' - |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + |AS SELECT + | '1st' AS a, + | '2nd' AS b, + | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c """.stripMargin) checkAnswer( From 1a340da8d7590d831b040c74f5a6eb560e14d585 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Sun, 21 Feb 2016 16:27:56 +0000 Subject: [PATCH 1949/2089] [SPARK-13248][STREAMING] Remove deprecated Streaming APIs. Remove deprecated Streaming APIs and adjust sample applications. Author: Luciano Resende Closes #11139 from lresende/streaming-deprecated-apis. --- .../JavaRecoverableNetworkWordCount.java | 18 ++-- project/MimaExcludes.scala | 13 +++ .../spark/streaming/StreamingContext.scala | 38 -------- .../streaming/api/java/JavaDStreamLike.scala | 64 ------------- .../api/java/JavaStreamingContext.scala | 96 ------------------- .../spark/streaming/dstream/DStream.scala | 22 ----- .../spark/streaming/scheduler/BatchInfo.scala | 3 - .../spark/streaming/DStreamClosureSuite.scala | 7 -- 8 files changed, 22 insertions(+), 239 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index bc8cbcdef7272..f9929fc86dc7f 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -26,17 +26,14 @@ import java.util.regex.Pattern; import scala.Tuple2; + import com.google.common.io.Files; import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.api.java.function.VoidFunction2; +import org.apache.spark.api.java.function.*; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; @@ -44,7 +41,6 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; /** * Use this singleton to get or register a Broadcast variable. @@ -204,13 +200,17 @@ public static void main(String[] args) { final int port = Integer.parseInt(args[1]); final String checkpointDirectory = args[2]; final String outputPath = args[3]; - JavaStreamingContextFactory factory = new JavaStreamingContextFactory() { + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + Function0 createContextFunc = new Function0() { @Override - public JavaStreamingContext create() { + public JavaStreamingContext call() { return createContext(ip, port, checkpointDirectory, outputPath); } }; - JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, factory); + + JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, createContextFunc); ssc.start(); ssc.awaitTermination(); } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 65375a3ea787b..8f31a81ed01e7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -219,6 +219,19 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") + ) ++ Seq( + // SPARK-12348 Remove deprecated Streaming APIs. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate") ) ++ Seq( // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index a1b25c9f7da7f..25e61578a1860 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -269,20 +269,6 @@ class StreamingContext private[streaming] ( RDDOperationScope.withScope(sc, name, allowNesting = false, ignoreParent = false)(body) } - /** - * Create an input stream with any arbitrary user implemented receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * @param receiver Custom implementation of Receiver - * - * @deprecated As of 1.0.0 replaced by `receiverStream`. - */ - @deprecated("Use receiverStream", "1.0.0") - def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { - withNamedScope("network stream") { - receiverStream(receiver) - } - } - /** * Create an input stream with any arbitrary user implemented receiver. * Find more details at http://spark.apache.org/docs/latest/streaming-custom-receivers.html @@ -621,18 +607,6 @@ class StreamingContext private[streaming] ( waiter.waitForStopOrError() } - /** - * Wait for the execution to stop. Any exceptions that occurs during the execution - * will be thrown in this thread. - * @param timeout time to wait in milliseconds - * - * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. - */ - @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") - def awaitTermination(timeout: Long) { - waiter.waitForStopOrError(timeout) - } - /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. @@ -777,18 +751,6 @@ object StreamingContext extends Logging { } } - /** - * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. - * This is kept here only for backward compatibility. - */ - @deprecated("Replaced by implicit functions in the DStream companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) - : PairDStreamFunctions[K, V] = { - DStream.toPairDStreamFunctions(stream)(kt, vt, ord) - } - /** * :: Experimental :: * diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 9931a46d33a9f..65aab2fac162a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -211,26 +211,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f) - /** - * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a sliding window over this DStream. - * @param reduceFunc associative and commutative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @deprecated As this API is not Java compatible. - */ - @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0") - def reduceByWindow( - reduceFunc: (T, T) => T, - windowDuration: Duration, - slideDuration: Duration - ): DStream[T] = { - dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) - } - /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. @@ -281,50 +261,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.slice(fromTime, toTime).map(wrapRDD).asJava } - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of release 0.9.0, replaced by foreachRDD - */ - @deprecated("Use foreachRDD", "0.9.0") - def foreach(foreachFunc: JFunction[R, Void]) { - foreachRDD(foreachFunc) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of release 0.9.0, replaced by foreachRDD - */ - @deprecated("Use foreachRDD", "0.9.0") - def foreach(foreachFunc: JFunction2[R, Time, Void]) { - foreachRDD(foreachFunc) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) - */ - @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") - def foreachRDD(foreachFunc: JFunction[R, Void]) { - dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) - */ - @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") - def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { - dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) - } - /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 7a25ce54b6ff0..f8f1336693b85 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -154,12 +154,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) - /** - * @deprecated As of 0.9.0, replaced by `sparkContext` - */ - @deprecated("use sparkContext", "0.9.0") - val sc: JavaSparkContext = sparkContext - /** * Create an input stream from network source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited @@ -568,17 +562,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.awaitTermination() } - /** - * Wait for the execution to stop. Any exceptions that occurs during the execution - * will be thrown in this thread. - * @param timeout time to wait in milliseconds - * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. - */ - @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") - def awaitTermination(timeout: Long): Unit = { - ssc.awaitTermination(timeout) - } - /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. @@ -623,78 +606,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { */ object JavaStreamingContext { - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program - * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. - */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") - def getOrCreate( - checkpointPath: String, - factory: JavaStreamingContextFactory - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - factory.create.ssc - }) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. - */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") - def getOrCreate( - checkpointPath: String, - hadoopConf: Configuration, - factory: JavaStreamingContextFactory - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - factory.create.ssc - }, hadoopConf) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. - */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") - def getOrCreate( - checkpointPath: String, - hadoopConf: Configuration, - factory: JavaStreamingContextFactory, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - factory.create.ssc - }, hadoopConf, createOnError) - new JavaStreamingContext(ssc) - } - /** * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be @@ -767,10 +678,3 @@ object JavaStreamingContext { */ def jarOfClass(cls: Class[_]): Array[String] = SparkContext.jarOfClass(cls).toArray } - -/** - * Factory interface for creating a new JavaStreamingContext - */ -trait JavaStreamingContextFactory { - def create(): JavaStreamingContext -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 70e1d8abded4c..102a0308185ba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -618,28 +618,6 @@ abstract class DStream[T: ClassTag] ( this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of 0.9.0, replaced by `foreachRDD`. - */ - @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { - this.foreachRDD(foreachFunc) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of 0.9.0, replaced by `foreachRDD`. - */ - @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { - this.foreachRDD(foreachFunc) - } - /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 436eb0a566141..5b2b959f8138d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -41,9 +41,6 @@ case class BatchInfo( outputOperationInfos: Map[Int, OutputOperationInfo] ) { - @deprecated("Use streamIdToInputInfo instead", "1.5.0") - def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) - /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index e897de3cba6d2..1fc34f569f9f4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -56,7 +56,6 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { testFilter(dstream) testMapPartitions(dstream) testReduce(dstream) - testForeach(dstream) testForeachRDD(dstream) testTransform(dstream) testTransformWith(dstream) @@ -106,12 +105,6 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { ds.reduce { case (_, _) => return; 1 } } - private def testForeach(ds: DStream[Int]): Unit = { - val foreachF1 = (rdd: RDD[Int], t: Time) => return - val foreachF2 = (rdd: RDD[Int]) => return - expectCorrectException { ds.foreach(foreachF1) } - expectCorrectException { ds.foreach(foreachF2) } - } private def testForeachRDD(ds: DStream[Int]): Unit = { val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return val foreachRDDF2 = (rdd: RDD[Int]) => return From 0947f0989b136841fa0601295fc09b36f16cb933 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 21 Feb 2016 11:31:46 -0800 Subject: [PATCH 1950/2089] [SPARK-13420][SQL] Rename Subquery logical plan to SubqueryAlias ## What changes were proposed in this pull request? This patch renames logical.Subquery to logical.SubqueryAlias, which is a more appropriate name for this operator (versus subqueries as expressions). ## How was the this patch tested? Unit tests. Author: Reynold Xin Closes #11288 from rxin/SPARK-13420. --- .../spark/sql/catalyst/CatalystQl.scala | 4 ++-- .../sql/catalyst/analysis/Analyzer.scala | 12 ++++++------ .../spark/sql/catalyst/analysis/Catalog.scala | 10 +++++----- .../spark/sql/catalyst/dsl/package.scala | 7 ++++--- .../sql/catalyst/expressions/subquery.scala | 4 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 4 ++-- .../plans/logical/basicOperators.scala | 4 ++-- .../spark/sql/catalyst/CatalystQlSuite.scala | 4 ++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 ++-- .../sql/catalyst/analysis/AnalysisTest.scala | 4 ++-- .../BooleanSimplificationSuite.scala | 2 +- .../optimizer/CollapseProjectSuite.scala | 4 ++-- .../optimizer/ConstantFoldingSuite.scala | 5 ++--- ...la => EliminateSubqueryAliasesSuite.scala} | 14 +++++++------- .../optimizer/FilterPushdownSuite.scala | 19 ++++++++++--------- .../catalyst/optimizer/JoinOrderSuite.scala | 6 +++--- .../optimizer/LimitPushdownSuite.scala | 4 ++-- .../catalyst/optimizer/OptimizeInSuite.scala | 5 ++--- .../optimizer/OuterJoinEliminationSuite.scala | 4 ++-- .../optimizer/SetOperationSuite.scala | 4 ++-- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../sql/execution/datasources/rules.scala | 4 ++-- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 10 +++++----- .../apache/spark/sql/hive/SQLBuilder.scala | 8 ++++---- .../spark/sql/hive/execution/commands.scala | 4 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 4 ++-- 30 files changed, 83 insertions(+), 83 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{EliminateSubQueriesSuite.scala => EliminateSubqueryAliasesSuite.scala} (83%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index b16025a17e2e8..069c665a39680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -243,7 +243,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C queryArgs match { case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => val cteRelations = ctes.map { node => - val relation = nodeToRelation(node).asInstanceOf[Subquery] + val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias] relation.alias -> relation } (Some(from.head), inserts, Some(cteRelations.toMap)) @@ -454,7 +454,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C protected def nodeToRelation(node: ASTNode): LogicalPlan = { node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query)) + SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => nodeToGenerate( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e153f4dd2b0fa..562711a1b990d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -118,7 +118,7 @@ class Analyzer( // see https://github.com/apache/spark/pull/4929#discussion_r27186638 for more info case u : UnresolvedRelation => val substituted = cteRelations.get(u.tableIdentifier.table).map { relation => - val withAlias = u.alias.map(Subquery(_, relation)) + val withAlias = u.alias.map(SubqueryAlias(_, relation)) withAlias.getOrElse(relation) } substituted.getOrElse(u) @@ -357,7 +357,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => - i.copy(table = EliminateSubQueries(getTable(u))) + i.copy(table = EliminateSubqueryAliases(getTable(u))) case u: UnresolvedRelation => try { getTable(u) @@ -688,7 +688,7 @@ class Analyzer( resolved } else { plan match { - case u: UnaryNode if !u.isInstanceOf[Subquery] => + case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => resolveExpressionRecursively(resolved, u.child) case other => resolved } @@ -1372,12 +1372,12 @@ class Analyzer( } /** - * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide * scoping information for attributes and can be removed once analysis is complete. */ -object EliminateSubQueries extends Rule[LogicalPlan] { +object EliminateSubqueryAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Subquery(_, child) => child + case SubqueryAlias(_, child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index f2f9ec59417ef..67edab55dba8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, EmptyConf, TableIdentifier} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} /** * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception @@ -110,12 +110,12 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { if (table == null) { throw new AnalysisException("Table not found: " + tableName) } - val tableWithQualifiers = Subquery(tableName, table) + val tableWithQualifiers = SubqueryAlias(tableName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. alias - .map(a => Subquery(a, tableWithQualifiers)) + .map(a => SubqueryAlias(a, tableWithQualifiers)) .getOrElse(tableWithQualifiers) } @@ -158,11 +158,11 @@ trait OverrideCatalog extends Catalog { getOverriddenTable(tableIdent) match { case Some(table) => val tableName = getTableName(tableIdent) - val tableWithQualifiers = Subquery(tableName, table) + val tableWithQualifiers = SubqueryAlias(tableName, table) // If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes // are properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) case None => super.lookupRelation(tableIdent, alias) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5ac1984043d87..1a2ec7ed931ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -268,7 +268,7 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) @@ -290,7 +290,8 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) + def analyze: LogicalPlan = + EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index a8f5e1f63d4c7..d0c44b032866b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types.DataType /** @@ -56,7 +56,7 @@ case class ScalarSubquery( exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { - override def plan: LogicalPlan = Subquery(toString, query) + override def plan: LogicalPlan = SubqueryAlias(toString, query) override lazy val resolved: Boolean = query.resolved diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 202f6b5e07fce..1554382840a48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubqueryAliases} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -40,7 +40,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // However, because we also use the analyzer to canonicalized queries (for view definition), // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, - EliminateSubQueries, + EliminateSubqueryAliases, ComputeCurrentTime) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 86bd33f526209..5e7d144ae4107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.Subquery +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 35e0f5d5639a9..35df2429dbe46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -128,8 +128,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * can do better should override this function. */ def sameResult(plan: LogicalPlan): Boolean = { - val cleanLeft = EliminateSubQueries(this) - val cleanRight = EliminateSubQueries(plan) + val cleanLeft = EliminateSubqueryAliases(this) + val cleanRight = EliminateSubqueryAliases(plan) cleanLeft.getClass == cleanRight.getClass && cleanLeft.children.size == cleanRight.children.size && { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a19ba38ba061c..70ecbce82927f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -350,7 +350,7 @@ case class InsertIntoTable( * key is the alias of the CTE definition, * value is the CTE definition. */ -case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends UnaryNode { +case class With(child: LogicalPlan, cteRelations: Map[String, SubqueryAlias]) extends UnaryNode { override def output: Seq[Attribute] = child.output } @@ -623,7 +623,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } } -case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { +case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index ed7121831ac29..42147f516f964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -51,7 +51,7 @@ class CatalystQlSuite extends PlanTest { val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") val expected = Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - Subquery("u_1", + SubqueryAlias("u_1", Distinct( Union( Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, @@ -66,7 +66,7 @@ class CatalystQlSuite extends PlanTest { val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") val expected = Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - Subquery("u_1", + SubqueryAlias("u_1", Union( Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, UnresolvedRelation(TableIdentifier("t0"), None)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 92db02944c6ac..aa1d2b08613dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -344,8 +344,8 @@ class AnalysisSuite extends AnalysisTest { val query = Project(Seq($"x.key", $"y.key"), Join( - Project(Seq($"x.key"), Subquery("x", input)), - Project(Seq($"y.key"), Subquery("y", input)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + Project(Seq($"y.key"), SubqueryAlias("y", input)), Inner, None)) assertAnalysisSuccess(query) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 2756c463cf00e..e0a95ba8bbd5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,10 +35,10 @@ trait AnalysisTest extends PlanTest { caseInsensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil + override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } -> new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil + override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 6932f185b9d62..3e52441519ae2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -31,7 +31,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index f5fd5ca6beb15..833f054659bee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Rand @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) :: Batch("CollapseProject", Once, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 48f9ac77b74c3..641c89873dcc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue} -// For implicit conversions +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -32,7 +31,7 @@ class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, OptimizeIn, ConstantFolding, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala similarity index 83% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala index e0d430052fb55..9b6d68aee803a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubQueriesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -28,10 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class EliminateSubQueriesSuite extends PlanTest with PredicateHelper { +class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("EliminateSubQueries", Once, EliminateSubQueries) :: Nil + val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil } private def assertEquivalent(e1: Expression, e2: Expression): Unit = { @@ -46,13 +46,13 @@ class EliminateSubQueriesSuite extends PlanTest with PredicateHelper { test("eliminate top level subquery") { val input = LocalRelation('a.int, 'b.int) - val query = Subquery("a", input) + val query = SubqueryAlias("a", input) comparePlans(afterOptimization(query), input) } test("eliminate mid-tree subquery") { val input = LocalRelation('a.int, 'b.int) - val query = Filter(TrueLiteral, Subquery("a", input)) + val query = Filter(TrueLiteral, SubqueryAlias("a", input)) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) @@ -60,10 +60,10 @@ class EliminateSubQueriesSuite extends PlanTest with PredicateHelper { test("eliminate multiple subqueries") { val input = LocalRelation('a.int, 'b.int) - val query = Filter(TrueLiteral, Subquery("c", Subquery("b", Subquery("a", input)))) + val query = Filter(TrueLiteral, + SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input)))) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b49ca928b6292..7805723ec86e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -32,7 +32,7 @@ class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Filter Pushdown", Once, SamplePushDown, CombineFilters, @@ -481,7 +481,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) + comparePlans(analysis.EliminateSubqueryAliases(originalQuery.analyze), optimized) } test("joins: conjunctive predicates") { @@ -500,7 +500,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } test("joins: conjunctive predicates #2") { @@ -519,7 +519,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } test("joins: conjunctive predicates #3") { @@ -543,7 +543,7 @@ class FilterPushdownSuite extends PlanTest { condition = Some("z.a".attr === "x.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) @@ -619,7 +619,7 @@ class FilterPushdownSuite extends PlanTest { x.select('a) .sortBy(SortOrder('a, Ascending)).analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) // push down invalid val originalQuery1 = { @@ -634,7 +634,7 @@ class FilterPushdownSuite extends PlanTest { .sortBy(SortOrder('a, Ascending)) .select('b).analyze - comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) } test("push project and filter down into sample") { @@ -642,7 +642,8 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = Sample(0.0, 0.6, false, 11L, x).select('a) - val originalQueryAnalyzed = EliminateSubQueries(analysis.SimpleAnalyzer.execute(originalQuery)) + val originalQueryAnalyzed = + EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery)) val optimized = Optimize.execute(originalQueryAnalyzed) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala index 858a0d8fde3ea..a5b487bcc822f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression @@ -33,7 +33,7 @@ class JoinOrderSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, @@ -90,6 +90,6 @@ class JoinOrderSuite extends PlanTest { .join(y, condition = Some("y.d".attr === "z.a".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index fc1e99458179c..dcbc79365c3aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add @@ -30,7 +30,7 @@ class LimitPushdownSuite extends PlanTest { private object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Limit pushdown", FixedPoint(100), LimitPushDown, CombineLimits, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 3e384e473e5f7..17255ecfe870e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} -// For implicit conversions +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -34,7 +33,7 @@ class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), NullPropagation, ConstantFolding, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index a1dc836a5fd1c..5e6e54dc741f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans._ @@ -28,7 +28,7 @@ class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, OuterJoinElimination, PushPredicateThroughJoin) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index b8ea32b4dfe01..50f3b512d94ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -28,7 +28,7 @@ class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Union Pushdown", Once, CombineUnions, SetOperationPushDown, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 9674450118f5b..e3412f7a2ea4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -671,7 +671,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def as(alias: String): DataFrame = withPlan { - Subquery(alias, logicalPlan) + SubqueryAlias(alias, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 378763268acc6..ea7e7255abde6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -142,7 +142,7 @@ class Dataset[T] private[sql]( * the same name after two Datasets have been joined. * @since 1.6.0 */ - def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) + def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _)) /** * Converts this strongly typed collection of data to generic Dataframe. In contrast to the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 9358c9c37bf67..2e41e88392600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -40,7 +40,7 @@ private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[Logica provider = u.tableIdentifier.database.get, options = Map("path" -> u.tableIdentifier.table)) val plan = LogicalRelation(resolved.relation) - u.alias.map(a => Subquery(u.alias.get, plan)).getOrElse(plan) + u.alias.map(a => SubqueryAlias(u.alias.get, plan)).getOrElse(plan) } catch { case e: ClassNotFoundException => u case e: Exception => @@ -171,7 +171,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // the query. If so, we will throw an AnalysisException to let users know it is not allowed. if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match { + EliminateSubqueryAliases(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _, _) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ac174aa6bfa6d..88fda16ca0ef8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -353,7 +353,7 @@ class HiveContext private[hive]( */ def analyze(tableName: String) { val tableIdent = sqlParser.parseTableIdentifier(tableName) - val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) + val relation = EliminateSubqueryAliases(catalog.lookupRelation(tableIdent)) relation match { case relation: MetastoreRelation => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 61d0d6759ff72..c222b006a0f32 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -416,17 +416,17 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte if (table.properties.get("spark.sql.sources.provider").isDefined) { val dataSourceTable = cachedDataSourceTables(qualifiedTableName) - val tableWithQualifiers = Subquery(qualifiedTableName.name, dataSourceTable) + val tableWithQualifiers = SubqueryAlias(qualifiedTableName.name, dataSourceTable) // Then, if alias is specified, wrap the table with a Subquery using the alias. // Otherwise, wrap the table with a Subquery using the table name. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } else if (table.tableType == VirtualView) { val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(table.name, hive.parseSql(viewText)) - case Some(aliasText) => Subquery(aliasText, hive.parseSql(viewText)) + case None => SubqueryAlias(table.name, hive.parseSql(viewText)) + case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText)) } } else { MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) @@ -564,7 +564,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte case relation: MetastoreRelation if hive.convertMetastoreParquet && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) - Subquery(relation.alias.getOrElse(relation.tableName), parquetRelation) + SubqueryAlias(relation.alias.getOrElse(relation.tableName), parquetRelation) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1dda39d44ea50..32f17f41c950c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -102,14 +102,14 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val childrenSql = p.children.map(toSQL(_)) childrenSql.mkString(" UNION ALL ") - case p: Subquery => + case p: SubqueryAlias => p.child match { // Persisted data source relation case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" // Parentheses is not used for persisted data source relations // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 - case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => + case SubqueryAlias(_, _: LogicalRelation | _: MetastoreRelation) => build(toSQL(p.child), "AS", p.alias) case _ => build("(" + toSQL(p.child) + ")", "AS", p.alias) @@ -215,7 +215,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi wrapChildWithSubquery(plan) case plan @ Project(_, - _: Subquery + _: SubqueryAlias | _: Filter | _: Join | _: MetastoreRelation @@ -237,7 +237,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi a.withQualifiers(alias :: Nil) }.asInstanceOf[NamedExpression]) - Project(aliasedProjectList, Subquery(alias, child)) + Project(aliasedProjectList, SubqueryAlias(alias, child)) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index e703ac016496d..246b52a3b01d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ @@ -220,7 +220,7 @@ case class CreateMetastoreDataSourceAsSelect( provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { + EliminateSubqueryAliases(sqlContext.catalog.lookupRelation(tableIdent)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => if (l.relation != createdRelation.relation) { val errorDescription = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 75fa09db69ea0..6624494e08ea8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.functions._ @@ -264,7 +264,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS without serde") { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubQueries(catalog.lookupRelation(TableIdentifier(tableName))) + val relation = EliminateSubqueryAliases(catalog.lookupRelation(TableIdentifier(tableName))) relation match { case LogicalRelation(r: ParquetRelation, _, _) => if (!isDataSourceParquet) { From af441ddbd13f48c09b458c451d7bba3965a878d1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 21 Feb 2016 12:27:02 -0800 Subject: [PATCH 1951/2089] [SPARK-13306][SQL] Addendum to uncorrelated scalar subquery ## What changes were proposed in this pull request? This pull request fixes some minor issues (documentation, test flakiness, test organization) with #11190, which was merged earlier tonight. ## How was the this patch tested? unit tests. Author: Reynold Xin Closes #11285 from rxin/subquery. --- .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../sql/catalyst/expressions/literals.scala | 11 ---- .../sql/catalyst/expressions/subquery.scala | 5 +- .../spark/sql/execution/SparkPlan.scala | 49 ++++++++------- .../apache/spark/sql/execution/subquery.scala | 2 +- .../org/apache/spark/sql/SubquerySuite.scala | 61 +++++++++---------- 6 files changed, 61 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 562711a1b990d..23e4709bbd882 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -123,13 +123,12 @@ class Analyzer( } substituted.getOrElse(u) case other => - // This can't be done in ResolveSubquery because that does not know the CTE. + // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.query, cteRelations)) } } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d7d768babc115..37bfe98d3ab24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -255,14 +255,3 @@ case class Literal protected (value: Any, dataType: DataType) case _ => value.toString } } - -// TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) - extends LeafExpression with CodegenFallback { - - def update(expression: Expression, input: InternalRow): Unit = { - value = expression.eval(input) - } - - override def eval(input: InternalRow): Any = value -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d0c44b032866b..ddf214a4b30ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -45,9 +45,8 @@ abstract class SubqueryExpression extends LeafExpression { } /** - * A subquery that will return only one row and one column. - * - * This will be converted into [[execution.ScalarSubquery]] during physical planning. + * A subquery that will return only one row and one column. This will be converted into a physical + * scalar subquery during planning. * * Note: `exprId` is used to have unique name in explain string output. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 872ccde883060..477a9460d7dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null) + protected[spark] final val sqlContext = SQLContext.getActive().orNull protected def sparkContext = sqlContext.sparkContext @@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } - // All the subqueries and their Future of results. - @transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]() + /** + * List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node. + * This list is populated by [[prepareSubqueries]], which is called in [[prepare]]. + */ + @transient + private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])] /** - * Collects all the subqueries and create a Future to take the first two rows of them. + * Finds scalar subquery expressions in this plan node and starts evaluating them. + * The list of subqueries are added to [[subqueryResults]]. */ protected def prepareSubqueries(): Unit = { val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => val futureResult = Future { - // We only need the first row, try to take two rows so we can throw an exception if there - // are more than one rows returned. + // Each subquery should return only one row (and one column). We take two here and throws + // an exception later if the number of rows is greater than one. e.executedPlan.executeTake(2) }(SparkPlan.subqueryExecutionContext) - queryResults += e -> futureResult + subqueryResults += e -> futureResult } } /** - * Waits for all the subqueries to finish and updates the results. + * Blocks the thread until all subqueries finish evaluation and update the results. */ protected def waitForSubqueries(): Unit = { // fill in the result of subqueries - queryResults.foreach { - case (e, futureResult) => - val rows = Await.result(futureResult, Duration.Inf) - if (rows.length > 1) { - sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") - } - if (rows.length == 1) { - assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column") - e.updateResult(rows(0).get(0, e.dataType)) - } else { - // There is no rows returned, the result should be null. - e.updateResult(null) - } + subqueryResults.foreach { case (e, futureResult) => + val rows = Await.result(futureResult, Duration.Inf) + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, + s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") + e.updateResult(rows(0).get(0, e.dataType)) + } else { + // If there is no rows returned, the result should be null. + e.updateResult(null) + } } - queryResults.clear() + subqueryResults.clear() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 9c645c78e8732..e6d7480b0422c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -62,7 +62,7 @@ case class ScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e851eb02f01b3..21b19fe7df8b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -20,65 +20,64 @@ package org.apache.spark.sql import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("simple uncorrelated scalar subquery") { assertResult(Array(Row(1))) { sql("select (select 1 as b) as b").collect() } - assertResult(Array(Row(1))) { - sql("with t2 as (select 1 as b, 2 as c) " + - "select a from (select 1 as a union all select 2 as a) t " + - "where a = (select max(b) from t2) ").collect() - } - assertResult(Array(Row(3))) { sql("select (select (select 1) + 1) + 1").collect() } - // more than one columns - val error = intercept[AnalysisException] { - sql("select (select 1, 2) as b").collect() - } - assert(error.message contains "Scalar subquery must return only one column, but got 2") - - // more than one rows - val error2 = intercept[RuntimeException] { - sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() - } - assert(error2.getMessage contains - "more than one row returned by a subquery used as an expression") - // string type assertResult(Array(Row("s"))) { sql("select (select 's' as s) as b").collect() } + } - // zero rows + test("uncorrelated scalar subquery in CTE") { + assertResult(Array(Row(1))) { + sql("with t2 as (select 1 as b, 2 as c) " + + "select a from (select 1 as a union all select 2 as a) t " + + "where a = (select max(b) from t2) ").collect() + } + } + + test("uncorrelated scalar subquery should return null if there is 0 rows") { assertResult(Array(Row(null))) { sql("select (select 's' as s limit 0) as b").collect() } } - test("uncorrelated scalar subquery on testData") { - // initialize test Data - testData + test("runtime error when the number of rows is greater than 1") { + val error2 = intercept[RuntimeException] { + sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() + } + assert(error2.getMessage.contains( + "more than one row returned by a subquery used as an expression")) + } + + test("uncorrelated scalar subquery on a DataFrame generated query") { + val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value") + df.registerTempTable("subqueryData") - assertResult(Array(Row(5))) { - sql("select (select key from testData where key > 3 limit 1) + 1").collect() + assertResult(Array(Row(4))) { + sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect() } - assertResult(Array(Row(-100))) { - sql("select -(select max(key) from testData)").collect() + assertResult(Array(Row(-3))) { + sql("select -(select max(key) from subqueryData)").collect() } assertResult(Array(Row(null))) { - sql("select (select value from testData limit 0)").collect() + sql("select (select value from subqueryData limit 0)").collect() } - assertResult(Array(Row("99"))) { - sql("select (select min(value) from testData" + - " where key = (select max(key) from testData) - 1)").collect() + assertResult(Array(Row("two"))) { + sql("select (select min(value) from subqueryData" + + " where key = (select max(key) from subqueryData) - 1)").collect() } } } From b6a873d6d4682796f55dbafadd0b5cad881f96ea Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 21 Feb 2016 12:32:31 -0800 Subject: [PATCH 1952/2089] [SPARK-13136][SQL] Create a dedicated Broadcast exchange operator Quite a few Spark SQL join operators broadcast one side of the join to all nodes. The are a few problems with this: - This conflates broadcasting (a data exchange) with joining. Data exchanges should be managed by a different operator. - All these nodes implement their own (duplicate) broadcasting logic. - Re-use of indices is quite hard. This PR defines both a ```BroadcastDistribution``` and ```BroadcastPartitioning```, these contain a `BroadcastMode`. The `BroadcastMode` defines the way in which we transform the Array of `InternalRow`'s into an index. We currently support the following `BroadcastMode`'s: - IdentityBroadcastMode: This broadcasts the rows in their original form. - HashSetBroadcastMode: This applies a projection to the input rows, deduplicates these rows and broadcasts the resulting `Set`. - HashedRelationBroadcastMode: This transforms the input rows into a `HashedRelation`, and broadcasts this index. To match this distribution we implement a ```BroadcastExchange``` operator which will perform the broadcast for us, and have ```EnsureRequirements``` plan this operator. The old Exchange operator has been renamed into ShuffleExchange in order to clearly separate between Shuffled and Broadcasted exchanges. Finally the classes in Exchange.scala have been moved to a dedicated package. cc rxin davies Author: Herman van Hovell Closes #11083 from hvanhovell/SPARK-13136. --- .../plans/physical/broadcastMode.scala | 35 +++ .../plans/physical/partitioning.scala | 30 +- .../org/apache/spark/sql/SQLContext.scala | 12 +- .../spark/sql/execution/SparkPlan.scala | 34 ++- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../sql/execution/WholeStageCodegen.scala | 5 + .../exchange/BroadcastExchange.scala | 89 ++++++ .../exchange/EnsureRequirements.scala | 261 ++++++++++++++++++ .../{ => exchange}/ExchangeCoordinator.scala | 46 +-- .../ShuffleExchange.scala} | 255 +---------------- .../execution/joins/BroadcastHashJoin.scala | 75 +---- .../joins/BroadcastLeftSemiJoinHash.scala | 25 +- .../joins/BroadcastNestedLoopJoin.scala | 25 +- .../sql/execution/joins/HashSemiJoin.scala | 51 ++-- .../sql/execution/joins/HashedRelation.scala | 22 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 23 +- .../apache/spark/sql/execution/limit.scala | 7 +- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../execution/ExchangeCoordinatorSuite.scala | 21 +- .../spark/sql/execution/ExchangeSuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 21 +- .../execution/joins/BroadcastJoinSuite.scala | 5 +- .../sql/execution/joins/InnerJoinSuite.scala | 11 +- .../sql/execution/joins/OuterJoinSuite.scala | 3 +- .../sql/execution/joins/SemiJoinSuite.scala | 3 +- .../spark/sql/sources/BucketedReadSuite.scala | 13 +- 27 files changed, 658 insertions(+), 433 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/{ => exchange}/ExchangeCoordinator.scala (85%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{Exchange.scala => exchange/ShuffleExchange.scala} (50%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala new file mode 100644 index 0000000000000..c646dcfa11811 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.physical + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are + * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index). + */ +trait BroadcastMode { + def transform(rows: Array[InternalRow]): Any +} + +/** + * IdentityBroadcastMode requires that rows are broadcasted in their original form. + */ +case object IdentityBroadcastMode extends BroadcastMode { + override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d6e10c412ca1c..45e2841ec9dd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -75,6 +76,12 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Represents data where tuples are broadcasted to every node. It is quite common that the + * entire set of tuples is transformed into different data structure. + */ +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution + /** * Describes how an operator's output is split across partitions. The `compatibleWith`, * `guarantees`, and `satisfies` methods describe relationships between child partitionings, @@ -213,7 +220,10 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = true + override def satisfies(required: Distribution): Boolean = required match { + case _: BroadcastDistribution => false + case _ => true + } override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 @@ -351,3 +361,21 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) partitionings.map(_.toString).mkString("(", " or ", ")") } } + +/** + * Represents a partitioning where rows are collected, transformed and broadcasted to each + * node in the cluster. + */ +case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { + override val numPartitions: Int = 1 + + override def satisfies(required: Distribution): Boolean = required match { + case BroadcastDistribution(m) if m == mode => true + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case BroadcastPartitioning(m) if m == mode => true + case _ => false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 932df36b859b7..a2f386850c1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ @@ -59,7 +60,6 @@ import org.apache.spark.util.Utils * @groupname config Configuration * @groupname dataframes Custom DataFrame Creation * @groupname Ungrouped Support functions for language integrated queries - * * @since 1.0.0 */ class SQLContext private[sql]( @@ -313,10 +313,10 @@ class SQLContext private[sql]( } /** - * Returns true if the [[Queryable]] is currently cached in-memory. - * @group cachemgmt - * @since 1.3.0 - */ + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ private[sql] def isCached(qName: Queryable): Boolean = { cacheManager.lookupCachedData(qName).nonEmpty } @@ -364,6 +364,7 @@ class SQLContext private[sql]( /** * Converts $"col name" into an [[Column]]. + * * @since 1.3.0 */ // This must live here to preserve binary compatibility with Spark < 1.5. @@ -728,7 +729,6 @@ class SQLContext private[sql]( * cached/persisted before, it's also unpersisted. * * @param tableName the name of the table to be unregistered. - * * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 477a9460d7dd8..3be4cce045fea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -24,6 +24,7 @@ import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ import org.apache.spark.Logging +import org.apache.spark.broadcast import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -108,15 +109,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** - * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute - * after adding query plan information to created RDDs for visualization. - * Concrete implementations of SparkPlan should override doExecute instead. + * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after + * preparations. Concrete implementations of SparkPlan should override doExecute. */ - final def execute(): RDD[InternalRow] = { + final def execute(): RDD[InternalRow] = executeQuery { + doExecute() + } + + /** + * Returns the result of this query as a broadcast variable by delegating to doBroadcast after + * preparations. Concrete implementations of SparkPlan should override doBroadcast. + */ + final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery { + doExecuteBroadcast() + } + + /** + * Execute a query after preparing the query and adding query plan information to created RDDs + * for visualization. + */ + private final def executeQuery[T](query: => T): T = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() waitForSubqueries() - doExecute() + query } } @@ -192,6 +208,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected def doExecute(): RDD[InternalRow] + /** + * Overridden by concrete implementations of SparkPlan. + * Produces the result of the query as a broadcast variable. + */ + protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast") + } + /** * Runs this query returning the result as an array. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 382654afacb89..7347156398674 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{execution, Strategy} +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -25,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} @@ -328,7 +330,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil } else { execution.Coalesce(numPartitions, planLater(child)) :: Nil } @@ -367,7 +369,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r @ logical.Range(start, end, step, numSlices, output) => execution.Range(start, step, numSlices, r.numElements, output) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => - execution.Exchange(HashPartitioning( + exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ python.EvaluatePython(udf, child, _) => python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 990eeb22b6aa8..d79b547137c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer +import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -172,6 +173,10 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { child.execute() } + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.doExecuteBroadcast() + } + override def supportCodegen: Boolean = false override def upstream(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala new file mode 100644 index 0000000000000..40cad4b1a7645 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryNode} +import org.apache.spark.util.ThreadUtils + +/** + * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed + * SparkPlan. + */ +case class BroadcastExchange( + mode: BroadcastMode, + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + + @transient + private val timeout: Duration = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + @transient + private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + // Note that we use .executeCollect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = child.executeCollect() + + // Construct and broadcast the relation. + sparkContext.broadcast(mode.transform(input)) + } + }(BroadcastExchange.executionContext) + } + + override protected def doPrepare(): Unit = { + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the execute() code path.") + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + val result = Await.result(relationFuture, timeout) + result.asInstanceOf[broadcast.Broadcast[T]] + } +} + +object BroadcastExchange { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala new file mode 100644 index 0000000000000..709a4246365dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ + +/** + * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] + * of input data meets the + * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for + * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the + * input partition ordering requirements are met. + */ +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions + + private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + + private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + + private def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } + + /** + * Given a required distribution, returns a partitioning that satisfies that distribution. + */ + private def createPartitioning( + requiredDistribution: Distribution, + numPartitions: Int): Partitioning = { + requiredDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) + case dist => sys.error(s"Do not know how to satisfy distribution $dist") + } + } + + /** + * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled + * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]]. + */ + private def withExchangeCoordinator( + children: Seq[SparkPlan], + requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { + val supportsCoordinator = + if (children.exists(_.isInstanceOf[ShuffleExchange])) { + // Right now, ExchangeCoordinator only support HashPartitionings. + children.forall { + case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withCoordinator = + if (adaptiveExecutionEnabled && supportsCoordinator) { + val coordinator = + new ExchangeCoordinator( + children.length, + targetPostShuffleInputSize, + minNumPostShufflePartitions) + children.zip(requiredChildDistributions).map { + case (e: ShuffleExchange, _) => + // This child is an Exchange, we need to add the coordinator. + e.copy(coordinator = Some(coordinator)) + case (child, distribution) => + // If this child is not an Exchange, we need to add an Exchange for now. + // Ideally, we can try to avoid this Exchange. However, when we reach here, + // there are at least two children operators (because if there is a single child + // and we can avoid Exchange, supportsCoordinator will be false and we + // will not reach here.). Although we can make two children have the same number of + // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. + // For example, let's say we have the following plan + // Join + // / \ + // Agg Exchange + // / \ + // Exchange t2 + // / + // t1 + // In this case, because a post-shuffle partition can include multiple pre-shuffle + // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes + // after shuffle. So, even we can use the child Exchange operator of the Join to + // have a number of post-shuffle partitions that matches the number of partitions of + // Agg, we cannot say these two children are partitioned in the same way. + // Here is another case + // Join + // / \ + // Agg1 Agg2 + // / \ + // Exchange1 Exchange2 + // / \ + // t1 t2 + // In this case, two Aggs shuffle data with the same column of the join condition. + // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same + // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 + // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle + // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its + // pre-shuffle partitions by using another partitionStartIndices [0, 4]. + // So, Agg1 and Agg2 are actually not co-partitioned. + // + // It will be great to introduce a new Partitioning to represent the post-shuffle + // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. + val targetPartitioning = + createPartitioning(distribution, defaultNumPreShufflePartitions) + assert(targetPartitioning.isInstanceOf[HashPartitioning]) + ShuffleExchange(targetPartitioning, child, Some(coordinator)) + } + } else { + // If we do not need ExchangeCoordinator, the original children are returned. + children + } + + withCoordinator + } + + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + var children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) + + // Ensure that the operator's children satisfy their output distribution requirements: + children = children.zip(requiredChildDistributions).map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchange(mode, child) + case (child, distribution) => + ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + } + + // If the operator has multiple children and specifies child output distributions (e.g. join), + // then the children's output partitionings must be compatible: + def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { + case UnspecifiedDistribution => false + case BroadcastDistribution(_) => false + case _ => true + } + if (children.length > 1 + && requiredChildDistributions.exists(requireCompatiblePartitioning) + && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + + // First check if the existing partitions of the children all match. This means they are + // partitioned by the same partitioning into the same number of partitions. In that case, + // don't try to make them match `defaultPartitions`, just use the existing partitioning. + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max + val useExistingPartitioning = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + + children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. + children + } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions + } + + children.zip(requiredChildDistributions).map { + case (child, distribution) => + val targetPartitioning = createPartitioning(distribution, numPartitions) + if (child.outputPartitioning.guarantees(targetPartitioning)) { + child + } else { + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c) + case _ => ShuffleExchange(targetPartitioning, child) + } + } + } + } + } + + // Now, we need to add ExchangeCoordinator if necessary. + // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. + // However, with the way that we plan the query, we do not have a place where we have a + // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator + // at here for now. + // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, + // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. + children = withExchangeCoordinator(children, requiredChildDistributions) + + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + if (requiredOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. + if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + Sort(requiredOrdering, global = false, child = child) + } else { + child + } + } else { + child + } + } + + operator.withNewChildren(children) + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator @ ShuffleExchange(partitioning, child, _) => + child.children match { + case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + if (childPartitioning.guarantees(partitioning)) child else operator + case _ => operator + } + case operator: SparkPlan => ensureDistributionAndOrdering(operator) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala similarity index 85% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 07015e5a5aaef..6f3bb0ad2bac1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.exchange import java.util.{HashMap => JHashMap, Map => JMap} import javax.annotation.concurrent.GuardedBy @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, MapOutputStatistics, ShuffleDependency, SimpleFutureAction} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} /** * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. @@ -33,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow * * A coordinator is constructed with three parameters, `numExchanges`, * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to - * this coordinator. So, when we start to do any actual work, we have a way to make sure that - * we have got expected number of [[Exchange]]s. + * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered + * to this coordinator. So, when we start to do any actual work, we have a way to make sure that + * we have got expected number of [[ShuffleExchange]]s. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through @@ -45,26 +46,27 @@ import org.apache.spark.sql.catalyst.InternalRow * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator, + * - Before the execution of a [[SparkPlan]], for an [[ShuffleExchange]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will - * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will - * immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. + * - Once we start to execute a physical plan, an [[ShuffleExchange]] registered to this + * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle + * [[ShuffledRowRDD]]. + * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] + * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size - * statistics of pre-shuffle partitions, this coordinator will determine the number of + * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the the + * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can - * lookup the corresponding [[RDD]]. + * [[ShuffleExchange]]s. So, when an [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator + * can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and + * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until * the size of a post-shuffle partition is equal or greater than the target size. * For example, we have two stages with the following pre-shuffle partition size statistics: @@ -83,11 +85,11 @@ private[sql] class ExchangeCoordinator( extends Logging { // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[Exchange]() + private[this] val exchanges = ArrayBuffer[ShuffleExchange]() // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] = - new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] = + new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. // This variable will only be updated by doEstimationIfNecessary, which is protected by @@ -95,11 +97,11 @@ private[sql] class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be - * called in the `doPrepare` method of an [[Exchange]] operator. + * Registers an [[ShuffleExchange]] operator to this coordinator. This method is only allowed to + * be called in the `doPrepare` method of an [[ShuffleExchange]] operator. */ @GuardedBy("this") - def registerExchange(exchange: Exchange): Unit = synchronized { + def registerExchange(exchange: ShuffleExchange): Unit = synchronized { exchanges += exchange } @@ -199,7 +201,7 @@ private[sql] class ExchangeCoordinator( // Make sure we have the expected number of registered Exchange operators. assert(exchanges.length == numExchanges) - val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) // Submit all map stages val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() @@ -254,7 +256,7 @@ private[sql] class ExchangeCoordinator( } } - def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = { + def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = { doEstimationIfNecessary() if (!postShuffleRDDs.containsKey(exchange)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index e30adefc69ece..de21d7705e137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.exchange import java.util.Random @@ -24,19 +24,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ import org.apache.spark.util.MutablePair /** * Performs a shuffle that will result in the desired `newPartitioning`. */ -case class Exchange( +case class ShuffleExchange( var newPartitioning: Partitioning, child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { @@ -81,7 +80,8 @@ case class Exchange( * the returned ShuffleDependency will be the input of shuffle. */ private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { - Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer) + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, newPartitioning, serializer) } /** @@ -116,9 +116,9 @@ case class Exchange( } } -object Exchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) +object ShuffleExchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { + ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) } /** @@ -259,238 +259,3 @@ object Exchange { dependency } } - -/** - * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] - * of input data meets the - * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. Also ensure that the - * input partition ordering requirements are met. - */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions - - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize - - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled - - private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions - if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None - } - - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - - /** - * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled - * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]]. - */ - private def withExchangeCoordinator( - children: Seq[SparkPlan], - requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { - val supportsCoordinator = - if (children.exists(_.isInstanceOf[Exchange])) { - // Right now, ExchangeCoordinator only support HashPartitionings. - children.forall { - case e @ Exchange(hash: HashPartitioning, _, _) => true - case child => - child.outputPartitioning match { - case hash: HashPartitioning => true - case collection: PartitioningCollection => - collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) - case _ => false - } - } - } else { - // In this case, although we do not have Exchange operators, we may still need to - // shuffle data when we have more than one children because data generated by - // these children may not be partitioned in the same way. - // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) - children.length > 1 && supportsDistribution - } - - val withCoordinator = - if (adaptiveExecutionEnabled && supportsCoordinator) { - val coordinator = - new ExchangeCoordinator( - children.length, - targetPostShuffleInputSize, - minNumPostShufflePartitions) - children.zip(requiredChildDistributions).map { - case (e: Exchange, _) => - // This child is an Exchange, we need to add the coordinator. - e.copy(coordinator = Some(coordinator)) - case (child, distribution) => - // If this child is not an Exchange, we need to add an Exchange for now. - // Ideally, we can try to avoid this Exchange. However, when we reach here, - // there are at least two children operators (because if there is a single child - // and we can avoid Exchange, supportsCoordinator will be false and we - // will not reach here.). Although we can make two children have the same number of - // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. - // For example, let's say we have the following plan - // Join - // / \ - // Agg Exchange - // / \ - // Exchange t2 - // / - // t1 - // In this case, because a post-shuffle partition can include multiple pre-shuffle - // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes - // after shuffle. So, even we can use the child Exchange operator of the Join to - // have a number of post-shuffle partitions that matches the number of partitions of - // Agg, we cannot say these two children are partitioned in the same way. - // Here is another case - // Join - // / \ - // Agg1 Agg2 - // / \ - // Exchange1 Exchange2 - // / \ - // t1 t2 - // In this case, two Aggs shuffle data with the same column of the join condition. - // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same - // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 - // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle - // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its - // pre-shuffle partitions by using another partitionStartIndices [0, 4]. - // So, Agg1 and Agg2 are actually not co-partitioned. - // - // It will be great to introduce a new Partitioning to represent the post-shuffle - // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) - assert(targetPartitioning.isInstanceOf[HashPartitioning]) - Exchange(targetPartitioning, child, Some(coordinator)) - } - } else { - // If we do not need ExchangeCoordinator, the original children are returned. - children - } - - withCoordinator - } - - private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { - val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution - val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering - var children: Seq[SparkPlan] = operator.children - assert(requiredChildDistributions.length == children.length) - assert(requiredChildOrderings.length == children.length) - - // Ensure that the operator's children satisfy their output distribution requirements: - children = children.zip(requiredChildDistributions).map { case (child, distribution) => - if (child.outputPartitioning.satisfies(distribution)) { - child - } else { - Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) - } - } - - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - if (children.length > 1 - && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { - - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => { - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - } - - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => { - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } - - children.zip(requiredChildDistributions).map { - case (child, distribution) => { - val targetPartitioning = - createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case Exchange(_, c, _) => Exchange(targetPartitioning, c) - case _ => Exchange(targetPartitioning, child) - } - } - } - } - } - } - - // Now, we need to add ExchangeCoordinator if necessary. - // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. - // However, with the way that we plan the query, we do not have a place where we have a - // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator - // at here for now. - // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, - // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. - children = withExchangeCoordinator(children, requiredChildDistributions) - - // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: - children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => - if (requiredOrdering.nonEmpty) { - // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - Sort(requiredOrdering, global = false, child = child) - } else { - child - } - } else { - child - } - } - - operator.withNewChildren(children) - } - - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ Exchange(partitioning, child, _) => - child.children match { - case Exchange(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator - case _ => operator - } - case operator: SparkPlan => ensureDistributionAndOrdering(operator) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a64da225800a3..ddc08822f3e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution.joins -import scala.concurrent._ -import scala.concurrent.duration._ - import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -27,10 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.ThreadUtils import org.apache.spark.util.collection.CompactBuffer /** @@ -52,60 +48,25 @@ case class BroadcastHashJoin( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - val timeout: Duration = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - row.copy() - }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - // TODO: move this check into HashedRelation - val hashed = if (canJoinKeyFitWithinLong) { - LongHashedRelation( - input.iterator, buildSideKeyGenerator, input.size) - } else { - HashedRelation( - input.iterator, buildSideKeyGenerator, input.size) - } - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture + override def requiredChildDistribution: Seq[Distribution] = { + val mode = HashedRelationBroadcastMode( + canJoinKeyFitWithinLong, + rewriteKeyExpr(buildKeys), + buildPlan.output) + buildSide match { + case BuildLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastRelation = Await.result(broadcastFuture, timeout) - + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value @@ -160,7 +121,7 @@ case class BroadcastHashJoin( */ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation - val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName @@ -362,9 +323,3 @@ case class BroadcastHashJoin( } } } - -object BroadcastHashJoin { - - private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 4f1cfd2e8171b..1f99fbedde411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -38,25 +39,25 @@ case class BroadcastLeftSemiJoinHash( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = { + val mode = if (condition.isEmpty) { + HashSetBroadcastMode(rightKeys, right.output) + } else { + HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) + } + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val input = right.execute().map { row => - row.copy() - }.collect() - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) - val broadcastedRelation = sparkContext.broadcast(hashSet) - + val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]() left.execute().mapPartitionsInternal { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = - HashedRelation(input.toIterator, rightKeyGenerator, input.size) - val broadcastedRelation = sparkContext.broadcast(hashRelation) - + val broadcastedRelation = right.executeBroadcast[HashedRelation]() left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4585cbda929ab..e8bd7f69dbab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} @@ -33,7 +33,6 @@ case class BroadcastNestedLoopJoin( buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -44,8 +43,15 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def requiredChildDistribution: Seq[Distribution] = buildSide match { + case BuildLeft => + BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + UnsafeProjection.create(schema) } override def outputPartitioning: Partitioning = streamed.outputPartitioning @@ -73,15 +79,14 @@ case class BroadcastNestedLoopJoin( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - row.copy() - }.collect().toIndexedSeq) + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val relation = broadcastedRelation.value + val matchedRows = new CompactBuffer[InternalRow] - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(relation.length) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) @@ -92,8 +97,8 @@ case class BroadcastNestedLoopJoin( var i = 0 var streamRowMatched = false - while (i < broadcastedRelation.value.size) { - val broadcastedRow = broadcastedRelation.value(i) + while (i < relation.length) { + val broadcastedRow = relation(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 0220e0b8a7c2b..1cb6a00617c5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric @@ -44,22 +45,7 @@ trait HashSemiJoin { protected def buildKeyHashSet( buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { - val hashSet = new java.util.HashSet[InternalRow]() - - // Create a Hash set of buildKeys - val rightKey = rightKeyGenerator - while (buildIter.hasNext) { - val currentRow = buildIter.next() - val rowKey = rightKey(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey.copy()) - } - } - } - - hashSet + HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter) } protected def hashSemiJoin( @@ -92,3 +78,36 @@ trait HashSemiJoin { } } } + +private[execution] object HashSemiJoin { + def buildKeyHashSet( + keys: Seq[Expression], + attributes: Seq[Attribute], + rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + + // Create a Hash set of buildKeys + val key = UnsafeProjection.create(keys, attributes) + while (rows.hasNext) { + val currentRow = rows.next() + val rowKey = key(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey.copy()) + } + } + } + hashSet + } +} + +/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */ +private[execution] case class HashSetBroadcastMode( + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { + + override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = { + HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0978570d429f9..606269bf25ab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,12 +25,11 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} import org.apache.spark.sql.execution.local.LocalNode -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.MemoryLocation import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} import org.apache.spark.util.collection.CompactBuffer @@ -675,3 +674,20 @@ private[joins] object LongHashedRelation { } } } + +/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ +private[execution] case class HashedRelationBroadcastMode( + canJoinKeyFitWithinLong: Boolean, + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { + + def transform(rows: Array[InternalRow]): HashedRelation = { + val generator = UnsafeProjection.create(keys, attributes) + if (canJoinKeyFitWithinLong) { + LongHashedRelation(rows.iterator, generator, rows.length) + } else { + HashedRelation(rows.iterator, generator, rows.length) + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index ce758d63b36f4..df6dac88187cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -29,9 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * for hash join. */ case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -46,27 +44,28 @@ case class LeftSemiJoinBNL( /** The Broadcast relation */ override def right: SparkPlan = broadcast + override def requiredChildDistribution: Seq[Distribution] = { + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - row.copy() - }.collect().toIndexedSeq) + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow + val relation = broadcastedRelation.value streamedIter.filter(streamedRow => { var i = 0 var matched = false - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + while (i < relation.length && !matched) { + if (boundCondition(joinedRow(streamedRow, relation(i)))) { matched = true } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ef76847bcb2d2..cd543d4195286 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange /** @@ -38,7 +39,8 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { val shuffled = new ShuffledRowRDD( - Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer)) + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -110,7 +112,8 @@ case class TakeOrderedAndProject( } } val shuffled = new ShuffledRowRDD( - Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer)) + ShuffleExchange.prepareShuffleDependency( + localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e8d0678989d88..83d7953aaf700 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -23,9 +23,9 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -357,7 +357,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 99ba2e2061258..50a246489ee55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} -import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.test.SQLTestData.TestData2 @@ -1119,7 +1119,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } atFirstAgg = true } - case e: Exchange => atFirstAgg = false + case e: ShuffleExchange => atFirstAgg = false case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 35ff1c40fe6ca..b1c588a63d030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql._ +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext @@ -297,13 +298,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -311,7 +312,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -348,13 +349,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -362,7 +363,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -404,13 +405,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -456,13 +457,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 87bff3295f5be..d4f22de90c523 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -28,7 +29,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => Exchange(SinglePartition, plan), + plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 250ce8f86698f..4de56783fabce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -212,7 +213,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length assert(numExchanges === 5) } @@ -227,7 +228,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length assert(numExchanges === 5) } @@ -295,7 +296,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -333,7 +334,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -353,7 +354,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -376,7 +377,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -435,7 +436,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = Exchange(finalPartitioning, + val inputPlan = ShuffleExchange(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -444,7 +445,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -455,7 +456,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = Exchange(finalPartitioning, + val inputPlan = ShuffleExchange(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -464,7 +465,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index e25b5e0610ea1..a256ee95a153c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext} +import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ /** @@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = df3.queryExecution.sparkPlan + val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index e22a810a6b42f..6dfff3770b882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -88,7 +89,15 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) + val broadcastJoin = joins.BroadcastHashJoin( + leftKeys, + rightKeys, + Inner, + side, + boundCondition, + leftPlan, + rightPlan) + EnsureRequirements(sqlContext).apply(broadcastJoin) } def makeSortMergeJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index f4b01fbad0585..cd6b6fcbb18af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 9c86084f9b8a9..f3ad8409e5a3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9ba645626fe72..a05a57c0f5072 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,8 +22,9 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{Exchange, PhysicalRDD} +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -252,8 +253,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin]) val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin] - assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft) - assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight) + assert(joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft) + assert(joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight) } } } @@ -312,7 +313,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) } } @@ -326,7 +327,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) } } @@ -339,7 +340,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val agged = hiveContext.table("bucketed_table").groupBy("i").count() // make sure we fall back to non-bucketing mode and can't avoid shuffle - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isDefined) checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i")) } } From 7eb83fefd19e137d80a23b5174b66b14831c291a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 21 Feb 2016 13:21:59 -0800 Subject: [PATCH 1953/2089] [SPARK-13137][SQL] NullPoingException in schema inference for CSV when the first line is empty https://issues.apache.org/jira/browse/SPARK-13137 This PR adds a filter in schema inference so that it does not emit NullPointException. Also, I removed `MAX_COMMENT_LINES_IN_HEADER `but instead used a monad chaining with `filter()` and `first()`. Lastly, I simply added a newline rather than adding a new file for this so that this is covered with the original tests. Author: hyukjinkwon Closes #11023 from HyukjinKwon/SPARK-13137. --- .../sql/execution/datasources/csv/CSVOptions.scala | 3 --- .../sql/execution/datasources/csv/CSVRelation.scala | 12 +++++++----- sql/core/src/test/resources/cars.csv | 1 + 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index bea8e97a9a94f..38aa2dd80a4f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -75,9 +75,6 @@ private[sql] class CSVOptions( val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") - // Limit the number of lines we'll search for a header row that isn't comment-prefixed - val MAX_COMMENT_LINES_IN_HEADER = 10 - // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index f8e3a1b6d46ae..471ed0d5604b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -154,12 +154,14 @@ private[csv] class CSVRelation( */ private def findFirstLine(rdd: RDD[String]): String = { if (params.isCommentSet) { - rdd.take(params.MAX_COMMENT_LINES_IN_HEADER) - .find(!_.startsWith(params.comment.toString)) - .getOrElse(sys.error(s"No uncommented header line in " + - s"first ${params.MAX_COMMENT_LINES_IN_HEADER} lines")) + val comment = params.comment.toString + rdd.filter { line => + line.trim.nonEmpty && !line.startsWith(comment) + }.first() } else { - rdd.first() + rdd.filter { line => + line.trim.nonEmpty + }.first() } } } diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/cars.csv index 2b9d74ca607ad..40ded573ade5c 100644 --- a/sql/core/src/test/resources/cars.csv +++ b/sql/core/src/test/resources/cars.csv @@ -1,3 +1,4 @@ + year,make,model,comment,blank "2012","Tesla","S","No comment", From 6c3832b26e119626205732b8fd03c8f5ba986896 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sun, 21 Feb 2016 15:00:24 -0800 Subject: [PATCH 1954/2089] [SPARK-13080][SQL] Implement new Catalog API using Hive ## What changes were proposed in this pull request? This is a step towards merging `SQLContext` and `HiveContext`. A new internal Catalog API was introduced in #10982 and extended in #11069. This patch introduces an implementation of this API using `HiveClient`, an existing interface to Hive. It also extends `HiveClient` with additional calls to Hive that are needed to complete the catalog implementation. *Where should I start reviewing?* The new catalog introduced is `HiveCatalog`. This class is relatively simple because it just calls `HiveClientImpl`, where most of the new logic is. I would not start with `HiveClient`, `HiveQl`, or `HiveMetastoreCatalog`, which are modified mainly because of a refactor. *Why is this patch so big?* I had to refactor HiveClient to remove an intermediate representation of databases, tables, partitions etc. After this refactor `CatalogTable` convert directly to and from `HiveTable` (etc.). Otherwise we would have to first convert `CatalogTable` to the intermediate representation and then convert that to HiveTable, which is messy. The new class hierarchy is as follows: ``` org.apache.spark.sql.catalyst.catalog.Catalog - org.apache.spark.sql.catalyst.catalog.InMemoryCatalog - org.apache.spark.sql.hive.HiveCatalog ``` Note that, as of this patch, none of these classes are currently used anywhere yet. This will come in the future before the Spark 2.0 release. ## How was the this patch tested? All existing unit tests, and HiveCatalogSuite that extends CatalogTestCases. Author: Andrew Or Author: Reynold Xin Closes #11293 from rxin/hive-catalog. --- .../apache/spark/sql/AnalysisException.scala | 3 + .../spark/sql/catalyst/analysis/Catalog.scala | 9 - .../analysis/NoSuchItemException.scala | 52 +++ .../catalyst/catalog/InMemoryCatalog.scala | 154 +++---- .../sql/catalyst/catalog/interface.scala | 190 ++++++--- .../catalyst/catalog/CatalogTestCases.scala | 284 ++++++++----- .../apache/spark/sql/hive/HiveCatalog.scala | 293 ++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 171 ++++---- .../org/apache/spark/sql/hive/HiveQl.scala | 145 +++---- .../spark/sql/hive/client/HiveClient.scala | 210 ++++++---- .../sql/hive/client/HiveClientImpl.scala | 379 ++++++++++++------ .../hive/execution/CreateTableAsSelect.scala | 20 +- .../hive/execution/CreateViewAsSelect.scala | 20 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../spark/sql/hive/HiveCatalogSuite.scala | 49 +++ .../sql/hive/HiveMetastoreCatalogSuite.scala | 40 +- .../apache/spark/sql/hive/HiveQlSuite.scala | 94 ++--- .../sql/hive/MetastoreDataSourcesSuite.scala | 25 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 4 +- .../spark/sql/hive/client/VersionsSuite.scala | 37 +- .../sql/hive/execution/PruningSuite.scala | 2 +- 21 files changed, 1483 insertions(+), 700 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index f9992185a4563..97f28fad62e43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql import org.apache.spark.annotation.DeveloperApi + +// TODO: don't swallow original stack trace if it exists + /** * :: DeveloperApi :: * Thrown when a query fails to analyze, usually because the query itself is invalid. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 67edab55dba8b..52b284b757df5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -20,20 +20,11 @@ package org.apache.spark.sql.catalyst.analysis import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, EmptyConf, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} -/** - * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception - * as an AnalysisException with the correct position information. - */ -class NoSuchTableException extends Exception - -class NoSuchDatabaseException extends Exception /** * An interface for looking up relations by name. Used by an [[Analyzer]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala new file mode 100644 index 0000000000000..81399db9bc070 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.catalog.Catalog.TablePartitionSpec + + +/** + * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception + * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + */ +abstract class NoSuchItemException extends Exception { + override def getMessage: String +} + +class NoSuchDatabaseException(db: String) extends NoSuchItemException { + override def getMessage: String = s"Database $db not found" +} + +class NoSuchTableException(db: String, table: String) extends NoSuchItemException { + override def getMessage: String = s"Table $table not found in database $db" +} + +class NoSuchPartitionException( + db: String, + table: String, + spec: TablePartitionSpec) + extends NoSuchItemException { + + override def getMessage: String = { + s"Partition not found in table $table database $db:\n" + spec.mkString("\n") + } +} + +class NoSuchFunctionException(db: String, func: String) extends NoSuchItemException { + override def getMessage: String = s"Function $func not found in database $db" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 38be61c52a95e..cba4de34f2b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -30,15 +30,16 @@ import org.apache.spark.sql.AnalysisException class InMemoryCatalog extends Catalog { import Catalog._ - private class TableDesc(var table: Table) { - val partitions = new mutable.HashMap[PartitionSpec, TablePartition] + private class TableDesc(var table: CatalogTable) { + val partitions = new mutable.HashMap[TablePartitionSpec, CatalogTablePartition] } - private class DatabaseDesc(var db: Database) { + private class DatabaseDesc(var db: CatalogDatabase) { val tables = new mutable.HashMap[String, TableDesc] - val functions = new mutable.HashMap[String, Function] + val functions = new mutable.HashMap[String, CatalogFunction] } + // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { @@ -47,39 +48,33 @@ class InMemoryCatalog extends Catalog { } private def existsFunction(db: String, funcName: String): Boolean = { - assertDbExists(db) + requireDbExists(db) catalog(db).functions.contains(funcName) } private def existsTable(db: String, table: String): Boolean = { - assertDbExists(db) + requireDbExists(db) catalog(db).tables.contains(table) } - private def existsPartition(db: String, table: String, spec: PartitionSpec): Boolean = { - assertTableExists(db, table) + private def existsPartition(db: String, table: String, spec: TablePartitionSpec): Boolean = { + requireTableExists(db, table) catalog(db).tables(table).partitions.contains(spec) } - private def assertDbExists(db: String): Unit = { - if (!catalog.contains(db)) { - throw new AnalysisException(s"Database $db does not exist") - } - } - - private def assertFunctionExists(db: String, funcName: String): Unit = { + private def requireFunctionExists(db: String, funcName: String): Unit = { if (!existsFunction(db, funcName)) { throw new AnalysisException(s"Function $funcName does not exist in $db database") } } - private def assertTableExists(db: String, table: String): Unit = { + private def requireTableExists(db: String, table: String): Unit = { if (!existsTable(db, table)) { throw new AnalysisException(s"Table $table does not exist in $db database") } } - private def assertPartitionExists(db: String, table: String, spec: PartitionSpec): Unit = { + private def requirePartitionExists(db: String, table: String, spec: TablePartitionSpec): Unit = { if (!existsPartition(db, table, spec)) { throw new AnalysisException(s"Partition does not exist in database $db table $table: $spec") } @@ -90,7 +85,7 @@ class InMemoryCatalog extends Catalog { // -------------------------------------------------------------------------- override def createDatabase( - dbDefinition: Database, + dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { if (!ignoreIfExists) { @@ -124,17 +119,20 @@ class InMemoryCatalog extends Catalog { } } - override def alterDatabase(db: String, dbDefinition: Database): Unit = synchronized { - assertDbExists(db) - assert(db == dbDefinition.name) - catalog(db).db = dbDefinition + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + requireDbExists(dbDefinition.name) + catalog(dbDefinition.name).db = dbDefinition } - override def getDatabase(db: String): Database = synchronized { - assertDbExists(db) + override def getDatabase(db: String): CatalogDatabase = synchronized { + requireDbExists(db) catalog(db).db } + override def databaseExists(db: String): Boolean = synchronized { + catalog.contains(db) + } + override def listDatabases(): Seq[String] = synchronized { catalog.keySet.toSeq } @@ -143,15 +141,17 @@ class InMemoryCatalog extends Catalog { filterPattern(listDatabases(), pattern) } + override def setCurrentDatabase(db: String): Unit = { /* no-op */ } + // -------------------------------------------------------------------------- // Tables // -------------------------------------------------------------------------- override def createTable( db: String, - tableDefinition: Table, + tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { - assertDbExists(db) + requireDbExists(db) if (existsTable(db, tableDefinition.name)) { if (!ignoreIfExists) { throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database") @@ -165,7 +165,7 @@ class InMemoryCatalog extends Catalog { db: String, table: String, ignoreIfNotExists: Boolean): Unit = synchronized { - assertDbExists(db) + requireDbExists(db) if (existsTable(db, table)) { catalog(db).tables.remove(table) } else { @@ -176,31 +176,30 @@ class InMemoryCatalog extends Catalog { } override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { - assertTableExists(db, oldName) + requireTableExists(db, oldName) val oldDesc = catalog(db).tables(oldName) oldDesc.table = oldDesc.table.copy(name = newName) catalog(db).tables.put(newName, oldDesc) catalog(db).tables.remove(oldName) } - override def alterTable(db: String, table: String, tableDefinition: Table): Unit = synchronized { - assertTableExists(db, table) - assert(table == tableDefinition.name) - catalog(db).tables(table).table = tableDefinition + override def alterTable(db: String, tableDefinition: CatalogTable): Unit = synchronized { + requireTableExists(db, tableDefinition.name) + catalog(db).tables(tableDefinition.name).table = tableDefinition } - override def getTable(db: String, table: String): Table = synchronized { - assertTableExists(db, table) + override def getTable(db: String, table: String): CatalogTable = synchronized { + requireTableExists(db, table) catalog(db).tables(table).table } override def listTables(db: String): Seq[String] = synchronized { - assertDbExists(db) + requireDbExists(db) catalog(db).tables.keySet.toSeq } override def listTables(db: String, pattern: String): Seq[String] = synchronized { - assertDbExists(db) + requireDbExists(db) filterPattern(listTables(db), pattern) } @@ -211,9 +210,9 @@ class InMemoryCatalog extends Catalog { override def createPartitions( db: String, table: String, - parts: Seq[TablePartition], + parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = synchronized { - assertTableExists(db, table) + requireTableExists(db, table) val existingParts = catalog(db).tables(table).partitions if (!ignoreIfExists) { val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } @@ -229,9 +228,9 @@ class InMemoryCatalog extends Catalog { override def dropPartitions( db: String, table: String, - partSpecs: Seq[PartitionSpec], + partSpecs: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = synchronized { - assertTableExists(db, table) + requireTableExists(db, table) val existingParts = catalog(db).tables(table).partitions if (!ignoreIfNotExists) { val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } @@ -244,30 +243,42 @@ class InMemoryCatalog extends Catalog { partSpecs.foreach(existingParts.remove) } - override def alterPartition( + override def renamePartitions( db: String, table: String, - spec: Map[String, String], - newPart: TablePartition): Unit = synchronized { - assertPartitionExists(db, table, spec) - val existingParts = catalog(db).tables(table).partitions - if (spec != newPart.spec) { - // Also a change in specs; remove the old one and add the new one back - existingParts.remove(spec) + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = synchronized { + require(specs.size == newSpecs.size, "number of old and new partition specs differ") + specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => + val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec) + val existingParts = catalog(db).tables(table).partitions + existingParts.remove(oldSpec) + existingParts.put(newSpec, newPart) + } + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = synchronized { + parts.foreach { p => + requirePartitionExists(db, table, p.spec) + catalog(db).tables(table).partitions.put(p.spec, p) } - existingParts.put(newPart.spec, newPart) } override def getPartition( db: String, table: String, - spec: Map[String, String]): TablePartition = synchronized { - assertPartitionExists(db, table, spec) + spec: TablePartitionSpec): CatalogTablePartition = synchronized { + requirePartitionExists(db, table, spec) catalog(db).tables(table).partitions(spec) } - override def listPartitions(db: String, table: String): Seq[TablePartition] = synchronized { - assertTableExists(db, table) + override def listPartitions( + db: String, + table: String): Seq[CatalogTablePartition] = synchronized { + requireTableExists(db, table) catalog(db).tables(table).partitions.values.toSeq } @@ -275,44 +286,39 @@ class InMemoryCatalog extends Catalog { // Functions // -------------------------------------------------------------------------- - override def createFunction( - db: String, - func: Function, - ignoreIfExists: Boolean): Unit = synchronized { - assertDbExists(db) + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + requireDbExists(db) if (existsFunction(db, func.name)) { - if (!ignoreIfExists) { - throw new AnalysisException(s"Function $func already exists in $db database") - } + throw new AnalysisException(s"Function $func already exists in $db database") } else { catalog(db).functions.put(func.name, func) } } override def dropFunction(db: String, funcName: String): Unit = synchronized { - assertFunctionExists(db, funcName) + requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override def alterFunction( - db: String, - funcName: String, - funcDefinition: Function): Unit = synchronized { - assertFunctionExists(db, funcName) - if (funcName != funcDefinition.name) { - // Also a rename; remove the old one and add the new one back - catalog(db).functions.remove(funcName) - } + override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + requireFunctionExists(db, oldName) + val newFunc = getFunction(db, oldName).copy(name = newName) + catalog(db).functions.remove(oldName) + catalog(db).functions.put(newName, newFunc) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized { + requireFunctionExists(db, funcDefinition.name) catalog(db).functions.put(funcDefinition.name, funcDefinition) } - override def getFunction(db: String, funcName: String): Function = synchronized { - assertFunctionExists(db, funcName) + override def getFunction(db: String, funcName: String): CatalogFunction = synchronized { + requireFunctionExists(db, funcName) catalog(db).functions(funcName) } override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { - assertDbExists(db) + requireDbExists(db) filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 56aaa6bc6c2e9..dac5f023d1f58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import javax.annotation.Nullable + import org.apache.spark.sql.AnalysisException @@ -31,41 +33,59 @@ import org.apache.spark.sql.AnalysisException abstract class Catalog { import Catalog._ + protected def requireDbExists(db: String): Unit = { + if (!databaseExists(db)) { + throw new AnalysisException(s"Database $db does not exist") + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: Database, ignoreIfExists: Boolean): Unit + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** - * Alter an existing database. This operation does not support renaming. + * Alter a database whose name matches the one specified in `dbDefinition`, + * assuming the database exists. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. */ - def alterDatabase(db: String, dbDefinition: Database): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit - def getDatabase(db: String): Database + def getDatabase(db: String): CatalogDatabase + + def databaseExists(db: String): Boolean def listDatabases(): Seq[String] def listDatabases(pattern: String): Seq[String] + def setCurrentDatabase(db: String): Unit + // -------------------------------------------------------------------------- // Tables // -------------------------------------------------------------------------- - def createTable(db: String, tableDefinition: Table, ignoreIfExists: Boolean): Unit + def createTable(db: String, tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit def dropTable(db: String, table: String, ignoreIfNotExists: Boolean): Unit def renameTable(db: String, oldName: String, newName: String): Unit /** - * Alter an existing table. This operation does not support renaming. + * Alter a table whose name that matches the one specified in `tableDefinition`, + * assuming the table exists. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. */ - def alterTable(db: String, table: String, tableDefinition: Table): Unit + def alterTable(db: String, tableDefinition: CatalogTable): Unit - def getTable(db: String, table: String): Table + def getTable(db: String, table: String): CatalogTable def listTables(db: String): Seq[String] @@ -78,43 +98,62 @@ abstract class Catalog { def createPartitions( db: String, table: String, - parts: Seq[TablePartition], + parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit def dropPartitions( db: String, table: String, - parts: Seq[PartitionSpec], + parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit /** - * Alter an existing table partition and optionally override its spec. + * Override the specs of one or many existing table partitions, assuming they exist. + * This assumes index i of `specs` corresponds to index i of `newSpecs`. + */ + def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit + + /** + * Alter one or many table partitions whose specs that match those specified in `parts`, + * assuming the partitions exist. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. */ - def alterPartition( + def alterPartitions( db: String, table: String, - spec: PartitionSpec, - newPart: TablePartition): Unit + parts: Seq[CatalogTablePartition]): Unit - def getPartition(db: String, table: String, spec: PartitionSpec): TablePartition + def getPartition(db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition // TODO: support listing by pattern - def listPartitions(db: String, table: String): Seq[TablePartition] + def listPartitions(db: String, table: String): Seq[CatalogTablePartition] // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- - def createFunction(db: String, funcDefinition: Function, ignoreIfExists: Boolean): Unit + def createFunction(db: String, funcDefinition: CatalogFunction): Unit def dropFunction(db: String, funcName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit + /** - * Alter an existing function and optionally override its name. + * Alter a function whose name that matches the one specified in `funcDefinition`, + * assuming the function exists. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. */ - def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - def getFunction(db: String, funcName: String): Function + def getFunction(db: String, funcName: String): CatalogFunction def listFunctions(db: String, pattern: String): Seq[String] @@ -127,33 +166,30 @@ abstract class Catalog { * @param name name of the function * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" */ -case class Function( - name: String, - className: String -) +case class CatalogFunction(name: String, className: String) /** * Storage format, used to describe how a partition or a table is stored. */ -case class StorageFormat( - locationUri: String, - inputFormat: String, - outputFormat: String, - serde: String, - serdeProperties: Map[String, String] -) +case class CatalogStorageFormat( + locationUri: Option[String], + inputFormat: Option[String], + outputFormat: Option[String], + serde: Option[String], + serdeProperties: Map[String, String]) /** * A column in a table. */ -case class Column( - name: String, - dataType: String, - nullable: Boolean, - comment: String -) +case class CatalogColumn( + name: String, + // This may be null when used to create views. TODO: make this type-safe; this is left + // as a string due to issues in converting Hive varchars to and from SparkSQL strings. + @Nullable dataType: String, + nullable: Boolean = true, + comment: Option[String] = None) /** @@ -162,10 +198,7 @@ case class Column( * @param spec partition spec values indexed by column name * @param storage storage format of the partition */ -case class TablePartition( - spec: Catalog.PartitionSpec, - storage: StorageFormat -) +case class CatalogTablePartition(spec: Catalog.TablePartitionSpec, storage: CatalogStorageFormat) /** @@ -174,40 +207,65 @@ case class TablePartition( * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the * future once we have a better understanding of how we want to handle skewed columns. */ -case class Table( - name: String, - description: String, - schema: Seq[Column], - partitionColumns: Seq[Column], - sortColumns: Seq[Column], - storage: StorageFormat, - numBuckets: Int, - properties: Map[String, String], - tableType: String, - createTime: Long, - lastAccessTime: Long, - viewOriginalText: Option[String], - viewText: Option[String]) { - - require(tableType == "EXTERNAL_TABLE" || tableType == "INDEX_TABLE" || - tableType == "MANAGED_TABLE" || tableType == "VIRTUAL_VIEW") +case class CatalogTable( + specifiedDatabase: Option[String], + name: String, + tableType: CatalogTableType, + storage: CatalogStorageFormat, + schema: Seq[CatalogColumn], + partitionColumns: Seq[CatalogColumn] = Seq.empty, + sortColumns: Seq[CatalogColumn] = Seq.empty, + numBuckets: Int = 0, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = System.currentTimeMillis, + properties: Map[String, String] = Map.empty, + viewOriginalText: Option[String] = None, + viewText: Option[String] = None) { + + /** Return the database this table was specified to belong to, assuming it exists. */ + def database: String = specifiedDatabase.getOrElse { + throw new AnalysisException(s"table $name did not specify database") + } + + /** Return the fully qualified name of this table, assuming the database was specified. */ + def qualifiedName: String = s"$database.$name" + + /** Syntactic sugar to update a field in `storage`. */ + def withNewStorage( + locationUri: Option[String] = storage.locationUri, + inputFormat: Option[String] = storage.inputFormat, + outputFormat: Option[String] = storage.outputFormat, + serde: Option[String] = storage.serde, + serdeProperties: Map[String, String] = storage.serdeProperties): CatalogTable = { + copy(storage = CatalogStorageFormat( + locationUri, inputFormat, outputFormat, serde, serdeProperties)) + } + +} + + +case class CatalogTableType private(name: String) +object CatalogTableType { + val EXTERNAL_TABLE = new CatalogTableType("EXTERNAL_TABLE") + val MANAGED_TABLE = new CatalogTableType("MANAGED_TABLE") + val INDEX_TABLE = new CatalogTableType("INDEX_TABLE") + val VIRTUAL_VIEW = new CatalogTableType("VIRTUAL_VIEW") } /** * A database defined in the catalog. */ -case class Database( - name: String, - description: String, - locationUri: String, - properties: Map[String, String] -) +case class CatalogDatabase( + name: String, + description: String, + locationUri: String, + properties: Map[String, String]) object Catalog { /** - * Specifications of a table partition indexed by column name. + * Specifications of a table partition. Mapping column name to column value. */ - type PartitionSpec = Map[String, String] + type TablePartitionSpec = Map[String, String] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 45c5ceecb0eef..e0d1220d13e7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException @@ -26,18 +28,38 @@ import org.apache.spark.sql.AnalysisException * * Implementations of the [[Catalog]] interface can create test suites by extending this. */ -abstract class CatalogTestCases extends SparkFunSuite { - private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map()) - private val part1 = TablePartition(Map("a" -> "1"), storageFormat) - private val part2 = TablePartition(Map("b" -> "2"), storageFormat) - private val part3 = TablePartition(Map("c" -> "3"), storageFormat) +abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { + private lazy val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(tableInputFormat), + outputFormat = Some(tableOutputFormat), + serde = None, + serdeProperties = Map.empty) + private lazy val part1 = CatalogTablePartition(Map("a" -> "1", "b" -> "2"), storageFormat) + private lazy val part2 = CatalogTablePartition(Map("a" -> "3", "b" -> "4"), storageFormat) + private lazy val part3 = CatalogTablePartition(Map("a" -> "5", "b" -> "6"), storageFormat) private val funcClass = "org.apache.spark.myFunc" + // Things subclasses should override + protected val tableInputFormat: String = "org.apache.park.serde.MyInputFormat" + protected val tableOutputFormat: String = "org.apache.park.serde.MyOutputFormat" + protected def newUriForDatabase(): String = "uri" + protected def resetState(): Unit = { } protected def newEmptyCatalog(): Catalog + // Clear all state after each test + override def afterEach(): Unit = { + try { + resetState() + } finally { + super.afterEach() + } + } + /** * Creates a basic catalog, with the following structure: * + * default * db1 * db2 * - tbl1 @@ -48,37 +70,65 @@ abstract class CatalogTestCases extends SparkFunSuite { */ private def newBasicCatalog(): Catalog = { val catalog = newEmptyCatalog() + // When testing against a real catalog, the default database may already exist + catalog.createDatabase(newDb("default"), ignoreIfExists = true) catalog.createDatabase(newDb("db1"), ignoreIfExists = false) catalog.createDatabase(newDb("db2"), ignoreIfExists = false) - catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false) - catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl1", "db2"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl2", "db2"), ignoreIfExists = false) catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) - catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1")) catalog } - private def newFunc(): Function = Function("funcname", funcClass) + private def newFunc(): CatalogFunction = CatalogFunction("funcname", funcClass) + + private def newDb(name: String): CatalogDatabase = { + CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) + } + + private def newTable(name: String, db: String): CatalogTable = { + CatalogTable( + specifiedDatabase = Some(db), + name = name, + tableType = CatalogTableType.EXTERNAL_TABLE, + storage = storageFormat, + schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")), + partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string"))) + } - private def newDb(name: String = "default"): Database = - Database(name, name + " description", "uri", Map.empty) + private def newFunc(name: String): CatalogFunction = CatalogFunction(name, funcClass) - private def newTable(name: String): Table = - Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0, - None, None) + /** + * Whether the catalog's table partitions equal the ones given. + * Note: Hive sets some random serde things, so we just compare the specs here. + */ + private def catalogPartitionsEqual( + catalog: Catalog, + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Boolean = { + catalog.listPartitions(db, table).map(_.spec).toSet == parts.map(_.spec).toSet + } - private def newFunc(name: String): Function = Function(name, funcClass) // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - test("basic create, drop and list databases") { + test("basic create and list databases") { val catalog = newEmptyCatalog() - catalog.createDatabase(newDb(), ignoreIfExists = false) - assert(catalog.listDatabases().toSet == Set("default")) - - catalog.createDatabase(newDb("default2"), ignoreIfExists = false) - assert(catalog.listDatabases().toSet == Set("default", "default2")) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + assert(catalog.databaseExists("default")) + assert(!catalog.databaseExists("testing")) + assert(!catalog.databaseExists("testing2")) + catalog.createDatabase(newDb("testing"), ignoreIfExists = false) + assert(catalog.databaseExists("testing")) + assert(catalog.listDatabases().toSet == Set("default", "testing")) + catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) + assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) + assert(catalog.databaseExists("testing2")) + assert(!catalog.databaseExists("does_not_exist")) } test("get database when a database exists") { @@ -93,7 +143,7 @@ abstract class CatalogTestCases extends SparkFunSuite { test("list databases without pattern") { val catalog = newBasicCatalog() - assert(catalog.listDatabases().toSet == Set("db1", "db2")) + assert(catalog.listDatabases().toSet == Set("default", "db1", "db2")) } test("list databases with pattern") { @@ -107,7 +157,7 @@ abstract class CatalogTestCases extends SparkFunSuite { test("drop database") { val catalog = newBasicCatalog() catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) - assert(catalog.listDatabases().toSet == Set("db2")) + assert(catalog.listDatabases().toSet == Set("default", "db2")) } test("drop database when the database is not empty") { @@ -118,6 +168,7 @@ abstract class CatalogTestCases extends SparkFunSuite { intercept[AnalysisException] { catalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) } + resetState() // Throw exception if there are tables left val catalog2 = newBasicCatalog() @@ -125,11 +176,12 @@ abstract class CatalogTestCases extends SparkFunSuite { intercept[AnalysisException] { catalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) } + resetState() // When cascade is true, it should drop them val catalog3 = newBasicCatalog() catalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) - assert(catalog3.listDatabases().toSet == Set("db1")) + assert(catalog3.listDatabases().toSet == Set("default", "db1")) } test("drop database when the database does not exist") { @@ -144,13 +196,19 @@ abstract class CatalogTestCases extends SparkFunSuite { test("alter database") { val catalog = newBasicCatalog() - catalog.alterDatabase("db1", Database("db1", "new description", "lll", Map.empty)) - assert(catalog.getDatabase("db1").description == "new description") + val db1 = catalog.getDatabase("db1") + // Note: alter properties here because Hive does not support altering other fields + catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) + val newDb1 = catalog.getDatabase("db1") + assert(db1.properties.isEmpty) + assert(newDb1.properties.size == 2) + assert(newDb1.properties.get("k") == Some("v3")) + assert(newDb1.properties.get("good") == Some("true")) } test("alter database should throw exception when the database does not exist") { intercept[AnalysisException] { - newBasicCatalog().alterDatabase("no_db", Database("no_db", "ddd", "lll", Map.empty)) + newBasicCatalog().alterDatabase(newDb("does_not_exist")) } } @@ -165,61 +223,56 @@ abstract class CatalogTestCases extends SparkFunSuite { assert(catalog.listTables("db2").toSet == Set("tbl2")) } - test("drop table when database / table does not exist") { + test("drop table when database/table does not exist") { val catalog = newBasicCatalog() - // Should always throw exception when the database does not exist intercept[AnalysisException] { catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = false) } - intercept[AnalysisException] { catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = true) } - // Should throw exception when the table does not exist, if ignoreIfNotExists is false intercept[AnalysisException] { catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = false) } - catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = true) } test("rename table") { val catalog = newBasicCatalog() - assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) catalog.renameTable("db2", "tbl1", "tblone") assert(catalog.listTables("db2").toSet == Set("tblone", "tbl2")) } - test("rename table when database / table does not exist") { + test("rename table when database/table does not exist") { val catalog = newBasicCatalog() - - intercept[AnalysisException] { // Throw exception when the database does not exist + intercept[AnalysisException] { catalog.renameTable("unknown_db", "unknown_table", "unknown_table") } - - intercept[AnalysisException] { // Throw exception when the table does not exist + intercept[AnalysisException] { catalog.renameTable("db2", "unknown_table", "unknown_table") } } test("alter table") { val catalog = newBasicCatalog() - catalog.alterTable("db2", "tbl1", newTable("tbl1").copy(createTime = 10)) - assert(catalog.getTable("db2", "tbl1").createTime == 10) + val tbl1 = catalog.getTable("db2", "tbl1") + catalog.alterTable("db2", tbl1.copy(properties = Map("toh" -> "frem"))) + val newTbl1 = catalog.getTable("db2", "tbl1") + assert(!tbl1.properties.contains("toh")) + assert(newTbl1.properties.size == tbl1.properties.size + 1) + assert(newTbl1.properties.get("toh") == Some("frem")) } - test("alter table when database / table does not exist") { + test("alter table when database/table does not exist") { val catalog = newBasicCatalog() - - intercept[AnalysisException] { // Throw exception when the database does not exist - catalog.alterTable("unknown_db", "unknown_table", newTable("unknown_table")) + intercept[AnalysisException] { + catalog.alterTable("unknown_db", newTable("tbl1", "unknown_db")) } - - intercept[AnalysisException] { // Throw exception when the table does not exist - catalog.alterTable("db2", "unknown_table", newTable("unknown_table")) + intercept[AnalysisException] { + catalog.alterTable("db2", newTable("unknown_table", "db2")) } } @@ -227,12 +280,11 @@ abstract class CatalogTestCases extends SparkFunSuite { assert(newBasicCatalog().getTable("db2", "tbl1").name == "tbl1") } - test("get table when database / table does not exist") { + test("get table when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { catalog.getTable("unknown_db", "unknown_table") } - intercept[AnalysisException] { catalog.getTable("db2", "unknown_table") } @@ -246,10 +298,7 @@ abstract class CatalogTestCases extends SparkFunSuite { test("list tables with pattern") { val catalog = newBasicCatalog() - - // Test when database does not exist intercept[AnalysisException] { catalog.listTables("unknown_db") } - assert(catalog.listTables("db1", "*").toSet == Set.empty) assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2")) assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2")) @@ -263,12 +312,12 @@ abstract class CatalogTestCases extends SparkFunSuite { test("basic create and list partitions") { val catalog = newEmptyCatalog() catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - catalog.createTable("mydb", newTable("mytbl"), ignoreIfExists = false) - catalog.createPartitions("mydb", "mytbl", Seq(part1, part2), ignoreIfExists = false) - assert(catalog.listPartitions("mydb", "mytbl").toSet == Set(part1, part2)) + catalog.createTable("mydb", newTable("tbl", "mydb"), ignoreIfExists = false) + catalog.createPartitions("mydb", "tbl", Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual(catalog, "mydb", "tbl", Seq(part1, part2))) } - test("create partitions when database / table does not exist") { + test("create partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { catalog.createPartitions("does_not_exist", "tbl1", Seq(), ignoreIfExists = false) @@ -288,16 +337,17 @@ abstract class CatalogTestCases extends SparkFunSuite { test("drop partitions") { val catalog = newBasicCatalog() - assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) - assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part2)) + assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2))) + resetState() val catalog2 = newBasicCatalog() - assert(catalog2.listPartitions("db2", "tbl2").toSet == Set(part1, part2)) + assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2))) catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) assert(catalog2.listPartitions("db2", "tbl2").isEmpty) } - test("drop partitions when database / table does not exist") { + test("drop partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) @@ -317,14 +367,14 @@ abstract class CatalogTestCases extends SparkFunSuite { test("get partition") { val catalog = newBasicCatalog() - assert(catalog.getPartition("db2", "tbl2", part1.spec) == part1) - assert(catalog.getPartition("db2", "tbl2", part2.spec) == part2) + assert(catalog.getPartition("db2", "tbl2", part1.spec).spec == part1.spec) + assert(catalog.getPartition("db2", "tbl2", part2.spec).spec == part2.spec) intercept[AnalysisException] { catalog.getPartition("db2", "tbl1", part3.spec) } } - test("get partition when database / table does not exist") { + test("get partition when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { catalog.getPartition("does_not_exist", "tbl1", part1.spec) @@ -334,28 +384,69 @@ abstract class CatalogTestCases extends SparkFunSuite { } } - test("alter partitions") { + test("rename partitions") { + val catalog = newBasicCatalog() + val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) + val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) + val newSpecs = Seq(newPart1.spec, newPart2.spec) + catalog.renamePartitions("db2", "tbl2", Seq(part1.spec, part2.spec), newSpecs) + assert(catalog.getPartition("db2", "tbl2", newPart1.spec).spec === newPart1.spec) + assert(catalog.getPartition("db2", "tbl2", newPart2.spec).spec === newPart2.spec) + // The old partitions should no longer exist + intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part1.spec) } + intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part2.spec) } + } + + test("rename partitions when database/table does not exist") { val catalog = newBasicCatalog() - val partSameSpec = part1.copy(storage = storageFormat.copy(serde = "myserde")) - val partNewSpec = part1.copy(spec = Map("x" -> "10")) - // alter but keep spec the same - catalog.alterPartition("db2", "tbl2", part1.spec, partSameSpec) - assert(catalog.getPartition("db2", "tbl2", part1.spec) == partSameSpec) - // alter and change spec - catalog.alterPartition("db2", "tbl2", part1.spec, partNewSpec) intercept[AnalysisException] { - catalog.getPartition("db2", "tbl2", part1.spec) + catalog.renamePartitions("does_not_exist", "tbl1", Seq(part1.spec), Seq(part2.spec)) + } + intercept[AnalysisException] { + catalog.renamePartitions("db2", "does_not_exist", Seq(part1.spec), Seq(part2.spec)) } - assert(catalog.getPartition("db2", "tbl2", partNewSpec.spec) == partNewSpec) } - test("alter partition when database / table does not exist") { + test("alter partitions") { + val catalog = newBasicCatalog() + try{ + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + catalog.setCurrentDatabase("db2") + val newLocation = newUriForDatabase() + // alter but keep spec the same + val oldPart1 = catalog.getPartition("db2", "tbl2", part1.spec) + val oldPart2 = catalog.getPartition("db2", "tbl2", part2.spec) + catalog.alterPartitions("db2", "tbl2", Seq( + oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), + oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) + val newPart1 = catalog.getPartition("db2", "tbl2", part1.spec) + val newPart2 = catalog.getPartition("db2", "tbl2", part2.spec) + assert(newPart1.storage.locationUri == Some(newLocation)) + assert(newPart2.storage.locationUri == Some(newLocation)) + assert(oldPart1.storage.locationUri != Some(newLocation)) + assert(oldPart2.storage.locationUri != Some(newLocation)) + // alter but change spec, should fail because new partition specs do not exist yet + val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) + val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) + intercept[AnalysisException] { + catalog.alterPartitions("db2", "tbl2", Seq(badPart1, badPart2)) + } + } finally { + // Remember to restore the original current database, which we assume to be "default" + catalog.setCurrentDatabase("default") + } + } + + test("alter partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.alterPartition("does_not_exist", "tbl1", part1.spec, part1) + catalog.alterPartitions("does_not_exist", "tbl1", Seq(part1)) } intercept[AnalysisException] { - catalog.alterPartition("db2", "does_not_exist", part1.spec, part1) + catalog.alterPartitions("db2", "does_not_exist", Seq(part1)) } } @@ -366,23 +457,22 @@ abstract class CatalogTestCases extends SparkFunSuite { test("basic create and list functions") { val catalog = newEmptyCatalog() catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - catalog.createFunction("mydb", newFunc("myfunc"), ignoreIfExists = false) + catalog.createFunction("mydb", newFunc("myfunc")) assert(catalog.listFunctions("mydb", "*").toSet == Set("myfunc")) } test("create function when database does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.createFunction("does_not_exist", newFunc(), ignoreIfExists = false) + catalog.createFunction("does_not_exist", newFunc()) } } test("create function that already exists") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1")) } - catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = true) } test("drop function") { @@ -421,31 +511,43 @@ abstract class CatalogTestCases extends SparkFunSuite { } } - test("alter function") { + test("rename function") { val catalog = newBasicCatalog() + val newName = "funcky" assert(catalog.getFunction("db2", "func1").className == funcClass) - // alter func but keep name - catalog.alterFunction("db2", "func1", newFunc("func1").copy(className = "muhaha")) - assert(catalog.getFunction("db2", "func1").className == "muhaha") - // alter func and change name - catalog.alterFunction("db2", "func1", newFunc("funcky")) + catalog.renameFunction("db2", "func1", newName) + intercept[AnalysisException] { catalog.getFunction("db2", "func1") } + assert(catalog.getFunction("db2", newName).name == newName) + assert(catalog.getFunction("db2", newName).className == funcClass) + intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") } + } + + test("rename function when database does not exist") { + val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.getFunction("db2", "func1") + catalog.renameFunction("does_not_exist", "func1", "func5") } - assert(catalog.getFunction("db2", "funcky").className == funcClass) + } + + test("alter function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1").className == funcClass) + catalog.alterFunction("db2", newFunc("func1").copy(className = "muhaha")) + assert(catalog.getFunction("db2", "func1").className == "muhaha") + intercept[AnalysisException] { catalog.alterFunction("db2", newFunc("funcky")) } } test("alter function when database does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.alterFunction("does_not_exist", "func1", newFunc()) + catalog.alterFunction("does_not_exist", newFunc()) } } test("list functions") { val catalog = newBasicCatalog() - catalog.createFunction("db2", newFunc("func2"), ignoreIfExists = false) - catalog.createFunction("db2", newFunc("not_me"), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func2")) + catalog.createFunction("db2", newFunc("not_me")) assert(catalog.listFunctions("db2", "*").toSet == Set("func1", "func2", "not_me")) assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2")) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala new file mode 100644 index 0000000000000..21b9cfb820eaa --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.thrift.TException + +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NoSuchItemException +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.client.HiveClient + + +/** + * A persistent implementation of the system catalog using Hive. + * All public methods must be synchronized for thread-safety. + */ +private[spark] class HiveCatalog(client: HiveClient) extends Catalog with Logging { + import Catalog._ + + // Exceptions thrown by the hive client that we would like to wrap + private val clientExceptions = Set( + classOf[HiveException].getCanonicalName, + classOf[TException].getCanonicalName) + + /** + * Whether this is an exception thrown by the hive client that should be wrapped. + * + * Due to classloader isolation issues, pattern matching won't work here so we need + * to compare the canonical names of the exceptions, which we assume to be stable. + */ + private def isClientException(e: Throwable): Boolean = { + var temp: Class[_] = e.getClass + var found = false + while (temp != null && !found) { + found = clientExceptions.contains(temp.getCanonicalName) + temp = temp.getSuperclass + } + found + } + + /** + * Run some code involving `client` in a [[synchronized]] block and wrap certain + * exceptions thrown in the process in [[AnalysisException]]. + */ + private def withClient[T](body: => T): T = synchronized { + try { + body + } catch { + case e: NoSuchItemException => + throw new AnalysisException(e.getMessage) + case NonFatal(e) if isClientException(e) => + throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) + } + } + + private def requireDbMatches(db: String, table: CatalogTable): Unit = { + if (table.specifiedDatabase != Some(db)) { + throw new AnalysisException( + s"Provided database $db does not much the one specified in the " + + s"table definition (${table.specifiedDatabase.getOrElse("n/a")})") + } + } + + private def requireTableExists(db: String, table: String): Unit = { + withClient { getTable(db, table) } + } + + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase( + dbDefinition: CatalogDatabase, + ignoreIfExists: Boolean): Unit = withClient { + client.createDatabase(dbDefinition, ignoreIfExists) + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = withClient { + client.dropDatabase(db, ignoreIfNotExists, cascade) + } + + /** + * Alter a database whose name matches the one specified in `dbDefinition`, + * assuming the database exists. + * + * Note: As of now, this only supports altering database properties! + */ + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + val existingDb = getDatabase(dbDefinition.name) + if (existingDb.properties == dbDefinition.properties) { + logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + + s"the provided database properties are the same as the old ones. Hive does not " + + s"currently support altering other database fields.") + } + client.alterDatabase(dbDefinition) + } + + override def getDatabase(db: String): CatalogDatabase = withClient { + client.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = withClient { + client.getDatabaseOption(db).isDefined + } + + override def listDatabases(): Seq[String] = withClient { + client.listDatabases("*") + } + + override def listDatabases(pattern: String): Seq[String] = withClient { + client.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = withClient { + client.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable( + db: String, + tableDefinition: CatalogTable, + ignoreIfExists: Boolean): Unit = withClient { + requireDbExists(db) + requireDbMatches(db, tableDefinition) + client.createTable(tableDefinition, ignoreIfExists) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = withClient { + requireDbExists(db) + client.dropTable(db, table, ignoreIfNotExists) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + val newTable = client.getTable(db, oldName).copy(name = newName) + client.alterTable(oldName, newTable) + } + + /** + * Alter a table whose name that matches the one specified in `tableDefinition`, + * assuming the table exists. + * + * Note: As of now, this only supports altering table properties, serde properties, + * and num buckets! + */ + override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient { + requireDbMatches(db, tableDefinition) + requireTableExists(db, tableDefinition.name) + client.alterTable(tableDefinition) + } + + override def getTable(db: String, table: String): CatalogTable = withClient { + client.getTable(db, table) + } + + override def listTables(db: String): Seq[String] = withClient { + requireDbExists(db) + client.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = withClient { + requireDbExists(db) + client.listTables(db, pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = withClient { + requireTableExists(db, table) + client.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + parts: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withClient { + requireTableExists(db, table) + // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we + // need to implement it here ourselves. This is currently somewhat expensive because + // we make multiple synchronous calls to Hive for each partition we want to drop. + val partsToDrop = + if (ignoreIfNotExists) { + parts.filter { spec => + try { + getPartition(db, table, spec) + true + } catch { + // Filter out the partitions that do not actually exist + case _: AnalysisException => false + } + } + } else { + parts + } + if (partsToDrop.nonEmpty) { + client.dropPartitions(db, table, partsToDrop) + } + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = withClient { + client.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit = withClient { + client.alterPartitions(db, table, newParts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = withClient { + client.getPartition(db, table, spec) + } + + override def listPartitions( + db: String, + table: String): Seq[CatalogTablePartition] = withClient { + client.getAllPartitions(db, table) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction( + db: String, + funcDefinition: CatalogFunction): Unit = withClient { + client.createFunction(db, funcDefinition) + } + + override def dropFunction(db: String, name: String): Unit = withClient { + client.dropFunction(db, name) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + client.renameFunction(db, oldName, newName) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient { + client.alterFunction(db, funcDefinition) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = withClient { + client.getFunction(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = withClient { + client.listFunctions(db, pattern) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c222b006a0f32..3788736fd13c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -25,15 +25,16 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.Warehouse +import org.apache.hadoop.hive.metastore.{TableType => HiveTableType, Warehouse} import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.metadata._ +import org.apache.hadoop.hive.ql.metadata.{Table => HiveTable, _} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ @@ -96,6 +97,8 @@ private[hive] object HiveSerDe { } } + +// TODO: replace this with o.a.s.sql.hive.HiveCatalog once we merge SQLContext and HiveContext private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) extends Catalog with Logging { @@ -107,16 +110,16 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) - private def getQualifiedTableName(tableIdent: TableIdentifier) = { + private def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { QualifiedTableName( tableIdent.database.getOrElse(client.currentDatabase).toLowerCase, tableIdent.table.toLowerCase) } - private def getQualifiedTableName(hiveTable: HiveTable) = { + private def getQualifiedTableName(t: CatalogTable): QualifiedTableName = { QualifiedTableName( - hiveTable.specifiedDatabase.getOrElse(client.currentDatabase).toLowerCase, - hiveTable.name.toLowerCase) + t.specifiedDatabase.getOrElse(client.currentDatabase).toLowerCase, + t.name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -175,7 +178,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... - val options = table.serdeProperties + val options = table.storage.serdeProperties val resolvedRelation = ResolvedDataSource( @@ -276,53 +279,54 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val tableType = if (isExternal) { tableProperties.put("EXTERNAL", "TRUE") - ExternalTable + CatalogTableType.EXTERNAL_TABLE } else { tableProperties.put("EXTERNAL", "FALSE") - ManagedTable + CatalogTableType.MANAGED_TABLE } val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) val dataSource = ResolvedDataSource( hive, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options) - def newSparkSQLSpecificMetastoreTable(): HiveTable = { - HiveTable( + def newSparkSQLSpecificMetastoreTable(): CatalogTable = { + CatalogTable( specifiedDatabase = Option(dbName), name = tblName, - schema = Nil, - partitionColumns = Nil, tableType = tableType, - properties = tableProperties.toMap, - serdeProperties = options) + schema = Nil, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + serdeProperties = options + ), + properties = tableProperties.toMap) } - def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { - def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { - schema.map { field => - HiveColumn( - name = field.name, - hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), - comment = "") - } - } - + def newHiveCompatibleMetastoreTable( + relation: HadoopFsRelation, + serde: HiveSerDe): CatalogTable = { assert(partitionColumns.isEmpty) assert(relation.partitionColumns.isEmpty) - HiveTable( + CatalogTable( specifiedDatabase = Option(dbName), name = tblName, - schema = schemaToHiveColumn(relation.schema), - partitionColumns = Nil, tableType = tableType, + storage = CatalogStorageFormat( + locationUri = Some(relation.paths.head), + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + serdeProperties = options + ), + schema = relation.schema.map { f => + CatalogColumn(f.name, HiveMetastoreTypes.toMetastoreType(f.dataType)) + }, properties = tableProperties.toMap, - serdeProperties = options, - location = Some(relation.paths.head), - viewText = None, // TODO We need to place the SQL string here. - inputFormat = serde.inputFormat, - outputFormat = serde.outputFormat, - serde = serde.serde) + viewText = None) // TODO: We need to place the SQL string here } // TODO: Support persisting partitioned data source relations in Hive compatible format @@ -379,7 +383,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte // specific way. try { logInfo(message) - client.createTable(table) + client.createTable(table, ignoreIfExists = false) } catch { case throwable: Throwable => val warningMessage = @@ -387,20 +391,20 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte s"it into Hive metastore in Spark SQL specific format." logWarning(warningMessage, throwable) val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() - client.createTable(sparkSqlSpecificTable) + client.createTable(sparkSqlSpecificTable, ignoreIfExists = false) } case (None, message) => logWarning(message) val hiveTable = newSparkSQLSpecificMetastoreTable() - client.createTable(hiveTable) + client.createTable(hiveTable, ignoreIfExists = false) } } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - new Path(new Path(client.getDatabase(dbName).location), tblName).toString + new Path(new Path(client.getDatabase(dbName).locationUri), tblName).toString } override def tableExists(tableIdent: TableIdentifier): Boolean = { @@ -420,7 +424,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte // Then, if alias is specified, wrap the table with a Subquery using the alias. // Otherwise, wrap the table with a Subquery using the table name. alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) - } else if (table.tableType == VirtualView) { + } else if (table.tableType == CatalogTableType.VIRTUAL_VIEW) { val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) alias match { // because hive use things like `_c0` to build the expanded text @@ -429,7 +433,8 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText)) } } else { - MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) + MetastoreRelation( + qualifiedTableName.database, qualifiedTableName.name, alias)(table, client, hive) } } @@ -602,16 +607,14 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val schema = if (table.schema.nonEmpty) { table.schema } else { - child.output.map { - attr => new HiveColumn( - attr.name, - HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + child.output.map { a => + CatalogColumn(a.name, HiveMetastoreTypes.toMetastoreType(a.dataType), a.nullable) } } val desc = table.copy(schema = schema) - if (hive.convertCTAS && table.serde.isEmpty) { + if (hive.convertCTAS && table.storage.serde.isEmpty) { // Do the conversion when spark.sql.hive.convertCTAS is true and the query // does not specify any storage format (file format and storage handler). if (table.specifiedDatabase.isDefined) { @@ -632,9 +635,9 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte child ) } else { - val desc = if (table.serde.isEmpty) { + val desc = if (table.storage.serde.isEmpty) { // add default serde - table.copy( + table.withNewStorage( serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) } else { table @@ -744,10 +747,13 @@ private[hive] case class InsertIntoHiveTable( } } -private[hive] case class MetastoreRelation - (databaseName: String, tableName: String, alias: Option[String]) - (val table: HiveTable) - (@transient private val sqlContext: SQLContext) +private[hive] case class MetastoreRelation( + databaseName: String, + tableName: String, + alias: Option[String]) + (val table: CatalogTable, + @transient private val client: HiveClient, + @transient private val sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { @@ -765,7 +771,12 @@ private[hive] case class MetastoreRelation override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil - @transient val hiveQlTable: Table = { + private def toHiveColumn(c: CatalogColumn): FieldSchema = { + new FieldSchema(c.name, c.dataType, c.comment.orNull) + } + + // TODO: merge this with HiveClientImpl#toHiveTable + @transient val hiveQlTable: HiveTable = { // We start by constructing an API table as Hive performs several important transformations // internally when converting an API table to a QL table. val tTable = new org.apache.hadoop.hive.metastore.api.Table() @@ -776,27 +787,31 @@ private[hive] case class MetastoreRelation tTable.setParameters(tableParameters) table.properties.foreach { case (k, v) => tableParameters.put(k, v) } - tTable.setTableType(table.tableType.name) + tTable.setTableType(table.tableType match { + case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE.toString + case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE.toString + case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE.toString + case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW.toString + }) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - tTable.setPartitionKeys( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + sd.setCols(table.schema.map(toHiveColumn).asJava) + tTable.setPartitionKeys(table.partitionColumns.map(toHiveColumn).asJava) - table.location.foreach(sd.setLocation) - table.inputFormat.foreach(sd.setInputFormat) - table.outputFormat.foreach(sd.setOutputFormat) + table.storage.locationUri.foreach(sd.setLocation) + table.storage.inputFormat.foreach(sd.setInputFormat) + table.storage.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - table.serde.foreach(serdeInfo.setSerializationLib) + table.storage.serde.foreach(serdeInfo.setSerializationLib) sd.setSerdeInfo(serdeInfo) val serdeParameters = new java.util.HashMap[String, String]() - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + table.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } serdeInfo.setParameters(serdeParameters) - new Table(tTable) + new HiveTable(tTable) } @transient override lazy val statistics: Statistics = Statistics( @@ -821,11 +836,11 @@ private[hive] case class MetastoreRelation // When metastore partition pruning is turned off, we cache the list of all partitions to // mimic the behavior of Spark < 1.5 - lazy val allPartitions = table.getAllPartitions + private lazy val allPartitions: Seq[CatalogTablePartition] = client.getAllPartitions(table) def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { - table.getPartitions(predicates) + client.getPartitionsByFilter(table, predicates) } else { allPartitions } @@ -834,23 +849,22 @@ private[hive] case class MetastoreRelation val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) - tPartition.setValues(p.values.asJava) + tPartition.setValues(p.spec.values.toList.asJava) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) + sd.setCols(table.schema.map(toHiveColumn).asJava) + p.storage.locationUri.foreach(sd.setLocation) + p.storage.inputFormat.foreach(sd.setInputFormat) + p.storage.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) // maps and lists should be set only after all elements are ready (see HIVE-7975) - serdeInfo.setSerializationLib(p.storage.serde) + p.storage.serde.foreach(serdeInfo.setSerializationLib) val serdeParameters = new java.util.HashMap[String, String]() - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + table.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } serdeInfo.setParameters(serdeParameters) @@ -877,10 +891,10 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: HiveColumn) { + implicit class SchemaAttribute(f: CatalogColumn) { def toAttribute: AttributeReference = AttributeReference( f.name, - HiveMetastoreTypes.toDataType(f.hiveType), + HiveMetastoreTypes.toDataType(f.dataType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -901,19 +915,22 @@ private[hive] case class MetastoreRelation val columnOrdinals = AttributeMap(attributes.zipWithIndex) override def inputFiles: Array[String] = { - val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + val partLocations = client + .getPartitionsByFilter(table, Nil) + .flatMap(_.storage.locationUri) + .toArray if (partLocations.nonEmpty) { partLocations } else { Array( - table.location.getOrElse( + table.storage.locationUri.getOrElse( sys.error(s"Could not get the location of ${table.qualifiedName}."))) } } override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) + MetastoreRelation(databaseName, tableName, alias)(table, client, sqlContext) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 752c037a842a8..58010513538f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.ParseUtils._ @@ -39,7 +40,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper -import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.types._ import org.apache.spark.sql.AnalysisException @@ -55,7 +55,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan { } private[hive] case class CreateTableAsSelect( - tableDesc: HiveTable, + tableDesc: CatalogTable, child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { @@ -63,14 +63,14 @@ private[hive] case class CreateTableAsSelect( override lazy val resolved: Boolean = tableDesc.specifiedDatabase.isDefined && tableDesc.schema.nonEmpty && - tableDesc.serde.isDefined && - tableDesc.inputFormat.isDefined && - tableDesc.outputFormat.isDefined && + tableDesc.storage.serde.isDefined && + tableDesc.storage.inputFormat.isDefined && + tableDesc.storage.outputFormat.isDefined && childrenResolved } private[hive] case class CreateViewAsSelect( - tableDesc: HiveTable, + tableDesc: CatalogTable, child: LogicalPlan, allowExisting: Boolean, replace: Boolean, @@ -193,7 +193,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging view: ASTNode, viewNameParts: ASTNode, query: ASTNode, - schema: Seq[HiveColumn], + schema: Seq[CatalogColumn], properties: Map[String, String], allowExist: Boolean, replace: Boolean): CreateViewAsSelect = { @@ -201,18 +201,20 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging val originalText = query.source - val tableDesc = HiveTable( + val tableDesc = CatalogTable( specifiedDatabase = dbName, name = viewName, + tableType = CatalogTableType.VIRTUAL_VIEW, schema = schema, - partitionColumns = Seq.empty[HiveColumn], + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + serdeProperties = Map.empty[String, String] + ), properties = properties, - serdeProperties = Map[String, String](), - tableType = VirtualView, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, + viewOriginalText = Some(originalText), viewText = Some(originalText)) // We need to keep the original SQL string so that if `spark.sql.nativeView` is @@ -314,8 +316,8 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging val schema = maybeColumns.map { cols => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. - nodeToColumns(cols, lowerCase = true).map(_.copy(hiveType = null)) - }.getOrElse(Seq.empty[HiveColumn]) + nodeToColumns(cols, lowerCase = true).map(_.copy(dataType = null)) + }.getOrElse(Seq.empty[CatalogColumn]) val properties = scala.collection.mutable.Map.empty[String, String] @@ -369,19 +371,23 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) // TODO add bucket support - var tableDesc: HiveTable = HiveTable( + var tableDesc: CatalogTable = CatalogTable( specifiedDatabase = dbName, name = tblName, - schema = Seq.empty[HiveColumn], - partitionColumns = Seq.empty[HiveColumn], - properties = Map[String, String](), - serdeProperties = Map[String, String](), - tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, - viewText = None) + tableType = + if (externalTable.isDefined) { + CatalogTableType.EXTERNAL_TABLE + } else { + CatalogTableType.MANAGED_TABLE + }, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + serdeProperties = Map.empty[String, String] + ), + schema = Seq.empty[CatalogColumn]) // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) @@ -392,9 +398,10 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) } - hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) - hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) - hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) + tableDesc = tableDesc.withNewStorage( + inputFormat = hiveSerDe.inputFormat.orElse(tableDesc.storage.inputFormat), + outputFormat = hiveSerDe.outputFormat.orElse(tableDesc.storage.outputFormat), + serde = hiveSerDe.serde.orElse(tableDesc.storage.serde)) children.collect { case list @ Token("TOK_TABCOLLIST", _) => @@ -440,13 +447,13 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging // TODO support the nullFormat case _ => assert(false) } - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) + tableDesc = tableDesc.withNewStorage( + serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text)) - tableDesc = tableDesc.copy(location = Option(location)) + tableDesc = tableDesc.withNewStorage(locationUri = Option(location)) case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.copy( + tableDesc = tableDesc.withNewStorage( serde = Option(unescapeSQLString(child.children.head.text))) if (child.numChildren == 2) { // This is based on the readProps(..) method in @@ -459,59 +466,59 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging .orNull (unescapeSQLString(prop), value) }.toMap - tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) + tableDesc = tableDesc.withNewStorage( + serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => child.text.toLowerCase(Locale.ENGLISH) match { case "orc" => - tableDesc = tableDesc.copy( + tableDesc = tableDesc.withNewStorage( inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( + if (tableDesc.storage.serde.isEmpty) { + tableDesc = tableDesc.withNewStorage( serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) } case "parquet" => - tableDesc = tableDesc.copy( + tableDesc = tableDesc.withNewStorage( inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( + if (tableDesc.storage.serde.isEmpty) { + tableDesc = tableDesc.withNewStorage( serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) } case "rcfile" => - tableDesc = tableDesc.copy( + tableDesc = tableDesc.withNewStorage( inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy(serde = - Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + if (tableDesc.storage.serde.isEmpty) { + tableDesc = tableDesc.withNewStorage( + serde = + Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) } case "textfile" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + tableDesc = tableDesc.withNewStorage( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) case "sequencefile" => - tableDesc = tableDesc.copy( + tableDesc = tableDesc.withNewStorage( inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) case "avro" => - tableDesc = tableDesc.copy( + tableDesc = tableDesc.withNewStorage( inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( + if (tableDesc.storage.serde.isEmpty) { + tableDesc = tableDesc.withNewStorage( serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) } @@ -522,23 +529,21 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging case Token("TOK_TABLESERIALIZER", Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) + tableDesc = tableDesc.withNewStorage(serde = Option(unquoteString(serdeName))) otherProps match { case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) + tableDesc = tableDesc.withNewStorage( + serdeProperties = tableDesc.storage.serdeProperties ++ getProperties(list)) case _ => } case Token("TOK_TABLEPROPERTIES", list :: Nil) => tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) case list @ Token("TOK_TABLEFILEFORMAT", _) => - tableDesc = tableDesc.copy( - inputFormat = - Option(unescapeSQLString(list.children.head.text)), - outputFormat = - Option(unescapeSQLString(list.children(1).text))) + tableDesc = tableDesc.withNewStorage( + inputFormat = Option(unescapeSQLString(list.children.head.text)), + outputFormat = Option(unescapeSQLString(list.children(1).text))) case Token("TOK_STORAGEHANDLER", _) => throw new AnalysisException( "CREATE TABLE AS SELECT cannot be used for a non-native table") @@ -678,15 +683,15 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging // This is based the getColumns methods in // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[HiveColumn] = { + protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[CatalogColumn] = { node.children.map(_.children).collect { case Token(rawColName, Nil) :: colTypeNode :: comment => - val colName = if (!lowerCase) rawColName - else rawColName.toLowerCase - HiveColumn( - cleanIdentifier(colName), - nodeToTypeString(colTypeNode), - comment.headOption.map(n => unescapeSQLString(n.text)).orNull) + val colName = if (!lowerCase) rawColName else rawColName.toLowerCase + CatalogColumn( + name = cleanIdentifier(colName), + dataType = nodeToTypeString(colTypeNode), + nullable = true, + comment.headOption.map(n => unescapeSQLString(n.text))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index f681cc67041a1..6a0a089fd1f44 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -18,67 +18,11 @@ package org.apache.spark.sql.hive.client import java.io.PrintStream -import java.util.{Map => JMap} -import javax.annotation.Nullable -import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression -private[hive] case class HiveDatabase(name: String, location: String) - -private[hive] abstract class TableType { val name: String } -private[hive] case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } -private[hive] case object IndexTable extends TableType { override val name = "INDEX_TABLE" } -private[hive] case object ManagedTable extends TableType { override val name = "MANAGED_TABLE" } -private[hive] case object VirtualView extends TableType { override val name = "VIRTUAL_VIEW" } - -// TODO: Use this for Tables and Partitions -private[hive] case class HiveStorageDescriptor( - location: String, - inputFormat: String, - outputFormat: String, - serde: String, - serdeProperties: Map[String, String]) - -private[hive] case class HivePartition( - values: Seq[String], - storage: HiveStorageDescriptor) - -private[hive] case class HiveColumn(name: String, @Nullable hiveType: String, comment: String) -private[hive] case class HiveTable( - specifiedDatabase: Option[String], - name: String, - schema: Seq[HiveColumn], - partitionColumns: Seq[HiveColumn], - properties: Map[String, String], - serdeProperties: Map[String, String], - tableType: TableType, - location: Option[String] = None, - inputFormat: Option[String] = None, - outputFormat: Option[String] = None, - serde: Option[String] = None, - viewText: Option[String] = None) { - - @transient - private[client] var client: HiveClient = _ - - private[client] def withClient(ci: HiveClient): this.type = { - client = ci - this - } - - def database: String = specifiedDatabase.getOrElse(sys.error("database not resolved")) - - def isPartitioned: Boolean = partitionColumns.nonEmpty - - def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) - - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = - client.getPartitionsByFilter(this, predicates) - - // Hive does not support backticks when passing names to the client. - def qualifiedName: String = s"$database.$name" -} /** * An externally visible interface to the Hive client. This interface is shared across both the @@ -106,6 +50,9 @@ private[hive] trait HiveClient { /** Returns the names of all tables in the given database. */ def listTables(dbName: String): Seq[String] + /** Returns the names of tables in the given database that matches the given pattern. */ + def listTables(dbName: String, pattern: String): Seq[String] + /** Returns the name of the active database. */ def currentDatabase: String @@ -113,46 +60,133 @@ private[hive] trait HiveClient { def setCurrentDatabase(databaseName: String): Unit /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ - def getDatabase(name: String): HiveDatabase = { - getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) + final def getDatabase(name: String): CatalogDatabase = { + getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException(name)) } /** Returns the metadata for a given database, or None if it doesn't exist. */ - def getDatabaseOption(name: String): Option[HiveDatabase] + def getDatabaseOption(name: String): Option[CatalogDatabase] + + /** List the names of all the databases that match the specified pattern. */ + def listDatabases(pattern: String): Seq[String] /** Returns the specified table, or throws [[NoSuchTableException]]. */ - def getTable(dbName: String, tableName: String): HiveTable = { - getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException) + final def getTable(dbName: String, tableName: String): CatalogTable = { + getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException(dbName, tableName)) } - /** Returns the metadata for the specified table or None if it doens't exist. */ - def getTableOption(dbName: String, tableName: String): Option[HiveTable] + /** Returns the metadata for the specified table or None if it doesn't exist. */ + def getTableOption(dbName: String, tableName: String): Option[CatalogTable] /** Creates a view with the given metadata. */ - def createView(view: HiveTable): Unit + def createView(view: CatalogTable): Unit /** Updates the given view with new metadata. */ - def alertView(view: HiveTable): Unit + def alertView(view: CatalogTable): Unit /** Creates a table with the given metadata. */ - def createTable(table: HiveTable): Unit + def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit - /** Updates the given table with new metadata. */ - def alterTable(table: HiveTable): Unit + /** Drop the specified table. */ + def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean): Unit + + /** Alter a table whose name matches the one specified in `table`, assuming it exists. */ + final def alterTable(table: CatalogTable): Unit = alterTable(table.name, table) + + /** Updates the given table with new metadata, optionally renaming the table. */ + def alterTable(tableName: String, table: CatalogTable): Unit /** Creates a new database with the given name. */ - def createDatabase(database: HiveDatabase): Unit + def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit + + /** + * Drop the specified database, if it exists. + * + * @param name database to drop + * @param ignoreIfNotExists if true, do not throw error if the database does not exist + * @param cascade whether to remove all associated objects such as tables and functions + */ + def dropDatabase(name: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + + /** + * Alter a database whose name matches the one specified in `database`, assuming it exists. + */ + def alterDatabase(database: CatalogDatabase): Unit + + /** + * Create one or many partitions in the given table. + */ + def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit + + /** + * Drop one or many partitions in the given table. + * + * Note: Unfortunately, Hive does not currently provide a way to ignore this call if the + * partitions do not already exist. The seemingly relevant flag `ifExists` in + * [[org.apache.hadoop.hive.metastore.PartitionDropOptions]] is not read anywhere. + */ + def dropPartitions( + db: String, + table: String, + specs: Seq[Catalog.TablePartitionSpec]): Unit - /** Returns the specified paritition or None if it does not exist. */ + /** + * Rename one or many existing table partitions, assuming they exist. + */ + def renamePartitions( + db: String, + table: String, + specs: Seq[Catalog.TablePartitionSpec], + newSpecs: Seq[Catalog.TablePartitionSpec]): Unit + + /** + * Alter one or more table partitions whose specs match the ones specified in `newParts`, + * assuming the partitions exist. + */ + def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit + + /** Returns the specified partition, or throws [[NoSuchPartitionException]]. */ + final def getPartition( + dbName: String, + tableName: String, + spec: Catalog.TablePartitionSpec): CatalogTablePartition = { + getPartitionOption(dbName, tableName, spec).getOrElse { + throw new NoSuchPartitionException(dbName, tableName, spec) + } + } + + /** Returns the specified partition or None if it does not exist. */ + final def getPartitionOption( + db: String, + table: String, + spec: Catalog.TablePartitionSpec): Option[CatalogTablePartition] = { + getPartitionOption(getTable(db, table), spec) + } + + /** Returns the specified partition or None if it does not exist. */ def getPartitionOption( - hTable: HiveTable, - partitionSpec: JMap[String, String]): Option[HivePartition] + table: CatalogTable, + spec: Catalog.TablePartitionSpec): Option[CatalogTablePartition] + + /** Returns all partitions for the given table. */ + final def getAllPartitions(db: String, table: String): Seq[CatalogTablePartition] = { + getAllPartitions(getTable(db, table)) + } /** Returns all partitions for the given table. */ - def getAllPartitions(hTable: HiveTable): Seq[HivePartition] + def getAllPartitions(table: CatalogTable): Seq[CatalogTablePartition] /** Returns partitions filtered by predicates for the given table. */ - def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] + def getPartitionsByFilter( + table: CatalogTable, + predicates: Seq[Expression]): Seq[CatalogTablePartition] /** Loads a static partition into an existing table. */ def loadPartition( @@ -181,6 +215,29 @@ private[hive] trait HiveClient { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + /** Create a function in an existing database. */ + def createFunction(db: String, func: CatalogFunction): Unit + + /** Drop an existing function an the database. */ + def dropFunction(db: String, name: String): Unit + + /** Rename an existing function in the database. */ + def renameFunction(db: String, oldName: String, newName: String): Unit + + /** Alter a function whose name matches the one specified in `func`, assuming it exists. */ + def alterFunction(db: String, func: CatalogFunction): Unit + + /** Return an existing function in the database, assuming it exists. */ + final def getFunction(db: String, name: String): CatalogFunction = { + getFunctionOption(db, name).getOrElse(throw new NoSuchFunctionException(db, name)) + } + + /** Return an existing function in the database, or None if it doesn't exist. */ + def getFunctionOption(db: String, name: String): Option[CatalogFunction] + + /** Return the names of all functions that match the given pattern in the database. */ + def listFunctions(db: String, pattern: String): Seq[String] + /** Add a jar into class loader */ def addJar(path: String): Unit @@ -192,4 +249,5 @@ private[hive] trait HiveClient { /** Used for testing only. Removes all metadata from this instance of Hive. */ def reset(): Unit + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index cf1ff55c96fc9..7a007d2acc29c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,24 +18,25 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} -import java.util.{Map => JMap} import scala.collection.JavaConverters._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{TableType => HTableType} -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} -import org.apache.hadoop.hive.ql.{metadata, Driver} -import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceUri} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkConf, SparkException} -import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException +import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -234,167 +235,184 @@ private[hive] class HiveClientImpl( if (getDatabaseOption(databaseName).isDefined) { state.setCurrentDatabase(databaseName) } else { - throw new NoSuchDatabaseException + throw new NoSuchDatabaseException(databaseName) } } - override def createDatabase(database: HiveDatabase): Unit = withHiveState { + override def createDatabase( + database: CatalogDatabase, + ignoreIfExists: Boolean): Unit = withHiveState { client.createDatabase( - new Database( + new HiveDatabase( database.name, - "", - new File(database.location).toURI.toString, - new java.util.HashMap), - true) + database.description, + database.locationUri, + database.properties.asJava), + ignoreIfExists) } - override def getDatabaseOption(name: String): Option[HiveDatabase] = withHiveState { + override def dropDatabase( + name: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = withHiveState { + client.dropDatabase(name, true, ignoreIfNotExists, cascade) + } + + override def alterDatabase(database: CatalogDatabase): Unit = withHiveState { + client.alterDatabase( + database.name, + new HiveDatabase( + database.name, + database.description, + database.locationUri, + database.properties.asJava)) + } + + override def getDatabaseOption(name: String): Option[CatalogDatabase] = withHiveState { Option(client.getDatabase(name)).map { d => - HiveDatabase( + CatalogDatabase( name = d.getName, - location = d.getLocationUri) + description = d.getDescription, + locationUri = d.getLocationUri, + properties = d.getParameters.asScala.toMap) } } + override def listDatabases(pattern: String): Seq[String] = withHiveState { + client.getDatabasesByPattern(pattern).asScala.toSeq + } + override def getTableOption( dbName: String, - tableName: String): Option[HiveTable] = withHiveState { - + tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") - - val hiveTable = Option(client.getTable(dbName, tableName, false)) - val converted = hiveTable.map { h => - - HiveTable( - name = h.getTableName, + Option(client.getTable(dbName, tableName, false)).map { h => + CatalogTable( specifiedDatabase = Option(h.getDbName), - schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - partitionColumns = h.getPartCols.asScala.map(f => - HiveColumn(f.getName, f.getType, f.getComment)), - properties = h.getParameters.asScala.toMap, - serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, + name = h.getTableName, tableType = h.getTableType match { - case HTableType.MANAGED_TABLE => ManagedTable - case HTableType.EXTERNAL_TABLE => ExternalTable - case HTableType.VIRTUAL_VIEW => VirtualView - case HTableType.INDEX_TABLE => IndexTable + case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL_TABLE + case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED_TABLE + case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX_TABLE + case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIRTUAL_VIEW }, - location = shim.getDataLocation(h), - inputFormat = Option(h.getInputFormatClass).map(_.getName), - outputFormat = Option(h.getOutputFormatClass).map(_.getName), - serde = Option(h.getSerializationLib), - viewText = Option(h.getViewExpandedText)).withClient(this) + schema = h.getCols.asScala.map(fromHiveColumn), + partitionColumns = h.getPartCols.asScala.map(fromHiveColumn), + sortColumns = Seq(), + numBuckets = h.getNumBuckets, + createTime = h.getTTable.getCreateTime.toLong * 1000, + lastAccessTime = h.getLastAccessTime.toLong * 1000, + storage = CatalogStorageFormat( + locationUri = shim.getDataLocation(h), + inputFormat = Option(h.getInputFormatClass).map(_.getName), + outputFormat = Option(h.getOutputFormatClass).map(_.getName), + serde = Option(h.getSerializationLib), + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap + ), + properties = h.getParameters.asScala.toMap, + viewOriginalText = Option(h.getViewOriginalText), + viewText = Option(h.getViewExpandedText)) } - converted } - private def toInputFormat(name: String) = - Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] - - private def toOutputFormat(name: String) = - Utils.classForName(name) - .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] - - private def toQlTable(table: HiveTable): metadata.Table = { - val qlTable = new metadata.Table(table.database, table.name) - - qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - qlTable.setPartCols( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } - table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } - - // set owner - qlTable.setOwner(conf.getUser) - // set create time - qlTable.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) - - table.location.foreach { loc => shim.setDataLocation(qlTable, loc) } - table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass) - table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass) - table.serde.foreach(qlTable.setSerializationLib) - - qlTable + override def createView(view: CatalogTable): Unit = withHiveState { + client.createTable(toHiveViewTable(view)) } - private def toViewTable(view: HiveTable): metadata.Table = { - // TODO: this is duplicated with `toQlTable` except the table type stuff. - val tbl = new metadata.Table(view.database, view.name) - tbl.setTableType(HTableType.VIRTUAL_VIEW) - tbl.setSerializationLib(null) - tbl.clearSerDeInfo() - - // TODO: we will save the same SQL string to original and expanded text, which is different - // from Hive. - tbl.setViewOriginalText(view.viewText.get) - tbl.setViewExpandedText(view.viewText.get) - - tbl.setFields(view.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - view.properties.foreach { case (k, v) => tbl.setProperty(k, v) } - - // set owner - tbl.setOwner(conf.getUser) - // set create time - tbl.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) - - tbl + override def alertView(view: CatalogTable): Unit = withHiveState { + client.alterTable(view.qualifiedName, toHiveViewTable(view)) } - override def createView(view: HiveTable): Unit = withHiveState { - client.createTable(toViewTable(view)) + override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { + client.createTable(toHiveTable(table), ignoreIfExists) } - override def alertView(view: HiveTable): Unit = withHiveState { - client.alterTable(view.qualifiedName, toViewTable(view)) + override def dropTable( + dbName: String, + tableName: String, + ignoreIfNotExists: Boolean): Unit = withHiveState { + client.dropTable(dbName, tableName, true, ignoreIfNotExists) } - override def createTable(table: HiveTable): Unit = withHiveState { - val qlTable = toQlTable(table) - client.createTable(qlTable) + override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { + val hiveTable = toHiveTable(table) + // Do not use `table.qualifiedName` here because this may be a rename + val qualifiedTableName = s"${table.database}.$tableName" + client.alterTable(qualifiedTableName, hiveTable) } - override def alterTable(table: HiveTable): Unit = withHiveState { - val qlTable = toQlTable(table) - client.alterTable(table.qualifiedName, qlTable) + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = withHiveState { + val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists) + parts.foreach { s => + addPartitionDesc.addPartition(s.spec.asJava, s.storage.locationUri.orNull) + } + client.createPartitions(addPartitionDesc) + } + + override def dropPartitions( + db: String, + table: String, + specs: Seq[Catalog.TablePartitionSpec]): Unit = withHiveState { + // TODO: figure out how to drop multiple partitions in one call + specs.foreach { s => client.dropPartition(db, table, s.values.toList.asJava, true) } + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[Catalog.TablePartitionSpec], + newSpecs: Seq[Catalog.TablePartitionSpec]): Unit = withHiveState { + require(specs.size == newSpecs.size, "number of old and new partition specs differ") + val catalogTable = getTable(db, table) + val hiveTable = toHiveTable(catalogTable) + specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => + val hivePart = getPartitionOption(catalogTable, oldSpec) + .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } + .getOrElse { throw new NoSuchPartitionException(db, table, oldSpec) } + client.renamePartition(hiveTable, oldSpec.asJava, hivePart) + } } - private def toHivePartition(partition: metadata.Partition): HivePartition = { - val apiPartition = partition.getTPartition - HivePartition( - values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), - storage = HiveStorageDescriptor( - location = apiPartition.getSd.getLocation, - inputFormat = apiPartition.getSd.getInputFormat, - outputFormat = apiPartition.getSd.getOutputFormat, - serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, - serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) + override def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit = withHiveState { + val hiveTable = toHiveTable(getTable(db, table)) + client.alterPartitions(table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) } override def getPartitionOption( - table: HiveTable, - partitionSpec: JMap[String, String]): Option[HivePartition] = withHiveState { - - val qlTable = toQlTable(table) - val qlPartition = client.getPartition(qlTable, partitionSpec, false) - Option(qlPartition).map(toHivePartition) + table: CatalogTable, + spec: Catalog.TablePartitionSpec): Option[CatalogTablePartition] = withHiveState { + val hiveTable = toHiveTable(table) + val hivePartition = client.getPartition(hiveTable, spec.asJava, false) + Option(hivePartition).map(fromHivePartition) } - override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getAllPartitions(client, qlTable).map(toHivePartition) + override def getAllPartitions(table: CatalogTable): Seq[CatalogTablePartition] = withHiveState { + val hiveTable = toHiveTable(table) + shim.getAllPartitions(client, hiveTable).map(fromHivePartition) } override def getPartitionsByFilter( - hTable: HiveTable, - predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) + table: CatalogTable, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { + val hiveTable = toHiveTable(table) + shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) } override def listTables(dbName: String): Seq[String] = withHiveState { client.getAllTables(dbName).asScala } + override def listTables(dbName: String, pattern: String): Seq[String] = withHiveState { + client.getTablesByPattern(dbName, pattern).asScala + } + /** * Runs the specified SQL query using Hive. */ @@ -508,6 +526,34 @@ private[hive] class HiveClientImpl( listBucketingEnabled) } + override def createFunction(db: String, func: CatalogFunction): Unit = withHiveState { + client.createFunction(toHiveFunction(func, db)) + } + + override def dropFunction(db: String, name: String): Unit = withHiveState { + client.dropFunction(db, name) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState { + val catalogFunc = getFunction(db, oldName).copy(name = newName) + val hiveFunc = toHiveFunction(catalogFunc, db) + client.alterFunction(db, oldName, hiveFunc) + } + + override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState { + client.alterFunction(db, func.name, toHiveFunction(func, db)) + } + + override def getFunctionOption( + db: String, + name: String): Option[CatalogFunction] = withHiveState { + Option(client.getFunction(db, name)).map(fromHiveFunction) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = withHiveState { + client.getFunctions(db, pattern).asScala + } + def addJar(path: String): Unit = { val uri = new Path(path).toUri val jarURL = if (uri.getScheme == null) { @@ -541,4 +587,97 @@ private[hive] class HiveClientImpl( client.dropDatabase(db, true, false, true) } } + + + /* -------------------------------------------------------- * + | Helper methods for converting to and from Hive classes | + * -------------------------------------------------------- */ + + private def toInputFormat(name: String) = + Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + + private def toOutputFormat(name: String) = + Utils.classForName(name) + .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] + + private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { + new HiveFunction( + f.name, + db, + f.className, + null, + PrincipalType.USER, + (System.currentTimeMillis / 1000).toInt, + FunctionType.JAVA, + List.empty[ResourceUri].asJava) + } + + private def fromHiveFunction(hf: HiveFunction): CatalogFunction = { + new CatalogFunction(hf.getFunctionName, hf.getClassName) + } + + private def toHiveColumn(c: CatalogColumn): FieldSchema = { + new FieldSchema(c.name, c.dataType, c.comment.orNull) + } + + private def fromHiveColumn(hc: FieldSchema): CatalogColumn = { + new CatalogColumn( + name = hc.getName, + dataType = hc.getType, + nullable = true, + comment = Option(hc.getComment)) + } + + private def toHiveTable(table: CatalogTable): HiveTable = { + val hiveTable = new HiveTable(table.database, table.name) + hiveTable.setTableType(table.tableType match { + case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE + case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE + case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE + case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW + }) + hiveTable.setFields(table.schema.map(toHiveColumn).asJava) + hiveTable.setPartCols(table.partitionColumns.map(toHiveColumn).asJava) + // TODO: set sort columns here too + hiveTable.setOwner(conf.getUser) + hiveTable.setNumBuckets(table.numBuckets) + hiveTable.setCreateTime((table.createTime / 1000).toInt) + hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) + table.storage.locationUri.foreach { loc => shim.setDataLocation(hiveTable, loc) } + table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) + table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) + table.storage.serde.foreach(hiveTable.setSerializationLib) + table.storage.serdeProperties.foreach { case (k, v) => hiveTable.setSerdeParam(k, v) } + table.properties.foreach { case (k, v) => hiveTable.setProperty(k, v) } + table.viewOriginalText.foreach { t => hiveTable.setViewOriginalText(t) } + table.viewText.foreach { t => hiveTable.setViewExpandedText(t) } + hiveTable + } + + private def toHiveViewTable(view: CatalogTable): HiveTable = { + val tbl = toHiveTable(view) + tbl.setTableType(HiveTableType.VIRTUAL_VIEW) + tbl.setSerializationLib(null) + tbl.clearSerDeInfo() + tbl + } + + private def toHivePartition( + p: CatalogTablePartition, + ht: HiveTable): HivePartition = { + new HivePartition(ht, p.spec.asJava, p.storage.locationUri.map { l => new Path(l) }.orNull) + } + + private def fromHivePartition(hp: HivePartition): CatalogTablePartition = { + val apiPartition = hp.getTPartition + CatalogTablePartition( + spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), + storage = CatalogStorageFormat( + locationUri = Option(apiPartition.getSd.getLocation), + inputFormat = Option(apiPartition.getSd.getInputFormat), + outputFormat = Option(apiPartition.getSd.getOutputFormat), + serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 4c0aae6c04bd7..3f81c99c41e14 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} -import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** * Create table and insert the query result into it. @@ -33,7 +33,7 @@ import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} */ private[hive] case class CreateTableAsSelect( - tableDesc: HiveTable, + tableDesc: CatalogTable, query: LogicalPlan, allowExisting: Boolean) extends RunnableCommand { @@ -51,25 +51,25 @@ case class CreateTableAsSelect( import org.apache.hadoop.mapred.TextInputFormat val withFormat = - tableDesc.copy( + tableDesc.withNewStorage( inputFormat = - tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), + tableDesc.storage.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), outputFormat = - tableDesc.outputFormat + tableDesc.storage.outputFormat .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)), - serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName()))) + serde = tableDesc.storage.serde.orElse(Some(classOf[LazySimpleSerDe].getName))) val withSchema = if (withFormat.schema.isEmpty) { // Hive doesn't support specifying the column list for target table in CTAS // However we don't think SparkSQL should follow that. - tableDesc.copy(schema = - query.output.map(c => - HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null))) + tableDesc.copy(schema = query.output.map { c => + CatalogColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType)) + }) } else { withFormat } - hiveContext.catalog.client.createTable(withSchema) + hiveContext.catalog.client.createTable(withSchema, ignoreIfExists = false) // Get the Metastore Relation hiveContext.catalog.lookupRelation(tableIdentifier, None) match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 5da58a73e1e36..2914d03749321 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -21,11 +21,11 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, SQLBuilder} -import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{ HiveContext, HiveMetastoreTypes, SQLBuilder} /** * Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of @@ -34,7 +34,7 @@ import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} // TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is different // from Hive and may not work for some cases like create view on self join. private[hive] case class CreateViewAsSelect( - tableDesc: HiveTable, + tableDesc: CatalogTable, child: LogicalPlan, allowExisting: Boolean, orReplace: Boolean) extends RunnableCommand { @@ -72,7 +72,7 @@ private[hive] case class CreateViewAsSelect( Seq.empty[Row] } - private def prepareTable(sqlContext: SQLContext): HiveTable = { + private def prepareTable(sqlContext: SQLContext): CatalogTable = { val expandedText = if (sqlContext.conf.canonicalView) { try rebuildViewQueryString(sqlContext) catch { case NonFatal(e) => wrapViewTextWithSelect @@ -83,12 +83,16 @@ private[hive] case class CreateViewAsSelect( val viewSchema = { if (tableDesc.schema.isEmpty) { - childSchema.map { attr => - HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + childSchema.map { a => + CatalogColumn(a.name, HiveMetastoreTypes.toMetastoreType(a.dataType)) } } else { - childSchema.zip(tableDesc.schema).map { case (attr, col) => - HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + childSchema.zip(tableDesc.schema).map { case (a, col) => + CatalogColumn( + col.name, + HiveMetastoreTypes.toMetastoreType(a.dataType), + nullable = true, + col.comment) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index feb133d44898a..d316664241f1a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -205,7 +205,7 @@ case class InsertIntoHiveTable( val oldPart = catalog.client.getPartitionOption( catalog.client.getTable(table.databaseName, table.tableName), - partitionSpec.asJava) + partitionSpec) if (oldPart.isEmpty || !ifNotExists) { catalog.client.loadPartition( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala new file mode 100644 index 0000000000000..f73e7e2351447 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader} +import org.apache.spark.util.Utils + + +/** + * Test suite for the [[HiveCatalog]]. + */ +class HiveCatalogSuite extends CatalogTestCases { + + private val client: HiveClient = { + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion).createClient() + } + + protected override val tableInputFormat: String = + "org.apache.hadoop.mapred.SequenceFileInputFormat" + protected override val tableOutputFormat: String = + "org.apache.hadoop.mapred.SequenceFileOutputFormat" + + protected override def newUriForDatabase(): String = Utils.createTempDir().getAbsolutePath + + protected override def resetState(): Unit = client.reset() + + protected override def newEmptyCatalog(): Catalog = new HiveCatalog(client) + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 14a83d53904a6..f8764d4725f63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{QueryTest, Row, SaveMode, SQLConf} -import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} @@ -83,16 +83,16 @@ class DataSourceWithHiveMetastoreCatalogSuite } val hiveTable = catalog.client.getTable("default", "t") - assert(hiveTable.inputFormat === Some(inputFormat)) - assert(hiveTable.outputFormat === Some(outputFormat)) - assert(hiveTable.serde === Some(serde)) + assert(hiveTable.storage.inputFormat === Some(inputFormat)) + assert(hiveTable.storage.outputFormat === Some(outputFormat)) + assert(hiveTable.storage.serde === Some(serde)) - assert(!hiveTable.isPartitioned) - assert(hiveTable.tableType === ManagedTable) + assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.tableType === CatalogTableType.MANAGED_TABLE) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) @@ -114,16 +114,17 @@ class DataSourceWithHiveMetastoreCatalogSuite } val hiveTable = catalog.client.getTable("default", "t") - assert(hiveTable.inputFormat === Some(inputFormat)) - assert(hiveTable.outputFormat === Some(outputFormat)) - assert(hiveTable.serde === Some(serde)) + assert(hiveTable.storage.inputFormat === Some(inputFormat)) + assert(hiveTable.storage.outputFormat === Some(outputFormat)) + assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.tableType === ExternalTable) - assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) + assert(hiveTable.storage.locationUri === + Some(path.toURI.toString.stripSuffix(File.separator))) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) @@ -143,17 +144,16 @@ class DataSourceWithHiveMetastoreCatalogSuite """.stripMargin) val hiveTable = catalog.client.getTable("default", "t") - assert(hiveTable.inputFormat === Some(inputFormat)) - assert(hiveTable.outputFormat === Some(outputFormat)) - assert(hiveTable.serde === Some(serde)) + assert(hiveTable.storage.inputFormat === Some(inputFormat)) + assert(hiveTable.storage.outputFormat === Some(outputFormat)) + assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.isPartitioned === false) - assert(hiveTable.tableType === ExternalTable) - assert(hiveTable.partitionColumns.length === 0) + assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.dataType) === Seq("int", "string")) checkAnswer(table("t"), Row(1, "val_1")) assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 137dadd6c6bb3..e869c0e2bdb71 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -22,15 +22,15 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.catalyst.plans.logical.Generate -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { val parser = new HiveQl(SimpleParserConf()) - private def extractTableDesc(sql: String): (HiveTable, Boolean) = { + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) }.head @@ -53,28 +53,29 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { |AS SELECT * FROM src""".stripMargin val (desc, exists) = extractTableDesc(s1) - assert(exists == true) + assert(exists) assert(desc.specifiedDatabase == Some("mydb")) assert(desc.name == "page_view") - assert(desc.tableType == ExternalTable) - assert(desc.location == Some("/user/external/page_view")) + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.storage.locationUri == Some("/user/external/page_view")) assert(desc.schema == - HiveColumn("viewtime", "int", null) :: - HiveColumn("userid", "bigint", null) :: - HiveColumn("page_url", "string", null) :: - HiveColumn("referrer_url", "string", null) :: - HiveColumn("ip", "string", "IP Address of the User") :: - HiveColumn("country", "string", "country of origination") :: Nil) + CatalogColumn("viewtime", "int") :: + CatalogColumn("userid", "bigint") :: + CatalogColumn("page_url", "string") :: + CatalogColumn("referrer_url", "string") :: + CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: + CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) // TODO will be SQLText assert(desc.viewText == Option("This is the staging page view table")) assert(desc.partitionColumns == - HiveColumn("dt", "string", "date type") :: - HiveColumn("hour", "string", "hour of the day") :: Nil) - assert(desc.serdeProperties == + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.storage.serdeProperties == Map((serdeConstants.SERIALIZATION_FORMAT, "\054"), (serdeConstants.FIELD_DELIM, "\054"))) - assert(desc.inputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) } @@ -98,27 +99,27 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { |AS SELECT * FROM src""".stripMargin val (desc, exists) = extractTableDesc(s2) - assert(exists == true) + assert(exists) assert(desc.specifiedDatabase == Some("mydb")) assert(desc.name == "page_view") - assert(desc.tableType == ExternalTable) - assert(desc.location == Some("/user/external/page_view")) + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.storage.locationUri == Some("/user/external/page_view")) assert(desc.schema == - HiveColumn("viewtime", "int", null) :: - HiveColumn("userid", "bigint", null) :: - HiveColumn("page_url", "string", null) :: - HiveColumn("referrer_url", "string", null) :: - HiveColumn("ip", "string", "IP Address of the User") :: - HiveColumn("country", "string", "country of origination") :: Nil) + CatalogColumn("viewtime", "int") :: + CatalogColumn("userid", "bigint") :: + CatalogColumn("page_url", "string") :: + CatalogColumn("referrer_url", "string") :: + CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: + CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) // TODO will be SQLText assert(desc.viewText == Option("This is the staging page view table")) assert(desc.partitionColumns == - HiveColumn("dt", "string", "date type") :: - HiveColumn("hour", "string", "hour of the day") :: Nil) - assert(desc.serdeProperties == Map()) - assert(desc.inputFormat == Option("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.outputFormat == Option("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.serde == Option("parquet.hive.serde.ParquetHiveSerDe")) + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) } @@ -128,14 +129,15 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(exists == false) assert(desc.specifiedDatabase == None) assert(desc.name == "page_view") - assert(desc.tableType == ManagedTable) - assert(desc.location == None) - assert(desc.schema == Seq.empty[HiveColumn]) + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.storage.locationUri == None) + assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText - assert(desc.serdeProperties == Map()) - assert(desc.inputFormat == Option("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - assert(desc.serde.isEmpty) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) assert(desc.properties == Map()) } @@ -162,14 +164,14 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(exists == false) assert(desc.specifiedDatabase == None) assert(desc.name == "ctas2") - assert(desc.tableType == ManagedTable) - assert(desc.location == None) - assert(desc.schema == Seq.empty[HiveColumn]) + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.storage.locationUri == None) + assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText - assert(desc.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) - assert(desc.inputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d9e4b020fdfcc..0c288bdf8a681 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -25,9 +25,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -724,20 +724,25 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val tableName = "spark6655" withTable(tableName) { val schema = StructType(StructField("int", IntegerType, true) :: Nil) - val hiveTable = HiveTable( + val hiveTable = CatalogTable( specifiedDatabase = Some("default"), name = tableName, + tableType = CatalogTableType.MANAGED_TABLE, schema = Seq.empty, - partitionColumns = Seq.empty, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + serdeProperties = Map( + "path" -> catalog.hiveDefaultTableFilePath(TableIdentifier(tableName))) + ), properties = Map( "spark.sql.sources.provider" -> "json", "spark.sql.sources.schema" -> schema.json, - "EXTERNAL" -> "FALSE"), - tableType = ManagedTable, - serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(TableIdentifier(tableName)))) + "EXTERNAL" -> "FALSE")) - catalog.client.createTable(hiveTable) + catalog.client.createTable(hiveTable, ignoreIfExists = false) invalidateTable(tableName) val actualSchema = table(tableName).schema @@ -916,7 +921,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // As a proxy for verifying that the table was stored in Hive compatible format, we verify that // each column of the table is of native type StringType. assert(catalog.client.getTable("default", "not_skip_hive_metadata").schema - .forall(column => HiveMetastoreTypes.toDataType(column.hiveType) == StringType)) + .forall(column => HiveMetastoreTypes.toDataType(column.dataType) == StringType)) catalog.createDataSourceTable( tableIdent = TableIdentifier("skip_hive_metadata"), @@ -930,6 +935,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // As a proxy for verifying that the table was stored in SparkSQL format, we verify that // the table has a column type as array of StringType. assert(catalog.client.getTable("default", "skip_hive_metadata").schema - .forall(column => HiveMetastoreTypes.toDataType(column.hiveType) == ArrayType(StringType))) + .forall(column => HiveMetastoreTypes.toDataType(column.dataType) == ArrayType(StringType))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index c2c896e5f61bb..488f298981c86 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -26,9 +26,9 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle private def checkTablePath(dbName: String, tableName: String): Unit = { val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName) - val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName + val expectedPath = hiveContext.catalog.client.getDatabase(dbName).locationUri + "/" + tableName - assert(metastoreTable.serdeProperties("path") === expectedPath) + assert(metastoreTable.storage.serdeProperties("path") === expectedPath) } test(s"saveAsTable() to non-default database - with USE - Overwrite") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 1344a2cc4bd37..d850d522be297 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.hadoop.util.VersionInfo import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, NamedExpression} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext @@ -60,8 +61,8 @@ class VersionsSuite extends SparkFunSuite with Logging { hadoopVersion = VersionInfo.getVersion, config = buildConf(), ivyPath = ivyPath).createClient() - val db = new HiveDatabase("default", "") - badClient.createDatabase(db) + val db = new CatalogDatabase("default", "desc", "loc", Map()) + badClient.createDatabase(db, ignoreIfExists = true) } private def getNestedMessages(e: Throwable): String = { @@ -116,29 +117,27 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: createDatabase") { - val db = HiveDatabase("default", "") - client.createDatabase(db) + val db = CatalogDatabase("default", "desc", "loc", Map()) + client.createDatabase(db, ignoreIfExists = true) } test(s"$version: createTable") { val table = - HiveTable( + CatalogTable( specifiedDatabase = Option("default"), name = "src", - schema = Seq(HiveColumn("key", "int", "")), - partitionColumns = Seq.empty, - properties = Map.empty, - serdeProperties = Map.empty, - tableType = ManagedTable, - location = None, - inputFormat = - Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), - outputFormat = - Some(classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), - serde = - Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName())) - - client.createTable(table) + tableType = CatalogTableType.MANAGED_TABLE, + schema = Seq(CatalogColumn("key", "int")), + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), + outputFormat = Some( + classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName()), + serdeProperties = Map.empty + )) + + client.createTable(table, ignoreIfExists = false) } test(s"$version: getTable") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index b91248bfb3fc0..37c01792d9c3f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -149,7 +149,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = if (relation.table.isPartitioned) { + val partValues = if (relation.table.partitionColumns.nonEmpty) { p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty From 03e62aa3f6e16a271262c786be3d1542af79d3e4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 21 Feb 2016 15:27:07 -0800 Subject: [PATCH 1955/2089] [MINOR][DOCS] Fix typos in `configuration.md` and `hardware-provisioning.md` ## What changes were proposed in this pull request? This PR fixes some typos in the following documentation files. * `NOTICE`, `configuration.md`, and `hardware-provisioning.md`. ## How was the this patch tested? manual tests Author: Dongjoon Hyun Author: Dongjoon Hyun Closes #11289 from dongjoon-hyun/minor_fix_typos_notice_and_confdoc. --- docs/configuration.md | 10 +++++----- docs/hardware-provisioning.md | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index f2443e98574cf..568eca9b5fa95 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -249,7 +249,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading - classes in the the driver. This feature can be used to mitigate conflicts between Spark's + classes in the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. This is used in cluster mode only. @@ -373,7 +373,7 @@ Apart from these, the following properties are also available, and may be useful Reuse Python worker or not. If yes, it will use a fixed number of Python workers, does not need to fork() a Python process for every tasks. It will be very useful - if there is large broadcast, then the broadcast will not be needed to transfered + if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -1266,7 +1266,7 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. Putting a "*" in the list means any user can have the priviledge + help debug when things work. Putting a "*" in the list means any user can have the privilege of admin. @@ -1604,7 +1604,7 @@ Apart from these, the following properties are also available, and may be useful #### Deploy - + @@ -1693,7 +1693,7 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config # Overriding configuration directory To specify a different configuration directory other than the default "SPARK_HOME/conf", -you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) +you can set SPARK_CONF_DIR. Spark will use the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. # Inheriting Hadoop Cluster Configuration diff --git a/docs/hardware-provisioning.md b/docs/hardware-provisioning.md index 790220500a1b3..60ecb4f483afa 100644 --- a/docs/hardware-provisioning.md +++ b/docs/hardware-provisioning.md @@ -63,7 +63,7 @@ from the application's monitoring UI (`http://:4040`). # CPU Cores -Spark scales well to tens of CPU cores per machine because it performes minimal sharing between +Spark scales well to tens of CPU cores per machine because it performs minimal sharing between threads. You should likely provision at least **8-16 cores** per machine. Depending on the CPU cost of your workload, you may also need more: once data is in memory, most applications are either CPU- or network-bound. From 76bd98d914ecdf6ff0d73d9d8d221fe3403f8627 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 21 Feb 2016 15:32:49 -0800 Subject: [PATCH 1956/2089] [SPARK-13405][STREAMING][TESTS] Make sure no messages leak to the next test ## What changes were proposed in this pull request? Fixed the test failure `org.apache.spark.sql.util.ContinuousQueryListenerSuite.event ordering`: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-maven-hadoop-2.6/202/testReport/junit/org.apache.spark.sql.util/ContinuousQueryListenerSuite/event_ordering/ ``` org.scalatest.exceptions.TestFailedException: Assert failed: : null equaled null onQueryTerminated called before onQueryStarted org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500) org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:466) org.apache.spark.sql.util.ContinuousQueryListenerSuite$QueryStatusCollector$$anonfun$onQueryTerminated$1.apply$mcV$sp(ContinuousQueryListenerSuite.scala:204) org.scalatest.concurrent.AsyncAssertions$Waiter.apply(AsyncAssertions.scala:349) org.apache.spark.sql.util.ContinuousQueryListenerSuite$QueryStatusCollector.onQueryTerminated(ContinuousQueryListenerSuite.scala:203) org.apache.spark.sql.execution.streaming.ContinuousQueryListenerBus.doPostEvent(ContinuousQueryListenerBus.scala:67) org.apache.spark.sql.execution.streaming.ContinuousQueryListenerBus.doPostEvent(ContinuousQueryListenerBus.scala:32) org.apache.spark.util.ListenerBus$class.postToAll(ListenerBus.scala:63) org.apache.spark.sql.execution.streaming.ContinuousQueryListenerBus.postToAll(ContinuousQueryListenerBus.scala:32) ``` In the previous codes, when the test `adding and removing listener` finishes, there may be still some QueryTerminated events in the listener bus queue. Then when `event ordering` starts to run, it may see these events and throw the above exception. This PR just added `waitUntilEmpty` in `after` to make sure all events be consumed after each test. ## How was the this patch tested? Jenkins tests. Author: Shixiong Zhu Closes #11275 from zsxwing/SPARK-13405. --- .../apache/spark/sql/util/ContinuousQueryListenerSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala index d6cc6ad86bf61..52783281abb00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -41,6 +41,8 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with sqlContext.streams.active.foreach(_.stop()) assert(sqlContext.streams.active.isEmpty) assert(addedListeners.isEmpty) + // Make sure we don't leak any events to the next test + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) } test("single listener") { From 0cbadf28c99721ba1ac22ac57beef9df998ea685 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 21 Feb 2016 15:34:39 -0800 Subject: [PATCH 1957/2089] [SPARK-13271][SQL] Better error message if 'path' is not specified Improved the error message as per discussion in https://github.com/apache/spark/pull/11034#discussion_r52111238. Also made `path` and `metadataPath` in FileStreamSource case insensitive. Author: Shixiong Zhu Closes #11154 from zsxwing/path. --- .../execution/datasources/ResolvedDataSource.scala | 12 +++++++++--- .../org/apache/spark/sql/sources/interfaces.scala | 5 +++-- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index cefa8be0c6007..eec9070beed65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -156,7 +156,9 @@ object ResolvedDataSource extends Logging { } caseInsensitiveOptions.get("paths") .map(_.split("(? val hdfsPath = new Path(pathString) val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -197,7 +199,9 @@ object ResolvedDataSource extends Logging { } caseInsensitiveOptions.get("paths") .map(_.split("(? val hdfsPath = new Path(pathString) val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -260,7 +264,9 @@ object ResolvedDataSource extends Logging { // 3. It's OK that the output path doesn't exist yet; val caseInsensitiveOptions = new CaseInsensitiveMap(options) val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) + val path = new Path(caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + })) val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) path.makeQualified(fs.getUri, fs.getWorkingDirectory) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 428a313ca9dc2..87ea7f510e631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -203,10 +203,11 @@ trait HadoopFsRelationProvider extends StreamSourceProvider { schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { - val path = parameters.getOrElse("path", { + val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - val metadataPath = parameters.getOrElse("metadataPath", s"$path/_metadata") + val metadataPath = caseInsensitiveOptions.getOrElse("metadataPath", s"$path/_metadata") def dataFrameBuilder(files: Array[String]): DataFrame = { val relation = createRelation( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 0c288bdf8a681..cd85e23d8b960 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -548,7 +548,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv "org.apache.spark.sql.json", schema, Map.empty[String, String]) - }.getMessage.contains("key not found: path"), + }.getMessage.contains("'path' is not specified"), "We should complain that path is not specified.") } } From 0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459 Mon Sep 17 00:00:00 2001 From: Franklyn D'souza Date: Sun, 21 Feb 2016 16:58:17 -0800 Subject: [PATCH 1958/2089] [SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns. ## What changes were proposed in this pull request? This PR adds equality operators to UDT classes so that they can be correctly tested for dataType equality during union operations. This was previously causing `"AnalysisException: u"unresolved operator 'Union;""` when trying to unionAll two dataframes with UDT columns as below. ``` from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT from pyspark.sql import types schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)]) a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema) b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema) c = a.unionAll(b) ``` ## How was the this patch tested? Tested using two unit tests in sql/test.py and the DataFrameSuite. Additional information here : https://issues.apache.org/jira/browse/SPARK-13410 Author: Franklyn D'souza Closes #11279 from damnMeddlingKid/udt-union-all. --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ .../spark/sql/types/UserDefinedType.scala | 10 ++++++++++ .../spark/sql/test/ExamplePointUDT.scala | 7 ++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 16 ++++++++++++++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a796924..cc11c0f35cdc9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -601,6 +601,24 @@ def test_parquet_with_udt(self): point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_unionAll_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.sqlCtx.createDataFrame([row1], schema) + df2 = self.sqlCtx.createDataFrame([row2], schema) + + result = df1.unionAll(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index d7a2c23be8a9a..7664c30ee7650 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -86,6 +86,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { this.getClass == dataType.getClass override def sql: String = sqlType.sql + + override def equals(other: Any): Boolean = other match { + case that: UserDefinedType[_] => this.acceptsType(that) + case _ => false + } } /** @@ -112,4 +117,9 @@ private[sql] class PythonUserDefinedType( ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } + + override def equals(other: Any): Boolean = other match { + case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 20a17ba82be9d..e2c9fc421b83f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -26,7 +26,12 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def equals(other: Any): Boolean = other match { + case that: ExamplePoint => this.x == that.x && this.y == that.y + case _ => false + } +} /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50a246489ee55..4930c485da83f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -112,6 +112,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("unionAll should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.unionAll(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) From 3d79f1065cd02133ad9dd4423c09b8c8b52b38e2 Mon Sep 17 00:00:00 2001 From: Robin East Date: Sun, 21 Feb 2016 17:07:17 -0800 Subject: [PATCH 1959/2089] [SPARK-3650][GRAPHX] Triangle Count handles reverse edges incorrectly jegonzal ankurdave please could you review ## What changes were proposed in this pull request? Reworking of jegonzal PR #2495 to address the issue identified in SPARK-3650. Code amended to use the convertToCanonicalEdges method. ## How was the this patch tested? Patch was tested using the unit tests created in PR #2495 Author: Robin East Author: Joseph E. Gonzalez Closes #11290 from insidedctm/spark-3650. --- .../org/apache/spark/graphx/GraphOps.scala | 14 +++++ .../spark/graphx/lib/TriangleCount.scala | 63 +++++++++++++------ .../apache/spark/graphx/GraphOpsSuite.scala | 15 +++++ .../spark/graphx/lib/TriangleCountSuite.scala | 7 ++- 4 files changed, 76 insertions(+), 23 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 97a82239a97ee..3e8c38530252b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -17,10 +17,15 @@ package org.apache.spark.graphx +import scala.reflect.{classTag, ClassTag} import scala.reflect.ClassTag import scala.util.Random +import org.apache.spark.HashPartitioner +import org.apache.spark.SparkContext._ import org.apache.spark.SparkException +import org.apache.spark.graphx.impl.EdgePartitionBuilder +import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.graphx.lib._ import org.apache.spark.rdd.RDD @@ -183,6 +188,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } } + /** + * Remove self edges. + * + * @return a graph with all self edges removed + */ + def removeSelfEdges(): Graph[VD, ED] = { + graph.subgraph(epred = e => e.srcId != e.dstId) + } + /** * Join the vertices with an RDD and then apply a function from the * vertex and RDD entry to a new vertex value. The input table diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index a5d598053f9ca..51bcdf20dec46 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag import org.apache.spark.graphx._ +import org.apache.spark.graphx.PartitionStrategy.EdgePartition2D /** * Compute the number of triangles passing through each vertex. @@ -27,25 +28,47 @@ import org.apache.spark.graphx._ * The algorithm is relatively straightforward and can be computed in three steps: * *
          - *
        • Compute the set of neighbors for each vertex - *
        • For each edge compute the intersection of the sets and send the count to both vertices. - *
        • Compute the sum at each vertex and divide by two since each triangle is counted twice. + *
        • Compute the set of neighbors for each vertex
        • + *
        • For each edge compute the intersection of the sets and send the count to both vertices.
        • + *
        • Compute the sum at each vertex and divide by two since each triangle is counted twice.
        • *
        * - * Note that the input graph should have its edges in canonical direction - * (i.e. the `sourceId` less than `destId`). Also the graph must have been partitioned - * using [[org.apache.spark.graphx.Graph#partitionBy]]. + * There are two implementations. The default `TriangleCount.run` implementation first removes + * self cycles and canonicalizes the graph to ensure that the following conditions hold: + *
          + *
        • There are no self edges
        • + *
        • All edges are oriented src > dst
        • + *
        • There are no duplicate edges
        • + *
        + * However, the canonicalization procedure is costly as it requires repartitioning the graph. + * If the input data is already in "canonical form" with self cycles removed then the + * `TriangleCount.runPreCanonicalized` should be used instead. + * + * {{{ + * val canonicalGraph = graph.mapEdges(e => 1).removeSelfEdges().canonicalizeEdges() + * val counts = TriangleCount.runPreCanonicalized(canonicalGraph).vertices + * }}} + * */ object TriangleCount { def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { - // Remove redundant edges - val g = graph.groupEdges((a, b) => a).cache() + // Transform the edge data something cheap to shuffle and then canonicalize + val canonicalGraph = graph.mapEdges(e => true).removeSelfEdges().convertToCanonicalEdges() + // Get the triangle counts + val counters = runPreCanonicalized(canonicalGraph).vertices + // Join them bath with the original graph + graph.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) => + optCounter.getOrElse(0) + } + } + + def runPreCanonicalized[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Construct set representations of the neighborhoods val nbrSets: VertexRDD[VertexSet] = - g.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) => - val set = new VertexSet(4) + graph.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) => + val set = new VertexSet(nbrs.length) var i = 0 while (i < nbrs.size) { // prevent self cycle @@ -56,14 +79,14 @@ object TriangleCount { } set } + // join the sets with the graph - val setGraph: Graph[VertexSet, ED] = g.outerJoinVertices(nbrSets) { + val setGraph: Graph[VertexSet, ED] = graph.outerJoinVertices(nbrSets) { (vid, _, optSet) => optSet.getOrElse(null) } + // Edge function computes intersection of smaller vertex with larger vertex def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { - assert(ctx.srcAttr != null) - assert(ctx.dstAttr != null) val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { (ctx.srcAttr, ctx.dstAttr) } else { @@ -80,15 +103,15 @@ object TriangleCount { ctx.sendToSrc(counter) ctx.sendToDst(counter) } + // compute the intersection along edges val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _) // Merge counters with the graph and divide by two since each triangle is counted twice - g.outerJoinVertices(counters) { - (vid, _, optCounter: Option[Int]) => - val dblCount = optCounter.getOrElse(0) - // double count should be even (divisible by two) - assert((dblCount & 1) == 0) - dblCount / 2 + graph.outerJoinVertices(counters) { (_, _, optCounter: Option[Int]) => + val dblCount = optCounter.getOrElse(0) + // This algorithm double counts each triangle so the final count should be even + require(dblCount % 2 == 0, "Triangle count resulted in an invalid number of triangles.") + dblCount / 2 } - } // end of TriangleCount + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 57a8b95dd12e9..3967f6683de71 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -55,6 +55,21 @@ class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { } } + test("removeSelfEdges") { + withSpark { sc => + val edgeArray = Array((1 -> 2), (2 -> 3), (3 -> 3), (4 -> 3), (1 -> 1)) + .map { + case (a, b) => (a.toLong, b.toLong) + } + val correctEdges = edgeArray.filter { case (a, b) => a != b }.toSet + val graph = Graph.fromEdgeTuples(sc.parallelize(edgeArray), 1) + val canonicalizedEdges = graph.removeSelfEdges().edges.map(e => (e.srcId, e.dstId)) + .collect + assert(canonicalizedEdges.toSet.size === canonicalizedEdges.size) + assert(canonicalizedEdges.toSet === correctEdges) + } + } + test ("filter") { withSpark { sc => val n = 5 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 608e43cf3ff53..f19c3acdc85cf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -64,9 +64,9 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { val verts = triangleCount.vertices verts.collect().foreach { case (vid, count) => if (vid == 0) { - assert(count === 4) - } else { assert(count === 2) + } else { + assert(count === 1) } } } @@ -75,7 +75,8 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle with duplicate edges") { withSpark { sc => val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ - Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2) + Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ + Array(1L -> 0L, 1L -> 1L), 2) val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices From 55d6fdf22d1d6379180ac09f364c38982897d9ff Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 21 Feb 2016 19:10:17 -0800 Subject: [PATCH 1960/2089] [SPARK-13321][SQL] Support nested UNION in parser JIRA: https://issues.apache.org/jira/browse/SPARK-13321 The following SQL can not be parsed with current parser: SELECT `u_1`.`id` FROM (((SELECT `t0`.`id` FROM `default`.`t0`) UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 We should fix it. Author: Liang-Chi Hsieh Closes #11204 from viirya/nested-union. --- .../sql/catalyst/parser/SparkSqlParser.g | 20 ++++++ .../spark/sql/catalyst/CatalystQlSuite.scala | 64 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index e1908a8e03b30..9aeea69cd2d9b 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -2320,6 +2320,26 @@ regularBody[boolean topLevel] ) | selectStatement[topLevel] + | + (LPAREN selectStatement0[true]) => nestedSetOpSelectStatement[topLevel] + ; + +nestedSetOpSelectStatement[boolean topLevel] + : + ( + LPAREN s=selectStatement0[topLevel] RPAREN -> {$s.tree} + ) + (set=setOpSelectStatement[$nestedSetOpSelectStatement.tree, topLevel]) + -> {set == null}? + {$nestedSetOpSelectStatement.tree} + -> {$set.tree} + ; + +selectStatement0[boolean topLevel] + : + (selectStatement[true]) => selectStatement[topLevel] + | + (nestedSetOpSelectStatement[true]) => nestedSetOpSelectStatement[topLevel] ; selectStatement[boolean topLevel] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 42147f516f964..1d1c0cb216ed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -203,6 +203,70 @@ class CatalystQlSuite extends PlanTest { "from windowData") } + test("nesting UNION") { + val parsed = parser.parsePlan( + """ + |SELECT `u_1`.`id` FROM (((SELECT `t0`.`id` FROM `default`.`t0`) + |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL + |(SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 + """.stripMargin) + + val expected = Project( + UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, + Subquery("u_1", + Union( + Union( + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None)), + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None))))) + + comparePlans(parsed, expected) + + val parsedSame = parser.parsePlan( + """ + |SELECT `u_1`.`id` FROM ((SELECT `t0`.`id` FROM `default`.`t0`) + |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`) UNION ALL + |(SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 + """.stripMargin) + + comparePlans(parsedSame, expected) + + val parsed2 = parser.parsePlan( + """ + |SELECT `u_1`.`id` FROM ((((SELECT `t0`.`id` FROM `default`.`t0`) + |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL + |(SELECT `t0`.`id` FROM `default`.`t0`)) + |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 + """.stripMargin) + + val expected2 = Project( + UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, + Subquery("u_1", + Union( + Union( + Union( + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None)), + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), + Project( + UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, + UnresolvedRelation(TableIdentifier("t0", Some("default")), None))))) + + comparePlans(parsed2, expected2) + } + test("subquery") { parser.parsePlan("select (select max(b) from s) ss from t") parser.parsePlan("select * from t where a = (select b from s)") From 819b0ea029d6ef51e661a59e7672480f57322cbb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 21 Feb 2016 19:11:03 -0800 Subject: [PATCH 1961/2089] [SPARK-13381][SQL] Support for loading CSV with a single function call https://issues.apache.org/jira/browse/SPARK-13381 This PR adds the support to load CSV data directly by a single call with given paths. Also, I corrected this to refer all paths rather than the first path in schema inference, which JSON datasource dose. Several unitests were added for each functionality. Author: hyukjinkwon Closes #11262 from HyukjinKwon/SPARK-13381. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 11 +++++++++++ .../sql/execution/datasources/csv/CSVRelation.scala | 6 +++--- .../sql/execution/datasources/csv/CSVSuite.scala | 9 +++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 962fdadf1431d..20c861de23778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -344,6 +344,17 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { ) } + /** + * Loads a CSV file and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. To avoid going + * through the entire data once, specify the schema explicitly using [[schema]]. + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def csv(paths: String*): DataFrame = format("csv").load(paths : _*) + /** * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty * [[DataFrame]] if no paths are passed in. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 471ed0d5604b2..da945c44cde1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -private[csv] class CSVRelation( +private[sql] class CSVRelation( private val inputRDD: Option[RDD[String]], - override val paths: Array[String], + override val paths: Array[String] = Array.empty[String], private val maybeDataSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], private val parameters: Map[String, String]) @@ -127,7 +127,7 @@ private[csv] class CSVRelation( } private def inferSchema(paths: Array[String]): StructType = { - val rdd = baseRdd(Array(paths.head)) + val rdd = baseRdd(paths) val firstLine = findFirstLine(rdd) val firstRow = new LineCsvReader(params).parseLine(firstLine) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 9d1f4569ad5e9..7671bc106610a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -91,6 +91,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = false, checkTypes = false) } + test("simple csv test with calling another function to load") { + val cars = sqlContext + .read + .option("header", "false") + .csv(testFile(carsFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + test("simple csv test with type inference") { val cars = sqlContext .read From 9bf6a926a1071fa59f76ab12b23df8de618f1cae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 21 Feb 2016 19:37:35 -0800 Subject: [PATCH 1962/2089] [HOTFIX] Fix compilation break --- .../org/apache/spark/sql/catalyst/CatalystQlSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 1d1c0cb216ed0..812aa5acd8e56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.BooleanType import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { @@ -213,7 +212,7 @@ class CatalystQlSuite extends PlanTest { val expected = Project( UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, - Subquery("u_1", + SubqueryAlias("u_1", Union( Union( Project( @@ -247,7 +246,7 @@ class CatalystQlSuite extends PlanTest { val expected2 = Project( UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, - Subquery("u_1", + SubqueryAlias("u_1", Union( Union( Union( From 8a4ed78869e99c7de7062c3baa0ddb9d28c8e9b1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 21 Feb 2016 20:20:41 -0800 Subject: [PATCH 1963/2089] [SPARK-13379][MLLIB] Fix MLlib LogisticRegressionWithLBFGS set regularization incorrectly ## What changes were proposed in this pull request? Fix MLlib LogisticRegressionWithLBFGS regularization map as: ```SquaredL2Updater``` -> ```elasticNetParam = 0.0``` ```L1Updater``` -> ```elasticNetParam = 1.0``` cc dbtsai ## How was the this patch tested? unit tests Author: Yanbo Liang Closes #11258 from yanboliang/spark-13379. --- .../classification/LogisticRegression.scala | 4 +- .../LogisticRegressionSuite.scala | 348 ++++++++++++++++++ 2 files changed, 350 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index bf68e3edd7ed6..c3882606d7dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -444,8 +444,8 @@ class LogisticRegressionWithLBFGS createModel(weights, mlLogisticRegresionModel.intercept) } optimizer.getUpdater() match { - case x: SquaredL2Updater => runWithMlLogisitcRegression(1.0) - case x: L1Updater => runWithMlLogisitcRegression(0.0) + case x: SquaredL2Updater => runWithMlLogisitcRegression(0.0) + case x: L1Updater => runWithMlLogisitcRegression(1.0) case _ => super.run(input, initialWeights) } } else { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 8fef1316cd216..d140545e377f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -171,6 +172,37 @@ object LogisticRegressionSuite { class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { + + @transient var binaryDataset: RDD[LabeledPoint] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + coefficients, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + */ + binaryDataset = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( + coefficients, xMean, xVariance, true, nPoints, 42) + + sc.parallelize(testData, 2) + } + } + def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -555,6 +587,322 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w } } + /** + * From Spark 2.0, MLlib LogisticRegressionWithLBFGS will call the LogisticRegression + * implementation in ML to train model. We copies test cases from ML to guarantee + * they produce the same result. + */ + test("binary logistic regression with intercept without regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 + */ + val interceptR = 2.8366423 + val coefficientsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= coefficientsR relTol 1E-3) + + // Without regularization, with or without feature scaling will converge to the same solution. + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= coefficientsR relTol 1E-3) + } + + test("binary logistic regression without intercept without regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 + */ + val interceptR = 0.0 + val coefficientsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= coefficientsR relTol 1E-2) + + // Without regularization, with or without feature scaling should converge to the same solution. + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= coefficientsR relTol 1E-2) + } + + test("binary logistic regression with intercept with L1 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 + */ + val interceptR1 = -0.05627428 + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.weights ~= coefficientsR1 absTol 2E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.3722152 + data.V2 . + data.V3 . + data.V4 -0.1665453 + data.V5 . + */ + val interceptR2 = 0.3722152 + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.weights ~= coefficientsR2 absTol 1E-3) + } + + test("binary logistic regression without intercept with L1 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= coefficientsR1 absTol 1E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE, standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.08420782 + data.V5 . + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= coefficientsR2 absTol 1E-3) + } + + test("binary logistic regression with intercept with L2 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 + */ + val interceptR1 = 0.15021751 + val coefficientsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= coefficientsR1 relTol 1E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.48657516 + data.V2 -0.05155371 + data.V3 0.02301057 + data.V4 -0.11482896 + data.V5 -0.06266838 + */ + val interceptR2 = 0.48657516 + val coefficientsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= coefficientsR2 relTol 1E-3) + } + + test("binary logistic regression without intercept with L2 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= coefficientsR1 relTol 1E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE, standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.005679651 + data.V3 0.048967094 + data.V4 -0.093714016 + data.V5 -0.053314311 + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= coefficientsR2 relTol 1E-2) + } + } class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { From 39ff15457026767a4d9ff191174fc85e7907f489 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 22 Feb 2016 00:57:10 -0800 Subject: [PATCH 1964/2089] [SPARK-13426][CORE] Remove the support of SIMR ## What changes were proposed in this pull request? This PR removes the support of SIMR, since SIMR is not actively used and maintained for a long time, also is not supported from `SparkSubmit`, so here propose to remove it. ## How was the this patch tested? This patch is tested locally by running unit tests. Author: jerryshao Closes #11296 from jerryshao/SPARK-13426. --- .../scala/org/apache/spark/SparkContext.scala | 10 +-- .../cluster/SimrSchedulerBackend.scala | 75 ------------------- .../SparkContextSchedulerCreationSuite.scala | 9 +-- project/MimaExcludes.scala | 3 + 4 files changed, 5 insertions(+), 92 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index fa8c0f524bf3a..c001df31aac8e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -56,7 +56,7 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SimrSchedulerBackend, +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend @@ -2453,12 +2453,6 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case SIMR_REGEX(simrUrl) => - val scheduler = new TaskSchedulerImpl(sc) - val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) - scheduler.initialize(backend) - (backend, scheduler) - case zkUrl if zkUrl.startsWith("zk://") => logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") @@ -2484,8 +2478,6 @@ private object SparkMasterRegex { val SPARK_REGEX = """spark://(.*)""".r // Regular expression for connection to Mesos cluster by mesos:// or mesos://zk:// url val MESOS_REGEX = """mesos://(.*)""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala deleted file mode 100644 index a298cf5ef9a6d..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rpc.RpcEndpointAddress -import org.apache.spark.scheduler.TaskSchedulerImpl - -private[spark] class SimrSchedulerBackend( - scheduler: TaskSchedulerImpl, - sc: SparkContext, - driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) - with Logging { - - val tmpPath = new Path(driverFilePath + "_tmp") - val filePath = new Path(driverFilePath) - - val maxCores = conf.getInt("spark.simr.executor.cores", 1) - - override def start() { - super.start() - - val driverUrl = RpcEndpointAddress( - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port").toInt, - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - - val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) - val fs = FileSystem.get(conf) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - - logInfo("Writing to HDFS file: " + driverFilePath) - logInfo("Writing Driver address: " + driverUrl) - logInfo("Writing Spark UI Address: " + appUIAddress) - - // Create temporary file to prevent race condition where executors get empty driverUrl file - val temp = fs.create(tmpPath, true) - temp.writeUTF(driverUrl) - temp.writeInt(maxCores) - temp.writeUTF(appUIAddress) - temp.close() - - // "Atomic" rename - fs.rename(tmpPath, filePath) - } - - override def stop() { - val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) - val fs = FileSystem.get(conf) - if (!fs.delete(new Path(driverFilePath), false)) { - logWarning(s"error deleting ${driverFilePath}") - } - super.stop() - } - -} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 52919c1ec0b1e..b96c937f02ecf 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import org.scalatest.PrivateMethodTester import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.cluster.SparkDeploySchedulerBackend import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.util.Utils @@ -115,13 +115,6 @@ class SparkContextSchedulerCreationSuite } } - test("simr") { - createTaskScheduler("simr://uri").backend match { - case s: SimrSchedulerBackend => // OK - case _ => fail() - } - } - test("local-cluster") { createTaskScheduler("local-cluster[3, 14, 1024]").backend match { case s: SparkDeploySchedulerBackend => // OK diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8f31a81ed01e7..97a1e8b433281 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -261,6 +261,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") + ) ++Seq( + // SPARK-13426 Remove the support of SIMR + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") ) case v if v.startsWith("1.6") => Seq( From 8f35d3eac9268127512851e52864e64b0bae2f33 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 22 Feb 2016 09:44:32 +0000 Subject: [PATCH 1965/2089] [SPARK-13186][STREAMING] migrate away from SynchronizedMap trait SynchronizedMap in package mutable is deprecated: Synchronization via traits is deprecated as it is inherently unreliable. Change to java.util.concurrent.ConcurrentHashMap instead. Author: Huaxin Gao Closes #11250 from huaxingao/spark__13186. --- .../streaming/kafka/KafkaStreamSuite.scala | 13 ++++--- .../kinesis/KinesisStreamSuite.scala | 38 ++++++++++--------- .../streaming/dstream/FileInputDStream.scala | 30 ++++++++------- .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingListenerSuite.scala | 12 ++++-- 5 files changed, 55 insertions(+), 41 deletions(-) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 797b07f80d8ee..6a35ac14a8f6f 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -65,19 +65,20 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] + val result = new mutable.HashMap[String, Long]() stream.map(_._2).countByValue().foreachRDD { r => - val ret = r.collect() - ret.toMap.foreach { kv => - val count = result.getOrElseUpdate(kv._1, 0) + kv._2 - result.put(kv._1, count) + r.collect().foreach { kv => + result.synchronized { + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } } } ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert(sent === result) + assert(result.synchronized { sent === result }) } } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ee6a5f0390d04..ca5d13da46e99 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -230,7 +230,6 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val awsCredentials = KinesisTestUtils.getAWSCredentials() val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] - with mutable.SynchronizedMap[Time, (Array[SequenceNumberRanges], Seq[Int])] val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, @@ -241,13 +240,16 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq - collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + collectedData.synchronized { + collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + } }) ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint ssc.start() - def numBatchesWithData: Int = collectedData.count(_._2._2.nonEmpty) + def numBatchesWithData: Int = + collectedData.synchronized { collectedData.count(_._2._2.nonEmpty) } def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty @@ -268,21 +270,23 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges // and return the same data - val times = collectedData.keySet - times.foreach { time => - val (arrayOfSeqNumRanges, data) = collectedData(time) - val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] - rdd shouldBe a [KinesisBackedBlockRDD[_]] - - // Verify the recovered sequence ranges - val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] - assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) - arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => - assert(expected.ranges.toSeq === found.ranges.toSeq) + collectedData.synchronized { + val times = collectedData.keySet + times.foreach { time => + val (arrayOfSeqNumRanges, data) = collectedData(time) + val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] + rdd shouldBe a[KinesisBackedBlockRDD[_]] + + // Verify the recovered sequence ranges + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) + arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => + assert(expected.ranges.toSeq === found.ranges.toSeq) + } + + // Verify the recovered data + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) } - - // Verify the recovered data - assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) } ssc.stop() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 1c2325409b53e..a25dada5ea4f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -117,7 +117,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Map of batch-time to selected file info for the remembered batches // This is a concurrent map because it's also accessed in unit tests @transient private[streaming] var batchTimeToSelectedFiles = - new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] + new mutable.HashMap[Time, Array[String]] // Set of files that were selected in the remembered batches @transient private var recentlySelectedFiles = new mutable.HashSet[String]() @@ -148,7 +148,9 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Find new files val newFiles = findNewFiles(validTime.milliseconds) logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) - batchTimeToSelectedFiles += ((validTime, newFiles)) + batchTimeToSelectedFiles.synchronized { + batchTimeToSelectedFiles += ((validTime, newFiles)) + } recentlySelectedFiles ++= newFiles val rdds = Some(filesToRDD(newFiles)) // Copy newFiles to immutable.List to prevent from being modified by the user @@ -162,14 +164,15 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) - batchTimeToSelectedFiles --= oldFiles.keys - recentlySelectedFiles --= oldFiles.values.flatten - logInfo("Cleared " + oldFiles.size + " old files that were older than " + - (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) - logDebug("Cleared files are:\n" + - oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + batchTimeToSelectedFiles.synchronized { + val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) + batchTimeToSelectedFiles --= oldFiles.keys + recentlySelectedFiles --= oldFiles.values.flatten + logInfo("Cleared " + oldFiles.size + " old files that were older than " + + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) + logDebug("Cleared files are:\n" + + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + } // Delete file mod times that weren't accessed in the last round of getting new files fileToModTime.clearOldValues(lastNewFileFindingTime - 1) } @@ -307,8 +310,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() - batchTimeToSelectedFiles = - new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] + batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() fileToModTime = new TimeStampedHashMap[String, Long](true) } @@ -324,7 +326,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( override def update(time: Time) { hadoopFiles.clear() - hadoopFiles ++= batchTimeToSelectedFiles + batchTimeToSelectedFiles.synchronized { hadoopFiles ++= batchTimeToSelectedFiles } } override def cleanup(time: Time) { } @@ -335,7 +337,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + f.mkString("[", ", ", "]") ) - batchTimeToSelectedFiles += ((t, f)) + batchTimeToSelectedFiles.synchronized { batchTimeToSelectedFiles += ((t, f)) } recentlySelectedFiles ++= f generatedRDDs += ((t, filesToRDD(f))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 1f0245a397570..dada495843ef8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -613,7 +613,8 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester def recordedFiles(ssc: StreamingContext): Seq[Int] = { val fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - val filenames = fileInputDStream.batchTimeToSelectedFiles.values.flatten + val filenames = fileInputDStream.batchTimeToSelectedFiles.synchronized + { fileInputDStream.batchTimeToSelectedFiles.values.flatten } filenames.map(_.split(File.separator).last.toInt).toSeq.sorted } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 66f47394c7ac1..6c60652cd69fd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -270,7 +270,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } _ssc.stop() - failureReasonsCollector.failureReasons.toMap + failureReasonsCollector.failureReasons.synchronized + { + failureReasonsCollector.failureReasons.toMap + } } /** Check if a sequence of numbers is in increasing order */ @@ -354,12 +357,15 @@ class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_O */ class FailureReasonsCollector extends StreamingListener { - val failureReasons = new HashMap[Int, String] with SynchronizedMap[Int, String] + val failureReasons = new HashMap[Int, String] override def onOutputOperationCompleted( outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { outputOperationCompleted.outputOperationInfo.failureReason.foreach { f => - failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + failureReasons.synchronized + { + failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + } } } } From ef1047fca789e5470b7b12974f0435d6d1c4f2d5 Mon Sep 17 00:00:00 2001 From: Yong Gang Cao Date: Mon, 22 Feb 2016 09:47:36 +0000 Subject: [PATCH 1966/2089] [SPARK-12153][SPARK-7617][MLLIB] add support of arbitrary length sentence and other tuning for Word2Vec add support of arbitrary length sentence by using the nature representation of sentences in the input. add new similarity functions and add normalization option for distances in synonym finding add new accessor for internal structure(the vocabulary and wordindex) for convenience need instructions about how to set value for the Since annotation for newly added public functions. 1.5.3? jira link: https://issues.apache.org/jira/browse/SPARK-12153 Author: Yong Gang Cao Author: Yong-Gang Cao Closes #10152 from ygcao/improvementForSentenceBoundary. --- .../apache/spark/mllib/feature/Word2Vec.scala | 74 +++++++++---------- .../spark/ml/feature/Word2VecSuite.scala | 2 +- python/pyspark/ml/feature.py | 12 +-- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index dee898827f30f..3241ebeb22c42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -76,6 +76,18 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 + private var maxSentenceLength = 1000 + + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size (default: 1000) + */ + @Since("2.0.0") + def setMaxSentenceLength(maxSentenceLength: Int): this.type = { + this.maxSentenceLength = maxSentenceLength + this + } /** * Sets vector size (default: 100). @@ -146,7 +158,6 @@ class Word2Vec extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 - private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ private var window = 5 @@ -156,7 +167,9 @@ class Word2Vec extends Serializable with Logging { @transient private var vocab: Array[VocabWord] = null @transient private var vocabHash = mutable.HashMap.empty[String, Int] - private def learnVocab(words: RDD[String]): Unit = { + private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { + val words = dataset.flatMap(x => x) + vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .filter(_._2 >= minCount) @@ -272,15 +285,14 @@ class Word2Vec extends Serializable with Logging { /** * Computes the vector representation of each word in vocabulary. - * @param dataset an RDD of words + * @param dataset an RDD of sentences, + * each sentence is expressed as an iterable collection of words * @return a Word2VecModel */ @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { - val words = dataset.flatMap(x => x) - - learnVocab(words) + learnVocab(dataset) createBinaryTree() @@ -289,25 +301,15 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - - val sentences: RDD[Array[Int]] = words.mapPartitions { iter => - new Iterator[Array[Int]] { - def hasNext: Boolean = iter.hasNext - - def next(): Array[Int] = { - val sentence = ArrayBuilder.make[Int] - var sentenceLength = 0 - while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = bcVocabHash.value.get(iter.next()) - word match { - case Some(w) => - sentence += w - sentenceLength += 1 - case None => - } - } - sentence.result() - } + // each partition is a collection of sentences, + // will be translated into arrays of Index integer + val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => + // Each sentence will map to 0 or more Array[Int] + sentenceIter.flatMap { sentence => + // Sentence of words, some of which map to a word index + val wordIndexes = sentence.flatMap(bcVocabHash.value.get) + // break wordIndexes into trunks of maxSentenceLength when has more + wordIndexes.grouped(maxSentenceLength).map(_.toArray) } } @@ -477,15 +479,6 @@ class Word2VecModel private[spark] ( this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } - private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { - require(v1.length == v2.length, "Vectors should have the same length") - val n = v1.length - val norm1 = blas.snrm2(n, v1, 1) - val norm2 = blas.snrm2(n, v2, 1) - if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 - } - override protected def formatVersion = "1.0" @Since("1.4.0") @@ -542,6 +535,7 @@ class Word2VecModel private[spark] ( // Need not divide with the norm of the given vector since it is constant. val cosVec = cosineVec.map(_.toDouble) var ind = 0 + val vecNorm = blas.snrm2(vectorSize, fVector, 1) while (ind < numWords) { val norm = wordVecNorms(ind) if (norm == 0.0) { @@ -551,12 +545,17 @@ class Word2VecModel private[spark] ( } ind += 1 } - wordList.zip(cosVec) + var topResults = wordList.zip(cosVec) .toSeq - .sortBy(- _._2) + .sortBy(-_._2) .take(num + 1) .tail - .toArray + if (vecNorm != 0.0f) { + topResults = topResults.map { case (word, cosVal) => + (word, cosVal / vecNorm) + } + } + topResults.toArray } /** @@ -568,6 +567,7 @@ class Word2VecModel private[spark] ( (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) } } + } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a73b565125668..f094c550e545a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -133,7 +133,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823) + val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) val (synonyms, similarity) = model.findSynonyms("a", 2).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index d017a231886cb..464c9446f2f39 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1836,12 +1836,12 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has +----+--------------------+ ... >>> model.findSynonyms("a", 2).show() - +----+--------------------+ - |word| similarity| - +----+--------------------+ - | b| 0.16782984556103436| - | c|-0.46761559092107646| - +----+--------------------+ + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b| 0.2505344027513247| + | c|-0.6980510075367647| + +----+-------------------+ ... >>> model.transform(doc).head().model DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461]) From 1b144455b620861d8cc790d3fc69902717f14524 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 22 Feb 2016 09:50:51 +0000 Subject: [PATCH 1967/2089] [SPARK-13399][STREAMING] Fix checkpointsuite type erasure warnings ## What changes were proposed in this pull request? Change the checkpointsuite getting the outputstreams to explicitly be unchecked on the generic type so as to avoid the warnings. This only impacts test code. Alternatively we could encode the type tag in the TestOutputStreamWithPartitions and filter the type tag as well - but this is unnecessary since multiple testoutputstreams are not registered and the previous code was not actually checking this type. ## How was the this patch tested? unit tests (streaming/testOnly org.apache.spark.streaming.CheckpointSuite) Author: Holden Karau Closes #11286 from holdenk/SPARK-13399-checkpointsuite-type-erasure. --- .../spark/streaming/CheckpointSuite.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index dada495843ef8..ca716cf4e6f11 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -133,6 +133,17 @@ trait DStreamCheckpointTester { self: SparkFunSuite => new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) } + /** + * Get the first TestOutputStreamWithPartitions, does not check the provided generic type. + */ + protected def getTestOutputStream[V: ClassTag](streams: Array[DStream[_]]): + TestOutputStreamWithPartitions[V] = { + streams.collect { + case ds: TestOutputStreamWithPartitions[V @unchecked] => ds + }.head + } + + protected def generateOutput[V: ClassTag]( ssc: StreamingContext, targetBatchTime: Time, @@ -150,9 +161,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => clock.setTime(targetBatchTime.milliseconds) logInfo("Manual clock after advancing = " + clock.getTimeMillis()) - val outputStream = ssc.graph.getOutputStreams().filter { dstream => - dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] - }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams()) eventually(timeout(10 seconds)) { ssc.awaitTerminationOrTimeout(10) @@ -908,9 +917,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams().filter { dstream => - dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] - }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams()) outputStream.output.asScala.map(_.flatten) } } From 024482bf51e8158eed08a7dc0758f585baf86e1f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 22 Feb 2016 09:52:07 +0000 Subject: [PATCH 1968/2089] [MINOR][DOCS] Fix all typos in markdown files of `doc` and similar patterns in other comments ## What changes were proposed in this pull request? This PR tries to fix all typos in all markdown files under `docs` module, and fixes similar typos in other comments, too. ## How was the this patch tested? manual tests. Author: Dongjoon Hyun Closes #11300 from dongjoon-hyun/minor_fix_typos. --- R/pkg/R/functions.R | 6 +++--- R/pkg/R/sparkR.R | 2 +- .../java/org/apache/spark/util/sketch/CountMinSketch.java | 2 +- core/src/main/scala/org/apache/spark/CacheManager.scala | 2 +- .../scala/org/apache/spark/rdd/OrderedRDDFunctions.scala | 2 +- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/deploy/history/HistoryServerSuite.scala | 6 +++--- docs/ml-classification-regression.md | 2 +- docs/ml-features.md | 6 +++--- docs/ml-guide.md | 2 +- docs/mllib-clustering.md | 6 +++--- docs/mllib-evaluation-metrics.md | 6 +++--- docs/mllib-frequent-pattern-mining.md | 2 +- docs/monitoring.md | 2 +- docs/programming-guide.md | 2 +- docs/running-on-mesos.md | 4 ++-- docs/spark-standalone.md | 2 +- docs/sql-programming-guide.md | 2 +- docs/streaming-flume-integration.md | 2 +- docs/streaming-kinesis-integration.md | 4 ++-- docs/streaming-programming-guide.md | 2 +- .../main/scala/org/apache/spark/graphx/impl/GraphImpl.scala | 2 +- .../org/apache/spark/graphx/util/GraphGenerators.scala | 2 +- .../main/scala/org/apache/spark/ml/recommendation/ALS.scala | 2 +- python/pyspark/sql/functions.py | 6 +++--- .../sql/catalyst/analysis/DistinctAggregationRewriter.scala | 2 +- .../expressions/aggregate/HyperLogLogPlusPlus.scala | 2 +- .../expressions/codegen/GenerateMutableProjection.scala | 2 +- .../scala/org/apache/spark/sql/RandomDataGenerator.scala | 4 ++-- .../spark/sql/execution/datasources/WriterContainer.scala | 2 +- .../spark/sql/execution/exchange/ExchangeCoordinator.scala | 2 +- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/StreamTest.scala | 4 ++-- .../sql/hive/thriftserver/HiveThriftServer2Suites.scala | 4 ++-- .../scala/org/apache/spark/streaming/util/StateMap.scala | 2 +- .../apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala | 2 +- 36 files changed, 55 insertions(+), 55 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 8f8651c295eec..e5521f3cffadf 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1962,7 +1962,7 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' shiftLeft #' -#' Shift the the given value numBits left. If the given value is a long value, this function +#' Shift the given value numBits left. If the given value is a long value, this function #' will return a long value else it will return an integer value. #' #' @family math_funcs @@ -1980,7 +1980,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the the given value numBits right. If the given value is a long value, it will return +#' Shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @family math_funcs @@ -1998,7 +1998,7 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' shiftRightUnsigned #' -#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' Unsigned shift the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' #' @family math_funcs diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d2bfad553104f..3e9eafc7f5b90 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -299,7 +299,7 @@ sparkRHive.init <- function(jsc = NULL) { #' #' @param sc existing spark context #' @param groupid the ID to be assigned to job groups -#' @param description description for the the job group ID +#' @param description description for the job group ID #' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation #' @examples #'\dontrun{ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 48f98680f48ca..2c9aa93582e6f 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -39,7 +39,7 @@ * Suppose you want to estimate the number of times an element {@code x} has appeared in a data * stream so far. With probability {@code delta}, the estimate of this frequency is within the * range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the - * total count of items have appeared the the data stream so far. + * total count of items have appeared the data stream so far. * * Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array * with depth {@code d} and width {@code w}, where diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 923ff411ce252..1ec9ba7755725 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -120,7 +120,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { * The effective storage level refers to the level that actually specifies BlockManager put * behavior, not the level originally specified by the user. This is mainly for forcing a * MEMORY_AND_DISK partition to disk if there is not enough room to unroll the partition, - * while preserving the the original semantics of the RDD as specified by the application. + * while preserving the original semantics of the RDD as specified by the application. */ private def putInBlockManager[T]( key: BlockId, diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index d71bb63000904..2096a37de9dc7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -76,7 +76,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, } /** - * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`. + * Returns an RDD containing only the elements in the inclusive range `lower` to `upper`. * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be * performed efficiently by only scanning the partitions that might contain matching elements. * Otherwise, a standard `filter` is applied to all partitions. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 379dc14ad7cd1..ba773e1e7b5a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -655,7 +655,7 @@ class DAGScheduler( /** * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter - * can be used to block until the the job finishes executing or can be used to cancel the job. + * can be used to block until the job finishes executing or can be used to cancel the job. * This method is used for adaptive query planning, to run map stages and look at statistics * about their outputs before submitting downstream stages. * diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4b05469c42055..e5cd2eddba1e8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json * metrics api to a set of known "golden files". If new endpoints / parameters are added, - * cases should be added to this test suite. The expected outcomes can be genered by running + * cases should be added to this test suite. The expected outcomes can be generated by running * the HistoryServerSuite.main. Note that this will blindly generate new expectation files matching * the current behavior -- the developer must verify that behavior is correct. * @@ -274,12 +274,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers implicit val webDriver: WebDriver = new HtmlUnitDriver implicit val formats = org.json4s.DefaultFormats - // this test dir is explictly deleted on successful runs; retained for diagnostics when + // this test dir is explicitly deleted on successful runs; retained for diagnostics when // not val logDir = Utils.createDirectory(System.getProperty("java.io.tmpdir", "logs")) // a new conf is used with the background thread set and running at its fastest - // alllowed refresh rate (1Hz) + // allowed refresh rate (1Hz) val myConf = new SparkConf() .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) .set("spark.eventLog.dir", logDir.getAbsolutePath) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 9569a06472cbf..45155c8ad1f12 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -252,7 +252,7 @@ Nodes in the output layer use softmax function: \]` The number of nodes `$N$` in the output layer corresponds to the number of classes. -MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. +MLPC employs backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. **Example** diff --git a/docs/ml-features.md b/docs/ml-features.md index 5809f65d637e4..68d3ea2971172 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -185,7 +185,7 @@ for more details on the API.
        Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Tokenizer) and -the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) +the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. {% include_example python/ml/tokenizer_example.py %} @@ -459,7 +459,7 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. -Additionaly, there are two strategies regarding how `StringIndexer` will handle +Additionally, there are two strategies regarding how `StringIndexer` will handle unseen labels when you have fit a `StringIndexer` on one dataset and then use it to transform another: @@ -779,7 +779,7 @@ for more details on the API. * `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`. -Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potenial out of Bucketizer bounds exception. +Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potential out of Bucketizer bounds exception. Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`. diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 1770aabf6f5da..8eee2fb674881 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -628,7 +628,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName` +for multiclass problems. The default metric used to choose the best `ParamMap` can be overridden by the `setMetricName` method in each of these evaluators. The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d0be03286849a..8e724fbf068b0 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -300,7 +300,7 @@ for i in range(2): ## Power iteration clustering (PIC) Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a -graph given pairwise similarties as edge properties, +graph given pairwise similarities as edge properties, described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. @@ -786,7 +786,7 @@ This example shows how to estimate clusters on streaming data.
        Refer to the [`StreamingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.StreamingKMeans) for details on the API. -First we import the neccessary classes. +First we import the necessary classes. {% highlight scala %} @@ -837,7 +837,7 @@ ssc.awaitTermination()
        Refer to the [`StreamingKMeans` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.StreamingKMeans) for more details on the API. -First we import the neccessary classes. +First we import the necessary classes. {% highlight python %} from pyspark.mllib.linalg import Vectors diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 774826c2703f8..a269dbf030e7c 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -67,7 +67,7 @@ plots (recall, false positive rate) points.
        - + @@ -360,7 +360,7 @@ $$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{ca **Examples** -The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +The following code snippets illustrate how to evaluate the performance of a multilabel classifier. The examples use the fake prediction and label data for multilabel classification that is shown below. Document predictions: @@ -558,7 +558,7 @@ variable from a number of independent variables. - + diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 2c8a8f236163f..a7b55dc5e5668 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -135,7 +135,7 @@ pattern mining problem. included in the results. * `maxLocalProjDBSize`: the maximum number of items allowed in a prefix-projected database before local iterative processing of the - projected databse begins. This parameter should be tuned with respect + projected database begins. This parameter should be tuned with respect to the size of your executors. **Examples** diff --git a/docs/monitoring.md b/docs/monitoring.md index c37f6fb20dd6e..c139e1cb5ac5f 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -108,7 +108,7 @@ The history server can be configured as follows: diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 3de72bc016dd4..fd94c34d1638d 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -335,7 +335,7 @@ By default, standalone scheduling clusters are resilient to Worker failures (ins **Overview** -Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. +Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/trunk/zookeeperStarted.html). diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d246100f3e6aa..c4d277f9bf958 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1372,7 +1372,7 @@ Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation ru 1. The reconciled schema contains exactly those fields defined in Hive metastore schema. - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. - - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + - Any fields that only appear in the Hive metastore schema are added as nullable field in the reconciled schema. #### Metadata Refreshing diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index e2d589b843f99..8eeeee75dbf40 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -30,7 +30,7 @@ See the [Flume's documentation](https://flume.apache.org/documentation.html) for configuring Flume agents. #### Configuring Spark Streaming Application -1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 5f5e2b9087cc9..2a868e8bca268 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -95,7 +95,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + - `streamingContext`: StreamingContext containing an application name used by Kinesis to tie this Kinesis application to the Kinesis stream - `[Kinesis app name]`: The application name that will be used to checkpoint the Kinesis sequence numbers in DynamoDB table. @@ -216,6 +216,6 @@ de-aggregate records during consumption. - Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. -- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable. +- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPositionInStream.LATEST). This is configurable. - InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). - InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 4d1932bc8c3e1..5d67a0a9a986a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -158,7 +158,7 @@ JavaReceiverInputDStream lines = jssc.socketTextStream("localhost", 9999 {% endhighlight %} This `lines` DStream represents the stream of data that will be received from the data -server. Each record in this stream is a line of text. Then, we want to split the the lines by +server. Each record in this stream is a line of text. Then, we want to split the lines by space into words. {% highlight java %} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index c5cb533b13a07..699731b360b59 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -266,7 +266,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } } - /** Test whether the closure accesses the the attribute with name `attrName`. */ + /** Test whether the closure accesses the attribute with name `attrName`. */ private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = { try { BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 280b6c5578fe5..95522299f02bb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -166,7 +166,7 @@ object GraphGenerators extends Logging { } /** - * This method recursively subdivides the the adjacency matrix into quadrants + * This method recursively subdivides the adjacency matrix into quadrants * until it picks a single cell. The naming conventions in this paper match * those of the R-MAT paper. There are a power of 2 number of nodes in the graph. * The adjacency matrix looks like: diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 8f49423af89ee..4be4d6abed6b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -1301,7 +1301,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k, - * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner + * we have getPartition(getPartition(k)) = getPartition(k). Since the default HashPartitioner * satisfies this requirement, we simply use a type alias here. */ private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fdae05d98ceb6..6894c2733828c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -479,7 +479,7 @@ def round(col, scale=0): @since(1.5) def shiftLeft(col, numBits): - """Shift the the given value numBits left. + """Shift the given value numBits left. >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] @@ -490,7 +490,7 @@ def shiftLeft(col, numBits): @since(1.5) def shiftRight(col, numBits): - """Shift the the given value numBits right. + """Shift the given value numBits right. >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() [Row(r=21)] @@ -502,7 +502,7 @@ def shiftRight(col, numBits): @since(1.5) def shiftRightUnsigned(col, numBits): - """Unsigned shift the the given value numBits right. + """Unsigned shift the given value numBits right. >>> df = sqlContext.createDataFrame([(-42,)], ['a']) >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index b49885d469c0f..7518946a943f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -88,7 +88,7 @@ import org.apache.spark.sql.types.IntegerType * this aggregate consists of the original group by clause, all the requested distinct columns * and the group id. Both de-duplication of distinct column and the aggregation of the * non-distinct group take advantage of the fact that we group by the group id (gid) and that we - * have nulled out all non-relevant columns for the the given group. + * have nulled out all non-relevant columns the given group. * 3. Aggregating the distinct groups and combining this with the results of the non-distinct * aggregation. In this step we use the group id to filter the inputs for the aggregate * functions. The result of the non-distinct group are 'aggregated' by using the first operator, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index ec833d6789e85..a474017221721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -238,7 +238,7 @@ case class HyperLogLogPlusPlus( diff * diff } - // Keep moving bounds as long as the the (exclusive) high bound is closer to the estimate than + // Keep moving bounds as long as the (exclusive) high bound is closer to the estimate than // the lower (inclusive) bound. var low = math.max(nearestEstimateIndex - K + 1, 0) var high = math.min(low + K, numEstimates) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5b4dc8df8622b..9abe92b1e7f82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } } - // Evaluate all the the subexpressions. + // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") val updates = validExpr.zip(index).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 7c173cbceefed..8207d64798bdd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -148,7 +148,7 @@ object RandomDataGenerator { // for "0001-01-01 00:00:00.000000". We need to find a // number that is greater or equals to this number as a valid timestamp value. while (milliseconds < -62135740800000L) { - // 253402329599999L is the the number of milliseconds since + // 253402329599999L is the number of milliseconds since // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". milliseconds = rand.nextLong() % 253402329599999L } @@ -163,7 +163,7 @@ object RandomDataGenerator { // for "0001-01-01 00:00:00.000000". We need to find a // number that is greater or equals to this number as a valid timestamp value. while (milliseconds < -62135740800000L) { - // 253402329599999L is the the number of milliseconds since + // 253402329599999L is the number of milliseconds since // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". milliseconds = rand.nextLong() % 253402329599999L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 6340229dbb78b..7e5c8f2f48d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -145,7 +145,7 @@ private[sql] abstract class BaseWriterContainer( // If we are appending data to an existing dir, we will only use the output committer // associated with the file output format since it is not safe to use a custom // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the the appending job fails. + // leave partial data in the destination dir when the appending job fails. // // See SPARK-8578 for more details logInfo( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 6f3bb0ad2bac1..7f54ea97cdd66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the the + * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c6992e18753..510894afac7be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1782,7 +1782,7 @@ object functions extends LegacyFunctions { def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } /** - * Shift the the given value numBits left. If the given value is a long value, this function + * Shift the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. * * @group math_funcs @@ -1791,7 +1791,7 @@ object functions extends LegacyFunctions { def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } /** - * Shift the the given value numBits right. If the given value is a long value, it will return + * Shift the given value numBits right. If the given value is a long value, it will return * a long value else it will return an integer value. * * @group math_funcs @@ -1802,7 +1802,7 @@ object functions extends LegacyFunctions { } /** - * Unsigned shift the the given value numBits right. If the given value is a long value, + * Unsigned shift the given value numBits right. If the given value is a long value, * it will return a long value else it will return an integer value. * * @group math_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 62710e72fba8b..bb5135826e2f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -173,10 +173,10 @@ trait StreamTest extends QueryTest with Timeouts { testStream(stream.toDF())(actions: _*) /** - * Executes the specified actions on the the given streaming DataFrame and provides helpful + * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. * - * Note that if the stream is not explictly started before an action that requires it to be + * Note that if the stream is not explicitly started before an action that requires it to be * running then it will be automatically started before performing any other actions. */ def testStream(stream: DataFrame)(actions: StreamAction*): Unit = { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 865197e24caf8..5f9952a90a57d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -721,13 +721,13 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } /** - * String to scan for when looking for the the thrift binary endpoint running. + * String to scan for when looking for the thrift binary endpoint running. * This can change across Hive versions. */ val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" /** - * String to scan for when looking for the the thrift HTTP endpoint running. + * String to scan for when looking for the thrift HTTP endpoint running. * This can change across Hive versions. */ val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 4ccc905b275d9..2be1d6df86f89 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -364,7 +364,7 @@ private[streaming] object OpenHashMapBasedStateMap { } /** - * Internal class to represent a marker the demarkate the the end of all state data in the + * Internal class to represent a marker the demarkate the end of all state data in the * serialized bytes. */ class LimitMarker(val num: Int) extends Serializable diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index b8daa501af7f0..2ac9e33873ce7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -151,7 +151,7 @@ private[yarn] class AMDelegationTokenRenewer( // passed in already has tokens for that FS even if the tokens are expired (it really only // checks if there are tokens for the service, and not if they are valid). So the only real // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the the current user's Credentials. + // get new tokens and then the new tokens are copied over the current user's Credentials. // So: // - we login as a different user and get the UGI // - use that UGI to get the tokens (see doAs block below) From e298ac91e3f6177c6da83e2d8ee994d9037466da Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 22 Feb 2016 12:48:37 +0200 Subject: [PATCH 1969/2089] [SPARK-12632][PYSPARK][DOC] PySpark fpm and als parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the fpm and recommendation modules. Closes #10602 Closes #10897 Author: Bryan Cutler Author: somideshmukh Closes #11186 from BryanCutler/param-desc-consistent-fpmrecc-SPARK-12632. --- docs/mllib-collaborative-filtering.md | 3 +- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../apache/spark/mllib/fpm/PrefixSpan.scala | 6 +- .../spark/mllib/recommendation/ALS.scala | 114 +++++++++--------- .../MatrixFactorizationModel.scala | 4 +- python/pyspark/mllib/fpm.py | 47 +++++--- python/pyspark/mllib/recommendation.py | 89 +++++++++++--- 7 files changed, 164 insertions(+), 101 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index b8f0566d8735e..5c33292aaf086 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -21,7 +21,8 @@ following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). * *rank* is the number of latent factors in the model. -* *iterations* is the number of iterations to run. +* *iterations* is the number of iterations of ALS to run. ALS typically converges to a reasonable + solution in 20 iterations or less. * *lambda* specifies the regularization parameter in ALS. * *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for *implicit feedback* data. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 1250bc1a07cb4..85d609386ff94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -152,7 +152,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * [[http://dx.doi.org/10.1145/335191.335372 Han et al., Mining frequent patterns without candidate * generation]]. * - * @param minSupport the minimal support level of the frequent pattern, any pattern appears + * @param minSupport the minimal support level of the frequent pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output * @param numPartitions number of partitions used by parallel FP-growth * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index ed49c9492fdcd..94a24b527ba3e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -38,9 +38,9 @@ import org.apache.spark.storage.StorageLevel * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]). * - * @param minSupport the minimal support level of the sequential pattern, any pattern appears - * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears + * @param minSupport the minimal support level of the sequential pattern, any pattern that appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears * less than maxPatternLength will be output * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 33aaf853e599d..3e619c4264f20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -218,7 +218,7 @@ class ALS private ( } /** - * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. + * Run ALS with the configured parameters on an input RDD of [[Rating]] objects. * Returns a MatrixFactorizationModel with feature vectors for each user and product. */ @Since("0.8.0") @@ -279,18 +279,17 @@ class ALS private ( @Since("0.8.0") object ALS { /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. This is done using a level of - * parallelism given by `blocks`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a configurable + * level of parallelism. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into - * @param seed random seed + * @param seed random seed for initial matrix factorization model */ @Since("0.9.1") def train( @@ -305,16 +304,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. This is done using a level of - * parallelism given by `blocks`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a configurable + * level of parallelism. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into */ @Since("0.8.0") @@ -329,16 +327,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a level of + * parallelism automatically based on the number of partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter */ @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) @@ -347,15 +344,14 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a level of + * parallelism automatically based on the number of partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) + * @param iterations number of iterations of ALS */ @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int) @@ -372,11 +368,11 @@ object ALS { * * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into * @param alpha confidence parameter - * @param seed random seed + * @param seed random seed for initial matrix factorization model */ @Since("0.8.1") def trainImplicit( @@ -392,16 +388,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of 'implicit preferences' given by users - * to some products, in the form of (userID, productID, preference) pairs. We approximate the - * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). - * To solve for these features, we run a given number of iterations of ALS. This is done using - * a level of parallelism given by `blocks`. + * Train a matrix factorization model given an RDD of 'implicit preferences' of users for a + * subset of products. The ratings matrix is approximated as the product of two lower-rank + * matrices of a given rank (number of features). To solve for these features, ALS is run + * iteratively with a configurable level of parallelism. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into * @param alpha confidence parameter */ @@ -418,16 +413,16 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of 'implicit preferences' given by users to - * some products, in the form of (userID, productID, preference) pairs. We approximate the - * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). - * To solve for these features, we run a given number of iterations of ALS. The level of - * parallelism is determined automatically based on the number of partitions in `ratings`. + * Train a matrix factorization model given an RDD of 'implicit preferences' of users for a + * subset of products. The ratings matrix is approximated as the product of two lower-rank + * matrices of a given rank (number of features). To solve for these features, ALS is run + * iteratively with a level of parallelism determined automatically based on the number of + * partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param alpha confidence parameter */ @Since("0.8.1") @@ -437,16 +432,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of 'implicit preferences' ratings given by - * users to some products, in the form of (userID, productID, rating) pairs. We approximate the - * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). - * To solve for these features, we run a given number of iterations of ALS. The level of - * parallelism is determined automatically based on the number of partitions in `ratings`. - * Model parameters `alpha` and `lambda` are set to reasonable default values + * Train a matrix factorization model given an RDD of 'implicit preferences' of users for a + * subset of products. The ratings matrix is approximated as the product of two lower-rank + * matrices of a given rank (number of features). To solve for these features, ALS is run + * iteratively with a level of parallelism determined automatically based on the number of + * partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) + * @param iterations number of iterations of ALS */ @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 0dc40483dd0ff..628cf1dd57280 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -206,7 +206,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( } /** - * Recommends topK products for all users. + * Recommends top products for all users. * * @param num how many products to return for every user. * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of @@ -224,7 +224,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** - * Recommends topK users for all products. + * Recommends top users for all products. * * @param num how many users to return for every product. * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 2039decc0cb3c..7a2d77a4dad13 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -29,7 +29,6 @@ @inherit_doc @ignore_unicode_prefix class FPGrowthModel(JavaModelWrapper): - """ .. note:: Experimental @@ -68,11 +67,15 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. - :param data: The input data set, each element contains a - transaction. - :param minSupport: The minimal support level (default: `0.3`). - :param numPartitions: The number of partitions used by - parallel FP-growth (default: same as input data). + :param data: + The input data set, each element contains a transaction. + :param minSupport: + The minimal support level. + (default: 0.3) + :param numPartitions: + The number of partitions used by parallel FP-growth. A value + of -1 will use the same number as input data. + (default: -1) """ model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) @@ -128,17 +131,27 @@ class PrefixSpan(object): @since("1.6.0") def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000): """ - Finds the complete set of frequent sequential patterns in the input sequences of itemsets. - - :param data: The input data set, each element contains a sequnce of itemsets. - :param minSupport: the minimal support level of the sequential pattern, any pattern appears - more than (minSupport * size-of-the-dataset) times will be output (default: `0.1`) - :param maxPatternLength: the maximal length of the sequential pattern, any pattern appears - less than maxPatternLength will be output. (default: `10`) - :param maxLocalProjDBSize: The maximum number of items (including delimiters used in - the internal storage format) allowed in a projected database before local - processing. If a projected database exceeds this size, another - iteration of distributed prefix growth is run. (default: `32000000`) + Finds the complete set of frequent sequential patterns in the + input sequences of itemsets. + + :param data: + The input data set, each element contains a sequence of + itemsets. + :param minSupport: + The minimal support level of the sequential pattern, any + pattern that appears more than (minSupport * + size-of-the-dataset) times will be output. + (default: 0.1) + :param maxPatternLength: + The maximal length of the sequential pattern, any pattern + that appears less than maxPatternLength will be output. + (default: 10) + :param maxLocalProjDBSize: + The maximum number of items (including delimiters used in the + internal storage format) allowed in a projected database before + local processing. If a projected database exceeds this size, + another iteration of distributed prefix growth is run. + (default: 32000000) """ model = callMLlibFunc("trainPrefixSpanModel", data, minSupport, maxPatternLength, maxLocalProjDBSize) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 93e47a797f490..7e60255d43ead 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -138,7 +138,8 @@ def predict(self, user, product): @since("0.9.0") def predictAll(self, user_product): """ - Returns a list of predicted ratings for input user and product pairs. + Returns a list of predicted ratings for input user and product + pairs. """ assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() @@ -165,28 +166,33 @@ def productFeatures(self): @since("1.4.0") def recommendUsers(self, product, num): """ - Recommends the top "num" number of users for a given product and returns a list - of Rating objects sorted by the predicted rating in descending order. + Recommends the top "num" number of users for a given product and + returns a list of Rating objects sorted by the predicted rating in + descending order. """ return list(self.call("recommendUsers", product, num)) @since("1.4.0") def recommendProducts(self, user, num): """ - Recommends the top "num" number of products for a given user and returns a list - of Rating objects sorted by the predicted rating in descending order. + Recommends the top "num" number of products for a given user and + returns a list of Rating objects sorted by the predicted rating in + descending order. """ return list(self.call("recommendProducts", user, num)) def recommendProductsForUsers(self, num): """ - Recommends top "num" products for all users. The number returned may be less than this. + Recommends the top "num" number of products for all users. The + number of recommendations returned per user may be less than "num". """ return self.call("wrappedRecommendProductsForUsers", num) def recommendUsersForProducts(self, num): """ - Recommends top "num" users for all products. The number returned may be less than this. + Recommends the top "num" number of users for all products. The + number of recommendations returned per product may be less than + "num". """ return self.call("wrappedRecommendUsersForProducts", num) @@ -234,11 +240,34 @@ def _prepare(cls, ratings): def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): """ - Train a matrix factorization model given an RDD of ratings given by users to some products, - in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - product of two lower-rank matrices of a given rank (number of features). To solve for these - features, we run a given number of iterations of ALS. This is done using a level of - parallelism given by `blocks`. + Train a matrix factorization model given an RDD of ratings by users + for a subset of products. The ratings matrix is approximated as the + product of two lower-rank matrices of a given rank (number of + features). To solve for these features, ALS is run iteratively with + a configurable level of parallelism. + + :param ratings: + RDD of `Rating` or (userID, productID, rating) tuple. + :param rank: + Rank of the feature matrices computed (number of features). + :param iterations: + Number of iterations of ALS. + (default: 5) + :param lambda_: + Regularization parameter. + (default: 0.01) + :param blocks: + Number of blocks used to parallelize the computation. A value + of -1 will use an auto-configured number of blocks. + (default: -1) + :param nonnegative: + A value of True will solve least-squares with nonnegativity + constraints. + (default: False) + :param seed: + Random seed for initial matrix factorization model. A value + of None will use system time as the seed. + (default: None) """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) @@ -249,11 +278,37 @@ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): """ - Train a matrix factorization model given an RDD of 'implicit preferences' given by users - to some products, in the form of (userID, productID, preference) pairs. We approximate the - ratings matrix as the product of two lower-rank matrices of a given rank (number of - features). To solve for these features, we run a given number of iterations of ALS. - This is done using a level of parallelism given by `blocks`. + Train a matrix factorization model given an RDD of 'implicit + preferences' of users for a subset of products. The ratings matrix + is approximated as the product of two lower-rank matrices of a + given rank (number of features). To solve for these features, ALS + is run iteratively with a configurable level of parallelism. + + :param ratings: + RDD of `Rating` or (userID, productID, rating) tuple. + :param rank: + Rank of the feature matrices computed (number of features). + :param iterations: + Number of iterations of ALS. + (default: 5) + :param lambda_: + Regularization parameter. + (default: 0.01) + :param blocks: + Number of blocks used to parallelize the computation. A value + of -1 will use an auto-configured number of blocks. + (default: -1) + :param alpha: + A constant used in computing confidence. + (default: 0.01) + :param nonnegative: + A value of True will solve least-squares with nonnegativity + constraints. + (default: False) + :param seed: + Random seed for initial matrix factorization model. A value + of None will use system time as the seed. + (default: None) """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) From 40e6d40fe79ce45d511e049133d2f30a2963740b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 22 Feb 2016 12:59:50 +0200 Subject: [PATCH 1970/2089] [SPARK-13334][ML] ML KMeansModel / BisectingKMeansModel / QuantileDiscretizer should set parent ML ```KMeansModel / BisectingKMeansModel / QuantileDiscretizer``` should set parent. cc mengxr Author: Yanbo Liang Closes #11214 from yanboliang/spark-13334. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 2 +- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 2 +- .../org/apache/spark/ml/feature/QuantileDiscretizer.scala | 2 +- .../org/apache/spark/ml/clustering/BisectingKMeansSuite.scala | 1 + .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 1 + .../apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 4 +++- 6 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 0b47cbbac8fe5..45d293bc69618 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -185,7 +185,7 @@ class BisectingKMeans @Since("2.0.0") ( .setSeed($(seed)) val parentModel = bkm.run(rdd) val model = new BisectingKMeansModel(uid, parentModel) - copyValues(model) + copyValues(model.setParent(this)) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc6d5d9280970..b2292e20e2124 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -250,7 +250,7 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) - copyValues(model) + copyValues(model.setParent(this)) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2a294d3881829..1f4cca123310a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -95,7 +95,7 @@ final class QuantileDiscretizer(override val uid: String) val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) val splits = QuantileDiscretizer.getSplits(candidates) val bucketizer = new Bucketizer(uid).setSplits(splits) - copyValues(bucketizer) + copyValues(bucketizer.setParent(this)) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b26571eb9f35a..fc4a4add5d755 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -81,5 +81,6 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) + assert(model.hasParent) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2724e51f31aa4..e5357ba8e220c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -97,6 +97,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) + assert(model.hasParent) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 4fde42972f01b..6a2c601bbed18 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -94,7 +94,9 @@ private object QuantileDiscretizerSuite extends SparkFunSuite { val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") .setNumBuckets(numBucket).setSeed(1) - val result = discretizer.fit(df).transform(df) + val model = discretizer.fit(df) + assert(model.hasParent) + val result = model.transform(df) val transformedFeatures = result.select("result").collect() .map { case Row(transformedFeature: Double) => transformedFeature } From 00461bb911c31aff9c945a14e23df2af4c280c23 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Mon, 22 Feb 2016 11:11:33 -0800 Subject: [PATCH 1971/2089] [SPARK-10749][MESOS] Support multiple roles with mesos cluster mode. Currently the Mesos cluster dispatcher is not using offers from multiple roles correctly, as it simply aggregates all the offers resource values into one, but doesn't apply them correctly before calling the driver as Mesos needs the resources from the offers to be specified which role it originally belongs to. Multiple roles is already supported with fine/coarse grain scheduler, so porting that logic here to the cluster scheduler. https://issues.apache.org/jira/browse/SPARK-10749 Author: Timothy Chen Closes #8872 from tnachen/cluster_multi_roles. --- .../cluster/mesos/MesosClusterScheduler.scala | 55 ++++--- .../mesos/MesosClusterSchedulerSuite.scala | 139 ++++++++++++++++++ .../mesos/MesosClusterSchedulerSuite.scala | 74 ---------- 3 files changed, 170 insertions(+), 98 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 8cda4ff0eb3b3..2df7b1120b1c3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -357,9 +357,10 @@ private[spark] class MesosClusterScheduler( val appJar = CommandInfo.URI.newBuilder() .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() val builder = CommandInfo.newBuilder().addUris(appJar) - val entries = - (conf.getOption("spark.executor.extraLibraryPath").toList ++ - desc.command.libraryPathEntries) + val entries = conf.getOption("spark.executor.extraLibraryPath") + .map(path => Seq(path) ++ desc.command.libraryPathEntries) + .getOrElse(desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { Utils.libraryPathEnvPrefix(entries) } else { @@ -442,9 +443,12 @@ private[spark] class MesosClusterScheduler( options } - private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + private class ResourceOffer( + val offerId: OfferID, + val slaveId: SlaveID, + var resources: JList[Resource]) { override def toString(): String = { - s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + s"Offer id: ${offerId}, resources: ${resources}" } } @@ -463,27 +467,29 @@ private[spark] class MesosClusterScheduler( val driverMem = submission.mem logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") val offerOption = currentOffers.find { o => - o.cpu >= driverCpu && o.mem >= driverMem + getResource(o.resources, "cpus") >= driverCpu && + getResource(o.resources, "mem") >= driverMem } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + s"cpu: $driverCpu, mem: $driverMem") } else { val offer = offerOption.get - offer.cpu -= driverCpu - offer.mem -= driverMem val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() - val cpuResource = createResource("cpus", driverCpu) - val memResource = createResource("mem", driverMem) + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.resources, "cpus", driverCpu) + val (finalResources, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", driverMem) val commandInfo = buildDriverCommand(submission) val appName = submission.schedulerProperties("spark.app.name") val taskInfo = TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for $appName") - .setSlaveId(offer.offer.getSlaveId) + .setSlaveId(offer.slaveId) .setCommand(commandInfo) - .addResources(cpuResource) - .addResources(memResource) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + offer.resources = finalResources.asJava submission.schedulerProperties.get("spark.mesos.executor.docker.image").foreach { image => val container = taskInfo.getContainerBuilder() val volumes = submission.schedulerProperties @@ -496,11 +502,11 @@ private[spark] class MesosClusterScheduler( container, image, volumes = volumes, portmaps = portmaps) taskInfo.setContainer(container.build()) } - val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo]) queuedTasks += taskInfo.build() - logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " + submission.submissionId) - val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + val newState = new MesosClusterSubmissionState(submission, taskId, offer.slaveId, None, new Date(), None) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) @@ -510,14 +516,14 @@ private[spark] class MesosClusterScheduler( } override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { - val currentOffers = offers.asScala.map(o => - new ResourceOffer( - o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) - ).toList - logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + logTrace(s"Received offers from Mesos: \n${offers.asScala.mkString("\n")}") val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() val currentTime = new Date() + val currentOffers = offers.asScala.map { + o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList) + }.toList + stateLock.synchronized { // We first schedule all the supervised drivers that are ready to retry. // This list will be empty if none of the drivers are marked as supervise. @@ -541,9 +547,10 @@ private[spark] class MesosClusterScheduler( tasks.foreach { case (offerId, taskInfos) => driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - offers.asScala - .filter(o => !tasks.keySet.contains(o.getId)) - .foreach(o => driver.declineOffer(o.getId)) + + for (o <- currentOffers if !tasks.contains(o.offerId)) { + driver.declineOffer(o.offerId) + } } private def copyBuffer( diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..dbef6868f20e2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.{Collection, Collections, Date} + +import scala.collection.JavaConverters._ + +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.{Scalar, Type} +import org.apache.mesos.SchedulerDriver +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription + + +class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), Map(), Seq(), Seq(), Seq()) + private var scheduler: MesosClusterScheduler = _ + + override def beforeEach(): Unit = { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + } + + test("can queue drivers") { + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } + + test("can handle multiple roles") { + val driver = mock[SchedulerDriver] + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1200, 1.5, true, + command, + Map(("spark.mesos.executor.home", "test"), ("spark.app.name", "test")), + "s1", + new Date())) + assert(response.success) + val offer = Offer.newBuilder() + .addResources( + Resource.newBuilder().setRole("*") + .setScalar(Scalar.newBuilder().setValue(1).build()).setName("cpus").setType(Type.SCALAR)) + .addResources( + Resource.newBuilder().setRole("*") + .setScalar(Scalar.newBuilder().setValue(1000).build()) + .setName("mem") + .setType(Type.SCALAR)) + .addResources( + Resource.newBuilder().setRole("role2") + .setScalar(Scalar.newBuilder().setValue(1).build()).setName("cpus").setType(Type.SCALAR)) + .addResources( + Resource.newBuilder().setRole("role2") + .setScalar(Scalar.newBuilder().setValue(500).build()).setName("mem").setType(Type.SCALAR)) + .setId(OfferID.newBuilder().setValue("o1").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setHostname("host1") + .build() + + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) + + when( + driver.launchTasks( + Matchers.eq(Collections.singleton(offer.getId)), + capture.capture()) + ).thenReturn(Status.valueOf(1)) + + scheduler.resourceOffers(driver, Collections.singletonList(offer)) + + val taskInfos = capture.getValue + assert(taskInfos.size() == 1) + val taskInfo = taskInfos.iterator().next() + val resources = taskInfo.getResourcesList + assert(scheduler.getResource(resources, "cpus") == 1.5) + assert(scheduler.getResource(resources, "mem") == 1200) + val resourcesSeq: Seq[Resource] = resources.asScala + val cpus = resourcesSeq.filter(_.getName.equals("cpus")).toList + assert(cpus.size == 2) + assert(cpus.exists(_.getRole().equals("role2"))) + assert(cpus.exists(_.getRole().equals("*"))) + val mem = resourcesSeq.filter(_.getName.equals("mem")).toList + assert(mem.size == 2) + assert(mem.exists(_.getRole().equals("role2"))) + assert(mem.exists(_.getRole().equals("*"))) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer.getId)), + capture.capture() + ) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala deleted file mode 100644 index 98fdc58786ec8..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.mesos - -import java.util.Date - -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.Command -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.scheduler.cluster.mesos._ - -class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { - - private val command = new Command("mainClass", Seq("arg"), null, null, null, null) - - test("can queue drivers") { - val conf = new SparkConf() - conf.setMaster("mesos://localhost:5050") - conf.setAppName("spark mesos") - val scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { ready = true } - } - scheduler.start() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) - assert(response.success) - val response2 = - scheduler.submitDriver(new MesosDriverDescription( - "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) - assert(response2.success) - val state = scheduler.getSchedulerState() - val queuedDrivers = state.queuedDrivers.toList - assert(queuedDrivers(0).submissionId == response.submissionId) - assert(queuedDrivers(1).submissionId == response2.submissionId) - } - - test("can kill queued drivers") { - val conf = new SparkConf() - conf.setMaster("mesos://localhost:5050") - conf.setAppName("spark mesos") - val scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { ready = true } - } - scheduler.start() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) - assert(response.success) - val killResponse = scheduler.killDriver(response.submissionId) - assert(killResponse.success) - val state = scheduler.getSchedulerState() - assert(state.queuedDrivers.isEmpty) - } -} From 4a91806a45a48432c3ea4c2aaa553177952673e9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 22 Feb 2016 14:01:35 -0800 Subject: [PATCH 1972/2089] [SPARK-13413] Remove SparkContext.metricsSystem ## What changes were proposed in this pull request? This patch removes SparkContext.metricsSystem. SparkContext.metricsSystem returns MetricsSystem, which is a private class. I think it was added by accident. In addition, I also removed an unused private[spark] method schedulerBackend setter. ## How was the this patch tested? N/A. Author: Reynold Xin This patch had conflicts when merged, resolved by Committer: Josh Rosen Closes #11282 from rxin/SPARK-13413. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 9 ++------- project/MimaExcludes.scala | 6 +++++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c001df31aac8e..cd7eed382e346 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -297,9 +297,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val sparkUser = Utils.getCurrentUserName() private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend - private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { - _schedulerBackend = sb - } private[spark] def taskScheduler: TaskScheduler = _taskScheduler private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { @@ -322,8 +319,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId - def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null - private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = @@ -514,9 +509,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // The metrics system for Driver need to be set spark.app.id to app ID. // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() + _env.metricsSystem.start() // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _env.metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) _eventLogger = if (isEventLogEnabled) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 97a1e8b433281..746223f39e4f6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -261,9 +261,13 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") - ) ++Seq( + ) ++ Seq( // SPARK-13426 Remove the support of SIMR ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") + ) ++ Seq( + // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") ) case v if v.startsWith("1.6") => Seq( From 173aa949c309ff7a7a03e9d762b9108542219a95 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 22 Feb 2016 15:27:29 -0800 Subject: [PATCH 1973/2089] [SPARK-12546][SQL] Change default number of open parquet files A common problem that users encounter with Spark 1.6.0 is that writing to a partitioned parquet table OOMs. The root cause is that parquet allocates a significant amount of memory that is not accounted for by our own mechanisms. As a workaround, we can ensure that only a single file is open per task unless the user explicitly asks for more. Author: Michael Armbrust Closes #11308 from marmbrus/parquetWriteOOM. --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 61a7b9935af0b..a601c87fc930e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -430,7 +430,7 @@ private[spark] object SQLConf { val PARTITION_MAX_FILES = intConf("spark.sql.sources.maxConcurrentWrites", - defaultValue = Some(5), + defaultValue = Some(1), doc = "The maximum number of concurrent files to open before falling back on sorting when " + "writing out files using dynamic partitioning.") From 2063781840831469b394313694bfd25cbde2bb1e Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 22 Feb 2016 16:34:02 -0800 Subject: [PATCH 1974/2089] [SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins Use the HashedRelation which is a more optimized datastructure and reduce code complexity Author: Xiu Guo Closes #11291 from xguo27/SPARK-13422. --- .../joins/BroadcastLeftSemiJoinHash.scala | 27 +++------ .../sql/execution/joins/HashSemiJoin.scala | 55 +------------------ .../execution/joins/LeftSemiJoinHash.scala | 13 ++--- 3 files changed, 14 insertions(+), 81 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 1f99fbedde411..d3bcfad7c3de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. + * Build the right table's join keys into a HashedRelation, and iteratively go through the left + * table, to find if the join keys are in the HashedRelation. */ case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], @@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def requiredChildDistribution: Seq[Distribution] = { - val mode = if (condition.isEmpty) { - HashSetBroadcastMode(rightKeys, right.output) - } else { - HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) - } + val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - if (condition.isEmpty) { - val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]() - left.execute().mapPartitionsInternal { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows) - } - } else { - val broadcastedRelation = right.executeBroadcast[HashedRelation]() - left.execute().mapPartitionsInternal { streamIter => - val hashedRelation = broadcastedRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashSemiJoin(streamIter, hashedRelation, numOutputRows) - } + val broadcastedRelation = right.executeBroadcast[HashedRelation]() + left.execute().mapPartitionsInternal { streamIter => + val hashedRelation = broadcastedRelation.value + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) + hashSemiJoin(streamIter, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 1cb6a00617c5e..3eed6e3e11131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -43,24 +43,6 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { - HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter) - } - - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow], - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator - streamIter.filter(current => { - val key = joinKeys(current) - val r = !key.anyNull && hashSet.contains(key) - if (r) numOutputRows += 1 - r - }) - } - protected def hashSemiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, @@ -70,44 +52,11 @@ trait HashSemiJoin { streamIter.filter { current => val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) - } + }) if (r) numOutputRows += 1 r } } } - -private[execution] object HashSemiJoin { - def buildKeyHashSet( - keys: Seq[Expression], - attributes: Seq[Attribute], - rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = { - val hashSet = new java.util.HashSet[InternalRow]() - - // Create a Hash set of buildKeys - val key = UnsafeProjection.create(keys, attributes) - while (rows.hasNext) { - val currentRow = rows.next() - val rowKey = key(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey.copy()) - } - } - } - hashSet - } -} - -/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */ -private[execution] case class HashSetBroadcastMode( - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { - - override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = { - HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index d8d3045ccf5c3..242ed612327cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. + * Build the right table's join keys into a HashedRelation, and iteratively go through the left + * table, to find if the join keys are in the HashedRelation. */ case class LeftSemiJoinHash( leftKeys: Seq[Expression], @@ -47,13 +47,8 @@ case class LeftSemiJoinHash( val numOutputRows = longMetric("numOutputRows") right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet, numOutputRows) - } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation, numOutputRows) - } + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + hashSemiJoin(streamIter, hashRelation, numOutputRows) } } } From 9f410871ca03f4c04bd965b2e4f80167ce543139 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Mon, 22 Feb 2016 17:16:56 -0800 Subject: [PATCH 1975/2089] [SPARK-13016][DOCUMENTATION] Replace example code in mllib-dimensionality-reduction.md using include_example Replaced example example code in mllib-dimensionality-reduction.md using include_example Author: Devaraj K Closes #11132 from devaraj-kavali/SPARK-13016. --- docs/mllib-dimensionality-reduction.md | 113 +----------------- .../spark/examples/mllib/JavaPCAExample.java | 65 ++++++++++ .../spark/examples/mllib/JavaSVDExample.java | 70 +++++++++++ .../mllib/PCAOnRowMatrixExample.scala | 58 +++++++++ .../mllib/PCAOnSourceVectorExample.scala | 57 +++++++++ .../spark/examples/mllib/SVDExample.scala | 61 ++++++++++ 6 files changed, 316 insertions(+), 108 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 11d8e0bd1d23d..cceddce9f79a6 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -64,19 +64,7 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver.
        Refer to the [`SingularValueDecomposition` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.SingularValueDecomposition) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.SingularValueDecomposition - -val mat: RowMatrix = ... - -// Compute the top 20 singular values and corresponding singular vectors. -val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(20, computeU = true) -val U: RowMatrix = svd.U // The U factor is a RowMatrix. -val s: Vector = svd.s // The singular values are stored in a local dense vector. -val V: Matrix = svd.V // The V factor is a local dense matrix. -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SVDExample.scala %} The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. @@ -84,43 +72,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an
        Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/mllib/linalg/SingularValueDecomposition.html) for details on the API. -{% highlight java %} -import java.util.LinkedList; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.SingularValueDecomposition; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.rdd.RDD; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class SVD { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVD Example"); - SparkContext sc = new SparkContext(conf); - - double[][] array = ... - LinkedList rowsList = new LinkedList(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); - - // Create a RowMatrix from JavaRDD. - RowMatrix mat = new RowMatrix(rows.rdd()); - - // Compute the top 4 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(4, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSVDExample.java %} The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. @@ -151,36 +103,14 @@ and use them to project the vectors into a low-dimensional space. Refer to the [`RowMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix - -val mat: RowMatrix = ... - -// Compute the top 10 principal components. -val pc: Matrix = mat.computePrincipalComponents(10) // Principal components are stored in a local dense matrix. - -// Project the rows to the linear space spanned by the top 10 principal components. -val projected: RowMatrix = mat.multiply(pc) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala %} The following code demonstrates how to compute principal components on source vectors and use them to project the vectors into a low-dimensional space while keeping associated labels: Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.feature.PCA - -val data: RDD[LabeledPoint] = ... - -// Compute the top 10 principal components. -val pca = new PCA(10).fit(data.map(_.features)) - -// Project vectors to the linear space spanned by the top 10 principal components, keeping the label -val projected = data.map(p => p.copy(features = pca.transform(p.features))) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala %}
        @@ -192,40 +122,7 @@ The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. -{% highlight java %} -import java.util.LinkedList; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.rdd.RDD; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class PCA { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("PCA Example"); - SparkContext sc = new SparkContext(conf); - - double[][] array = ... - LinkedList rowsList = new LinkedList(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); - - // Create a RowMatrix from JavaRDD. - RowMatrix mat = new RowMatrix(rows.rdd()); - - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); - RowMatrix projected = mat.multiply(pc); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %}
        diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java new file mode 100644 index 0000000000000..faf76a9540e77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.LinkedList; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +// $example off$ + +/** + * Example for compute principal components on a 'RowMatrix'. + */ +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("PCA Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; + LinkedList rowsList = new LinkedList(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 principal components. + Matrix pc = mat.computePrincipalComponents(3); + RowMatrix projected = mat.multiply(pc); + // $example off$ + Vector[] collectPartitions = (Vector[])projected.rows().collect(); + System.out.println("Projected vector of principal component:"); + for (Vector vector : collectPartitions) { + System.out.println("\t" + vector); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java new file mode 100644 index 0000000000000..f3685db9f2fb2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.LinkedList; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.SingularValueDecomposition; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +// $example off$ + +/** + * Example for SingularValueDecomposition. + */ +public class JavaSVDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVD Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; + LinkedList rowsList = new LinkedList(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); + RowMatrix U = svd.U(); + Vector s = svd.s(); + Matrix V = svd.V(); + // $example off$ + Vector[] collectPartitions = (Vector[]) U.rows().collect(); + System.out.println("U factor is:"); + for (Vector vector : collectPartitions) { + System.out.println("\t" + vector); + } + System.out.println("Singular values are: " + s); + System.out.println("V factor is:\n" + V); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala new file mode 100644 index 0000000000000..234de230eb201 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +// $example off$ + +object PCAOnRowMatrixExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAOnRowMatrixExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + + val dataRDD = sc.parallelize(data, 2) + + val mat: RowMatrix = new RowMatrix(dataRDD) + + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + val pc: Matrix = mat.computePrincipalComponents(4) + + // Project the rows to the linear space spanned by the top 4 principal components. + val projected: RowMatrix = mat.multiply(pc) + // $example off$ + val collect = projected.rows.collect() + println("Projected Row Matrix of principal component:") + collect.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala new file mode 100644 index 0000000000000..f7694879dfbdb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +// $example off$ + +object PCAOnSourceVectorExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAOnSourceVectorExample") + val sc = new SparkContext(conf) + + // $example on$ + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)))) + + // Compute the top 5 principal components. + val pca = new PCA(5).fit(data.map(_.features)) + + // Project vectors to the linear space spanned by the top 5 principal + // components, keeping the label + val projected = data.map(p => p.copy(features = pca.transform(p.features))) + // $example off$ + val collect = projected.collect() + println("Projected vector of principal component:") + collect.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala new file mode 100644 index 0000000000000..c26580d4c1960 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +// $example off$ + +object SVDExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("SVDExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + + val dataRDD = sc.parallelize(data, 2) + + val mat: RowMatrix = new RowMatrix(dataRDD) + + // Compute the top 5 singular values and corresponding singular vectors. + val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) + val U: RowMatrix = svd.U // The U factor is a RowMatrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. + // $example off$ + val collect = U.rows.collect() + println("U factor is:") + collect.foreach { vector => println(vector) } + println(s"Singular values are: $s") + println(s"V factor is:\n$V") + } +} +// scalastyle:on println From 02b1fefffb00d50c1076a26f2f3f41f3c1fa0001 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Mon, 22 Feb 2016 17:21:37 -0800 Subject: [PATCH 1976/2089] [SPARK-13012][DOCUMENTATION] Replace example code in ml-guide.md using include_example Replaced example code in ml-guide.md using include_example Author: Devaraj K Closes #11053 from devaraj-kavali/SPARK-13012. --- docs/ml-guide.md | 660 +----------------- .../spark/examples/ml/JavaDocument.java | 43 ++ .../JavaEstimatorTransformerParamExample.java | 111 +++ .../examples/ml/JavaLabeledDocument.java | 38 + ...delSelectionViaCrossValidationExample.java | 122 ++++ ...lectionViaTrainValidationSplitExample.java | 83 +++ .../examples/ml/JavaPipelineExample.java | 91 +++ .../ml/estimator_transformer_param_example.py | 87 +++ .../src/main/python/ml/pipeline_example.py | 64 ++ .../ml/EstimatorTransformerParamExample.scala | 100 +++ ...elSelectionViaCrossValidationExample.scala | 111 +++ ...ectionViaTrainValidationSplitExample.scala | 72 ++ .../spark/examples/ml/PipelineExample.scala | 93 +++ 13 files changed, 1025 insertions(+), 650 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java create mode 100644 examples/src/main/python/ml/estimator_transformer_param_example.py create mode 100644 examples/src/main/python/ml/pipeline_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 8eee2fb674881..5900d665b37f0 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -213,209 +213,15 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.
        -{% highlight scala %} -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.sql.Row - -// Prepare training data from a list of (label, features) tuples. -val training = sqlContext.createDataFrame(Seq( - (1.0, Vectors.dense(0.0, 1.1, 0.1)), - (0.0, Vectors.dense(2.0, 1.0, -1.0)), - (0.0, Vectors.dense(2.0, 1.3, 1.0)), - (1.0, Vectors.dense(0.0, 1.2, -0.5)) -)).toDF("label", "features") - -// Create a LogisticRegression instance. This instance is an Estimator. -val lr = new LogisticRegression() -// Print out the parameters, documentation, and any default values. -println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") - -// We may set parameters using setter methods. -lr.setMaxIter(10) - .setRegParam(0.01) - -// Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training) -// Since model1 is a Model (i.e., a Transformer produced by an Estimator), -// we can view the parameters it used during fit(). -// This prints the parameter (name: value) pairs, where names are unique IDs for this -// LogisticRegression instance. -println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) - -// We may alternatively specify parameters using a ParamMap, -// which supports several methods for specifying parameters. -val paramMap = ParamMap(lr.maxIter -> 20) - .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. - -// One can also combine ParamMaps. -val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name -val paramMapCombined = paramMap ++ paramMap2 - -// Now learn a new model using the paramMapCombined parameters. -// paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training, paramMapCombined) -println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) - -// Prepare test data. -val test = sqlContext.createDataFrame(Seq( - (1.0, Vectors.dense(-1.0, 1.5, 1.3)), - (0.0, Vectors.dense(3.0, 2.0, -0.1)), - (1.0, Vectors.dense(0.0, 2.2, -1.5)) -)).toDF("label", "features") - -// Make predictions on test data using the Transformer.transform() method. -// LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'myProbability' column instead of the usual -// 'probability' column since we renamed the lr.probabilityCol parameter previously. -model2.transform(test) - .select("features", "label", "myProbability", "prediction") - .collect() - .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => - println(s"($features, $label) -> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %}
        -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Prepare training data. -// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans -// into DataFrames, where it uses the bean metadata to infer the schema. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) -), LabeledPoint.class); - -// Create a LogisticRegression instance. This instance is an Estimator. -LogisticRegression lr = new LogisticRegression(); -// Print out the parameters, documentation, and any default values. -System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); - -// We may set parameters using setter methods. -lr.setMaxIter(10) - .setRegParam(0.01); - -// Learn a LogisticRegression model. This uses the parameters stored in lr. -LogisticRegressionModel model1 = lr.fit(training); -// Since model1 is a Model (i.e., a Transformer produced by an Estimator), -// we can view the parameters it used during fit(). -// This prints the parameter (name: value) pairs, where names are unique IDs for this -// LogisticRegression instance. -System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); - -// We may alternatively specify parameters using a ParamMap. -ParamMap paramMap = new ParamMap() - .put(lr.maxIter().w(20)) // Specify 1 Param. - .put(lr.maxIter(), 30) // This overwrites the original maxIter. - .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. - -// One can also combine ParamMaps. -ParamMap paramMap2 = new ParamMap() - .put(lr.probabilityCol().w("myProbability")); // Change output column name -ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); - -// Now learn a new model using the paramMapCombined parameters. -// paramMapCombined overrides all parameters set earlier via lr.set* methods. -LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); -System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); - -// Prepare test documents. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) -), LabeledPoint.class); - -// Make predictions on test documents using the Transformer.transform() method. -// LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'myProbability' column instead of the usual -// 'probability' column since we renamed the lr.probabilityCol parameter previously. -DataFrame results = model2.transform(test); -for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %}
        -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.param import Param, Params - -# Prepare training data from a list of (label, features) tuples. -training = sqlContext.createDataFrame([ - (1.0, Vectors.dense([0.0, 1.1, 0.1])), - (0.0, Vectors.dense([2.0, 1.0, -1.0])), - (0.0, Vectors.dense([2.0, 1.3, 1.0])), - (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) - -# Create a LogisticRegression instance. This instance is an Estimator. -lr = LogisticRegression(maxIter=10, regParam=0.01) -# Print out the parameters, documentation, and any default values. -print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" - -# Learn a LogisticRegression model. This uses the parameters stored in lr. -model1 = lr.fit(training) - -# Since model1 is a Model (i.e., a transformer produced by an Estimator), -# we can view the parameters it used during fit(). -# This prints the parameter (name: value) pairs, where names are unique IDs for this -# LogisticRegression instance. -print "Model 1 was fit using parameters: " -print model1.extractParamMap() - -# We may alternatively specify parameters using a Python dictionary as a paramMap -paramMap = {lr.maxIter: 20} -paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. -paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. - -# You can combine paramMaps, which are python dictionaries. -paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name -paramMapCombined = paramMap.copy() -paramMapCombined.update(paramMap2) - -# Now learn a new model using the paramMapCombined parameters. -# paramMapCombined overrides all parameters set earlier via lr.set* methods. -model2 = lr.fit(training, paramMapCombined) -print "Model 2 was fit using parameters: " -print model2.extractParamMap() - -# Prepare test data -test = sqlContext.createDataFrame([ - (1.0, Vectors.dense([-1.0, 1.5, 1.3])), - (0.0, Vectors.dense([3.0, 2.0, -0.1])), - (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) - -# Make predictions on test data using the Transformer.transform() method. -# LogisticRegression.transform will only use the 'features' column. -# Note that model2.transform() outputs a "myProbability" column instead of the usual -# 'probability' column since we renamed the lr.probabilityCol parameter previously. -prediction = model2.transform(test) -selected = prediction.select("features", "label", "myProbability", "prediction") -for row in selected.collect(): - print row - -{% endhighlight %} +{% include_example python/ml/estimator_transformer_param_example.py %}
        @@ -427,191 +233,15 @@ This example follows the simple text document `Pipeline` illustrated in the figu
        -{% highlight scala %} -import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training documents from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// Fit the pipeline to training documents. -val model = pipeline.fit(training) - -// now we can optionally save the fitted pipeline to disk -model.save("/tmp/spark-logistic-regression-model") - -// we can also save this unfit pipeline to disk -pipeline.save("/tmp/unfit-lr-model") - -// and load it back in during production -val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. -model.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
        -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// Fit the pipeline to training documents. -PipelineModel model = pipeline.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. -DataFrame predictions = model.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
        -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row - -# Prepare training documents from a list of (id, text, label) tuples. -LabeledDocument = Row("id", "text", "label") -training = sqlContext.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) - -# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. -tokenizer = Tokenizer(inputCol="text", outputCol="words") -hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") -lr = LogisticRegression(maxIter=10, regParam=0.01) -pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) - -# Fit the pipeline to training documents. -model = pipeline.fit(training) - -# Prepare test documents, which are unlabeled (id, text) tuples. -test = sqlContext.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")], ["id", "text"]) - -# Make predictions on test documents and print columns of interest. -prediction = model.transform(test) -selected = prediction.select("id", "text", "prediction") -for row in selected.collect(): - print(row) - -{% endhighlight %} +{% include_example python/ml/pipeline_example.py %}
        @@ -646,201 +276,11 @@ However, it is also a well-established method for choosing parameters which is m
        -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training data from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0), - (4L, "b spark who", 1.0), - (5L, "g d a y", 0.0), - (6L, "spark fly", 1.0), - (7L, "was mapreduce", 0.0), - (8L, "e spark program", 1.0), - (9L, "a e c l", 0.0), - (10L, "spark compile", 1.0), - (11L, "hadoop software", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -val cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2) // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -val cvModel = cv.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
        -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2); // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = cv.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -DataFrame predictions = cvModel.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
        @@ -864,91 +304,11 @@ The `ParamMap` which produces the best evaluation metric is selected as the best
        -{% highlight scala %} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} - -// Prepare training and test data. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") -val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - -val lr = new LinearRegression() - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - // 80% of the data will be used for training and the remaining 20% for validation. - .setTrainRatio(0.8) - -// Run train validation split, and choose the best set of parameters. -val model = trainValidationSplit.fit(training) - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
        -{% highlight java %} -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.DataFrame; - -DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_linear_regression_data.txt"); - -// Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); -DataFrame training = splits[0]; -DataFrame test = splits[1]; - -LinearRegression lr = new LinearRegression(); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation - -// Run train validation split, and choose the best set of parameters. -TrainValidationSplitModel model = trainValidationSplit.fit(training); - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
        diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java new file mode 100644 index 0000000000000..6459dabc0698b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.io.Serializable; + +/** + * Unlabeled instance type, Spark SQL can infer schema from Java Beans. + */ +@SuppressWarnings("serial") +public class JavaDocument implements Serializable { + + private long id; + private String text; + + public JavaDocument(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { + return this.id; + } + + public String getText() { + return this.text; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java new file mode 100644 index 0000000000000..44cf3507f3743 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Estimator, Transformer, and Param. + */ +public class JavaEstimatorTransformerParamExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaEstimatorTransformerParamExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training data. + // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into + // DataFrames, where it uses the bean metadata to infer the schema. + DataFrame training = sqlContext.createDataFrame( + Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) + ), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + LogisticRegression lr = new LogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10).setRegParam(0.01); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + LogisticRegressionModel model1 = lr.fit(training); + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + + // We may alternatively specify parameters using a ParamMap. + ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + + // One can also combine ParamMaps. + ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name + ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); + System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + + // Prepare test documents. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) + ), LabeledPoint.class); + + // Make predictions on test documents using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r : results.select("features", "label", "myProbability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java new file mode 100644 index 0000000000000..68d1caf6ad3f3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.io.Serializable; + +/** + * Labeled instance type, Spark SQL can infer schema from Java Beans. + */ +@SuppressWarnings("serial") +public class JavaLabeledDocument extends JavaDocument implements Serializable { + + private double label; + + public JavaLabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { + return this.label; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java new file mode 100644 index 0000000000000..87ad119491e9a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Cross Validation. + */ +public class JavaModelSelectionViaCrossValidationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaCrossValidationExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L,"spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0), + new JavaLabeledDocument(4L, "b spark who", 1.0), + new JavaLabeledDocument(5L, "g d a y", 0.0), + new JavaLabeledDocument(6L, "spark fly", 1.0), + new JavaLabeledDocument(7L, "was mapreduce", 0.0), + new JavaLabeledDocument(8L, "e spark program", 1.0), + new JavaLabeledDocument(9L, "a e c l", 0.0), + new JavaLabeledDocument(10L, "spark compile", 1.0), + new JavaLabeledDocument(11L, "hadoop software", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000}) + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .build(); + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + CrossValidatorModel cvModel = cv.fit(training); + + // Prepare test documents, which are unlabeled. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame predictions = cvModel.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java new file mode 100644 index 0000000000000..77adb02dfd9a8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.ml.tuning.TrainValidationSplit; +import org.apache.spark.ml.tuning.TrainValidationSplitModel; +import org.apache.spark.sql.DataFrame; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Train Validation Split. + */ +public class JavaModelSelectionViaTrainValidationSplitExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaTrainValidationSplitExample"); + SparkContext sc = new SparkContext(conf); + SQLContext jsql = new SQLContext(sc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java new file mode 100644 index 0000000000000..3407c25c83c37 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for simple text document 'Pipeline'. + */ +public class JavaPipelineExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPipelineExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L, "spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. + DataFrame predictions = model.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/python/ml/estimator_transformer_param_example.py b/examples/src/main/python/ml/estimator_transformer_param_example.py new file mode 100644 index 0000000000000..9a8993dac4f65 --- /dev/null +++ b/examples/src/main/python/ml/estimator_transformer_param_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Estimator Transformer Param Example. +""" +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="EstimatorTransformerParamExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Prepare training data from a list of (label, features) tuples. + training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + + # Create a LogisticRegression instance. This instance is an Estimator. + lr = LogisticRegression(maxIter=10, regParam=0.01) + # Print out the parameters, documentation, and any default values. + print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print "Model 1 was fit using parameters: " + print model1.extractParamMap() + + # We may alternatively specify parameters using a Python dictionary as a paramMap + paramMap = {lr.maxIter: 20} + paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. + paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + + # You can combine paramMaps, which are python dictionaries. + paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name + paramMapCombined = paramMap.copy() + paramMapCombined.update(paramMap2) + + # Now learn a new model using the paramMapCombined parameters. + # paramMapCombined overrides all parameters set earlier via lr.set* methods. + model2 = lr.fit(training, paramMapCombined) + print "Model 2 was fit using parameters: " + print model2.extractParamMap() + + # Prepare test data + test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegression.transform will only use the 'features' column. + # Note that model2.transform() outputs a "myProbability" column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + prediction = model2.transform(test) + selected = prediction.select("features", "label", "myProbability", "prediction") + for row in selected.collect(): + print row + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py new file mode 100644 index 0000000000000..3288568f0c287 --- /dev/null +++ b/examples/src/main/python/ml/pipeline_example.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Pipeline Example. +""" +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PipelineExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Prepare training documents from a list of (id, text, label) tuples. + training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) + + # Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled (id, text) tuples. + test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + selected = prediction.select("id", "text", "prediction") + for row in selected.collect(): + print(row) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala new file mode 100644 index 0000000000000..65e3c365abb3f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object EstimatorTransformerParamExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("EstimatorTransformerParamExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training data from a list of (label, features) tuples. + val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) + )).toDF("label", "features") + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new LogisticRegression() + // Print out the parameters, documentation, and any default values. + println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + .setRegParam(0.01) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model1 = lr.fit(training) + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + + // We may alternatively specify parameters using a ParamMap, + // which supports several methods for specifying parameters. + val paramMap = ParamMap(lr.maxIter -> 20) + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + + // One can also combine ParamMaps. + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name + val paramMapCombined = paramMap ++ paramMap2 + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + val model2 = lr.fit(training, paramMapCombined) + println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + + // Prepare test data. + val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) + )).toDF("label", "features") + + // Make predictions on test data using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test) + .select("features", "label", "myProbability", "prediction") + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala new file mode 100644 index 0000000000000..0331d6e7b35df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object ModelSelectionViaCrossValidationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ModelSelectionViaCrossValidationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training data from a list of (id, text, label) tuples. + val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) + )).toDF("id", "text", "label") + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + val cvModel = cv.fit(training) + + // Prepare test documents, which are unlabeled (id, text) tuples. + val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + )).toDF("id", "text") + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + cvModel.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala new file mode 100644 index 0000000000000..5a95344f223df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +// $example off$ +import org.apache.spark.sql.SQLContext + +object ModelSelectionViaTrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ModelSelectionViaTrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training and test data. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + // $example off$ + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala new file mode 100644 index 0000000000000..6c29063626bac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object PipelineExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PipelineExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training documents from a list of (id, text, label) tuples. + val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) + )).toDF("id", "text", "label") + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Now we can optionally save the fitted pipeline to disk + model.write.overwrite().save("/tmp/spark-logistic-regression-model") + + // We can also save this unfit pipeline to disk + pipeline.write.overwrite().save("/tmp/unfit-lr-model") + + // And load it back in during production + val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") + + // Prepare test documents, which are unlabeled (id, text) tuples. + val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + )).toDF("id", "text") + + // Make predictions on test documents. + model.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println From 33ef3aa7eabbe323620eb77fa94a53996ed0251d Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Mon, 22 Feb 2016 17:26:32 -0800 Subject: [PATCH 1977/2089] [SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements - avoid creating new instances of arrays/vectors for each record As also mentioned/marked by TODO in AFTAggregator.AFTAggregator.add(data: AFTPoint) method a new array is being created for intercept value and it is being concatenated with another array which contains the betas, the resulted Array is being converted into a Dense vector which in its turn is being converted into breeze vector. This is expensive and not necessarily beautiful. I've tried to solve above mentioned problem by simple algebraic decompositions - keeping and treating intercept independently. Please let me know what do you think and if you have any questions. Thanks, Narine Author: Narine Kokhlikyan Closes #11179 from NarineK/survivaloptim. --- .../ml/regression/AFTSurvivalRegression.scala | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e8a1ff2278a92..1e5b4cb83c652 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) extends Serializable { - // beta is the intercept and regression coefficients to the covariates - private val beta = parameters.slice(1, parameters.length) + // the regression coefficients to the covariates + private val coefficients = parameters.slice(2, parameters.length) + private val intercept = parameters.valueAt(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) + private var gradientInterceptSum = 0.0 private var gradientLogSigmaSum = 0.0 def count: Long = totalCnt def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - // Here we optimize loss function over beta and log(sigma) + // Here we optimize loss function over coefficients, intercept and log(sigma) def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - gradientBetaSum/totalCnt.toDouble) + BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) */ def add(data: AFTPoint): this.type = { - // TODO: Don't create a new xi vector each time. - val xi = if (fitIntercept) { - Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze - } else { - Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze - } + val interceptFlag = if (fitIntercept) 1.0 else 0.0 + + val xi = data.features.toBreeze val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma lossSum += math.log(sigma) * delta lossSum += (math.exp(epsilon) - delta * epsilon) @@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) assert(!lossSum.isInfinity, s"AFTAggregator loss sum is infinity. Error for unknown reason.") - gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma - gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + val deltaMinusExpEps = delta - math.exp(epsilon) + gradientCoefficientSum += xi * deltaMinusExpEps / sigma + gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma + gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon totalCnt += 1 this @@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientBetaSum += other.gradientBetaSum + gradientCoefficientSum += other.gradientCoefficientSum + gradientInterceptSum += other.gradientInterceptSum gradientLogSigmaSum += other.gradientLogSigmaSum } this From a11b3995190cb4a983adcc8667f7b316cce18d24 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 22 Feb 2016 17:42:30 -0800 Subject: [PATCH 1978/2089] [SPARK-13298][CORE][UI] Escape "label" to avoid DAG being broken by some special character ## What changes were proposed in this pull request? When there are some special characters (e.g., `"`, `\`) in `label`, DAG will be broken. This patch just escapes `label` to avoid DAG being broken by some special characters ## How was the this patch tested? Jenkins tests Author: Shixiong Zhu Closes #11309 from zsxwing/SPARK-13298. --- .../org/apache/spark/ui/scope/RDDOperationGraph.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 003c218aada9c..cb2827199853a 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -20,10 +20,11 @@ package org.apache.spark.ui.scope import scala.collection.mutable import scala.collection.mutable.{ListBuffer, StringBuilder} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.CallSite /** * A representation of a generic cluster graph used for storing information on RDD operations. @@ -183,7 +184,7 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { val label = s"${node.name} [${node.id}]\n${node.callsite}" - s"""${node.id} [label="$label"]""" + s"""${node.id} [label="${StringEscapeUtils.escapeJava(label)}"]""" } /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ @@ -192,7 +193,7 @@ private[ui] object RDDOperationGraph extends Logging { cluster: RDDOperationCluster, indent: String): Unit = { subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent).append(s""" label="${cluster.name}";\n""") + .append(indent).append(s""" label="${StringEscapeUtils.escapeJava(cluster.name)}";\n""") cluster.childNodes.foreach { node => subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } From 5d80fac58f837933b5359a8057676f45539e53af Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 22 Feb 2016 18:13:32 -0800 Subject: [PATCH 1979/2089] [SPARK-11624][SPARK-11972][SQL] fix commands that need hive to exec In SparkSQLCLI, we have created a `CliSessionState`, but then we call `SparkSQLEnv.init()`, which will start another `SessionState`. This would lead to exception because `processCmd` need to get the `CliSessionState` instance by calling `SessionState.get()`, but the return value would be a instance of `SessionState`. See the exception below. spark-sql> !echo "test"; Exception in thread "main" java.lang.ClassCastException: org.apache.hadoop.hive.ql.session.SessionState cannot be cast to org.apache.hadoop.hive.cli.CliSessionState at org.apache.hadoop.hive.cli.CliDriver.processCmd(CliDriver.java:112) at org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver.processCmd(SparkSQLCLIDriver.scala:301) at org.apache.hadoop.hive.cli.CliDriver.processLine(CliDriver.java:376) at org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver$.main(SparkSQLCLIDriver.scala:242) at org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver.main(SparkSQLCLIDriver.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:691) at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180) at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:120) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Author: Daoyuan Wang Closes #9589 from adrian-wang/clicommand. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 7 ++- .../sql/hive/thriftserver/CliSuite.scala | 5 ++ sql/hive/pom.xml | 6 +++ .../sql/hive/client/HiveClientImpl.scala | 54 +++++++++++-------- 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index f279b78f47c7d..fb31119a9e1dd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -288,8 +288,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || - cmd_lower.equals("exit") || - tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + cmd_lower.equals("exit")) { + sessionState.close() + System.exit(0) + } + if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || isRemoteMode) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 72da266da4d01..81508e134695a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -234,4 +234,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { -> "Error in query: Table not found: nonexistent_table;" ) } + + test("SPARK-11624 Spark SQL CLI should set sessionState only once") { + runCliWithin(2.minute, Seq("-e", "!echo \"This is a test for Spark-11624\";"))( + "" -> "This is a test for Spark-11624") + } } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 14cf9acf09d5b..22bad93e6dd58 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -72,6 +72,12 @@ protobuf-java ${protobuf.version} +--> + + ${hive.group} + hive-cli + + upstream() -------> upstream() ------> execute() + * doExecute() ---------> upstreams() -------> upstreams() ------> execute() * | * -----------------> produce() * | @@ -267,6 +274,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) public GeneratedIterator(Object[] references) { this.references = references; + } + + public void init(scala.collection.Iterator inputs[]) { ${ctx.initMutableStates()} } @@ -283,19 +293,33 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // println(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) - plan.upstream().mapPartitions { iter => - - val clazz = CodeGenerator.compile(source) - val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.setInput(iter) - new Iterator[InternalRow] { - override def hasNext: Boolean = buffer.hasNext - override def next: InternalRow = buffer.next() + val rdds = plan.upstreams() + assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") + if (rdds.length == 1) { + rdds.head.mapPartitions { iter => + val clazz = CodeGenerator.compile(cleanedSource) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.init(Array(iter)) + new Iterator[InternalRow] { + override def hasNext: Boolean = buffer.hasNext + override def next: InternalRow = buffer.next() + } + } + } else { + // Right now, we support up to two upstreams. + rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => + val clazz = CodeGenerator.compile(cleanedSource) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.init(Array(leftIter, rightIter)) + new Iterator[InternalRow] { + override def hasNext: Boolean = buffer.hasNext + override def next: InternalRow = buffer.next() + } } } } - override def upstream(): RDD[InternalRow] = { + override def upstreams(): Seq[RDD[InternalRow]] = { throw new UnsupportedOperationException } @@ -312,7 +336,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRows.add($row.copy()); + |append($row.copy()); """.stripMargin } else { assert(input != null) @@ -324,13 +348,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ctx.currentVars = input val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | ${code.code.trim} - | currentRows.add(${code.value}.copy()); + |${code.code.trim} + |append(${code.value}.copy()); """.stripMargin } else { // There is no columns s""" - | currentRows.add(unsafeRow); + |append(unsafeRow); """.stripMargin } } @@ -402,6 +426,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru b.copy(left = apply(left)) case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => b.copy(right = apply(right)) + case j @ SortMergeJoin(_, _, _, left, right) => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = apply(left), right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively inputs += input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 852203f3743dc..a46722963a6e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -121,8 +121,8 @@ case class TungstenAggregate( !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 55bddd196ec46..b2f443c0e9ae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { @@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() } protected override def doProduce(ctx: CodegenContext): String = { @@ -156,8 +156,9 @@ case class Range( private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + override def upstreams(): Seq[RDD[InternalRow]] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) :: Nil } protected override def doProduce(ctx: CodegenContext): String = { @@ -213,12 +214,15 @@ case class Range( | } """.stripMargin) + val input = ctx.freshName("input") + // Right now, Range is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; - | if (input.hasNext()) { - | initRange(((InternalRow) input.next()).getInt(0)); + | if ($input.hasNext()) { + | initRange(((InternalRow) $input.next()).getInt(0)); | } else { | return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index ddc08822f3e17..6699dbafe7e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -99,8 +99,8 @@ case class BroadcastHashJoin( } } - override def upstream(): RDD[InternalRow] = { - streamedPlan.asInstanceOf[CodegenSupport].upstream() + override def upstreams(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].upstreams() } override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index e417079b61b4e..fabd2fbe1e0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter - /** * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, * will be much faster than building the right partition for every row in left RDD, it also diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index cd8a5670e2301..7ec4027188f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Performs an sort merge join of two child relations. @@ -34,7 +35,7 @@ case class SortMergeJoin( rightKeys: Seq[Expression], condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan) extends BinaryNode with CodegenSupport { override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -125,6 +126,246 @@ case class SortMergeJoin( }.toScala } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + private def createJoinKey( + ctx: CodegenContext, + row: String, + keys: Seq[Expression], + input: Seq[Attribute]): Seq[ExprCode] = { + ctx.INPUT_ROW = row + keys.map(BindReferences.bindReference(_, input).gen(ctx)) + } + + private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { + vars.zipWithIndex.map { case (ev, i) => + val value = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "") + val code = + s""" + |$value = ${ev.value}; + """.stripMargin + ExprCode(code, "false", value) + } + } + + private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => + s""" + |if (comp == 0) { + | comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)}; + |} + """.stripMargin.trim + } + s""" + |comp = 0; + |${comparisons.mkString("\n")} + """.stripMargin + } + + /** + * Generate a function to scan both left and right to find a match, returns the term for + * matched one row from left side and buffered rows from right side. + */ + private def genScanner(ctx: CodegenContext): (String, String) = { + // Create class member for next row from both sides. + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, "") + val rightRow = ctx.freshName("rightRow") + ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the right key as class members so they could be used in next function call. + val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + + // A list to hold all matched rows from right side. + val matches = ctx.freshName("matches") + val clsName = classOf[java.util.ArrayList[InternalRow]].getName + ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + // Copy the left keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + + ctx.addNewFunction("findNextInnerJoinRows", + s""" + |private boolean findNextInnerJoinRows( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) { + | $leftRow = null; + | int comp = 0; + | while ($leftRow == null) { + | if (!leftIter.hasNext()) return false; + | $leftRow = (InternalRow) leftIter.next(); + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | $leftRow = null; + | continue; + | } + | if (!$matches.isEmpty()) { + | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | if (comp == 0) { + | return true; + | } + | $matches.clear(); + | } + | + | do { + | if ($rightRow == null) { + | if (!rightIter.hasNext()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return !$matches.isEmpty(); + | } + | $rightRow = (InternalRow) rightIter.next(); + | ${rightKeyTmpVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | $rightRow = null; + | continue; + | } + | ${rightKeyVars.map(_.code).mkString("\n")} + | } + | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | if (comp > 0) { + | $rightRow = null; + | } else if (comp < 0) { + | if (!$matches.isEmpty()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return true; + | } + | $leftRow = null; + | } else { + | $matches.add($rightRow.copy()); + | $rightRow = null;; + | } + | } while ($leftRow != null); + | } + | return false; // unreachable + |} + """.stripMargin) + + (leftRow, matches) + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = used.map(_._2.code).mkString("\n") + val afterCond = notUsed.map(_._2.code).mkString("\n") + (beforeCond, afterCond) + } else { + (variables.map(_.code).mkString("\n"), "") + } + } + + override def doProduce(ctx: CodegenContext): String = { + val leftInput = ctx.freshName("leftInput") + ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") + val rightInput = ctx.freshName("rightInput") + ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + + val (leftRow, matches) = genScanner(ctx) + + // Create variables for row from both sides. + val leftVars = createLeftVars(ctx, leftRow) + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + val resultVars = leftVars ++ rightVars + + // Check condition + ctx.currentVars = resultVars + val cond = if (condition.isDefined) { + BindReferences.bindReference(condition.get, output).gen(ctx) + } else { + ExprCode("", "false", "true") + } + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + + + val size = ctx.freshName("size") + val i = ctx.freshName("i") + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | int $size = $matches.size(); + | boolean $loaded = false; + | $leftBefore + | for (int $i = 0; $i < $size; $i ++) { + | InternalRow $rightRow = (InternalRow) $matches.get($i); + | $rightBefore + | ${cond.code} + | if (${cond.isNull} || !${cond.value}) continue; + | if (!$loaded) { + | $loaded = true; + | $leftAfter + | } + | $rightAfter + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | if (shouldStop()) return; + |} + """.stripMargin + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index bcac660a35a65..6d6cc0186a962 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -38,6 +38,7 @@ import org.apache.spark.util.Benchmark class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") .set("spark.sql.shuffle.partitions", "1") + .set("spark.sql.autoBroadcastJoinThreshold", "0") lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) @@ -187,6 +188,39 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("sort merge join") { + val N = 2 << 20 + runBenchmark("merge join", N) { + val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1") + val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + merge join codegen=false 1588 / 1880 1.3 757.1 1.0X + merge join codegen=true 1477 / 1531 1.4 704.2 1.1X + */ + + runBenchmark("sort merge join", N) { + val df1 = sqlContext.range(N) + .selectExpr(s"(id * 15485863) % ${N*10} as k1") + val df2 = sqlContext.range(N) + .selectExpr(s"(id * 15485867) % ${N*10} as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X + sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X + */ + } + ignore("rube") { val N = 5 << 20 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a05a57c0f5072..0ef42f45e3b78 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -240,7 +240,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1") withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2") - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { val t1 = hiveContext.table("bucketed_table1") val t2 = hiveContext.table("bucketed_table2") val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) From 15e30155631d52e35ab8522584027ab350e5acb3 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 23 Feb 2016 15:31:17 -0800 Subject: [PATCH 1995/2089] [SPARK-6761][SQL][ML] Fixes to API and documentation of approximate quantiles ## What changes were proposed in this pull request? This continues thunterdb 's work on `approxQuantile` API. It changes the signature of `approxQuantile` from `(col: String, quantile: Double, epsilon: Double): Double` to `(col: String, probabilities: Array[Double], relativeError: Double): Array[Double]` and update API doc. It also improves the error message in tests and simplifies the merge algorithm for summaries. ## How was the this patch tested? Use the same unit tests as before. Closes #11325 Author: Timothy Hunter Author: Xiangrui Meng Closes #11332 from mengxr/SPARK-6761. --- .../spark/sql/DataFrameStatFunctions.scala | 36 +++- .../sql/execution/stat/StatFunctions.scala | 180 +++++++++--------- .../apache/spark/sql/DataFrameStatSuite.scala | 42 ++-- .../execution/stat/ApproxQuantileSuite.scala | 12 +- 4 files changed, 150 insertions(+), 120 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 7f110c4e7f34b..39a31ab02898f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -37,13 +37,37 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} final class DataFrameStatFunctions private[sql](df: DataFrame) { /** - * Calculate the approximate quantile of numerical column of a DataFrame. - * @param col the name of the column - * @param quantile the quantile number - * @return the approximate quantile + * Calculates the approximate quantiles of a numerical column of a DataFrame. + * + * The result of this algorithm has the following deterministic bound: + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient + * Online Computation of Quantile Summaries]] by Greenwald and Khanna. + * + * @param col the name of the numerical column + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities + * + * @since 2.0.0 */ - def approxQuantile(col: String, quantile: Double, epsilon: Double): Double = { - StatFunctions.approxQuantile(df, col, quantile, epsilon) + def approxQuantile( + col: String, + probabilities: Array[Double], + relativeError: Double): Array[Double] = { + StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index eb056d555bbdb..26e4eda542d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.stat -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -33,59 +32,37 @@ private[sql] object StatFunctions extends Logging { import QuantileSummaries.Stats /** - * Calculates the approximate quantile for the given column. - * - * If you need to compute multiple quantiles at once, you should use [[multipleApproxQuantiles]] - * - * Note on the target error. + * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass. * * The result of this algorithm has the following deterministic bound: - * if the DataFrame has N elements and if we request the quantile `phi` up to error `epsi`, - * then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank - * of `x` close to (phi * N). More precisely: - * - * floor((phi - epsi) * N) <= rank(x) <= ceil((phi + epsi) * N) - * - * Note on the algorithm used. + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, * - * This method implements a variation of the Greenwald-Khanna algorithm - * (with some speed optimizations). The algorithm was first present in the following article: - * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael - * and Khanna, Sanjeev. (http://dl.acm.org/citation.cfm?id=375670) + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). * - * The performance optimizations are detailed in the comments of the implementation. - * - * @param df the dataframe to estimate quantiles on - * @param col the name of the column - * @param quantile the target quantile of interest - * @param epsilon the target error. Should be >= 0. - * */ - def approxQuantile( - df: DataFrame, - col: String, - quantile: Double, - epsilon: Double = QuantileSummaries.defaultEpsilon): Double = { - require(quantile >= 0.0 && quantile <= 1.0, "Quantile must be in the range of (0.0, 1.0).") - val Seq(Seq(res)) = multipleApproxQuantiles(df, Seq(col), Seq(quantile), epsilon) - res - } - - /** - * Runs multiple quantile computations in a single pass, with the same target error. - * - * See [[approxQuantile)]] for more details on the approximation guarantees. + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient + * Online Computation of Quantile Summaries]] by Greenwald and Khanna. * * @param df the dataframe - * @param cols columns of the dataframe - * @param quantiles target quantiles to compute - * @param epsilon the precision to achieve + * @param cols numerical columns of the dataframe + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * * @return for each column, returns the requested approximations */ def multipleApproxQuantiles( df: DataFrame, cols: Seq[String], - quantiles: Seq[Double], - epsilon: Double): Seq[Seq[Double]] = { + probabilities: Seq[Double], + relativeError: Double): Seq[Seq[Double]] = { val columns: Seq[Column] = cols.map { colName => val field = df.schema(colName) require(field.dataType.isInstanceOf[NumericType], @@ -94,7 +71,7 @@ private[sql] object StatFunctions extends Logging { Column(Cast(Column(colName).expr, DoubleType)) } val emptySummaries = Array.fill(cols.size)( - new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, epsilon)) + new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) // Note that it works more or less by accident as `rdd.aggregate` is not a pure function: // this function returns the same array as given in the input (because `aggregate` reuses @@ -115,40 +92,49 @@ private[sql] object StatFunctions extends Logging { } val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) - summaries.map { summary => quantiles.map(summary.query) } + summaries.map { summary => probabilities.map(summary.query) } } /** * Helper class to compute approximate quantile summary. * This implementation is based on the algorithm proposed in the paper: * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael - * and Khanna, Sanjeev. (http://dl.acm.org/citation.cfm?id=375670) + * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670) * * In order to optimize for speed, it maintains an internal buffer of the last seen samples, * and only inserts them after crossing a certain size threshold. This guarantees a near-constant * runtime complexity compared to the original algorithm. * - * @param compressThreshold the compression threshold: after the internal buffer of statistics - * crosses this size, it attempts to compress the statistics together - * @param epsilon the target precision - * @param sampled a buffer of quantile statistics. See the G-K article for more details + * @param compressThreshold the compression threshold. + * After the internal buffer of statistics crosses this size, it attempts to compress the + * statistics together. + * @param relativeError the target relative error. + * It is uniform across the complete range of values. + * @param sampled a buffer of quantile statistics. + * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) * @param headSampled a buffer of latest samples seen so far */ class QuantileSummaries( val compressThreshold: Int, - val epsilon: Double, + val relativeError: Double, val sampled: ArrayBuffer[Stats] = ArrayBuffer.empty, private[stat] var count: Long = 0L, val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty) extends Serializable { import QuantileSummaries._ + /** + * Returns a summary with the given observation inserted into the summary. + * This method may either modify in place the current summary (and return the same summary, + * modified in place), or it may create a new summary from scratch it necessary. + * @param x the new observation to insert into the summary + */ def insert(x: Double): QuantileSummaries = { headSampled.append(x) if (headSampled.size >= defaultHeadSize) { - this.withHeadInserted + this.withHeadBufferInserted } else { this } @@ -162,7 +148,7 @@ private[sql] object StatFunctions extends Logging { * * @return a new quantile summary object. */ - private def withHeadInserted: QuantileSummaries = { + private def withHeadBufferInserted: QuantileSummaries = { if (headSampled.isEmpty) { return this } @@ -187,7 +173,7 @@ private[sql] object StatFunctions extends Logging { if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * epsilon * currentCount).toInt + math.floor(2 * relativeError * currentCount).toInt } val tuple = Stats(currentSample, 1, delta) @@ -200,67 +186,80 @@ private[sql] object StatFunctions extends Logging { newSamples.append(sampled(sampleIdx)) sampleIdx += 1 } - new QuantileSummaries(compressThreshold, epsilon, newSamples, currentCount) + new QuantileSummaries(compressThreshold, relativeError, newSamples, currentCount) } + /** + * Returns a new summary that compresses the summary statistics and the head buffer. + * + * This implements the COMPRESS function of the GK algorithm. It does not modify the object. + * + * @return a new summary object with compressed statistics + */ def compress(): QuantileSummaries = { // Inserts all the elements first - val inserted = this.withHeadInserted + val inserted = this.withHeadBufferInserted assert(inserted.headSampled.isEmpty) assert(inserted.count == count + headSampled.size) val compressed = - compressImmut(inserted.sampled, mergeThreshold = 2 * epsilon * inserted.count) - new QuantileSummaries(compressThreshold, epsilon, compressed, inserted.count) + compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) } + private def shallowCopy: QuantileSummaries = { + new QuantileSummaries(compressThreshold, relativeError, sampled, count, headSampled) + } + + /** + * Merges two (compressed) summaries together. + * + * Returns a new summary. + */ def merge(other: QuantileSummaries): QuantileSummaries = { + require(headSampled.isEmpty, "Current buffer needs to be compressed before merge") + require(other.headSampled.isEmpty, "Other buffer needs to be compressed before merge") if (other.count == 0) { - this + this.shallowCopy } else if (count == 0) { - other + other.shallowCopy } else { - // We rely on the fact that they are ordered to efficiently interleave them. - val thisSampled = sampled.toList - val otherSampled = other.sampled.toList - val res: ArrayBuffer[Stats] = ArrayBuffer.empty - - @tailrec - def mergeCurrent( - thisList: List[Stats], - otherList: List[Stats]): Unit = (thisList, otherList) match { - case (Nil, l) => - res.appendAll(l) - case (l, Nil) => - res.appendAll(l) - case (h1 :: t1, h2 :: t2) if h1.value > h2.value => - mergeCurrent(otherList, thisList) - case (h1 :: t1, l) => - // We know that h1.value <= all values in l - // TODO(thunterdb) do we need to adjust g and delta? - res.append(h1) - mergeCurrent(t1, l) - } - - mergeCurrent(thisSampled, otherSampled) - val comp = compressImmut(res, mergeThreshold = 2 * epsilon * count) - new QuantileSummaries(other.compressThreshold, other.epsilon, comp, other.count + count) + // Merge the two buffers. + // The GK algorithm is a bit unclear about it, but it seems there is no need to adjust the + // statistics during the merging: the invariants are still respected after the merge. + // TODO: could replace full sort by ordered merge, the two lists are known to be sorted + // already. + val res = (sampled ++ other.sampled).sortBy(_.value) + val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) + new QuantileSummaries( + other.compressThreshold, other.relativeError, comp, other.count + count) } } + /** + * Runs a query for a given quantile. + * The result follows the approximation guarantees detailed above. + * The query can only be run on a compressed summary: you need to call compress() before using + * it. + * + * @param quantile the target quantile + * @return + */ def query(quantile: Double): Double = { require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") + require(headSampled.isEmpty, + "Cannot operate on an uncompressed summary, call compress() first") - if (quantile <= epsilon) { + if (quantile <= relativeError) { return sampled.head.value } - if (quantile >= 1 - epsilon) { + if (quantile >= 1 - relativeError) { return sampled.last.value } // Target rank val rank = math.ceil(quantile * count).toInt - val targetError = math.ceil(epsilon * count) + val targetError = math.ceil(relativeError * count) // Minimum rank at current sample var minRank = 0 var i = 1 @@ -291,9 +290,10 @@ private[sql] object StatFunctions extends Logging { val defaultHeadSize: Int = 50000 /** - * The default value for epsilon. + * The default value for the relative error (1%). + * With this value, the best extreme percentiles that can be approximated are 1% and 99%. */ - val defaultEpsilon: Double = 0.01 + val defaultRelativeError: Double = 0.01 /** * Statisttics from the Greenwald-Khanna paper. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7f9229244b583..e865dbe6b5063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -126,22 +126,29 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("approximate quantile") { - val df = Seq.tabulate(1000)(i => (i, 2.0 * i)).toDF("singles", "doubles") - - val expected_1 = 500.0 - val expected_2 = 1600.0 + val n = 1000 + val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + val q1 = 0.5 + val q2 = 0.8 val epsilons = List(0.1, 0.05, 0.001) for (epsilon <- epsilons) { - val result1 = df.stat.approxQuantile("singles", 0.5, epsilon) - val result2 = df.stat.approxQuantile("doubles", 0.8, epsilon) - - val error_1 = 2 * 1000 * epsilon - val error_2 = 2 * 2000 * epsilon - - assert(math.abs(result1 - expected_1) < error_1) - assert(math.abs(result2 - expected_2) < error_2) + val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon) + val Array(double2) = df.stat.approxQuantile("doubles", Array(q2), epsilon) + // Also make sure there is no regression by computing multiple quantiles at once. + val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) + val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) + + val error_single = 2 * 1000 * epsilon + val error_double = 2 * 2000 * epsilon + + assert(math.abs(single1 - q1 * n) < error_single) + assert(math.abs(double2 - 2 * q2 * n) < error_double) + assert(math.abs(s1 - q1 * n) < error_single) + assert(math.abs(s2 - q2 * n) < error_single) + assert(math.abs(d1 - 2 * q1 * n) < error_double) + assert(math.abs(d2 - 2 * q2 * n) < error_double) } } @@ -296,9 +303,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Logging { // Turn on this test if you want to test the performance of approximate quantiles. - ignore("describe() should not be slowed down too much by quantiles") { + ignore("computing quantiles should not take much longer than describe()") { val df = sqlContext.range(5000000L).toDF("col1").cache() - def millis(f: => Any): Double = { + def seconds(f: => Any): Double = { // Do some warmup logDebug("warmup...") for (i <- 1 to 10) { @@ -314,15 +321,14 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin (end - start) / 1e9 } logDebug("execute done") - times.sum.toDouble / times.length.toDouble - + times.sum / times.length.toDouble } logDebug("*** Normal describe ***") - val t1 = millis { df.describe() } + val t1 = seconds { df.describe() } logDebug(s"T1 = $t1") logDebug("*** Just quantiles ***") - val t2 = millis { + val t2 = seconds { StatFunctions.multipleApproxQuantiles(df, Seq("col1"), Seq(0.1, 0.25, 0.5, 0.75, 0.9), 0.01) } logDebug(s"T1 = $t1, T2 = $t2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala index 6992b4c7235ad..0a989d026ce1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala @@ -46,12 +46,12 @@ class ApproxQuantileSuite extends SparkFunSuite { val approx = summary.query(quant) // The rank of the approximation. val rank = data.count(_ < approx) // has to be <, not <= to be exact - val lower = math.floor((quant - summary.epsilon) * data.size) - assert(rank >= lower, - s"approx_rank: $rank ! >= $lower, requested quantile = $quant") - val upper = math.ceil((quant + summary.epsilon) * data.size) - assert(rank <= upper, - s"approx_rank: $rank ! <= $upper, requested quantile = $quant") + val lower = math.floor((quant - summary.relativeError) * data.size) + val upper = math.ceil((quant + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) } for { From 8d29001dec5c3695721a76df3f70da50512ef28f Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 23 Feb 2016 15:42:58 -0800 Subject: [PATCH 1996/2089] [SPARK-13011] K-means wrapper in SparkR https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin Closes #11124 from yinxusen/SPARK-13011. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/generics.R | 8 ++ R/pkg/R/mllib.R | 74 ++++++++++++++++++- R/pkg/inst/tests/testthat/test_mllib.R | 28 +++++++ .../apache/spark/ml/clustering/KMeans.scala | 45 ++++++++++- .../apache/spark/ml/r/SparkRWrappers.scala | 52 ++++++++++++- 6 files changed, 203 insertions(+), 8 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f194a46303e0d..6a3d63f43f785 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -13,7 +13,9 @@ export("print.jobj") # MLlib integration exportMethods("glm", "predict", - "summary") + "summary", + "kmeans", + "fitted") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 2dba71abec689..ab61bce03df23 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind #' @export setGeneric("rbind", signature = "...") + +#' @rdname kmeans +#' @export +setGeneric("kmeans") + +#' @rdname fitted +#' @export +setGeneric("fitted") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 8d3b4388ae575..346f33d7dab2c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"), setMethod("summary", signature(object = "PipelineModel"), function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", object@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelDevianceResiduals", object@model) @@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) - } else { + } else if (modelName == "LogisticRegressionModel") { coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) + } else if (modelName == "KMeansModel") { + modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansModelSize", object@model) + cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, "classes") + k <- unlist(modelSize)[1] + size <- unlist(modelSize)[-1] + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) + } + }) + +#' Fit a k-means model +#' +#' Fit a k-means model, similarly to R's kmeans(). +#' +#' @param x DataFrame for training +#' @param centers Number of centers +#' @param iter.max Maximum iteration number +#' @param algorithm Algorithm choosen to fit the model +#' @return A fitted k-means model +#' @rdname kmeans +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(x, centers = 2, algorithm="random") +#'} +setMethod("kmeans", signature(x = "DataFrame"), + function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { + columnNames <- as.array(colnames(x)) + algorithm <- match.arg(algorithm) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf, + algorithm, iter.max, centers, columnNames) + return(new("PipelineModel", model = model)) + }) + +#' Get fitted result from a model +#' +#' Get fitted result from a model, similarly to R's fitted(). +#' +#' @param object A fitted MLlib model +#' @return DataFrame containing fitted values +#' @rdname fitted +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(trainingData, 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +setMethod("fitted", signature(object = "PipelineModel"), + function(object, method = c("centers", "classes"), ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", object@model) + + if (modelName == "KMeansModel") { + method <- match.arg(method) + fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, method) + return(dataFrame(fittedResult)) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) } }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 08099dd96a87b..595512e0e0d3a 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -113,3 +113,31 @@ test_that("summary works on base GLM models", { baseSummary <- summary(baseModel) expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) }) + +test_that("kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + + # Cache the DataFrame here to work around the bug SPARK-13178. + cache(training) + take(training, 1) + + model <- kmeans(x = training, centers = 2) + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(unique(statsModel$cluster), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index b2292e20e2124..c6a3eac5877f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} @@ -135,6 +136,26 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + + private var trainingSummary: Option[KMeansSummary] = None + + private[clustering] def setSummary(summary: KMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: KMeansSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}", + new NullPointerException()) + } } @Since("1.6.0") @@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") ( .setSeed($(seed)) .setEpsilon($(tol)) val parentModel = algo.run(rdd) - val model = new KMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) + val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + model.setSummary(summary) } @Since("1.5.0") @@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +class KMeansSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of each cluster. + */ + @Since("2.0.0") + lazy val size: Array[Int] = cluster.map { + case Row(clusterIdx: Int) => (clusterIdx, 1) + }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 551e75dc0a02d..d23e4fc9d1f57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.{RFormula, VectorAssembler} import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.sql.DataFrame @@ -51,6 +52,22 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + def fitKMeans( + df: DataFrame, + initMode: String, + maxIter: Double, + k: Double, + columns: Array[String]): PipelineModel = { + val assembler = new VectorAssembler().setInputCols(columns) + val kMeans = new KMeans() + .setInitMode(initMode) + .setMaxIter(maxIter.toInt) + .setK(k.toInt) + .setFeaturesCol(assembler.getOutputCol) + val pipeline = new Pipeline().setStages(Array(assembler, kMeans)) + pipeline.fit(df) + } + def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => { @@ -72,6 +89,8 @@ private[r] object SparkRWrappers { m.coefficients.toArray } } + case m: KMeansModel => + m.clusterCenters.flatMap(_.toArray) } } @@ -85,6 +104,31 @@ private[r] object SparkRWrappers { } } + def getKMeansModelSize(model: PipelineModel): Array[Int] = { + model.stages.last match { + case m: KMeansModel => Array(m.getK) ++ m.summary.size + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + + def getKMeansCluster(model: PipelineModel, method: String): DataFrame = { + model.stages.last match { + case m: KMeansModel => + if (method == "centers") { + // Drop the assembled vector for easy-print to R side. + m.summary.predictions.drop(m.summary.featuresCol) + } else if (method == "classes") { + m.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + def getModelFeatures(model: PipelineModel): Array[String] = { model.stages.last match { case m: LinearRegressionModel => @@ -103,6 +147,10 @@ private[r] object SparkRWrappers { } else { attrs.attributes.get.map(_.name.get) } + case m: KMeansModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + attrs.attributes.get.map(_.name.get) } } @@ -112,6 +160,8 @@ private[r] object SparkRWrappers { "LinearRegressionModel" case m: LogisticRegressionModel => "LogisticRegressionModel" + case m: KMeansModel => + "KMeansModel" } } } From 230bbeaa614ed0ee87ecceece42355dd9a4bacb3 Mon Sep 17 00:00:00 2001 From: JeremyNixon Date: Tue, 23 Feb 2016 15:57:29 -0800 Subject: [PATCH 1997/2089] [SPARK-10759][ML] update cross validator with include_example This pull request uses {%include_example%} to add an example for the python cross validator to ml-guide. Author: JeremyNixon Closes #11240 from JeremyNixon/pipeline_include_example. --- docs/ml-guide.md | 5 +++++ examples/src/main/python/ml/cross_validator.py | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5900d665b37f0..a5a825f64e66e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -283,6 +283,11 @@ However, it is also a well-established method for choosing parameters which is m {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %} +
        + +{% include_example python/ml/cross_validator.py %} +
        + ## Example: model selection via train validation split diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index f0ca97c724940..5f0ef20218c4a 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -18,12 +18,14 @@ from __future__ import print_function from pyspark import SparkContext +# $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.sql import Row, SQLContext +# $example off$ """ A simple example demonstrating model selection using CrossValidator. @@ -36,7 +38,7 @@ if __name__ == "__main__": sc = SparkContext(appName="CrossValidatorExample") sqlContext = SQLContext(sc) - + # $example on$ # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") training = sc.parallelize([(0, "a b c d e spark", 1.0), @@ -92,5 +94,6 @@ selected = prediction.select("id", "text", "probability", "prediction") for row in selected.collect(): print(row) + # $example off$ sc.stop() From e9533b419e3a87589313350310890ce0caf73dbb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 23 Feb 2016 18:19:22 -0800 Subject: [PATCH 1998/2089] [SPARK-13376] [SQL] improve column pruning ## What changes were proposed in this pull request? This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset). ## How was the this patch tested? This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s). Author: Davies Liu Closes #11256 from davies/fix_column_pruning. --- .../sql/catalyst/optimizer/Optimizer.scala | 128 ++++++++---------- .../optimizer/ColumnPruningSuite.scala | 128 +++++++++++++++++- .../optimizer/FilterPushdownSuite.scala | 80 ----------- .../columnar/InMemoryColumnarTableScan.scala | 7 +- 4 files changed, 187 insertions(+), 156 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1f05f2065c949..2b804976f3f11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(projects, output, child)) - if (e.outputSet -- a.references).nonEmpty => - val newOutput = output.filter(a.references.contains(_)) - val newProjects = projects.map { proj => - proj.zip(output).filter { case (e, a) => + // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) + case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + p.copy( + child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + projectList = w.projectList.filter(p.references.contains), + windowExpressions = w.windowExpressions.filter(p.references.contains))) + case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + val newOutput = e.output.filter(a.references.contains(_)) + val newProjects = e.projections.map { proj => + proj.zip(e.output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, child)) + a.copy(child = Expand(newProjects, newOutput, grandChild)) + // TODO: support some logical plan for Dataset - case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- e.references -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) - - // Eliminate attributes that are not needed to calculate the specified aggregates. + // Prunes the unused columns from child of Aggregate/Window/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = Project(a.references.toSeq, child)) - - // Eliminate attributes that are not needed to calculate the Generate. + a.copy(child = prunedChild(child, a.references)) + case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => + w.copy(child = prunedChild(child, w.references)) + case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = Project(g.references.toSeq, g.child)) + g.copy(child = prunedChild(g.child, g.references)) + // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - case p @ Project(projectList, g: Generate) if g.join => - val neededChildOutput = p.references -- g.generatorOutput ++ g.references - if (neededChildOutput == g.child.outputSet) { - p + // Eliminate unneeded attributes from right side of a LeftSemiJoin. + case j @ Join(left, right, LeftSemi, condition) => + j.copy(right = prunedChild(right, j.references)) + + // all the columns will be used to compare, so we can't prune them + case p @ Project(_, _: SetOperation) => p + case p @ Project(_, _: Distinct) => p + // Eliminate unneeded attributes from children of Union. + case p @ Project(_, u: Union) => + if ((u.outputSet -- p.references).nonEmpty) { + val firstChild = u.children.head + val newOutput = prunedChild(firstChild, p.references).output + // pruning the columns of all children based on the pruned first child. + val newChildren = u.children.map { p => + val selected = p.output.zipWithIndex.filter { case (a, i) => + newOutput.contains(firstChild.output(i)) + }.map(_._1) + Project(selected, p) + } + p.copy(child = u.withNewChildren(newChildren)) } else { - Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + p } - case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) - if (a.outputSet -- p.references).nonEmpty => - Project( - projectList, - Aggregate( - groupingExpressions, - aggregateExpressions.filter(e => p.references.contains(e)), - child)) - - // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => - // Collect the list of all references required either above or to evaluate the condition. - val allReferences: AttributeSet = - AttributeSet( - projectList.flatMap(_.references.iterator)) ++ - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) + // Can't prune the columns on LeafNode + case p @ Project(_, l: LeafNode) => p - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) - - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => - // Collect the list of all references required to evaluate the condition. - val allReferences: AttributeSet = - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - Join(left, prunedChild(right, allReferences), LeftSemi, condition) - - // Push down project through limit, so that we may have chance to push it further. - case Project(projectList, Limit(exp, child)) => - Limit(exp, Project(projectList, child)) - - // Push down project if possible when the child is sort. - case p @ Project(projectList, s @ Sort(_, _, grandChild)) => - if (s.references.subsetOf(p.outputSet)) { - s.copy(child = Project(projectList, grandChild)) + // Eliminate no-op Projects + case p @ Project(projectList, child) if child.output == p.output => child + + // for all other logical plans that inherits the output from it's children + case p @ Project(_, child) => + val required = child.references ++ p.references + if ((child.inputSet -- required).nonEmpty) { + val newChildren = child.children.map(c => prunedChild(c, required)) + p.copy(child = child.withNewChildren(newChildren)) } else { - val neededReferences = s.references ++ p.references - if (neededReferences == grandChild.outputSet) { - // No column we can prune, return the original plan. - p - } else { - // Do not use neededReferences.toSeq directly, should respect grandChild's output order. - val newProjectList = grandChild.output.filter(neededReferences.contains) - p.copy(child = s.copy(child = Project(newProjectList, grandChild))) - } + p } - - // Eliminate no-op Projects - case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + Project(c.output.filter(allReferences.contains), c) } else { c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index c890fffc40167..715d01a3cd876 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest { Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), - Project(Seq('c, 'a), + Project(Seq('a, 'c), input))).analyze comparePlans(optimized, expected) } + test("Column pruning on Filter") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze + val expected = + Project('a :: Nil, + Filter('c > Literal(0.0), + Project(Seq('a, 'c), input))).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("Column pruning on except/intersect/distinct") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Except(input, input)).analyze + comparePlans(Optimize.execute(query), query) + + val query2 = Project('a :: Nil, Intersect(input, input)).analyze + comparePlans(Optimize.execute(query2), query2) + val query3 = Project('a :: Nil, Distinct(input)).analyze + comparePlans(Optimize.execute(query3), query3) + } + + test("Column pruning on Project") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze + val expected = Project(Seq('a), input).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("column pruning for group") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("push down project past sort") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + + test("Column pruning on Union") { + val input1 = LocalRelation('a.int, 'b.string, 'c.double) + val input2 = LocalRelation('c.int, 'd.string, 'e.double) + val query = Project('b :: Nil, + Union(input1 :: input2 :: Nil)).analyze + val expected = Project('b :: Nil, + Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + comparePlans(Optimize.execute(query), expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70b34cbb246ab..7d60862f5ae70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, - ColumnPruning, CollapseProject) :: Nil } @@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("column pruning for group") { - val originalQuery = - testRelation - .groupBy('a)('a, count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val originalQuery = - testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push down project past sort") { - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) - } - test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 4858140229d45..22d427808593e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation( @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient private[sql] var _statistics: Statistics = null, private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends LogicalPlan with MultiInstanceRelation { + extends logical.LeafNode with MultiInstanceRelation { override def producedAttributes: AttributeSet = outputSet @@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } - override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), From 86c852cf2ecdb2d348dad089640b46d1d97d7ead Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 23 Feb 2016 21:22:00 -0800 Subject: [PATCH 1999/2089] [SPARK-13431] [SQL] [test-maven] split keywords from ExpressionParser.g ## What changes were proposed in this pull request? This PR pull all the keywords (and some others) from ExpressionParser.g as KeywordParser.g, because ExpressionParser is too large to compile. ## How was the this patch tested? unit test, maven build Closes #11329 Author: Davies Liu Closes #11331 from davies/split_expr. --- .../sql/catalyst/parser/ExpressionParser.g | 195 -------------- .../spark/sql/catalyst/parser/KeywordParser.g | 244 ++++++++++++++++++ .../sql/catalyst/parser/SparkSqlParser.g | 2 +- 3 files changed, 245 insertions(+), 196 deletions(-) create mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index 10f2e2416bb64..13a6a2d276a57 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -398,198 +398,3 @@ precedenceOrExpression : precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* ; - - -booleanValue - : - KW_TRUE^ | KW_FALSE^ - ; - -booleanValueTok - : - KW_TRUE -> TOK_TRUE - | KW_FALSE -> TOK_FALSE - ; - -tableOrPartition - : - tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) - ; - -partitionSpec - : - KW_PARTITION - LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) - ; - -partitionVal - : - identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) - ; - -dropPartitionSpec - : - KW_PARTITION - LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) - ; - -dropPartitionVal - : - identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) - ; - -dropPartitionOperator - : - EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -sysFuncNames - : - KW_AND - | KW_OR - | KW_NOT - | KW_LIKE - | KW_IF - | KW_CASE - | KW_WHEN - | KW_TINYINT - | KW_SMALLINT - | KW_INT - | KW_BIGINT - | KW_FLOAT - | KW_DOUBLE - | KW_BOOLEAN - | KW_STRING - | KW_BINARY - | KW_ARRAY - | KW_MAP - | KW_STRUCT - | KW_UNIONTYPE - | EQUAL - | EQUAL_NS - | NOTEQUAL - | LESSTHANOREQUALTO - | LESSTHAN - | GREATERTHANOREQUALTO - | GREATERTHAN - | DIVIDE - | PLUS - | MINUS - | STAR - | MOD - | DIV - | AMPERSAND - | TILDE - | BITWISEOR - | BITWISEXOR - | KW_RLIKE - | KW_REGEXP - | KW_IN - | KW_BETWEEN - ; - -descFuncNames - : - (sysFuncNames) => sysFuncNames - | StringLiteral - | functionIdentifier - ; - -//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. -looseIdentifier - : - Identifier - | looseNonReserved -> Identifier[$looseNonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -identifier - : - Identifier - | nonReserved -> Identifier[$nonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -functionIdentifier -@init { gParent.pushMsg("function identifier", state); } -@after { gParent.popMsg(state); } - : - identifier (DOT identifier)? -> identifier+ - ; - -principalIdentifier -@init { gParent.pushMsg("identifier for principal spec", state); } -@after { gParent.popMsg(state); } - : identifier - | QuotedIdentifier - ; - -looseNonReserved - : nonReserved | KW_FROM | KW_TO - ; - -//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved -//Non reserved keywords are basically the keywords that can be used as identifiers. -//All the KW_* are automatically not only keywords, but also reserved keywords. -//That means, they can NOT be used as identifiers. -//If you would like to use them as identifiers, put them in the nonReserved list below. -//If you are not sure, please refer to the SQL2011 column in -//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html -nonReserved - : - KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS - | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS - | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY - | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY - | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE - | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT - | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE - | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR - | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG - | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE - | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY - | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER - | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE - | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED - | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED - | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED - | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET - | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR - | KW_WORK - | KW_TRANSACTION - | KW_WRITE - | KW_ISOLATION - | KW_LEVEL - | KW_SNAPSHOT - | KW_AUTOCOMMIT - | KW_ANTI - | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND - | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS -; - -//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. -sql11ReservedKeywordsUsedAsCastFunctionName - : - KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP - ; - -//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. -//We are planning to remove the following whole list after several releases. -//Thus, please do not change the following list unless you know what to do. -sql11ReservedKeywordsUsedAsIdentifier - : - KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN - | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE - | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT - | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL - | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION - | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT - | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE - | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH -//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. - | KW_REGEXP | KW_RLIKE - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g new file mode 100644 index 0000000000000..12cd5f54a0297 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g @@ -0,0 +1,244 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. +*/ + +parser grammar KeywordParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. +looseIdentifier + : + Identifier + | looseNonReserved -> Identifier[$looseNonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : + identifier (DOT identifier)? -> identifier+ + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +looseNonReserved + : nonReserved | KW_FROM | KW_TO + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI + | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND + | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 9aeea69cd2d9b..3010d3ae0dcac 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -26,7 +26,7 @@ ASTLabelType=CommonTree; backtrack=false; k=3; } -import SelectClauseParser, FromClauseParser, IdentifiersParser, ExpressionParser; +import SelectClauseParser, FromClauseParser, IdentifiersParser, KeywordParser, ExpressionParser; tokens { TOK_INSERT; From 831429170f39743e9f9dca92ce6f064eda67392b Mon Sep 17 00:00:00 2001 From: Rahul Tanwani Date: Wed, 24 Feb 2016 00:45:31 -0800 Subject: [PATCH 2000/2089] [MINOR][MAINTENANCE] Fix typo for the pull request template. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was the this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Rahul Tanwani Closes #11343 from tanwanirahul/pull_request_template. --- .github/PULL_REQUEST_TEMPLATE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 39bcc84225b4a..989e95ccd0135 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -3,7 +3,7 @@ (Please fill in changes proposed in this fix) -## How was the this patch tested? +## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) From bcfd55fa982b24184c07fcd4ccdd55dcf6465bf4 Mon Sep 17 00:00:00 2001 From: Daniel Jalova Date: Wed, 24 Feb 2016 12:15:11 +0000 Subject: [PATCH 2001/2089] [SPARK-12759][Core][Spark should fail fast if --executor-memory is too small for spark to start] Added an exception to be thrown in UnifiedMemoryManager.scala if the configuration given for executor memory is too low. Also modified the exception message thrown when driver memory is too low. This patch was tested manually by passing in config options to Spark shell. I also added a test in UnifiedMemoryManagerSuite.scala Author: Daniel Jalova Closes #11255 from djalova/SPARK-12759. --- .../spark/memory/UnifiedMemoryManager.scala | 12 ++++++++++- .../memory/UnifiedMemoryManagerSuite.scala | 20 ++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index a3321e3f179f6..6c57c98ea5c59 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -183,7 +183,17 @@ object UnifiedMemoryManager { val minSystemMemory = reservedMemory * 1.5 if (systemMemory < minSystemMemory) { throw new IllegalArgumentException(s"System memory $systemMemory must " + - s"be at least $minSystemMemory. Please use a larger heap size.") + s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " + + s"option or spark.driver.memory in Spark configuration.") + } + // SPARK-12759 Check executor memory to fail fast if memory is insufficient + if (conf.contains("spark.executor.memory")) { + val executorMemory = conf.getSizeAsBytes("spark.executor.memory") + if (executorMemory < minSystemMemory) { + throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " + + s"$minSystemMemory. Please increase executor memory using the " + + s"--executor-memory option or spark.executor.memory in Spark configuration.") + } } val usableMemory = systemMemory - reservedMemory val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 0c4359c3c2cd5..9686c6621b465 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -227,7 +227,25 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val exception = intercept[IllegalArgumentException] { UnifiedMemoryManager(conf2, numCores = 1) } - assert(exception.getMessage.contains("larger heap size")) + assert(exception.getMessage.contains("increase heap size")) + } + + test("insufficient executor memory") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + + // Try using an executor memory that's too small + val conf2 = conf.clone().set("spark.executor.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("increase executor memory")) } test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { From 89301818334185fd4f9881e5c0b123be94018e76 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Feb 2016 07:05:20 -0800 Subject: [PATCH 2002/2089] [SPARK-13472] [SPARKR] Fix unstable Kmeans test in R JIRA: https://issues.apache.org/jira/browse/SPARK-13472 ## What changes were proposed in this pull request? One Kmeans test in R is unstable and sometimes fails. We should fix it. ## How was this patch tested? Unit test is modified in this PR. Author: Liang-Chi Hsieh Closes #11345 from viirya/fix-kmeans-r-test and squashes the following commits: f959f61 [Liang-Chi Hsieh] Sort resulted clusters. --- R/pkg/inst/tests/testthat/test_mllib.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 595512e0e0d3a..af84a0abcf94d 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -130,7 +130,7 @@ test_that("kmeans", { # Test stats::kmeans is working statsModel <- kmeans(x = newIris, centers = 2) - expect_equal(unique(statsModel$cluster), c(1, 2)) + expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) # Test fitted works on KMeans fitted.model <- fitted(model) From f3739869973ba4285196a61775d891292b8e282b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Feb 2016 10:22:40 -0800 Subject: [PATCH 2003/2089] [SPARK-13383][SQL] Keep broadcast hint after column pruning JIRA: https://issues.apache.org/jira/browse/SPARK-13383 ## What changes were proposed in this pull request? When we do column pruning in Optimizer, we put additional Project on top of a logical plan. However, when we already wrap a BroadcastHint on a logical plan, the added Project will hide BroadcastHint after later execution. We should take care of BroadcastHint when we do column pruning. ## How was the this patch tested? Unit test is added. Author: Liang-Chi Hsieh Closes #11260 from viirya/keep-broadcasthint. --- .../plans/logical/basicOperators.scala | 4 +++ ...uite.scala => JoinOptimizationSuite.scala} | 35 ++++++++++++++++--- .../spark/sql/execution/SparkStrategies.scala | 12 ++++--- 3 files changed, 42 insertions(+), 9 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{JoinOrderSuite.scala => JoinOptimizationSuite.scala} (77%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af43cb3786f8c..5d2a65b716b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -332,6 +332,10 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + // We manually set statistics of BroadcastHint to smallest value to make sure + // the plan wrapped by BroadcastHint will be considered to broadcast later. + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala similarity index 77% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a5b487bcc822f..d482519827cc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -class JoinOrderSuite extends PlanTest { +class JoinOptimizationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, BooleanSimplification, @@ -92,4 +92,31 @@ class JoinOrderSuite extends PlanTest { comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } + + test("broadcasthint sets relation statistics to smallest value") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + + val optimized = Optimize.execute(query) + + val expected = + Project(Seq($"x.key", $"y.key"), + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint( + Project(Seq($"y.key"), SubqueryAlias("y", input))), + Inner, None)).analyze + + comparePlans(optimized, expected) + + val broadcastChildren = optimized.collect { + case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r + } + assert(broadcastChildren.size == 1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7347156398674..247eb054a8beb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -81,11 +81,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - case BroadcastHint(p) => Some(p) - case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) - case _ => None + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + Some(plan) + } else { + None + } } } From 382b27babf7771b724f7abff78195a858631d138 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Feb 2016 11:58:12 -0800 Subject: [PATCH 2004/2089] Revert "[SPARK-13383][SQL] Keep broadcast hint after column pruning" This reverts commit f3739869973ba4285196a61775d891292b8e282b. --- .../plans/logical/basicOperators.scala | 4 --- ...zationSuite.scala => JoinOrderSuite.scala} | 35 +++---------------- .../spark/sql/execution/SparkStrategies.scala | 12 +++---- 3 files changed, 9 insertions(+), 42 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{JoinOptimizationSuite.scala => JoinOrderSuite.scala} (77%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5d2a65b716b24..af43cb3786f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -332,10 +332,6 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - - // We manually set statistics of BroadcastHint to smallest value to make sure - // the plan wrapped by BroadcastHint will be considered to broadcast later. - override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala similarity index 77% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala index d482519827cc1..a5b487bcc822f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class JoinOptimizationSuite extends PlanTest { +class JoinOrderSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", FixedPoint(100), + Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, BooleanSimplification, @@ -92,31 +92,4 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } - - test("broadcasthint sets relation statistics to smallest value") { - val input = LocalRelation('key.int, 'value.string) - - val query = - Project(Seq($"x.key", $"y.key"), - Join( - SubqueryAlias("x", input), - BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze - - val optimized = Optimize.execute(query) - - val expected = - Project(Seq($"x.key", $"y.key"), - Join( - Project(Seq($"x.key"), SubqueryAlias("x", input)), - BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input))), - Inner, None)).analyze - - comparePlans(optimized, expected) - - val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r - } - assert(broadcastChildren.size == 1) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 247eb054a8beb..7347156398674 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -81,13 +81,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { - Some(plan) - } else { - None - } + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + case BroadcastHint(p) => Some(p) + case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) + case _ => None } } From d563c8fa01cfaebb5899ff7970115d0f2e64e8d5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Feb 2016 11:58:32 -0800 Subject: [PATCH 2005/2089] Revert "[SPARK-13376] [SQL] improve column pruning" This reverts commit e9533b419e3a87589313350310890ce0caf73dbb. --- .../sql/catalyst/optimizer/Optimizer.scala | 128 ++++++++++-------- .../optimizer/ColumnPruningSuite.scala | 128 +----------------- .../optimizer/FilterPushdownSuite.scala | 80 +++++++++++ .../columnar/InMemoryColumnarTableScan.scala | 7 +- 4 files changed, 156 insertions(+), 187 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2b804976f3f11..1f05f2065c949 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -313,85 +313,97 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand - case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => - p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) - case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => - p.copy( - child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - p.copy(child = w.copy( - projectList = w.projectList.filter(p.references.contains), - windowExpressions = w.windowExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => - val newOutput = e.output.filter(a.references.contains(_)) - val newProjects = e.projections.map { proj => - proj.zip(e.output).filter { case (e, a) => + case a @ Aggregate(_, _, e @ Expand(projects, output, child)) + if (e.outputSet -- a.references).nonEmpty => + val newOutput = output.filter(a.references.contains(_)) + val newProjects = projects.map { proj => + proj.zip(output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, grandChild)) - // TODO: support some logical plan for Dataset + a.copy(child = Expand(newProjects, newOutput, child)) - // Prunes the unused columns from child of Aggregate/Window/Expand/Generate + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- e.references -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) + + // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = prunedChild(child, a.references)) - case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => - w.copy(child = prunedChild(child, w.references)) - case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => - e.copy(child = prunedChild(child, e.references)) + a.copy(child = Project(a.references.toSeq, child)) + + // Eliminate attributes that are not needed to calculate the Generate. case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = prunedChild(g.child, g.references)) + g.copy(child = Project(g.references.toSeq, g.child)) - // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case j @ Join(left, right, LeftSemi, condition) => - j.copy(right = prunedChild(right, j.references)) - - // all the columns will be used to compare, so we can't prune them - case p @ Project(_, _: SetOperation) => p - case p @ Project(_, _: Distinct) => p - // Eliminate unneeded attributes from children of Union. - case p @ Project(_, u: Union) => - if ((u.outputSet -- p.references).nonEmpty) { - val firstChild = u.children.head - val newOutput = prunedChild(firstChild, p.references).output - // pruning the columns of all children based on the pruned first child. - val newChildren = u.children.map { p => - val selected = p.output.zipWithIndex.filter { case (a, i) => - newOutput.contains(firstChild.output(i)) - }.map(_._1) - Project(selected, p) - } - p.copy(child = u.withNewChildren(newChildren)) - } else { + case p @ Project(projectList, g: Generate) if g.join => + val neededChildOutput = p.references -- g.generatorOutput ++ g.references + if (neededChildOutput == g.child.outputSet) { p + } else { + Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) } - // Can't prune the columns on LeafNode - case p @ Project(_, l: LeafNode) => p + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) + if (a.outputSet -- p.references).nonEmpty => + Project( + projectList, + Aggregate( + groupingExpressions, + aggregateExpressions.filter(e => p.references.contains(e)), + child)) - // Eliminate no-op Projects - case p @ Project(projectList, child) if child.output == p.output => child - - // for all other logical plans that inherits the output from it's children - case p @ Project(_, child) => - val required = child.references ++ p.references - if ((child.inputSet -- required).nonEmpty) { - val newChildren = child.children.map(c => prunedChild(c, required)) - p.copy(child = child.withNewChildren(newChildren)) + // Eliminate unneeded attributes from either side of a Join. + case Project(projectList, Join(left, right, joinType, condition)) => + // Collect the list of all references required either above or to evaluate the condition. + val allReferences: AttributeSet = + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + + /** Applies a projection only when the child is producing unnecessary attributes */ + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) + + Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) + + // Eliminate unneeded attributes from right side of a LeftSemiJoin. + case Join(left, right, LeftSemi, condition) => + // Collect the list of all references required to evaluate the condition. + val allReferences: AttributeSet = + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + + Join(left, prunedChild(right, allReferences), LeftSemi, condition) + + // Push down project through limit, so that we may have chance to push it further. + case Project(projectList, Limit(exp, child)) => + Limit(exp, Project(projectList, child)) + + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) } else { - p + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } } + + // Eliminate no-op Projects + case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(c.output.filter(allReferences.contains), c) + Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 715d01a3cd876..c890fffc40167 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -120,134 +119,11 @@ class ColumnPruningSuite extends PlanTest { Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), - Project(Seq('a, 'c), + Project(Seq('c, 'a), input))).analyze comparePlans(optimized, expected) } - test("Column pruning on Filter") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze - val expected = - Project('a :: Nil, - Filter('c > Literal(0.0), - Project(Seq('a, 'c), input))).analyze - comparePlans(Optimize.execute(query), expected) - } - - test("Column pruning on except/intersect/distinct") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Except(input, input)).analyze - comparePlans(Optimize.execute(query), query) - - val query2 = Project('a :: Nil, Intersect(input, input)).analyze - comparePlans(Optimize.execute(query2), query2) - val query3 = Project('a :: Nil, Distinct(input)).analyze - comparePlans(Optimize.execute(query3), query3) - } - - test("Column pruning on Project") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze - val expected = Project(Seq('a), input).analyze - comparePlans(Optimize.execute(query), expected) - } - - test("column pruning for group") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val originalQuery = - testRelation - .groupBy('a)('a, count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - - val originalQuery = - testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - - test("push down project past sort") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) - } - - test("Column pruning on Union") { - val input1 = LocalRelation('a.int, 'b.string, 'c.double) - val input2 = LocalRelation('c.int, 'd.string, 'e.double) - val query = Project('b :: Nil, - Union(input1 :: input2 :: Nil)).analyze - val expected = Project('b :: Nil, - Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze - comparePlans(Optimize.execute(query), expected) - } - // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 7d60862f5ae70..70b34cbb246ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -41,6 +41,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + ColumnPruning, CollapseProject) :: Nil } @@ -64,6 +65,52 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for group") { + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -557,6 +604,39 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("push down project past sort") { + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 22d427808593e..4858140229d45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -64,7 +63,7 @@ private[sql] case class InMemoryRelation( @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient private[sql] var _statistics: Statistics = null, private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends logical.LeafNode with MultiInstanceRelation { + extends LogicalPlan with MultiInstanceRelation { override def producedAttributes: AttributeSet = outputSet @@ -185,6 +184,8 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } + override def children: Seq[LogicalPlan] = Seq.empty + override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), From 65805ab6eaa6b0ed1dc7f6cbd3c3368eb750b375 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Feb 2016 12:03:45 -0800 Subject: [PATCH 2006/2089] Revert "Revert "[SPARK-13383][SQL] Keep broadcast hint after column pruning"" This reverts commit 382b27babf7771b724f7abff78195a858631d138. --- .../plans/logical/basicOperators.scala | 4 +++ ...uite.scala => JoinOptimizationSuite.scala} | 35 ++++++++++++++++--- .../spark/sql/execution/SparkStrategies.scala | 12 ++++--- 3 files changed, 42 insertions(+), 9 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{JoinOrderSuite.scala => JoinOptimizationSuite.scala} (77%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af43cb3786f8c..5d2a65b716b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -332,6 +332,10 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + // We manually set statistics of BroadcastHint to smallest value to make sure + // the plan wrapped by BroadcastHint will be considered to broadcast later. + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala similarity index 77% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a5b487bcc822f..d482519827cc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -class JoinOrderSuite extends PlanTest { +class JoinOptimizationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, BooleanSimplification, @@ -92,4 +92,31 @@ class JoinOrderSuite extends PlanTest { comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } + + test("broadcasthint sets relation statistics to smallest value") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + + val optimized = Optimize.execute(query) + + val expected = + Project(Seq($"x.key", $"y.key"), + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint( + Project(Seq($"y.key"), SubqueryAlias("y", input))), + Inner, None)).analyze + + comparePlans(optimized, expected) + + val broadcastChildren = optimized.collect { + case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r + } + assert(broadcastChildren.size == 1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7347156398674..247eb054a8beb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -81,11 +81,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - case BroadcastHint(p) => Some(p) - case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) - case _ => None + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + Some(plan) + } else { + None + } } } From f92f53faeea020d80638a06752d69ca7a949cdeb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Feb 2016 12:25:02 -0800 Subject: [PATCH 2007/2089] Revert "[SPARK-13321][SQL] Support nested UNION in parser" This reverts commit 55d6fdf22d1d6379180ac09f364c38982897d9ff. --- .../sql/catalyst/parser/SparkSqlParser.g | 20 ------ .../spark/sql/catalyst/CatalystQlSuite.scala | 64 ------------------- 2 files changed, 84 deletions(-) diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 3010d3ae0dcac..1db3aed65815d 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -2320,26 +2320,6 @@ regularBody[boolean topLevel] ) | selectStatement[topLevel] - | - (LPAREN selectStatement0[true]) => nestedSetOpSelectStatement[topLevel] - ; - -nestedSetOpSelectStatement[boolean topLevel] - : - ( - LPAREN s=selectStatement0[topLevel] RPAREN -> {$s.tree} - ) - (set=setOpSelectStatement[$nestedSetOpSelectStatement.tree, topLevel]) - -> {set == null}? - {$nestedSetOpSelectStatement.tree} - -> {$set.tree} - ; - -selectStatement0[boolean topLevel] - : - (selectStatement[true]) => selectStatement[topLevel] - | - (nestedSetOpSelectStatement[true]) => nestedSetOpSelectStatement[topLevel] ; selectStatement[boolean topLevel] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 812aa5acd8e56..53a8d6e53e38a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -202,70 +202,6 @@ class CatalystQlSuite extends PlanTest { "from windowData") } - test("nesting UNION") { - val parsed = parser.parsePlan( - """ - |SELECT `u_1`.`id` FROM (((SELECT `t0`.`id` FROM `default`.`t0`) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL - |(SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - """.stripMargin) - - val expected = Project( - UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, - SubqueryAlias("u_1", - Union( - Union( - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None)), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))))) - - comparePlans(parsed, expected) - - val parsedSame = parser.parsePlan( - """ - |SELECT `u_1`.`id` FROM ((SELECT `t0`.`id` FROM `default`.`t0`) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`) UNION ALL - |(SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - """.stripMargin) - - comparePlans(parsedSame, expected) - - val parsed2 = parser.parsePlan( - """ - |SELECT `u_1`.`id` FROM ((((SELECT `t0`.`id` FROM `default`.`t0`) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) UNION ALL - |(SELECT `t0`.`id` FROM `default`.`t0`)) - |UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - """.stripMargin) - - val expected2 = Project( - UnresolvedAlias(UnresolvedAttribute("u_1.id"), None) :: Nil, - SubqueryAlias("u_1", - Union( - Union( - Union( - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None)), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))), - Project( - UnresolvedAlias(UnresolvedAttribute("t0.id"), None) :: Nil, - UnresolvedRelation(TableIdentifier("t0", Some("default")), None))))) - - comparePlans(parsed2, expected2) - } - test("subquery") { parser.parsePlan("select (select max(b) from s) ss from t") parser.parsePlan("select * from t where a = (select b from s)") From a60f91284ceee64de13f04559ec19c13a820a133 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Feb 2016 12:44:54 -0800 Subject: [PATCH 2008/2089] [SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code ## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan Closes #11342 from cloud-fan/python-clean. --- .../apache/spark/api/python/PythonRDD.scala | 37 +++++++++++-------- python/pyspark/rdd.py | 23 +++++++----- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/functions.py | 8 ++-- .../apache/spark/sql/UDFRegistration.scala | 8 ++-- .../python/BatchPythonEvaluation.scala | 8 +--- .../sql/execution/python/PythonUDF.scala | 13 ++----- .../python/UserDefinedPythonFunction.scala | 15 ++------ 8 files changed, 51 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f12e2dfafa19d..05d1c31a08f22 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class PythonRDD( parent: RDD[_], - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + func: PythonFunction, + preservePartitoning: Boolean) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) @@ -64,29 +58,37 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner( - command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, - bufferSize, reuse_worker) + val runner = new PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } - /** - * A helper class to run Python UDFs in Spark. + * A wrapper for a Python function, contains all necessary context to run the function in Python + * runner. */ -private[spark] class PythonRunner( +private[spark] case class PythonFunction( command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + func: PythonFunction, bufferSize: Int, reuse_worker: Boolean) extends Logging { + private val envVars = func.envVars + private val pythonExec = func.pythonExec + private val accumulator = func.accumulator + def compute( inputIterator: Iterator[_], partitionIndex: Int, @@ -225,6 +227,11 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null + private val pythonVer = func.pythonVer + private val pythonIncludes = func.pythonIncludes + private val broadcastVars = func.broadcastVars + private val command = func.command + setDaemon(true) /** Contains the exception thrown while writing the parent iterator to the Python process. */ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4eaf589ad5d46..37574cea0b687 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2309,7 +2309,7 @@ def toLocalIterator(self): yield row -def _prepare_for_python_RDD(sc, command, obj=None): +def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) @@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ @@ -2390,14 +2399,10 @@ def _jrdd(self): else: profiler = None - command = (self.func, profiler, self._prev_jrdd_deserializer, - self._jrdd_deserializer) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_cmd), - env, includes, self.preservesPartitioning, - self.ctx.pythonExec, self.ctx.pythonVer, - bvars, self.ctx._javaAccumulator) + wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer, profiler) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, + self.preservesPartitioning) self._jrdd_val = python_rdd.asJavaRDD() if profiler: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 89bf1443a68a1..87e32c04eac1c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,7 +29,7 @@ from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6894c2733828c..b30cc6799eb97 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import _wrap_function, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1645,16 +1645,14 @@ def _create_judf(self, name): f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) sc = SparkContext.getOrCreate() - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + wrapped_func = _wrap_function(sc, func, ser, ser) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, - broadcast_vars, sc._javaAccumulator, jdt) + name, wrapped_func, jdt) return judf def __del__(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ecfc170bee3ad..de01cbcb0e1b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { s""" | Registering new PythonUDF: | name: $name - | command: ${udf.command.toSeq} - | envVars: ${udf.envVars} - | pythonIncludes: ${udf.pythonIncludes} - | pythonExec: ${udf.pythonExec} + | command: ${udf.func.command.toSeq} + | envVars: ${udf.func.envVars} + | pythonIncludes: ${udf.func.pythonIncludes} + | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} """.stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 00df0195279c3..c65a7bcff8503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // Output iterator for results from Python. val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, + udf.func, bufferSize, reuseWorker ).compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9aff0be716d4b..0aa2785cb6b9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType @@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType */ case class PythonUDF( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType, children: Seq[Expression]) extends Expression with Unevaluable with NonSQLExpression with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 79ac1c85c0be5..d301874c223d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.Accumulator -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.Column import org.apache.spark.sql.types.DataType @@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType */ case class UserDefinedPythonFunction( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) + PythonUDF(name, func, dataType, e) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ From 5289837a72169f3f53be4bd576490b555828e03e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 24 Feb 2016 13:30:23 -0800 Subject: [PATCH 2009/2089] [HOT][TEST] Disable a Test that Requires Nested Union Support. ## What changes were proposed in this pull request? Since "[SPARK-13321][SQL] Support nested UNION in parser" is reverted, we need to disable the test case that requires this PR. Thanks! rxin yhuai marmbrus ## How was this patch tested? N/A Author: gatorsmile Closes #11352 from gatorsmile/disableTestCase. --- .../scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index fa78f5a425de3..f9a5cf4d4781e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -130,7 +130,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // FROM (((SELECT `t0`.`id` FROM `default`.`t0`) // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - test("three-child union") { + ignore("three-child union") { checkHiveQl( """ |SELECT id FROM parquet_t0 From bc353805bd98243872d520e05fa6659da57170bf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 24 Feb 2016 13:34:53 -0800 Subject: [PATCH 2010/2089] [SPARK-13475][TESTS][SQL] HiveCompatibilitySuite should still run in PR builder even if a PR only changes sql/core ## What changes were proposed in this pull request? `HiveCompatibilitySuite` should still run in PR build even if a PR only changes sql/core. So, I am going to remove `ExtendedHiveTest` annotation from `HiveCompatibilitySuite`. https://issues.apache.org/jira/browse/SPARK-13475 Author: Yin Huai Closes #11351 from yhuai/SPARK-13475. --- .../spark/sql/hive/execution/HiveCompatibilitySuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9097c1a1d3117..973ba6bf38b3e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -25,12 +25,10 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.tags.ExtendedHiveTest /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( From cbb0b65ad53f2642fab8aad3ea115375c53b6eed Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 24 Feb 2016 16:13:55 -0800 Subject: [PATCH 2011/2089] [SPARK-13383][SQL] Fix test ## What changes were proposed in this pull request? Reverting SPARK-13376 (https://github.com/apache/spark/commit/d563c8fa01cfaebb5899ff7970115d0f2e64e8d5) affects the test added by SPARK-13383. So, I am fixing the test. Author: Yin Huai Closes #11355 from yhuai/SPARK-13383-fix-test. --- .../spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index d482519827cc1..1ab53a12572a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -108,8 +108,7 @@ class JoinOptimizationSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input))), + Project(Seq($"y.key"), BroadcastHint(SubqueryAlias("y", input))), Inner, None)).analyze comparePlans(optimized, expected) From 5a7af9e7ac85e04aa4a420bc2887207bfa18f792 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 24 Feb 2016 17:16:45 -0800 Subject: [PATCH 2012/2089] [SPARK-13250] [SQL] Update PhysicallRDD to convert to UnsafeRow if using the vectorized scanner. Some parts of the engine rely on UnsafeRow which the vectorized parquet scanner does not want to produce. This add a conversion in Physical RDD. In the case where codegen is used (and the scan is the start of the pipeline), there is no requirement to use UnsafeRow. This patch adds update PhysicallRDD to support codegen, which eliminates the need for the UnsafeRow conversion in all cases. The result of these changes for TPCDS-Q19 at the 10gb sf reduces the query time from 9.5 seconds to 6.5 seconds. Author: Nong Li Closes #11141 from nongli/spark-13250. --- python/pyspark/sql/dataframe.py | 3 +- .../spark/sql/execution/ExistingRDD.scala | 45 ++++++++++++++-- .../sql/execution/WholeStageCodegen.scala | 1 + .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 ++- .../spark/sql/sources/FilteredScanSuite.scala | 54 +++++++++++-------- .../spark/sql/sources/PrunedScanSuite.scala | 45 +++++++++------- .../spark/sql/sources/BucketedReadSuite.scala | 46 ++++++++-------- 7 files changed, 129 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bf43452e08c3b..7275e69353485 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -173,7 +173,8 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan ExistingRDD[age#0,name#1] + WholeStageCodegen + : +- Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index cad7e25a32788..8649d2d69b62e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType @@ -102,7 +104,7 @@ private[sql] case class PhysicalRDD( override val metadata: Map[String, String] = Map.empty, isUnsafeRow: Boolean = false, override val outputPartitioning: Partitioning = UnknownPartitioning(0)) - extends LeafNode { + extends LeafNode with CodegenSupport { private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -128,6 +130,36 @@ private[sql] case class PhysicalRDD( val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" } + + override def upstreams(): Seq[RDD[InternalRow]] = { + rdd :: Nil + } + + // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen + // never requires UnsafeRow as input. + override protected def doProduce(ctx: CodegenContext): String = { + val input = ctx.freshName("input") + // PhysicalRDD always just has one input + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + val row = ctx.freshName("row") + val numOutputRows = metricTerm(ctx, "numOutputRows") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns = exprs.map(_.gen(ctx)) + s""" + | while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} + | if (shouldStop()) { + | return; + | } + | } + """.stripMargin + } } private[sql] object PhysicalRDD { @@ -140,8 +172,13 @@ private[sql] object PhysicalRDD { rdd: RDD[InternalRow], relation: BaseRelation, metadata: Map[String, String] = Map.empty): PhysicalRDD = { - // All HadoopFsRelations output UnsafeRows - val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] + val outputUnsafeRows = if (relation.isInstanceOf[ParquetRelation]) { + // The vectorized parquet reader does not produce unsafe rows. + !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + } else { + // All HadoopFsRelations output UnsafeRows + relation.isInstanceOf[HadoopFsRelation] + } val bucketSpec = relation match { case r: HadoopFsRelation => r.getBucketSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 08c52e5f4399f..afaddcf35775c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -42,6 +42,7 @@ trait CodegenSupport extends SparkPlan { case _: TungstenAggregate => "agg" case _: BroadcastHashJoin => "bhj" case _: SortMergeJoin => "smj" + case _: PhysicalRDD => "rdd" case _ => nodeName.toLowerCase } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7a0f7abaa1baf..f8a9a95c873ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -187,8 +187,10 @@ class JDBCSuite extends SparkFunSuite val parentPlan = df.queryExecution.executedPlan // Check if SparkPlan Filter is removed in a physical plan and // the plan only has PhysicalRDD to scan JDBCRelation. - assert(parentPlan.isInstanceOf[PhysicalRDD]) - assert(parentPlan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] + assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) + assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) df } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 7196b6dc13394..4b578a6c10ec4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -304,30 +304,38 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic expectedCount: Int, requiredColumnNames: Set[String], expectedUnhandledFilters: Set[Filter]): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { - val queryExecution = sql(sqlString).queryExecution - val rawPlan = queryExecution.executedPlan.collect { - case p: execution.PhysicalRDD => p - } match { - case Seq(p) => p - case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") - } - val rawCount = rawPlan.execute().count() - assert(ColumnsRequired.set === requiredColumnNames) - - val table = caseInsensitiveContext.table("oneToTenFiltered") - val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _, _) => r - }.get - - assert( - relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) - - if (rawCount != expectedCount) { - fail( - s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) + // These tests check a particular plan, disable whole stage codegen. + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + try { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + assert(ColumnsRequired.set === requiredColumnNames) + + val table = caseInsensitiveContext.table("oneToTenFiltered") + val relation = table.queryExecution.logical.collectFirst { + case LogicalRelation(r, _, _) => r + }.get + + assert( + relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } finally { + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index a89c5f8007e78..02e8ab1c51746 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -117,28 +117,35 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution - val rawPlan = queryExecution.executedPlan.collect { - case p: execution.PhysicalRDD => p - } match { - case Seq(p) => p - case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") - } - val rawColumns = rawPlan.output.map(_.name) - val rawOutput = rawPlan.execute().first() - - if (rawColumns != expectedColumns) { - fail( - s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) - } - if (rawOutput.numFields != expectedColumns.size) { - fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + // These tests check a particular plan, disable whole stage codegen. + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + try { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.numFields != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } finally { + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 0ef42f45e3b78..62f6b998e6f2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -74,32 +74,34 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet bucketValues: Seq[Integer], filterCondition: Column, originalDataFrame: DataFrame): Unit = { + // This test verifies parts of the plan. Disable whole stage codegen. + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k") + val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec + // Limit: bucket pruning only works when the bucket column has one and only one column + assert(bucketColumnNames.length == 1) + val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) + val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + } - val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k") - val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec - // Limit: bucket pruning only works when the bucket column has one and only one column - assert(bucketColumnNames.length == 1) - val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) - val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) - val matchedBuckets = new BitSet(numBuckets) - bucketValues.foreach { value => - matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) - } + // Filter could hide the bug in bucket pruning. Thus, skipping all the filters + val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[PhysicalRDD]) + assert(rdd.isDefined, plan) - // Filter could hide the bug in bucket pruning. Thus, skipping all the filters - val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan - .find(_.isInstanceOf[PhysicalRDD]) - assert(rdd.isDefined) + val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty) + } + // checking if all the pruned buckets are empty + assert(checkedResult.collect().forall(_ == true)) - val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => - if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty) + checkAnswer( + bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), + originalDataFrame.filter(filterCondition).orderBy("i", "j", "k")) } - // checking if all the pruned buckets are empty - assert(checkedResult.collect().forall(_ == true)) - - checkAnswer( - bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), - originalDataFrame.filter(filterCondition).orderBy("i", "j", "k")) } test("read partitioning bucketed tables with bucket pruning filters") { From 2b042577fb077865c3fce69c9d4eda22fde92673 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 24 Feb 2016 19:43:00 -0800 Subject: [PATCH 2013/2089] [SPARK-13092][SQL] Add ExpressionSet for constraint tracking This PR adds a new abstraction called an `ExpressionSet` which attempts to canonicalize expressions to remove cosmetic differences. Deterministic expressions that are in the set after canonicalization will always return the same answer given the same input (i.e. false positives should not be possible). However, it is possible that two canonical expressions that are not equal will in fact return the same answer given any input (i.e. false negatives are possible). ```scala val set = AttributeSet('a + 1 :: 1 + 'a :: Nil) set.iterator => Iterator('a + 1) set.contains('a + 1) => true set.contains(1 + 'a) => true set.contains('a + 2) => false ``` Other relevant changes include: - Since this concept overlaps with the existing `semanticEquals` and `semanticHash`, those functions are also ported to this new infrastructure. - A memoized `canonicalized` version of the expression is added as a `lazy val` to `Expression` and is used by both `semanticEquals` and `ExpressionSet`. - A set of unit tests for `ExpressionSet` are added - Tests which expect `semanticEquals` to be less intelligent than it now is are updated. As a followup, we should consider auditing the places where we do `O(n)` `semanticEquals` operations and replace them with `ExpressionSet`. We should also consider consolidating `AttributeSet` as a specialized factory for an `ExpressionSet.` Author: Michael Armbrust Closes #11338 from marmbrus/expressionSet. --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../catalyst/expressions/Canonicalize.scala | 81 +++++++++++++++++ .../sql/catalyst/expressions/Expression.scala | 56 ++++-------- .../catalyst/expressions/ExpressionSet.scala | 87 ++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 12 +-- .../expressions/ExpressionSetSuite.scala | 89 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 3 +- 7 files changed, 285 insertions(+), 45 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1a2ec7ed931ea..a12f7396fe819 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -235,7 +235,7 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) - def nullable: AttributeReference = a.withNullability(true) + def canBeNull: AttributeReference = a.withNullability(true) def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala new file mode 100644 index 0000000000000..b58a5273041e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.rules._ + +/** + * Rewrites an expression using rules that are guaranteed preserve the result while attempting + * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization + * will always return the same answer given the same input (i.e. false positives should not be + * possible). However, it is possible that two canonical expressions that are not equal will in fact + * return the same answer given any input (i.e. false negatives are possible). + * + * The following rules are applied: + * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. + * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered + * by `hashCode`. +* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. + * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + */ +object Canonicalize extends RuleExecutor[Expression] { + override protected def batches: Seq[Batch] = + Batch( + "Expression Canonicalization", FixedPoint(100), + IgnoreNamesTypes, + Reorder) :: Nil + + /** Remove names and nullability from types. */ + protected object IgnoreNamesTypes extends Rule[Expression] { + override def apply(e: Expression): Expression = e transformUp { + case a: AttributeReference => + AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + } + } + + /** Collects adjacent commutative operations. */ + protected def gatherCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { + case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) + case other => other :: Nil + } + + /** Orders a set of commutative operations by their hash code. */ + protected def orderCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = + gatherCommutative(e, f).sortBy(_.hashCode()) + + /** Rearrange expressions that are commutative or associative. */ + protected object Reorder extends Rule[Expression] { + override def apply(e: Expression): Expression = e transformUp { + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) + case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) + + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 119496c7eee5e..692c16092fe3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -144,49 +144,32 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns an expression where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, etc.) See [[Canonicalize]] for more details. + * + * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always + * evaluate to the same result. + */ + lazy val canonicalized: Expression = Canonicalize.execute(this) + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). + * + * See [[Canonicalize]] for more details. */ - def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { - def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 - case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) - case (i1, i2) => i1 == i2 - } - } - // Non-deterministic expressions cannot be semantic equal - if (!deterministic || !other.deterministic) return false - val elements1 = this.productIterator.toSeq - val elements2 = other.asInstanceOf[Product].productIterator.toSeq - checkSemantic(elements1, elements2) - } + def semanticEquals(other: Expression): Boolean = + deterministic && other.deterministic && canonicalized == other.canonicalized /** - * Returns the hash for this expression. Expressions that compute the same result, even if - * they differ cosmetically should return the same hash. + * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + * + * See [[Canonicalize]] for more details. */ - def semanticHash() : Int = { - def computeHash(e: Seq[Any]): Int = { - // See http://stackoverflow.com/questions/113511/hash-code-implementation - var hash: Int = 17 - e.foreach(i => { - val h: Int = i match { - case e: Expression => e.semanticHash() - case Some(e: Expression) => e.semanticHash() - case t: Traversable[_] => computeHash(t.toSeq) - case null => 0 - case other => other.hashCode() - } - hash = hash * 37 + h - }) - hash - } - - computeHash(this.productIterator.toSeq) - } + def semanticHash(): Int = canonicalized.hashCode() /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, @@ -369,7 +352,6 @@ abstract class UnaryExpression extends Expression { } } - /** * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala new file mode 100644 index 0000000000000..acea049adca3d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +object ExpressionSet { + /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */ + def apply(expressions: TraversableOnce[Expression]): ExpressionSet = { + val set = new ExpressionSet() + expressions.foreach(set.add) + set + } +} + +/** + * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] + * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * + * Internally this set uses the canonical representation, but keeps also track of the original + * expressions to ease debugging. Since different expressions can share the same canonical + * representation, this means that operations that extract expressions from this set are only + * guranteed to see at least one such expression. For example: + * + * {{{ + * val set = AttributeSet(a + 1, 1 + a) + * + * set.iterator => Iterator(a + 1) + * set.contains(a + 1) => true + * set.contains(1 + a) => true + * set.contains(a + 2) => false + * }}} + */ +class ExpressionSet protected( + protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, + protected val originals: mutable.Buffer[Expression] = new ArrayBuffer) + extends Set[Expression] { + + protected def add(e: Expression): Unit = { + if (!baseSet.contains(e.canonicalized)) { + baseSet.add(e.canonicalized) + originals.append(e) + } + } + + override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized) + + override def +(elem: Expression): ExpressionSet = { + val newSet = new ExpressionSet(baseSet.clone(), originals.clone()) + newSet.add(elem) + newSet + } + + override def -(elem: Expression): ExpressionSet = { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } + + override def iterator: Iterator[Expression] = originals.iterator + + /** + * Returns a string containing both the post [[Canonicalize]] expressions and the original + * expressions in this set. + */ + def toDebugString: String = + s""" + |baseSet: ${baseSet.mkString(", ")} + |originals: ${originals.mkString(", ")} + """.stripMargin +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5e7d144ae4107..a74b288cb22ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -63,17 +63,19 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } /** - * A sequence of expressions that describes the data property of the output rows of this - * operator. For example, if the output of this operator is column `a`, an example `constraints` - * can be `Set(a > 10, a < 20)`. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then * canonicalized and filtered automatically to contain only those attributes that appear in the - * [[outputSet]] + * [[outputSet]]. + * + * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala new file mode 100644 index 0000000000000..ce42e5784ccd2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionSetSuite extends SparkFunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + def setTest(size: Int, exprs: Expression*): Unit = { + test(s"expect $size: ${exprs.mkString(", ")}") { + val set = ExpressionSet(exprs) + if (set.size != size) { + fail(set.toDebugString) + } + } + } + + def setTestIgnore(size: Int, exprs: Expression*): Unit = + ignore(s"expect $size: ${exprs.mkString(", ")}") {} + + // Commutative + setTest(1, aUpper + 1, aLower + 1) + setTest(2, aUpper + 1, aLower + 2) + setTest(2, aUpper + 1, fakeA + 1) + setTest(2, aUpper + 1, bUpper + 1) + + setTest(1, aUpper + aLower, aLower + aUpper) + setTest(1, aUpper + bUpper, bUpper + aUpper) + setTest(1, + aUpper + bUpper + 3, + bUpper + 3 + aUpper, + bUpper + aUpper + 3, + Literal(3) + aUpper + bUpper) + setTest(1, + aUpper * bUpper * 3, + bUpper * 3 * aUpper, + bUpper * aUpper * 3, + Literal(3) * aUpper * bUpper) + setTest(1, aUpper === bUpper, bUpper === aUpper) + + setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper) + + + // Not commutative + setTest(2, aUpper - bUpper, bUpper - aUpper) + + // Reversable + setTest(1, aUpper > bUpper, bUpper < aUpper) + setTest(1, aUpper >= bUpper, bUpper <= aUpper) + + test("add to / remove from set") { + val initialSet = ExpressionSet(aUpper + 1 :: Nil) + + assert((initialSet + (aUpper + 1)).size == 1) + assert((initialSet + (aUpper + 2)).size == 2) + assert((initialSet - (aUpper + 1)).size == 0) + assert((initialSet - (aUpper + 2)).size == 1) + + assert((initialSet + (aLower + 1)).size == 1) + assert((initialSet - (aLower + 1)).size == 0) + + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1d9db27e097ee..13ff4a2c419e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1980,9 +1980,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - // Would be nice if semantic equals for `+` understood commutative verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") From 13ce10e95401b21fa40ca0bb27ebf9a0bfffe70f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 24 Feb 2016 23:15:36 -0800 Subject: [PATCH 2014/2089] [SPARK-13479][SQL][PYTHON] Added Python API for approxQuantile ## What changes were proposed in this pull request? * Scala DataFrameStatFunctions: Added version of approxQuantile taking a List instead of an Array, for Python compatbility * Python DataFrame and DataFrameStatFunctions: Added approxQuantile ## How was this patch tested? * unit test in sql/tests.py Documentation was copied from the existing approxQuantile exactly. Author: Joseph K. Bradley Closes #11356 from jkbradley/approx-quantile-python. --- python/pyspark/sql/dataframe.py | 54 +++++++++++++++++++ python/pyspark/sql/tests.py | 7 +++ .../spark/sql/DataFrameStatFunctions.scala | 10 ++++ 3 files changed, 71 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7275e69353485..76fbb0c9aa4c9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1178,6 +1178,55 @@ def replace(self, to_replace, value, subset=None): return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + @since(2.0) + def approxQuantile(self, col, probabilities, relativeError): + """ + Calculates the approximate quantiles of a numerical column of a + DataFrame. + + The result of this algorithm has the following deterministic bound: + If the DataFrame has N elements and if we request the quantile at + probability `p` up to error `err`, then the algorithm will return + a sample `x` from the DataFrame so that the *exact* rank of `x` is + close to (p * N). More precisely, + + floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + + This method implements a variation of the Greenwald-Khanna + algorithm (with some speed optimizations). The algorithm was first + present in [[http://dx.doi.org/10.1145/375663.375670 + Space-efficient Online Computation of Quantile Summaries]] + by Greenwald and Khanna. + + :param col: the name of the numerical column + :param probabilities: a list of quantile probabilities + Each number must belong to [0, 1]. + For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + :param relativeError: The relative target precision to achieve + (>= 0). If set to zero, the exact quantiles are computed, which + could be very expensive. Note that values greater than 1 are + accepted but give the same result as 1. + :return: the approximate quantiles at the given probabilities + """ + if not isinstance(col, str): + raise ValueError("col should be a string.") + + if not isinstance(probabilities, (list, tuple)): + raise ValueError("probabilities should be a list or tuple") + if isinstance(probabilities, tuple): + probabilities = list(probabilities) + for p in probabilities: + if not isinstance(p, (float, int, long)) or p < 0 or p > 1: + raise ValueError("probabilities should be numerical (float, int, long) in [0,1].") + probabilities = _to_list(self._sc, probabilities) + + if not isinstance(relativeError, (float, int, long)) or relativeError < 0: + raise ValueError("relativeError should be numerical (float, int, long) >= 0.") + relativeError = float(relativeError) + + jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) + return list(jaq) + @since(1.4) def corr(self, col1, col2, method=None): """ @@ -1396,6 +1445,11 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df + def approxQuantile(self, col, probabilities, relativeError): + return self.df.approxQuantile(col, probabilities, relativeError) + + approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__ + def corr(self, col1, col2, method=None): return self.df.corr(col1, col2, method) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc11c0f35cdc9..90fd7696910ed 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -669,6 +669,13 @@ def test_first_last_ignorenulls(self): functions.last(df2.id, True).alias('d')) self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_approxQuantile(self): + df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF() + aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) + self.assertTrue(all(isinstance(q, float) for q in aq)) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 39a31ab02898f..3eb1f0f0d58ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -70,6 +70,16 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray } + /** + * Python-friendly version of [[approxQuantile()]] + */ + private[spark] def approxQuantile( + col: String, + probabilities: List[Double], + relativeError: Double): java.util.List[Double] = { + approxQuantile(col, probabilities.toArray, relativeError).toList.asJava + } + /** * Calculate the sample covariance of two numerical columns of a DataFrame. * @param col1 the name of the first column From 4d2864b2d7cac027cf6bb78fc3cac9bc37534b07 Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Wed, 24 Feb 2016 23:22:14 -0800 Subject: [PATCH 2015/2089] [SPARK-7106][MLLIB][PYSPARK] Support model save/load in Python's FPGrowth ## What changes were proposed in this pull request? Python API supports mode save/load in FPGrowth JIRA: [https://issues.apache.org/jira/browse/SPARK-7106](https://issues.apache.org/jira/browse/SPARK-7106) ## How was the this patch tested? The patch is tested with Python doctest. Author: Kai Jiang Closes #11321 from vectorijk/spark-7106. --- python/pyspark/mllib/fpm.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 7a2d77a4dad13..5c9706cb8cb29 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -21,14 +21,15 @@ from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] @inherit_doc @ignore_unicode_prefix -class FPGrowthModel(JavaModelWrapper): +class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + >>> model_path = temp_path + "/fpm" + >>> model.save(sc, model_path) + >>> sameModel = FPGrowthModel.load(sc, model_path) + >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect()) + True .. versionadded:: 1.4.0 """ @@ -51,6 +57,16 @@ def freqItemsets(self): """ return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) + @classmethod + @since("2.0.0") + def load(cls, sc, path): + """ + Load a model from the given path. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.FPGrowthModelWrapper(model) + return FPGrowthModel(wrapper) + class FPGrowth(object): """ @@ -170,8 +186,19 @@ def _test(): import pyspark.mllib.fpm globs = pyspark.mllib.fpm.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + import tempfile + + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) From 264533b553be806b6c45457201952e83c028ec78 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Wed, 24 Feb 2016 23:52:17 -0800 Subject: [PATCH 2016/2089] [SPARK-13482][MINOR][CONFIGURATION] Make consistency of the configuraiton named in TransportConf. `spark.storage.memoryMapThreshold` has two kind of the value, one is 2*1024*1024 as integer and the other one is '2m' as string. "2m" is recommanded in document but it will go wrong if the code goes into `TransportConf#memoryMapBytes`. [Jira](https://issues.apache.org/jira/browse/SPARK-13482) Author: huangzhaowei Closes #11360 from SaintBacchus/SPARK-13482. --- .../main/java/org/apache/spark/network/util/TransportConf.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 115135d44adbd..9f030da2b3cec 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -132,7 +132,8 @@ public int ioRetryWaitTimeMs() { * memory mapping has high overhead for blocks close to or below the page size of the OS. */ public int memoryMapBytes() { - return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.storage.memoryMapThreshold", "2m"))); } /** From 07f92ef1fa090821bef9c60689bf41909d781ee7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 25 Feb 2016 00:13:07 -0800 Subject: [PATCH 2017/2089] [SPARK-13376] [SPARK-13476] [SQL] improve column pruning ## What changes were proposed in this pull request? This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset). This PR also fix a bug in Generate, it should always output UnsafeRow, added an regression test for that. ## How was this patch tested? This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s). Author: Davies Liu Closes #11354 from davies/fix_column_pruning. --- .../sql/catalyst/optimizer/Optimizer.scala | 128 ++++++++---------- .../optimizer/ColumnPruningSuite.scala | 128 +++++++++++++++++- .../optimizer/FilterPushdownSuite.scala | 80 ----------- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../apache/spark/sql/execution/Generate.scala | 28 ++-- .../columnar/InMemoryColumnarTableScan.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++ 7 files changed, 215 insertions(+), 166 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1f05f2065c949..2b804976f3f11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(projects, output, child)) - if (e.outputSet -- a.references).nonEmpty => - val newOutput = output.filter(a.references.contains(_)) - val newProjects = projects.map { proj => - proj.zip(output).filter { case (e, a) => + // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) + case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + p.copy( + child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + projectList = w.projectList.filter(p.references.contains), + windowExpressions = w.windowExpressions.filter(p.references.contains))) + case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + val newOutput = e.output.filter(a.references.contains(_)) + val newProjects = e.projections.map { proj => + proj.zip(e.output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, child)) + a.copy(child = Expand(newProjects, newOutput, grandChild)) + // TODO: support some logical plan for Dataset - case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- e.references -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) - - // Eliminate attributes that are not needed to calculate the specified aggregates. + // Prunes the unused columns from child of Aggregate/Window/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = Project(a.references.toSeq, child)) - - // Eliminate attributes that are not needed to calculate the Generate. + a.copy(child = prunedChild(child, a.references)) + case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => + w.copy(child = prunedChild(child, w.references)) + case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = Project(g.references.toSeq, g.child)) + g.copy(child = prunedChild(g.child, g.references)) + // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - case p @ Project(projectList, g: Generate) if g.join => - val neededChildOutput = p.references -- g.generatorOutput ++ g.references - if (neededChildOutput == g.child.outputSet) { - p + // Eliminate unneeded attributes from right side of a LeftSemiJoin. + case j @ Join(left, right, LeftSemi, condition) => + j.copy(right = prunedChild(right, j.references)) + + // all the columns will be used to compare, so we can't prune them + case p @ Project(_, _: SetOperation) => p + case p @ Project(_, _: Distinct) => p + // Eliminate unneeded attributes from children of Union. + case p @ Project(_, u: Union) => + if ((u.outputSet -- p.references).nonEmpty) { + val firstChild = u.children.head + val newOutput = prunedChild(firstChild, p.references).output + // pruning the columns of all children based on the pruned first child. + val newChildren = u.children.map { p => + val selected = p.output.zipWithIndex.filter { case (a, i) => + newOutput.contains(firstChild.output(i)) + }.map(_._1) + Project(selected, p) + } + p.copy(child = u.withNewChildren(newChildren)) } else { - Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + p } - case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) - if (a.outputSet -- p.references).nonEmpty => - Project( - projectList, - Aggregate( - groupingExpressions, - aggregateExpressions.filter(e => p.references.contains(e)), - child)) - - // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => - // Collect the list of all references required either above or to evaluate the condition. - val allReferences: AttributeSet = - AttributeSet( - projectList.flatMap(_.references.iterator)) ++ - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) + // Can't prune the columns on LeafNode + case p @ Project(_, l: LeafNode) => p - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) - - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => - // Collect the list of all references required to evaluate the condition. - val allReferences: AttributeSet = - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - Join(left, prunedChild(right, allReferences), LeftSemi, condition) - - // Push down project through limit, so that we may have chance to push it further. - case Project(projectList, Limit(exp, child)) => - Limit(exp, Project(projectList, child)) - - // Push down project if possible when the child is sort. - case p @ Project(projectList, s @ Sort(_, _, grandChild)) => - if (s.references.subsetOf(p.outputSet)) { - s.copy(child = Project(projectList, grandChild)) + // Eliminate no-op Projects + case p @ Project(projectList, child) if child.output == p.output => child + + // for all other logical plans that inherits the output from it's children + case p @ Project(_, child) => + val required = child.references ++ p.references + if ((child.inputSet -- required).nonEmpty) { + val newChildren = child.children.map(c => prunedChild(c, required)) + p.copy(child = child.withNewChildren(newChildren)) } else { - val neededReferences = s.references ++ p.references - if (neededReferences == grandChild.outputSet) { - // No column we can prune, return the original plan. - p - } else { - // Do not use neededReferences.toSeq directly, should respect grandChild's output order. - val newProjectList = grandChild.output.filter(neededReferences.contains) - p.copy(child = s.copy(child = Project(newProjectList, grandChild))) - } + p } - - // Eliminate no-op Projects - case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + Project(c.output.filter(allReferences.contains), c) } else { c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index c890fffc40167..715d01a3cd876 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest { Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), - Project(Seq('c, 'a), + Project(Seq('a, 'c), input))).analyze comparePlans(optimized, expected) } + test("Column pruning on Filter") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze + val expected = + Project('a :: Nil, + Filter('c > Literal(0.0), + Project(Seq('a, 'c), input))).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("Column pruning on except/intersect/distinct") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Except(input, input)).analyze + comparePlans(Optimize.execute(query), query) + + val query2 = Project('a :: Nil, Intersect(input, input)).analyze + comparePlans(Optimize.execute(query2), query2) + val query3 = Project('a :: Nil, Distinct(input)).analyze + comparePlans(Optimize.execute(query3), query3) + } + + test("Column pruning on Project") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze + val expected = Project(Seq('a), input).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("column pruning for group") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("push down project past sort") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + + test("Column pruning on Union") { + val input1 = LocalRelation('a.int, 'b.string, 'c.double) + val input2 = LocalRelation('c.int, 'd.string, 'e.double) + val query = Project('b :: Nil, + Union(input1 :: input2 :: Nil)).analyze + val expected = Project('b :: Nil, + Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + comparePlans(Optimize.execute(query), expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70b34cbb246ab..7d60862f5ae70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, - ColumnPruning, CollapseProject) :: Nil } @@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("column pruning for group") { - val originalQuery = - testRelation - .groupBy('a)('a, count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val originalQuery = - testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push down project past sort") { - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) - } - test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 1ab53a12572a1..2f382bbda0c58 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -108,7 +108,7 @@ class JoinOptimizationSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - Project(Seq($"y.key"), BroadcastHint(SubqueryAlias("y", input))), + BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Inner, None)).analyze comparePlans(optimized, expected) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 4db88a09d8152..6bc4649d432ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -54,17 +55,19 @@ case class Generate( child: SparkPlan) extends UnaryNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def expressions: Seq[Expression] = generator :: Nil val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - if (join) { + val rows = if (join) { child.execute().mapPartitionsInternal { iter => - val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) + val generatorNullRow = new GenericInternalRow(generator.elementTypes.size) val joinedRow = new JoinedRow - val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -73,19 +76,26 @@ case class Generate( if (outer && outputRows.isEmpty) { joinedRow.withRight(generatorNullRow) :: Nil } else { - outputRows.map(or => joinedRow.withRight(or)) + outputRows.map(joinedRow.withRight) } - } ++ LazyIterator(() => boundGenerator.terminate()).map { row => + } ++ LazyIterator(boundGenerator.terminate).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - proj(joinedRow.withRight(row)) + joinedRow.withRight(row) } } } else { child.execute().mapPartitionsInternal { iter => - val proj = UnsafeProjection.create(output, output) - (iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate())).map(proj) + iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) + } + } + + val numOutputRows = longMetric("numOutputRows") + rows.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(output, output) + iter.map { r => + numOutputRows += 1 + proj(r) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 4858140229d45..22d427808593e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation( @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient private[sql] var _statistics: Statistics = null, private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends LogicalPlan with MultiInstanceRelation { + extends logical.LeafNode with MultiInstanceRelation { override def producedAttributes: AttributeSet = outputSet @@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } - override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4930c485da83f..b8d1b5a6ae136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -194,6 +194,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("a", Seq("a"), 1) :: Nil) } + test("sort after generate with join=true") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select($"*", explode($"a").as("c")).sortWithinPartitions("b", "c"), + Row(Seq("a"), 1, "a") :: Nil) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), From 2b2c8c33236677c916541f956f7b94bba014a9ce Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 25 Feb 2016 17:49:50 +0800 Subject: [PATCH 2018/2089] [SPARK-13486][SQL] Move SQLConf into an internal package ## What changes were proposed in this pull request? This patch moves SQLConf into org.apache.spark.sql.internal package to make it very explicit that it is internal. Soon I will also submit more API work that creates implementations of interfaces in this internal package. ## How was this patch tested? If it compiles, then the refactoring should work. Author: Reynold Xin Closes #11363 from rxin/SPARK-13486. --- project/MimaExcludes.scala | 6 ++ .../org/apache/spark/sql/DataFrame.scala | 1 + .../org/apache/spark/sql/GroupedData.scala | 1 + .../org/apache/spark/sql/SQLContext.scala | 3 +- .../spark/sql/execution/ExistingRDD.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../apache/spark/sql/execution/commands.scala | 3 +- .../InsertIntoHadoopFsRelation.scala | 1 + .../datasources/SqlNewHadoopRDD.scala | 3 +- .../datasources/WriterContainer.scala | 1 + .../parquet/CatalystSchemaConverter.scala | 4 +- .../parquet/CatalystWriteSupport.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 1 + .../spark/sql/execution/debug/package.scala | 1 + .../execution/local/BinaryHashJoinNode.scala | 2 +- .../local/BroadcastHashJoinNode.scala | 2 +- .../execution/local/ConvertToSafeNode.scala | 2 +- .../execution/local/ConvertToUnsafeNode.scala | 2 +- .../sql/execution/local/ExpandNode.scala | 2 +- .../sql/execution/local/FilterNode.scala | 2 +- .../sql/execution/local/IntersectNode.scala | 4 +- .../spark/sql/execution/local/LimitNode.scala | 2 +- .../spark/sql/execution/local/LocalNode.scala | 3 +- .../execution/local/NestedLoopJoinNode.scala | 2 +- .../sql/execution/local/ProjectNode.scala | 2 +- .../sql/execution/local/SampleNode.scala | 2 +- .../sql/execution/local/SeqScanNode.scala | 2 +- .../local/TakeOrderedAndProjectNode.scala | 2 +- .../spark/sql/execution/local/UnionNode.scala | 2 +- .../spark/sql/{ => internal}/SQLConf.scala | 86 +++++++++---------- .../spark/sql/internal/package-info.java | 22 +++++ .../apache/spark/sql/internal/package.scala | 24 ++++++ .../spark/sql/DataFrameAggregateSuite.scala | 1 + .../spark/sql/DataFramePivotSuite.scala | 1 + .../org/apache/spark/sql/DataFrameSuite.scala | 1 + .../org/apache/spark/sql/JoinSuite.scala | 1 + .../spark/sql/MultiSQLContextsSuite.scala | 1 + .../org/apache/spark/sql/QueryTest.scala | 1 + .../apache/spark/sql/SQLContextSuite.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 1 + .../execution/ExchangeCoordinatorSuite.scala | 1 + .../spark/sql/execution/PlannerSuite.scala | 3 +- .../columnar/PartitionBatchPruningSuite.scala | 1 + .../datasources/json/JsonSuite.scala | 1 + .../parquet/ParquetFilterSuite.scala | 1 + .../datasources/parquet/ParquetIOSuite.scala | 1 + .../ParquetPartitionDiscoverySuite.scala | 1 + .../parquet/ParquetQuerySuite.scala | 1 + .../parquet/ParquetReadBenchmark.scala | 3 +- .../datasources/parquet/ParquetTest.scala | 3 +- .../sql/execution/joins/InnerJoinSuite.scala | 3 +- .../sql/execution/joins/OuterJoinSuite.scala | 3 +- .../sql/execution/joins/SemiJoinSuite.scala | 3 +- .../spark/sql/execution/local/DummyNode.scala | 2 +- .../execution/local/HashJoinNodeSuite.scala | 2 +- .../sql/execution/local/LocalNodeTest.scala | 2 +- .../local/NestedLoopJoinNodeSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 1 + .../{ => internal}/SQLConfEntrySuite.scala | 4 +- .../sql/{ => internal}/SQLConfSuite.scala | 3 +- .../spark/sql/sources/DataSourceTest.scala | 1 + .../spark/sql/sources/FilteredScanSuite.scala | 1 + .../spark/sql/sources/PrunedScanSuite.scala | 1 + .../spark/sql/sources/SaveLoadSuite.scala | 3 +- .../spark/sql/test/TestSQLContext.scala | 6 +- .../hive/thriftserver/HiveThriftServer2.scala | 2 +- .../SparkExecuteStatementOperation.scala | 3 +- .../HiveThriftServer2Suites.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 5 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 3 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 1 + .../hive/ParquetHiveCompatibilitySuite.scala | 3 +- .../spark/sql/hive/QueryPartitionSuite.scala | 1 + .../spark/sql/hive/StatisticsSuite.scala | 3 +- .../execution/AggregationQuerySuite.scala | 1 + .../sql/hive/execution/SQLQuerySuite.scala | 1 + .../hive/orc/OrcHadoopFsRelationSuite.scala | 3 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 1 + .../apache/spark/sql/hive/parquetSuites.scala | 2 + .../spark/sql/sources/BucketedReadSuite.scala | 1 + .../sql/sources/BucketedWriteSuite.scala | 3 +- .../ParquetHadoopFsRelationSuite.scala | 1 + .../sql/sources/hadoopFsRelationSuites.scala | 1 + 86 files changed, 206 insertions(+), 97 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => internal}/SQLConf.scala (90%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala rename sql/core/src/test/scala/org/apache/spark/sql/{ => internal}/SQLConfEntrySuite.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => internal}/SQLConfSuite.scala (98%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b12aefcf836ab..14e3c90f1b0d2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -275,6 +275,12 @@ object MimaExcludes { ) ++ Seq ( // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ Seq( + // [SPARK-13486][SQL] Move SQLConf into an internal package + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") ) case v if v.startsWith("1.6") => Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f590ac0114ec4..abb8fe552b7cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 66ec0e7338fc4..f06d16116ec8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a2f386850c1b5..1c24d9e4aeb0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,6 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.{execution => sparkexecution} -import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.{InternalRow, _} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor @@ -43,6 +42,8 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.SQLConfEntry import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 8649d2d69b62e..2cbe3f2c94202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 247eb054a8beb..5fdf38c733efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeComman import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.internal.SQLConf private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => @@ -98,7 +99,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Join implementations are chosen with the following precedence: * * - Broadcast: if one side of the join has an estimated physical size that is smaller than the - * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold * or if that side has an explicit broadcast hint (e.g. the user applied the * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side * of the join will be broadcasted and the other side will be streamed, with no shuffling diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index c6adb583f931b..5574645741823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,12 +21,13 @@ import java.util.NoSuchElementException import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 2d3e1714d2b7b..c8b020d55a3cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 25911334a674f..f4271d165c9bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -32,7 +32,8 @@ import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7e5c8f2f48d6b..c3db2a0af4bd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 54dda0c3915b4..ab4250d0adbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -25,8 +25,9 @@ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ -import org.apache.spark.sql.{AnalysisException, SQLConf} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -37,7 +38,6 @@ import org.apache.spark.sql.types._ * [[MessageType]] schemas. * * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * * @constructor * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index e78afa5ae6d0b..3508220c9541f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -30,11 +30,11 @@ import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.io.api.{Binary, RecordConsumer} import org.apache.spark.Logging -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1e686d41f41db..184cbb2f296b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dbb6b654b1a38..95d033bc57548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.internal.SQLConf /** * Contains methods for debugging query execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index 59345046da495..8f063e24fbf8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} +import org.apache.spark.sql.internal.SQLConf /** * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala index cd1c86516ec5f..9ffa272d2189a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.local import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} +import org.apache.spark.sql.internal.SQLConf /** * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala index b31c5a863832e..f79d795a904d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} +import org.apache.spark.sql.internal.SQLConf case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala index de2f4e661ab44..f3fa474b0f7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} +import org.apache.spark.sql.internal.SQLConf case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 85111bd6d1c98..6ccd6db0e6ca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf case class ExpandNode( conf: SQLConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index dd1113b6726cf..c5eb33cef4420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.internal.SQLConf case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala index 740d485f8d9e6..e594e132dea79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.local import scala.collection.mutable -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) extends BinaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala index 401b10a5ed307..9af45ac0aac9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 8726e4878106d..a5d09691dc46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.local import org.apache.spark.Logging -import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index b93bde58a55e5..b5ea08325c58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class NestedLoopJoinNode( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index bd73b08263f87..5fe068a13c8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, UnsafeProjection} +import org.apache.spark.sql.internal.SQLConf case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index 793700803f216..078fb50deb16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index b8467f6ae58e0..8ebfe3a68b3a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf /** * An operator that scans some local data collection in the form of Scala Seq. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index ca68b7677ce83..f52f5f7bb59b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.BoundedPriorityQueue case class TakeOrderedAndProjectNode( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala index 0f2b8303e7372..e53bc220d8d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a601c87fc930e..c1e3f386b2562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import java.util.Properties -import scala.collection.immutable import scala.collection.JavaConverters._ +import scala.collection.immutable import org.apache.parquet.hadoop.ParquetOutputCommitter @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// -private[spark] object SQLConf { +object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( new java.util.HashMap[String, SQLConfEntry[_]]()) @@ -55,7 +55,7 @@ private[spark] object SQLConf { * configuration is only used internally and we should not expose it to the user. * @tparam T the value type */ - private[sql] class SQLConfEntry[T] private( + class SQLConfEntry[T] private( val key: String, val defaultValue: Option[T], val valueConverter: String => T, @@ -70,7 +70,7 @@ private[spark] object SQLConf { } } - private[sql] object SQLConfEntry { + object SQLConfEntry { private def apply[T]( key: String, @@ -528,7 +528,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { +class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -537,85 +537,85 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon /** ************************ Spark SQL Params/Hints ******************* */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) + def useCompression: Boolean = getConf(COMPRESS_CACHED) - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) - private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) + def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) - private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) - private[spark] def targetPostShuffleInputSize: Long = + def targetPostShuffleInputSize: Long = getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) - private[spark] def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) - private[spark] def minNumPostShufflePartitions: Int = + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) - private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) - private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) + def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) - private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) + def nativeView: Boolean = getConf(NATIVE_VIEW) - private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) - private[spark] def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) + def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def subexpressionEliminationEnabled: Boolean = + def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) - private[spark] def defaultSizeInBytes: Long = + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) - private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) + def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) - private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - private[spark] def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) - private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) - private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) + def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) - private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) + def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) - private[spark] def partitionDiscoveryEnabled(): Boolean = + def partitionDiscoveryEnabled(): Boolean = getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) - private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = + def partitionColumnTypeInferenceEnabled(): Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) - private[spark] def parallelPartitionDiscoveryThreshold: Int = + def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) - private[spark] def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) + def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) + def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) - private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) - private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) - private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) @@ -715,15 +715,15 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon settings.put(key, value) } - private[spark] def unsetConf(key: String): Unit = { + def unsetConf(key: String): Unit = { settings.remove(key) } - private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { + def unsetConf(entry: SQLConfEntry[_]): Unit = { settings.remove(entry.key) } - private[spark] def clear(): Unit = { + def clear(): Unit = { settings.clear() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java b/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java new file mode 100644 index 0000000000000..1e801cb6ee2a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * All classes in this package are considered an internal API to Spark and + * are subject to change between minor releases. + */ +package org.apache.spark.sql.internal; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala new file mode 100644 index 0000000000000..c2394f42e552d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * All classes in this package are considered an internal API to Spark and + * are subject to change between minor releases. + */ +package object internal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 78bf6c1bcebf2..f54bff9f18bc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index bc1a336ea4fd0..368aa5cd141f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class DataFramePivotSuite extends QueryTest with SharedSQLContext{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b8d1b5a6ae136..84f30c0aaf862 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a1211e43808fe..41e27ec46648f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 6a375a33bfcf6..0b5a92c256e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.apache.spark._ +import org.apache.spark.sql.internal.SQLConf class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5401212428d6f..c05aa5486ab15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{LogicalRDD, Queryable} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SQLConf abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 14b9448d260f4..ec19d97d8cec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 13ff4a2c419e9..16e769feca487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b1c588a63d030..4f01e46633c8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunS import org.apache.spark.sql._ import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.TestSQLContext class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4de56783fabce..f66e08e6ca5c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row, SQLConf} +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMem import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 86c2c25c2c7e1..d19fec6140acb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index dd83a0e36f6f7..c7f33e17465b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 3ded32c450541..fbffe867e4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 41a9404b00795..3c74464d57012 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 3d1677bed4770..8bc5c89959803 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b123d2b31efcf..acfc1a518a0a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index e8893073e305a..dafc589999c7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -22,7 +22,8 @@ import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{Benchmark, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 5cbcccbd862dc..e8c524e9e550d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -30,7 +30,8 @@ import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.schema.MessageType -import org.apache.spark.sql.{DataFrame, SaveMode, SQLConf} +import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 6dfff3770b882..b748229e402f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index cd6b6fcbb18af..22fe8caff265e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index f3ad8409e5a3d..5c982885d652f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala index efc3227dd60d8..cd9277d3bcf1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf /** * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index eb70747926fe5..268f2aac87195 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.local import org.mockito.Mockito.{mock, when} import org.apache.spark.broadcast.TorrentBroadcast -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedMutableProjection, UnsafeProjection} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} +import org.apache.spark.sql.internal.SQLConf class HashJoinNodeSuite extends LocalNodeTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 1a485f967dd38..cd67a66ebf576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} class LocalNodeTest extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 45df2ea6552d8..bcc87a9175517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.internal.SQLConf class NestedLoopJoinNodeSuite extends LocalNodeTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 46bb699b780a9..c49f2439fce49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{JsonProtocol, Utils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 2e33777f14adc..2b89fa9f23815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf._ +import org.apache.spark.sql.internal.SQLConf._ class SQLConfEntrySuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index cf0701eca29ea..e944d328a3ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal +import org.apache.spark.sql.{QueryTest, SQLContext} import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index af04079ec895a..92061133cd49b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf private[sql] abstract class DataSourceTest extends QueryTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 4b578a6c10ec4..2ff79a2316bdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 02e8ab1c51746..db722975379a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index e055da9e8a39a..588f6e268f31c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,7 +21,8 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLConf} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index c89a1516503e0..28ad7ae64a517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} - +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.SQLConf /** * A special [[SQLContext]] prepared for testing. @@ -39,7 +39,7 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel super.clear() // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.map { + TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 66eaa3ebcd737..f32ba5fe68a63 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -32,10 +32,10 @@ import org.apache.hive.service.server.{HiveServer2, HiveServerServerOptionsProce import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 8fef22cf777f6..458d4f2c3c959 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -33,9 +33,10 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.sql.{DataFrame, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.{Utils => SparkUtils} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 5f9952a90a57d..c05527b519daa 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -202,7 +202,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } test("test multiple session") { - import org.apache.spark.sql.SQLConf + import org.apache.spark.sql.internal.SQLConf var defaultV1: String = null var defaultV2: String = null var data: ArrayBuffer[Int] = null diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 973ba6bf38b3e..d15f883138938 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,8 +23,8 @@ import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.internal.SQLConf /** * Runs the test cases that are included in the hive distribution. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 88fda16ca0ef8..d511dd685ce75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -40,8 +40,6 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.SQLConf.SQLConfEntry -import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.{InternalRow, ParserInterface} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} @@ -52,6 +50,9 @@ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsert import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.SQLConfEntry +import org.apache.spark.sql.internal.SQLConf.SQLConfEntry._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index d316664241f1a..145b5f7cc2dc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -27,13 +27,13 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.SerializableJobConf private[hive] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 246108e0d0e11..9d0622bea92c7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index f8764d4725f63..90d65d9e9b8c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{QueryTest, Row, SaveMode, SQLConf} +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index cd85e23d8b960..cb23959c2dd57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 4a73153a80ee8..a9823ae26278d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -21,9 +21,10 @@ import java.sql.Timestamp import org.apache.hadoop.hive.conf.HiveConf -import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index f49ee690ac041..78569c58085cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -21,6 +21,7 @@ import com.google.common.io.Files import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 91bedf9c5af5a..3952e716d38a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf class StatisticsSuite extends QueryTest with TestHiveSingleton { import hiveContext.sql diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index caf1db9ad0855..45634a4475a6f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAg import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6624494e08ea8..ff1719eaf6efc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index e8a61123d18b1..528f40b002be0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.Row +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 2156806d21f96..b11d1d9de092e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.internal.SQLConf case class AllDataTypesWithNonPrimitiveType( stringField: String, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index a6ca7d0386b22..68d5c7da1f974 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertI import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -755,6 +756,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with /** * Drop named tables if they exist + * * @param tableNames tables to drop */ def dropTables(tableNames: String*): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 62f6b998e6f2a..9a52276fcdc6a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index a32f8fb4c5a1d..c37b21bed3ab0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.sources import java.io.File import java.net.URI -import org.apache.spark.sql.{AnalysisException, QueryTest, SQLConf} +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index ba2a483bba534..5ce58e898ebec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 1a4b3ece72a66..2a921a061f358 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ From 2e44031fafdb8cf486573b98e4faa6b31ffb90a4 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Thu, 25 Feb 2016 12:18:43 +0000 Subject: [PATCH 2019/2089] [SPARK-13117][WEB UI] WebUI should use the local ip not 0.0.0.0 Fixed the HTTP Server Host Name/IP issue i.e. HTTP Server to take the configured host name/IP and not '0.0.0.0' always. Author: Devaraj K Closes #11133 from devaraj-kavali/SPARK-13117. --- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index fe4949b9f6fee..e515916f31420 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -134,7 +134,7 @@ private[spark] abstract class WebUI( def bind() { assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) try { - serverInfo = Some(startJettyServer("0.0.0.0", port, sslOptions, handlers, conf, name)) + serverInfo = Some(startJettyServer(publicHostName, port, sslOptions, handlers, conf, name)) logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) } catch { case e: Exception => diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index f416ace5c2b71..972b552f7a5b0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.SparkConfWithEnv +import org.apache.spark.util.Utils class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { @@ -53,7 +54,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { } test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { - val SPARK_PUBLIC_DNS = "public_dns" + val SPARK_PUBLIC_DNS = Utils.localHostNameForURI() val conf = new SparkConfWithEnv(Map("SPARK_PUBLIC_DNS" -> SPARK_PUBLIC_DNS)).set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) From 3fa6491be66dad690ca5329dd32e7c82037ae8c1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 25 Feb 2016 20:43:03 +0800 Subject: [PATCH 2020/2089] [SPARK-13473][SQL] Don't push predicate through project with nondeterministic field(s) ## What changes were proposed in this pull request? Predicates shouldn't be pushed through project with nondeterministic field(s). See https://github.com/graphframes/graphframes/pull/23 and SPARK-13473 for more details. This PR targets master, branch-1.6, and branch-1.5. ## How was this patch tested? A test case is added in `FilterPushdownSuite`. It constructs a query plan where a filter is over a project with a nondeterministic field. Optimized query plan shouldn't change in this case. Author: Cheng Lian Closes #11348 from liancheng/spark-13473-no-ppd-through-nondeterministic-project-field. --- .../sql/catalyst/optimizer/Optimizer.scala | 9 ++++++- .../optimizer/FilterPushdownSuite.scala | 27 +++---------------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2b804976f3f11..2aeb9575f1dd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -792,7 +792,14 @@ object SimplifyFilters extends Rule[LogicalPlan] { */ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, project @ Project(fields, grandChild)) => + // SPARK-13473: We can't push the predicate down when the underlying projection output non- + // deterministic field(s). Non-deterministic expressions are essentially stateful. This + // implies that, for a given input row, the output are determined by the expression's initial + // state and all the input rows processed before. In another word, the order of input rows + // matters for non-deterministic expressions, while pushing down predicates changes the order. + case filter @ Filter(condition, project @ Project(fields, grandChild)) + if fields.forall(_.deterministic) => + // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). val aliasMap = AttributeMap(fields.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 7d60862f5ae70..1292aa0003dd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -98,7 +98,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter through project") { + test("nondeterministic: can't push down filter with nondeterministic condition through project") { val originalQuery = testRelation .select(Rand(10).as('rand), 'a) .where('rand > 5 || 'a > 5) @@ -109,36 +109,15 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("nondeterministic: push down part of filter through project") { + test("nondeterministic: can't push down filter through project with nondeterministic field") { val originalQuery = testRelation .select(Rand(10).as('rand), 'a) - .where('rand > 5 && 'a > 5) - .analyze - - val optimized = Optimize.execute(originalQuery) - - val correctAnswer = testRelation .where('a > 5) - .select(Rand(10).as('rand), 'a) - .where('rand > 5) - .analyze - - comparePlans(optimized, correctAnswer) - } - - test("nondeterministic: push down filter through project") { - val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('a > 5 && 'a < 10) .analyze val optimized = Optimize.execute(originalQuery) - val correctAnswer = testRelation - .where('a > 5 && 'a < 10) - .select(Rand(10).as('rand), 'a) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery) } test("filters: combines filters") { From 6f8e835c68dff6fcf97326dc617132a41ff9d043 Mon Sep 17 00:00:00 2001 From: Oliver Pierson Date: Thu, 25 Feb 2016 13:24:46 +0000 Subject: [PATCH 2021/2089] [SPARK-13444][MLLIB] QuantileDiscretizer chooses bad splits on large DataFrames ## What changes were proposed in this pull request? Change line 113 of QuantileDiscretizer.scala to `val requiredSamples = math.max(numBins * numBins, 10000.0)` so that `requiredSamples` is a `Double`. This will fix the division in line 114 which currently results in zero if `requiredSamples < dataset.count` ## How was the this patch tested? Manual tests. I was having a problems using QuantileDiscretizer with my a dataset and after making this change QuantileDiscretizer behaves as expected. Author: Oliver Pierson Author: Oliver Pierson Closes #11319 from oliverpierson/SPARK-13444. --- .../ml/feature/QuantileDiscretizer.scala | 11 ++++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1f4cca123310a..769f4406e2dff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -103,6 +103,13 @@ final class QuantileDiscretizer(override val uid: String) @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + + /** + * Minimum number of samples required for finding splits, regardless of number of bins. If + * the dataset has fewer rows than this value, the entire dataset will be used. + */ + private[spark] val minSamplesRequired: Int = 10000 + /** * Sampling from the given dataset to collect quantile statistics. */ @@ -110,8 +117,8 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, 10000) - val fraction = math.min(requiredSamples / dataset.count(), 1.0) + val requiredSamples = math.max(numBins * numBins, minSamplesRequired) + val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0) dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6a2c601bbed18..25fabf64d5594 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -71,6 +71,26 @@ class QuantileDiscretizerSuite } } + test("Test splits on dataset larger than minSamplesRequired") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ + + val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 + val numBuckets = 5 + val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + .setSeed(1) + + val result = discretizer.fit(df).transform(df) + val observedNumBuckets = result.select("result").distinct.count + + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + } + test("read/write") { val t = new QuantileDiscretizer() .setInputCol("myInputCol") From fae88af18445c5a88212b4644e121de4b30ce027 Mon Sep 17 00:00:00 2001 From: Terence Yim Date: Thu, 25 Feb 2016 13:29:30 +0000 Subject: [PATCH 2022/2089] [SPARK-13441][YARN] Fix NPE in yarn Client.createConfArchive method ## What changes were proposed in this pull request? Instead of using result of File.listFiles() directly, which may throw NPE, check for null first. If it is null, log a warning instead ## How was the this patch tested? Ran the ./dev/run-tests locally Tested manually on a cluster Author: Terence Yim Closes #11337 from chtyim/fixes/SPARK-13441-null-check. --- .../scala/org/apache/spark/deploy/yarn/Client.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d4ca255953a48..530f1d753b131 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -537,9 +537,14 @@ private[spark] class Client( sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { - dir.listFiles().foreach { file => - if (file.isFile && !hadoopConfFiles.contains(file.getName())) { - hadoopConfFiles(file.getName()) = file + val files = dir.listFiles() + if (files == null) { + logWarning("Failed to list files under directory " + dir) + } else { + files.foreach { file => + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { + hadoopConfFiles(file.getName()) = file + } } } } From c98a93ded36db5da2f3ebd519aa391de90927688 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Thu, 25 Feb 2016 13:32:09 +0000 Subject: [PATCH 2023/2089] [SPARK-13439][MESOS] Document that spark.mesos.uris is comma-separated Author: Michael Gummelt Closes #11311 from mgummelt/document_csv. --- docs/running-on-mesos.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index b9f64c7ed146e..9816d030e90ac 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -349,8 +349,9 @@ See the [configuration page](configuration.html) for information on Spark config
        From 4460113d419b5da47ba3c956b8430fd00eb03217 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 25 Feb 2016 13:34:29 +0000 Subject: [PATCH 2024/2089] [SPARK-13490][ML] ML LinearRegression should cache standardization param value ## What changes were proposed in this pull request? Like #11027 for ```LogisticRegression```, ```LinearRegression``` with L1 regularization should also cache the value of the ```standardization``` rather than re-fetching it from the ```ParamMap``` for every OWLQN iteration. cc srowen ## How was this patch tested? No extra tests are added. It should pass all existing tests. Author: Yanbo Liang Closes #11367 from yanboliang/spark-13490. --- .../org/apache/spark/ml/regression/LinearRegression.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e253f25c0ea65..ccfb5c4b9d697 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -277,8 +277,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def effectiveL1RegFun = (index: Int) => { - if ($(standardization)) { + if (standardizationParam) { effectiveL1RegParam } else { // If `standardization` is false, we still standardize the data From 157fe64f3ecbd13b7286560286e50235eecfe30e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 25 Feb 2016 23:07:59 +0800 Subject: [PATCH 2025/2089] [SPARK-13457][SQL] Removes DataFrame RDD operations ## What changes were proposed in this pull request? This PR removes DataFrame RDD operations. Original calls are now replaced by calls to methods of `DataFrame.rdd`. ## How was the this patch tested? No extra tests are added. Existing tests should do the work. Author: Cheng Lian Closes #11323 from liancheng/remove-df-rdd-ops. --- .../spark/examples/ml/DataFrameExample.scala | 2 +- .../examples/ml/DecisionTreeExample.scala | 8 ++-- .../spark/examples/ml/OneVsRestExample.scala | 2 +- .../spark/examples/mllib/LDAExample.scala | 1 + .../spark/examples/sql/RDDRelation.scala | 2 +- .../examples/sql/hive/HiveFromSpark.scala | 2 +- .../scala/org/apache/spark/ml/Predictor.scala | 6 ++- .../classification/LogisticRegression.scala | 13 +++--- .../spark/ml/clustering/BisectingKMeans.scala | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 6 +-- .../org/apache/spark/ml/clustering/LDA.scala | 1 + .../BinaryClassificationEvaluator.scala | 9 ++-- .../MulticlassClassificationEvaluator.scala | 6 +-- .../ml/evaluation/RegressionEvaluator.scala | 3 +- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/OneHotEncoder.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 1 + .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 1 + .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 6 +-- .../ml/regression/LinearRegression.scala | 16 ++++--- .../mllib/api/python/PythonMLLibAPI.scala | 8 ++-- .../spark/mllib/clustering/KMeansModel.scala | 2 +- .../spark/mllib/clustering/LDAModel.scala | 4 +- .../clustering/PowerIterationClustering.scala | 2 +- .../BinaryClassificationMetrics.scala | 2 +- .../mllib/evaluation/MulticlassMetrics.scala | 2 +- .../mllib/evaluation/MultilabelMetrics.scala | 4 +- .../mllib/evaluation/RegressionMetrics.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 2 +- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../MatrixFactorizationModel.scala | 12 +++--- .../mllib/tree/model/DecisionTreeModel.scala | 2 +- .../mllib/tree/model/treeEnsembleModels.scala | 2 +- .../LogisticRegressionSuite.scala | 2 +- .../MultilayerPerceptronClassifierSuite.scala | 5 ++- .../ml/classification/OneVsRestSuite.scala | 6 +-- .../ml/clustering/BisectingKMeansSuite.scala | 3 +- .../spark/ml/clustering/KMeansSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 2 +- .../spark/ml/feature/OneHotEncoderSuite.scala | 4 +- .../spark/ml/feature/StringIndexerSuite.scala | 6 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 5 ++- .../spark/ml/feature/Word2VecSuite.scala | 8 ++-- .../spark/ml/recommendation/ALSSuite.scala | 7 ++-- .../ml/regression/GBTRegressorSuite.scala | 2 +- .../regression/IsotonicRegressionSuite.scala | 6 +-- .../ml/regression/LinearRegressionSuite.scala | 17 ++++---- .../org/apache/spark/sql/DataFrame.scala | 42 ------------------- .../org/apache/spark/sql/GroupedData.scala | 1 + .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 2 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 2 + .../spark/sql/hive/orc/OrcQuerySuite.scala | 5 ++- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- 67 files changed, 140 insertions(+), 159 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0a477abae5679..7e608a281203e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -79,7 +79,7 @@ object DataFrameExample { labelSummary.show() // Convert features column to an RDD of vectors. - val features = df.select("features").map { case Row(v: Vector) => v } + val features = df.select("features").rdd.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index a37d12aa636cd..d2560cc00ba07 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -310,8 +310,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) // Print number of classes for reference val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { case Some(n) => n @@ -335,8 +335,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError println(s" Root mean squared error (RMSE): $RMSE") } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index ccee3b2aef980..a0bb5dabf4574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -155,7 +155,7 @@ object OneVsRestExample { // evaluate the model val predictionsAndLabels = predictions.select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) + .rdd.map(row => (row.getDouble(0), row.getDouble(1))) val metrics = new MulticlassMetrics(predictionsAndLabels) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index d28323555b990..038b2fe6112e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -221,6 +221,7 @@ object LDAExample { val model = pipeline.fit(df) val documents = model.transform(df) .select("features") + .rdd .map { case Row(features: Vector) => features } .zipWithIndex() .map(_.swap) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index a2f0fcd0e486e..620ff07631c36 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -52,7 +52,7 @@ object RDDRelation { val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") - rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) + rddFromSql.rdd.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 4e427f54daa52..b654a2c8d4a40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -63,7 +63,7 @@ object HiveFromSpark { val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") - val rddAsStrings = rddFromSql.map { + val rddAsStrings = rddFromSql.rdd.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d1388b5e2eb5a..4b27ee6c5a414 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -122,8 +122,10 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { - dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } + dataset.select($(labelCol), $(featuresCol)).rdd.map { + case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ac0124513f283..0d329d2c08d50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -263,10 +263,11 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -790,6 +791,7 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * :: Experimental :: * Logistic regression training results. + * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. @@ -813,6 +815,7 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( /** * :: Experimental :: * Binary Logistic regression results for a given model. + * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. @@ -837,7 +840,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).map { + predictions.select(probabilityCol, labelCol).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 45d293bc69618..f014a1d572387 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") def computeCost(dataset: DataFrame): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } } @@ -176,7 +176,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: DataFrame): BisectingKMeansModel = { - val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() .setK($(k)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index c6a3eac5877f2..79332b0d02157 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -130,7 +130,7 @@ class KMeansModel private[ml] ( @Since("1.6.0") def computeCost(dataset: DataFrame): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } @@ -260,7 +260,7 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { - val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibKMeans() .setK($(k)) @@ -303,7 +303,7 @@ class KMeansSummary private[clustering] ( * Size of each cluster. */ @Since("2.0.0") - lazy val size: Array[Int] = cluster.map { + lazy val size: Array[Int] = cluster.rdd.map { case Row(clusterIdx: Int) => (clusterIdx, 1) }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 99383e77f7eb0..6304b20d544ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -803,6 +803,7 @@ private[clustering] object LDA extends DefaultParamsReadable[LDA] { dataset .withColumn("docId", monotonicallyIncreasingId()) .select("docId", featuresCol) + .rdd .map { case Row(docId: Long, features: Vector) => (docId, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index a1d36c4becfa2..00f31255845ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -84,11 +84,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) - .map { - case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) - case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) - } + val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index a921153b9474f..55ff44323a790 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -74,9 +74,9 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) - .map { case Row(prediction: Double, label: Double) => - (prediction, label) + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map { + case Row(prediction: Double, label: Double) => + (prediction, label) } val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index b6b25ecd01b3d..adee61e297081 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -85,7 +85,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) - .map { case Row(prediction: Double, label: Double) => + .rdd. + map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 7b565ef3ed922..4abc459f5369a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -79,7 +79,7 @@ final class ChiSqSelector(override val uid: String) override def fit(dataset: DataFrame): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).map { + val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index a6dfe58e56f0e..cf151458f0917 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -133,7 +133,7 @@ class CountVectorizer(override val uid: String) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) - val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 9e7eee4f29988..cebbe5c162f79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -79,7 +79,7 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) copyValues(new IDFModel(uid, idf).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ad0458d0d0e1a..18be5c0701fb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -108,7 +108,7 @@ class MinMaxScaler(override val uid: String) override def fit(dataset: DataFrame): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 342540418fa80..e9df161c00b83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -130,7 +130,7 @@ class OneHotEncoder(override val uid: String) extends Transformer transformSchema(dataset.schema)(outputColName)) if (outputAttrGroup.size < 0) { // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) .aggregate(0.0)( (m, x) => { assert(x >=0.0 && x == x.toInt, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 0e07dfabfeaab..80b124f74716d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -70,7 +70,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams */ override def fit(dataset: DataFrame): PCAModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 6a0b6c240ec60..9952d3bc9f1a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -87,7 +87,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 912bd95a2ec70..e8b617d9c8df9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -83,6 +83,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) + .rdd .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 2a5268406ddf2..5c11760fab9b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -113,7 +113,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size - val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val maxCats = $(maxCategories) val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 2b6b3c3a0fc58..a4c3d2751f50d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -138,7 +138,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def fit(dataset: DataFrame): Word2VecModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() .setLearningRate($(stepSize)) .setMinCount($(minCount)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4be4d6abed6b3..dacdac9a1df16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -392,6 +392,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) + .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 1e5b4cb83c652..e4339d67b928d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -184,7 +184,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { - dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 1573bb4c1b745..36b006c10e1fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -90,9 +90,9 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w) - .map { case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + dataset.select(col($(labelCol)), f, w).rdd.map { + case Row(label: Double, feature: Double, weight: Double) => + (label, feature, weight) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ccfb5c4b9d697..8f78fd122f343 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. - val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) @@ -170,7 +170,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -196,10 +196,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -513,6 +514,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the * training coefficients except for the objective trace. + * * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -537,6 +539,7 @@ class LinearRegressionTrainingSummary private[regression] ( /** * :: Experimental :: * Linear regression results evaluated on a dataset. + * * @param predictions predictions outputted by the model's `transform` method. */ @Since("1.5.0") @@ -551,6 +554,7 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions .select(predictionCol, labelCol) + .rdd .map { case Row(pred: Double, label: Double) => (pred, label) }, !model.getFitIntercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 93cf16e6f0c2a..ca0ed95a483f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1052,7 +1052,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for the constructor of Python mllib RankingMetrics */ def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = { - new RankingMetrics(predictionAndLabels.map( + new RankingMetrics(predictionAndLabels.rdd.map( r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } @@ -1135,7 +1135,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { // We use DataFrames for serialization of IndexedRows from Python, // so map each Row in the DataFrame back to an IndexedRow. - val indexedRows = rows.map { + val indexedRows = rows.rdd.map { case Row(index: Long, vector: Vector) => IndexedRow(index, vector) } new IndexedRowMatrix(indexedRows, numRows, numCols) @@ -1147,7 +1147,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = { // We use DataFrames for serialization of MatrixEntry entries from // Python, so map each Row in the DataFrame back to a MatrixEntry. - val entries = rows.map { + val entries = rows.rdd.map { case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value) } new CoordinateMatrix(entries, numRows, numCols) @@ -1161,7 +1161,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks from // Python, so map each Row in the DataFrame back to a // ((blockRowIndex, blockColIndex), sub-matrix) tuple. - val blockTuples = blocks.map { + val blockTuples = blocks.rdd.map { case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 26c6235fe5907..3b91fe8643dd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -143,7 +143,7 @@ object KMeansModel extends Loader[KMeansModel] { val k = (metadata \ "k").extract[Int] val centroids = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centroids.schema) - val localCentroids = centroids.map(Cluster.apply).collect() + val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.size) new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b30ecb80209d9..25d67a3756f6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -896,11 +896,11 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { Loader.checkSchema[EdgeData](edgeDataFrame.schema) val globalTopicTotals: LDA.TopicCounts = dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector - val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.rdd.map { case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) } - val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.rdd.map { case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) } val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index feacafec7930f..9732dfa1744f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -93,7 +93,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) - val assignmentsRDD = assignments.map { + val assignmentsRDD = assignments.rdd.map { case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 12cf22095720a..319c54724dac1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -58,7 +58,7 @@ class BinaryClassificationMetrics @Since("1.3.0") ( * @param scoreAndLabels a DataFrame with two double columns: score and label */ private[mllib] def this(scoreAndLabels: DataFrame) = - this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Unpersist intermediate RDDs used in the computation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index c5104960cfcb6..3029b15f588a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -38,7 +38,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param predictionAndLabels a DataFrame with two double columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() private lazy val labelCount: Long = labelCountByClass.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index c100b3c9ec14a..daf6ff4db4ed0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -35,7 +35,9 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] * @param predictionAndLabels a DataFrame with two double array columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) + this(predictionAndLabels.rdd.map { r => + (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray) + }) private lazy val numDocs: Long = predictionAndLabels.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 18c90b204a26a..0f4c97ec20c00 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -46,7 +46,7 @@ class RegressionMetrics @Since("2.0.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 33728bf5d77e5..4f0e13feae086 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) - val features = dataArray.map { + val features = dataArray.rdd.map { case Row(feature: Int) => (feature) }.collect() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 85d609386ff94..b35d7217d693e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -134,7 +134,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { } def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { - val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => + val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) new FreqItemset(items, freq) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 628cf1dd57280..c91729a9fd495 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -369,13 +369,13 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.read.parquet(userPath(path)) - .map { case Row(id: Int, features: Seq[_]) => + val userFeatures = sqlContext.read.parquet(userPath(path)).rdd.map { + case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + val productFeatures = sqlContext.read.parquet(productPath(path)).rdd.map { + case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) - } - val productFeatures = sqlContext.read.parquet(productPath(path)) - .map { case Row(id: Int, features: Seq[_]) => - (id, features.asInstanceOf[Seq[Double]].toArray) } new MatrixFactorizationModel(rank, userFeatures, productFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 89c470d573431..ec5d7b91892c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -247,7 +247,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) - val nodes = dataRDD.map(NodeData.apply) + val nodes = dataRDD.rdd.map(NodeData.apply) // Build node data into a tree. val trees = constructTrees(nodes) assert(trees.size == 1, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index feabcee24fa2c..59713c382e58b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -473,7 +473,7 @@ private[tree] object TreeEnsembleModel extends Logging { treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) - val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) + val nodes = sqlContext.read.parquet(datapath).rdd.map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 972c0868a454a..cfb9bbfd41ee7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -735,7 +735,7 @@ class LogisticRegressionSuite val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) - val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label } + val histogram = binaryDataset.rdd.map { case Row(label: Double, features: Vector) => label } .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index a326432d017fc..602b5a8116998 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,8 +76,9 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") - .map { case Row(p: Double, l: Double) => (p, l) } + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { + case Row(p: Double, l: Double) => (p, l) + } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 445e50d867e15..2ae74a2090ecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -81,9 +81,9 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset - .select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) + val ovaResults = transformedDataset.select("prediction", "label").rdd.map { + row => (row.getDouble(0), row.getDouble(1)) + } val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fc4a4add5d755..b719a8c7e7340 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -77,7 +77,8 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e5357ba8e220c..c684bc11cccff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -93,7 +93,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 97dbfd9a4314a..03270401ad2bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -199,7 +199,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead // describeTopics val topics = model.describeTopics(3) assert(topics.count() === k) - assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + assert(topics.select("topic").rdd.map(_.getInt(0)).collect().toSet === Range(0, k).toSet) topics.select("termIndices").collect().foreach { case r: Row => val termIndices = r.getAs[Seq[Int]](0) assert(termIndices.length === 3 && termIndices.toSet.size === 3) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 76d12050f9677..e238b33ed8c64 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -51,7 +51,7 @@ class OneHotEncoderSuite .setDropLast(false) val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").map { r => + val output = encoded.select("id", "labelVec").rdd.map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet @@ -68,7 +68,7 @@ class OneHotEncoderSuite .setOutputCol("labelVec") val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").map { r => + val output = encoded.select("id", "labelVec").rdd.map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5d199ca9b51b1..dea61648f2f5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -52,7 +52,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -83,7 +83,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 @@ -102,7 +102,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 67817fa4baf56..d4f836ef33ad9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -159,7 +159,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) assert(featureAttrs.name === "indexed") assert(featureAttrs.attributes.get.length === model.numFeatures) @@ -216,7 +216,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + val indexedPoints = + model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() points.zip(indexedPoints).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index f094c550e545a..1671fb6f3a85e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -100,7 +100,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val realVectors = model.getVectors.sort("word").select("vector").map { + val realVectors = model.getVectors.sort("word").select("vector").rdd.map { case Row(v: Vector) => v }.collect() // These expectations are just magic values, characterizing the current @@ -134,7 +134,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(docDF) val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).map { + val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -161,7 +161,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val (synonyms, similarity) = model.findSynonyms("a", 6).map { + val (synonyms, similarity) = model.findSynonyms("a", 6).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -174,7 +174,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setWindowSize(10) .fit(docDF) - val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip // The similarity score should be very different with the larger window diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ff0d8f5568279..2bedd70ce93e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -342,11 +342,10 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()) - .select("rating", "prediction") - .map { case Row(rating: Float, prediction: Float) => + val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { + case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) - } + } val rmse = if (implicitPrefs) { // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 09326600e620f..244db8637bea0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -87,7 +87,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) val preds = model.transform(df) - val predictions = preds.select("prediction").map(_.getDouble(0)) + val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) assert(predictions.max() > 2) assert(predictions.min() < -1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index f067c29d27a7d..b8874b4cd32a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -46,7 +46,7 @@ class IsotonicRegressionSuite val predictions = model .transform(dataset) - .select("prediction").map { case Row(pred) => + .select("prediction").rdd.map { case Row(pred) => pred }.collect() @@ -66,7 +66,7 @@ class IsotonicRegressionSuite val predictions = model .transform(features) - .select("prediction").map { + .select("prediction").rdd.map { case Row(pred) => pred }.collect() @@ -160,7 +160,7 @@ class IsotonicRegressionSuite val predictions = model .transform(features) - .select("prediction").map { + .select("prediction").rdd.map { case Row(pred) => pred }.collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 3ae108d822de7..9dee04c8776db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -686,17 +686,18 @@ class LinearRegressionSuite // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = datasetWithDenseFeature.select("features", "label") + .rdd .map { case Row(features: DenseVector, label: Double) => - val prediction = - features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + - model.intercept - label - prediction - } - .zip(model.summary.residuals.map(_.getDouble(0))) + val prediction = + features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + + model.intercept + label - prediction + } + .zip(model.summary.residuals.rdd.map(_.getDouble(0))) .collect() .foreach { case (manualResidual: Double, resultResidual: Double) => - assert(manualResidual ~== resultResidual relTol 1E-5) - } + assert(manualResidual ~== resultResidual relTol 1E-5) + } /* # Use the following R code to generate model training results. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index abb8fe552b7cd..be902d688e519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1426,48 +1426,6 @@ class DataFrame private[sql]( */ def transform[U](t: DataFrame => DataFrame): DataFrame = t(this) - /** - * Returns a new RDD by applying a function to all rows of this DataFrame. - * @group rdd - * @since 1.3.0 - */ - def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) - - /** - * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], - * and then flattening the results. - * @group rdd - * @since 1.3.0 - */ - def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) - - /** - * Returns a new RDD by applying a function to each partition of this DataFrame. - * @group rdd - * @since 1.3.0 - */ - def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { - rdd.mapPartitions(f) - } - - /** - * Applies a function `f` to all rows. - * @group rdd - * @since 1.3.0 - */ - def foreach(f: Row => Unit): Unit = withNewExecutionId { - rdd.foreach(f) - } - - /** - * Applies a function f to each partition of this [[DataFrame]]. - * @group rdd - * @since 1.3.0 - */ - def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId { - rdd.foreachPartition(f) - } - /** * Returns the first `n` rows in the [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f06d16116ec8a..a7258d742aa96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -306,6 +306,7 @@ class GroupedData protected[sql]( val values = df.select(pivotColumn) .distinct() .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .rdd .map(_.get(0)) .take(maxValues + 1) .toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d912aeb70d517..68a251757c596 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -100,7 +100,7 @@ private[r] object SQLUtils { } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { - df.map(r => rowToRBytes(r)) + df.rdd.map(r => rowToRBytes(r)) } private[this] def doConversion(data: Object, dataType: DataType): Object = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e295722cacf15..d7111a6a1c47e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -276,7 +276,7 @@ object JdbcUtils extends Logging { val rddSchema = df.schema val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt - df.foreachPartition { iterator => + df.rdd.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f54bff9f18bc0..7d96ef6fe0a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -257,7 +257,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + assert(testData2.count() === testData2.rdd.map(_ => 1).count()) checkAnswer( testData2.agg(count('a), sumDistinct('a)), // non-partial diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index fbffe867e4b78..bd51154c58aa6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -101,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3c74464d57012..c85eeddc2c6d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -599,7 +599,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data = sqlContext.range(200).map { i => + val data = sqlContext.range(200).rdd.map { i => if (i.getLong(0) < 150) Row(None) else Row("a") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 085e4a49a57e6..39920d8cc659f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -330,7 +330,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { // listener should ignore the non SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + sqlContext.sparkContext.parallelize(1 to 10).toDF().rdd.foreach(i => ()) sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) @@ -398,7 +398,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { ).toDF() df.collect() try { - df.foreach(_ => throw new RuntimeException("Oops")) + df.rdd.foreach(_ => throw new RuntimeException("Oops")) } catch { case e: SparkException => // This is expected for a failed job } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index f141a9bd0ff81..12a5542bd4bb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -210,7 +210,7 @@ object SparkSubmitClassLoaderTest extends Logging { } // Second, we load classes at the executor side. logInfo("Testing load classes at the executor side.") - val result = df.mapPartitions { x => + val result = df.rdd.mapPartitions { x => var exception: String = null try { Utils.classForName(args(0)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 3208ebc9ffb2e..10024874472f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -664,11 +664,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + .rdd .map { case Row(i: Int) => i } .collect() .toSet val expected = sql("SELECT key FROM src") + .rdd .map { case Row(i: Int) => i } .collect() .toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b11d1d9de092e..68249517f5c02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -119,6 +119,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // expr = (not leaf-0) assertResult(10) { sql("SELECT name, contacts FROM t where age > 5") + .rdd .flatMap(_.getAs[Seq[_]]("contacts")) .count() } @@ -131,7 +132,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") assert(df.count() === 2) assertResult(4) { - df.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() } } @@ -143,7 +144,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") assert(df.count() === 3) assertResult(6) { - df.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 68d5c7da1f974..a127cf6e4b7d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -854,7 +854,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with test(s"hive udfs $table") { checkAnswer( sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { + sql(s"SELECT stringField FROM $table").rdd.map { case Row(s: String) => Row(s + s) }.collect().toSeq) } From 5fcf4c2bfce4b7e3543815c8e49ffdec8072c9a2 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Thu, 25 Feb 2016 09:14:19 -0600 Subject: [PATCH 2026/2089] [SPARK-12316] Wait a minutes to avoid cycle calling. When application end, AM will clean the staging dir. But if the driver trigger to update the delegation token, it will can't find the right token file and then it will endless cycle call the method 'updateCredentialsIfRequired'. Then it lead driver StackOverflowError. https://issues.apache.org/jira/browse/SPARK-12316 Author: huangzhaowei Closes #10475 from SaintBacchus/SPARK-12316. --- .../spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 9d99c0d93fd1e..6474acc3dc9d2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -76,7 +76,10 @@ private[spark] class ExecutorDelegationTokenUpdater( SparkHadoopUtil.get.getTimeFromNowToRenewal( sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials) if (timeFromNowToRenewal <= 0) { - executorUpdaterRunnable.run() + // We just checked for new credentials but none were there, wait a minute and retry. + // This handles the shutdown case where the staging directory may have been removed(see + // SPARK-12316 for more details). + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.MINUTES) } else { logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.") delegationTokenRenewer.schedule( From 46f6e79316b72afea0c9b1559ea662dd3e95e57b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 25 Feb 2016 11:39:26 -0800 Subject: [PATCH 2027/2089] Revert "[SPARK-13117][WEB UI] WebUI should use the local ip not 0.0.0.0" This reverts commit 2e44031fafdb8cf486573b98e4faa6b31ffb90a4. --- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index e515916f31420..fe4949b9f6fee 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -134,7 +134,7 @@ private[spark] abstract class WebUI( def bind() { assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) try { - serverInfo = Some(startJettyServer(publicHostName, port, sslOptions, handlers, conf, name)) + serverInfo = Some(startJettyServer("0.0.0.0", port, sslOptions, handlers, conf, name)) logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) } catch { case e: Exception => diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 972b552f7a5b0..f416ace5c2b71 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.SparkConfWithEnv -import org.apache.spark.util.Utils class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { @@ -54,7 +53,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { } test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { - val SPARK_PUBLIC_DNS = Utils.localHostNameForURI() + val SPARK_PUBLIC_DNS = "public_dns" val conf = new SparkConfWithEnv(Map("SPARK_PUBLIC_DNS" -> SPARK_PUBLIC_DNS)).set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) From 751724b1320d38fd94186df3d8f1ca887f21947a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 25 Feb 2016 11:53:48 -0800 Subject: [PATCH 2028/2089] Revert "[SPARK-13457][SQL] Removes DataFrame RDD operations" This reverts commit 157fe64f3ecbd13b7286560286e50235eecfe30e. --- .../spark/examples/ml/DataFrameExample.scala | 2 +- .../examples/ml/DecisionTreeExample.scala | 8 ++-- .../spark/examples/ml/OneVsRestExample.scala | 2 +- .../spark/examples/mllib/LDAExample.scala | 1 - .../spark/examples/sql/RDDRelation.scala | 2 +- .../examples/sql/hive/HiveFromSpark.scala | 2 +- .../scala/org/apache/spark/ml/Predictor.scala | 6 +-- .../classification/LogisticRegression.scala | 13 +++--- .../spark/ml/clustering/BisectingKMeans.scala | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 6 +-- .../org/apache/spark/ml/clustering/LDA.scala | 1 - .../BinaryClassificationEvaluator.scala | 9 ++-- .../MulticlassClassificationEvaluator.scala | 6 +-- .../ml/evaluation/RegressionEvaluator.scala | 3 +- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/OneHotEncoder.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 1 - .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 1 - .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 6 +-- .../ml/regression/LinearRegression.scala | 16 +++---- .../mllib/api/python/PythonMLLibAPI.scala | 8 ++-- .../spark/mllib/clustering/KMeansModel.scala | 2 +- .../spark/mllib/clustering/LDAModel.scala | 4 +- .../clustering/PowerIterationClustering.scala | 2 +- .../BinaryClassificationMetrics.scala | 2 +- .../mllib/evaluation/MulticlassMetrics.scala | 2 +- .../mllib/evaluation/MultilabelMetrics.scala | 4 +- .../mllib/evaluation/RegressionMetrics.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 2 +- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../MatrixFactorizationModel.scala | 12 +++--- .../mllib/tree/model/DecisionTreeModel.scala | 2 +- .../mllib/tree/model/treeEnsembleModels.scala | 2 +- .../LogisticRegressionSuite.scala | 2 +- .../MultilayerPerceptronClassifierSuite.scala | 5 +-- .../ml/classification/OneVsRestSuite.scala | 6 +-- .../ml/clustering/BisectingKMeansSuite.scala | 3 +- .../spark/ml/clustering/KMeansSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 2 +- .../spark/ml/feature/OneHotEncoderSuite.scala | 4 +- .../spark/ml/feature/StringIndexerSuite.scala | 6 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 5 +-- .../spark/ml/feature/Word2VecSuite.scala | 8 ++-- .../spark/ml/recommendation/ALSSuite.scala | 7 ++-- .../ml/regression/GBTRegressorSuite.scala | 2 +- .../regression/IsotonicRegressionSuite.scala | 6 +-- .../ml/regression/LinearRegressionSuite.scala | 17 ++++---- .../org/apache/spark/sql/DataFrame.scala | 42 +++++++++++++++++++ .../org/apache/spark/sql/GroupedData.scala | 1 - .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 2 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 2 - .../spark/sql/hive/orc/OrcQuerySuite.scala | 5 +-- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- 67 files changed, 159 insertions(+), 140 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 7e608a281203e..0a477abae5679 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -79,7 +79,7 @@ object DataFrameExample { labelSummary.show() // Convert features column to an RDD of vectors. - val features = df.select("features").rdd.map { case Row(v: Vector) => v } + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index d2560cc00ba07..a37d12aa636cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -310,8 +310,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) // Print number of classes for reference val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { case Some(n) => n @@ -335,8 +335,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError println(s" Root mean squared error (RMSE): $RMSE") } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index a0bb5dabf4574..ccee3b2aef980 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -155,7 +155,7 @@ object OneVsRestExample { // evaluate the model val predictionsAndLabels = predictions.select("prediction", "label") - .rdd.map(row => (row.getDouble(0), row.getDouble(1))) + .map(row => (row.getDouble(0), row.getDouble(1))) val metrics = new MulticlassMetrics(predictionsAndLabels) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 038b2fe6112e8..d28323555b990 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -221,7 +221,6 @@ object LDAExample { val model = pipeline.fit(df) val documents = model.transform(df) .select("features") - .rdd .map { case Row(features: Vector) => features } .zipWithIndex() .map(_.swap) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 620ff07631c36..a2f0fcd0e486e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -52,7 +52,7 @@ object RDDRelation { val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") - rddFromSql.rdd.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) + rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b654a2c8d4a40..4e427f54daa52 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -63,7 +63,7 @@ object HiveFromSpark { val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") - val rddAsStrings = rddFromSql.rdd.map { + val rddAsStrings = rddFromSql.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 4b27ee6c5a414..d1388b5e2eb5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -122,10 +122,8 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { - dataset.select($(labelCol), $(featuresCol)).rdd.map { - case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - } + dataset.select($(labelCol), $(featuresCol)) + .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0d329d2c08d50..ac0124513f283 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -263,11 +263,10 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -791,7 +790,6 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * :: Experimental :: * Logistic regression training results. - * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. @@ -815,7 +813,6 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( /** * :: Experimental :: * Binary Logistic regression results for a given model. - * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. @@ -840,7 +837,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).rdd.map { + predictions.select(probabilityCol, labelCol).map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f014a1d572387..45d293bc69618 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") def computeCost(dataset: DataFrame): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } } @@ -176,7 +176,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: DataFrame): BisectingKMeansModel = { - val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() .setK($(k)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 79332b0d02157..c6a3eac5877f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -130,7 +130,7 @@ class KMeansModel private[ml] ( @Since("1.6.0") def computeCost(dataset: DataFrame): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } @@ -260,7 +260,7 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { - val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } val algo = new MLlibKMeans() .setK($(k)) @@ -303,7 +303,7 @@ class KMeansSummary private[clustering] ( * Size of each cluster. */ @Since("2.0.0") - lazy val size: Array[Int] = cluster.rdd.map { + lazy val size: Array[Int] = cluster.map { case Row(clusterIdx: Int) => (clusterIdx, 1) }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 6304b20d544ad..99383e77f7eb0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -803,7 +803,6 @@ private[clustering] object LDA extends DefaultParamsReadable[LDA] { dataset .withColumn("docId", monotonicallyIncreasingId()) .select("docId", featuresCol) - .rdd .map { case Row(docId: Long, features: Vector) => (docId, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 00f31255845ab..a1d36c4becfa2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -84,10 +84,11 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map { - case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) - case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) - } + val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) + .map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 55ff44323a790..a921153b9474f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -74,9 +74,9 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map { - case Row(prediction: Double, label: Double) => - (prediction, label) + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) } val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index adee61e297081..b6b25ecd01b3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -85,8 +85,7 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) - .rdd. - map { case Row(prediction: Double, label: Double) => + .map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 4abc459f5369a..7b565ef3ed922 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -79,7 +79,7 @@ final class ChiSqSelector(override val uid: String) override def fit(dataset: DataFrame): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { + val input = dataset.select($(labelCol), $(featuresCol)).map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index cf151458f0917..a6dfe58e56f0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -133,7 +133,7 @@ class CountVectorizer(override val uid: String) override def fit(dataset: DataFrame): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) - val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index cebbe5c162f79..9e7eee4f29988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -79,7 +79,7 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) copyValues(new IDFModel(uid, idf).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 18be5c0701fb1..ad0458d0d0e1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -108,7 +108,7 @@ class MinMaxScaler(override val uid: String) override def fit(dataset: DataFrame): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e9df161c00b83..342540418fa80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -130,7 +130,7 @@ class OneHotEncoder(override val uid: String) extends Transformer transformSchema(dataset.schema)(outputColName)) if (outputAttrGroup.size < 0) { // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) .aggregate(0.0)( (m, x) => { assert(x >=0.0 && x == x.toInt, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 80b124f74716d..0e07dfabfeaab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -70,7 +70,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams */ override def fit(dataset: DataFrame): PCAModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 9952d3bc9f1a5..6a0b6c240ec60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -87,7 +87,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e8b617d9c8df9..912bd95a2ec70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -83,7 +83,6 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) - .rdd .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 5c11760fab9b2..2a5268406ddf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -113,7 +113,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size - val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val maxCats = $(maxCategories) val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index a4c3d2751f50d..2b6b3c3a0fc58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -138,7 +138,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def fit(dataset: DataFrame): Word2VecModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() .setLearningRate($(stepSize)) .setMinCount($(minCount)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index dacdac9a1df16..4be4d6abed6b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -392,7 +392,6 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) - .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e4339d67b928d..1e5b4cb83c652 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -184,7 +184,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { - dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 36b006c10e1fb..1573bb4c1b745 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -90,9 +90,9 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w).rdd.map { - case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + dataset.select(col($(labelCol)), f, w) + .map { case Row(label: Double, feature: Double, weight: Double) => + (label, feature, weight) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8f78fd122f343..ccfb5c4b9d697 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. - val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { + val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { case Row(features: Vector) => features.size }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) @@ -170,7 +170,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -196,11 +196,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -514,7 +513,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the * training coefficients except for the objective trace. - * * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -539,7 +537,6 @@ class LinearRegressionTrainingSummary private[regression] ( /** * :: Experimental :: * Linear regression results evaluated on a dataset. - * * @param predictions predictions outputted by the model's `transform` method. */ @Since("1.5.0") @@ -554,7 +551,6 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions .select(predictionCol, labelCol) - .rdd .map { case Row(pred: Double, label: Double) => (pred, label) }, !model.getFitIntercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ca0ed95a483f4..93cf16e6f0c2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1052,7 +1052,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for the constructor of Python mllib RankingMetrics */ def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = { - new RankingMetrics(predictionAndLabels.rdd.map( + new RankingMetrics(predictionAndLabels.map( r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } @@ -1135,7 +1135,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { // We use DataFrames for serialization of IndexedRows from Python, // so map each Row in the DataFrame back to an IndexedRow. - val indexedRows = rows.rdd.map { + val indexedRows = rows.map { case Row(index: Long, vector: Vector) => IndexedRow(index, vector) } new IndexedRowMatrix(indexedRows, numRows, numCols) @@ -1147,7 +1147,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = { // We use DataFrames for serialization of MatrixEntry entries from // Python, so map each Row in the DataFrame back to a MatrixEntry. - val entries = rows.rdd.map { + val entries = rows.map { case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value) } new CoordinateMatrix(entries, numRows, numCols) @@ -1161,7 +1161,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks from // Python, so map each Row in the DataFrame back to a // ((blockRowIndex, blockColIndex), sub-matrix) tuple. - val blockTuples = blocks.rdd.map { + val blockTuples = blocks.map { case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3b91fe8643dd8..26c6235fe5907 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -143,7 +143,7 @@ object KMeansModel extends Loader[KMeansModel] { val k = (metadata \ "k").extract[Int] val centroids = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centroids.schema) - val localCentroids = centroids.rdd.map(Cluster.apply).collect() + val localCentroids = centroids.map(Cluster.apply).collect() assert(k == localCentroids.size) new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 25d67a3756f6c..b30ecb80209d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -896,11 +896,11 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { Loader.checkSchema[EdgeData](edgeDataFrame.schema) val globalTopicTotals: LDA.TopicCounts = dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector - val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.rdd.map { + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) } - val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.rdd.map { + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) } val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 9732dfa1744f5..feacafec7930f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -93,7 +93,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) - val assignmentsRDD = assignments.rdd.map { + val assignmentsRDD = assignments.map { case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 319c54724dac1..12cf22095720a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -58,7 +58,7 @@ class BinaryClassificationMetrics @Since("1.3.0") ( * @param scoreAndLabels a DataFrame with two double columns: score and label */ private[mllib] def this(scoreAndLabels: DataFrame) = - this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Unpersist intermediate RDDs used in the computation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 3029b15f588a4..c5104960cfcb6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -38,7 +38,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param predictionAndLabels a DataFrame with two double columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() private lazy val labelCount: Long = labelCountByClass.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index daf6ff4db4ed0..c100b3c9ec14a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -35,9 +35,7 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] * @param predictionAndLabels a DataFrame with two double array columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.rdd.map { r => - (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray) - }) + this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) private lazy val numDocs: Long = predictionAndLabels.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 0f4c97ec20c00..18c90b204a26a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -46,7 +46,7 @@ class RegressionMetrics @Since("2.0.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 4f0e13feae086..33728bf5d77e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) - val features = dataArray.rdd.map { + val features = dataArray.map { case Row(feature: Int) => (feature) }.collect() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index b35d7217d693e..85d609386ff94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -134,7 +134,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { } def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { - val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x => + val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) new FreqItemset(items, freq) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index c91729a9fd495..628cf1dd57280 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -369,13 +369,13 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.read.parquet(userPath(path)).rdd.map { - case Row(id: Int, features: Seq[_]) => - (id, features.asInstanceOf[Seq[Double]].toArray) - } - val productFeatures = sqlContext.read.parquet(productPath(path)).rdd.map { - case Row(id: Int, features: Seq[_]) => + val userFeatures = sqlContext.read.parquet(userPath(path)) + .map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) + } + val productFeatures = sqlContext.read.parquet(productPath(path)) + .map { case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) } new MatrixFactorizationModel(rank, userFeatures, productFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ec5d7b91892c9..89c470d573431 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -247,7 +247,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) - val nodes = dataRDD.rdd.map(NodeData.apply) + val nodes = dataRDD.map(NodeData.apply) // Build node data into a tree. val trees = constructTrees(nodes) assert(trees.size == 1, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 59713c382e58b..feabcee24fa2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -473,7 +473,7 @@ private[tree] object TreeEnsembleModel extends Logging { treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) - val nodes = sqlContext.read.parquet(datapath).rdd.map(NodeData.apply) + val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cfb9bbfd41ee7..972c0868a454a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -735,7 +735,7 @@ class LogisticRegressionSuite val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) - val histogram = binaryDataset.rdd.map { case Row(label: Double, features: Vector) => label } + val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label } .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 602b5a8116998..a326432d017fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,9 +76,8 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { - case Row(p: Double, l: Double) => (p, l) - } + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 2ae74a2090ecf..445e50d867e15 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -81,9 +81,9 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset.select("prediction", "label").rdd.map { - row => (row.getDouble(0), row.getDouble(1)) - } + val ovaResults = transformedDataset + .select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b719a8c7e7340..fc4a4add5d755 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -77,8 +77,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c684bc11cccff..e5357ba8e220c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -93,8 +93,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 03270401ad2bb..97dbfd9a4314a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -199,7 +199,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead // describeTopics val topics = model.describeTopics(3) assert(topics.count() === k) - assert(topics.select("topic").rdd.map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet) topics.select("termIndices").collect().foreach { case r: Row => val termIndices = r.getAs[Seq[Int]](0) assert(termIndices.length === 3 && termIndices.toSet.size === 3) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index e238b33ed8c64..76d12050f9677 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -51,7 +51,7 @@ class OneHotEncoderSuite .setDropLast(false) val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").rdd.map { r => + val output = encoded.select("id", "labelVec").map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet @@ -68,7 +68,7 @@ class OneHotEncoderSuite .setOutputCol("labelVec") val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").rdd.map { r => + val output = encoded.select("id", "labelVec").map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index dea61648f2f5a..5d199ca9b51b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -52,7 +52,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + val output = transformed.select("id", "labelIndex").map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -83,7 +83,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + val output = transformed.select("id", "labelIndex").map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 @@ -102,7 +102,7 @@ class StringIndexerSuite val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + val output = transformed.select("id", "labelIndex").map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index d4f836ef33ad9..67817fa4baf56 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -159,7 +159,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) assert(featureAttrs.name === "indexed") assert(featureAttrs.attributes.get.length === model.numFeatures) @@ -216,8 +216,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = - model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() points.zip(indexedPoints).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 1671fb6f3a85e..f094c550e545a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -100,7 +100,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val realVectors = model.getVectors.sort("word").select("vector").rdd.map { + val realVectors = model.getVectors.sort("word").select("vector").map { case Row(v: Vector) => v }.collect() // These expectations are just magic values, characterizing the current @@ -134,7 +134,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(docDF) val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { + val (synonyms, similarity) = model.findSynonyms("a", 2).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -161,7 +161,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val (synonyms, similarity) = model.findSynonyms("a", 6).rdd.map { + val (synonyms, similarity) = model.findSynonyms("a", 6).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -174,7 +174,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setWindowSize(10) .fit(docDF) - val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).rdd.map { + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip // The similarity score should be very different with the larger window diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2bedd70ce93e7..ff0d8f5568279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -342,10 +342,11 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => + val predictions = model.transform(test.toDF()) + .select("rating", "prediction") + .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) - } + } val rmse = if (implicitPrefs) { // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 244db8637bea0..09326600e620f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -87,7 +87,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) val preds = model.transform(df) - val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) + val predictions = preds.select("prediction").map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) assert(predictions.max() > 2) assert(predictions.min() < -1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index b8874b4cd32a4..f067c29d27a7d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -46,7 +46,7 @@ class IsotonicRegressionSuite val predictions = model .transform(dataset) - .select("prediction").rdd.map { case Row(pred) => + .select("prediction").map { case Row(pred) => pred }.collect() @@ -66,7 +66,7 @@ class IsotonicRegressionSuite val predictions = model .transform(features) - .select("prediction").rdd.map { + .select("prediction").map { case Row(pred) => pred }.collect() @@ -160,7 +160,7 @@ class IsotonicRegressionSuite val predictions = model .transform(features) - .select("prediction").rdd.map { + .select("prediction").map { case Row(pred) => pred }.collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 9dee04c8776db..3ae108d822de7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -686,18 +686,17 @@ class LinearRegressionSuite // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = datasetWithDenseFeature.select("features", "label") - .rdd .map { case Row(features: DenseVector, label: Double) => - val prediction = - features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + - model.intercept - label - prediction - } - .zip(model.summary.residuals.rdd.map(_.getDouble(0))) + val prediction = + features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + + model.intercept + label - prediction + } + .zip(model.summary.residuals.map(_.getDouble(0))) .collect() .foreach { case (manualResidual: Double, resultResidual: Double) => - assert(manualResidual ~== resultResidual relTol 1E-5) - } + assert(manualResidual ~== resultResidual relTol 1E-5) + } /* # Use the following R code to generate model training results. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index be902d688e519..abb8fe552b7cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1426,6 +1426,48 @@ class DataFrame private[sql]( */ def transform[U](t: DataFrame => DataFrame): DataFrame = t(this) + /** + * Returns a new RDD by applying a function to all rows of this DataFrame. + * @group rdd + * @since 1.3.0 + */ + def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) + + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + * @group rdd + * @since 1.3.0 + */ + def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + /** + * Returns a new RDD by applying a function to each partition of this DataFrame. + * @group rdd + * @since 1.3.0 + */ + def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Applies a function `f` to all rows. + * @group rdd + * @since 1.3.0 + */ + def foreach(f: Row => Unit): Unit = withNewExecutionId { + rdd.foreach(f) + } + + /** + * Applies a function f to each partition of this [[DataFrame]]. + * @group rdd + * @since 1.3.0 + */ + def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId { + rdd.foreachPartition(f) + } + /** * Returns the first `n` rows in the [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a7258d742aa96..f06d16116ec8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -306,7 +306,6 @@ class GroupedData protected[sql]( val values = df.select(pivotColumn) .distinct() .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .rdd .map(_.get(0)) .take(maxValues + 1) .toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 68a251757c596..d912aeb70d517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -100,7 +100,7 @@ private[r] object SQLUtils { } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { - df.rdd.map(r => rowToRBytes(r)) + df.map(r => rowToRBytes(r)) } private[this] def doConversion(data: Object, dataType: DataType): Object = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d7111a6a1c47e..e295722cacf15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -276,7 +276,7 @@ object JdbcUtils extends Logging { val rddSchema = df.schema val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt - df.rdd.foreachPartition { iterator => + df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7d96ef6fe0a10..f54bff9f18bc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -257,7 +257,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("count") { - assert(testData2.count() === testData2.rdd.map(_ => 1).count()) + assert(testData2.count() === testData2.map(_ => 1).count()) checkAnswer( testData2.agg(count('a), sumDistinct('a)), // non-partial diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index bd51154c58aa6..fbffe867e4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -101,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index c85eeddc2c6d9..3c74464d57012 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -599,7 +599,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data = sqlContext.range(200).rdd.map { i => + val data = sqlContext.range(200).map { i => if (i.getLong(0) < 150) Row(None) else Row("a") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 39920d8cc659f..085e4a49a57e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -330,7 +330,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { // listener should ignore the non SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().rdd.foreach(i => ()) + sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) @@ -398,7 +398,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { ).toDF() df.collect() try { - df.rdd.foreach(_ => throw new RuntimeException("Oops")) + df.foreach(_ => throw new RuntimeException("Oops")) } catch { case e: SparkException => // This is expected for a failed job } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 12a5542bd4bb2..f141a9bd0ff81 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -210,7 +210,7 @@ object SparkSubmitClassLoaderTest extends Logging { } // Second, we load classes at the executor side. logInfo("Testing load classes at the executor side.") - val result = df.rdd.mapPartitions { x => + val result = df.mapPartitions { x => var exception: String = null try { Utils.classForName(args(0)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 10024874472f2..3208ebc9ffb2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -664,13 +664,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") - .rdd .map { case Row(i: Int) => i } .collect() .toSet val expected = sql("SELECT key FROM src") - .rdd .map { case Row(i: Int) => i } .collect() .toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 68249517f5c02..b11d1d9de092e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -119,7 +119,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // expr = (not leaf-0) assertResult(10) { sql("SELECT name, contacts FROM t where age > 5") - .rdd .flatMap(_.getAs[Seq[_]]("contacts")) .count() } @@ -132,7 +131,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") assert(df.count() === 2) assertResult(4) { - df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.flatMap(_.getAs[Seq[_]]("contacts")).count() } } @@ -144,7 +143,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") assert(df.count() === 3) assertResult(6) { - df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.flatMap(_.getAs[Seq[_]]("contacts")).count() } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index a127cf6e4b7d4..68d5c7da1f974 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -854,7 +854,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with test(s"hive udfs $table") { checkAnswer( sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").rdd.map { + sql(s"SELECT stringField FROM $table").map { case Row(s: String) => Row(s + s) }.collect().toSeq) } From fb8bb04766005e8935607069c0155d639f407e8a Mon Sep 17 00:00:00 2001 From: Lin Zhao Date: Thu, 25 Feb 2016 12:32:17 -0800 Subject: [PATCH 2029/2089] [SPARK-13069][STREAMING] Add "ask" style store() to ActorReciever Introduces a "ask" style ```store``` in ```ActorReceiver``` as a way to allow actor receiver blocked by back pressure or maxRate. Author: Lin Zhao Closes #11176 from lin-zhao/SPARK-13069. --- .../spark/streaming/akka/ActorReceiver.scala | 39 ++++++++++++++++++- .../streaming/akka/JavaAkkaUtilsSuite.java | 2 + .../spark/streaming/akka/AkkaUtilsSuite.scala | 3 ++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala index c75dc92445b64..33415c15be2ef 100644 --- a/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala +++ b/external/akka/src/main/scala/org/apache/spark/streaming/akka/ActorReceiver.scala @@ -20,12 +20,15 @@ package org.apache.spark.streaming.akka import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} +import akka.pattern.ask +import akka.util.Timeout import com.typesafe.config.ConfigFactory import org.apache.spark.{Logging, TaskContext} @@ -105,13 +108,26 @@ abstract class ActorReceiver extends Actor { } /** - * Store a single item of received data to Spark's memory. + * Store a single item of received data to Spark's memory asynchronously. * These single items will be aggregated together into data blocks before * being pushed into Spark's memory. */ def store[T](item: T) { context.parent ! SingleItemData(item) } + + /** + * Store a single item of received data to Spark's memory and returns a `Future`. + * The `Future` will be completed when the operator finishes, or with an + * `akka.pattern.AskTimeoutException` after the given timeout has expired. + * These single items will be aggregated together into data blocks before + * being pushed into Spark's memory. + * + * This method allows the user to control the flow speed using `Future` + */ + def store[T](item: T, timeout: Timeout): Future[Unit] = { + context.parent.ask(AskStoreSingleItemData(item))(timeout).map(_ => ())(context.dispatcher) + } } /** @@ -162,6 +178,19 @@ abstract class JavaActorReceiver extends UntypedActor { def store[T](item: T) { context.parent ! SingleItemData(item) } + + /** + * Store a single item of received data to Spark's memory and returns a `Future`. + * The `Future` will be completed when the operator finishes, or with an + * `akka.pattern.AskTimeoutException` after the given timeout has expired. + * These single items will be aggregated together into data blocks before + * being pushed into Spark's memory. + * + * This method allows the user to control the flow speed using `Future` + */ + def store[T](item: T, timeout: Timeout): Future[Unit] = { + context.parent.ask(AskStoreSingleItemData(item))(timeout).map(_ => ())(context.dispatcher) + } } /** @@ -179,8 +208,10 @@ case class Statistics(numberOfMsgs: Int, /** Case class to receive data sent by child actors */ private[akka] sealed trait ActorReceiverData private[akka] case class SingleItemData[T](item: T) extends ActorReceiverData +private[akka] case class AskStoreSingleItemData[T](item: T) extends ActorReceiverData private[akka] case class IteratorData[T](iterator: Iterator[T]) extends ActorReceiverData private[akka] case class ByteBufferData(bytes: ByteBuffer) extends ActorReceiverData +private[akka] object Ack extends ActorReceiverData /** * Provides Actors as receivers for receiving stream. @@ -233,6 +264,12 @@ private[akka] class ActorReceiverSupervisor[T: ClassTag]( store(msg.asInstanceOf[T]) n.incrementAndGet + case AskStoreSingleItemData(msg) => + logDebug("received single sync") + store(msg.asInstanceOf[T]) + n.incrementAndGet + sender() ! Ack + case ByteBufferData(bytes) => logDebug("received bytes") store(bytes) diff --git a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java index b732506767154..ac5ef31c8b355 100644 --- a/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java +++ b/external/akka/src/test/java/org/apache/spark/streaming/akka/JavaAkkaUtilsSuite.java @@ -20,6 +20,7 @@ import akka.actor.ActorSystem; import akka.actor.Props; import akka.actor.SupervisorStrategy; +import akka.util.Timeout; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.Test; @@ -62,5 +63,6 @@ class JavaTestActor extends JavaActorReceiver { @Override public void onReceive(Object message) throws Exception { store((String) message); + store((String) message, new Timeout(1000)); } } diff --git a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala index f437585a98e4f..ce95d9dd72f90 100644 --- a/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala +++ b/external/akka/src/test/scala/org/apache/spark/streaming/akka/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.akka +import scala.concurrent.duration._ + import akka.actor.{Props, SupervisorStrategy} import org.apache.spark.SparkFunSuite @@ -60,5 +62,6 @@ class AkkaUtilsSuite extends SparkFunSuite { class TestActor extends ActorReceiver { override def receive: Receive = { case m: String => store(m) + case m => store(m, 10.seconds) } } From 14e2700de29d06460179a94cc9816bcd37344cf7 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 25 Feb 2016 13:21:33 -0800 Subject: [PATCH 2030/2089] [SPARK-12874][ML] ML StringIndexer does not protect itself from column name duplication ## What changes were proposed in this pull request? ML StringIndexer does not protect itself from column name duplication. We should still improve a way to validate a schema of `StringIndexer` and `StringIndexerModel`. However, it would be great to fix at another issue. ## How was this patch tested? unit test Author: Yu ISHIKAWA Closes #11370 from yu-iskw/SPARK-12874. --- .../org/apache/spark/ml/feature/StringIndexer.scala | 1 + .../apache/spark/ml/feature/StringIndexerSuite.scala | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 912bd95a2ec70..555f1130e46a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -150,6 +150,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset } + validateAndTransformSchema(dataset.schema) val indexer = udf { label: String => if (labelToIndex.contains(label)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5d199ca9b51b1..0dbaed2522957 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,17 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexerModel can't overwrite output column") { + val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val indexer = new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + intercept[IllegalArgumentException] { + indexer.transform(df) + } + } + test("StringIndexer read/write") { val t = new StringIndexer() .setInputCol("myInputCol") From 35316cb0b744bef9bcb390411ddc321167f953be Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 25 Feb 2016 13:29:10 -0800 Subject: [PATCH 2031/2089] [SPARK-13292] [ML] [PYTHON] QuantileDiscretizer should take random seed in PySpark ## What changes were proposed in this pull request? QuantileDiscretizer in Python should also specify a random seed. ## How was this patch tested? unit tests Author: Yu ISHIKAWA Closes #11362 from yu-iskw/SPARK-13292 and squashes the following commits: 02ffa76 [Yu ISHIKAWA] [SPARK-13292][ML][PYTHON] QuantileDiscretizer should take random seed in PySpark --- python/pyspark/ml/feature.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 464c9446f2f39..67bccfae7a26d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -939,7 +939,7 @@ def getDegree(self): @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed): """ .. note:: Experimental @@ -951,7 +951,9 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) >>> qds = QuantileDiscretizer(numBuckets=2, - ... inputCol="values", outputCol="buckets") + ... inputCol="values", outputCol="buckets", seed=123) + >>> qds.getSeed() + 123 >>> bucketizer = qds.fit(df) >>> splits = bucketizer.getSplits() >>> splits[0] @@ -971,9 +973,9 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol): "categories) into which data points are grouped. Must be >= 2. Default 2.") @keyword_only - def __init__(self, numBuckets=2, inputCol=None, outputCol=None): + def __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): """ - __init__(self, numBuckets=2, inputCol=None, outputCol=None) + __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) """ super(QuantileDiscretizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", @@ -987,9 +989,9 @@ def __init__(self, numBuckets=2, inputCol=None, outputCol=None): @keyword_only @since("2.0.0") - def setParams(self, numBuckets=2, inputCol=None, outputCol=None): + def setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): """ - setParams(self, numBuckets=2, inputCol=None, outputCol=None) + setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) Set the params for the QuantileDiscretizer """ kwargs = self.setParams._input_kwargs From dc6c5ea4c91c387deb87764c86c4f40ea71657b7 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 25 Feb 2016 15:36:25 -0800 Subject: [PATCH 2032/2089] [SPARK-13468][WEB UI] Fix a corner case where the Stage UI page should show DAG but it doesn't show MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When uses clicks more than one time on any stage in the DAG graph on the *Job* web UI page, many new *Stage* web UI pages are opened, but only half of their DAG graphs are expanded. After this PR's fix, every newly opened *Stage* page's DAG graph is expanded. Before: ![](https://cloud.githubusercontent.com/assets/15843379/13279144/74808e86-db10-11e5-8514-cecf31af8908.png) After: ![](https://cloud.githubusercontent.com/assets/15843379/13279145/77ca5dec-db10-11e5-9457-8e1985461328.png) ## What changes were proposed in this pull request? - Removed the `expandDagViz` parameter for _Stage_ page and related codes - Added a `onclick` function setting `expandDagVizArrowKey(false)` as `true` ## How was this patch tested? Manual tests (with this fix) to verified this fix work: - clicked many times on _Job_ Page's DAG Graph → each newly opened Stage page's DAG graph is expanded Manual tests (with this fix) to verified this fix do not break features we already had: - refreshed many times for a same _Stage_ page (whose DAG already expanded) → DAG remained expanded upon every refresh - refreshed many times for a same _Stage_ page (whose DAG unexpanded) → DAG remained unexpanded upon every refresh - refreshed many times for a same _Job_ page (whose DAG already expanded) → DAG remained expanded upon every refresh - refreshed many times for a same _Job_ page (whose DAG unexpanded) → DAG remained unexpanded upon every refresh Author: Liwei Lin Closes #11368 from proflin/SPARK-13468. --- .../org/apache/spark/ui/static/spark-dag-viz.js | 3 ++- .../src/main/scala/org/apache/spark/ui/UIUtils.scala | 7 ------- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 12 ------------ 3 files changed, 2 insertions(+), 20 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 4337c42087e79..1b0d4692d9cd0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -222,10 +222,11 @@ function renderDagVizForJob(svgContainer) { var attemptId = 0 var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) .select("a.name-link") - .attr("href") + "&expandDagViz=true"; + .attr("href"); container = svgContainer .append("a") .attr("xlink:href", stageLink) + .attr("onclick", "window.localStorage.setItem(expandDagVizArrowKey(false), true)") .append("g") .attr("id", containerId); } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ddd7f713fe417..0493513d667c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -408,13 +408,6 @@ private[spark] object UIUtils extends Logging { } - /** Return a script element that automatically expands the DAG visualization on page load. */ - def expandDagVizOnLoad(forJob: Boolean): Seq[Node] = { - - } - /** * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML * and make sure that it only contains anchors with root-relative links. Otherwise, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0b68b88566b70..689ab7dd5ed62 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -107,10 +107,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) - // If this is set, expand the dag visualization by default - val expandDagVizParam = request.getParameter("expandDagViz") - val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean - val stageId = parameterId.toInt val stageAttemptId = parameterAttempt.toInt val stageDataOption = progressListener.stageIdToData.get((stageId, stageAttemptId)) @@ -263,13 +259,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val dagViz = UIUtils.showDagVizForStage( stageId, operationGraphListener.getOperationGraphForStage(stageId)) - val maybeExpandDagViz: Seq[Node] = - if (expandDagViz) { - UIUtils.expandDagVizOnLoad(forJob = false) - } else { - Seq.empty - } - val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { (acc.name, acc.value) match { @@ -578,7 +567,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val content = summary ++ dagViz ++ - maybeExpandDagViz ++ showAdditionalMetrics ++ makeTimeline( // Only show the tasks in the table From 7a6ee8a8fe0fad78416ed7e1ac694959de5c5314 Mon Sep 17 00:00:00 2001 From: hushan Date: Thu, 25 Feb 2016 16:57:41 -0800 Subject: [PATCH 2033/2089] [SPARK-12009][YARN] Avoid to re-allocating yarn container while driver want to stop all Executors Author: hushan Closes #9992 from suyanNone/tricky. --- .../apache/spark/scheduler/cluster/YarnSchedulerBackend.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 1431bceb256a7..ca26277b50610 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -83,6 +83,9 @@ private[spark] abstract class YarnSchedulerBackend( override def stop(): Unit = { try { + // SPARK-12009: To prevent Yarn allocator from requesting backup for the executors which + // was Stopped by SchedulerBackend. + requestTotalExecutors(0, 0, Map.empty) super.stop() } finally { services.stop() From f2cfafdfe0f4b18f31bc63969e2abced1a66e896 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 25 Feb 2016 17:04:43 -0800 Subject: [PATCH 2034/2089] [SPARK-13501] Remove use of Guava Stopwatch Our nightly doc snapshot builds are failing due to some issue involving the Guava Stopwatch constructor: ``` [error] /home/jenkins/workspace/spark-master-docs/spark/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala:496: constructor Stopwatch in class Stopwatch cannot be accessed in class CoarseMesosSchedulerBackend [error] val stopwatch = new Stopwatch() [error] ^ ``` This Stopwatch constructor was deprecated in newer versions of Guava (https://github.com/google/guava/commit/fd0cbc2c5c90e85fb22c8e86ea19630032090943) and it's possible that some classpath issues affecting Unidoc could be causing this to trigger compilation failures. In order to work around this issue, this patch removes this use of Stopwatch since we don't use it anywhere else in the Spark codebase. Author: Josh Rosen Closes #11376 from JoshRosen/remove-stopwatch. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index f803cc7a36a9a..622f361ec2a3c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -19,14 +19,12 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{Collections, List => JList} -import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.{Buffer, HashMap, HashSet} -import com.google.common.base.Stopwatch import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} @@ -493,12 +491,11 @@ private[spark] class CoarseMesosSchedulerBackend( // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. // See SPARK-12330 - val stopwatch = new Stopwatch() - stopwatch.start() + val startTime = System.nanoTime() // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent while (numExecutors() > 0 && - stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) { + System.nanoTime() - startTime < shutdownTimeoutMS * 1000L * 1000L) { Thread.sleep(100) } From 712995757c22a0bd76e4ccb552446372acf2cc2e Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Thu, 25 Feb 2016 17:07:58 -0800 Subject: [PATCH 2035/2089] [SPARK-13387][MESOS] Add support for SPARK_DAEMON_JAVA_OPTS with MesosClusterDispatcher. ## What changes were proposed in this pull request? Add support for SPARK_DAEMON_JAVA_OPTS with MesosClusterDispatcher. ## How was the this patch tested? Manual testing by launching dispatcher with SPARK_DAEMON_JAVA_OPTS Author: Timothy Chen Closes #11277 from tnachen/cluster_dispatcher_opts. --- .../org/apache/spark/launcher/SparkClassCommandBuilder.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 931a24cfd4b1d..e575fd33080a2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -48,8 +48,8 @@ public List buildCommand(Map env) throws IOException { String memKey = null; String extraClassPath = null; - // Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + - // SPARK_DAEMON_MEMORY. + // Master, Worker, HistoryServer, ExternalShuffleService, MesosClusterDispatcher use + // SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. if (className.equals("org.apache.spark.deploy.master.Master")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); javaOptsKeys.add("SPARK_MASTER_OPTS"); @@ -69,6 +69,8 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.mesos.MesosClusterDispatcher")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") || className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); From 633d63a48ad98754dc7c56f9ac150fc2aa4e42c5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 25 Feb 2016 17:17:56 -0800 Subject: [PATCH 2036/2089] [SPARK-12757] Add block-level read/write locks to BlockManager ## Motivation As a pre-requisite to off-heap caching of blocks, we need a mechanism to prevent pages / blocks from being evicted while they are being read. With on-heap objects, evicting a block while it is being read merely leads to memory-accounting problems (because we assume that an evicted block is a candidate for garbage-collection, which will not be true during a read), but with off-heap memory this will lead to either data corruption or segmentation faults. ## Changes ### BlockInfoManager and reader/writer locks This patch adds block-level read/write locks to the BlockManager. It introduces a new `BlockInfoManager` component, which is contained within the `BlockManager`, holds the `BlockInfo` objects that the `BlockManager` uses for tracking block metadata, and exposes APIs for locking blocks in either shared read or exclusive write modes. `BlockManager`'s `get*()` and `put*()` methods now implicitly acquire the necessary locks. After a `get()` call successfully retrieves a block, that block is locked in a shared read mode. A `put()` call will block until it acquires an exclusive write lock. If the write succeeds, the write lock will be downgraded to a shared read lock before returning to the caller. This `put()` locking behavior allows us store a block and then immediately turn around and read it without having to worry about it having been evicted between the write and the read, which will allow us to significantly simplify `CacheManager` in the future (see #10748). See `BlockInfoManagerSuite`'s test cases for a more detailed specification of the locking semantics. ### Auto-release of locks at the end of tasks Our locking APIs support explicit release of locks (by calling `unlock()`), but it's not always possible to guarantee that locks will be released prior to the end of the task. One reason for this is our iterator interface: since our iterators don't support an explicit `close()` operator to signal that no more records will be consumed, operations like `take()` or `limit()` don't have a good means to release locks on their input iterators' blocks. Another example is broadcast variables, whose block locks can only be released at the end of the task. To address this, `BlockInfoManager` uses a pair of maps to track the set of locks acquired by each task. Lock acquisitions automatically record the current task attempt id by obtaining it from `TaskContext`. When a task finishes, code in `Executor` calls `BlockInfoManager.unlockAllLocksForTask(taskAttemptId)` to free locks. ### Locking and the MemoryStore In order to prevent in-memory blocks from being evicted while they are being read, the `MemoryStore`'s `evictBlocksToFreeSpace()` method acquires write locks on blocks which it is considering as candidates for eviction. These lock acquisitions are non-blocking, so a block which is being read will not be evicted. By holding write locks until the eviction is performed or skipped (in case evicting the blocks would not free enough memory), we avoid a race where a new reader starts to read a block after the block has been marked as an eviction candidate but before it has been removed. ### Locking and remote block transfer This patch makes small changes to to block transfer and network layer code so that locks acquired by the BlockTransferService are released as soon as block transfer messages are consumed and released by Netty. This builds on top of #11193, a bug fix related to freeing of network layer ManagedBuffers. ## FAQ - **Why not use Java's built-in [`ReadWriteLock`](https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html)?** Our locks operate on a per-task rather than per-thread level. Under certain circumstances a task may consist of multiple threads, so using `ReadWriteLock` would mean that we might call `unlock()` from a thread which didn't hold the lock in question, an operation which has undefined semantics. If we could rely on Java 8 classes, we might be able to use [`StampedLock`](https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/locks/StampedLock.html) to work around this issue. - **Why not detect "leaked" locks in tests?**: See above notes about `take()` and `limit`. Author: Josh Rosen Closes #10705 from JoshRosen/pin-pages. --- .../scala/org/apache/spark/CacheManager.scala | 6 +- .../spark/broadcast/TorrentBroadcast.scala | 68 ++- .../org/apache/spark/executor/Executor.scala | 18 +- .../spark/network/BlockDataManager.scala | 10 +- .../network/netty/NettyBlockRpcServer.scala | 6 +- .../org/apache/spark/scheduler/Task.scala | 1 + .../org/apache/spark/storage/BlockInfo.scala | 83 --- .../spark/storage/BlockInfoManager.scala | 445 +++++++++++++++ .../apache/spark/storage/BlockManager.scala | 261 ++++----- .../storage/BlockManagerManagedBuffer.scala | 48 ++ .../org/apache/spark/storage/BlockStore.scala | 2 + .../apache/spark/storage/MemoryStore.scala | 47 +- .../org/apache/spark/CacheManagerSuite.scala | 1 + .../apache/spark/ContextCleanerSuite.scala | 3 +- .../KryoSerializerDistributedSuite.scala | 6 +- .../spark/storage/BlockInfoManagerSuite.scala | 287 ++++++++++ .../BlockManagerReplicationSuite.scala | 6 + .../spark/storage/BlockManagerSuite.scala | 510 +++++++++++------- .../network/buffer/NioManagedBuffer.java | 2 +- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../receiver/ReceivedBlockHandler.scala | 4 + 22 files changed, 1384 insertions(+), 438 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 1ec9ba7755725..2b456facd9439 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.util.CompletionIterator /** * Spark class responsible for passing RDDs partition contents to the BlockManager and making @@ -47,6 +48,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { existingMetrics.incBytesReadInternal(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] + new InterruptibleIterator[T](context, iter) { override def next(): T = { existingMetrics.incRecordsReadInternal(1) @@ -156,7 +158,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Left(arr) => // We have successfully unrolled the entire partition, so cache it in memory blockManager.putArray(key, arr, level, tellMaster = true, effectiveStorageLevel) - arr.iterator.asInstanceOf[Iterator[T]] + CompletionIterator[T, Iterator[T]]( + arr.iterator.asInstanceOf[Iterator[T]], + blockManager.releaseLock(key)) case Right(it) => // There is not enough space to cache this partition in memory val returnValues = it.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 9bd69727f6086..c08f87a8b45c1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -24,10 +24,10 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ByteArrayChunkOutputStream @@ -90,22 +90,29 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** * Divide the object into multiple blocks and put those blocks in the block manager. + * * @param value the object to divide * @return number of blocks this broadcast variable is divided into */ private def writeBlocks(value: T): Int = { + import StorageLevel._ // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. - SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK, - tellMaster = false) + val blockManager = SparkEnv.get.blockManager + if (blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { + blockManager.releaseLock(broadcastId) + } else { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => - SparkEnv.get.blockManager.putBytes( - BroadcastBlockId(id, "piece" + i), - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) + val pieceId = BroadcastBlockId(id, "piece" + i) + if (blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) { + blockManager.releaseLock(pieceId) + } else { + throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") + } } blocks.length } @@ -127,16 +134,18 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => // If we found the block from remote executors/driver's BlockManager, put the block // in this executor's BlockManager. - SparkEnv.get.blockManager.putBytes( - pieceId, - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) + if (!bm.putBytes(pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } block } val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( throw new SparkException(s"Failed to get $pieceId of $broadcastId")) + // At this point we are guaranteed to hold a read lock, since we either got the block locally + // or stored the remotely-fetched block and automatically downgraded the write lock. blocks(pid) = block + releaseLock(pieceId) } blocks } @@ -165,8 +174,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) - SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { + val blockManager = SparkEnv.get.blockManager + blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => + releaseLock(broadcastId) x.asInstanceOf[T] case None => @@ -179,13 +190,36 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. - SparkEnv.get.blockManager.putSingle( - broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + releaseLock(broadcastId) + } else { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } obj } } } + /** + * If running in a task, register the given block's locks for release upon task completion. + * Otherwise, if not running in a task then immediately release the lock. + */ + private def releaseLock(blockId: BlockId): Unit = { + val blockManager = SparkEnv.get.blockManager + Option(TaskContext.get()) match { + case Some(taskContext) => + taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + case None => + // This should only happen on the driver, where broadcast variables may be accessed + // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow + // broadcast variables to be garbage collected we need to free the reference here + // which is slightly unsafe but is technically okay because broadcast variables aren't + // stored off-heap. + blockManager.releaseLock(blockId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 00be3a240dbac..a602fcac68a6b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -218,7 +218,9 @@ private[spark] class Executor( threwException = false res } finally { + val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { @@ -227,6 +229,17 @@ private[spark] class Executor( logError(errMsg) } } + + if (releasedLocks.nonEmpty) { + val errMsg = + s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" + + releasedLocks.mkString("[", ", ", "]") + if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } } val taskFinish = System.currentTimeMillis() @@ -266,8 +279,11 @@ private[spark] class Executor( ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) } else if (resultSize >= maxRpcMessageSize) { val blockId = TaskResultBlockId(taskId) - env.blockManager.putBytes( + val putSucceeded = env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) + if (putSucceeded) { + env.blockManager.releaseLock(blockId) + } logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 1745d52c81923..cc5e851c29b32 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -31,6 +31,14 @@ trait BlockDataManager { /** * Put the block locally, using the given storage level. + * + * Returns true if the block was stored and false if the put operation failed or the block + * already existed. */ - def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean + + /** + * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. + */ + def releaseLock(blockId: BlockId): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index df8c21fb837ed..e4246df83a6ec 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -65,7 +65,11 @@ class NettyBlockRpcServer( val level: StorageLevel = serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) - blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) + val blockId = BlockId(uploadBlock.blockId) + val putSucceeded = blockManager.putBlockData(blockId, data, level) + if (putSucceeded) { + blockManager.releaseLock(blockId) + } responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index a49f3716e2702..5c68d001f280f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -64,6 +64,7 @@ private[spark] abstract class Task[T]( taskAttemptId: Long, attemptNumber: Int, metricsSystem: MetricsSystem): T = { + SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, partitionId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala deleted file mode 100644 index 22fdf73e9d1f4..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.ConcurrentHashMap - -private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - // To save space, 'pending' and 'failed' are encoded as special sizes: - @volatile var size: Long = BlockInfo.BLOCK_PENDING - private def pending: Boolean = size == BlockInfo.BLOCK_PENDING - private def failed: Boolean = size == BlockInfo.BLOCK_FAILED - private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this) - - setInitThread() - - private def setInitThread() { - /* Set current thread as init thread - waitForReady will not block this thread - * (in case there is non trivial initialization which ends up calling waitForReady - * as part of initialization itself) */ - BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (pending && initThread != Thread.currentThread()) { - synchronized { - while (pending) { - this.wait() - } - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - require(sizeInBytes >= 0, s"sizeInBytes was negative: $sizeInBytes") - assert(pending) - size = sizeInBytes - BlockInfo.blockInfoInitThreads.remove(this) - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert(pending) - size = BlockInfo.BLOCK_FAILED - BlockInfo.blockInfoInitThreads.remove(this) - synchronized { - this.notifyAll() - } - } -} - -private object BlockInfo { - /* initThread is logically a BlockInfo field, but we store it here because - * it's only needed while this block is in the 'pending' state and we want - * to minimize BlockInfo's memory footprint. */ - private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] - - private val BLOCK_PENDING: Long = -1L - private val BLOCK_FAILED: Long = -2L -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala new file mode 100644 index 0000000000000..0eda97e58d451 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.common.collect.ConcurrentHashMultiset + +import org.apache.spark.{Logging, SparkException, TaskContext} + + +/** + * Tracks metadata for an individual block. + * + * Instances of this class are _not_ thread-safe and are protected by locks in the + * [[BlockInfoManager]]. + * + * @param level the block's storage level. This is the requested persistence level, not the + * effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this + * does not imply that the block is actually resident in memory). + * @param tellMaster whether state changes for this block should be reported to the master. This + * is true for most blocks, but is false for broadcast blocks. + */ +private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + + /** + * The size of the block (in bytes) + */ + def size: Long = _size + def size_=(s: Long): Unit = { + _size = s + checkInvariants() + } + private[this] var _size: Long = 0 + + /** + * The number of times that this block has been locked for reading. + */ + def readerCount: Int = _readerCount + def readerCount_=(c: Int): Unit = { + _readerCount = c + checkInvariants() + } + private[this] var _readerCount: Int = 0 + + /** + * The task attempt id of the task which currently holds the write lock for this block, or + * [[BlockInfo.NON_TASK_WRITER]] if the write lock is held by non-task code, or + * [[BlockInfo.NO_WRITER]] if this block is not locked for writing. + */ + def writerTask: Long = _writerTask + def writerTask_=(t: Long): Unit = { + _writerTask = t + checkInvariants() + } + private[this] var _writerTask: Long = 0 + + /** + * True if this block has been removed from the BlockManager and false otherwise. + * This field is used to communicate block deletion to blocked readers / writers (see its usage + * in [[BlockInfoManager]]). + */ + def removed: Boolean = _removed + def removed_=(r: Boolean): Unit = { + _removed = r + checkInvariants() + } + private[this] var _removed: Boolean = false + + private def checkInvariants(): Unit = { + // A block's reader count must be non-negative: + assert(_readerCount >= 0) + // A block is either locked for reading or for writing, but not for both at the same time: + assert(_readerCount == 0 || _writerTask == BlockInfo.NO_WRITER) + // If a block is removed then it is not locked: + assert(!_removed || (_readerCount == 0 && _writerTask == BlockInfo.NO_WRITER)) + } + + checkInvariants() +} + +private[storage] object BlockInfo { + + /** + * Special task attempt id constant used to mark a block's write lock as being unlocked. + */ + val NO_WRITER: Long = -1 + + /** + * Special task attempt id constant used to mark a block's write lock as being held by + * a non-task thread (e.g. by a driver thread or by unit test code). + */ + val NON_TASK_WRITER: Long = -1024 +} + +/** + * Component of the [[BlockManager]] which tracks metadata for blocks and manages block locking. + * + * The locking interface exposed by this class is readers-writer lock. Every lock acquisition is + * automatically associated with a running task and locks are automatically released upon task + * completion or failure. + * + * This class is thread-safe. + */ +private[storage] class BlockInfoManager extends Logging { + + private type TaskAttemptId = Long + + /** + * Used to look up metadata for individual blocks. Entries are added to this map via an atomic + * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed + * by [[removeBlock()]]. + */ + @GuardedBy("this") + private[this] val infos = new mutable.HashMap[BlockId, BlockInfo] + + /** + * Tracks the set of blocks that each task has locked for writing. + */ + @GuardedBy("this") + private[this] val writeLocksByTask = + new mutable.HashMap[TaskAttemptId, mutable.Set[BlockId]] + with mutable.MultiMap[TaskAttemptId, BlockId] + + /** + * Tracks the set of blocks that each task has locked for reading, along with the number of times + * that a block has been locked (since our read locks are re-entrant). + */ + @GuardedBy("this") + private[this] val readLocksByTask = + new mutable.HashMap[TaskAttemptId, ConcurrentHashMultiset[BlockId]] + + // ---------------------------------------------------------------------------------------------- + + // Initialization for special task attempt ids: + registerTask(BlockInfo.NON_TASK_WRITER) + + // ---------------------------------------------------------------------------------------------- + + /** + * Called at the start of a task in order to register that task with this [[BlockInfoManager]]. + * This must be called prior to calling any other BlockInfoManager methods from that task. + */ + def registerTask(taskAttemptId: TaskAttemptId): Unit = synchronized { + require(!readLocksByTask.contains(taskAttemptId), + s"Task attempt $taskAttemptId is already registered") + readLocksByTask(taskAttemptId) = ConcurrentHashMultiset.create() + } + + /** + * Returns the current task's task attempt id (which uniquely identifies the task), or + * [[BlockInfo.NON_TASK_WRITER]] if called by a non-task thread. + */ + private def currentTaskAttemptId: TaskAttemptId = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(BlockInfo.NON_TASK_WRITER) + } + + /** + * Lock a block for reading and return its metadata. + * + * If another task has already locked this block for reading, then the read lock will be + * immediately granted to the calling task and its lock count will be incremented. + * + * If another task has locked this block for writing, then this call will block until the write + * lock is released or will return immediately if `blocking = false`. + * + * A single task can lock a block multiple times for reading, in which case each lock will need + * to be released separately. + * + * @param blockId the block to lock. + * @param blocking if true (default), this call will block until the lock is acquired. If false, + * this call will return immediately if the lock acquisition fails. + * @return None if the block did not exist or was removed (in which case no lock is held), or + * Some(BlockInfo) (in which case the block is locked for reading). + */ + def lockForReading( + blockId: BlockId, + blocking: Boolean = true): Option[BlockInfo] = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to acquire read lock for $blockId") + infos.get(blockId).map { info => + while (info.writerTask != BlockInfo.NO_WRITER) { + if (blocking) wait() else return None + } + if (info.removed) return None + info.readerCount += 1 + readLocksByTask(currentTaskAttemptId).add(blockId) + logTrace(s"Task $currentTaskAttemptId acquired read lock for $blockId") + info + } + } + + /** + * Lock a block for writing and return its metadata. + * + * If another task has already locked this block for either reading or writing, then this call + * will block until the other locks are released or will return immediately if `blocking = false`. + * + * If this is called by a task which already holds the block's exclusive write lock, then this + * method will throw an exception. + * + * @param blockId the block to lock. + * @param blocking if true (default), this call will block until the lock is acquired. If false, + * this call will return immediately if the lock acquisition fails. + * @return None if the block did not exist or was removed (in which case no lock is held), or + * Some(BlockInfo) (in which case the block is locked for writing). + */ + def lockForWriting( + blockId: BlockId, + blocking: Boolean = true): Option[BlockInfo] = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to acquire write lock for $blockId") + infos.get(blockId).map { info => + if (info.writerTask == currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId has already locked $blockId for writing") + } else { + while (info.writerTask != BlockInfo.NO_WRITER || info.readerCount != 0) { + if (blocking) wait() else return None + } + if (info.removed) return None + } + info.writerTask = currentTaskAttemptId + writeLocksByTask.addBinding(currentTaskAttemptId, blockId) + logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId") + info + } + } + + /** + * Throws an exception if the current task does not hold a write lock on the given block. + * Otherwise, returns the block's BlockInfo. + */ + def assertBlockIsLockedForWriting(blockId: BlockId): BlockInfo = synchronized { + infos.get(blockId) match { + case Some(info) => + if (info.writerTask != currentTaskAttemptId) { + throw new SparkException( + s"Task $currentTaskAttemptId has not locked block $blockId for writing") + } else { + info + } + case None => + throw new SparkException(s"Block $blockId does not exist") + } + } + + /** + * Get a block's metadata without acquiring any locks. This method is only exposed for use by + * [[BlockManager.getStatus()]] and should not be called by other code outside of this class. + */ + private[storage] def get(blockId: BlockId): Option[BlockInfo] = synchronized { + infos.get(blockId) + } + + /** + * Downgrades an exclusive write lock to a shared read lock. + */ + def downgradeLock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId downgrading write lock for $blockId") + val info = get(blockId).get + require(info.writerTask == currentTaskAttemptId, + s"Task $currentTaskAttemptId tried to downgrade a write lock that it does not hold on" + + s" block $blockId") + unlock(blockId) + val lockOutcome = lockForReading(blockId, blocking = false) + assert(lockOutcome.isDefined) + } + + /** + * Release a lock on the given block. + */ + def unlock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + val info = get(blockId).getOrElse { + throw new IllegalStateException(s"Block $blockId not found") + } + if (info.writerTask != BlockInfo.NO_WRITER) { + info.writerTask = BlockInfo.NO_WRITER + writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + } else { + assert(info.readerCount > 0, s"Block $blockId is not locked for reading") + info.readerCount -= 1 + val countsForTask = readLocksByTask(currentTaskAttemptId) + val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 + assert(newPinCountForTask >= 0, + s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + } + notifyAll() + } + + /** + * Atomically create metadata for a block and acquire a write lock for it, if it doesn't already + * exist. + * + * @param blockId the block id. + * @param newBlockInfo the block info for the new block. + * @return true if the block did not already exist, false otherwise. If this returns false, then + * no new locks are acquired. If this returns true, a write lock on the new block will + * be held. + */ + def lockNewBlockForWriting( + blockId: BlockId, + newBlockInfo: BlockInfo): Boolean = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to put $blockId") + if (!infos.contains(blockId)) { + infos(blockId) = newBlockInfo + newBlockInfo.writerTask = currentTaskAttemptId + writeLocksByTask.addBinding(currentTaskAttemptId, blockId) + logTrace(s"Task $currentTaskAttemptId successfully locked new block $blockId") + true + } else { + logTrace(s"Task $currentTaskAttemptId did not create and lock block $blockId " + + s"because that block already exists") + false + } + } + + /** + * Release all lock held by the given task, clearing that task's pin bookkeeping + * structures and updating the global pin counts. This method should be called at the + * end of a task (either by a task completion handler or in `TaskRunner.run()`). + * + * @return the ids of blocks whose pins were released + */ + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() + + val readLocks = synchronized { + readLocksByTask.remove(taskAttemptId).get + } + val writeLocks = synchronized { + writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) + } + + for (blockId <- writeLocks) { + infos.get(blockId).foreach { info => + assert(info.writerTask == taskAttemptId) + info.writerTask = BlockInfo.NO_WRITER + } + blocksWithReleasedLocks += blockId + } + readLocks.entrySet().iterator().asScala.foreach { entry => + val blockId = entry.getElement + val lockCount = entry.getCount + blocksWithReleasedLocks += blockId + synchronized { + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) + } + } + } + + synchronized { + notifyAll() + } + blocksWithReleasedLocks + } + + /** + * Returns the number of blocks tracked. + */ + def size: Int = synchronized { + infos.size + } + + /** + * Return the number of map entries in this pin counter's internal data structures. + * This is used in unit tests in order to detect memory leaks. + */ + private[storage] def getNumberOfMapEntries: Long = synchronized { + size + + readLocksByTask.size + + readLocksByTask.map(_._2.size()).sum + + writeLocksByTask.size + + writeLocksByTask.map(_._2.size).sum + } + + /** + * Returns an iterator over a snapshot of all blocks' metadata. Note that the individual entries + * in this iterator are mutable and thus may reflect blocks that are deleted while the iterator + * is being traversed. + */ + def entries: Iterator[(BlockId, BlockInfo)] = synchronized { + infos.toArray.toIterator + } + + /** + * Removes the given block and releases the write lock on it. + * + * This can only be called while holding a write lock on the given block. + */ + def removeBlock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId") + infos.get(blockId) match { + case Some(blockInfo) => + if (blockInfo.writerTask != currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId called remove() on block $blockId without a write lock") + } else { + infos.remove(blockId) + blockInfo.readerCount = 0 + blockInfo.writerTask = BlockInfo.NO_WRITER + blockInfo.removed = true + } + case None => + throw new IllegalArgumentException( + s"Task $currentTaskAttemptId called remove() on non-existent block $blockId") + } + notifyAll() + } + + /** + * Delete all state. Called during shutdown. + */ + def clear(): Unit = synchronized { + infos.valuesIterator.foreach { blockInfo => + blockInfo.readerCount = 0 + blockInfo.writerTask = BlockInfo.NO_WRITER + blockInfo.removed = true + } + infos.clear() + readLocksByTask.clear() + writeLocksByTask.clear() + notifyAll() + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 77fd03a6bcfc5..29124b368e405 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,9 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ @@ -77,7 +75,7 @@ private[spark] class BlockManager( val diskBlockManager = new DiskBlockManager(this, conf) - private val blockInfo = new ConcurrentHashMap[BlockId, BlockInfo] + private[storage] val blockInfoManager = new BlockInfoManager private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) @@ -223,8 +221,8 @@ private[spark] class BlockManager( * will be made then. */ private def reportAllBlocks(): Unit = { - logInfo(s"Reporting ${blockInfo.size} blocks to the master.") - for ((blockId, info) <- blockInfo.asScala) { + logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.") + for ((blockId, info) <- blockInfoManager.entries) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { logError(s"Failed to report $blockId to master; giving up.") @@ -286,7 +284,7 @@ private[spark] class BlockManager( .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - new NioManagedBuffer(buffer) + new BlockManagerManagedBuffer(this, blockId, buffer) } else { throw new BlockNotFoundException(blockId.toString) } @@ -296,7 +294,7 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean = { putBytes(blockId, data.nioByteBuffer(), level) } @@ -305,7 +303,7 @@ private[spark] class BlockManager( * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { - blockInfo.asScala.get(blockId).map { info => + blockInfoManager.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L BlockStatus(info.level, memSize = memSize, diskSize = diskSize) @@ -318,7 +316,12 @@ private[spark] class BlockManager( * may not know of). */ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { - (blockInfo.asScala.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + // The `toArray` is necessary here in order to force the list to be materialized so that we + // don't try to serialize a lazy iterator when responding to client requests. + (blockInfoManager.entries.map(_._1) ++ diskBlockManager.getAllBlocks()) + .filter(filter) + .toArray + .toSeq } /** @@ -425,26 +428,11 @@ private[spark] class BlockManager( } private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { - val info = blockInfo.get(blockId) - if (info != null) { - info.synchronized { - // Double check to make sure the block is still there. There is a small chance that the - // block has been removed by removeBlock (which also synchronizes on the blockInfo object). - // Note that this only checks metadata tracking. If user intentionally deleted the block - // on disk or from off heap storage without using removeBlock, this conditional check will - // still pass but eventually we will get an exception because we can't find the block. - if (blockInfo.asScala.get(blockId).isEmpty) { - logWarning(s"Block $blockId had been removed") - return None - } - - // If another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure.") - return None - } - + blockInfoManager.lockForReading(blockId) match { + case None => + logDebug(s"Block $blockId was not found") + None + case Some(info) => val level = info.level logDebug(s"Level for block $blockId is $level") @@ -452,7 +440,10 @@ private[spark] class BlockManager( if (level.useMemory) { logDebug(s"Getting block $blockId from memory") val result = if (asBlockResult) { - memoryStore.getValues(blockId).map(new BlockResult(_, DataReadMethod.Memory, info.size)) + memoryStore.getValues(blockId).map { iter => + val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + new BlockResult(ci, DataReadMethod.Memory, info.size) + } } else { memoryStore.getBytes(blockId) } @@ -470,6 +461,7 @@ private[spark] class BlockManager( val bytes: ByteBuffer = diskStore.getBytes(blockId) match { case Some(b) => b case None => + releaseLock(blockId) throw new BlockException( blockId, s"Block $blockId not found on disk, though it should be") } @@ -478,8 +470,9 @@ private[spark] class BlockManager( if (!level.useMemory) { // If the block shouldn't be stored in memory, we can just return it if (asBlockResult) { - return Some(new BlockResult(dataDeserialize(blockId, bytes), DataReadMethod.Disk, - info.size)) + val iter = CompletionIterator[Any, Iterator[Any]]( + dataDeserialize(blockId, bytes), releaseLock(blockId)) + return Some(new BlockResult(iter, DataReadMethod.Disk, info.size)) } else { return Some(bytes) } @@ -511,26 +504,34 @@ private[spark] class BlockManager( // space to unroll the block. Either way, the put here should return an iterator. putResult.data match { case Left(it) => - return Some(new BlockResult(it, DataReadMethod.Disk, info.size)) + val ci = CompletionIterator[Any, Iterator[Any]](it, releaseLock(blockId)) + return Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) case _ => // This only happens if we dropped the values back to disk (which is never) throw new SparkException("Memory store did not return an iterator!") } } else { - return Some(new BlockResult(values, DataReadMethod.Disk, info.size)) + val ci = CompletionIterator[Any, Iterator[Any]](values, releaseLock(blockId)) + return Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } } } + } else { + // This branch represents a case where the BlockInfoManager contained an entry for + // the block but the block could not be found in any of the block stores. This case + // should never occur, but for completeness's sake we address it here. + logError( + s"Block $blockId is supposedly stored locally but was not found in any block store") + releaseLock(blockId) + None } - } - } else { - logDebug(s"Block $blockId not registered locally") } - None } /** * Get block from remote block managers. + * + * This does not acquire a lock on this block in this JVM. */ def getRemote(blockId: BlockId): Option[BlockResult] = { logDebug(s"Getting remote block $blockId") @@ -597,6 +598,10 @@ private[spark] class BlockManager( /** * Get a block from the block manager (either local or remote). + * + * This acquires a read lock on the block if the block was stored locally and does not acquire + * any locks if the block was fetched from a remote block manager. The read lock will + * automatically be freed once the result's `data` iterator is fully consumed. */ def get(blockId: BlockId): Option[BlockResult] = { val local = getLocal(blockId) @@ -612,6 +617,36 @@ private[spark] class BlockManager( None } + /** + * Downgrades an exclusive write lock to a shared read lock. + */ + def downgradeLock(blockId: BlockId): Unit = { + blockInfoManager.downgradeLock(blockId) + } + + /** + * Release a lock on the given block. + */ + def releaseLock(blockId: BlockId): Unit = { + blockInfoManager.unlock(blockId) + } + + /** + * Registers a task with the BlockManager in order to initialize per-task bookkeeping structures. + */ + def registerTask(taskAttemptId: Long): Unit = { + blockInfoManager.registerTask(taskAttemptId) + } + + /** + * Release all locks for the given task. + * + * @return the blocks whose locks were released. + */ + def releaseAllLocksForTask(taskAttemptId: Long): Seq[BlockId] = { + blockInfoManager.releaseAllLocksForTask(taskAttemptId) + } + /** * @return true if the block was stored or false if the block was already stored or an * error occurred. @@ -703,19 +738,12 @@ private[spark] class BlockManager( * to be dropped right after it got put into memory. Note, however, that other threads will * not be able to get() this block until we call markReady on its BlockInfo. */ val putBlockInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = Option(blockInfo.putIfAbsent(blockId, tinfo)) - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning(s"Block $blockId already exists on this machine; not re-adding it") - return false - } - // TODO: So the block info exists - but previous attempt to load it (?) failed. - // What do we do now ? Retry on it ? - oldBlockOpt.get + val newInfo = new BlockInfo(level, tellMaster) + if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo)) { + newInfo } else { - tinfo + logWarning(s"Block $blockId already exists on this machine; not re-adding it") + return false } } @@ -750,7 +778,7 @@ private[spark] class BlockManager( case _ => null } - var marked = false + var blockWasSuccessfullyStored = false putBlockInfo.synchronized { logTrace("Put for block %s took %s to get into synchronized block" @@ -792,11 +820,11 @@ private[spark] class BlockManager( } val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) - if (putBlockStatus.storageLevel != StorageLevel.NONE) { + blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid + if (blockWasSuccessfullyStored) { // Now that the block is in either the memory, externalBlockStore, or disk store, // let other threads read it, and tell the master about it. - marked = true - putBlockInfo.markReady(size) + putBlockInfo.size = size if (tellMaster) { reportBlockStatus(blockId, putBlockInfo, putBlockStatus) } @@ -805,13 +833,10 @@ private[spark] class BlockManager( } } } finally { - // If we failed in putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (!marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - putBlockInfo.markFailure() + if (blockWasSuccessfullyStored) { + blockInfoManager.downgradeLock(blockId) + } else { + blockInfoManager.removeBlock(blockId) logWarning(s"Putting block $blockId failed") } } @@ -852,7 +877,7 @@ private[spark] class BlockManager( .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } - marked + blockWasSuccessfullyStored } /** @@ -989,77 +1014,63 @@ private[spark] class BlockManager( * store reaches its limit and needs to free up space. * * If `data` is not put on disk, it won't be created. + * + * The caller of this method must hold a write lock on the block before calling this method. + * This method does not release the write lock. + * + * @return the block's new effective StorageLevel. */ def dropFromMemory( blockId: BlockId, - data: () => Either[Array[Any], ByteBuffer]): Unit = { - + data: () => Either[Array[Any], ByteBuffer]): StorageLevel = { logInfo(s"Dropping block $blockId from memory") - val info = blockInfo.get(blockId) - - // If the block has not already been dropped - if (info != null) { - info.synchronized { - // required ? As of now, this will be invoked only for blocks which are ready - // But in case this changes in future, adding for consistency sake. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure. Nothing to drop") - return - } else if (blockInfo.asScala.get(blockId).isEmpty) { - logWarning(s"Block $blockId was already dropped.") - return - } - var blockIsUpdated = false - val level = info.level - - // Drop to disk, if storage level requires - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo(s"Writing block $blockId to disk") - data() match { - case Left(elements) => - diskStore.putArray(blockId, elements, level, returnValues = false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) - } - blockIsUpdated = true - } + val info = blockInfoManager.assertBlockIsLockedForWriting(blockId) + var blockIsUpdated = false + val level = info.level + + // Drop to disk, if storage level requires + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo(s"Writing block $blockId to disk") + data() match { + case Left(elements) => + diskStore.putArray(blockId, elements, level, returnValues = false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + blockIsUpdated = true + } - // Actually drop from memory store - val droppedMemorySize = - if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val blockIsRemoved = memoryStore.remove(blockId) - if (blockIsRemoved) { - blockIsUpdated = true - } else { - logWarning(s"Block $blockId could not be dropped from memory as it does not exist") - } + // Actually drop from memory store + val droppedMemorySize = + if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val blockIsRemoved = memoryStore.remove(blockId) + if (blockIsRemoved) { + blockIsUpdated = true + } else { + logWarning(s"Block $blockId could not be dropped from memory as it does not exist") + } - val status = getCurrentBlockStatus(blockId, info) - if (info.tellMaster) { - reportBlockStatus(blockId, info, status, droppedMemorySize) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } - if (blockIsUpdated) { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status))) - } - } + val status = getCurrentBlockStatus(blockId, info) + if (info.tellMaster) { + reportBlockStatus(blockId, info, status, droppedMemorySize) + } + if (blockIsUpdated) { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status))) } } + status.storageLevel } /** * Remove all blocks belonging to the given RDD. + * * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(s"Removing RDD $rddId") - val blocksToRemove = blockInfo.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } @@ -1069,7 +1080,7 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfo.asScala.keys.collect { + val blocksToRemove = blockInfoManager.entries.map(_._1).collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } @@ -1081,9 +1092,11 @@ private[spark] class BlockManager( */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { logDebug(s"Removing block $blockId") - val info = blockInfo.get(blockId) - if (info != null) { - info.synchronized { + blockInfoManager.lockForWriting(blockId) match { + case None => + // The block has already been removed; do nothing. + logWarning(s"Asked to remove block $blockId, which does not exist") + case Some(info) => // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) @@ -1091,15 +1104,11 @@ private[spark] class BlockManager( logWarning(s"Block $blockId could not be removed as it was not found in either " + "the disk, memory, or external block store") } - blockInfo.remove(blockId) + blockInfoManager.removeBlock(blockId) if (tellMaster && info.tellMaster) { val status = getCurrentBlockStatus(blockId, info) reportBlockStatus(blockId, info, status) } - } - } else { - // The block has already been removed; do nothing. - logWarning(s"Asked to remove block $blockId, which does not exist") } } @@ -1174,7 +1183,7 @@ private[spark] class BlockManager( } diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) - blockInfo.clear() + blockInfoManager.clear() memoryStore.clear() diskStore.clear() futureExecutionContext.shutdownNow() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala new file mode 100644 index 0000000000000..5886b9c00b557 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} + +/** + * This [[ManagedBuffer]] wraps a [[ByteBuffer]] which was retrieved from the [[BlockManager]] + * so that the corresponding block's read lock can be released once this buffer's references + * are released. + * + * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks + * to the network layer's notion of retain / release counts. + */ +private[storage] class BlockManagerManagedBuffer( + blockManager: BlockManager, + blockId: BlockId, + buf: ByteBuffer) extends NioManagedBuffer(buf) { + + override def retain(): ManagedBuffer = { + super.retain() + val locked = blockManager.blockInfoManager.lockForReading(blockId, blocking = false) + assert(locked.isDefined) + this + } + + override def release(): ManagedBuffer = { + blockManager.releaseLock(blockId) + super.release() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 69985c9759e2d..6f6a6773ba4fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -60,8 +60,10 @@ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends /** * Remove a block, if it exists. + * * @param blockId the block to remove. * @return True if the block was found and removed, False otherwise. + * @throws IllegalStateException if the block is pinned by a task. */ def remove(blockId: BlockId): Boolean diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 024b660ce6a7b..2f16c8f3d8bad 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -95,7 +95,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - tryToPut(blockId, bytes, bytes.limit, deserialized = false) + tryToPut(blockId, () => bytes, bytes.limit, deserialized = false) PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -127,11 +127,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo returnValues: Boolean): PutResult = { if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, values, sizeEstimate, deserialized = true) + tryToPut(blockId, () => values, sizeEstimate, deserialized = true) PutResult(sizeEstimate, Left(values.iterator)) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, bytes, bytes.limit, deserialized = false) + tryToPut(blockId, () => bytes, bytes.limit, deserialized = false) PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -208,7 +208,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { - val entry = entries.synchronized { entries.remove(blockId) } + val entry = entries.synchronized { + entries.remove(blockId) + } if (entry != null) { memoryManager.releaseStorageMemory(entry.size) logDebug(s"Block $blockId of size ${entry.size} dropped " + @@ -327,14 +329,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo blockId.asRDDId.map(_.rddId) } - private def tryToPut( - blockId: BlockId, - value: Any, - size: Long, - deserialized: Boolean): Boolean = { - tryToPut(blockId, () => value, size, deserialized) - } - /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size @@ -410,6 +404,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] + def blockIsEvictable(blockId: BlockId): Boolean = { + rddToAdd.isEmpty || rddToAdd != getRddId(blockId) + } // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. @@ -418,9 +415,14 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { - selectedBlocks += blockId - freedMemory += pair.getValue.size + if (blockIsEvictable(blockId)) { + // We don't want to evict blocks which are currently being read, so we need to obtain + // an exclusive write lock on blocks which are candidates for eviction. We perform a + // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: + if (blockManager.blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { + selectedBlocks += blockId + freedMemory += pair.getValue.size + } } } } @@ -438,7 +440,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } else { Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) } - blockManager.dropFromMemory(blockId, () => data) + val newEffectiveStorageLevel = blockManager.dropFromMemory(blockId, () => data) + if (newEffectiveStorageLevel.isValid) { + // The block is still present in at least one store, so release the lock + // but don't delete the block info + blockManager.releaseLock(blockId) + } else { + // The block isn't present in any store, so delete the block info so that the + // block can be stored again + blockManager.blockInfoManager.removeBlock(blockId) + } } } freedMemory @@ -447,6 +458,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo logInfo(s"Will not store $id as it would require dropping another block " + "from the same RDD") } + selectedBlocks.foreach { id => + blockManager.releaseLock(id) + } 0L } } @@ -463,6 +477,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo /** * Reserve memory for unrolling the given block for this task. + * * @return whether the request is granted. */ def reserveUnrollMemoryForThisTask(blockId: BlockId, memory: Long): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 48a0282b30cf0..ffc02bcb011f3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -87,6 +87,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val context = TaskContext.empty() try { TaskContext.setTaskContext(context) + sc.env.blockManager.registerTask(0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlockStatuses.size === 2) } finally { diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 7b0238091730d..d1e806b2eb80a 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -485,7 +485,8 @@ class CleanerTester( def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout) { try { eventually(waitTimeout, interval(100 millis)) { - assert(isAllCleanedUp) + assert(isAllCleanedUp, + "The following resources were not cleaned up:\n" + uncleanedResourcesToString) } postCleanupValidate() } finally { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index a0483f6483889..c1484b0afa85f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ import org.apache.spark.util.Utils -class KryoSerializerDistributedSuite extends SparkFunSuite { +class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContext { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) @@ -34,7 +34,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) @@ -47,8 +47,6 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { // Join the two RDDs, and force evaluation assert(shuffledRDD.join(cachedRDD).collect().size == 1) - - LocalSparkContext.stop(sc) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala new file mode 100644 index 0000000000000..662b18f667b0d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkException, SparkFunSuite, TaskContext, TaskContextImpl} + + +class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { + + private implicit val ec = ExecutionContext.global + private var blockInfoManager: BlockInfoManager = _ + + override protected def beforeEach(): Unit = { + super.beforeEach() + blockInfoManager = new BlockInfoManager() + for (t <- 0 to 4) { + blockInfoManager.registerTask(t) + } + } + + override protected def afterEach(): Unit = { + try { + blockInfoManager = null + } finally { + super.afterEach() + } + } + + private implicit def stringToBlockId(str: String): BlockId = { + TestBlockId(str) + } + + private def newBlockInfo(): BlockInfo = { + new BlockInfo(StorageLevel.MEMORY_ONLY, tellMaster = false) + } + + private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null)) + block + } finally { + TaskContext.unset() + } + } + + test("initial memory usage") { + assert(blockInfoManager.size === 0) + } + + test("get non-existent block") { + assert(blockInfoManager.get("non-existent-block").isEmpty) + assert(blockInfoManager.lockForReading("non-existent-block").isEmpty) + assert(blockInfoManager.lockForWriting("non-existent-block").isEmpty) + } + + test("basic lockNewBlockForWriting") { + val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries + val blockInfo = newBlockInfo() + withTaskId(1) { + assert(blockInfoManager.lockNewBlockForWriting("block", blockInfo)) + assert(blockInfoManager.get("block").get eq blockInfo) + assert(!blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + assert(blockInfoManager.get("block").get eq blockInfo) + assert(blockInfo.readerCount === 0) + assert(blockInfo.writerTask === 1) + blockInfoManager.unlock("block") + assert(blockInfo.readerCount === 0) + assert(blockInfo.writerTask === BlockInfo.NO_WRITER) + } + assert(blockInfoManager.size === 1) + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 1) + } + + test("read locks are reentrant") { + withTaskId(1) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + assert(blockInfoManager.lockForReading("block").isDefined) + assert(blockInfoManager.lockForReading("block").isDefined) + assert(blockInfoManager.get("block").get.readerCount === 2) + assert(blockInfoManager.get("block").get.writerTask === BlockInfo.NO_WRITER) + blockInfoManager.unlock("block") + assert(blockInfoManager.get("block").get.readerCount === 1) + blockInfoManager.unlock("block") + assert(blockInfoManager.get("block").get.readerCount === 0) + } + } + + test("multiple tasks can hold read locks") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(2) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(3) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(4) { assert(blockInfoManager.lockForReading("block").isDefined) } + assert(blockInfoManager.get("block").get.readerCount === 4) + } + + test("single task can hold write lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForWriting("block").isDefined) + assert(blockInfoManager.get("block").get.writerTask === 1) + } + withTaskId(2) { + assert(blockInfoManager.lockForWriting("block", blocking = false).isEmpty) + assert(blockInfoManager.get("block").get.writerTask === 1) + } + } + + test("cannot call lockForWriting while already holding a write lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForWriting("block").isDefined) + intercept[IllegalStateException] { + blockInfoManager.lockForWriting("block") + } + blockInfoManager.assertBlockIsLockedForWriting("block") + } + } + + test("assertBlockIsLockedForWriting throws exception if block is not locked") { + intercept[SparkException] { + blockInfoManager.assertBlockIsLockedForWriting("block") + } + withTaskId(BlockInfo.NON_TASK_WRITER) { + intercept[SparkException] { + blockInfoManager.assertBlockIsLockedForWriting("block") + } + } + } + + test("downgrade lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.downgradeLock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForReading("block").isDefined) + } + assert(blockInfoManager.get("block").get.readerCount === 2) + assert(blockInfoManager.get("block").get.writerTask === BlockInfo.NO_WRITER) + } + + test("write lock will block readers") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val get1Future = Future { + withTaskId(1) { + blockInfoManager.lockForReading("block") + } + } + val get2Future = Future { + withTaskId(2) { + blockInfoManager.lockForReading("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.unlock("block") + } + assert(Await.result(get1Future, 1.seconds).isDefined) + assert(Await.result(get2Future, 1.seconds).isDefined) + assert(blockInfoManager.get("block").get.readerCount === 2) + } + + test("read locks will block writer") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + blockInfoManager.lockForReading("block") + } + val write1Future = Future { + withTaskId(1) { + blockInfoManager.lockForWriting("block") + } + } + val write2Future = Future { + withTaskId(2) { + blockInfoManager.lockForWriting("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.unlock("block") + } + assert( + Await.result(Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) + val firstWriteWinner = if (write1Future.isCompleted) 1 else 2 + withTaskId(firstWriteWinner) { + blockInfoManager.unlock("block") + } + assert(Await.result(write1Future, 1.seconds).isDefined) + assert(Await.result(write2Future, 1.seconds).isDefined) + } + + test("removing a non-existent block throws IllegalArgumentException") { + withTaskId(0) { + intercept[IllegalArgumentException] { + blockInfoManager.removeBlock("non-existent-block") + } + } + } + + test("removing a block without holding any locks throws IllegalStateException") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + intercept[IllegalStateException] { + blockInfoManager.removeBlock("block") + } + } + } + + test("removing a block while holding only a read lock throws IllegalStateException") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + assert(blockInfoManager.lockForReading("block").isDefined) + intercept[IllegalStateException] { + blockInfoManager.removeBlock("block") + } + } + } + + test("removing a block causes blocked callers to receive None") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val getFuture = Future { + withTaskId(1) { + blockInfoManager.lockForReading("block") + } + } + val writeFuture = Future { + withTaskId(2) { + blockInfoManager.lockForWriting("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.removeBlock("block") + } + assert(Await.result(getFuture, 1.seconds).isEmpty) + assert(Await.result(writeFuture, 1.seconds).isEmpty) + } + + test("releaseAllLocksForTask releases write locks") { + val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 3) + blockInfoManager.releaseAllLocksForTask(0) + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 3fd6fb4560465..a94d8b424d956 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -190,6 +190,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { stores.head.putSingle(blockId, new Array[Byte](blockSize), level) + stores.head.releaseLock(blockId) val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet stores.foreach { _.removeBlock(blockId) } master.removeBlock(blockId) @@ -251,6 +252,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo // Insert a block with 2x replication and return the number of copies of the block def replicateAndGetNumCopies(blockId: String): Int = { store.putSingle(blockId, new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK_2) + store.releaseLock(blockId) val numLocations = master.getLocations(blockId).size allStores.foreach { _.removeBlock(blockId) } numLocations @@ -288,6 +290,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo def replicateAndGetNumCopies(blockId: String, replicationFactor: Int): Int = { val storageLevel = StorageLevel(true, true, false, true, replicationFactor) initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + initialStores.head.releaseLock(blockId) val numLocations = master.getLocations(blockId).size allStores.foreach { _.removeBlock(blockId) } numLocations @@ -355,6 +358,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val blockId = new TestBlockId( "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + stores(0).releaseLock(blockId) // Assert that master know two locations for the block val blockLocations = master.getLocations(blockId).map(_.executorId).toSet @@ -367,6 +371,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName") + testStore.releaseLock(blockId) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") @@ -392,6 +397,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo (1 to 10).foreach { i => testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER) + testStore.releaseLock(s"dummy-block-$i") } (1 to 10).foreach { i => testStore.removeBlock(s"dummy-block-$i") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e1b2c9633edca..e4ab9ee0ebb38 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -45,6 +45,8 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { + import BlockManagerSuite._ + var conf: SparkConf = null var store: BlockManager = null var store2: BlockManager = null @@ -66,6 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master): BlockManager = { + val serializer = new KryoSerializer(conf) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, @@ -169,14 +172,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") // Checking whether master knows about the blocks or not assert(master.getLocations("a1").size > 0, "master was not told about a1") @@ -184,10 +187,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", () => null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", () => null: Either[Array[Any], ByteBuffer]) - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") + store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ByteBuffer]) + assert(store.getSingleAndReleaseLock("a1") === None, "a1 not removed from store") + assert(store.getSingleAndReleaseLock("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") assert(master.getLocations("a2").size === 0, "master did not remove a2") } @@ -202,7 +205,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY_2) store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") @@ -215,17 +218,17 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putSingleAndReleaseLock("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory and memory size val memStatus = master.getMemoryStatus.head._2 assert(memStatus._1 == 20000L, "total memory " + memStatus._1 + " should equal 20000") assert(memStatus._2 <= 12000L, "remaining memory " + memStatus._2 + " should <= 12000") - assert(store.getSingle("a1-to-remove").isDefined, "a1 was not in store") - assert(store.getSingle("a2-to-remove").isDefined, "a2 was not in store") - assert(store.getSingle("a3-to-remove").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1-to-remove").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2-to-remove").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3-to-remove").isDefined, "a3 was not in store") // Checking whether master knows about the blocks or not assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") @@ -238,15 +241,15 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeBlock("a3-to-remove") eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a1-to-remove") should be (None) + assert(!store.hasLocalBlock("a1-to-remove")) master.getLocations("a1-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a2-to-remove") should be (None) + assert(!store.hasLocalBlock("a2-to-remove")) master.getLocations("a2-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a3-to-remove") should not be (None) + assert(store.hasLocalBlock("a3-to-remove")) master.getLocations("a3-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { @@ -262,30 +265,30 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) // Putting a1, a2 and a3 in memory. - store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) - store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("nonrddblock", a3, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle(rdd(0, 0)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 0)) should be (None) master.getLocations(rdd(0, 0)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle(rdd(0, 1)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 1)) should be (None) master.getLocations(rdd(0, 1)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("nonrddblock") should not be (None) + store.getSingleAndReleaseLock("nonrddblock") should not be (None) master.getLocations("nonrddblock") should have size (1) } - store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = true) - store.getSingle(rdd(0, 0)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 0)) should be (None) master.getLocations(rdd(0, 0)) should have size 0 - store.getSingle(rdd(0, 1)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 1)) should be (None) master.getLocations(rdd(0, 1)) should have size 0 } @@ -305,54 +308,54 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // insert broadcast blocks in both the stores Seq(driverStore, executorStore).foreach { case s => - s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) - s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) - s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) - s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) + s.putSingleAndReleaseLock(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) } // verify whether the blocks exist in both the stores Seq(driverStore, executorStore).foreach { case s => - s.getLocal(broadcast0BlockId) should not be (None) - s.getLocal(broadcast1BlockId) should not be (None) - s.getLocal(broadcast2BlockId) should not be (None) - s.getLocal(broadcast2BlockId2) should not be (None) + assert(s.hasLocalBlock(broadcast0BlockId)) + assert(s.hasLocalBlock(broadcast1BlockId)) + assert(s.hasLocalBlock(broadcast2BlockId)) + assert(s.hasLocalBlock(broadcast2BlockId2)) } // remove broadcast 0 block only from executors master.removeBroadcast(0, removeFromMaster = false, blocking = true) // only broadcast 0 block should be removed from the executor store - executorStore.getLocal(broadcast0BlockId) should be (None) - executorStore.getLocal(broadcast1BlockId) should not be (None) - executorStore.getLocal(broadcast2BlockId) should not be (None) + assert(!executorStore.hasLocalBlock(broadcast0BlockId)) + assert(executorStore.hasLocalBlock(broadcast1BlockId)) + assert(executorStore.hasLocalBlock(broadcast2BlockId)) // nothing should be removed from the driver store - driverStore.getLocal(broadcast0BlockId) should not be (None) - driverStore.getLocal(broadcast1BlockId) should not be (None) - driverStore.getLocal(broadcast2BlockId) should not be (None) + assert(driverStore.hasLocalBlock(broadcast0BlockId)) + assert(driverStore.hasLocalBlock(broadcast1BlockId)) + assert(driverStore.hasLocalBlock(broadcast2BlockId)) // remove broadcast 0 block from the driver as well master.removeBroadcast(0, removeFromMaster = true, blocking = true) - driverStore.getLocal(broadcast0BlockId) should be (None) - driverStore.getLocal(broadcast1BlockId) should not be (None) + assert(!driverStore.hasLocalBlock(broadcast0BlockId)) + assert(driverStore.hasLocalBlock(broadcast1BlockId)) // remove broadcast 1 block from both the stores asynchronously // and verify all broadcast 1 blocks have been removed master.removeBroadcast(1, removeFromMaster = true, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - driverStore.getLocal(broadcast1BlockId) should be (None) - executorStore.getLocal(broadcast1BlockId) should be (None) + assert(!driverStore.hasLocalBlock(broadcast1BlockId)) + assert(!executorStore.hasLocalBlock(broadcast1BlockId)) } // remove broadcast 2 from both the stores asynchronously // and verify all broadcast 2 blocks have been removed master.removeBroadcast(2, removeFromMaster = true, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - driverStore.getLocal(broadcast2BlockId) should be (None) - driverStore.getLocal(broadcast2BlockId2) should be (None) - executorStore.getLocal(broadcast2BlockId) should be (None) - executorStore.getLocal(broadcast2BlockId2) should be (None) + assert(!driverStore.hasLocalBlock(broadcast2BlockId)) + assert(!driverStore.hasLocalBlock(broadcast2BlockId2)) + assert(!executorStore.hasLocalBlock(broadcast2BlockId)) + assert(!executorStore.hasLocalBlock(broadcast2BlockId2)) } executorStore.stop() driverStore.stop() @@ -363,9 +366,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(2000) val a1 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") assert(master.getLocations("a1").size > 0, "master was not told about a1") master.removeExecutor(store.blockManagerId.executorId) @@ -381,13 +384,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) assert(master.getLocations("a1").size > 0, "master was not told about a1") master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.MEMORY_ONLY) store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") @@ -404,12 +407,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { override def run() { - store.putIterator("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } } val t2 = new Thread { override def run() { - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY) } } val t3 = new Thread { @@ -425,8 +429,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE t2.join() t3.join() - store.dropFromMemory("a1", () => null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -437,9 +441,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list2 = List(new Array[Byte](500), new Array[Byte](1000), new Array[Byte](1500)) val list1SizeEstimate = SizeEstimator.estimate(list1.iterator.toArray) val list2SizeEstimate = SizeEstimator.estimate(list2.iterator.toArray) - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) @@ -479,8 +486,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store2 = makeBlockManager(8000, "executor2") store3 = makeBlockManager(8000, "executor3") val list1 = List(new Array[Byte](4000)) - store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store2.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") store2.stop() store2 = null @@ -506,18 +515,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, storageLevel) - store.putSingle("a2", a2, storageLevel) - store.putSingle("a3", a3, storageLevel) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.putSingleAndReleaseLock("a1", a1, storageLevel) + store.putSingleAndReleaseLock("a2", a2, storageLevel) + store.putSingleAndReleaseLock("a3", a3, storageLevel) + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1") === None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, storageLevel) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + store.putSingleAndReleaseLock("a1", a1, storageLevel) + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3") === None, "a3 was in store") } test("in-memory LRU for partitions of same RDD") { @@ -525,34 +534,34 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") - assert(store.getSingle(rdd(0, 1)).isDefined, "rdd_0_1 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 1)).isDefined, "rdd_0_1 was not in store") // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") } test("in-memory LRU for partitions of multiple RDDs") { store = makeBlockManager(12000) - store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 2), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(1, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 2), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(1, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // At this point rdd_1_1 should've replaced rdd_0_1 assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store") assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle(rdd(0, 3), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(0, 4), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 3), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 4), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped // when we try to add rdd_0_4. assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store") @@ -567,28 +576,28 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - store.putSingle("a2", a2, StorageLevel.DISK_ONLY) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was in store") - assert(store.getSingle("a3").isDefined, "a3 was in store") - assert(store.getSingle("a1").isDefined, "a1 was in store") + store.putSingleAndReleaseLock("a1", a1, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a3", a3, StorageLevel.DISK_ONLY) + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingle) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingleAndReleaseLock) } test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytes) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytesAndReleaseLock) } test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingle) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingleAndReleaseLock) } test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytes) + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytesAndReleaseLock) } def testDiskAndMemoryStorage( @@ -598,9 +607,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, storageLevel) - store.putSingle("a2", a2, storageLevel) - store.putSingle("a3", a3, storageLevel) + store.putSingleAndReleaseLock("a1", a1, storageLevel) + store.putSingleAndReleaseLock("a2", a2, storageLevel) + store.putSingleAndReleaseLock("a3", a3, storageLevel) assert(accessMethod(store)("a2").isDefined, "a2 was not in store") assert(accessMethod(store)("a3").isDefined, "a3 was not in store") assert(store.memoryStore.getValues("a1").isEmpty, "a1 was in memory store") @@ -615,19 +624,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val a3 = new Array[Byte](4000) val a4 = new Array[Byte](4000) // First store a1 and a2, both in memory, and a3, on disk only - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock("a2", a2, StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock("a3", a3, StorageLevel.DISK_ONLY) // At this point LRU should not kick in because a3 is only on disk - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") // Now let's add in a4, which uses both disk and memory; a1 should drop out - store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a1") == None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a4").isDefined, "a4 was not in store") + store.putSingleAndReleaseLock("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getSingleAndReleaseLock("a1") == None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } test("in-memory LRU with streams") { @@ -635,23 +644,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list2").isDefined, "list2 was not in store") + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list1") === None, "list1 was in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list1").isDefined, "list1 was not in store") + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3") === None, "list1 was in store") + assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } test("LRU with mixed storage levels and streams") { @@ -661,33 +674,37 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) val list4 = List(new Array[Byte](2000), new Array[Byte](2000)) // First store list1 and list2, both in memory, and list3, on disk only - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.putIterator("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val listForSizeEstimate = new ArrayBuffer[Any] listForSizeEstimate ++= list1.iterator val listSize = SizeEstimator.estimate(listForSizeEstimate) // At this point LRU should not kick in because list3 is only on disk - assert(store.get("list1").isDefined, "list1 was not in store") + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list1").isDefined, "list1 was not in store") + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.putIterator("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2").isDefined, "list2 was not in store") + store.putIteratorAndReleaseLock( + "list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) + assert(store.getAndReleaseLock("list1") === None, "list1 was in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list4").isDefined, "list4 was not in store") + assert(store.getAndReleaseLock("list4").isDefined, "list4 was not in store") assert(store.get("list4").get.data.size === 2) } @@ -705,18 +722,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("overly large block") { store = makeBlockManager(5000) - store.putSingle("a1", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") === None, "a1 was in store") - store.putSingle("a2", new Array[Byte](10000), StorageLevel.MEMORY_AND_DISK) + store.putSingleAndReleaseLock("a1", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) + assert(store.getSingleAndReleaseLock("a1") === None, "a1 was in store") + store.putSingleAndReleaseLock("a2", new Array[Byte](10000), StorageLevel.MEMORY_AND_DISK) assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") } test("block compression") { try { conf.set("spark.shuffle.compress", "true") store = makeBlockManager(20000, "exec1") - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") store.stop() @@ -724,7 +742,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.compress", "false") store = makeBlockManager(20000, "exec2") - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 10000, "shuffle_0_0_0 was compressed") store.stop() @@ -732,7 +751,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.broadcast.compress", "true") store = makeBlockManager(20000, "exec3") - store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 1000, "broadcast_0 was not compressed") store.stop() @@ -740,28 +760,29 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.broadcast.compress", "false") store = makeBlockManager(20000, "exec4") - store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock( + BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 10000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") store = makeBlockManager(20000, "exec5") - store.putSingle(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 1000, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") store = makeBlockManager(20000, "exec6") - store.putSingle(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock(rdd(0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 10000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed store = makeBlockManager(20000, "exec7") - store.putSingle("other_block", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock("other_block", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 10000, "other_block was compressed") store.stop() store = null @@ -789,12 +810,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE class UnserializableClass val a1 = new UnserializableClass intercept[java.io.NotSerializableException] { - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) + store.putSingleAndReleaseLock("a1", a1, StorageLevel.DISK_ONLY) } // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1").isEmpty, "a1 should not be in store") + assert(store.getSingleAndReleaseLock("a1").isEmpty, "a1 should not be in store") } } @@ -844,6 +865,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("updated block statuses") { store = makeBlockManager(12000) + store.registerTask(0) val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) @@ -860,7 +882,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 1 updated block (i.e. list1) val updatedBlocks1 = getUpdatedBlocks { - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks1.size === 1) assert(updatedBlocks1.head._1 === TestBlockId("list1")) @@ -868,7 +891,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 1 updated block (i.e. list2) val updatedBlocks2 = getUpdatedBlocks { - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } assert(updatedBlocks2.size === 1) assert(updatedBlocks2.head._1 === TestBlockId("list2")) @@ -876,7 +900,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 2 updated blocks - list1 is kicked out of memory while list3 is added val updatedBlocks3 = getUpdatedBlocks { - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks3.size === 2) updatedBlocks3.foreach { case (id, status) => @@ -890,7 +915,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // 2 updated blocks - list2 is kicked out of memory (but put on disk) while list4 is added val updatedBlocks4 = getUpdatedBlocks { - store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks4.size === 2) updatedBlocks4.foreach { case (id, status) => @@ -905,7 +931,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // No updated blocks - list5 is too big to fit in store and nothing is kicked out val updatedBlocks5 = getUpdatedBlocks { - store.putIterator("list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } assert(updatedBlocks5.size === 0) @@ -929,9 +956,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](2000)) // Tell master. By LRU, only list2 and list3 remains. - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getLocations("list1").size === 0) @@ -945,9 +975,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. - store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) - store.putIterator("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.putIterator("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIteratorAndReleaseLock( + "list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIteratorAndReleaseLock( + "list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIteratorAndReleaseLock( + "list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) // getLocations should return nothing because the master is not informed // getBlockStatus without asking slaves should have the same result @@ -968,9 +1001,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](100)) // insert some blocks - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size @@ -979,9 +1015,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE === 1) // insert some more blocks - store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIteratorAndReleaseLock( + "newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIteratorAndReleaseLock( + "newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIteratorAndReleaseLock( + "newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size @@ -991,7 +1030,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => - store.putIterator(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIteratorAndReleaseLock( + blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } val matchedBlockIds = store.master.getMatchingBlockIds(_ match { case RDDBlockId(1, _) => true @@ -1002,12 +1042,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = makeBlockManager(12000) - store.putSingle(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) - store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. - assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") + assert(store.getSingleAndReleaseLock(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") // According to the same-RDD rule, rdd_1_0 should be replaced here. - store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) + store.putSingleAndReleaseLock(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // rdd_1_0 should have been replaced, even it's not least recently used. assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store") assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") @@ -1086,8 +1126,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. - store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) - store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) assert(memoryStore.currentUnrollMemoryForThisTask === 0) @@ -1098,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. // In the mean time, however, we kicked out someBlock2 before giving up. - store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator @@ -1130,8 +1170,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // would not know how to drop them from memory later. memoryStore.remove("b1") memoryStore.remove("b2") - store.putIterator("b1", smallIterator, memOnly) - store.putIterator("b2", smallIterator, memOnly) + store.putIteratorAndReleaseLock("b1", smallIterator, memOnly) + store.putIteratorAndReleaseLock("b2", smallIterator, memOnly) // Unroll with not enough space. This should succeed but kick out b1 in the process. val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) @@ -1142,7 +1182,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memoryStore.contains("b3")) assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") - store.putIterator("b3", smallIterator, memOnly) + store.putIteratorAndReleaseLock("b3", smallIterator, memOnly) // Unroll huge block with not enough space. This should fail and kick out b2 in the process. val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, returnValues = true) @@ -1169,8 +1209,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisTask === 0) - store.putIterator("b1", smallIterator, memAndDisk) - store.putIterator("b2", smallIterator, memAndDisk) + store.putIteratorAndReleaseLock("b1", smallIterator, memAndDisk) + store.putIteratorAndReleaseLock("b2", smallIterator, memAndDisk) // Unroll with not enough space. This should succeed but kick out b1 in the process. // Memory store should contain b2 and b3, while disk store should contain only b1 @@ -1183,7 +1223,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b2")) assert(!diskStore.contains("b3")) memoryStore.remove("b3") - store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) + store.putIteratorAndReleaseLock("b3", smallIterator, StorageLevel.MEMORY_ONLY) assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk @@ -1244,6 +1284,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore val blockId = BlockId("rdd_3_10") + store.blockInfoManager.lockNewBlockForWriting( + blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, tellMaster = false)) val result = memoryStore.putBytes(blockId, 13000, () => { fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") }) @@ -1263,4 +1305,104 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result.size === 10000) assert(result.data === Right(bytes)) } + + test("read-locked blocks cannot be evicted from the MemoryStore") { + store = makeBlockManager(12000) + val arr = new Array[Byte](4000) + // First store a1 and a2, both in memory, and a3, on disk only + store.putSingleAndReleaseLock("a1", arr, StorageLevel.MEMORY_ONLY_SER) + store.putSingleAndReleaseLock("a2", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a2").isDefined, "a2 was not in store") + // This put should fail because both a1 and a2 should be read-locked: + store.putSingleAndReleaseLock("a3", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a3").isEmpty, "a3 was in store") + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a2").isDefined, "a2 was not in store") + // Release both pins of block a2: + store.releaseLock("a2") + store.releaseLock("a2") + // Block a1 is the least-recently accessed, so an LRU eviction policy would evict it before + // block a2. However, a1 is still pinned so this put of a3 should evict a2 instead: + store.putSingleAndReleaseLock("a3", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a2").isEmpty, "a2 was in store") + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a3").isDefined, "a3 was not in store") + } +} + +private object BlockManagerSuite { + + private implicit class BlockManagerTestUtils(store: BlockManager) { + + def putSingleAndReleaseLock( + block: BlockId, + value: Any, + storageLevel: StorageLevel, + tellMaster: Boolean): Unit = { + if (store.putSingle(block, value, storageLevel, tellMaster)) { + store.releaseLock(block) + } + } + + def putSingleAndReleaseLock(block: BlockId, value: Any, storageLevel: StorageLevel): Unit = { + if (store.putSingle(block, value, storageLevel)) { + store.releaseLock(block) + } + } + + def putIteratorAndReleaseLock( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel): Unit = { + if (store.putIterator(blockId, values, level)) { + store.releaseLock(blockId) + } + } + + def putIteratorAndReleaseLock( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + tellMaster: Boolean): Unit = { + if (store.putIterator(blockId, values, level, tellMaster)) { + store.releaseLock(blockId) + } + } + + def dropFromMemoryIfExists( + blockId: BlockId, + data: () => Either[Array[Any], ByteBuffer]): Unit = { + store.blockInfoManager.lockForWriting(blockId).foreach { info => + val newEffectiveStorageLevel = store.dropFromMemory(blockId, data) + if (newEffectiveStorageLevel.isValid) { + // The block is still present in at least one store, so release the lock + // but don't delete the block info + store.releaseLock(blockId) + } else { + // The block isn't present in any store, so delete the block info so that the + // block can be stored again + store.blockInfoManager.removeBlock(blockId) + } + } + } + + private def wrapGet[T](f: BlockId => Option[T]): BlockId => Option[T] = (blockId: BlockId) => { + val result = f(blockId) + if (result.isDefined) { + store.releaseLock(blockId) + } + result + } + + def hasLocalBlock(blockId: BlockId): Boolean = { + getLocalAndReleaseLock(blockId).isDefined + } + + val getLocalAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.getLocal) + val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) + val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) + val getLocalBytesAndReleaseLock: (BlockId) => Option[ByteBuffer] = wrapGet(store.getLocalBytes) + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index f55b884bc45ce..631d767715256 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -28,7 +28,7 @@ /** * A {@link ManagedBuffer} backed by {@link ByteBuffer}. */ -public final class NioManagedBuffer extends ManagedBuffer { +public class NioManagedBuffer extends ManagedBuffer { private final ByteBuffer buf; public NioManagedBuffer(ByteBuffer buf) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 83d7953aaf700..efa2eeaf4d751 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -46,7 +46,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) + maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) + maybeBlock.nonEmpty } test("withColumn doesn't invalidate cached dataframe") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 11863caffed75..86f02e68e5b0e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -40,7 +40,9 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) + maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) + maybeBlock.nonEmpty } test("cache table") { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index e22e320b17126..3d9c085013ff5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -91,6 +91,8 @@ private[streaming] class BlockManagerBasedBlockHandler( if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") + } else { + blockManager.releaseLock(blockId) } BlockManagerBasedStoreResult(blockId, numRecords) } @@ -189,6 +191,8 @@ private[streaming] class WriteAheadLogBasedBlockHandler( if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") + } else { + blockManager.releaseLock(blockId) } } From 1b39fafa75a162f183824ff2daa61d73b05ebc83 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 25 Feb 2016 20:17:48 -0800 Subject: [PATCH 2037/2089] [SPARK-13361][SQL] Add benchmark codes for Encoder#compress() in CompressionSchemeBenchmark This pr added benchmark codes for Encoder#compress(). Also, it replaced the benchmark results with new ones because the output format of `Benchmark` changed. Author: Takeshi YAMAMURO Closes #11236 from maropu/CompressionSpike. --- .../CompressionSchemeBenchmark.scala | 282 ++++++++++++------ 1 file changed, 193 insertions(+), 89 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 95eb5cf912e2a..0000a5d1efd09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -17,19 +17,13 @@ package org.apache.spark.sql.execution.columnar.compression -import java.nio.ByteBuffer -import java.nio.ByteOrder +import java.nio.{ByteBuffer, ByteOrder} import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} -import org.apache.spark.sql.execution.columnar.BOOLEAN -import org.apache.spark.sql.execution.columnar.INT -import org.apache.spark.sql.execution.columnar.LONG -import org.apache.spark.sql.execution.columnar.NativeColumnType -import org.apache.spark.sql.execution.columnar.SHORT -import org.apache.spark.sql.execution.columnar.STRING +import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} import org.apache.spark.sql.types.AtomicType import org.apache.spark.util.Benchmark import org.apache.spark.util.Utils._ @@ -53,35 +47,70 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { () => rng.sample } - private[this] def runBenchmark[T <: AtomicType]( + private[this] def prepareEncodeInternal[T <: AtomicType]( + count: Int, + tpe: NativeColumnType[T], + supportedScheme: CompressionScheme, + input: ByteBuffer): ((ByteBuffer, ByteBuffer) => ByteBuffer, Double, ByteBuffer) = { + assert(supportedScheme.supports(tpe)) + + def toRow(d: Any) = new GenericInternalRow(Array[Any](d)) + val encoder = supportedScheme.encoder(tpe) + for (i <- 0 until count) { + encoder.gatherCompressibilityStats(toRow(tpe.extract(input)), 0) + } + input.rewind() + + val compressedSize = if (encoder.compressedSize == 0) { + input.remaining() + } else { + encoder.compressedSize + } + + (encoder.compress, encoder.compressionRatio, allocateLocal(4 + compressedSize)) + } + + private[this] def runEncodeBenchmark[T <: AtomicType]( name: String, iters: Int, count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) schemes.filter(_.supports(tpe)).map { scheme => - def toRow(d: Any) = new GenericInternalRow(Array[Any](d)) - val encoder = scheme.encoder(tpe) - for (i <- 0 until count) { - encoder.gatherCompressibilityStats(toRow(tpe.extract(input)), 0) - } - input.rewind() + val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) + val label = s"${getFormattedClassName(scheme)}(${compressionRatio.formatted("%.3f")})" - val label = s"${getFormattedClassName(scheme)}(${encoder.compressionRatio.formatted("%.3f")})" benchmark.addCase(label)({ i: Int => - val compressedSize = if (encoder.compressedSize == 0) { - input.remaining() - } else { - encoder.compressedSize + for (n <- 0L until iters) { + compressFunc(input, buf) + input.rewind() + buf.rewind() } + }) + } + + benchmark.run() + } - val buf = allocateLocal(4 + compressedSize) + private[this] def runDecodeBenchmark[T <: AtomicType]( + name: String, + iters: Int, + count: Int, + tpe: NativeColumnType[T], + input: ByteBuffer): Unit = { + val benchmark = new Benchmark(name, iters * count) + + schemes.filter(_.supports(tpe)).map { scheme => + val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) + val compressedBuf = compressFunc(input, buf) + val label = s"${getFormattedClassName(scheme)}" + + input.rewind() + + benchmark.addCase(label)({ i: Int => val rowBuf = new GenericMutableRow(1) - val compressedBuf = encoder.compress(input, buf) - input.rewind() for (n <- 0L until iters) { compressedBuf.rewind.position(4) @@ -96,16 +125,10 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { benchmark.run() } - def bitDecode(iters: Int): Unit = { + def bitEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * BOOLEAN.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // BOOLEAN Decode: Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 124.98 536.96 1.00 X - // RunLengthEncoding(2.494) 631.37 106.29 0.20 X - // BooleanBitSet(0.125) 1200.36 55.91 0.10 X val g = { val rng = genLowerSkewData() () => (rng().toInt % 2).toByte @@ -113,110 +136,176 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { for (i <- 0 until count) { testData.put(i * BOOLEAN.defaultSize, g()) } - runBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 3 / 4 19300.2 0.1 1.0X + // RunLengthEncoding(2.491) 923 / 939 72.7 13.8 0.0X + // BooleanBitSet(0.125) 359 / 363 187.1 5.3 0.0X + runEncodeBenchmark("BOOLEAN Encode", iters, count, BOOLEAN, testData) + + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 129 / 136 519.8 1.9 1.0X + // RunLengthEncoding 613 / 623 109.4 9.1 0.2X + // BooleanBitSet 1196 / 1222 56.1 17.8 0.1X + runDecodeBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) } - def shortDecode(iters: Int): Unit = { + def shortEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * SHORT.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Decode (Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 376.87 178.07 1.00 X - // RunLengthEncoding(1.498) 831.59 80.70 0.45 X val g1 = genLowerSkewData() for (i <- 0 until count) { testData.putShort(i * SHORT.defaultSize, g1().toShort) } - runBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Decode (Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 426.83 157.23 1.00 X - // RunLengthEncoding(1.996) 845.56 79.37 0.50 X + // SHORT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 6 / 7 10971.4 0.1 1.0X + // RunLengthEncoding(1.510) 1526 / 1542 44.0 22.7 0.0X + runEncodeBenchmark("SHORT Encode (Lower Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 811 / 837 82.8 12.1 1.0X + // RunLengthEncoding 1219 / 1266 55.1 18.2 0.7X + runDecodeBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) + val g2 = genHigherSkewData() for (i <- 0 until count) { testData.putShort(i * SHORT.defaultSize, g2().toShort) } - runBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 7 / 7 10112.4 0.1 1.0X + // RunLengthEncoding(2.009) 1623 / 1661 41.4 24.2 0.0X + runEncodeBenchmark("SHORT Encode (Higher Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 818 / 827 82.0 12.2 1.0X + // RunLengthEncoding 1202 / 1237 55.8 17.9 0.7X + runDecodeBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) } - def intDecode(iters: Int): Unit = { + def intEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * INT.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Decode(Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 325.16 206.39 1.00 X - // RunLengthEncoding(0.997) 1219.44 55.03 0.27 X - // DictionaryEncoding(0.500) 955.51 70.23 0.34 X - // IntDelta(0.250) 1146.02 58.56 0.28 X val g1 = genLowerSkewData() for (i <- 0 until count) { testData.putInt(i * INT.defaultSize, g1().toInt) } - runBenchmark("INT Decode(Lower Skew)", iters, count, INT, testData) // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Decode(Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 1133.45 59.21 1.00 X - // RunLengthEncoding(1.334) 1399.00 47.97 0.81 X - // DictionaryEncoding(0.501) 1032.87 64.97 1.10 X - // IntDelta(0.250) 948.02 70.79 1.20 X + // INT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 18 / 19 3716.4 0.3 1.0X + // RunLengthEncoding(1.001) 1992 / 2056 33.7 29.7 0.0X + // DictionaryEncoding(0.500) 723 / 739 92.8 10.8 0.0X + // IntDelta(0.250) 368 / 377 182.2 5.5 0.0X + runEncodeBenchmark("INT Encode (Lower Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 821 / 845 81.8 12.2 1.0X + // RunLengthEncoding 1246 / 1256 53.9 18.6 0.7X + // DictionaryEncoding 757 / 766 88.6 11.3 1.1X + // IntDelta 680 / 689 98.7 10.1 1.2X + runDecodeBenchmark("INT Decode (Lower Skew)", iters, count, INT, testData) + val g2 = genHigherSkewData() for (i <- 0 until count) { testData.putInt(i * INT.defaultSize, g2().toInt) } - runBenchmark("INT Decode(Higher Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 17 / 19 3888.4 0.3 1.0X + // RunLengthEncoding(1.339) 2127 / 2148 31.5 31.7 0.0X + // DictionaryEncoding(0.501) 960 / 972 69.9 14.3 0.0X + // IntDelta(0.250) 362 / 366 185.5 5.4 0.0X + runEncodeBenchmark("INT Encode (Higher Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 838 / 884 80.1 12.5 1.0X + // RunLengthEncoding 1287 / 1311 52.1 19.2 0.7X + // DictionaryEncoding 844 / 859 79.5 12.6 1.0X + // IntDelta 764 / 784 87.8 11.4 1.1X + runDecodeBenchmark("INT Decode (Higher Skew)", iters, count, INT, testData) } - def longDecode(iters: Int): Unit = { + def longEncodingBenchmark(iters: Int): Unit = { val count = 65536 val testData = allocateLocal(count * LONG.defaultSize) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Decode(Lower Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 1101.07 60.95 1.00 X - // RunLengthEncoding(0.756) 1372.57 48.89 0.80 X - // DictionaryEncoding(0.250) 947.80 70.81 1.16 X - // LongDelta(0.125) 721.51 93.01 1.53 X val g1 = genLowerSkewData() for (i <- 0 until count) { testData.putLong(i * LONG.defaultSize, g1().toLong) } - runBenchmark("LONG Decode(Lower Skew)", iters, count, LONG, testData) // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Decode(Higher Skew): Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 986.71 68.01 1.00 X - // RunLengthEncoding(1.013) 1348.69 49.76 0.73 X - // DictionaryEncoding(0.251) 865.48 77.54 1.14 X - // LongDelta(0.125) 816.90 82.15 1.21 X + // LONG Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 37 / 38 1804.8 0.6 1.0X + // RunLengthEncoding(0.748) 2065 / 2094 32.5 30.8 0.0X + // DictionaryEncoding(0.250) 950 / 962 70.6 14.2 0.0X + // LongDelta(0.125) 475 / 482 141.2 7.1 0.1X + runEncodeBenchmark("LONG Encode (Lower Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 888 / 894 75.5 13.2 1.0X + // RunLengthEncoding 1301 / 1311 51.6 19.4 0.7X + // DictionaryEncoding 887 / 904 75.7 13.2 1.0X + // LongDelta 693 / 735 96.8 10.3 1.3X + runDecodeBenchmark("LONG Decode (Lower Skew)", iters, count, LONG, testData) + val g2 = genHigherSkewData() for (i <- 0 until count) { testData.putLong(i * LONG.defaultSize, g2().toLong) } - runBenchmark("LONG Decode(Higher Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 34 / 35 1963.9 0.5 1.0X + // RunLengthEncoding(0.999) 2260 / 3021 29.7 33.7 0.0X + // DictionaryEncoding(0.251) 1270 / 1438 52.8 18.9 0.0X + // LongDelta(0.125) 496 / 509 135.3 7.4 0.1X + runEncodeBenchmark("LONG Encode (Higher Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 965 / 1494 69.5 14.4 1.0X + // RunLengthEncoding 1350 / 1378 49.7 20.1 0.7X + // DictionaryEncoding 892 / 924 75.2 13.3 1.1X + // LongDelta 817 / 847 82.2 12.2 1.2X + runDecodeBenchmark("LONG Decode (Higher Skew)", iters, count, LONG, testData) } - def stringDecode(iters: Int): Unit = { + def stringEncodingBenchmark(iters: Int): Unit = { val count = 65536 val strLen = 8 val tableSize = 16 val testData = allocateLocal(count * (4 + strLen)) - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // STRING Decode: Avg Time(ms) Avg Rate(M/s) Relative Rate - // ------------------------------------------------------------------------------- - // PassThrough(1.000) 2277.05 29.47 1.00 X - // RunLengthEncoding(0.893) 2624.35 25.57 0.87 X - // DictionaryEncoding(0.167) 2672.28 25.11 0.85 X val g = { val dataTable = (0 until tableSize).map(_ => RandomStringUtils.randomAlphabetic(strLen)) val rng = genHigherSkewData() @@ -227,14 +316,29 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.put(g().getBytes) } testData.rewind() - runBenchmark("STRING Decode", iters, count, STRING, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 56 / 57 1197.9 0.8 1.0X + // RunLengthEncoding(0.893) 4892 / 4937 13.7 72.9 0.0X + // DictionaryEncoding(0.167) 2968 / 2992 22.6 44.2 0.0X + runEncodeBenchmark("STRING Encode", iters, count, STRING, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 2422 / 2449 27.7 36.1 1.0X + // RunLengthEncoding 2885 / 3018 23.3 43.0 0.8X + // DictionaryEncoding 2716 / 2752 24.7 40.5 0.9X + runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) } def main(args: Array[String]): Unit = { - bitDecode(1024) - shortDecode(1024) - intDecode(1024) - longDecode(1024) - stringDecode(1024) + bitEncodingBenchmark(1024) + shortEncodingBenchmark(1024) + intEncodingBenchmark(1024) + longEncodingBenchmark(1024) + stringEncodingBenchmark(1024) } } From 90d07154c2cef3d1095cb3caeafa7003218a3e49 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 25 Feb 2016 21:04:35 -0800 Subject: [PATCH 2038/2089] [SPARK-13028] [ML] Add MaxAbsScaler to ML.feature as a transformer jira: https://issues.apache.org/jira/browse/SPARK-13028 MaxAbsScaler works in a very similar way as MinMaxScaler, but scales in a way that the training data lies within the range [-1, 1] by dividing through the largest maximum value in each feature. The motivation to use this scaling includes robustness to very small standard deviations of features and preserving zero entries in sparse data. Unlike StandardScaler and MinMaxScaler, MaxAbsScaler does not shift/center the data, and thus does not destroy any sparsity. Something similar from sklearn: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html#sklearn.preprocessing.MaxAbsScaler Author: Yuhao Yang Closes #10939 from hhbyyh/maxabs and squashes the following commits: fd8bdcd [Yuhao Yang] add tag and some optimization on fit 648fced [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into maxabs 75bebc2 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into maxabs cb10bb6 [Yuhao Yang] remove minmax 91ef8f3 [Yuhao Yang] ut added 8ab0747 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into maxabs a9215b5 [Yuhao Yang] max abs scaler --- .../spark/ml/feature/MaxAbsScaler.scala | 176 ++++++++++++++++++ .../spark/ml/feature/MaxAbsScalerSuite.scala | 70 +++++++ 2 files changed, 246 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala new file mode 100644 index 0000000000000..15c308b238126 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MaxAbsScaler]] and [[MaxAbsScalerModel]]. + */ +private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to range [-1, 1] by dividing through the largest maximum + * absolute value in each feature. It does not shift/center the data, and thus does not destroy + * any sparsity. + */ +@Experimental +class MaxAbsScaler @Since("2.0.0") (override val uid: String) + extends Estimator[MaxAbsScalerModel] with MaxAbsScalerParams with DefaultParamsWritable { + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("maxAbsScal")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame): MaxAbsScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + val minVals = summary.min.toArray + val maxVals = summary.max.toArray + val n = minVals.length + val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) } + + copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs)).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MaxAbsScaler = defaultCopy(extra) +} + +@Since("1.6.0") +object MaxAbsScaler extends DefaultParamsReadable[MaxAbsScaler] { + + @Since("1.6.0") + override def load(path: String): MaxAbsScaler = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[MaxAbsScaler]]. + * + */ +@Experimental +class MaxAbsScalerModel private[ml] ( + override val uid: String, + val maxAbs: Vector) + extends Model[MaxAbsScalerModel] with MaxAbsScalerParams with MLWritable { + + import MaxAbsScalerModel._ + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + // TODO: this looks hack, we may have to handle sparse and dense vectors separately. + val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) + val reScale = udf { (vector: Vector) => + val brz = vector.toBreeze / maxAbsUnzero.toBreeze + Vectors.fromBreeze(brz) + } + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MaxAbsScalerModel = { + val copied = new MaxAbsScalerModel(uid, maxAbs) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new MaxAbsScalerModelWriter(this) +} + +@Since("1.6.0") +object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { + + private[MaxAbsScalerModel] + class MaxAbsScalerModelWriter(instance: MaxAbsScalerModel) extends MLWriter { + + private case class Data(maxAbs: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.maxAbs) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MaxAbsScalerModelReader extends MLReader[MaxAbsScalerModel] { + + private val className = classOf[MaxAbsScalerModel].getName + + override def load(path: String): MaxAbsScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath) + .select("maxAbs") + .head() + val model = new MaxAbsScalerModel(metadata.uid, maxAbs) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[MaxAbsScalerModel] = new MaxAbsScalerModelReader + + @Since("1.6.0") + override def load(path: String): MaxAbsScalerModel = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala new file mode 100644 index 0000000000000..e083d4713680e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + test("MaxAbsScaler fit basic case") { + val data = Array( + Vectors.dense(1, 0, 100), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(-2, -100)), + Vectors.sparse(3, Array(0), Array(-1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(0.5, 0, 1), + Vectors.dense(1, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(-1, -1)), + Vectors.sparse(3, Array(0), Array(-0.75))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MaxAbsScaler() + .setInputCol("features") + .setOutputCol("scaled") + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("MaxAbsScaler read/write") { + val t = new MaxAbsScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } + + test("MaxAbsScalerModel read/write") { + val instance = new MaxAbsScalerModel( + "myMaxAbsScalerModel", Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.maxAbs === instance.maxAbs) + } + +} From f3be369ef7b78944ce9cafc94b19df52772cea58 Mon Sep 17 00:00:00 2001 From: Tommy YU Date: Thu, 25 Feb 2016 21:09:02 -0800 Subject: [PATCH 2039/2089] [SPARK-13033] [ML] [PYSPARK] Add import/export for ml.regression Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/regression.py. yanboliang Please help to review. For doctest, I though it's enough to add one since it's common usage. But I can add to all if we want it. Author: Tommy YU Closes #11000 from Wenpei/spark-13033-ml.regression-exprot-import and squashes the following commits: 3646b36 [Tommy YU] address review comments 9cddc98 [Tommy YU] change base on review and pr 11197 cc61d9d [Tommy YU] remove default parameter set 19535d4 [Tommy YU] add export/import to regression 44a9dc2 [Tommy YU] add import/export for ml.regression --- python/pyspark/ml/regression.py | 34 +++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de4a751a54796..6b994fe9f93b4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -154,7 +154,7 @@ def intercept(self): @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol): + HasWeightCol, MLWritable, MLReadable): """ .. note:: Experimental @@ -172,6 +172,18 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti 0.0 >>> model.boundaries DenseVector([0.0, 1.0]) + >>> ir_path = temp_path + "/ir" + >>> ir.save(ir_path) + >>> ir2 = IsotonicRegression.load(ir_path) + >>> ir2.getIsotonic() + True + >>> model_path = temp_path + "/ir_model" + >>> model.save(model_path) + >>> model2 = IsotonicRegressionModel.load(model_path) + >>> model.boundaries == model2.boundaries + True + >>> model.predictions == model2.predictions + True """ isotonic = \ @@ -237,7 +249,7 @@ def getFeatureIndex(self): return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel): +class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable): """ .. note:: Experimental @@ -663,7 +675,7 @@ class GBTRegressionModel(TreeEnsembleModels): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol): + HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -690,6 +702,20 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | 0.0|(1,[],[])| 0.0| 1.0| +-----+---------+------+----------+ ... + >>> aftsr_path = temp_path + "/aftsr" + >>> aftsr.save(aftsr_path) + >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path) + >>> aftsr2.getMaxIter() + 100 + >>> model_path = temp_path + "/aftsr_model" + >>> model.save(model_path) + >>> model2 = AFTSurvivalRegressionModel.load(model_path) + >>> model.coefficients == model2.coefficients + True + >>> model.intercept == model2.intercept + True + >>> model.scale == model2.scale + True .. versionadded:: 1.6.0 """ @@ -787,7 +813,7 @@ def getQuantilesCol(self): return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel): +class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by AFTSurvivalRegression. From 50e60e36f7775a10cf39338e7c5716578a24d89f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 25 Feb 2016 21:23:41 -0800 Subject: [PATCH 2040/2089] [SPARK-13504] [SPARKR] Add approxQuantile for SparkR ## What changes were proposed in this pull request? Add ```approxQuantile``` for SparkR. ## How was this patch tested? unit tests Author: Yanbo Liang Closes #11383 from yanboliang/spark-13504 and squashes the following commits: 4f17adb [Yanbo Liang] Add approxQuantile for SparkR --- R/pkg/NAMESPACE | 1 + R/pkg/R/generics.R | 7 ++++ R/pkg/R/stats.R | 39 +++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 8 +++++ 4 files changed, 55 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 6a3d63f43f785..636d39e1e9cae 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -111,6 +111,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "approxQuantile", "array_contains", "asc", "ascii", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ab61bce03df23..3db72b57954d7 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -67,6 +67,13 @@ setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) +# @rdname statfunctions +# @export +setGeneric("approxQuantile", + function(x, col, probabilities, relativeError) { + standardGeneric("approxQuantile") + }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 2e8076843f08a..edf72937c633a 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -130,6 +130,45 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), collect(dataFrame(sct)) }) +#' approxQuantile +#' +#' Calculates the approximate quantiles of a numerical column of a DataFrame. +#' +#' The result of this algorithm has the following deterministic bound: +#' If the DataFrame has N elements and if we request the quantile at probability `p` up to error +#' `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank +#' of `x` is close to (p * N). More precisely, +#' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). +#' This method implements a variation of the Greenwald-Khanna algorithm (with some speed +#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 +#' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. +#' +#' @param x A SparkSQL DataFrame. +#' @param col The name of the numerical column. +#' @param probabilities A list of quantile probabilities. Each number must belong to [0, 1]. +#' For example 0 is the minimum, 0.5 is the median, 1 is the maximum. +#' @param relativeError The relative target precision to achieve (>= 0). If set to zero, +#' the exact quantiles are computed, which could be very expensive. +#' Note that values greater than 1 are accepted but give the same result as 1. +#' @return The approximate quantiles at the given probabilities. +#' +#' @rdname statfunctions +#' @name approxQuantile +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) +#' } +setMethod("approxQuantile", + signature(x = "DataFrame", col = "character", + probabilities = "numeric", relativeError = "numeric"), + function(x, col, probabilities, relativeError) { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "approxQuantile", col, + as.list(probabilities), relativeError) + }) + #' sampleBy #' #' Returns a stratified sample without replacement based on the fraction given on each stratum. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cc118108f61cc..236bae6bded25 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1785,6 +1785,14 @@ test_that("sampleBy() on a DataFrame", { expect_identical(as.list(result[2, ]), list(key = "1", count = 7)) }) +test_that("approxQuantile() on a DataFrame", { + l <- lapply(c(0:99), function(i) { i }) + df <- createDataFrame(sqlContext, l, "key") + quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) + expect_equal(quantiles[[1]], 50) + expect_equal(quantiles[[2]], 80) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) expect_equal(grepl("Table not found: blah", retError), TRUE) From 8afe49141d9b6a603eb3907f32dce802a3d05172 Mon Sep 17 00:00:00 2001 From: thomastechs Date: Thu, 25 Feb 2016 22:52:25 -0800 Subject: [PATCH 2041/2089] [SPARK-12941][SQL][MASTER] Spark-SQL JDBC Oracle dialect fails to map string datatypes to Oracle VARCHAR datatype ## What changes were proposed in this pull request? This Pull request is used for the fix SPARK-12941, creating a data type mapping to Oracle for the corresponding data type"Stringtype" from dataframe. This PR is for the master branch fix, where as another PR is already tested with the branch 1.4 ## How was the this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) This patch was tested using the Oracle docker .Created a new integration suite for the same.The oracle.jdbc jar was to be downloaded from the maven repository.Since there was no jdbc jar available in the maven repository, the jar was downloaded from oracle site manually and installed in the local; thus tested. So, for SparkQA test case run, the ojdbc jar might be manually placed in the local maven repository(com/oracle/ojdbc6/11.2.0.2.0) while Spark QA test run. Author: thomastechs Closes #11306 from thomastechs/master. --- docker-integration-tests/pom.xml | 13 +++ .../sql/jdbc/OracleIntegrationSuite.scala | 80 +++++++++++++++++++ .../apache/spark/sql/jdbc/OracleDialect.scala | 5 ++ 3 files changed, 98 insertions(+) create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index 833ca29cd8218..048e58d88869f 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -131,6 +131,19 @@ postgresql test + + nodes to be split in tree + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree. * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, * where nodeIndexInfo stores the index in the group and the * feature subsets (if using feature subsets). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). - * Updated with new non-leaf nodes which are created. + * @param splits Possible splits for all features, indexed (numFeatures)(numSplits). + * @param bins Possible bins for all features, indexed (numFeatures)(numBins). + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where * each value in the array is the data point's node Id * for a corresponding tree. This is used to prevent the need @@ -527,10 +528,10 @@ object DecisionTree extends Serializable with Logging { * Each data point contributes to one node. For each feature, * the aggregate sufficient statistics are updated for the relevant bins. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - * @return agg + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return Array of decision tree statistics. */ def binSeqOp( agg: Array[DTStatsAggregator], @@ -563,6 +564,7 @@ object DecisionTree extends Serializable with Logging { /** * Get node index in group --> features indices map, * which is a short cut to find feature indices for a node given node index in group + * * @param treeToNodeToIndexInfo * @return */ @@ -719,9 +721,10 @@ object DecisionTree extends Serializable with Logging { /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * + * @param leftImpurityCalculator Left node aggregates for this (feature, split). + * @param rightImpurityCalculator Right node aggregate for this (feature, split). + * @return Information gain and statistics for split. */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, @@ -771,9 +774,10 @@ object DecisionTree extends Serializable with Logging { /** * Calculate predict value for current node, given stats of any split. * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node + * + * @param leftImpurityCalculator Left node aggregates for a split. + * @param rightImpurityCalculator Right node aggregates for a split. + * @return Predict value and impurity for current node. */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, @@ -788,8 +792,9 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. + * * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) + * @return Tuple for best split: (Split, information gain, prediction at node). */ private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, @@ -955,8 +960,8 @@ object DecisionTree extends Serializable with Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param metadata Learning and dataset metadata + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param metadata Learning and dataset metadata. * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -1102,12 +1107,13 @@ object DecisionTree extends Serializable with Logging { * NOTE: Returned number of splits is set based on `featureSamples` and * could be different from the specified `numSplits`. * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. - * @param featureSamples feature values of each sample - * @param metadata decision tree metadata + * + * @param featureSamples Feature values of each sample. + * @param metadata Decision tree metadata. * NOTE: `metadata.numbins` will be changed accordingly - * if there are not enough splits to be found - * @param featureIndex feature index to find splits - * @return array of splits + * if there are not enough splits to be found. + * @param featureIndex Feature index to find splits. + * @return Array of splits. */ private[tree] def findSplitsForContinuousFeature( featureSamples: Array[Double], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1b71256c585bd..d131f5da6c7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -54,8 +54,9 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to train a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -82,13 +83,14 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to validate a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.4.0") def runWithValidation( @@ -132,7 +134,7 @@ object GradientBoostedTrees extends Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def train( @@ -153,11 +155,11 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. - * @param input training dataset - * @param validationInput validation dataset, ignored if validate is set to false. - * @param boostingStrategy boosting parameters - * @param validate whether or not to use the validation dataset. - * @return a gradient boosted trees model that can be used for prediction + * @param input Training dataset. + * @param validationInput Validation dataset, ignored if validate is set to false. + * @param boostingStrategy Boosting parameters. + * @param validate Whether or not to use the validation dataset. + * @return GradientBoostedTreesModel that can be used for prediction. */ private def boost( input: RDD[LabeledPoint], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 570a76f960796..b7714b382a594 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -53,12 +53,12 @@ import org.apache.spark.util.random.SamplingUtils * random forests]] * * @param strategy The configuration parameters for the random forest algorithm which specify - * the type of algorithm (classification, regression, etc.), feature type + * the type of random forest (classification or regression), feature type * (continuous, categorical), depth of the tree, quantile calculation strategy, * etc. * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and @@ -121,8 +121,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return a random forest model that can be used for prediction + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return RandomForestModel that can be used for prediction. */ def run(input: RDD[LabeledPoint]): RandomForestModel = { @@ -269,12 +270,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -294,25 +295,25 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -358,12 +359,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( @@ -383,24 +384,24 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 6c04403f1ad75..9e3e50192d507 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -34,8 +34,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], * [[org.apache.spark.mllib.tree.impurity.Entropy]]. * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. * (Ignored for regression.) * Default value is 2 (binary classification). @@ -45,10 +45,9 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * number of discrete values they take. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param minInstancesPerNode Minimum number of instances each child must have after split. * Default value is 1. If a split cause left or right child * to have less than minInstancesPerNode, @@ -60,7 +59,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * 256 MB. * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will - * maintain a separate RDD of node Id cache for each row. + * maintain a separate RDD of node Id cache for each row. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 0001b60093a69..f7ea466b43291 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -60,8 +60,7 @@ def numTrees(self): @since("1.3.0") def totalNumNodes(self): """ - Get total number of nodes, summed over all trees in the - ensemble. + Get total number of nodes, summed over all trees in the ensemble. """ return self.call("totalNumNodes") @@ -92,8 +91,9 @@ def predict(self, x): transformation or action. Call predict directly on the RDD instead. - :param x: Data point (feature vector), - or an RDD of data points (feature vectors). + :param x: + Data point (feature vector), or an RDD of data points (feature + vectors). """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -108,8 +108,9 @@ def numNodes(self): @since("1.1.0") def depth(self): - """Get depth of tree. - E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ + Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). """ return self._java_model.depth() @@ -152,24 +153,37 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ - Train a DecisionTreeModel for classification. - - :param data: Training data: RDD of LabeledPoint. - Labels are integers {0,1,...,numClasses}. - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: Supported values: "entropy" or "gini" - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split - :param minInfoGain: Min info gain required to create a split - :return: DecisionTreeModel + Train a decision tree model for classification. + + :param data: + Training data: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + :param numClasses: + Number of classes for classification. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param impurity: + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + :param maxBins: + Number of bins used for finding splits at each node. + (default: 32) + :param minInstancesPerNode: + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + :param minInfoGain: + Minimum info gain required to create a split. + (default: 0.0) + :return: + DecisionTreeModel. Example usage: @@ -211,23 +225,34 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ - Train a DecisionTreeModel for regression. - - :param data: Training data: RDD of LabeledPoint. - Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature - index to number of categories. - Any feature not in this map is treated as continuous. - :param impurity: Supported values: "variance" - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each - node. - :param minInstancesPerNode: Min number of instances required at - child nodes to create the parent split - :param minInfoGain: Min info gain required to create a split - :return: DecisionTreeModel + Train a decision tree model for regression. + + :param data: + Training data: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param impurity: + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + :param maxBins: + Number of bins used for finding splits at each node. + (default: 32) + :param minInstancesPerNode: + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + :param minInfoGain: + Minimum info gain required to create a split. + (default: 0.0) + :return: + DecisionTreeModel. Example usage: @@ -302,34 +327,44 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): """ - Method to train a decision tree model for binary or multiclass + Train a random forest model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels - should take values {0, 1, ..., numClasses-1}. - :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that - feature n is categorical with k categories indexed - from 0: {0, 1, ..., k-1}. - :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain calculation. - Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. - E.g., depth 0 means 1 leaf node; depth 1 means - 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting - features - (default: 32) - :param seed: Random seed for bootstrapping and choosing feature - subsets. - :return: RandomForestModel that can be used for prediction + :param data: + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + :param numClasses: + Number of classes for classification. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param numTrees: + Number of trees in the random forest. + :param featureSubsetStrategy: + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + (default: "auto") + :param impurity: + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + :param maxBins: + Maximum number of bins used for splitting features. + (default: 32) + :param seed: + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + :return: + RandomForestModel that can be used for prediction. Example usage: @@ -383,32 +418,40 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ - Method to train a decision tree model for regression. - - :param data: Training dataset: RDD of LabeledPoint. Labels are - real numbers. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain - calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting - features (default: 32) - :param seed: Random seed for bootstrapping and choosing feature - subsets. - :return: RandomForestModel that can be used for prediction + Train a random forest model for regression. + + :param data: + Training dataset: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param numTrees: + Number of trees in the random forest. + :param featureSubsetStrategy: + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + (default: "auto") + :param impurity: + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + :param maxBins: + Maximum number of bins used for splitting features. + (default: 32) + :param seed: + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + :return: + RandomForestModel that can be used for prediction. Example usage: @@ -480,31 +523,37 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ - Method to train a gradient-boosted trees model for - classification. - - :param data: Training dataset: RDD of LabeledPoint. - Labels should take values {0, 1}. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient - boosting. Supported: {"logLoss" (default), - "leastSquaresError", "leastAbsoluteError"}. - :param numIterations: Number of iterations of boosting. - (default: 100) - :param learningRate: Learning rate for shrinking the - contribution of each estimator. The learning rate - should be between in the interval (0, 1]. - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 3) - :param maxBins: maximum number of bins used for splitting - features (default: 32) DecisionTree requires maxBins >= max categories - :return: GradientBoostedTreesModel that can be used for - prediction + Train a gradient-boosted trees model for classification. + + :param data: + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1}. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param loss: + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "logLoss") + :param numIterations: + Number of iterations of boosting. + (default: 100) + :param learningRate: + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + :param maxBins: + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + :return: + GradientBoostedTreesModel that can be used for prediction. Example usage: @@ -543,30 +592,36 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ - Method to train a gradient-boosted trees model for regression. - - :param data: Training dataset: RDD of LabeledPoint. Labels are - real numbers. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient - boosting. Supported: {"logLoss" (default), - "leastSquaresError", "leastAbsoluteError"}. - :param numIterations: Number of iterations of boosting. - (default: 100) - :param learningRate: Learning rate for shrinking the - contribution of each estimator. The learning rate - should be between in the interval (0, 1]. - (default: 0.1) - :param maxBins: maximum number of bins used for splitting - features (default: 32) DecisionTree requires maxBins >= max categories - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 3) - :return: GradientBoostedTreesModel that can be used for - prediction + Train a gradient-boosted trees model for regression. + + :param data: + Training dataset: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param loss: + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "leastSquaresError") + :param numIterations: + Number of iterations of boosting. + (default: 100) + :param learningRate: + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + :param maxBins: + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + :return: + GradientBoostedTreesModel that can be used for prediction. Example usage: From 7af0de076f74e975c9235c88b0f11b22fcbae060 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 26 Feb 2016 08:31:55 -0800 Subject: [PATCH 2048/2089] [SPARK-11381][DOCS] Replace example code in mllib-linear-methods.md using include_example ## What changes were proposed in this pull request? This PR replaces example codes in `mllib-linear-methods.md` using `include_example` by doing the followings: * Extracts the example codes(Scala,Java,Python) as files in `example` module. * Merges some dialog-style examples into a single file. * Hide redundant codes in HTML for the consistency with other docs. ## How was the this patch tested? manual test. This PR can be tested by document generations, `SKIP_API=1 jekyll build`. Author: Dongjoon Hyun Closes #11320 from dongjoon-hyun/SPARK-11381. --- docs/mllib-linear-methods.md | 441 +----------------- .../JavaLinearRegressionWithSGDExample.java | 94 ++++ ...avaLogisticRegressionWithLBFGSExample.java | 79 ++++ .../examples/mllib/JavaSVMWithSGDExample.java | 82 ++++ .../linear_regression_with_sgd_example.py | 54 +++ .../logistic_regression_with_lbfgs_example.py | 54 +++ .../streaming_linear_regression_example.py | 62 +++ .../main/python/mllib/svm_with_sgd_example.py | 47 ++ .../LinearRegressionWithSGDExample.scala | 64 +++ .../LogisticRegressionWithLBFGSExample.scala | 69 +++ .../examples/mllib/SVMWithSGDExample.scala | 70 +++ .../StreamingLinearRegressionExample.scala | 58 +++ 12 files changed, 758 insertions(+), 416 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java create mode 100644 examples/src/main/python/mllib/linear_regression_with_sgd_example.py create mode 100644 examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py create mode 100644 examples/src/main/python/mllib/streaming_linear_regression_example.py create mode 100644 examples/src/main/python/mllib/svm_with_sgd_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index aac8f7560a4f8..63665c49bc972 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -170,42 +170,7 @@ error. Refer to the [`SVMWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMWithSGD) and [`SVMModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -// Run training algorithm to build the model -val numIterations = 100 -val model = SVMWithSGD.train(training, numIterations) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Area under ROC = " + auROC) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = SVMModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala %} The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we @@ -216,6 +181,7 @@ variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. {% highlight scala %} + import org.apache.spark.mllib.optimization.L1Updater val svmAlg = new SVMWithSGD() @@ -237,61 +203,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`SVMWithSGD` Java docs](api/java/org/apache/spark/mllib/classification/SVMWithSGD.html) and [`SVMModel` Java docs](api/java/org/apache/spark/mllib/classification/SVMModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.*; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; - -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class SVMClassifier { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD training = data.sample(false, 0.6, 11L); - training.cache(); - JavaRDD test = data.subtract(training); - - // Run training algorithm to build the model. - int numIterations = 100; - final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); - double auROC = metrics.areaUnderROC(); - - System.out.println("Area under ROC = " + auROC); - - // Save and load model - model.save(sc, "myModelPath"); - SVMModel sameModel = SVMModel.load(sc, "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java %} The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we @@ -325,30 +237,7 @@ and make predictions with the resulting model to compute the training error. Refer to the [`SVMWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMWithSGD) and [`SVMModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import SVMWithSGD, SVMModel -from pyspark.mllib.regression import LabeledPoint - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/sample_svm_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = SVMWithSGD.train(parsedData, iterations=100) - -# Evaluating the model on training data -labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) -print("Training Error = " + str(trainErr)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = SVMModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/svm_with_sgd_example.py %} @@ -406,42 +295,7 @@ Then the model is evaluated against the test dataset and saved to disk. Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(10) - .run(training) - -// Compute raw scores on the test set. -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Get evaluation metrics. -val metrics = new MulticlassMetrics(predictionAndLabels) -val precision = metrics.precision -println("Precision = " + precision) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = LogisticRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala %} @@ -454,57 +308,7 @@ Then the model is evaluated against the test dataset and saved to disk. Refer to the [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) and [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MultinomialLogisticRegressionExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(10) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - double precision = metrics.precision(); - System.out.println("Precision = " + precision); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java %}
        @@ -516,30 +320,7 @@ will in the future. Refer to the [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel -from pyspark.mllib.regression import LabeledPoint - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/sample_svm_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LogisticRegressionWithLBFGS.train(parsedData) - -# Evaluating the model on training data -labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) -print("Training Error = " + str(trainErr)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LogisticRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/logistic_regression_with_lbfgs_example.py %}
        @@ -575,36 +356,7 @@ values. We compute the mean squared error at the end to evaluate Refer to the [`LinearRegressionWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/ridge-data/lpsa.data") -val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -}.cache() - -// Building the model -val numIterations = 100 -val stepSize = 0.00000001 -val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) - -// Evaluate model on training examples and compute training error -val valuesAndPreds = parsedData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = LinearRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`. @@ -620,70 +372,7 @@ the Scala snippet provided, is presented below: Refer to the [`LinearRegressionWithSGD` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionWithSGD.html) and [`LinearRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/ridge-data/lpsa.data"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - String[] features = parts[1].split(" "); - double[] v = new double[features.length]; - for (int i = 0; i < features.length - 1; i++) - v[i] = Double.parseDouble(features[i]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - double stepSize = 0.00000001; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - double MSE = new JavaDoubleRDD(valuesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - return Math.pow(pair._1() - pair._2(), 2.0); - } - } - ).rdd()).mean(); - System.out.println("training Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java %}
        @@ -696,29 +385,7 @@ Note that the Python API does not yet support model save/load but will in the fu Refer to the [`LinearRegressionWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.replace(',', ' ').split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/ridge-data/lpsa.data") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData, iterations=100, step=0.00000001) - -# Evaluate the model on training data -valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LinearRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/linear_regression_with_sgd_example.py %}
        @@ -748,108 +415,50 @@ online to the first stream, and make predictions on the second stream. First, we import the necessary classes for parsing our input data and creating the model. -{% highlight scala %} - -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD - -{% endhighlight %} - Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. -{% highlight scala %} - -val trainingData = ssc.textFileStream("/training/data/dir").map(LabeledPoint.parse).cache() -val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) - -{% endhighlight %} +We create our model by initializing the weights to zero and register the streams for training and +testing then start the job. Printing predictions alongside true labels lets us easily see the +result. -We create our model by initializing the weights to 0 - -{% highlight scala %} - -val numFeatures = 3 -val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.zeros(numFeatures)) - -{% endhighlight %} - -Now we register the streams for training and testing and start the job. -Printing predictions alongside true labels lets us easily see the result. - -{% highlight scala %} - -model.trainOn(trainingData) -model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() - -ssc.start() -ssc.awaitTermination() - -{% endhighlight %} - -We can now save text files with data to the training or testing folders. +Finally we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` +the model will update. Anytime a text file is placed in `args(1)` you will see predictions. As you feed more data to the training directory, the predictions will get better! +Here is a complete example: +{% include_example scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala %} +
        First, we import the necessary classes for parsing our input data and creating the model. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.regression import StreamingLinearRegressionWithSGD -{% endhighlight %} - Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. -{% highlight python %} -def parse(lp): - label = float(lp[lp.find('(') + 1: lp.find(',')]) - vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) - return LabeledPoint(label, vec) - -trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() -testData = ssc.textFileStream("/testing/data/dir").map(parse) -{% endhighlight %} - -We create our model by initializing the weights to 0 - -{% highlight python %} -numFeatures = 3 -model = StreamingLinearRegressionWithSGD() -model.setInitialWeights([0.0, 0.0, 0.0]) -{% endhighlight %} +We create our model by initializing the weights to 0. Now we register the streams for training and testing and start the job. -{% highlight python %} -model.trainOn(trainingData) -print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) - -ssc.start() -ssc.awaitTermination() -{% endhighlight %} - We can now save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +and `x1,x2,x3` are the features. Anytime a text file is placed in `sys.argv[1]` +the model will update. Anytime a text file is placed in `sys.argv[2]` you will see predictions. As you feed more data to the training directory, the predictions will get better! +Here a complete example: +{% include_example python/mllib/streaming_linear_regression_example.py %} +
        diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java new file mode 100644 index 0000000000000..3e50118c0d9ec --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +// $example off$ + +/** + * Example for LinearRegressionWithSGD. + */ +public class JavaLinearRegressionWithSGDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // $example on$ + // Load and parse the data + String path = "data/mllib/ridge-data/lpsa.data"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(","); + String[] features = parts[1].split(" "); + double[] v = new double[features.length]; + for (int i = 0; i < features.length - 1; i++) { + v[i] = Double.parseDouble(features[i]); + } + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + double stepSize = 0.00000001; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + double MSE = new JavaDoubleRDD(valuesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + return Math.pow(pair._1() - pair._2(), 2.0); + } + } + ).rdd()).mean(); + System.out.println("training Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/javaLinearRegressionWithSGDModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java new file mode 100644 index 0000000000000..9d8e4a90dbc99 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +/** + * Example for LogisticRegressionWithLBFGS. + */ +public class JavaLogisticRegressionWithLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + double precision = metrics.precision(); + System.out.println("Precision = " + precision); + + // Save and load model + model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/javaLogisticRegressionWithLBFGSModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java new file mode 100644 index 0000000000000..720b167b2cadf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.SVMModel; +import org.apache.spark.mllib.classification.SVMWithSGD; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +/** + * Example for SVMWithSGD. + */ +public class JavaSVMWithSGDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD training = data.sample(false, 0.6, 11L); + training.cache(); + JavaRDD test = data.subtract(training); + + // Run training algorithm to build the model. + int numIterations = 100; + final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); + double auROC = metrics.areaUnderROC(); + + System.out.println("Area under ROC = " + auROC); + + // Save and load model + model.save(sc, "target/tmp/javaSVMWithSGDModel"); + SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/python/mllib/linear_regression_with_sgd_example.py b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py new file mode 100644 index 0000000000000..6fbaeff0cd5a0 --- /dev/null +++ b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Linear Regression With SGD Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonLinearRegressionWithSGDExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.replace(',', ' ').split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/ridge-data/lpsa.data") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData, iterations=100, step=0.00000001) + + # Evaluate the model on training data + valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + MSE = valuesAndPreds \ + .map(lambda (v, p): (v - p)**2) \ + .reduce(lambda x, y: x + y) / valuesAndPreds.count() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/pythonLinearRegressionWithSGDModel") + sameModel = LinearRegressionModel.load(sc, "target/tmp/pythonLinearRegressionWithSGDModel") + # $example off$ diff --git a/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py new file mode 100644 index 0000000000000..e030b74ba6b15 --- /dev/null +++ b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Logistic Regression With LBFGS Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel +from pyspark.mllib.regression import LabeledPoint +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/sample_svm_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LogisticRegressionWithLBFGS.train(parsedData) + + # Evaluating the model on training data + labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) + print("Training Error = " + str(trainErr)) + + # Save and load model + model.save(sc, "target/tmp/pythonLogisticRegressionWithLBFGSModel") + sameModel = LogisticRegressionModel.load(sc, + "target/tmp/pythonLogisticRegressionWithLBFGSModel") + # $example off$ diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py new file mode 100644 index 0000000000000..f600496867c11 --- /dev/null +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Streaming Linear Regression Example. +""" +from __future__ import print_function + +# $example on$ +import sys +# $example off$ + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +# $example off$ + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: streaming_linear_regression_example.py ", + file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") + ssc = StreamingContext(sc, 1) + + # $example on$ + def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + + trainingData = ssc.textFileStream(sys.argv[1]).map(parse).cache() + testData = ssc.textFileStream(sys.argv[2]).map(parse) + + numFeatures = 3 + model = StreamingLinearRegressionWithSGD() + model.setInitialWeights([0.0, 0.0, 0.0]) + + model.trainOn(trainingData) + print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + + ssc.start() + ssc.awaitTermination() + # $example off$ diff --git a/examples/src/main/python/mllib/svm_with_sgd_example.py b/examples/src/main/python/mllib/svm_with_sgd_example.py new file mode 100644 index 0000000000000..309ab09cc375a --- /dev/null +++ b/examples/src/main/python/mllib/svm_with_sgd_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import SVMWithSGD, SVMModel +from pyspark.mllib.regression import LabeledPoint +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVMWithSGDExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/sample_svm_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = SVMWithSGD.train(parsedData, iterations=100) + + # Evaluating the model on training data + labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) + print("Training Error = " + str(trainErr)) + + # Save and load model + model.save(sc, "target/tmp/pythonSVMWithSGDModel") + sameModel = SVMModel.load(sc, "target/tmp/pythonSVMWithSGDModel") + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala new file mode 100644 index 0000000000000..669868787e8f0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +// $example off$ + +object LinearRegressionWithSGDExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LinearRegressionWithSGDExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/ridge-data/lpsa.data") + val parsedData = data.map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) + }.cache() + + // Building the model + val numIterations = 100 + val stepSize = 0.00000001 + val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) + + // Evaluate model on training examples and compute training error + val valuesAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() + println("training Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") + val sameModel = LinearRegressionModel.load(sc, "target/tmp/scalaLinearRegressionWithSGDModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala new file mode 100644 index 0000000000000..632a2d537e5bc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object LogisticRegressionWithLBFGSExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionWithLBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training) + + // Compute raw scores on the test set. + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Get evaluation metrics. + val metrics = new MulticlassMetrics(predictionAndLabels) + val precision = metrics.precision + println("Precision = " + precision) + + // Save and load model + model.save(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel") + val sameModel = LogisticRegressionModel.load(sc, + "target/tmp/scalaLogisticRegressionWithLBFGSModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala new file mode 100644 index 0000000000000..b73fe9b2b3faa --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object SVMWithSGDExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("SVMWithSGDExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + // Run training algorithm to build the model + val numIterations = 100 + val model = SVMWithSGD.train(training, numIterations) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Area under ROC = " + auROC) + + // Save and load model + model.save(sc, "target/tmp/scalaSVMWithSGDModel") + val sameModel = SVMModel.load(sc, "target/tmp/scalaSVMWithSGDModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala new file mode 100644 index 0000000000000..0a1cd2d62d5b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +// $example off$ +import org.apache.spark.streaming._ + +object StreamingLinearRegressionExample { + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + System.err.println("Usage: StreamingLinearRegressionExample ") + System.exit(1) + } + + val conf = new SparkConf().setAppName("StreamingLinearRegressionExample") + val ssc = new StreamingContext(conf, Seconds(1)) + + // $example on$ + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse).cache() + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val numFeatures = 3 + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.zeros(numFeatures)) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + // $example off$ + + ssc.stop() + } +} +// scalastyle:on println From 727e78014fd4957e477d62adc977fa4da3e3455d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 26 Feb 2016 17:11:19 +0000 Subject: [PATCH 2049/2089] [MINOR][SQL] Fix modifier order. ## What changes were proposed in this pull request? This PR fixes the order of modifier from `abstract public` into `public abstract`. Currently, when we run `./dev/lint-java`, it shows the error. ``` Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/util/sketch/CountMinSketch.java:[53,10] (modifier) ModifierOrder: 'public' modifier out of order with the JLS suggestions. ``` ## How was this patch tested? ``` $ ./dev/lint-java Checkstyle checks passed. ``` Author: Dongjoon Hyun Closes #11390 from dongjoon-hyun/fix_modifier_order. --- .../main/java/org/apache/spark/util/sketch/CountMinSketch.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 2c9aa93582e6f..40fa20c4a3e37 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -50,7 +50,7 @@ * * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ -abstract public class CountMinSketch { +public abstract class CountMinSketch { public enum Version { /** From 6df1e55a6594ae4bc7882f44af8d230aad9489b4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Feb 2016 09:58:05 -0800 Subject: [PATCH 2050/2089] [SPARK-12313] [SQL] improve performance of BroadcastNestedLoopJoin ## What changes were proposed in this pull request? Currently, BroadcastNestedLoopJoin is implemented for worst case, it's too slow, very easy to hang forever. This PR will create fast path for some joinType and buildSide, also improve the worst case (will use much less memory than before). Before this PR, one task requires O(N*K) + O(K) in worst cases, N is number of rows from one partition of streamed table, it could hang the job (because of GC). In order to workaround this for InnerJoin, we have to disable auto-broadcast, switch to CartesianProduct: This could be workaround for InnerJoin, see https://forums.databricks.com/questions/6747/how-do-i-get-a-cartesian-product-of-a-huge-dataset.html In this PR, we will have fast path for these joins : InnerJoin with BuildLeft or BuildRight LeftOuterJoin with BuildRight RightOuterJoin with BuildLeft LeftSemi with BuildRight These fast paths are all stream based (take one pass on streamed table), required O(1) memory. All other join types and build types will take two pass on streamed table, one pass to find the matched rows that includes streamed part, which require O(1) memory, another pass to find the rows from build table that does not have a matched row from streamed table, which required O(K) memory, K is the number rows from build side, one bit per row, should be much smaller than the memory for broadcast. The following join types work in this way: LeftOuterJoin with BuildLeft RightOuterJoin with BuildRight FullOuterJoin with BuildLeft or BuildRight LeftSemi with BuildLeft This PR also added tests for all the join types for BroadcastNestedLoopJoin. After this PR, for InnerJoin with one small table, BroadcastNestedLoopJoin should be faster than CartesianProduct, we don't need that workaround anymore. ## How was the this patch tested? Added unit tests. Author: Davies Liu Closes #11328 from davies/nested_loop. --- .../plans/physical/broadcastMode.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 14 +- .../joins/BroadcastNestedLoopJoin.scala | 295 +++++++++++++----- .../org/apache/spark/sql/JoinSuite.scala | 11 +- .../sql/execution/joins/InnerJoinSuite.scala | 27 ++ .../sql/execution/joins/OuterJoinSuite.scala | 18 ++ .../sql/execution/joins/SemiJoinSuite.scala | 20 +- 7 files changed, 295 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index c646dcfa11811..e01f69f81359e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -31,5 +31,6 @@ trait BroadcastMode { * IdentityBroadcastMode requires that rows are broadcasted in their original form. */ case object IdentityBroadcastMode extends BroadcastMode { + // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5fdf38c733efc..dd8c96d5fa1d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -253,22 +253,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil - case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil + case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil + planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemi => + case logical.Join(left, right, Inner, None) => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -286,6 +283,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { joins.BuildLeft } + // This join could be very slow or even hang forever joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index e8bd7f69dbab9..d83486df02c87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +27,6 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} - case class BroadcastNestedLoopJoin( left: SparkPlan, right: SparkPlan, @@ -51,125 +51,266 @@ case class BroadcastNestedLoopJoin( } private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + // Always put the stream side on left to simplify implementation + UnsafeProjection.create(output, streamed.output ++ broadcast.output) + } } override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { joinType match { + case Inner => + left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case Inner => - // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case - left.output ++ right.output - case x => // TODO support the Left Semi Join + case LeftSemi => + left.output + case x => throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + @transient private lazy val boundCondition = { + if (condition.isDefined) { + newPredicate(condition.get, streamed.output ++ broadcast.output) + } else { + (r: InternalRow) => true + } + } - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") + /** + * The implementation for InnerJoin. + */ + private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + streamedIter.flatMap { streamedRow => + val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) + if (condition.isDefined) { + joinedRows.filter(boundCondition) + } else { + joinedRows + } + } + } + } - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val relation = broadcastedRelation.value + /** + * The implementation for these joins: + * + * LeftOuter with BuildRight + * RightOuter with BuildLeft + */ + private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) + + // Returns an iterator to avoid copy the rows. + new Iterator[InternalRow] { + // current row from stream side + private var streamRow: InternalRow = null + // have found a match for current row or not + private var foundMatch: Boolean = false + // the matched result row + private var resultRow: InternalRow = null + // the next index of buildRows to try + private var nextIndex: Int = 0 - val matchedRows = new CompactBuffer[InternalRow] - val includedBroadcastTuples = new BitSet(relation.length) + private def findNextMatch(): Boolean = { + if (streamRow == null) { + if (!streamedIter.hasNext) { + return false + } + streamRow = streamedIter.next() + nextIndex = 0 + foundMatch = false + } + while (nextIndex < buildRows.length) { + resultRow = joinedRow(streamRow, buildRows(nextIndex)) + nextIndex += 1 + if (boundCondition(resultRow)) { + foundMatch = true + return true + } + } + if (!foundMatch) { + resultRow = joinedRow(streamRow, nulls) + streamRow = null + true + } else { + resultRow = null + streamRow = null + findNextMatch() + } + } + + override def hasNext(): Boolean = { + resultRow != null || findNextMatch() + } + override def next(): InternalRow = { + val r = resultRow + resultRow = null + r + } + } + } + } + + /** + * The implementation for these joins: + * + * LeftSemi with BuildRight + */ + private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + if (condition.isDefined) { + streamedIter.filter(l => + buildRows.exists(r => boundCondition(joinedRow(l, r))) + ) + } else { + streamedIter.filter(r => !buildRows.isEmpty) + } + } + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + */ + private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val streamRdd = streamed.execute() + + val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val matched = new BitSet(buildRows.length) + val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => var i = 0 - var streamRowMatched = false - - while (i < relation.length) { - val broadcastedRow = relation(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() - streamRowMatched = true - includedBroadcastTuples.set(i) - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() - streamRowMatched = true - includedBroadcastTuples.set(i) - case _ => + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matched.set(i) } i += 1 } + } + Seq(matched).toIterator + } - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => + val matchedBroadcastRows = matchedBuildRows.fold( + new BitSet(relation.value.length) + )(_ | _) + + if (joinType == LeftSemi) { + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() } + i += 1 } - Iterator((matchedRows, includedBroadcastTuples)) + return sparkContext.makeRDD(buf.toSeq) } - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new BitSet(broadcastedRelation.value.size) - )(_ | _) + val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + streamedIter.flatMap { streamedRow => + var i = 0 + var foundMatch = false + val matchedRows = new CompactBuffer[InternalRow] + + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matchedRows += joinedRow.copy() + foundMatch = true + } + i += 1 + } + + if (!foundMatch && joinType == FullOuter) { + matchedRows += joinedRow(streamedRow, nulls).copy() + } + matchedRows.iterator + } + } - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[InternalRow] = { + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 - val rel = broadcastedRelation.value - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - val joinedRow = new JoinedRow - joinedRow.withLeft(leftNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.get(i)) { - buf += resultProj(joinedRow.withRight(rel(i))).copy() - } - i += 1 - } - case (LeftOuter | FullOuter, BuildLeft) => - val joinedRow = new JoinedRow - joinedRow.withRight(rightNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.get(i)) { - buf += resultProj(joinedRow.withLeft(rel(i))).copy() - } - i += 1 - } - case _ => + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 } buf.toSeq } - // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), - sparkContext.makeRDD(broadcastRowsWithNulls) - ).map { row => - // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. - numOutputRows += 1 - row + matchedStreamRows, + sparkContext.makeRDD(notMatchedBroadcastRows) + ) + } + + protected override def doExecute(): RDD[InternalRow] = { + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + + val resultRdd = (joinType, buildSide) match { + case (Inner, _) => + innerJoin(broadcastedRelation) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => + outerJoin(broadcastedRelation) + case (LeftSemi, BuildRight) => + leftSemiJoin(broadcastedRelation) + case _ => + /** + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + */ + defaultJoin(broadcastedRelation) + } + + val numOutputRows = longMetric("numOutputRows") + resultRdd.mapPartitionsInternal { iter => + val resultProj = genResultProjection + iter.map { r => + numOutputRows += 1 + resultProj(r) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 41e27ec46648f..3dab848e7b033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -70,13 +70,14 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", - classOf[CartesianProduct]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index b748229e402f5..7eb15249ebbd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -146,6 +146,33 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using CartesianProduct") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + Filter(condition(), CartesianProduct(left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } testInnerJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 22fe8caff265e..0d1c29fe574a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -105,6 +105,24 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } // --- Basic outer joins ------------------------------------------------------------------------ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 5c982885d652f..355f916a97552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements @@ -103,6 +103,24 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } testLeftSemiJoin( From 0598a2b81d1426dd2cf9e6fc32cef345364d18c6 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 26 Feb 2016 12:43:50 -0800 Subject: [PATCH 2051/2089] [SPARK-13499] [SQL] Performance improvements for parquet reader. ## What changes were proposed in this pull request? This patch includes these performance fixes: - Remove unnecessary setNotNull() calls. The NULL bits are cleared already. - Speed up RLE group decoding - Speed up dictionary decoding by decoding NULLs directly into the result. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) In addition to the updated benchmarks, on TPCDS, the result of these changes running Q55 (sf40) is: ``` TPCDS: Best/Avg Time(ms) Rate(M/s) Per Row(ns) --------------------------------------------------------------------------------- q55 (Before) 6398 / 6616 18.0 55.5 q55 (After) 4983 / 5189 23.1 43.3 ``` Author: Nong Li Closes #11375 from nongli/spark-13499. --- .../parquet/UnsafeRowParquetRecordReader.java | 17 +---- .../parquet/VectorizedRleValuesReader.java | 66 ++++++++++++------- .../parquet/ParquetReadBenchmark.scala | 30 ++++----- 3 files changed, 59 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 4576ac2a3222f..9d50cfab3bd3d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -628,7 +628,8 @@ private void readBatch(int total, ColumnVector column) throws IOException { dictionaryIds.reserve(total); } // Read and decode dictionary ids. - readIntBatch(rowId, num, dictionaryIds); + defColumn.readIntegers( + num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); decodeDictionaryIds(rowId, num, column); } else { switch (descriptor.getType()) { @@ -739,18 +740,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { default: throw new NotImplementedException("Unsupported type: " + descriptor.getType()); } - - if (dictionaryIds.numNulls() > 0) { - // Copy the NULLs over. - // TODO: we can improve this by decoding the NULLs directly into column. This would - // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then - // just do the ID remapping as above. - for (int i = 0; i < num; ++i) { - if (dictionaryIds.getIsNull(rowId + i)) { - column.putNull(rowId + i); - } - } - } } /** @@ -769,7 +758,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 629959a73baf3..b2048c0e397ed 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; @@ -176,11 +175,11 @@ public int readInteger() { * if (this.readInt() == level) { * c[rowId] = data.readInteger(); * } else { - * c[rowId] = nullValue; + * c[rowId] = null; * } */ public void readIntegers(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data, int nullValue) { + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -189,7 +188,6 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readIntegers(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -198,9 +196,7 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putInt(rowId + i, data.readInteger()); - c.putNotNull(rowId + i); } else { - c.putInt(rowId + i, nullValue); c.putNull(rowId + i); } } @@ -223,7 +219,6 @@ public void readBooleans(int total, ColumnVector c, case RLE: if (currentValue == level) { data.readBooleans(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -232,7 +227,6 @@ public void readBooleans(int total, ColumnVector c, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putBoolean(rowId + i, data.readBoolean()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -257,7 +251,6 @@ public void readIntsAsLongs(int total, ColumnVector c, for (int i = 0; i < n; i++) { c.putLong(rowId + i, data.readInteger()); } - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -266,7 +259,6 @@ public void readIntsAsLongs(int total, ColumnVector c, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putLong(rowId + i, data.readInteger()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -289,7 +281,6 @@ public void readBytes(int total, ColumnVector c, case RLE: if (currentValue == level) { data.readBytes(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -298,7 +289,6 @@ public void readBytes(int total, ColumnVector c, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putByte(rowId + i, data.readByte()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -321,7 +311,6 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readLongs(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -330,7 +319,6 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putLong(rowId + i, data.readLong()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -353,7 +341,6 @@ public void readFloats(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readFloats(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -362,7 +349,6 @@ public void readFloats(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putFloat(rowId + i, data.readFloat()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -385,7 +371,6 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level, case RLE: if (currentValue == level) { data.readDoubles(n, c, rowId); - c.putNotNulls(rowId, n); } else { c.putNulls(rowId, n); } @@ -394,7 +379,6 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level, for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { c.putDouble(rowId + i, data.readDouble()); - c.putNotNull(rowId + i); } else { c.putNull(rowId + i); } @@ -416,7 +400,6 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, switch (mode) { case RLE: if (currentValue == level) { - c.putNotNulls(rowId, n); data.readBinary(n, c, rowId); } else { c.putNulls(rowId, n); @@ -425,7 +408,6 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, case PACKED: for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { - c.putNotNull(rowId + i); data.readBinary(1, c, rowId + i); } else { c.putNull(rowId + i); @@ -439,6 +421,40 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, } } + /** + * Decoding for dictionary ids. The IDs are populated into `values` and the nullability is + * populated into `nulls`. + */ + public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readIntegers(n, values, rowId); + } else { + nulls.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + values.putInt(rowId + i, data.readInteger()); + } else { + nulls.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + // The RLE reader implements the vectorized decoding interface when used to decode dictionary // IDs. This is different than the above APIs that decodes definitions levels along with values. @@ -560,12 +576,14 @@ private int readIntLittleEndianPaddedOnBitWidth() { throw new RuntimeException("Unreachable"); } + private int ceil8(int value) { + return (value + 7) / 8; + } + /** * Reads the next group. */ private void readNextGroup() { - Preconditions.checkArgument(this.offset < this.end, - "Reading past RLE/BitPacking stream. offset=" + this.offset + " end=" + this.end); int header = readUnsignedVarInt(); this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; switch (mode) { @@ -576,14 +594,12 @@ private void readNextGroup() { case PACKED: int numGroups = header >>> 1; this.currentCount = numGroups * 8; + int bytesToRead = ceil8(this.currentCount * this.bitWidth); if (this.currentBuffer.length < this.currentCount) { this.currentBuffer = new int[this.currentCount]; } currentBufferIdx = 0; - int bytesToRead = (int)Math.ceil((double)(this.currentCount * this.bitWidth) / 8.0D); - - bytesToRead = Math.min(bytesToRead, this.end - this.offset); int valueIndex = 0; for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index dafc589999c7c..660f0f173a9e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -150,21 +150,21 @@ object ParquetReadBenchmark { /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - SQL Parquet Reader 1350.56 11.65 1.00 X - SQL Parquet MR 1844.09 8.53 0.73 X - SQL Parquet Vectorized 1062.04 14.81 1.27 X + SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Reader 1042 / 1208 15.1 66.2 1.0X + SQL Parquet MR 1544 / 1607 10.2 98.2 0.7X + SQL Parquet Vectorized 674 / 739 23.3 42.9 1.5X */ sqlBenchmark.run() /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - ParquetReader 610.40 25.77 1.00 X - ParquetReader(Batched) 172.66 91.10 3.54 X - ParquetReader(Batch -> Row) 192.28 81.80 3.17 X + Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + ParquetReader 565 / 609 27.8 35.9 1.0X + ParquetReader(Batched) 165 / 174 95.3 10.5 3.4X + ParquetReader(Batch -> Row) 158 / 188 99.3 10.1 3.6X */ parquetReaderBenchmark.run() } @@ -218,12 +218,12 @@ object ParquetReadBenchmark { /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------- - SQL Parquet Reader 1737.94 6.03 1.00 X - SQL Parquet MR 2393.08 4.38 0.73 X - SQL Parquet Vectorized 1442.99 7.27 1.20 X - ParquetReader 1032.11 10.16 1.68 X + SQL Parquet Reader 1381 / 1679 7.6 131.7 1.0X + SQL Parquet MR 2005 / 2177 5.2 191.2 0.7X + SQL Parquet Vectorized 919 / 1044 11.4 87.6 1.5X + ParquetReader 1035 / 1163 10.1 98.7 1.3X */ benchmark.run() } From 391755dc6ed2e156b8df8a530ac8df6ed7ba7f8a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 26 Feb 2016 12:49:16 -0800 Subject: [PATCH 2052/2089] [SPARK-13465] Add a task failure listener to TaskContext ## What changes were proposed in this pull request? TaskContext supports task completion callback, which gets called regardless of task failures. However, there is no way for the listener to know if there is an error. This patch adds a new listener that gets called when a task fails. ## How was the this patch tested? New unit test case and integration test case covering the code path Author: Reynold Xin Closes #11340 from rxin/SPARK-13465. --- .../scala/org/apache/spark/TaskContext.scala | 28 +++++++++++- .../org/apache/spark/TaskContextImpl.scala | 33 +++++++++++--- .../org/apache/spark/scheduler/Task.scala | 5 +++ .../TaskCompletionListenerException.scala | 34 -------------- ...tionListener.scala => taskListeners.scala} | 37 +++++++++++++++- .../spark/JavaTaskCompletionListenerImpl.java | 39 ---------------- .../spark/JavaTaskContextCompileCheck.java | 30 +++++++++++++ .../spark/scheduler/TaskContextSuite.scala | 44 ++++++++++++++++++- project/MimaExcludes.scala | 4 +- 9 files changed, 169 insertions(+), 85 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala rename core/src/main/scala/org/apache/spark/util/{TaskCompletionListener.scala => taskListeners.scala} (51%) delete mode 100644 core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 9f49cf1c4c9bd..bfcacbf229b00 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener} object TaskContext { @@ -106,6 +106,8 @@ abstract class TaskContext extends Serializable { * Adds a (Java friendly) listener to be executed on task completion. * This will be called in all situation - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext @@ -113,8 +115,30 @@ abstract class TaskContext extends Serializable { * Adds a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + }) + } + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(listener: TaskFailureListener): TaskContext + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { + addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error) + }) + } /** * The ID of the stage that this task belong to. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 1d228b6b86c55..65f6f741f7f4e 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ private[spark] class TaskContextImpl( val stageId: Int, @@ -41,9 +41,12 @@ private[spark] class TaskContextImpl( */ override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators) - // List of callback functions to execute when the task completes. + /** List of callback functions to execute when the task completes. */ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + /** List of callback functions to execute when the task fails. */ + @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] + // Whether the corresponding task has been killed. @volatile private var interrupted: Boolean = false @@ -55,14 +58,30 @@ private[spark] class TaskContextImpl( this } - override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + onFailureCallbacks += listener this } - /** Marks the task as completed and triggers the listeners. */ + /** Marks the task as completed and triggers the failure listeners. */ + private[spark] def markTaskFailed(error: Throwable): Unit = { + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onFailureCallbacks.reverse.foreach { listener => + try { + listener.onTaskFailure(this, error) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskFailureListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs, Option(error)) + } + } + + /** Marks the task as completed and triggers the completion listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true val errorMsgs = new ArrayBuffer[String](2) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c68d001f280f..d2b8ca90a9899 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -80,7 +80,12 @@ private[spark] abstract class Task[T]( } try { runTask(context) + } catch { case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + context.markTaskFailed(e) + throw e } finally { + // Call the task completion callbacks. context.markTaskCompleted() try { Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala deleted file mode 100644 index f64e069cd1724..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Exception thrown when there is an exception in - * executing the callback in TaskCompletionListener. - */ -private[spark] -class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { - - override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala similarity index 51% rename from core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala rename to core/src/main/scala/org/apache/spark/util/taskListeners.scala index c1b8bf052c0ca..1be31e88ab68e 100644 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -29,5 +29,40 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait TaskCompletionListener extends EventListener { - def onTaskCompletion(context: TaskContext) + def onTaskCompletion(context: TaskContext): Unit +} + + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution encounters an error. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ +@DeveloperApi +trait TaskFailureListener extends EventListener { + def onTaskFailure(context: TaskContext, error: Throwable): Unit +} + + +/** + * Exception thrown when there is an exception in executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException( + errorMessages: Seq[String], + val previousError: Option[Throwable] = None) + extends RuntimeException { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + + previousError.map { e => + "\n\nPrevious exception in task: " + e.getMessage + "\n" + + e.getStackTrace.mkString("\t", "\n\t", "") + }.getOrElse("") + } } diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java deleted file mode 100644 index e38bc38949d7c..0000000000000 --- a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark; - -import org.apache.spark.TaskContext; -import org.apache.spark.util.TaskCompletionListener; - - -/** - * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and - * TaskContext is Java friendly. - */ -public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { - - @Override - public void onTaskCompletion(TaskContext context) { - context.isCompleted(); - context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.isRunningLocally(); - context.addTaskCompletionListener(this); - } -} diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 4a918f725dc91..f914081d7d5b2 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -18,6 +18,8 @@ package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskFailureListener; /** * Something to make sure that TaskContext can be used in Java. @@ -32,10 +34,38 @@ public static void test() { tc.isRunningLocally(); tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + tc.addTaskFailureListener(new JavaTaskFailureListenerImpl()); tc.attemptNumber(); tc.partitionId(); tc.stageId(); tc.taskAttemptId(); } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.isRunningLocally(); + context.addTaskCompletionListener(this); + } + } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskFailureListenerImpl implements TaskFailureListener { + @Override + public void onTaskFailure(TaskContext context, Throwable error) { + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 850e470ca14d6..c4cf2f9f70755 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -66,6 +66,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(TaskContextSuite.completed === true) } + test("calls TaskFailureListeners after failure") { + TaskContextSuite.lastError = null + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error) + sys.error("damn error") + } + } + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + intercept[RuntimeException] { + task.run(0, 0, null) + } + assert(TaskContextSuite.lastError.getMessage == "damn error") + } + test("all TaskCompletionListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) @@ -80,6 +100,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark verify(listener, times(1)).onTaskCompletion(any()) } + test("all TaskFailureListeners should be called even if some fail") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskFailureListener]) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1")) + context.addTaskFailureListener(listener) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskFailed(new Exception("exception in task")) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskFailure(any(), any()) + + // also need to check failure in TaskFailureListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks // Check that attemptIds are 0 for all tasks' initial attempts @@ -153,6 +193,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark private object TaskContextSuite { @volatile var completed = false + + @volatile var lastError: Throwable = _ } private case class StubPartition(index: Int) extends Partition diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 14e3c90f1b0d2..165280a1b2e41 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -271,7 +271,9 @@ object MimaExcludes { ) ++ Seq( // SPARK-13220 Deprecate yarn-client and yarn-cluster mode ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler"), + // SPARK-13465 TaskContext. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") ) ++ Seq ( // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") From 1e5fcdf96c0176a11e5f425ba539b6ed629281db Mon Sep 17 00:00:00 2001 From: zlpmichelle Date: Fri, 26 Feb 2016 14:37:44 -0800 Subject: [PATCH 2053/2089] [SPARK-13505][ML] add python api for MaxAbsScaler ## What changes were proposed in this pull request? After SPARK-13028, we should add Python API for MaxAbsScaler. ## How was this patch tested? unit test Author: zlpmichelle Closes #11393 from zlpmichelle/master. --- python/pyspark/ml/feature.py | 75 ++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 67bccfae7a26d..369f3508fda5e 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -28,13 +28,14 @@ from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', - 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', - 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', - 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', - 'ChiSqSelector', 'ChiSqSelectorModel'] + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', + 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinMaxScaler', 'MinMaxScalerModel', + 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', + 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel', + 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', + 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', + 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', 'ChiSqSelector', + 'ChiSqSelectorModel'] @inherit_doc @@ -544,6 +545,66 @@ class IDFModel(JavaModel): """ +@inherit_doc +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to range [-1, 1] by dividing through the largest maximum + absolute value in each feature. It does not shift/center the data, and thus does not destroy + any sparsity. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled") + >>> model = maScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[1.0]| [0.5]| + |[2.0]| [1.0]| + +-----+------+ + ... + + .. versionadded:: 2.0.0 + """ + + @keyword_only + def __init__(self, inputCol=None, outputCol=None): + """ + __init__(self, inputCol=None, outputCol=None) + """ + super(MaxAbsScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) + self._setDefault() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, inputCol=None, outputCol=None): + """ + setParams(self, inputCol=None, outputCol=None) + Sets params for this MaxAbsScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return MaxAbsScalerModel(java_model) + + +class MaxAbsScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MaxAbsScaler`. + + .. versionadded:: 2.0.0 + """ + + @inherit_doc class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): """ From ad615291fe76580ee59e3f48f4efe4627a01409d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 26 Feb 2016 15:11:57 -0800 Subject: [PATCH 2054/2089] [SPARK-13519][CORE] Driver should tell Executor to stop itself when cleaning executor's state ## What changes were proposed in this pull request? When the driver removes an executor's state, the connection between the driver and the executor may be still alive so that the executor cannot exit automatically (E.g., Master will send RemoveExecutor when a work is lost but the executor is still alive), so the driver should try to tell the executor to stop itself. Otherwise, we will leak an executor. This PR modified the driver to send `StopExecutor` to the executor when it's removed. ## How was this patch tested? manual test: increase the worker heartbeat interval to force it's always timeout and the leak executors are gone. Author: Shixiong Zhu Closes #11399 from zsxwing/SPARK-13519. --- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 0a5b09dc0d1fa..d151de5f6a830 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -179,6 +179,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case RemoveExecutor(executorId, reason) => + // We will remove the executor's state and cannot restore it. However, the connection + // between the driver and the executor may be still alive so that the executor won't exit + // automatically, so try to tell the executor to stop itself. See SPARK-13519. + executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) context.reply(true) From f77dc4e1e202942aa8393fb5d8f492863973fe17 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Feb 2016 18:40:00 -0800 Subject: [PATCH 2055/2089] [SPARK-13474][PROJECT INFRA] Update packaging scripts to push artifacts to home.apache.org Due to the people.apache.org -> home.apache.org migration, we need to update our packaging scripts to publish artifacts to the new server. Because the new server only supports sftp instead of ssh, we need to update the scripts to use lftp instead of ssh + rsync. Author: Josh Rosen Closes #11350 from JoshRosen/update-release-scripts-for-apache-home. --- dev/create-release/release-build.sh | 60 +++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 2fd7fcc39ea28..c08b6d7de6fe0 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -23,8 +23,8 @@ usage: release-build.sh Creates build deliverables from a Spark commit. Top level targets are - package: Create binary packages and copy them to people.apache - docs: Build docs and copy them to people.apache + package: Create binary packages and copy them to home.apache + docs: Build docs and copy them to home.apache publish-snapshot: Publish snapshot release to Apache snapshots publish-release: Publish a release to Apache release repo @@ -64,13 +64,16 @@ for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do fi done +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + # Commit ref to checkout when building GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" GPG="gpg --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads @@ -97,7 +100,20 @@ if [ -z "$SPARK_PACKAGE_VERSION" ]; then fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" -USER_HOST="$ASF_USERNAME@people.apache.org" + +function LFTP { + SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" + COMMANDS=$(cat < Date: Fri, 26 Feb 2016 22:35:12 -0800 Subject: [PATCH 2056/2089] [SPARK-13521][BUILD] Remove reference to Tachyon in cluster & release scripts ## What changes were proposed in this pull request? We provide a very limited set of cluster management script in Spark for Tachyon, although Tachyon itself provides a much better version of it. Given now Spark users can simply use Tachyon as a normal file system and does not require extensive configurations, we can remove this management capabilities to simplify Spark bash scripts. Note that this also reduces coupling between a 3rd party external system and Spark's release scripts, and would eliminate possibility for failures such as Tachyon being renamed or the tar balls being relocated. ## How was this patch tested? N/A Author: Reynold Xin Closes #11400 from rxin/release-script. --- docs/configuration.md | 24 ------------------- docs/job-scheduling.md | 3 +-- docs/programming-guide.md | 22 ++--------------- make-distribution.sh | 50 +-------------------------------------- sbin/start-all.sh | 15 ++---------- sbin/start-master.sh | 21 ---------------- sbin/start-slaves.sh | 22 ----------------- sbin/stop-master.sh | 4 ---- sbin/stop-slaves.sh | 5 ---- 9 files changed, 6 insertions(+), 160 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 4b1b00720b006..e9b66238bd189 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -929,30 +929,6 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system.
        - - - - - - - - - - - - - - -
        Property NameDefaultMeaniing
        Property NameDefaultMeaning
        spark.deploy.recoveryMode NONE
        Precision (Postive Predictive Value)Precision (Positive Predictive Value) $PPV=\frac{TP}{TP + FP}$
        $RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
        Mean Absoloute Error (MAE)Mean Absolute Error (MAE) $MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
        spark.history.fs.update.interval 10s - The period at which the the filesystem history provider checks for new or + The period at which the filesystem history provider checks for new or updated logs in the log directory. A shorter interval detects new applications faster, at the expense of more server load re-reading updated applications. As soon as an update has completed, listings of the completed and incomplete applications diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 2d6f7767d93ef..5ebafa40b0c7a 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -629,7 +629,7 @@ class MyClass { } {% endhighlight %} -is equilvalent to writing `rdd.map(x => this.field + x)`, which references all of `this`. To avoid this +is equivalent to writing `rdd.map(x => this.field + x)`, which references all of `this`. To avoid this issue, the simplest way is to copy `field` into a local variable instead of accessing it externally: {% highlight scala %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 35f6caab1700c..b9f64c7ed146e 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -188,7 +188,7 @@ overhead, but at the cost of reserving the Mesos resources for the complete dura application. Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true -to turn it on explictly in [SparkConf](configuration.html#spark-properties): +to turn it on explicitly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -384,7 +384,7 @@ See the [configuration page](configuration.html) for information on Spark config
      • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
      • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
      • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
      • -
      • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
      • +
      • Text constraints are matched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
      • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
      • spark.mesos.uris (none) - A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. - This applies to both coarse-grain and fine-grain mode. + A comma-separated list of URIs to be downloaded to the sandbox + when driver or executor is launched by Mesos. This applies to + both coarse-grained and fine-grained mode.
        spark.externalBlockStore.blockManagerorg.apache.spark.storage.TachyonBlockManager - Implementation of external block manager (file system) that store RDDs. The file system's URL is set by - spark.externalBlockStore.url. -
        spark.externalBlockStore.baseDirSystem.getProperty("java.io.tmpdir") - Directories of the external block store that store RDDs. The file system's URL is set by - spark.externalBlockStore.url It can also be a comma-separated list of multiple - directories on Tachyon file system. -
        spark.externalBlockStore.urltachyon://localhost:19998 for Tachyon - The URL of the underlying external blocker file system in the external block store. -
        #### Networking diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 95d47794ea337..00b6a18836a54 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -54,8 +54,7 @@ an application to gain back cores on one node when it has work to do. To use thi Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying -the same RDDs. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will -provide another approach to share RDDs. +the same RDDs. ## Dynamic Resource Allocation diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 5ebafa40b0c7a..2f0ed5eca2b2b 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1177,7 +1177,7 @@ that originally created it. In addition, each persisted RDD can be stored using a different *storage level*, allowing you, for example, to persist the dataset on disk, persist it in memory but as serialized Java objects (to save space), -replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-project.org/). +replicate it across nodes. These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), @@ -1218,24 +1218,11 @@ storage levels is: MEMORY_ONLY_2, MEMORY_AND_DISK_2, etc. Same as the levels above, but replicate each partition on two cluster nodes. - - OFF_HEAP (experimental) - Store RDD in serialized format in Tachyon. - Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors - to be smaller and to share a pool of memory, making it attractive in environments with - large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory - in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon - out-of-the-box. Please refer to this page - for the suggested version pairings. - - **Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, -`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, `DISK_ONLY_2` and `OFF_HEAP`.* +`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1259,11 +1246,6 @@ requests from a web application). *All* the storage levels provide full fault to recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. -* In environments with high amounts of memory or multiple applications, the experimental `OFF_HEAP` -mode has several advantages: - * It allows multiple executors to share the same pool of memory in Tachyon. - * It significantly reduces garbage collection costs. - * Cached data is not lost if individual executors crash. ### Removing Data diff --git a/make-distribution.sh b/make-distribution.sh index 327659298e4d8..20998144f0b50 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,11 +32,6 @@ set -x SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" -SPARK_TACHYON=false -TACHYON_VERSION="0.8.2" -TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" - MAKE_TGZ=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -45,7 +40,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--mvn ] [--with-tachyon]" + cl_options="[--name] [--tgz] [--mvn ]" echo "./make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -69,9 +64,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --with-tachyon) - SPARK_TACHYON=true - ;; --tgz) MAKE_TGZ=true ;; @@ -150,12 +142,6 @@ else echo "Making distribution for Spark $VERSION in $DISTDIR..." fi -if [ "$SPARK_TACHYON" == "true" ]; then - echo "Tachyon Enabled" -else - echo "Tachyon Disabled" -fi - # Build uber fat JAR cd "$SPARK_HOME" @@ -219,40 +205,6 @@ if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi -# Download and copy in tachyon, if requested -if [ "$SPARK_TACHYON" == "true" ]; then - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - - pushd "$TMPD" > /dev/null - echo "Fetching tachyon tgz" - - TACHYON_DL="${TACHYON_TGZ}.part" - if [ $(command -v curl) ]; then - curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" - elif [ $(command -v wget) ]; then - wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" - else - printf "You do not have curl or wget installed. please install Tachyon manually.\n" - exit -1 - fi - - tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" - mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" - cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" - - if [[ `uname -a` == Darwin* ]]; then - # need to run sed differently on osx - nl=$'\n'; sed -i "" -e "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\\$nl export TACHYON_JAR=\$TACHYON_HOME/../lib/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" - else - sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../lib/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" - fi - - popd > /dev/null - rm -rf "$TMPD" -fi - if [ "$MAKE_TGZ" == "true" ]; then TARDIR_NAME=spark-$VERSION-bin-$NAME TARDIR="$SPARK_HOME/$TARDIR_NAME" diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 6217f9bf28e3d..a5d30d274ea6e 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -25,22 +25,11 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -TACHYON_STR="" - -while (( "$#" )); do -case $1 in - --with-tachyon) - TACHYON_STR="--with-tachyon" - ;; - esac -shift -done - # Load the Spark configuration . "${SPARK_HOME}/sbin/spark-config.sh" # Start Master -"${SPARK_HOME}/sbin"/start-master.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-master.sh # Start Workers -"${SPARK_HOME}/sbin"/start-slaves.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-slaves.sh diff --git a/sbin/start-master.sh b/sbin/start-master.sh index 9f2e14dff609f..ce7f17795997e 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -39,21 +39,6 @@ fi ORIGINAL_ARGS="$@" -START_TACHYON=false - -while (( "$#" )); do -case $1 in - --with-tachyon) - if [ ! -e "${SPARK_HOME}"/tachyon/bin/tachyon ]; then - echo "Error: --with-tachyon specified, but tachyon not found." - exit -1 - fi - START_TACHYON=true - ;; - esac -shift -done - . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -73,9 +58,3 @@ fi "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS - -if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}"/tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "${SPARK_HOME}"/tachyon/bin/tachyon format -s - "${SPARK_HOME}"/tachyon/bin/tachyon-start.sh master -fi diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 51ca81e053b70..5bf2b83b42ce4 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -23,21 +23,6 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -START_TACHYON=false - -while (( "$#" )); do -case $1 in - --with-tachyon) - if [ ! -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - echo "Error: --with-tachyon specified, but tachyon not found." - exit -1 - fi - START_TACHYON=true - ;; - esac -shift -done - . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -50,12 +35,5 @@ if [ "$SPARK_MASTER_IP" = "" ]; then SPARK_MASTER_IP="`hostname`" fi -if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" - - # set -t so we can call sudo - SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 -fi - # Launch the slaves "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index e57962bb354d9..14644ea72d43b 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -26,7 +26,3 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 - -if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master -fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 63956377629d6..a57441b52a04a 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -25,9 +25,4 @@ fi . "${SPARK_HOME}/bin/load-spark-env.sh" -# do before the below calls as they exec -if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker -fi - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-slave.sh From 7a0cb4e58728834b49050ce4fae418acc18a601f Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 26 Feb 2016 22:36:32 -0800 Subject: [PATCH 2057/2089] [SPARK-13518][SQL] Enable vectorized parquet scanner by default ## What changes were proposed in this pull request? Change the default of the flag to enable this feature now that the implementation is complete. ## How was this patch tested? The new parquet reader should be a drop in, so will be exercised by the existing tests. Author: Nong Li Closes #11397 from nongli/spark-13518. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9a50ef77efbd4..1d1e2884414d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -345,12 +345,9 @@ object SQLConf { defaultValue = Some(true), doc = "Enables using the custom ParquetUnsafeRowRecordReader.") - // Note: this can not be enabled all the time because the reader will not be returning UnsafeRows. - // Doing so is very expensive and we should remove this requirement instead of fixing it here. - // Initial testing seems to indicate only sort requires this. val PARQUET_VECTORIZED_READER_ENABLED = booleanConf( key = "spark.sql.parquet.enableVectorizedReader", - defaultValue = Some(false), + defaultValue = Some(true), doc = "Enables vectorized parquet decoding.") val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", From ec0cc75e158335a4110c86cf63c19d7d45167834 Mon Sep 17 00:00:00 2001 From: mark800 Date: Sat, 27 Feb 2016 13:50:37 +0000 Subject: [PATCH 2058/2089] [SPARK-7483][MLLIB] Upgrade Chill to 0.7.2 to support Kryo with FPGrowth It registers more Scala classes, including ListBuffer to support Kryo with FPGrowth. See https://github.com/twitter/chill/releases for Chill's change log. Author: mark800 Closes #11041 from mark800/master. --- dev/deps/spark-deps-hadoop-2.2 | 4 ++-- dev/deps/spark-deps-hadoop-2.3 | 4 ++-- dev/deps/spark-deps-hadoop-2.4 | 4 ++-- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 47482fd3adeeb..266e0ce88d552 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -19,8 +19,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.5.0.jar -chill_2.11-0.5.0.jar +chill-java-0.7.4.jar +chill_2.11-0.7.4.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index dd93abf4f395b..257751a09088f 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -21,8 +21,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.5.0.jar -chill_2.11-0.5.0.jar +chill-java-0.7.4.jar +chill_2.11-0.7.4.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 987a6d7f593f6..061ef2876b55a 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -21,8 +21,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.5.0.jar -chill_2.11-0.5.0.jar +chill-java-0.7.4.jar +chill_2.11-0.7.4.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9f158485d2803..f171d5aa390d2 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -25,8 +25,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.5.0.jar -chill_2.11-0.5.0.jar +chill-java-0.7.4.jar +chill_2.11-0.7.4.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 6cbb9ab9591ce..a7a9dcdef0a9a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -25,8 +25,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.5.0.jar -chill_2.11-0.5.0.jar +chill-java-0.7.4.jar +chill_2.11-0.7.4.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar diff --git a/pom.xml b/pom.xml index 244355b080221..bcf04d48c2f6f 100644 --- a/pom.xml +++ b/pom.xml @@ -147,7 +147,7 @@ 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 - 0.5.0 + 0.7.4 2.4.0 2.0.8 3.1.2 From 3814d0bcf6f1697a94123be4b224cbd7554025a9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 27 Feb 2016 11:41:35 -0800 Subject: [PATCH 2059/2089] [SPARK-13530][SQL] Add ShortType support to UnsafeRowParquetRecordReader JIRA: https://issues.apache.org/jira/browse/SPARK-13530 ## What changes were proposed in this pull request? By enabling vectorized parquet scanner by default, the unit test `ParquetHadoopFsRelationSuite` based on `HadoopFsRelationTest` will be failed due to the lack of short type support in `UnsafeRowParquetRecordReader`. We should fix it. The error exception: [info] ParquetHadoopFsRelationSuite: [info] - test all data types - StringType (499 milliseconds) [info] - test all data types - BinaryType (447 milliseconds) [info] - test all data types - BooleanType (520 milliseconds) [info] - test all data types - ByteType (418 milliseconds) 00:22:58.920 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 124.0 (TID 1949) org.apache.commons.lang.NotImplementedException: Unimplemented type: ShortType at org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader$ColumnReader.readIntBatch(UnsafeRowParquetRecordReader.java:769) at org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader$ColumnReader.readBatch(UnsafeRowParquetRecordReader.java:640) at org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader$ColumnReader.access$000(UnsafeRowParquetRecordReader.java:461) at org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader.nextBatch(UnsafeRowParquetRecordReader.java:224) ## How was this patch tested? The unit test `ParquetHadoopFsRelationSuite` based on `HadoopFsRelationTest` will be [failed](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/52110/consoleFull) due to the lack of short type support in UnsafeRowParquetRecordReader. By adding this support, the test can be passed. Author: Liang-Chi Hsieh Closes #11412 from viirya/add-shorttype-support. --- .../parquet/UnsafeRowParquetRecordReader.java | 3 ++ .../parquet/VectorizedRleValuesReader.java | 34 ++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 9d50cfab3bd3d..e7f0ec2e77895 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -765,6 +765,9 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce } else if (DecimalType.is64BitDecimalType(column.dataType())) { defColumn.readIntsAsLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.ShortType) { + defColumn.readShorts( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { throw new NotImplementedException("Unimplemented type: " + column.dataType()); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index b2048c0e397ed..8613fcae0b805 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -301,6 +301,38 @@ public void readBytes(int total, ColumnVector c, } } + public void readShorts(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + for (int i = 0; i < n; i++) { + c.putShort(rowId + i, (short)data.readInteger()); + } + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putShort(rowId + i, (short)data.readInteger()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + public void readLongs(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; @@ -611,4 +643,4 @@ private void readNextGroup() { throw new ParquetDecodingException("not a valid mode " + this.mode); } } -} \ No newline at end of file +} From d780ed8b5c9156ac8595ac5cbb77d83c65fad64c Mon Sep 17 00:00:00 2001 From: Nong Li Date: Sat, 27 Feb 2016 19:45:57 -0800 Subject: [PATCH 2060/2089] [SPARK-13533][SQL] Fix readBytes in VectorizedPlainValuesReader ## What changes were proposed in this pull request? Fix readBytes in VectorizedPlainValuesReader. This fixes a copy and paste issue. ## How was this patch tested? Ran ParquetHadoopFsRelationSuite which failed before this. Author: Nong Li Closes #11414 from nongli/spark-13533. --- .../datasources/parquet/VectorizedPlainValuesReader.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index bf3283e85329b..57cc28e9f4e05 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -85,7 +85,7 @@ public final void readBytes(int total, ColumnVector c, int rowId) { for (int i = 0; i < total; i++) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putInt(rowId + i, buffer[offset]); + c.putByte(rowId + i, buffer[offset]); offset += 4; } } From 4c5e968db23db41c0bea802819ebd75fad63bc2b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Feb 2016 19:47:31 -0800 Subject: [PATCH 2061/2089] Closes #11413 From cca79fad66c4315b0ed6de59fd87700a540e6646 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 27 Feb 2016 19:51:28 -0800 Subject: [PATCH 2062/2089] [SPARK-13526][SQL] Move SQLContext per-session states to new class ## What changes were proposed in this pull request? This creates a `SessionState`, which groups a few fields that existed in `SQLContext`. Because `HiveContext` extends `SQLContext` we also need to make changes there. This is mainly a cleanup task that will soon pave the way for merging the two contexts. ## How was this patch tested? Existing unit tests; this patch introduces no change in behavior. Author: Andrew Or Closes #11405 from andrewor14/refactor-session. --- project/MimaExcludes.scala | 7 +- .../org/apache/spark/sql/SQLContext.scala | 88 +++++--------- .../apache/spark/sql/UDFRegistration.scala | 5 +- .../spark/sql/internal/SessionState.scala | 111 ++++++++++++++++++ .../spark/sql/test/TestSQLContext.scala | 22 ++-- .../execution/HiveCompatibilitySuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 81 ++----------- .../spark/sql/hive/HiveSessionState.scala | 103 ++++++++++++++++ .../apache/spark/sql/hive/test/TestHive.scala | 30 +++-- .../apache/spark/sql/hive/parquetSuites.scala | 7 +- 10 files changed, 294 insertions(+), 164 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 165280a1b2e41..9ce37fc753c46 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -271,12 +271,17 @@ object MimaExcludes { ) ++ Seq( // SPARK-13220 Deprecate yarn-client and yarn-cluster mode ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler"), + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( // SPARK-13465 TaskContext. ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") ) ++ Seq ( // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ Seq( + // SPARK-13526 Move SQLContext per-session states to new class + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.sql.UDFRegistration.this") ) ++ Seq( // [SPARK-13486][SQL] Move SQLConf into an internal package ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c24d9e4aeb0a..cb4a6397b261b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,13 +25,12 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql.{execution => sparkexecution} -import org.apache.spark.sql.catalyst.{InternalRow, _} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ @@ -40,9 +39,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.internal.SQLConf.SQLConfEntry import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ @@ -68,7 +66,7 @@ class SQLContext private[sql]( @transient protected[sql] val cacheManager: CacheManager, @transient private[sql] val listener: SQLListener, val isRootContext: Boolean) - extends org.apache.spark.Logging with Serializable { + extends Logging with Serializable { self => @@ -115,9 +113,27 @@ class SQLContext private[sql]( } /** - * @return Spark SQL configuration + * Per-session state, e.g. configuration, functions, temporary tables etc. */ - protected[sql] lazy val conf = new SQLConf + @transient + protected[sql] lazy val sessionState: SessionState = new SessionState(self) + protected[sql] def conf: SQLConf = sessionState.conf + protected[sql] def catalog: Catalog = sessionState.catalog + protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry + protected[sql] def analyzer: Analyzer = sessionState.analyzer + protected[sql] def optimizer: Optimizer = sessionState.optimizer + protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser + protected[sql] def planner: SparkPlanner = sessionState.planner + protected[sql] def continuousQueryManager = sessionState.continuousQueryManager + protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] = + sessionState.prepareForExecution + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + */ + @Experimental + def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** * Set Spark SQL configuration properties. @@ -179,43 +195,11 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - @transient - lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - - protected[sql] lazy val continuousQueryManager = new ContinuousQueryManager(this) - - @transient - protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) - - @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - - @transient - protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, conf) { - override val extendedResolutionRules = - python.ExtractPythonUDFs :: - PreInsertCastAndRename :: - (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) - - override val extendedCheckRules = Seq( - datasources.PreWriteCheck(catalog) - ) - } - - @transient - protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) - - @transient - protected[sql] val sqlParser: ParserInterface = new SparkQl(conf) - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) - protected[sql] def executeSql(sql: String): - org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = - new sparkexecution.QueryExecution(this, plan) + protected[sql] def executePlan(plan: LogicalPlan) = new QueryExecution(this, plan) /** * Add a jar to SQLContext @@ -299,10 +283,8 @@ class SQLContext private[sql]( * * @group basic * @since 1.3.0 - * TODO move to SQLSession? */ - @transient - val udf: UDFRegistration = new UDFRegistration(this) + def udf: UDFRegistration = sessionState.udf /** * Returns true if the table is currently cached in-memory. @@ -872,25 +854,9 @@ class SQLContext private[sql]( }.toArray } - @transient - protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) - @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) - /** - * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - @transient - protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { - val batches = Seq( - Batch("Subquery", Once, PlanSubqueries(self)), - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) - ) - } - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index de01cbcb0e1b3..d894825632a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.spark.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF @@ -34,9 +35,7 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { - - private val functionRegistry = sqlContext.functionRegistry +class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala new file mode 100644 index 0000000000000..f93a405f77fc7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration} +import org.apache.spark.sql.catalyst.ParserInterface +import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.{PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.util.ExecutionListenerManager + + +/** + * A class that holds all session-specific state in a given [[SQLContext]]. + */ +private[sql] class SessionState(ctx: SQLContext) { + + // Note: These are all lazy vals because they depend on each other (e.g. conf) and we + // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. + + /** + * SQL-specific key-value configurations. + */ + lazy val conf = new SQLConf + + /** + * Internal catalog for managing table and database states. + */ + lazy val catalog: Catalog = new SimpleCatalog(conf) + + /** + * Internal catalog for managing functions registered by the user. + */ + lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() + + /** + * Interface exposed to the user for registering user-defined functions. + */ + lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + */ + lazy val analyzer: Analyzer = { + new Analyzer(catalog, functionRegistry, conf) { + override val extendedResolutionRules = + python.ExtractPythonUDFs :: + PreInsertCastAndRename :: + (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) + + override val extendedCheckRules = Seq(datasources.PreWriteCheck(catalog)) + } + } + + /** + * Logical query plan optimizer. + */ + lazy val optimizer: Optimizer = new SparkOptimizer(ctx) + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + */ + lazy val sqlParser: ParserInterface = new SparkQl(conf) + + /** + * Planner that converts optimized logical plans to physical plans. + */ + lazy val planner: SparkPlanner = new SparkPlanner(ctx) + + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + lazy val prepareForExecution = new RuleExecutor[SparkPlan] { + override val batches: Seq[Batch] = Seq( + Batch("Subquery", Once, PlanSubqueries(ctx)), + Batch("Add exchange", Once, EnsureRequirements(ctx)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)) + ) + } + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + */ + lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + + /** + * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. + */ + lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 28ad7ae64a517..b3e146fba80be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SessionState, SQLConf} /** * A special [[SQLContext]] prepared for testing. @@ -31,16 +31,16 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel new SparkConf().set("spark.sql.testkey", "true"))) } - protected[sql] override lazy val conf: SQLConf = new SQLConf { - - clear() - - override def clear(): Unit = { - super.clear() - - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { - case (key, value) => setConfString(key, value) + @transient + protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d15f883138938..0dc2a95eea70e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -55,7 +55,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Use Hive hash expression instead of the native one - TestHive.functionRegistry.unregisterFunction("hash") + TestHive.sessionState.functionRegistry.unregisterFunction("hash") RuleExecutor.resetTime() } @@ -65,7 +65,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - TestHive.functionRegistry.restore() + TestHive.sessionState.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d511dd685ce75..a9295d31c07bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -40,17 +40,15 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserInterface} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck, ResolveDataSource} import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SQLConfEntry import org.apache.spark.sql.internal.SQLConf.SQLConfEntry._ import org.apache.spark.sql.types._ @@ -110,6 +108,16 @@ class HiveContext private[hive]( isRootContext = false) } + @transient + protected[sql] override lazy val sessionState = new HiveSessionState(self) + + protected[sql] override def catalog = sessionState.catalog + + // The Hive UDF current_database() is foldable, will be evaluated by optimizer, + // but the optimizer can't access the SessionState of metadataHive. + sessionState.functionRegistry.registerFunction( + "current_database", (e: Seq[Expression]) => new CurrentDatabase(self)) + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive @@ -442,39 +450,6 @@ class HiveContext private[hive]( setConf(entry.key, entry.stringConverter(value)) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ - @transient - override protected[sql] lazy val catalog = - new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog - - // Note that HiveUDFs will be overridden by functions registered in this context. - @transient - override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this.executionHive) - - // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer - // can't access the SessionState of metadataHive. - functionRegistry.registerFunction( - "current_database", - (expressions: Seq[Expression]) => new CurrentDatabase(this)) - - /* An analyzer that uses the Hive metastore. */ - @transient - override protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, conf) { - override val extendedResolutionRules = - catalog.ParquetConversions :: - catalog.CreateTables :: - catalog.PreInsertionCasts :: - python.ExtractPythonUDFs :: - PreInsertCastAndRename :: - (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) - - override val extendedCheckRules = Seq( - PreWriteCheck(catalog) - ) - } - /** Overridden by child classes that need to set configuration before the client init. */ protected def configure(): Map[String, String] = { // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch @@ -544,37 +519,6 @@ class HiveContext private[hive]( c } - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - } - - @transient - protected[sql] override val sqlParser: ParserInterface = new HiveQl(conf) - - @transient - private val hivePlanner = new SparkPlanner(this) with HiveStrategies { - val hiveContext = self - - override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( - DataSourceStrategy, - HiveCommandStrategy(self), - HiveDDLStrategy, - DDLStrategy, - SpecialLimits, - InMemoryScans, - HiveTableScans, - DataSinks, - Scripts, - Aggregation, - LeftSemiJoin, - EquiJoinSelection, - BasicOperators, - BroadcastNestedLoop, - CartesianProduct, - DefaultJoin - ) - } - private def functionOrMacroDDLPattern(command: String) = Pattern.compile( ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL).matcher(command) @@ -590,9 +534,6 @@ class HiveContext private[hive]( } } - @transient - override protected[sql] val planner = hivePlanner - /** Extends QueryExecution with hive specific features. */ protected[sql] class QueryExecution(logicalPlan: LogicalPlan) extends org.apache.spark.sql.execution.QueryExecution(this, logicalPlan) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala new file mode 100644 index 0000000000000..09f54be04d0c7 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.ParserInterface +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, OverrideCatalog} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.{python, SparkPlanner} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.{SessionState, SQLConf} + + +/** + * A class that holds all session-specific state in a given [[HiveContext]]. + */ +private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) { + + override lazy val conf: SQLConf = new SQLConf { + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + } + + /** + * A metadata catalog that points to the Hive metastore. + */ + override lazy val catalog = new HiveMetastoreCatalog(ctx.metadataHive, ctx) with OverrideCatalog + + /** + * Internal catalog for managing functions registered by the user. + * Note that HiveUDFs will be overridden by functions registered in this context. + */ + override lazy val functionRegistry: FunctionRegistry = { + new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), ctx.executionHive) + } + + /** + * An analyzer that uses the Hive metastore. + */ + override lazy val analyzer: Analyzer = { + new Analyzer(catalog, functionRegistry, conf) { + override val extendedResolutionRules = + catalog.ParquetConversions :: + catalog.CreateTables :: + catalog.PreInsertionCasts :: + python.ExtractPythonUDFs :: + PreInsertCastAndRename :: + (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) + + override val extendedCheckRules = Seq(PreWriteCheck(catalog)) + } + } + + /** + * Parser for HiveQl query texts. + */ + override lazy val sqlParser: ParserInterface = new HiveQl(conf) + + /** + * Planner that takes into account Hive-specific strategies. + */ + override lazy val planner: SparkPlanner = { + new SparkPlanner(ctx) with HiveStrategies { + override val hiveContext = ctx + + override def strategies: Seq[Strategy] = { + ctx.experimental.extraStrategies ++ Seq( + DataSourceStrategy, + HiveCommandStrategy(ctx), + HiveDDLStrategy, + DDLStrategy, + SpecialLimits, + InMemoryScans, + HiveTableScans, + DataSinks, + Scripts, + Aggregation, + LeftSemiJoin, + EquiJoinSelection, + BasicOperators, + BroadcastNestedLoop, + CartesianProduct, + DefaultJoin + ) + } + } + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9d0622bea92c7..a7eca46d1980d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -120,18 +120,25 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - - clear() - - override def clear(): Unit = { - super.clear() - - TestHiveContext.overrideConfs.map { - case (key, value) => setConfString(key, value) + @transient + protected[sql] override lazy val sessionState = new HiveSessionState(this) { + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + override def clear(): Unit = { + super.clear() + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } } } + + override lazy val functionRegistry = { + new TestHiveFunctionRegistry( + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), self.executionHive) + } } /** @@ -454,9 +461,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } - @transient - override protected[sql] lazy val functionRegistry = new TestHiveFunctionRegistry( - org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) } private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index a127cf6e4b7d4..d81c5668226e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -425,7 +425,8 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("Caching converted data source Parquet Relations") { - def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { + val _catalog = catalog + def checkCached(tableIdentifier: _catalog.QualifiedTableName): Unit = { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") @@ -452,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = _catalog.QualifiedTableName("default", "test_insert_parquet") // First, make sure the converted test_parquet is not cached. assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) @@ -492,7 +493,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + tableIdentifier = _catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ From 9e01dcc6446f8648e61062f8afe62589b9d4b5ab Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Feb 2016 17:25:07 -0800 Subject: [PATCH 2063/2089] [SPARK-13529][BUILD] Move network/* modules into common/network-* ## What changes were proposed in this pull request? As the title says, this moves the three modules currently in network/ into common/network-*. This removes one top level, non-user-facing folder. ## How was this patch tested? Compilation and existing tests. We should run both SBT and Maven. Author: Reynold Xin Closes #11409 from rxin/SPARK-13529. --- {network/common => common/network-common}/pom.xml | 0 .../java/org/apache/spark/network/TransportContext.java | 0 .../spark/network/buffer/FileSegmentManagedBuffer.java | 0 .../org/apache/spark/network/buffer/LazyFileRegion.java | 0 .../java/org/apache/spark/network/buffer/ManagedBuffer.java | 0 .../org/apache/spark/network/buffer/NettyManagedBuffer.java | 0 .../org/apache/spark/network/buffer/NioManagedBuffer.java | 0 .../spark/network/client/ChunkFetchFailureException.java | 0 .../apache/spark/network/client/ChunkReceivedCallback.java | 0 .../apache/spark/network/client/RpcResponseCallback.java | 0 .../org/apache/spark/network/client/StreamCallback.java | 0 .../org/apache/spark/network/client/StreamInterceptor.java | 0 .../org/apache/spark/network/client/TransportClient.java | 0 .../spark/network/client/TransportClientBootstrap.java | 0 .../apache/spark/network/client/TransportClientFactory.java | 0 .../spark/network/client/TransportResponseHandler.java | 0 .../org/apache/spark/network/protocol/AbstractMessage.java | 0 .../spark/network/protocol/AbstractResponseMessage.java | 0 .../apache/spark/network/protocol/ChunkFetchFailure.java | 0 .../apache/spark/network/protocol/ChunkFetchRequest.java | 0 .../apache/spark/network/protocol/ChunkFetchSuccess.java | 0 .../java/org/apache/spark/network/protocol/Encodable.java | 0 .../java/org/apache/spark/network/protocol/Encoders.java | 0 .../java/org/apache/spark/network/protocol/Message.java | 0 .../org/apache/spark/network/protocol/MessageDecoder.java | 0 .../org/apache/spark/network/protocol/MessageEncoder.java | 0 .../apache/spark/network/protocol/MessageWithHeader.java | 0 .../org/apache/spark/network/protocol/OneWayMessage.java | 0 .../org/apache/spark/network/protocol/RequestMessage.java | 0 .../org/apache/spark/network/protocol/ResponseMessage.java | 0 .../java/org/apache/spark/network/protocol/RpcFailure.java | 0 .../java/org/apache/spark/network/protocol/RpcRequest.java | 0 .../java/org/apache/spark/network/protocol/RpcResponse.java | 0 .../org/apache/spark/network/protocol/StreamChunkId.java | 0 .../org/apache/spark/network/protocol/StreamFailure.java | 0 .../org/apache/spark/network/protocol/StreamRequest.java | 0 .../org/apache/spark/network/protocol/StreamResponse.java | 0 .../org/apache/spark/network/sasl/SaslClientBootstrap.java | 0 .../java/org/apache/spark/network/sasl/SaslEncryption.java | 0 .../apache/spark/network/sasl/SaslEncryptionBackend.java | 0 .../java/org/apache/spark/network/sasl/SaslMessage.java | 0 .../java/org/apache/spark/network/sasl/SaslRpcHandler.java | 0 .../org/apache/spark/network/sasl/SaslServerBootstrap.java | 0 .../java/org/apache/spark/network/sasl/SecretKeyHolder.java | 0 .../java/org/apache/spark/network/sasl/SparkSaslClient.java | 0 .../java/org/apache/spark/network/sasl/SparkSaslServer.java | 0 .../org/apache/spark/network/server/MessageHandler.java | 0 .../org/apache/spark/network/server/NoOpRpcHandler.java | 0 .../apache/spark/network/server/OneForOneStreamManager.java | 0 .../java/org/apache/spark/network/server/RpcHandler.java | 0 .../java/org/apache/spark/network/server/StreamManager.java | 0 .../spark/network/server/TransportChannelHandler.java | 0 .../spark/network/server/TransportRequestHandler.java | 0 .../org/apache/spark/network/server/TransportServer.java | 0 .../spark/network/server/TransportServerBootstrap.java | 0 .../apache/spark/network/util/ByteArrayWritableChannel.java | 0 .../main/java/org/apache/spark/network/util/ByteUnit.java | 0 .../java/org/apache/spark/network/util/ConfigProvider.java | 0 .../src/main/java/org/apache/spark/network/util/IOMode.java | 0 .../main/java/org/apache/spark/network/util/JavaUtils.java | 0 .../org/apache/spark/network/util/LimitedInputStream.java | 0 .../org/apache/spark/network/util/MapConfigProvider.java | 0 .../main/java/org/apache/spark/network/util/NettyUtils.java | 0 .../spark/network/util/SystemPropertyConfigProvider.java | 0 .../java/org/apache/spark/network/util/TransportConf.java | 0 .../apache/spark/network/util/TransportFrameDecoder.java | 0 .../apache/spark/network/ChunkFetchIntegrationSuite.java | 0 .../test/java/org/apache/spark/network/ProtocolSuite.java | 0 .../spark/network/RequestTimeoutIntegrationSuite.java | 0 .../java/org/apache/spark/network/RpcIntegrationSuite.java | 0 .../src/test/java/org/apache/spark/network/StreamSuite.java | 0 .../java/org/apache/spark/network/TestManagedBuffer.java | 0 .../src/test/java/org/apache/spark/network/TestUtils.java | 0 .../apache/spark/network/TransportClientFactorySuite.java | 0 .../apache/spark/network/TransportResponseHandlerSuite.java | 0 .../spark/network/protocol/MessageWithHeaderSuite.java | 0 .../java/org/apache/spark/network/sasl/SparkSaslSuite.java | 0 .../spark/network/server/OneForOneStreamManagerSuite.java | 0 .../spark/network/util/TransportFrameDecoderSuite.java | 0 .../network-common}/src/test/resources/log4j.properties | 0 {network/shuffle => common/network-shuffle}/pom.xml | 0 .../org/apache/spark/network/sasl/ShuffleSecretManager.java | 0 .../apache/spark/network/shuffle/BlockFetchingListener.java | 0 .../spark/network/shuffle/ExternalShuffleBlockHandler.java | 0 .../spark/network/shuffle/ExternalShuffleBlockResolver.java | 0 .../apache/spark/network/shuffle/ExternalShuffleClient.java | 0 .../apache/spark/network/shuffle/OneForOneBlockFetcher.java | 0 .../apache/spark/network/shuffle/RetryingBlockFetcher.java | 0 .../org/apache/spark/network/shuffle/ShuffleClient.java | 0 .../network/shuffle/mesos/MesosExternalShuffleClient.java | 0 .../network/shuffle/protocol/BlockTransferMessage.java | 0 .../spark/network/shuffle/protocol/ExecutorShuffleInfo.java | 0 .../apache/spark/network/shuffle/protocol/OpenBlocks.java | 0 .../spark/network/shuffle/protocol/RegisterExecutor.java | 0 .../apache/spark/network/shuffle/protocol/StreamHandle.java | 0 .../apache/spark/network/shuffle/protocol/UploadBlock.java | 0 .../network/shuffle/protocol/mesos/RegisterDriver.java | 0 .../org/apache/spark/network/sasl/SaslIntegrationSuite.java | 0 .../spark/network/shuffle/BlockTransferMessagesSuite.java | 0 .../network/shuffle/ExternalShuffleBlockHandlerSuite.java | 0 .../network/shuffle/ExternalShuffleBlockResolverSuite.java | 0 .../spark/network/shuffle/ExternalShuffleCleanupSuite.java | 0 .../network/shuffle/ExternalShuffleIntegrationSuite.java | 0 .../spark/network/shuffle/ExternalShuffleSecuritySuite.java | 0 .../spark/network/shuffle/OneForOneBlockFetcherSuite.java | 0 .../spark/network/shuffle/RetryingBlockFetcherSuite.java | 0 .../spark/network/shuffle/TestShuffleDataContext.java | 0 {network/yarn => common/network-yarn}/pom.xml | 0 .../org/apache/spark/network/yarn/YarnShuffleService.java | 0 .../spark/network/yarn/util/HadoopConfigProvider.java | 0 dev/sparktestsupport/modules.py | 4 ++-- docs/job-scheduling.md | 2 +- .../org/apache/spark/launcher/AbstractCommandBuilder.java | 3 ++- make-distribution.sh | 2 +- pom.xml | 6 +++--- 115 files changed, 9 insertions(+), 8 deletions(-) rename {network/common => common/network-common}/pom.xml (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/TransportContext.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/StreamCallback.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/StreamInterceptor.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/TransportClient.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/TransportClientFactory.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/Encodable.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/Encoders.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/Message.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/RequestMessage.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/RpcFailure.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/RpcRequest.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/RpcResponse.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/StreamFailure.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/StreamRequest.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/protocol/StreamResponse.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SaslMessage.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/MessageHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/RpcHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/StreamManager.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/TransportServer.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/ByteUnit.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/ConfigProvider.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/IOMode.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/JavaUtils.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/LimitedInputStream.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/MapConfigProvider.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/NettyUtils.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/TransportConf.java (100%) rename {network/common => common/network-common}/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/ProtocolSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/StreamSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/TestManagedBuffer.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/TestUtils.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java (100%) rename {network/common => common/network-common}/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java (100%) rename {network/common => common/network-common}/src/test/resources/log4j.properties (100%) rename {network/shuffle => common/network-shuffle}/pom.xml (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java (100%) rename {network/shuffle => common/network-shuffle}/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java (100%) rename {network/shuffle => common/network-shuffle}/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java (100%) rename {network/yarn => common/network-yarn}/pom.xml (100%) rename {network/yarn => common/network-yarn}/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java (100%) rename {network/yarn => common/network-yarn}/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java (100%) diff --git a/network/common/pom.xml b/common/network-common/pom.xml similarity index 100% rename from network/common/pom.xml rename to common/network-common/pom.xml diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/TransportContext.java rename to common/network-common/src/main/java/org/apache/spark/network/TransportContext.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java rename to common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClient.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Message.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java rename to common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/StreamManager.java rename to common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportServer.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/IOMode.java rename to common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java rename to common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java rename to common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java rename to common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/TransportConf.java rename to common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java rename to common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/StreamSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java rename to common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/common/network-common/src/test/java/org/apache/spark/network/TestUtils.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TestUtils.java rename to common/network-common/src/test/java/org/apache/spark/network/TestUtils.java diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java rename to common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java diff --git a/network/common/src/test/resources/log4j.properties b/common/network-common/src/test/resources/log4j.properties similarity index 100% rename from network/common/src/test/resources/log4j.properties rename to common/network-common/src/test/resources/log4j.properties diff --git a/network/shuffle/pom.xml b/common/network-shuffle/pom.xml similarity index 100% rename from network/shuffle/pom.xml rename to common/network-shuffle/pom.xml diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java similarity index 100% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java diff --git a/network/yarn/pom.xml b/common/network-yarn/pom.xml similarity index 100% rename from network/yarn/pom.xml rename to common/network-yarn/pom.xml diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java similarity index 100% rename from network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java rename to common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java similarity index 100% rename from network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java rename to common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 07c3078e45490..4e04672ad39e0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -473,11 +473,11 @@ def __hash__(self): dependencies=[], source_file_regexes=[ "yarn/", - "network/yarn/", + "common/network-yarn/", ], sbt_test_goals=[ "yarn/test", - "network-yarn/test", + "common/network-yarn/test", ], test_tags=[ "org.apache.spark.tags.ExtendedYarnTest" diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 00b6a18836a54..083c020caa5db 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -88,7 +88,7 @@ In YARN mode, start the shuffle service on each `NodeManager` as follows: 1. Build Spark with the [YARN profile](building-spark.html). Skip this step if you are using a pre-packaged distribution. 2. Locate the `spark--yarn-shuffle.jar`. This should be under -`$SPARK_HOME/network/yarn/target/scala-` if you are building Spark yourself, and under +`$SPARK_HOME/common/network-yarn/target/scala-` if you are building Spark yourself, and under `lib` if you are using a distribution. 2. Add this jar to the classpath of all `NodeManager`s in your cluster. 3. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 68af14397ba81..45c2c008f6dc4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,8 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); + "yarn", "launcher", + "common/network-common", "common/network-shuffle", "common/network-yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/make-distribution.sh b/make-distribution.sh index 20998144f0b50..ac90ea317a6fc 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -169,7 +169,7 @@ cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # This will fail if the -Pyarn profile is not provided # In this case, silence the error and ignore the return code of this command -cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : +cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" diff --git a/pom.xml b/pom.xml index bcf04d48c2f6f..2376e307ced18 100644 --- a/pom.xml +++ b/pom.xml @@ -87,13 +87,13 @@ common/sketch + common/network-common + common/network-shuffle tags core graphx mllib tools - network/common - network/shuffle streaming sql/catalyst sql/core @@ -2442,7 +2442,7 @@ yarn yarn - network/yarn + common/network-yarn From 6dfc4a764c8bcfc24d951239835015da3ed7c29e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 28 Feb 2016 21:16:06 -0800 Subject: [PATCH 2064/2089] [SPARK-13537][SQL] Fix readBytes in VectorizedPlainValuesReader JIRA: https://issues.apache.org/jira/browse/SPARK-13537 ## What changes were proposed in this pull request? In readBytes of VectorizedPlainValuesReader, we use buffer[offset] to access bytes in buffer. It is incorrect because offset is added with Platform.BYTE_ARRAY_OFFSET when initialization. We should fix it. ## How was this patch tested? `ParquetHadoopFsRelationSuite` sometimes (depending on the randomly generated data) will be [failed](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/52136/consoleFull) by this bug. After applying this, the test can be passed. I added a test to `ParquetHadoopFsRelationSuite` with the data which will fail without this patch. The error exception: [info] ParquetHadoopFsRelationSuite: [info] - test all data types - StringType (440 milliseconds) [info] - test all data types - BinaryType (434 milliseconds) [info] - test all data types - BooleanType (406 milliseconds) 20:59:38.618 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 2597.0 (TID 67966) java.lang.ArrayIndexOutOfBoundsException: 46 at org.apache.spark.sql.execution.datasources.parquet.VectorizedPlainValuesReader.readBytes(VectorizedPlainValuesReader.java:88) Author: Liang-Chi Hsieh Closes #11418 from viirya/fix-readbytes. --- .../parquet/VectorizedPlainValuesReader.java | 2 +- .../ParquetHadoopFsRelationSuite.scala | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 57cc28e9f4e05..ee9a7a221bbde 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -85,7 +85,7 @@ public final void readBytes(int total, ColumnVector c, int rowId) { for (int i = 0; i < total; i++) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putByte(rowId + i, buffer[offset]); + c.putByte(rowId + i, Platform.getByte(buffer, offset)); offset += 4; } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 5ce58e898ebec..f2501d7ce3590 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -175,4 +175,37 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + + test(s"SPARK-13537: Fix readBytes in VectorizedPlainValuesReader") { + withTempPath { file => + val path = file.getCanonicalPath + + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", ByteType, nullable = true) + + val data = Seq(Row(1, -33.toByte), Row(2, 0.toByte), Row(3, -55.toByte), Row(4, 56.toByte), + Row(5, 127.toByte), Row(6, -44.toByte), Row(7, 23.toByte), Row(8, -95.toByte), + Row(9, 127.toByte), Row(10, 13.toByte)) + + val rdd = sqlContext.sparkContext.parallelize(data) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) + + val loadedDF = sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } + } } From dd3b5455c61bddce96a94c2ce8f5d76ed4948ea1 Mon Sep 17 00:00:00 2001 From: Rahul Tanwani Date: Sun, 28 Feb 2016 23:16:34 -0800 Subject: [PATCH 2065/2089] [SPARK-13309][SQL] Fix type inference issue with CSV data Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309 Author: Rahul Tanwani Closes #11194 from tanwanirahul/master. --- .../datasources/csv/CSVInferSchema.scala | 18 +++++++++--------- sql/core/src/test/resources/simple_sparse.csv | 5 +++++ .../datasources/csv/CSVInferSchemaSuite.scala | 5 +++++ .../execution/datasources/csv/CSVSuite.scala | 14 +++++++++++++- 4 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/resources/simple_sparse.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index ace8cd7ad864e..7f1ed28046b1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ - private[csv] object CSVInferSchema { /** @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -65,12 +68,8 @@ private[csv] object CSVInferSchema { } def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + first.zipAll(second, NullType, NullType).map { case (a, b) => + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -140,6 +139,8 @@ private[csv] object CSVInferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => @@ -150,7 +151,6 @@ private[csv] object CSVInferSchema { } } - private[csv] object CSVTypeCast { /** diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 0000000000000..02d29cabf95f2 --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index a1796f1326007..412f1b89beee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) } + + test("Merging Nulltypes should yeild Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 7671bc106610a..5d57d77ab0549 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema.fieldNames.size === 1) } - test("DDL test with empty file") { sqlContext.sql(s""" |CREATE TEMPORARY TABLE carsTable @@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(carsCopy, withHeader = true) } } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map(field => field.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } From d81a71357e24160244b6eeff028b0d9a4863becf Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 29 Feb 2016 00:55:51 -0800 Subject: [PATCH 2066/2089] [SPARK-13545][MLLIB][PYSPARK] Make MLlib LogisticRegressionWithLBFGS's default parameters consistent in Scala and Python ## What changes were proposed in this pull request? * The default value of ```regParam``` of PySpark MLlib ```LogisticRegressionWithLBFGS``` should be consistent with Scala which is ```0.0```. (This is also consistent with ML ```LogisticRegression```.) * BTW, if we use a known updater(L1 or L2) for binary classification, ```LogisticRegressionWithLBFGS``` will call the ML implementation. We should update the API doc to clarifying ```numCorrections``` will have no effect if we fall into that route. * Make a pass for all parameters of ```LogisticRegressionWithLBFGS```, others are set properly. cc mengxr dbtsai ## How was this patch tested? No new tests, it should pass all current tests. Author: Yanbo Liang Closes #11424 from yanboliang/spark-13545. --- .../spark/mllib/classification/LogisticRegression.scala | 4 ++++ python/pyspark/mllib/classification.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index c3882606d7dbd..f807b5683c390 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -408,6 +408,10 @@ class LogisticRegressionWithLBFGS * defaults to the mllib implementation. If more than two classes * or feature scaling is disabled, always uses mllib implementation. * Uses user provided weights. + * + * In the ml LogisticRegression implementation, the number of corrections + * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()` + * will have no effect if we fall into that route. */ override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { run(input, initialWeights, userSuppliedWeights = true) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index b4d54ef61b0e6..53a0df27cace2 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -326,7 +326,7 @@ class LogisticRegressionWithLBFGS(object): """ @classmethod @since('1.2.0') - def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", + def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2", intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -341,7 +341,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: None) :param regParam: The regularizer parameter. - (default: 0.01) + (default: 0.0) :param regType: The type of regularizer used for training our model. Allowed values: @@ -356,7 +356,9 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: False) :param corrections: The number of corrections used in the LBFGS update. - (default: 10) + If a known updater is used for binary classification, + it calls the ml implementation and this parameter will + have no effect. (default: 10) :param tolerance: The convergence tolerance of iterations for L-BFGS. (default: 1e-6) From 99fe8993f51d3c72cd95eb0825b090dd4d4cd2cd Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 29 Feb 2016 12:08:37 +0000 Subject: [PATCH 2067/2089] =?UTF-8?q?[SPARK-12994][CORE]=20It=20is=20not?= =?UTF-8?q?=20necessary=20to=20create=20ExecutorAllocationM=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …anager in local mode Author: Jeff Zhang Closes #10914 from zjffdu/SPARK-12994. --- .../scala/org/apache/spark/SparkContext.scala | 6 +----- .../scala/org/apache/spark/util/Utils.scala | 19 +++++++++++++++++-- .../org/apache/spark/util/UtilsSuite.scala | 3 +++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a1fa266e183e1..0e8b735b923bb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -244,7 +244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec - def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + def isLocal: Boolean = Utils.isLocalMaster(_conf) /** * @return true if context is stopped or in the midst of stopping. @@ -526,10 +526,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) - if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") - } - _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e0c9bf02a1a20..6103a10ccc50e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2195,6 +2195,16 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + + /** + * + * @return whether it is local mode + */ + def isLocalMaster(conf: SparkConf): Boolean = { + val master = conf.get("spark.master", "") + master == "local" || master.startsWith("local[") + } + /** * Return whether dynamic allocation is enabled in the given conf * Dynamic allocation and explicitly setting the number of executors are inherently @@ -2202,8 +2212,13 @@ private[spark] object Utils extends Logging { * the latter should override the former (SPARK-9092). */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - conf.getBoolean("spark.dynamicAllocation.enabled", false) && - conf.getInt("spark.executor.instances", 0) == 0 + val numExecutor = conf.getInt("spark.executor.instances", 0) + val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) + if (numExecutor != 0 && dynamicAllocationEnabled) { + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + numExecutor == 0 && dynamicAllocationEnabled && + (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) } def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 7c6778b065467..412c0ac9d9be3 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -722,6 +722,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isDynamicAllocationEnabled") { val conf = new SparkConf() + conf.set("spark.master", "yarn-client") assert(Utils.isDynamicAllocationEnabled(conf) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "false")) === false) @@ -731,6 +732,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "1")) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.executor.instances", "0")) === true) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true"))) } test("encodeFileNameToURIRawPath") { From 236e3c8fbc887e4da4f143cbf533f016f21c10d4 Mon Sep 17 00:00:00 2001 From: vijaykiran Date: Mon, 29 Feb 2016 15:52:41 +0200 Subject: [PATCH 2068/2089] [SPARK-12633][PYSPARK] [DOC] PySpark regression parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the regression module. Also, updated 2 params in classification to read as `Supported values:` to be consistent. closes #10600 Author: vijaykiran Author: Bryan Cutler Closes #11404 from BryanCutler/param-desc-consistent-regression-SPARK-12633. --- python/pyspark/mllib/classification.py | 4 +- python/pyspark/mllib/regression.py | 326 +++++++++++++------------ 2 files changed, 166 insertions(+), 164 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 53a0df27cace2..57106f8690a7d 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -294,7 +294,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: 0.01) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) @@ -344,7 +344,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType= (default: 0.0) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 4dd7083d79c8c..3b77a6200054f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -37,10 +37,11 @@ class LabeledPoint(object): """ Class that represents the features and labels of a data point. - :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + :param label: + Label for this data point. + :param features: + Vector of features for this point (NumPy array, list, + pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). Note: 'label' and 'features' are accessible as class attributes. @@ -66,8 +67,10 @@ class LinearModel(object): """ A linear model that has a vector of coefficients and an intercept. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. .. versionadded:: 0.9.0 """ @@ -217,19 +220,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): """ - Train a linear regression model with no regularization using Stochastic Gradient Descent. - This solves the least squares regression formulation - - f(weights) = 1/n ||A weights-y||^2 - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @@ -237,47 +229,52 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient - Descent (SGD). - This solves the least squares regression formulation - - f(weights) = 1/(2n) ||A weights - y||^2, - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.0). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization (lasso), - - "l2" for using L2 regularization (ridge), - - None for no regularization - - (default: None) - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Descent (SGD). This solves the least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + + which is the mean squared error. Here the data matrix has n rows, + and the input RDD holds the set of rows of A, each with its + corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.0) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization (default) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -368,56 +365,53 @@ def load(cls, sc, path): class LassoWithSGD(object): """ - Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the L1-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L1-regularization using - Stochastic Gradient Descent. - This solves the l1-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L1-regularization using Stochastic + Gradient Descent. This solves the l1-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -508,56 +502,53 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ - Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the L2-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L2-regularization using - Stochastic Gradient Descent. - This solves the l2-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L2-regularization using Stochastic + Gradient Descent. This solves the l2-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), @@ -572,12 +563,14 @@ class IsotonicRegressionModel(Saveable, Loader): """ Regression model for isotonic regression. - :param boundaries: Array of boundaries for which predictions are - known. Boundaries must be sorted in increasing order. - :param predictions: Array of predictions associated to the - boundaries at the same index. Results of isotonic - regression and therefore monotone. - :param isotonic: indicates whether this is isotonic or antitonic. + :param boundaries: + Array of boundaries for which predictions are known. Boundaries + must be sorted in increasing order. + :param predictions: + Array of predictions associated to the boundaries at the same + index. Results of isotonic regression and therefore monotone. + :param isotonic: + Indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -628,7 +621,8 @@ def predict(self, x): values with the same boundary then the same rules as in 2) are used. - :param x: Feature or RDD of Features to be labeled. + :param x: + Feature or RDD of Features to be labeled. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -657,8 +651,8 @@ def load(cls, sc, path): class IsotonicRegression(object): """ Isotonic regression. - Currently implemented using parallelized pool adjacent violators algorithm. - Only univariate (single feature) algorithm supported. + Currently implemented using parallelized pool adjacent violators + algorithm. Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: @@ -684,8 +678,11 @@ def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. + :param data: + RDD of (label, feature, weight) tuples. + :param isotonic: + Whether this is isotonic (which is default) or antitonic. + (default: True) """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) @@ -721,9 +718,11 @@ def _validate(self, dstream): @since("1.5.0") def predictOn(self, dstream): """ - Make predictions on a dstream. + Use the model to make predictions on batches of data from a + DStream. - :return: Transformed dstream object. + :return: + DStream containing predictions. """ self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) @@ -731,9 +730,11 @@ def predictOn(self, dstream): @since("1.5.0") def predictOnValues(self, dstream): """ - Make predictions on a keyed dstream. + Use the model to make predictions on the values of a DStream and + carry over its keys. - :return: Transformed dstream object. + :return: + DStream containing the input keys and the predictions as values. """ self._validate(dstream) return dstream.mapValues(lambda x: self._model.predict(x)) @@ -742,14 +743,15 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Train or predict a linear regression model on streaming data. Training uses - Stochastic Gradient Descent to update the model based on each new batch of - incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + Train or predict a linear regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model + based on each new batch of incoming data from a DStream + (see `LinearRegressionWithSGD` for model equation). Each batch of data is assumed to be an RDD of LabeledPoints. The number of data points per batch can vary, but the number - of features must be constant. An initial weight - vector must be provided. + of features must be constant. An initial weight vector must + be provided. :param stepSize: Step size for each iteration of gradient descent. From 2f91f5ac0d2a5932169b245e3ef3e19849131277 Mon Sep 17 00:00:00 2001 From: zhuol Date: Mon, 29 Feb 2016 08:37:33 -0600 Subject: [PATCH 2069/2089] [SPARK-13481] Desc order of appID by default for history server page. ## What changes were proposed in this pull request? Now by default, it shows as ascending order of appId. We might prefer to display as descending order by default, which will show the latest application at the top. ## How was this patch tested? Manual tested. See screenshot below: ![desc-sort](https://cloud.githubusercontent.com/assets/11683054/13307473/102f4cf8-db31-11e5-8dd5-391edbf32f0d.png) Author: zhuol Closes #11357 from zhuoliu/13481. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 6195916195e38..167c8020850d5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -149,7 +149,8 @@ $(document).ready(function() { {name: 'seventh'}, {name: 'eighth'}, ], - "autoWidth": false + "autoWidth": false, + "order": [[ 0, "desc" ]] }; var rowGroupConf = { From ac5c635281aa796547e31e20c10d2483469294ee Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 29 Feb 2016 14:51:27 +0000 Subject: [PATCH 2070/2089] [SPARK-13506][MLLIB] Fix the wrong parameter in R code comment in AssociationRulesSuite JIRA: https://issues.apache.org/jira/browse/SPARK-13506 ## What changes were proposed in this pull request? just chang R Snippet Comment in AssociationRulesSuite ## How was this patch tested? unit test passsed Author: Zheng RuiFeng Closes #11387 from zhengruifeng/ars. --- .../org/apache/spark/mllib/fpm/AssociationRulesSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index 77a2773c36f56..dcb1f398b04b8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -42,6 +42,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { .collect() /* Verify results using the `R` code: + library(arules) transactions = as(sapply( list("r z h k p", "z y x w v u t s", @@ -52,7 +53,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { FUN=function(x) strsplit(x," ",fixed=TRUE)), "transactions") ars = apriori(transactions, - parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + parameter = list(support = 0.5, confidence = 0.9, target="rules", minlen=2)) arsDF = as(ars, "data.frame") arsDF$support = arsDF$support * length(transactions) names(arsDF)[names(arsDF) == "support"] = "freq" From 916fc34f98dd731f607d9b3ed657bad6cc30df2c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Mar 2016 01:07:45 +0800 Subject: [PATCH 2071/2089] [SPARK-13540][SQL] Supports using nested classes within Scala objects as Dataset element type ## What changes were proposed in this pull request? Nested classes defined within Scala objects are translated into Java static nested classes. Unlike inner classes, they don't need outer scopes. But the analyzer still thinks that an outer scope is required. This PR fixes this issue simply by checking whether a nested class is static before looking up its outer scope. ## How was this patch tested? A test case is added to `DatasetSuite`. It checks contents of a Dataset whose element type is a nested class declared in a Scala object. Author: Cheng Lian Closes #11421 from liancheng/spark-13540-object-as-outer-scope. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++++++- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 23e4709bbd882..876aa0eae0e90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -559,7 +561,13 @@ class Analyzer( } resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + case n: NewInstance + // If this is an inner class of another class, register the outer object in `OuterScopes`. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + if n.outerPointer.isEmpty && + n.cls.isMemberClass && + !Modifier.isStatic(n.cls.getModifiers) => val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 14fc37b64aa36..33df6375e3aad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -621,12 +621,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds.filter(_ => true), Some(1), Some(2), Some(3)) } + + test("SPARK-13540 Dataset of nested class defined in Scala object") { + checkAnswer( + Seq(OuterObject.InnerClass("foo")).toDS(), + OuterObject.InnerClass("foo")) + } } class OuterClass extends Serializable { case class InnerClass(a: String) } +object OuterObject { + case class InnerClass(a: String) +} + case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) From 02aa499dfb71bc9571bebb79e6383842e4f48143 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 29 Feb 2016 09:44:29 -0800 Subject: [PATCH 2072/2089] [SPARK-13509][SPARK-13507][SQL] Support for writing CSV with a single function call https://issues.apache.org/jira/browse/SPARK-13507 https://issues.apache.org/jira/browse/SPARK-13509 ## What changes were proposed in this pull request? This PR adds the support to write CSV data directly by a single call to the given path. Several unitests were added for each functionality. ## How was this patch tested? This was tested with unittests and with `dev/run_tests` for coding style Author: hyukjinkwon Author: Hyukjin Kwon Closes #11389 from HyukjinKwon/SPARK-13507-13509. --- python/pyspark/sql/readwriter.py | 50 +++++++++++++++++++ python/test_support/sql/ages.csv | 4 ++ .../apache/spark/sql/DataFrameWriter.scala | 23 +++++++++ .../datasources/json/JSONOptions.scala | 5 +- .../datasources/text/DefaultSource.scala | 5 +- .../execution/datasources/csv/CSVSuite.scala | 3 +- 6 files changed, 80 insertions(+), 10 deletions(-) create mode 100644 python/test_support/sql/ages.csv diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b1453c637f79e..7f5368d8bdbb2 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -233,6 +233,23 @@ def text(self, paths): paths = [paths] return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(2.0) + def csv(self, paths): + """Loads a CSV file and returns the result as a [[DataFrame]]. + + This function goes through the input once to determine the input schema. To avoid going + through the entire data once, specify the schema explicitly using [[schema]]. + + :param paths: string, or list of strings, for input path(s). + + >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df.dtypes + [('C0', 'string'), ('C1', 'string')] + """ + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(1.5) def orc(self, path): """Loads an ORC file, returning the result as a :class:`DataFrame`. @@ -448,6 +465,11 @@ def json(self, path, mode=None): * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + You can set the following JSON-specific option(s) for writing JSON files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode)._jwrite.json(path) @@ -476,11 +498,39 @@ def parquet(self, path, mode=None, partitionBy=None): def text(self, path): """Saves the content of the DataFrame in a text file at the specified path. + :param path: the path in any Hadoop supported file system + The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. + + You can set the following option(s) for writing text files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). """ self._jwrite.text(path) + @since(2.0) + def csv(self, path, mode=None): + """Saves the content of the [[DataFrame]] in CSV format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + You can set the following CSV-specific option(s) for writing CSV files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + + >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode)._jwrite.csv(path) + @since(1.5) def orc(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. diff --git a/python/test_support/sql/ages.csv b/python/test_support/sql/ages.csv new file mode 100644 index 0000000000000..18991feda788a --- /dev/null +++ b/python/test_support/sql/ages.csv @@ -0,0 +1,4 @@ +Joe,20 +Tom,30 +Hyukjin,25 + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d6bdd3d825565..093504c765ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -453,6 +453,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("json").save(path) * }}} * + * You can set the following JSON-specific option(s) for writing JSON files: + *
      • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
      • + * * @since 1.4.0 */ def json(path: String): Unit = format("json").save(path) @@ -492,10 +496,29 @@ final class DataFrameWriter private[sql](df: DataFrame) { * df.write().text("/path/to/output") * }}} * + * You can set the following option(s) for writing text files: + *
      • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
      • + * * @since 1.6.0 */ def text(path: String): Unit = format("text").save(path) + /** + * Saves the content of the [[DataFrame]] in CSV format at the specified path. + * This is equivalent to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can set the following CSV-specific option(s) for writing CSV files: + *
      • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
      • + * + * @since 2.0.0 + */ + def csv(path: String): Unit = format("csv").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 31a95ed461215..e59dbd6b3d438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -48,10 +48,7 @@ private[sql] class JSONOptions( parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 60155b32349a7..8f3f6335e4282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -115,10 +115,7 @@ private[sql] class TextRelation( /** Write path. */ override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = job.getConfiguration - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5d57d77ab0549..3ecbb14f2ea65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -268,9 +268,8 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)) cars.coalesce(1).write - .format("csv") .option("header", "true") - .save(csvDir) + .csv(csvDir) val carsCopy = sqlContext.read .format("csv") From bc65f60ef7c920db756bfe643f7edbdf3593a989 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 29 Feb 2016 10:10:04 -0800 Subject: [PATCH 2073/2089] [SPARK-13544][SQL] Rewrite/Propagate Constraints for Aliases in Aggregate #### What changes were proposed in this pull request? After analysis by Analyzer, two operators could have alias. They are `Project` and `Aggregate`. So far, we only rewrite and propagate constraints if `Alias` is defined in `Project`. This PR is to resolve this issue in `Aggregate`. #### How was this patch tested? Added a test case for `Aggregate` in `ConstraintPropagationSuite`. marmbrus sameeragarwal Author: gatorsmile Closes #11422 from gatorsmile/validConstraintsInUnaryNodes. --- .../catalyst/plans/logical/LogicalPlan.scala | 16 ++++++++++ .../plans/logical/basicOperators.scala | 30 +++++-------------- .../plans/ConstraintPropagationSuite.scala | 15 ++++++++++ 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8095083f336e1..31e775d60f950 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + override protected def validConstraints: Set[Expression] = child.constraints override def statistics: Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5d2a65b716b24..e81a0f9487469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias - */ - private def getAliasedConstraints: Set[Expression] = { - projectList.flatMap { - case a @ Alias(e, _) => - child.constraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet - } - - override def validConstraints: Set[Expression] = { - child.constraints.union(getAliasedConstraints) - } + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(projectList)) } /** @@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - } // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. @@ -442,6 +423,9 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(aggregateExpressions)) + override def statistics: Statistics = { if (groupingExpressions.isEmpty) { Statistics(sizeInBytes = 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 373b1ffa83d23..b68432b1a128f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aggregate") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "c1") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), + resolveColumn(aliasedRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) From 17a253cbf4712dbeab06c454b5142917a1bba78b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 29 Feb 2016 11:02:45 -0800 Subject: [PATCH 2074/2089] [SPARK-13522][CORE] Executor should kill itself when it's unable to heartbeat to driver more than N times ## What changes were proposed in this pull request? Sometimes, network disconnection event won't be triggered for other potential race conditions that we may not have thought of, then the executor will keep sending heartbeats to driver and won't exit. This PR adds a new configuration `spark.executor.heartbeat.maxFailures` to kill Executor when it's unable to heartbeat to the driver more than `spark.executor.heartbeat.maxFailures` times. ## How was this patch tested? unit tests Author: Shixiong Zhu Closes #11401 from zsxwing/SPARK-13522. --- .../org/apache/spark/executor/Executor.scala | 22 ++++++++++++++++++- .../spark/executor/ExecutorExitCode.scala | 8 +++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a602fcac68a6b..86c121f787461 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -114,6 +114,19 @@ private[spark] class Executor( private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) + + /** + * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + startDriverHeartbeater() def launchTask( @@ -461,8 +474,15 @@ private[spark] class Executor( logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } + heartbeatFailures = 0 } catch { - case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + logError(s"Unable to send heartbeats to driver more than $HEARTBEAT_MAX_FAILURES times") + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ea36fb60bd540..99858f785600d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -39,6 +39,12 @@ object ExecutorExitCode { /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 + /** + * Executor is unable to send heartbeats to the driver more than + * "spark.executor.heartbeat.maxFailures" times. + */ + val HEARTBEAT_FAILURE = 56 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -51,6 +57,8 @@ object ExecutorExitCode { // TODO: replace external block store with concrete implementation name case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => "ExternalBlockStore failed to create a local temporary directory." + case HEARTBEAT_FAILURE => + "Unable to send heartbeats to driver." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { From 644dbb641afd337ca39733da5153239cf39cdd81 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 29 Feb 2016 11:52:11 -0800 Subject: [PATCH 2075/2089] [SPARK-13522][CORE] Fix the exit log place for heartbeat ## What changes were proposed in this pull request? Just fixed the log place introduced by #11401 ## How was this patch tested? unit tests. Author: Shixiong Zhu Closes #11432 from zsxwing/SPARK-13522-follow-up. --- core/src/main/scala/org/apache/spark/executor/Executor.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 86c121f787461..a959f200d4cc2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -478,9 +478,10 @@ private[spark] class Executor( } catch { case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) - logError(s"Unable to send heartbeats to driver more than $HEARTBEAT_MAX_FAILURES times") heartbeatFailures += 1 if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) } } From 4bd697da03079c26fd4409dc128dbff28c737701 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 29 Feb 2016 12:59:46 -0800 Subject: [PATCH 2076/2089] [SPARK-13123][SQL] Implement whole state codegen for sort ## What changes were proposed in this pull request? This PR adds support for implementing whole state codegen for sort. Builds heaving on nongli 's PR: https://github.com/apache/spark/pull/11008 (which actually implements the feature), and adds the following changes on top: - [x] Generated code updates peak execution memory metrics - [x] Unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite` ## How was this patch tested? New unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite`. Further, all existing sort tests should pass. Author: Sameer Agarwal Author: Nong Li Closes #11359 from sameeragarwal/sort-codegen. --- .../execution/UnsafeExternalRowSorter.java | 9 +- .../org/apache/spark/sql/execution/Sort.scala | 124 ++++++++++++++---- .../sql/execution/WholeStageCodegen.scala | 8 +- .../execution/WholeStageCodegenSuite.scala | 9 ++ .../execution/metric/SQLMetricsSuite.scala | 7 + 5 files changed, 122 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 27ae62f1212f6..0ad0f4976c77a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -final class UnsafeExternalRowSorter { +public final class UnsafeExternalRowSorter { /** * If positive, forces records to be spilled to disk at the given frequency (measured in numbers @@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) { testSpillFrequency = frequency; } - @VisibleForTesting - void insertRow(UnsafeRow row) throws IOException { + public void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), @@ -110,8 +109,7 @@ private void cleanupResources() { sorter.cleanupResources(); } - @VisibleForTesting - Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -160,7 +158,6 @@ public UnsafeRow next() { } } - public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 75cb6d1137c35..2ea889ea72c75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -37,7 +39,7 @@ case class Sort( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryNode { + extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -50,34 +52,36 @@ case class Sort( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output + def createSorter(): UnsafeExternalRowSorter = { + val ordering = newOrdering(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter + } + + protected override def doExecute(): RDD[InternalRow] = { val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") child.execute().mapPartitionsInternal { iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } + val sorter = createSorter() val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can @@ -93,4 +97,74 @@ case class Sort( sortedIterator } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.freshName("needToSort") + ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + sorterVariable = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + s"$sorterVariable = $thisPlan.createSorter();") + val metrics = ctx.freshName("metrics") + ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.freshName("sortedIter") + ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + + val addToSorter = ctx.freshName("addToSorter") + ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val dataSize = metricTerm(ctx, "dataSize") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + s""" + | if ($needToSort) { + | $addToSorter(); + | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $sortedIterator = $sorterVariable.sort(); + | $dataSize.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index afaddcf35775c..cb68ca6ada366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ${code.trim} } } - """ + """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) @@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // There is an UnsafeRow already s""" |append($row.copy()); - """.stripMargin + """.stripMargin.trim } else { assert(input != null) if (input.nonEmpty) { @@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) s""" |${code.code.trim} |append(${code.value}.copy()); - """.stripMargin + """.stripMargin.trim } else { // There is no columns s""" |append(unsafeRow); - """.stripMargin + """.stripMargin.trim } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9350205d791d7..de371d85d9fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } + + test("Sort should be included in WholeStageCodegen") { + val df = sqlContext.range(3, 0, -1).sort(col("id")) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index c49f2439fce49..5b4f6f1d2461b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("Sort metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) + val df = sqlContext.range(10).sort('id) + testSparkPlanMetrics(df, 2, Map.empty) + } + test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. From c7fccb56cd9260b8d72572e65f8e46a14707b9a5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Feb 2016 13:01:27 -0800 Subject: [PATCH 2077/2089] [SPARK-13478][YARN] Use real user when fetching delegation tokens. The Hive client library is not smart enough to notice that the current user is a proxy user; so when using a proxy user, it fails to fetch delegation tokens from the metastore because of a missing kerberos TGT for the current user. To fix it, just run the code that fetches the delegation token as the real logged in user. Tested on a kerberos cluster both submitting normally and with a proxy user; Hive and HBase tokens are retrieved correctly in both cases. Author: Marcelo Vanzin Closes #11358 from vanzin/SPARK-13478. --- .../spark/deploy/SparkSubmitArguments.scala | 5 ++ .../deploy/yarn/YarnSparkHadoopUtil.scala | 46 ++++++++++++++----- .../yarn/YarnSparkHadoopUtilSuite.scala | 2 +- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 915ef81b4eae3..175756b80b6bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -255,6 +255,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } + + if (proxyUser != null && principal != null) { + SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + } } private def validateKillArguments(): Unit = { @@ -517,6 +521,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | | --proxy-user NAME User to impersonate when submitting the application. + | This argument does not work with --principal / --keytab. | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4c9432dbd6ab0..aef78fdfd4c57 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,7 +18,9 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.lang.reflect.UndeclaredThrowableException import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction import java.util.regex.Matcher import java.util.regex.Pattern @@ -194,7 +196,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { */ def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = { try { - obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName) + obtainTokenForHiveMetastoreInner(conf) } catch { case e: ClassNotFoundException => logInfo(s"Hive class not found $e") @@ -209,8 +211,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { * @param username the username of the principal requesting the delegating token. * @return a delegation token */ - private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration, - username: String): Option[Token[DelegationTokenIdentifier]] = { + private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration): + Option[Token[DelegationTokenIdentifier]] = { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down @@ -225,11 +227,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // Check for local metastore if (metastoreUri.nonEmpty) { - require(username.nonEmpty, "Username undefined") val principalKey = "hive.metastore.kerberos.principal" val principal = hiveConf.getTrimmed(principalKey, "") require(principal.nonEmpty, "Hive principal $principalKey undefined") - logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri") + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") val closeCurrent = hiveClass.getMethod("closeCurrent") try { @@ -238,12 +241,14 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { classOf[String], classOf[String]) val getHive = hiveClass.getMethod("get", hiveConfClass) - // invoke - val hive = getHive.invoke(null, hiveConf) - val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String] - val hive2Token = new Token[DelegationTokenIdentifier]() - hive2Token.decodeFromUrlString(tokenStr) - Some(hive2Token) + doAsRealUser { + val hive = getHive.invoke(null, hiveConf) + val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) + .asInstanceOf[String] + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + Some(hive2Token) + } } finally { Utils.tryLogNonFatalError { closeCurrent.invoke(null) @@ -303,6 +308,25 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } } + /** + * Run some code as the real logged in user (which may differ from the current user, for + * example, when using proxying). + */ + private def doAsRealUser[T](fn: => T): T = { + val currentUser = UserGroupInformation.getCurrentUser() + val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) + + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { + realUser.doAs(new PrivilegedExceptionAction[T]() { + override def run(): T = fn + }) + } catch { + case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e) + } + } + } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index d3acaf229cc85..9202bd892f01b 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -255,7 +255,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging hadoopConf.set("hive.metastore.uris", "http://localhost:0") val util = new YarnSparkHadoopUtil assertNestedHiveException(intercept[InvocationTargetException] { - util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") + util.obtainTokenForHiveMetastoreInner(hadoopConf) }) assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf) From 0a4b620f3144d68232eb7914ae05563aab648ced Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 29 Feb 2016 22:24:43 -0800 Subject: [PATCH 2078/2089] [SPARK-13551][MLLIB] Fix wrong comment and remove meanless lines in mllib.JavaBisectingKMeansExample JIRA: https://issues.apache.org/jira/browse/SPARK-13551 ## What changes were proposed in this pull request? Fix wrong comment and remove meanless lines in mllib.JavaBisectingKMeansExample ## How was this patch tested? manual test Author: Zheng RuiFeng Closes #11429 from zhengruifeng/mllib_bkm_je. --- .../spark/examples/mllib/JavaBisectingKMeansExample.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java index 0001500f4fa5a..c600094947d5a 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -33,7 +33,7 @@ // $example off$ /** - * Java example for graph clustering using power iteration clustering (PIC). + * Java example for bisecting k-means clustering. */ public class JavaBisectingKMeansExample { public static void main(String[] args) { @@ -54,9 +54,7 @@ public static void main(String[] args) { BisectingKMeansModel model = bkm.run(data); System.out.println("Compute Cost: " + model.computeCost(data)); - for (Vector center: model.clusterCenters()) { - System.out.println(""); - } + Vector[] clusterCenters = model.clusterCenters(); for (int i = 0; i < clusterCenters.length; i++) { Vector clusterCenter = clusterCenters[i]; From 3c5f5e3b5c4eb69472fdd8124aa9988bd8d933b5 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 29 Feb 2016 23:55:26 -0800 Subject: [PATCH 2079/2089] [SPARK-13550][ML] Add java example for ml.clustering.BisectingKMeans JIRA: https://issues.apache.org/jira/browse/SPARK-13550 ## What changes were proposed in this pull request? Just add a java example for ml.clustering.BisectingKMeans ## How was this patch tested? manual tests were done. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Zheng RuiFeng Closes #11428 from zhengruifeng/ml_bkm_je. --- .../ml/JavaBisectingKMeansExample.java | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java new file mode 100644 index 0000000000000..e124c1cf18550 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +// $example on$ +import org.apache.spark.ml.clustering.BisectingKMeans; +import org.apache.spark.ml.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + + +/** + * An example demonstrating a bisecting k-means clustering. + */ +public class JavaBisectingKMeansExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), + RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), + RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), + RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), + RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), + RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + BisectingKMeans bkm = new BisectingKMeans().setK(2); + BisectingKMeansModel model = bkm.fit(dataset); + + System.out.println("Compute Cost: " + model.computeCost(dataset)); + + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + jsc.stop(); + } +} From 12a2a57e1af21da0aa4275971365d76a8fc84a43 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 1 Mar 2016 14:37:36 +0000 Subject: [PATCH 2080/2089] [SPARK-13592][WINDOWS] fix path of spark-submit2.cmd in spark-submit.cmd ## What changes were proposed in this pull request? This patch fixes the problem that pyspark fails on Windows because pyspark can't find ```spark-submit2.cmd```. ## How was this patch tested? manual tests: I ran ```bin\pyspark.cmd``` and checked if pyspark is launched correctly after this patch is applyed. Author: Masayoshi TSUZUKI Closes #11442 from tsudukim/feature/SPARK-13592. --- bin/spark-submit.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f121b62a53d24..f301606933a95 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C spark-submit2.cmd %* +cmd /V /E /C "%~dp0spark-submit2.cmd" %* From c43899a04e4de18e238a1761bf4fe9f54e182320 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Mar 2016 08:43:02 -0800 Subject: [PATCH 2081/2089] [SPARK-13511] [SQL] Add wholestage codegen for limit JIRA: https://issues.apache.org/jira/browse/SPARK-13511 ## What changes were proposed in this pull request? Current limit operator doesn't support wholestage codegen. This is open to add support for it. In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator. Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is: TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L]) +- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L]) +- GlobalLimit 100 +- Exchange SinglePartition, None +- LocalLimit 100 +- Range 0, 1, 1, 524288000, [id#5L] After add wholestage codegen support: WholeStageCodegen : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L]) : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L]) : +- GlobalLimit 100 : +- INPUT +- Exchange SinglePartition, None +- WholeStageCodegen : +- LocalLimit 100 : +- Range 0, 1, 1, 524288000, [id#40L] ## How was this patch tested? A test is added into BenchmarkWholeStageCodegen. Author: Liang-Chi Hsieh Closes #11391 from viirya/wholestage-limit. --- .../apache/spark/sql/execution/limit.scala | 35 +++++++++++++++++-- .../BenchmarkWholeStageCodegen.scala | 14 ++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cd543d4195286..45175d36d5c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { /** * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. */ -trait BaseLimit extends UnaryNode { +trait BaseLimit extends UnaryNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 6d6cc0186a962..2d3e34d0e1292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("range/limit/sum") { + val N = 500 << 20 + runBenchmark("range/limit/sum", N) { + sqlContext.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + ignore("stat functions") { val N = 100 << 20 From 5ed48dd84d38dfe621428e164a02e74ddbdbc622 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 1 Mar 2016 08:47:56 -0800 Subject: [PATCH 2082/2089] [SPARK-12811][ML] Estimator for Generalized Linear Models(GLMs) Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS. cc mengxr Author: Yanbo Liang Closes #11136 from yanboliang/spark-12811. --- .../spark/ml/optim/WeightedLeastSquares.scala | 10 +- .../GeneralizedLinearRegression.scala | 577 ++++++++++++++++++ .../ml/regression/LinearRegression.scala | 4 +- .../GeneralizedLinearRegressionSuite.scala | 507 +++++++++++++++ 4 files changed, 1094 insertions(+), 4 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 61b3642131810..55b7510656643 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -156,6 +156,12 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { + /** + * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]] + * only supports the number of features is no more than 4096. + */ + val MAX_NUM_FEATURES: Int = 4096 + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares { private var aaSum: DenseVector = _ private def init(k: Int): Unit = { - require(k <= 4096, "In order to take the normal equation approach efficiently, " + - s"we set the max number of features to 4096 but got $k.") + require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to $MAX_NUM_FEATURES but got $k.") this.k = k triK = k * (k + 1) / 2 count = 0L diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala new file mode 100644 index 0000000000000..a850dfee0a452 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import breeze.stats.distributions.{Gaussian => GD} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ + +/** + * Params for Generalized Linear Regression. + */ +private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams + with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol + with HasSolver with Logging { + + /** + * Param for the name of family which is a description of the error distribution + * to be used in the model. + * Supported options: "gaussian", "binomial", "poisson" and "gamma". + * Default is "gaussian". + * @group param + */ + @Since("2.0.0") + final val family: Param[String] = new Param(this, "family", + "The name of family which is a description of the error distribution to be used in the " + + "model. Supported options: gaussian(default), binomial, poisson and gamma.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getFamily: String = $(family) + + /** + * Param for the name of link function which provides the relationship + * between the linear predictor and the mean of the distribution function. + * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * @group param + */ + @Since("2.0.0") + final val link: Param[String] = new Param(this, "link", "The name of link function " + + "which provides the relationship between the linear predictor and the mean of the " + + "distribution function. Supported options: identity, log, inverse, logit, probit, " + + "cloglog and sqrt.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getLink: String = $(link) + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + override def validateParams(): Unit = { + if ($(solver) == "irls") { + setDefault(maxIter -> 25) + } + if (isDefined(link)) { + require(supportedFamilyAndLinkPairs.contains( + Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + + s"with ${$(family)} family does not support ${$(link)} link function.") + } + } +} + +/** + * :: Experimental :: + * + * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]]) + * specified by giving a symbolic description of the linear predictor (link function) and + * a description of the error distribution (family). + * It supports "gaussian", "binomial", "poisson" and "gamma" as family. + * Valid link functions for each family is listed below. The first link function of each family + * is the default one. + * - "gaussian" -> "identity", "log", "inverse" + * - "binomial" -> "logit", "probit", "cloglog" + * - "poisson" -> "log", "identity", "sqrt" + * - "gamma" -> "inverse", "identity", "log" + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase with Logging { + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("glm")) + + /** + * Sets the value of param [[family]]. + * Default is "gaussian". + * @group setParam + */ + @Since("2.0.0") + def setFamily(value: String): this.type = set(family, value) + setDefault(family -> Gaussian.name) + + /** + * Sets the value of param [[link]]. + * @group setParam + */ + @Since("2.0.0") + def setLink(value: String): this.type = set(link, value) + + /** + * Sets if we should fit the intercept. + * Default is true. + * @group setParam + */ + @Since("2.0.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + + /** + * Sets the maximum number of iterations. + * Default is 25 if the solver algorithm is "irls". + * @group setParam + */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Sets the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Sets the regularization parameter. + * Default is 0.0. + * @group setParam + */ + @Since("2.0.0") + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("2.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + /** + * Sets the solver algorithm used for optimization. + * Currently only support "irls" which is also the default solver. + * @group setParam + */ + @Since("2.0.0") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "irls") + + override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + val familyObj = Family.fromName($(family)) + val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd + .map { case Row(features: Vector) => + features.size + }.first() + if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { + val msg = "Currently, GeneralizedLinearRegression only supports number of features" + + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." + throw new SparkException(msg) + } + + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd + .map { case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + if (familyObj == Gaussian && linkObj == Identity) { + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + standardizeFeatures = true, standardizeLabel = true) + val wlsModel = optimizer.fit(instances) + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) + .setParent(this)) + return model + } + + // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, + $(fitIntercept), $(regParam), $(maxIter), $(tol)) + val irlsModel = optimizer.fit(instances) + + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) + .setParent(this)) + model + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra) +} + +@Since("2.0.0") +private[ml] object GeneralizedLinearRegression { + + /** Set of family and link pairs that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyAndLinkPairs = Set( + Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, + Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, + Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, + Gamma -> Inverse, Gamma -> Identity, Gamma -> Log + ) + + /** Set of family names that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + + /** Set of link names that GeneralizedLinearRegression supports. */ + lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + + val epsilon: Double = 1E-16 + + /** + * Wrapper of family and link combination used in the model. + */ + private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + + /** Linear predictor based on given mu. */ + def predict(mu: Double): Double = link.link(family.project(mu)) + + /** Fitted value based on linear predictor eta. */ + def fitted(eta: Double): Double = family.project(link.unlink(eta)) + + /** + * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. + */ + def initialize( + instances: RDD[Instance], + fitIntercept: Boolean, + regParam: Double): WeightedLeastSquaresModel = { + val newInstances = instances.map { instance => + val mu = family.initialize(instance.label, instance.weight) + val eta = predict(mu) + Instance(eta, instance.weight, instance.features) + } + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + standardizeFeatures = true, standardizeLabel = true) + .fit(newInstances) + initialModel + } + + /** + * The reweight function used to update offsets and weights + * at each iteration of [[IterativelyReweightedLeastSquares]]. + */ + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: Instance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + val mu = fitted(eta) + val offset = eta + (instance.label - mu) * link.deriv(mu) + val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (offset, weight) + } + } + } + + /** + * A description of the error distribution to be used in the model. + * @param name the name of the family. + */ + private[ml] abstract class Family(val name: String) extends Serializable { + + /** The default link instance of this family. */ + val defaultLink: Link + + /** Initialize the starting value for mu. */ + def initialize(y: Double, weight: Double): Double + + /** The variance of the endogenous variable's mean, given the value mu. */ + def variance(mu: Double): Double + + /** Trim the fitted value so that it will be in valid range. */ + def project(mu: Double): Double = mu + } + + private[ml] object Family { + + /** + * Gets the [[Family]] object from its name. + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + */ + def fromName(name: String): Family = { + name match { + case Gaussian.name => Gaussian + case Binomial.name => Binomial + case Poisson.name => Poisson + case Gamma.name => Gamma + } + } + } + + /** + * Gaussian exponential family distribution. + * The default link for the Gaussian family is the identity link. + */ + private[ml] object Gaussian extends Family("gaussian") { + + val defaultLink: Link = Identity + + override def initialize(y: Double, weight: Double): Double = y + + def variance(mu: Double): Double = 1.0 + + override def project(mu: Double): Double = { + if (mu.isNegInfinity) { + Double.MinValue + } else if (mu.isPosInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Binomial exponential family distribution. + * The default link for the Binomial family is the logit link. + */ + private[ml] object Binomial extends Family("binomial") { + + val defaultLink: Link = Logit + + override def initialize(y: Double, weight: Double): Double = { + val mu = (weight * y + 0.5) / (weight + 1.0) + require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" + + s"should be in range (0, 1), but got $mu") + mu + } + + override def variance(mu: Double): Double = mu * (1.0 - mu) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu > 1.0 - epsilon) { + 1.0 - epsilon + } else { + mu + } + } + } + + /** + * Poisson exponential family distribution. + * The default link for the Poisson family is the log link. + */ + private[ml] object Poisson extends Family("poisson") { + + val defaultLink: Link = Log + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Poisson family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = mu + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Gamma exponential family distribution. + * The default link for the Gamma family is the inverse link. + */ + private[ml] object Gamma extends Family("gamma") { + + val defaultLink: Link = Inverse + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Gamma family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = math.pow(mu, 2.0) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * A description of the link function to be used in the model. + * The link function provides the relationship between the linear predictor + * and the mean of the distribution function. + * @param name the name of link function. + */ + private[ml] abstract class Link(val name: String) extends Serializable { + + /** The link function. */ + def link(mu: Double): Double + + /** Derivative of the link function. */ + def deriv(mu: Double): Double + + /** The inverse link function. */ + def unlink(eta: Double): Double + } + + private[ml] object Link { + + /** + * Gets the [[Link]] object from its name. + * @param name link name: "identity", "logit", "log", + * "inverse", "probit", "cloglog" or "sqrt". + */ + def fromName(name: String): Link = { + name match { + case Identity.name => Identity + case Logit.name => Logit + case Log.name => Log + case Inverse.name => Inverse + case Probit.name => Probit + case CLogLog.name => CLogLog + case Sqrt.name => Sqrt + } + } + } + + private[ml] object Identity extends Link("identity") { + + override def link(mu: Double): Double = mu + + override def deriv(mu: Double): Double = 1.0 + + override def unlink(eta: Double): Double = eta + } + + private[ml] object Logit extends Link("logit") { + + override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) + + override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) + } + + private[ml] object Log extends Link("log") { + + override def link(mu: Double): Double = math.log(mu) + + override def deriv(mu: Double): Double = 1.0 / mu + + override def unlink(eta: Double): Double = math.exp(eta) + } + + private[ml] object Inverse extends Link("inverse") { + + override def link(mu: Double): Double = 1.0 / mu + + override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0) + + override def unlink(eta: Double): Double = 1.0 / eta + } + + private[ml] object Probit extends Link("probit") { + + override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu) + + override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu)) + + override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta) + } + + private[ml] object CLogLog extends Link("cloglog") { + + override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) + + override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) + } + + private[ml] object Sqrt extends Link("sqrt") { + + override def link(mu: Double): Double = math.sqrt(mu) + + override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu)) + + override def unlink(eta: Double): Double = math.pow(eta, 2.0) + } +} + +/** + * :: Experimental :: + * Model produced by [[GeneralizedLinearRegression]]. + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegressionModel private[ml] ( + @Since("2.0.0") override val uid: String, + @Since("2.0.0") val coefficients: Vector, + @Since("2.0.0") val intercept: Double) + extends RegressionModel[Vector, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase { + + import GeneralizedLinearRegression._ + + lazy val familyObj = Family.fromName($(family)) + lazy val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + override protected def predict(features: Vector): Double = { + val eta = BLAS.dot(features, coefficients) + intercept + familyAndLink.fitted(eta) + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { + copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) + .setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8f78fd122f343..b4f17b8e28982 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -163,8 +163,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || - $(solver) == "normal") { + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala new file mode 100644 index 0000000000000..8bfa9855ce4ea --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} +import org.apache.spark.mllib.random._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val seed: Int = 42 + @transient var datasetGaussianIdentity: DataFrame = _ + @transient var datasetGaussianLog: DataFrame = _ + @transient var datasetGaussianInverse: DataFrame = _ + @transient var datasetBinomial: DataFrame = _ + @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonIdentity: DataFrame = _ + @transient var datasetPoissonSqrt: DataFrame = _ + @transient var datasetGammaInverse: DataFrame = _ + @transient var datasetGammaIdentity: DataFrame = _ + @transient var datasetGammaLog: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + import GeneralizedLinearRegressionSuite._ + + datasetGaussianIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity"), 2)) + + datasetGaussianLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log"), 2)) + + datasetGaussianInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse"), 2)) + + datasetBinomial = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, + addIntercept = true, nPoints, seed) + + sqlContext.createDataFrame(sc.parallelize(testData, 2)) + } + + datasetPoissonLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log"), 2)) + + datasetPoissonIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity"), 2)) + + datasetPoissonSqrt = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt"), 2)) + + datasetGammaInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse"), 2)) + + datasetGammaIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity"), 2)) + + datasetGammaLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log"), 2)) + } + + test("params") { + ParamsSuite.checkParams(new GeneralizedLinearRegression) + val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("generalized linear regression: default params") { + val glr = new GeneralizedLinearRegression + assert(glr.getLabelCol === "label") + assert(glr.getFeaturesCol === "features") + assert(glr.getPredictionCol === "prediction") + assert(glr.getFitIntercept) + assert(glr.getTol === 1E-6) + assert(glr.getWeightCol === "") + assert(glr.getRegParam === 0.0) + assert(glr.getSolver == "irls") + // TODO: Construct model directly instead of via fitting. + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + assert(model.getFamily === "gaussian") + assert(model.getLink === "identity") + } + + test("generalized linear regression: gaussian family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="gaussian", data=data) + print(as.vector(coef(model))) + } + + [1] 2.2960999 0.8087933 + [1] 2.5002642 2.2000403 0.5999485 + + data <- read.csv("path", header=FALSE) + model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0)) + model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0)) + print(as.vector(coef(model1))) + print(as.vector(coef(model2))) + + [1] 0.23069326 0.07993778 + [1] 0.25001858 0.22002452 0.05998789 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=gaussian(link=inverse), data=data) + print(as.vector(coef(model))) + } + + [1] 2.3010179 0.8198976 + [1] 2.4108902 2.2130248 0.6086152 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2960999, 0.8087933), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(0.0, 0.23069326, 0.07993778), + Vectors.dense(0.25001858, 0.22002452, 0.05998789), + Vectors.dense(0.0, 2.3010179, 0.8198976), + Vectors.dense(2.4108902, 2.2130248, 0.6086152)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog), + ("inverse", datasetGaussianInverse))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gaussian family against glmnet") { + /* + R code: + library(glmnet) + data <- read.csv("path", header=FALSE) + label = data$V1 + features = as.matrix(data.frame(data$V2, data$V3)) + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + model <- glmnet(features, label, family="gaussian", intercept=intercept, + lambda=lambda, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + + [1] 0.0000000 2.2961005 0.8087932 + [1] 0.0000000 2.2130368 0.8309556 + [1] 0.0000000 1.7176137 0.9610657 + [1] 2.5002642 2.2000403 0.5999485 + [1] 3.1106389 2.0935142 0.5712711 + [1] 6.7597127 1.4581054 0.3994266 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2961005, 0.8087932), + Vectors.dense(0.0, 2.2130368, 0.8309556), + Vectors.dense(0.0, 1.7176137, 0.9610657), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(3.1106389, 2.0935142, 0.5712711), + Vectors.dense(6.7597127, 1.4581054, 0.3994266)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian") + .setFitIntercept(fitIntercept).setRegParam(regParam) + val model = trainer.fit(datasetGaussianIdentity) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"fitIntercept = $fitIntercept and regParam = $regParam.") + + idx += 1 + } + } + + test("generalized linear regression: binomial family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 + data <- read.csv("path", header=FALSE) + + for (formula in c(f1, f2)) { + model <- glm(formula, family="binomial", data=data) + print(as.vector(coef(model))) + } + + [1] -0.3560284 1.3010002 -0.3570805 -0.7406762 + [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=probit), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2134390 0.7800646 -0.2144267 -0.4438358 + [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=cloglog), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2832198 0.8434144 -0.2524727 -0.5293452 + [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758 + */ + val expected = Seq( + Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762), + Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989), + Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358), + Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850), + Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452), + Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial), + ("cloglog", datasetBinomial))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), + model.coefficients(2), model.coefficients(3)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: poisson family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + + [1] 0.22999393 0.08047088 + [1] 0.25022353 0.21998599 0.05998621 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2929501 0.8119415 + [1] 2.5012730 2.1999407 0.5999107 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=sqrt), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2958947 0.8090515 + [1] 2.5000480 2.1999972 0.5999968 + */ + val expected = Seq( + Vectors.dense(0.0, 0.22999393, 0.08047088), + Vectors.dense(0.25022353, 0.21998599, 0.05998621), + Vectors.dense(0.0, 2.2929501, 0.8119415), + Vectors.dense(2.5012730, 2.1999407, 0.5999107), + Vectors.dense(0.0, 2.2958947, 0.8090515), + Vectors.dense(2.5000480, 2.1999972, 0.5999968)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity), + ("sqrt", datasetPoissonSqrt))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gamma family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="Gamma", data=data) + print(as.vector(coef(model))) + } + + [1] 2.3392419 0.8058058 + [1] 2.3507700 2.2533574 0.6042991 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2908883 0.8147796 + [1] 2.5002406 2.1998346 0.6000059 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=log), data=data) + print(as.vector(coef(model))) + } + + [1] 0.22958970 0.08091066 + [1] 0.25003210 0.21996957 0.06000215 + */ + val expected = Seq( + Vectors.dense(0.0, 2.3392419, 0.8058058), + Vectors.dense(2.3507700, 2.2533574, 0.6042991), + Vectors.dense(0.0, 2.2908883, 0.8147796), + Vectors.dense(2.5002406, 2.1998346, 0.6000059), + Vectors.dense(0.0, 0.22958970, 0.08091066), + Vectors.dense(0.25003210, 0.21996957, 0.06000215)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), + ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } +} + +object GeneralizedLinearRegressionSuite { + + def generateGeneralizedLinearRegressionInput( + intercept: Double, + coefficients: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + noiseLevel: Double, + family: String, + link: String): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + def rndElement(i: Int) = { + (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + } + val (generator, mean) = family match { + case "gaussian" => (new StandardNormalGenerator, 0.0) + case "poisson" => (new PoissonGenerator(1.0), 1.0) + case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0) + } + generator.setSeed(seed) + + (0 until nPoints).map { _ => + val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray) + val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept + val mu = link match { + case "identity" => eta + case "log" => math.exp(eta) + case "sqrt" => math.pow(eta, 2.0) + case "inverse" => 1.0 / eta + } + val label = mu + noiseLevel * (generator.nextValue() - mean) + // Return LabeledPoints with DenseVector + LabeledPoint(label, features) + } + } +} From c37bbb3a1cbd93c749aaaeca1345817e0c20094f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 1 Mar 2016 09:40:33 -0800 Subject: [PATCH 2083/2089] Closes #11320 Closes #10940 Closes #11302 Closes #11430 Closes #10912 From c27ba0d547a0cd3fd00bb42c76ad971b2d48b4a0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Mar 2016 13:07:04 -0800 Subject: [PATCH 2084/2089] [SPARK-13582] [SQL] defer dictionary decoding in parquet reader ## What changes were proposed in this pull request? This PR defer the resolution from a id of dictionary to value until the column is actually accessed (inside getInt/getLong), this is very useful for those columns and rows that are filtered out. It's also useful for binary type, we will not need to copy all the byte arrays. This PR also change the underlying type for small decimal that could be fit within a Int, in order to use getInt() to lookup the value from IntDictionary. ## How was this patch tested? Manually test TPCDS Q7 with scale factor 10, saw about 30% improvements (after PR #11274). Author: Davies Liu Closes #11437 from davies/decode_dict. --- .../org/apache/spark/sql/types/Decimal.scala | 3 + .../apache/spark/sql/types/DecimalType.scala | 11 ++ .../parquet/UnsafeRowParquetRecordReader.java | 101 +++++------------ .../parquet/VectorizedRleValuesReader.java | 33 ------ .../execution/vectorized/ColumnVector.java | 105 +++++++++++++++--- .../vectorized/ColumnVectorUtils.java | 8 +- .../execution/vectorized/ColumnarBatch.java | 24 +--- .../vectorized/OffHeapColumnVector.java | 58 ++++++---- .../vectorized/OnHeapColumnVector.java | 47 ++++++-- .../parquet/CatalystRowConverter.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 14 +-- .../parquet/CatalystWriteSupport.scala | 8 +- .../parquet/ParquetEncodingSuite.scala | 2 +- .../vectorized/ColumnarBatchBenchmark.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 15 files changed, 221 insertions(+), 203 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ede..6a59e9728a9f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -340,6 +340,9 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + /** Maximum number of decimal digits a Int can represent */ + val MAX_INT_DIGITS = 9 + /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 2e03ddae760b2..9c1319c1c5e6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a int + */ + def is32BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_INT_DIGITS + case _ => false + } + } + /** * Returns if dt is a DecimalType that fits inside a long */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index e7f0ec2e77895..57dbd7c2ff56f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -257,8 +257,7 @@ private void initializeInternal() throws IOException { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && - primitiveType.getDecimalMetadata().getPrecision() > - CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) { throw new IOException("Decimal with high precision is not supported."); } if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { @@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); int precision = type.getDecimalMetadata().getPrecision(); int scale = type.getDecimalMetadata().getScale(); - Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(), "Unsupported precision."); for (int n = 0; n < num; ++n) { @@ -480,11 +479,6 @@ private final class ColumnReader { */ private boolean useDictionary; - /** - * If useDictionary is true, the staging vector used to decode the ids. - */ - private ColumnVector dictionaryIds; - /** * Maximum definition level for this column. */ @@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException { } int num = Math.min(total, leftInPage); if (useDictionary) { - // Data is dictionary encoded. We will vector decode the ids and then resolve the values. - if (dictionaryIds == null) { - dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(total); - } // Read and decode dictionary ids. + ColumnVector dictionaryIds = column.reserveDictionaryIds(total);; defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - decodeDictionaryIds(rowId, num, column); + decodeDictionaryIds(rowId, num, column, dictionaryIds); } else { + column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + private void decodeDictionaryIds(int rowId, int num, ColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: - if (column.dataType() == DataTypes.IntegerType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ByteType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ShortType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case INT64: - if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case FLOAT: - for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); - } - break; - case DOUBLE: - for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); - } + case BINARY: + column.setDictionary(dictionary); break; case FIXED_LEN_BYTE_ARRAY: - if (DecimalType.is64BitDecimalType(column.dataType())) { + // DecimalType written in the legacy mode + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); @@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { } break; - case BINARY: - // TODO: this is incredibly inefficient as it blows up the dictionary right here. We - // need to do this better. We should probably add the dictionary data to the ColumnVector - // and reuse it across batches. This should mean adding a ByteArray would just update - // the length and offset. - for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); - } - break; - default: throw new NotImplementedException("Unsupported type: " + descriptor.getType()); } @@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || + DecimalType.is32BitDecimalType(column.dataType())) { defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - defColumn.readIntsAsLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ShortType) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num, VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (DecimalType.is64BitDecimalType(column.dataType())) { + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putInt(rowId + i, + (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 8613fcae0b805..62157389013bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -25,7 +25,6 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c, } } - public void readIntsAsLongs(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { - int left = total; - while (left > 0) { - if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - switch (mode) { - case RLE: - if (currentValue == level) { - for (int i = 0; i < n; i++) { - c.putLong(rowId + i, data.readInteger()); - } - } else { - c.putNulls(rowId, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == level) { - c.putLong(rowId + i, data.readInteger()); - } else { - c.putNull(rowId + i); - } - } - break; - } - rowId += n; - left -= n; - currentCount -= n; - } - } - public void readBytes(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0514252a8e53d..bb0247c2fbedf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,10 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -27,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. @@ -157,7 +159,7 @@ public Object[] array() { } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { - list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + list[i] = getUTF8String(i).toString(); } } } else if (dt instanceof CalendarIntervalType) { @@ -204,28 +206,17 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return data.getDecimal(offset + ordinal, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - Array child = data.getByteArray(offset + ordinal); - return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + return data.getUTF8String(offset + ordinal); } @Override public byte[] getBinary(int ordinal) { - ColumnVector.Array array = data.getByteArray(offset + ordinal); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return data.getBinary(offset + ordinal); } @Override @@ -534,12 +525,57 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - public final Array getByteArray(int rowId) { + private Array getByteArray(int rowId) { Array array = getArray(rowId); array.data.loadBytes(array); return array; } + /** + * Returns the decimal for rowId. + */ + public final Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.apply(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Returns the UTF8String for rowId. + */ + public final UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return UTF8String.fromBytes(v.getBytes()); + } + } + + /** + * Returns the byte array for rowId. + */ + public final byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return v.getBytes(); + } + } + /** * Append APIs. These APIs all behave similarly and will append data to the current vector. It * is not valid to mix the put and append APIs. The append APIs are slower and should only be @@ -816,6 +852,39 @@ public final int appendStruct(boolean isNull) { */ protected final ColumnarBatch.Row resultStruct; + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected ColumnVector dictionaryIds; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public ColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = allocate(capacity, DataTypes.IntegerType, + this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2aeef7f2f90fe..681ace3387139 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -22,24 +22,20 @@ import java.util.Iterator; import java.util.List; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.commons.lang.NotImplementedException; - /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly * for debugging or other non-performance critical paths. * These utilities are mostly used to convert ColumnVectors into other formats. */ public class ColumnVectorUtils { - public static String toString(ColumnVector.Array a) { - return new String(a.byteArray, a.byteArrayOffset, a.length); - } - /** * Returns the array data as the java primitive array. * For example, an array of IntegerType will return an int[]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 070d897a7158c..8a0d7f8b12379 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; @@ -31,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class is the in memory representation of rows as they are streamed through operators. It * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that @@ -193,29 +191,17 @@ public final boolean anyNull() { @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = columns[ordinal].getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + return columns[ordinal].getUTF8String(rowId); } @Override public final byte[] getBinary(int ordinal) { - ColumnVector.Array array = columns[ordinal].getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return columns[ordinal].getBinary(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e38ed051219b7..b06b7f2457b54 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,25 +18,11 @@ import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - - import org.apache.commons.lang.NotImplementedException; -import org.apache.commons.lang.NotImplementedException; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. @@ -171,7 +157,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return Platform.getByte(null, data + rowId); + if (dictionary == null) { + return Platform.getByte(null, data + rowId); + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -199,7 +189,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return Platform.getShort(null, data + 2 * rowId); + if (dictionary == null) { + return Platform.getShort(null, data + 2 * rowId); + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -233,7 +227,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return Platform.getInt(null, data + 4 * rowId); + if (dictionary == null) { + return Platform.getInt(null, data + 4 * rowId); + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -267,7 +265,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return Platform.getLong(null, data + 8 * rowId); + if (dictionary == null) { + return Platform.getLong(null, data + 8 * rowId); + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -301,7 +303,11 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public final float getFloat(int rowId) { - return Platform.getFloat(null, data + rowId * 4); + if (dictionary == null) { + return Platform.getFloat(null, data + rowId * 4); + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } } @@ -336,7 +342,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return Platform.getDouble(null, data + rowId * 8); + if (dictionary == null) { + return Platform.getDouble(null, data + rowId * 8); + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -394,7 +404,7 @@ private final void reserveInternal(int newCapacity) { } else if (type instanceof ShortType) { this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || - type instanceof DateType) { + type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3502d31bd1dfa..305e84a86bdc7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,12 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.util.Arrays; + import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.util.Arrays; - /** * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. @@ -68,7 +67,6 @@ public final void close() { doubleData = null; } - // // APIs dealing with nulls // @@ -154,7 +152,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return byteData[rowId]; + if (dictionary == null) { + return byteData[rowId]; + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -180,7 +182,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return shortData[rowId]; + if (dictionary == null) { + return shortData[rowId]; + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } @@ -217,7 +223,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return intData[rowId]; + if (dictionary == null) { + return intData[rowId]; + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -253,7 +263,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return longData[rowId]; + if (dictionary == null) { + return longData[rowId]; + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -280,7 +294,13 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { } @Override - public final float getFloat(int rowId) { return floatData[rowId]; } + public final float getFloat(int rowId) { + if (dictionary == null) { + return floatData[rowId]; + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } // // APIs dealing with doubles @@ -309,7 +329,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return doubleData[rowId]; + if (dictionary == null) { + return doubleData[rowId]; + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -377,7 +401,8 @@ private final void reserveInternal(int newCapacity) { short[] newData = new short[newCapacity]; if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); shortData = newData; - } else if (type instanceof IntegerType || type instanceof DateType) { + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 42d89f4bf81d6..8a128b4b61769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter( } protected def decimalFromBinary(value: Binary): Decimal = { - if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { + if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index ab4250d0adbae..6f6340f541ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) case UINT_8 => typeNotSupported() case UINT_16 => typeNotSupported() case UINT_32 => typeNotSupported() @@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() @@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter { // Returns the minimum number of bytes needed to store a decimal with a given `precision`. val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ - - val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ - // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { Math.round( // convert double to long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 3508220c9541f..0252c79d8e143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -33,7 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi writeLegacyParquetFormat match { // Standard mode, 1 <= precision <= 9, writes as INT32 - case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer // Standard mode, 10 <= precision <= 18, writes as INT64 - case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY - case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY case _ => binaryWriterUsingUnscaledBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index cef6b79a094d1..281a2cffa894a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(batch.column(0).getByte(i) == 1) assert(batch.column(1).getInt(i) == 2) assert(batch.column(2).getLong(i) == 3) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + assert(batch.column(3).getUTF8String(i).toString == "abc") i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 8efdf8adb042a..97638a66ab473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -370,7 +370,7 @@ object ColumnarBatchBenchmark { } i = 0 while (i < count) { - sum += column.getByteArray(i).length + sum += column.getUTF8String(i).numBytes() i += 1 } column.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 445f311107e33..b3c3e66fbcbd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) From b0ee7d43730469ad61fdf6b7b75cc1b1efb62c31 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Mar 2016 15:39:13 -0800 Subject: [PATCH 2085/2089] [SPARK-13548][BUILD] Move tags and unsafe modules into common ## What changes were proposed in this pull request? This patch moves tags and unsafe modules into common directory to remove 2 top level non-user-facing directories. ## How was this patch tested? Jenkins should suffice. Author: Reynold Xin Closes #11426 from rxin/SPARK-13548. --- {tags => common/tags}/README.md | 0 {tags => common/tags}/pom.xml | 2 +- .../tags}/src/main/java/org/apache/spark/tags/DockerTest.java | 0 .../src/main/java/org/apache/spark/tags/ExtendedHiveTest.java | 0 .../src/main/java/org/apache/spark/tags/ExtendedYarnTest.java | 0 {unsafe => common/unsafe}/pom.xml | 2 +- .../src/main/java/org/apache/spark/unsafe/KVIterator.java | 0 .../src/main/java/org/apache/spark/unsafe/Platform.java | 0 .../java/org/apache/spark/unsafe/array/ByteArrayMethods.java | 0 .../main/java/org/apache/spark/unsafe/array/LongArray.java | 0 .../java/org/apache/spark/unsafe/bitset/BitSetMethods.java | 0 .../java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java | 0 .../org/apache/spark/unsafe/memory/HeapMemoryAllocator.java | 0 .../java/org/apache/spark/unsafe/memory/MemoryAllocator.java | 0 .../main/java/org/apache/spark/unsafe/memory/MemoryBlock.java | 0 .../java/org/apache/spark/unsafe/memory/MemoryLocation.java | 0 .../org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java | 0 .../main/java/org/apache/spark/unsafe/types/ByteArray.java | 0 .../java/org/apache/spark/unsafe/types/CalendarInterval.java | 0 .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 0 .../test/java/org/apache/spark/unsafe/PlatformUtilSuite.java | 0 .../java/org/apache/spark/unsafe/array/LongArraySuite.java | 0 .../org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java | 0 .../org/apache/spark/unsafe/types/CalendarIntervalSuite.java | 0 .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 0 .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 0 pom.xml | 4 ++-- 27 files changed, 4 insertions(+), 4 deletions(-) rename {tags => common/tags}/README.md (100%) rename {tags => common/tags}/pom.xml (97%) rename {tags => common/tags}/src/main/java/org/apache/spark/tags/DockerTest.java (100%) rename {tags => common/tags}/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java (100%) rename {tags => common/tags}/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java (100%) rename {unsafe => common/unsafe}/pom.xml (98%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/KVIterator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/Platform.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/array/LongArray.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/types/ByteArray.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/types/UTF8String.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java (100%) rename {unsafe => common/unsafe}/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala (100%) diff --git a/tags/README.md b/common/tags/README.md similarity index 100% rename from tags/README.md rename to common/tags/README.md diff --git a/tags/pom.xml b/common/tags/pom.xml similarity index 97% rename from tags/pom.xml rename to common/tags/pom.xml index 3e8e6f6182875..8e702b4fefe8c 100644 --- a/tags/pom.xml +++ b/common/tags/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/main/java/org/apache/spark/tags/DockerTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/unsafe/pom.xml b/common/unsafe/pom.xml similarity index 98% rename from unsafe/pom.xml rename to common/unsafe/pom.xml index 75fea556eeae1..5250014739da2 100644 --- a/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/Platform.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala similarity index 100% rename from unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala rename to common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala diff --git a/pom.xml b/pom.xml index 2376e307ced18..2148379896d35 100644 --- a/pom.xml +++ b/pom.xml @@ -89,7 +89,8 @@ common/sketch common/network-common common/network-shuffle - tags + common/unsafe + common/tags core graphx mllib @@ -99,7 +100,6 @@ sql/core sql/hive docker-integration-tests - unsafe assembly external/twitter external/flume From a640c5b4fbd653919a5897a7b11f16328f2094eb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Mar 2016 17:27:57 -0800 Subject: [PATCH 2086/2089] [SPARK-13598] [SQL] remove LeftSemiJoinBNL ## What changes were proposed in this pull request? Broadcast left semi join without joining keys is already supported in BroadcastNestedLoopJoin, it has the same implementation as LeftSemiJoinBNL, we should remove that. ## How was this patch tested? Updated unit tests. Author: Davies Liu Closes #11448 from davies/remove_bnl. --- .../spark/sql/execution/SparkStrategies.scala | 3 - .../sql/execution/joins/LeftSemiJoinBNL.scala | 80 ------------------- .../org/apache/spark/sql/JoinSuite.scala | 5 +- .../sql/execution/joins/SemiJoinSuite.scala | 9 --- 4 files changed, 2 insertions(+), 95 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd8c96d5fa1d6..0255103b63d81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -71,9 +71,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil - // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => - joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala deleted file mode 100644 index df6dac88187cc..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output: Seq[Attribute] = left.output - - /** The Streamed Relation */ - override def left: SparkPlan = streamed - - /** The Broadcast relation */ - override def right: SparkPlan = broadcast - - override def requiredChildDistribution: Seq[Distribution] = { - UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil - } - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - val relation = broadcastedRelation.value - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < relation.length && !matched) { - if (boundCondition(joinedRow(streamedRow, relation(i)))) { - matched = true - } - i += 1 - } - if (matched) { - numOutputRows += 1 - } - matched - }) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3dab848e7b033..5b98c11ef2a4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -47,7 +47,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j @@ -67,7 +66,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), @@ -465,7 +464,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", - classOf[LeftSemiJoinBNL]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 355f916a97552..bc341db5571be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -95,15 +95,6 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using LeftSemiJoinBNL") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - test(s"$testName using BroadcastNestedLoopJoin build left") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => From e42724b12b976b3276accc1132f446fa67f7f981 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 1 Mar 2016 17:34:21 -0800 Subject: [PATCH 2087/2089] [SPARK-13167][SQL] Include rows with null values for partition column when reading from JDBC datasources. Rows with null values in partition column are not included in the results because none of the partition where clause specify is null predicate on the partition column. This fix adds is null predicate on the partition column to the first JDBC partition where clause. Example: JDBCPartition(THEID < 1 or THEID is null, 0),JDBCPartition(THEID >= 1 AND THEID < 2,1), JDBCPartition(THEID >= 2, 2) Author: sureshthalamati Closes #11063 from sureshthalamati/nullable_jdbc_part_col_spark-13167. --- .../datasources/jdbc/JDBCRelation.scala | 8 +++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index ee6373d03e1fd..9e336422d1f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -44,6 +44,12 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. + * + * Null value predicate is added to the first partition where clause to include + * the rows with null value for the partitions column. + * + * @param partitioning partition information to generate the where clause for each partition + * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -66,7 +72,7 @@ private[sql] object JDBCRelation { if (upperBound == null) { lowerBound } else if (lowerBound == null) { - upperBound + s"$upperBound or $column is null" } else { s"$lowerBound AND $upperBound" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index f8a9a95c873ad..30a5e2ea4acd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -171,6 +171,27 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( + "create table test.emp(name TEXT(32) NOT NULL," + + " theid INTEGER, \"Dept\" INTEGER)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('fred', 1, 10)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('mary', 2, null)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('joe ''foo'' \"bar\"', 3, 30)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('kathy', null, null)").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE nullparts + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.EMP', user 'testUser', password 'testPass', + |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4') + """.stripMargin.replaceAll("\n", " ")) + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -338,6 +359,23 @@ class JDBCSuite extends SparkFunSuite .collect().length === 3) } + test("Partioning on column that might have null values.") { + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + .collect().length === 4) + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + .collect().length === 4) + // partitioning on a nullable quoted column + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + .collect().length === 4) + } + + test("SELECT * on partitioned table with a nullable partioncolumn") { + assert(sql("SELECT * FROM nullparts").collect().size == 4) + } + test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) From ab05aa61b21a23531e43d082757a40bf2ab750d8 Mon Sep 17 00:00:00 2001 From: GayathriMurali Date: Tue, 1 Mar 2016 20:22:44 -0800 Subject: [PATCH 2088/2089] Allow users to set initial model in logistic regression --- .../org/apache/spark/ml/classification/LogisticRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0d329d2c08d50..9aca2625c77ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -250,7 +250,7 @@ class LogisticRegression @Since("1.2.0") ( private var optInitialModel: Option[LogisticRegressionModel] = None /** @group setParam */ - private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { + def setInitialModel(model: LogisticRegressionModel): this.type = { this.optInitialModel = Some(model) this } From 857f825034a40b246b905361f1b0bdad8635482f Mon Sep 17 00:00:00 2001 From: GayathriMurali Date: Sat, 5 Mar 2016 16:42:56 -0800 Subject: [PATCH 2089/2089] Confusing examples in pyspark SQL docs --- python/pyspark/sql/column.py | 17 ++++++++++++++++- python/pyspark/sql/dataframe.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 3866a49c0b5f6..c0c392c47608e 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -324,7 +324,10 @@ def cast(self, dataType): [Row(ages=u'2'), Row(ages=u'5')] >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] + + (Note : cast is alias of astype) """ + if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): @@ -337,7 +340,19 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) - astype = cast + @ignore_unicode_prefix + @since(1.3) + def astype(self, dataType): + """ Convert the column into type ``dataType``. + + >>> df.select(df.age.astype("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.astype(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + + (Note : astype is alias for cast) + """ + return self.cast(dataType) @since(1.3) def between(self, lowerBound, upperBound): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 76fbb0c9aa4c9..6f22ed289bfdd 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1012,6 +1012,35 @@ def dropDuplicates(self, subset=None): jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx) + @since(1.4) + def drop_duplicates(self, subset=None): + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only considering certain columns. + + :func:`dropDuplicates` is an alias for :func:`drop_duplicates`. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([ \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=10, height=80)]).toDF() + >>> df.drop_duplicates().show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + | 10| 80|Alice| + +---+------+-----+ + + >>> df.drop_duplicates(['name', 'height']).show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + +---+------+-----+ + """ + return self.cast(subset) + @since("1.3.1") def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. @@ -1401,8 +1430,6 @@ def toPandas(self): ########################################################################################## groupby = groupBy - drop_duplicates = dropDuplicates - def _to_scala_map(sc, jm): """